您的位置:首页 > 其它

单变量线性回归程序实现

2014-06-27 17:51 344 查看

程序流程

1. 读取数据

[plain]
view plaincopyprint?

data = load('ex1data1.txt');
X = data(:, 1); y = data(:, 2);
m = length(y); % number of training examples

2. 画散点图 plotData(X,y)

[plain]
view plaincopyprint?





plot(x,y,'rx','MarkerSize',10);
ylabel('Profit in $10,000s');
xlabel('Population of City in 10,000s');

3. 梯度下降法

计算代价函数J computerCost(X,y,theta)

[plain]
view plaincopyprint?





X = [ones(m, 1), data(:,1)]; % Add a column of ones to x
theta = zeros(2, 1) % initialize fitting parameters

[plain]
view plaincopyprint?





y1 = X * theta;
J = (y - y1)'*(y - y1)/(2*m);

确定迭代次数

[plain]
view plaincopyprint?





iterations = 1500;
alpha = 0.01;

迭代θ0和θ1

[plain]
view plaincopyprint?





for iter = 1:num_iters

temp1 = theta(1) - alpha * sum(X*theta-y)/m;
temp2 = theta(2) - alpha * (X*theta-y)'*X(:,2)/m;
theta = [temp1; temp2];

J_history(iter) = computeCost(X, y, theta);

end

4. 画出由θ确定的直线、输出预测数据

[plain]
view plaincopyprint?





plot(X(:,2), X*theta, '-')
legend('Training data', 'Linear regression')

% Predict values for population sizes of 35,000 and 70,000
predict1 = [1, 3.5] *theta;
fprintf('For population = 35,000, we predict a profit of %f\n',...
predict1*10000);
predict2 = [1, 7] * theta;
fprintf('For population = 70,000, we predict a profit of %f\n',...
predict2*10000);



5. 可视化代价函数J

[plain]
view plaincopyprint?





% Grid over which we will calculate J
theta0_vals = linspace(-10, 10, 100);
theta1_vals = linspace(-1, 4, 100);

% initialize J_vals to a matrix of 0's
J_vals = zeros(length(theta0_vals), length(theta1_vals));

% Fill out J_vals
for i = 1:length(theta0_vals)
for j = 1:length(theta1_vals)
t = [theta0_vals(i); theta1_vals(j)];
J_vals(i,j) = computeCost(X, y, t);
end
end

% Because of the way meshgrids work in the surf command, we need to
% transpose J_vals before calling surf, or else the axes will be flipped
J_vals = J_vals';
% Surface plot
figure;
surf(theta0_vals, theta1_vals, J_vals)
xlabel('\theta_0'); ylabel('\theta_1');



6. J的等值线图

[plain]
view plaincopyprint?





contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20))
xlabel('\theta_0'); ylabel('\theta_1');
hold on;
plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: