2.1.6 线性回归示例代码

在安装完成相应的Python包后就可以正式地对线性回归进行建模与求解了,详细完整的代码见Book/Chapter02/01_house_price_train.py文件。

1.导入包

首先需要将用到的相关Python包进行导入,代码如下:

2.制作样本(数据集)

制作用于训练模型的数据集,代码如下:

在上述代码中,第4行的作用是在真实的房价中加入一定的噪声(误差)。

3.定义模型并求解

通过sklearn中的LinearRegression类来对线性回归模型的参数进行求解与预测,代码如下:

在上述代码中np.reshape(x,(-1,1))表示把x变成[n,1]的形状,至于n到底是多少,将通过np.reshape函数自己推导出。例如x的shape为[4,5],如果想把a改成[2,10]形状,则可以使用a.reshape([2,10]),或者使用a.reshape([2,-1])进行形状的变换。

4.运行结果

最后,调用定义好的函数运行程序,并输出最后训练得到的参数结果,代码如下:

可以发现,其中参数w=7.97、b=-154.32,这就意味着h(x)=7.97x-154.32。在这之后,便可以通过h(x)来对新的输入进行预测了。同时,还能够根据求解后的模型画出对应拟合出的直线,如图2-6所示。

到此,便完成了对于线性回归第一阶段的学习。