第三章:简单线性模型Py...

195 阅读1分钟

话不多说,直接放上自己写的代码

这是y=ax+b的简单实现,直接求导获取最值:

多元的类似,只是一个是数,一个是矩阵运算而已



  • from
    numpy
    import
    *



  • import
    random



  • """




  • dfdgffhhghghr




  • """




  • #一维线性方程的线性回归实现,通过求导获得代价函数的最值




  • def loadDataSet2():



  • ###数据导入 自己创一个拟合y=2x+1的数据




  • dataArr=[]



  • labelArr=[]



  • # myArr=[[0,1],[1,3],[2,4.8],[4,9.3],[-2,-2.8],[-3,-5.2],[1.5,5],[2.8,4],[-3.5,-4]]




  • for
    i
    in
    range(
    20
    ):



  • x=random.uniform(
    -5
    ,
    5
    )



  • dataArr.append(round(x,
    2
    ))



  • labelArr.append(
    2
    *x+
    1
    -random.uniform(
    -1
    ,
    1
    ))







  • #for itemArr in myArr:




  • #dataArr.append(float(itemArr[0]))




  • # labelArr.append(float(itemArr[1]))




  • return
    dataArr,labelArr







  • def calLeastSquareMethod(dataIn,labels):



  • length=len(dataIn)
    #数据个数




  • sumX=
    0.0
    ;sumXX=
    0.0




  • for
    data
    in
    dataIn:



  • sumX+=data



  • sumXX+=data*data



  • avgX=sumX/length



  • s=
    0.0




  • for
    i
    in
    range(length):



  • s+=labels*(dataIn-avgX)



  • weight=s/(sumXX-sumX*sumX/length)







  • s2=
    0.0




  • for
    i
    in
    range(length):



  • s2+=labels-weight*dataIn



  • b=s2/length



  • return
    weight,b







  • def plotBestFit2(weight,b):



  • ###画出最佳拟合直线




  • import
    matplotlib.pyplot
    as
    plt



  • print(weight,b)



  • # dataArr=array(dataMatrix) #矩阵转换为数组




  • n=len(dataArr)



  • xcord1=[];ycord1=[]



  • #xcord2=[];ycord2=[]








  • # for i in range(n):




  • # if(int(labelMatrix)==1):




  • # xcord1.append(dataArr[i,1]);ycord1.append(dataArr[i,2])




  • # else:




  • # xcord2.append(dataArr[i,1]);ycord2.append(dataArr[i,2])




  • for
    i
    in
    range(n):



  • xcord1.append(dataArr);



  • ycord1.append(labelArr)



  • fig=plt.figure()



  • #在子图中画出样本点




  • ax=fig.add_subplot(
    111
    )



  • ax.scatter(xcord1,ycord1,s=
    30
    ,c=
    'red'
    ,marker=
    's'
    )



  • #ax.scatter(xcord2,ycord2,s=30,c='green')








  • #画出拟合直线




  • x=arange(
    -5.0
    ,
    5.0
    ,
    0.1
    )



  • y=weight*x+b







  • ax.plot(x,y)



  • plt.xlabel(
    'X'
    );plt.ylabel(
    'Y'
    )



  • plt.show()







  • dataArr,labelArr=loadDataSet2()



  • weight,b=calLeastSquareMethod(dataArr,labelArr)



  • plotBestFit2(weight,b)

更多免费技术资料可关注:annalin1203