李宏毅老师PM2.5预测作业白话

356 阅读11分钟

机器学习——PM2.5预测白话

项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析
课程 传送门
该项目AiStudio项目 传送门
数据集 传送门

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

作业1-PM2.5预测

项目描述

  • 本次作业的资料是从行政院环境环保署空气品质监测网所下载的观测资料。
  • 希望大家能在本作业实现 linear regression 预测出 PM2.5 的数值。

数据集介绍

  • 本次作业使用丰原站的观测记录,分成 train set 跟 test set,train set 是丰原站每个月的前 20 天所有资料。test set 则是从丰原站剩下的资料中取样出来。
  • train.csv: 每个月前 20 天的完整资料。
  • test.csv : 从剩下的资料当中取样出连续的 10 小时为一笔,前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。一共取出 240 笔不重複的 test data,请根据 feature 预测这 240 笔的 PM2.5。
  • Data 含有 18 项观测数据 AMB_TEMP, CH4, CO, NHMC, NO, NO2, NOx, O3, PM10, PM2.5, RAINFALL, RH, SO2, THC, WD_HR, WIND_DIREC, WIND_SPEED, WS_HR。

项目要求

  • 请手动实现 linear regression,方法限使用 gradient descent。
  • 禁止使用 numpy.linalg.lstsq

数据准备

环境配置/安装

数据解读

对数据进行理解和了解后数据如图:
在这里插入图片描述
横向分别是24小时的数据值
竖向是12个月、每月20天、每天18种数据

!pip install --upgrade pandas
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already up-to-date: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.2.3)
Requirement already satisfied, skipping upgrade: pytz>=2017.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2019.3)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2.8.0)
Requirement already satisfied, skipping upgrade: numpy>=1.16.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (1.20.1)
Requirement already satisfied, skipping upgrade: six>=1.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)

导入需要的包,读取训练集

import numpy as np
import pandas as pd
data = pd.read_csv('work/hw1_data/train.csv', encoding = 'big5')  # 使用'big5'进行编码
print(data)  # 查看数据
print(data.shape)  # 查看数据大小
         0     1     2     3     4     5     6     7     8     9  ...    14  \
0       14    14    14    13    12    12    12    12    15    17  ...    22   
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   1.8   
2     0.51  0.41  0.39  0.37  0.35   0.3  0.37  0.47  0.78  0.74  ...  0.37   
3      0.2  0.15  0.13  0.12  0.11  0.06   0.1  0.13  0.26  0.23  ...   0.1   
4      0.9   0.6   0.5   1.7   1.8   1.5   1.9   2.2   6.6   7.9  ...   2.5   
...    ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   ...   
4315   1.8   1.8   1.8   1.8   1.8   1.7   1.7   1.8   1.8   1.8  ...   1.8   
4316    46    13    61    44    55    68    66    70    66    85  ...    59   
4317    36    55    72   327    74    52    59    83   106   105  ...    18   
4318   1.9   2.4   1.9   2.8   2.3   1.9   2.1   3.7   2.8   3.8  ...   2.3   
4319   0.7   0.8   1.8     1   1.9   1.7   2.1     2     2   1.7  ...   1.3   

        15    16    17    18    19    20    21    22    23  
0       22    21    19    17    16    15    15    15    15  
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  
2     0.37  0.47  0.69  0.56  0.45  0.38  0.35  0.36  0.32  
3     0.13  0.14  0.23  0.18  0.12   0.1  0.09   0.1  0.08  
4      2.2   2.5   2.3   2.1   1.9   1.5   1.6   1.8   1.5  
...    ...   ...   ...   ...   ...   ...   ...   ...   ...  
4315   1.8     2   2.1     2   1.9   1.9   1.9     2     2  
4316   308   327    21   100   109   108   114   108   109  
4317   311    52    54   121    97   107   118   100   105  
4318   2.6   1.3     1   1.5     1   1.7   1.5     2     2  
4319   1.7   0.7   0.4   1.1   1.4   1.3   1.6   1.8     2  

[4320 rows x 24 columns]
(4320, 24)

取需要的数值部分,将 ‘RAINFALL’ 栏位全部补 0。

对数据进行查看后发现有缺失的数据,对缺失数据进行处理,填补缺失值。
在这里插入图片描述

读取的数据中有部分解释性的内容,我们不需要,可以进行提取直接忽略
在这里插入图片描述
data.iloc[:,:]该函数用于处理数据,把我们需要的部分进行切割获取
data[data == 'xxx'] = 0 把xxx的内容替换成0

data = data.iloc[:, 3:]  # 从列表的第4路项开始取(不要那些没有意义的数字)
print(data)  #查看数据
print(data.shape)  #查看数据大小
print(type(data))
data[data == 'NR'] = 0  # 把'NR'项装换成0
raw_data = data.to_numpy()  # 把数据转换成numpy数组
         0     1     2     3     4     5     6     7     8     9  ...    14  \
0       14    14    14    13    12    12    12    12    15    17  ...    22   
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   1.8   
2     0.51  0.41  0.39  0.37  0.35   0.3  0.37  0.47  0.78  0.74  ...  0.37   
3      0.2  0.15  0.13  0.12  0.11  0.06   0.1  0.13  0.26  0.23  ...   0.1   
4      0.9   0.6   0.5   1.7   1.8   1.5   1.9   2.2   6.6   7.9  ...   2.5   
...    ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   ...   
4315   1.8   1.8   1.8   1.8   1.8   1.7   1.7   1.8   1.8   1.8  ...   1.8   
4316    46    13    61    44    55    68    66    70    66    85  ...    59   
4317    36    55    72   327    74    52    59    83   106   105  ...    18   
4318   1.9   2.4   1.9   2.8   2.3   1.9   2.1   3.7   2.8   3.8  ...   2.3   
4319   0.7   0.8   1.8     1   1.9   1.7   2.1     2     2   1.7  ...   1.3   

        15    16    17    18    19    20    21    22    23  
0       22    21    19    17    16    15    15    15    15  
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  
2     0.37  0.47  0.69  0.56  0.45  0.38  0.35  0.36  0.32  
3     0.13  0.14  0.23  0.18  0.12   0.1  0.09   0.1  0.08  
4      2.2   2.5   2.3   2.1   1.9   1.5   1.6   1.8   1.5  
...    ...   ...   ...   ...   ...   ...   ...   ...   ...  
4315   1.8     2   2.1     2   1.9   1.9   1.9     2     2  
4316   308   327    21   100   109   108   114   108   109  
4317   311    52    54   121    97   107   118   100   105  
4318   2.6   1.3     1   1.5     1   1.7   1.5     2     2  
4319   1.7   0.7   0.4   1.1   1.4   1.3   1.6   1.8     2  

[4320 rows x 24 columns]
(4320, 24)
<class 'pandas.core.frame.DataFrame'>
print(raw_data.shape)  # 查看数组大小
print(type(raw_data))  # 查看类型
(4320, 24)
<class 'numpy.ndarray'>

将原始 4320 * 24 的资料依照每个月分重组成 12 个 18 (features) * 480 (hours) 的资料。

在这里插入图片描述

从原先的24*(18*20*12)转换成12*18*(20*24)

month_data = {}
for month in range(12):
    sample = np.empty([18, 480])  # 新建np数组大小是[18, 480]内容随机
    for day in range(20):
        sample[:, day * 24 : (day + 1) * 24] = raw_data[18 * (20 * month + day) : 18 * (20 * month + day + 1), :]
    month_data[month] = sample
# print(len(month_data),len(month_data[0]))  # 大小查看
# print(month_data)  # 数据查看
print(month_data[month])
print(month_data[month].shape)
[[ 23.    23.    23.   ...  13.    13.    13.  ]
 [  1.6    1.7    1.7  ...   1.8    1.8    1.8 ]
 [  0.22   0.2    0.18 ...   0.51   0.57   0.56]
 ...
 [ 93.    50.    99.   ... 118.   100.   105.  ]
 [  1.8    2.1    3.2  ...   1.5    2.     2.  ]
 [  1.3    0.9    1.   ...   1.6    1.8    2.  ]]
(18, 480)

在这里插入图片描述

每个月会有 480hrs,每 9 小时形成一个 data,每个月会有 471 个 data,故总资料数为 471 * 12 笔,而每笔 data 有 9 * 18 的 features (一小时 18 个 features * 9 小时)。
  • 471次/月 * 12个月就是我们得到的数据量,而每次的数据量是9小时 * 18种数据
对应的 target 则有 471 * 12 个(第 10 个小时的 PM2.5)
  • target是下一个时刻的PM2.5的值

解析:
一个月的数据就是20天 * 24小时 共计 480小时的数据
按照9小时一组进行处理应该得到 480-9+1== 472
但是最后一组数据是没有最后对应的y的值的
所以:是472 - 1 == 471

  • 为什么是9次一组不是10次???
    题目要求:前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。
    根据要求9次数据为训练集第10次的为比较值。
  • 为什么不10个一组然后再做数据处理???
    这个建议不是不行,但是如果10个一组就少了一点点的训练量
    这个具体看自己理解。
x = np.empty([12 * 471, 18 * 9], dtype = float)  
y = np.empty([12 * 471, 1], dtype = float)
for month in range(12):
    for day in range(20):
        for hour in range(24):
            if day == 19 and hour > 14:
                continue
            x[month * 471 + day * 24 + hour, :] = month_data[month][:,day * 24 + hour : day * 24 + hour + 9].reshape(1, -1) #vector dim:18*9 (9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9)
            y[month * 471 + day * 24 + hour, 0] = month_data[month][9, day * 24 + hour + 9] #value
print(x)
print(x.shape)
print(y)
print(y.shape)
[[14.  14.  14.  ...  2.   2.   0.5]
 [14.  14.  13.  ...  2.   0.5  0.3]
 [14.  13.  12.  ...  0.5  0.3  0.8]
 ...
 [17.  18.  19.  ...  1.1  1.4  1.3]
 [18.  19.  18.  ...  1.4  1.3  1.6]
 [19.  18.  17.  ...  1.3  1.6  1.8]]
(5652, 162)
[[30.]
 [41.]
 [44.]
 ...
 [17.]
 [24.]
 [29.]]
(5652, 1)

求值的离散程度

计算x值的方差和平均值
然后计算每一个值的离散程度
np.mean(x, axis = 0):平均值
np.std(x, axis = 0): 方差
离散程度 = 平均差/方差

mean_x = np.mean(x, axis = 0)  # 18 * 9 
print(mean_x.shape)
std_x = np.std(x, axis = 0)  # 18 * 9 
print(std_x.shape)
for i in range(len(x)):  # 12 * 471
    for j in range(len(x[0])):  # 18 * 9 
        if std_x[j] != 0:
            x[i][j] = (x[i][j] - mean_x[j]) / std_x[j]
x.shape
(162,)
(162,)





(5652, 162)

数据集生成

把数据集安装一定的比例进行区分,一部分生成训练集一部分生成测试集(建议8:2)
注:查看对应的数据是否一样!

import math
x_train_set = x[: math.floor(len(x) * 0.8), :]
y_train_set = y[: math.floor(len(y) * 0.8), :]
x_validation = x[math.floor(len(x) * 0.8): , :]
y_validation = y[math.floor(len(y) * 0.8): , :]
print(x_train_set)
print(y_train_set)
print(x_validation)
print(y_validation)
print(len(x_train_set))
print(len(y_train_set))
print(len(x_validation))
print(len(y_validation))
[[-1.35825331 -1.35883937 -1.359222   ...  0.26650729  0.2656797  -1.14082131]
 [-1.35825331 -1.35883937 -1.51819928 ...  0.26650729 -1.13963133  -1.32832904]
 [-1.35825331 -1.51789368 -1.67717656 ... -1.13923451 -1.32700613  -0.85955971]
 ...
 [ 0.86929969  0.70886668  0.38952809 ...  1.39110073  0.2656797  -0.39079039]
 [ 0.71018876  0.39075806  0.07157353 ...  0.26650729 -0.39013211  -0.39079039]
 [ 0.3919669   0.07264944  0.07157353 ... -0.38950555 -0.39013211  -0.85955971]]
[[30.]
 [41.]
 [44.]
 ...
 [ 7.]
 [ 5.]
 [14.]]
[[ 0.07374504  0.07264944  0.07157353 ... -0.38950555 -0.85856912  -0.57829812]
 [ 0.07374504  0.07264944  0.23055081 ... -0.85808615 -0.57750692   0.54674825]
 [ 0.07374504  0.23170375  0.23055081 ... -0.57693779  0.54674191  -0.1095288 ]
 ...
 [-0.88092053 -0.72262212 -0.56433559 ... -0.57693779 -0.29644471  -0.39079039]
 [-0.7218096  -0.56356781 -0.72331287 ... -0.29578943 -0.39013211  -0.1095288 ]
 [-0.56269867 -0.72262212 -0.88229015 ... -0.38950555 -0.10906991   0.07797893]]
[[13.]
 [24.]
 [22.]
 ...
 [17.]
 [24.]
 [29.]]
4521
4521
1131
1131

因为常数项的存在,所以 dimension (dim) 需要多加一栏;eps 项是避免 adagrad 的分母为 0 而加的极小数值。

每一个 dimension (dim) 会对应到各自的 gradient, weight (w),透过一次次的 iteration (iter_time) 学习。

采用均方根误差

dim = 18 * 9 + 1  # 18个数据*9次+1(常量)个
w = np.zeros([dim, 1])  # 生成数据是0的数组
x = np.concatenate((np.ones([12 * 471, 1]), x), axis = 1).astype(float)  # 拼接1和x数组
learning_rate = 100  # 学习率
iter_time = 1000  # 学习次数
adagrad = np.zeros([dim, 1])  # 生成数据是0的数组
eps = 0.0000000001  
for t in range(iter_time):
    loss = np.sqrt(np.sum(np.power(np.dot(x, w) - y, 2))/471/12)  # rmse
    if(t%100==0):  # 100轮输出
        print(str(t) + ":" + str(loss))
    gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - y)  # dim*1
    adagrad += gradient ** 2
    w = w - learning_rate * gradient / np.sqrt(adagrad + eps)
np.save('weight.npy', w) # 保存文件
w
0:27.071214829194115
100:33.78905859777454
200:19.913751298197095
300:13.531068193689693
400:10.645466158446172
500:9.277353455475065
600:8.518042045956502
700:8.014061987588425
800:7.636756824775692
900:7.336563740371125





array([[ 2.13740269e+01],
       [ 3.58888909e+00],
       [ 4.56386323e+00],
       [ 2.16307023e+00],
       [-6.58545223e+00],
       [-3.38885580e+01],
       [ 3.22235518e+01],
       [ 3.49340354e+00],
       [-4.60308671e+00],
       [-1.02374754e+00],
       [-3.96791501e-01],
       [-1.06908800e-01],
       [ 2.22488184e-01],
       [ 8.99634117e-02],
       [ 1.31243105e-01],
       [ 2.15894989e-02],
       [-1.52867263e-01],
       [ 4.54087776e-02],
       [ 5.20999235e-01],
       [ 1.60824213e-01],
       [-3.17709451e-02],
       [ 1.28529025e-02],
       [-1.76839437e-01],
       [ 1.71241371e-01],
       [-1.31190032e-01],
       [-3.51614451e-02],
       [ 1.00826192e-01],
       [ 3.45018257e-01],
       [ 4.00130315e-02],
       [ 2.54331382e-02],
       [-5.04425219e-01],
       [ 3.71483018e-01],
       [ 8.46357671e-01],
       [-8.11920428e-01],
       [-8.00217575e-02],
       [ 1.52737711e-01],
       [ 2.64915130e-01],
       [-5.19860416e-02],
       [-2.51988315e-01],
       [ 3.85246517e-01],
       [ 1.65431451e-01],
       [-7.83633314e-02],
       [-2.89457231e-01],
       [ 1.77615023e-01],
       [ 3.22506948e-01],
       [-4.59955256e-01],
       [-3.48635358e-02],
       [-5.81764363e-01],
       [-6.43394528e-02],
       [-6.32876949e-01],
       [ 6.36624507e-02],
       [ 8.31592506e-02],
       [-4.45157961e-01],
       [-2.34526366e-01],
       [ 9.86608594e-01],
       [ 2.65230652e-01],
       [ 3.51938093e-02],
       [ 3.07464334e-01],
       [-1.04311239e-01],
       [-6.49166901e-02],
       [ 2.11224757e-01],
       [-2.43159815e-01],
       [-1.31285604e-01],
       [ 1.09045810e+00],
       [-3.97913710e-02],
       [ 9.19563678e-01],
       [-9.44824150e-01],
       [-5.04137735e-01],
       [ 6.81272939e-01],
       [-1.34494828e+00],
       [-2.68009542e-01],
       [ 4.36204342e-02],
       [ 1.89619513e+00],
       [-3.41873873e-01],
       [ 1.89162461e-01],
       [ 1.73251268e-02],
       [ 3.14431930e-01],
       [-3.40828467e-01],
       [ 4.92385651e-01],
       [ 9.29634214e-02],
       [-4.50983589e-01],
       [ 1.47456584e+00],
       [-3.03417236e-02],
       [ 7.71229328e-02],
       [ 6.38314494e-01],
       [-7.93287087e-01],
       [ 8.82877506e-01],
       [ 3.18965610e+00],
       [-5.75671706e+00],
       [ 1.60748945e+00],
       [ 1.36142440e+01],
       [ 1.50029111e-01],
       [-4.78389603e-02],
       [-6.29463755e-02],
       [-2.85383032e-02],
       [-3.01562821e-01],
       [ 4.12058013e-01],
       [-6.77534154e-02],
       [-1.00985479e-01],
       [-1.68972973e-01],
       [ 1.64093233e+00],
       [ 1.89670371e+00],
       [ 3.94713816e-01],
       [-4.71231449e+00],
       [-7.42760774e+00],
       [ 6.19781936e+00],
       [ 3.53986244e+00],
       [-9.56245861e-01],
       [-1.04372792e+00],
       [-4.92863713e-01],
       [ 6.31608790e-01],
       [-4.85175956e-01],
       [ 2.58400216e-01],
       [ 9.43846795e-02],
       [-1.29323184e-01],
       [-3.81235287e-01],
       [ 3.86819479e-01],
       [ 4.04211627e-01],
       [ 3.75568914e-01],
       [ 1.83512261e-01],
       [-8.01417708e-02],
       [-3.10188597e-01],
       [-3.96124612e-01],
       [ 3.66227853e-01],
       [ 1.79488593e-01],
       [-3.14477051e-01],
       [-2.37611443e-01],
       [ 3.97076104e-02],
       [ 1.38775912e-01],
       [-3.84015069e-02],
       [-5.47557119e-02],
       [ 4.19975207e-01],
       [ 4.46120687e-01],
       [-4.31074826e-01],
       [-8.74450768e-02],
       [-5.69534264e-02],
       [-7.23980157e-02],
       [-1.39880128e-02],
       [ 1.40489658e-01],
       [-2.44952334e-01],
       [ 1.83646770e-01],
       [-1.64135512e-01],
       [-7.41216452e-02],
       [-9.71414213e-02],
       [ 1.98829041e-02],
       [-4.46965919e-01],
       [-2.63440959e-01],
       [ 1.52924043e-01],
       [ 6.52532847e-02],
       [ 7.06818266e-01],
       [ 9.73757051e-02],
       [-3.35687787e-01],
       [-2.26559165e-01],
       [-3.00117086e-01],
       [ 1.24185231e-01],
       [ 4.18872344e-01],
       [-2.51891946e-01],
       [-1.29095731e-01],
       [-5.57512471e-01],
       [ 8.76239582e-02],
       [ 3.02594902e-01],
       [-4.23463160e-01],
       [ 4.89922051e-01]])

加载 test data,并且以相似于训练资料预先处理和特徵萃取的方式处理,使 test data 形成 240 个维度为 18 * 9 + 1 的资料。

# 对tast_data做同等处理
testdata = pd.read_csv('work/hw1_data/test.csv', header = None, encoding = 'big5')
test_data = testdata.iloc[:, 2:]
test_data[test_data == 'NR'] = 0
test_data = test_data.to_numpy()
test_x = np.empty([240, 18*9], dtype = float)
for i in range(240):
    test_x[i, :] = test_data[18 * i: 18* (i + 1), :].reshape(1, -1)
for i in range(len(test_x)):
    for j in range(len(test_x[0])):
        if std_x[j] != 0:
            test_x[i][j] = (test_x[i][j] - mean_x[j]) / std_x[j]
test_x = np.concatenate((np.ones([240, 1]), test_x), axis = 1).astype(float)
test_x
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/core/frame.py:3215: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._where(-key, value, inplace=True)





array([[ 1.        , -0.24447681, -0.24545919, ..., -0.67065391,
        -1.04594393,  0.07797893],
       [ 1.        , -1.35825331, -1.51789368, ...,  0.17279117,
        -0.10906991, -0.48454426],
       [ 1.        ,  1.5057434 ,  1.34508393, ..., -1.32666675,
        -1.04594393, -0.57829812],
       ...,
       [ 1.        ,  0.3919669 ,  0.54981237, ...,  0.26650729,
        -0.20275731,  1.20302531],
       [ 1.        , -1.8355861 , -1.8360023 , ..., -1.04551839,
        -1.13963133, -1.14082131],
       [ 1.        , -1.35825331, -1.35883937, ...,  2.98427476,
         3.26367657,  1.76554849]])

有了 weight 和测试资料就可以预测 target。

np.dot(test_x, w):预测!

w = np.load('work/weight.npy')  # 读取文档
ans_y = np.dot(test_x, w)  # 预测
ans_y
array([[ 5.17496040e+00],
       [ 1.83062143e+01],
       [ 2.04912181e+01],
       [ 1.15239429e+01],
       [ 2.66160568e+01],
       [ 2.05313481e+01],
       [ 2.19065510e+01],
       [ 3.17364687e+01],
       [ 1.33916741e+01],
       [ 6.44564665e+01],
       [ 2.02645688e+01],
       [ 1.53585761e+01],
       [ 6.85894728e+01],
       [ 4.84281137e+01],
       [ 1.87023338e+01],
       [ 1.01885957e+01],
       [ 3.07403629e+01],
       [ 7.11322178e+01],
       [-4.13051739e+00],
       [ 1.82356940e+01],
       [ 3.85789223e+01],
       [ 7.13115197e+01],
       [ 7.41034816e+00],
       [ 1.87179553e+01],
       [ 1.49372503e+01],
       [ 3.67197367e+01],
       [ 1.79616970e+01],
       [ 7.57894629e+01],
       [ 1.23093102e+01],
       [ 5.62953517e+01],
       [ 2.51131609e+01],
       [ 4.61024867e+00],
       [ 2.48377055e+00],
       [ 2.47594223e+01],
       [ 3.04802805e+01],
       [ 3.84639307e+01],
       [ 4.42023106e+01],
       [ 3.00868360e+01],
       [ 4.04736750e+01],
       [ 2.92264799e+01],
       [ 5.60645605e+00],
       [ 3.86660161e+01],
       [ 3.46102134e+01],
       [ 4.83896975e+01],
       [ 1.47572477e+01],
       [ 3.44668201e+01],
       [ 2.74831069e+01],
       [ 1.20008794e+01],
       [ 2.13780362e+01],
       [ 2.85444031e+01],
       [ 2.01655138e+01],
       [ 1.07966781e+01],
       [ 2.21710358e+01],
       [ 5.34462631e+01],
       [ 1.22195811e+01],
       [ 4.33009685e+01],
       [ 3.21823351e+01],
       [ 2.25672175e+01],
       [ 5.67395142e+01],
       [ 2.07450529e+01],
       [ 1.50288546e+01],
       [ 3.98553016e+01],
       [ 1.29753407e+01],
       [ 5.17416596e+01],
       [ 1.87833696e+01],
       [ 1.23487528e+01],
       [ 1.56336237e+01],
       [-5.88714707e-02],
       [ 4.15080111e+01],
       [ 3.15487475e+01],
       [ 1.86042512e+01],
       [ 3.74768197e+01],
       [ 5.65203907e+01],
       [ 6.58787719e+00],
       [ 1.22293397e+01],
       [ 5.20369640e+00],
       [ 4.79273751e+01],
       [ 1.30207057e+01],
       [ 1.71103017e+01],
       [ 2.06032345e+01],
       [ 2.12844816e+01],
       [ 3.86929353e+01],
       [ 3.00207167e+01],
       [ 8.87674067e+01],
       [ 3.59847002e+01],
       [ 2.67569136e+01],
       [ 2.39635168e+01],
       [ 3.27472428e+01],
       [ 2.21890438e+01],
       [ 2.09921589e+01],
       [ 2.95559943e+01],
       [ 4.09921689e+01],
       [ 8.62511781e+00],
       [ 3.23214718e+01],
       [ 4.65980444e+01],
       [ 2.28840708e+01],
       [ 3.15181297e+01],
       [ 1.11982335e+01],
       [ 2.85274366e+01],
       [ 2.91150680e-01],
       [ 1.79669611e+01],
       [ 2.71241639e+01],
       [ 1.13982328e+01],
       [ 1.64264269e+01],
       [ 2.34252610e+01],
       [ 4.06160827e+01],
       [ 2.58641250e+01],
       [ 5.42273695e+00],
       [ 1.07949211e+01],
       [ 7.28621369e+01],
       [ 4.80228371e+01],
       [ 1.57468083e+01],
       [ 2.46704106e+01],
       [ 1.28277933e+01],
       [ 1.01580576e+01],
       [ 2.72692233e+01],
       [ 2.92087386e+01],
       [ 8.83533962e+00],
       [ 2.00510881e+01],
       [ 2.02123337e+01],
       [ 7.99060093e+01],
       [ 1.80616143e+01],
       [ 3.05428093e+01],
       [ 2.59807924e+01],
       [ 5.21257727e+00],
       [ 3.03556973e+01],
       [ 7.76832289e+00],
       [ 1.53282683e+01],
       [ 2.26663657e+01],
       [ 6.27420542e+01],
       [ 1.89507804e+01],
       [ 1.90763556e+01],
       [ 6.13715741e+01],
       [ 1.58845621e+01],
       [ 1.34094181e+01],
       [ 8.48772484e-01],
       [ 7.83499672e+00],
       [ 5.70128290e+01],
       [ 2.56079968e+01],
       [ 4.96170473e+00],
       [ 3.64148790e+01],
       [ 2.87900067e+01],
       [ 4.91941210e+01],
       [ 4.03068699e+01],
       [ 1.33161806e+01],
       [ 2.76610119e+01],
       [ 1.71580275e+01],
       [ 4.96872626e+01],
       [ 2.30302723e+01],
       [ 3.92409365e+01],
       [ 1.31967539e+01],
       [ 5.94889370e+00],
       [ 2.58216090e+01],
       [ 8.25863421e+00],
       [ 1.91463205e+01],
       [ 4.31824865e+01],
       [ 6.71784358e+00],
       [ 3.38696152e+01],
       [ 1.53699378e+01],
       [ 1.69390450e+01],
       [ 3.78853368e+01],
       [ 1.92024845e+01],
       [ 9.05950472e+00],
       [ 1.02833996e+01],
       [ 4.86724471e+01],
       [ 3.05877162e+01],
       [ 2.47740990e+00],
       [ 1.28116039e+01],
       [ 7.03247898e+01],
       [ 1.48409677e+01],
       [ 6.88655876e+01],
       [ 4.27419924e+01],
       [ 2.40002615e+01],
       [ 2.34207249e+01],
       [ 6.16721244e+01],
       [ 2.54942028e+01],
       [ 1.90048098e+01],
       [ 3.48866829e+01],
       [ 9.40231340e+00],
       [ 2.95200113e+01],
       [ 1.45739659e+01],
       [ 9.12556314e+00],
       [ 5.28125840e+01],
       [ 4.50395380e+01],
       [ 1.74524347e+01],
       [ 3.84939353e+01],
       [ 2.70389191e+01],
       [ 6.55817097e+01],
       [ 7.03730638e+00],
       [ 5.27144771e+01],
       [ 3.82064593e+01],
       [ 2.11698011e+01],
       [ 3.02475569e+01],
       [ 2.71442299e+00],
       [ 1.99329326e+01],
       [-3.41333234e+00],
       [ 3.24459994e+01],
       [ 1.05829730e+01],
       [ 2.17752257e+01],
       [ 6.24652921e+01],
       [ 2.41329437e+01],
       [ 2.62012396e+01],
       [ 6.37444772e+01],
       [ 2.83429777e+00],
       [ 1.43792470e+01],
       [ 9.36985073e+00],
       [ 9.88116661e+00],
       [ 3.49494536e+00],
       [ 1.22608049e+02],
       [ 2.10835130e+01],
       [ 1.75322206e+01],
       [ 2.01830983e+01],
       [ 3.63931322e+01],
       [ 3.49351512e+01],
       [ 1.88303127e+01],
       [ 3.83445555e+01],
       [ 7.79166341e+01],
       [ 1.79532355e+00],
       [ 1.34458279e+01],
       [ 3.61311556e+01],
       [ 1.51504035e+01],
       [ 1.29418483e+01],
       [ 1.13125241e+02],
       [ 1.52246047e+01],
       [ 1.48240260e+01],
       [ 5.92673537e+01],
       [ 1.05836953e+01],
       [ 2.09930626e+01],
       [ 9.78936588e+00],
       [ 4.77118001e+00],
       [ 4.79278069e+01],
       [ 1.23994384e+01],
       [ 4.81464766e+01],
       [ 4.04663804e+01],
       [ 1.69405903e+01],
       [ 4.12665445e+01],
       [ 6.90278920e+01],
       [ 4.03462492e+01],
       [ 1.43137440e+01],
       [ 1.57707266e+01]])

保存预测文件

import csv
with open('work/submit.csv', mode='w', newline='') as submit_file:
    csv_writer = csv.writer(submit_file)
    header = ['id', 'value']
    print(header)
    csv_writer.writerow(header)
    for i in range(240):
        row = ['id_' + str(i), ans_y[i][0]]
        csv_writer.writerow(row)
        print(row)
['id', 'value']
['id_0', 5.174960398984736]
['id_1', 18.30621425352788]
['id_2', 20.491218094180528]
['id_3', 11.523942869805332]
['id_4', 26.616056752306132]
['id_5', 20.531348081761223]
['id_6', 21.906551018797376]
['id_7', 31.736468747068834]
['id_8', 13.391674055111736]
['id_9', 64.45646650291954]
['id_10', 20.264568836159437]
['id_11', 15.35857607736122]
['id_12', 68.58947276926726]
['id_13', 48.428113747457196]
['id_14', 18.702333824193218]
['id_15', 10.188595737466716]
['id_16', 30.74036285982044]
['id_17', 71.13221776355108]
['id_18', -4.130517391262444]
['id_19', 18.235694016428695]
['id_20', 38.578922275007756]
['id_21', 71.31151972531332]
['id_22', 7.410348162634086]
['id_23', 18.717955330321395]
['id_24', 14.93725026008458]
['id_25', 36.719736694705325]
['id_26', 17.96169700566271]
['id_27', 75.78946287210539]
['id_28', 12.309310248614484]
['id_29', 56.2953517396496]
['id_30', 25.113160865661484]
['id_31', 4.610248674094053]
['id_32', 2.4837705545150315]
['id_33', 24.75942226132128]
['id_34', 30.480280465591196]
['id_35', 38.46393074642666]
['id_36', 44.20231060933005]
['id_37', 30.08683601986601]
['id_38', 40.47367501574008]
['id_39', 29.22647990231738]
['id_40', 5.606456054343949]
['id_41', 38.666016078789596]
['id_42', 34.61021343187721]
['id_43', 48.38969750738482]
['id_44', 14.757247666944172]
['id_45', 34.46682011087208]
['id_46', 27.48310687418436]
['id_47', 12.000879378154043]
['id_48', 21.378036151603794]
['id_49', 28.54440309166328]
['id_50', 20.16551381841159]
['id_51', 10.796678149746501]
['id_52', 22.171035755750125]
['id_53', 53.446263109352266]
['id_54', 12.21958112161002]
['id_55', 43.30096845517155]
['id_56', 32.1823351032854]
['id_57', 22.5672175145708]
['id_58', 56.73951416554704]
['id_59', 20.745052945295473]
['id_60', 15.028854557473265]
['id_61', 39.8553015903851]
['id_62', 12.975340680728284]
['id_63', 51.74165959283004]
['id_64', 18.783369632539877]
['id_65', 12.348752842777712]
['id_66', 15.633623653541925]
['id_67', -0.05887147068500154]
['id_68', 41.50801107307596]
['id_69', 31.548747530656026]
['id_70', 18.604251157547075]
['id_71', 37.4768197248807]
['id_72', 56.52039065762305]
['id_73', 6.58787719352195]
['id_74', 12.229339737435051]
['id_75', 5.203696404134638]
['id_76', 47.92737510380059]
['id_77', 13.020705685594661]
['id_78', 17.110301693903597]
['id_79', 20.603234531002048]
['id_80', 21.284481560784613]
['id_81', 38.69293529051181]
['id_82', 30.020716675725847]
['id_83', 88.76740666723548]
['id_84', 35.984700239668264]
['id_85', 26.756913553477187]
['id_86', 23.963516843564403]
['id_87', 32.747242828083074]
['id_88', 22.18904375531994]
['id_89', 20.992158853626545]
['id_90', 29.555994316645446]
['id_91', 40.99216886651781]
['id_92', 8.625117809911558]
['id_93', 32.3214718088779]
['id_94', 46.59804436536759]
['id_95', 22.88407082672354]
['id_96', 31.518129728251655]
['id_97', 11.19823347976612]
['id_98', 28.527436642529608]
['id_99', 0.2911506800896443]
['id_100', 17.96696107953969]
['id_101', 27.124163929470143]
['id_102', 11.398232780652847]
['id_103', 16.426426865673527]
['id_104', 23.42526104692219]
['id_105', 40.6160826705684]
['id_106', 25.8641250265604]
['id_107', 5.422736951672389]
['id_108', 10.794921122256104]
['id_109', 72.86213692992126]
['id_110', 48.022837059481375]
['id_111', 15.746808276902996]
['id_112', 24.67041061417795]
['id_113', 12.827793326536716]
['id_114', 10.158057570240526]
['id_115', 27.269223342020982]
['id_116', 29.208738577932458]
['id_117', 8.835339619930767]
['id_118', 20.05108813712978]
['id_119', 20.212333743764248]
['id_120', 79.9060092987056]
['id_121', 18.061614288263595]
['id_122', 30.542809341304345]
['id_123', 25.98079237772804]
['id_124', 5.212577268164767]
['id_125', 30.355697305856214]
['id_126', 7.768322888914637]
['id_127', 15.328268255393336]
['id_128', 22.66636571769797]
['id_129', 62.742054211090085]
['id_130', 18.950780367987996]
['id_131', 19.076355630838545]
['id_132', 61.37157409163711]
['id_133', 15.884562052629718]
['id_134', 13.409418077705558]
['id_135', 0.8487724836112842]
['id_136', 7.834996717304126]
['id_137', 57.01282901179679]
['id_138', 25.607996751813804]
['id_139', 4.9617047292420855]
['id_140', 36.414879039062775]
['id_141', 28.790006721975917]
['id_142', 49.19412096197634]
['id_143', 40.3068698557345]
['id_144', 13.316180593982658]
['id_145', 27.661011875229164]
['id_146', 17.158027524366766]
['id_147', 49.68726256929682]
['id_148', 23.03027229160478]
['id_149', 39.240936524842766]
['id_150', 13.19675388941254]
['id_151', 5.948893701039413]
['id_152', 25.82160897630425]
['id_153', 8.258634214291634]
['id_154', 19.146320517225597]
['id_155', 43.18248652651674]
['id_156', 6.717843578093033]
['id_157', 33.869615246810646]
['id_158', 15.3699378469818]
['id_159', 16.939044973551923]
['id_160', 37.88533679463485]
['id_161', 19.202484541054467]
['id_162', 9.059504715654725]
['id_163', 10.283399610648509]
['id_164', 48.672447125698284]
['id_165', 30.58771621323082]
['id_166', 2.4774098975321657]
['id_167', 12.811603937805932]
['id_168', 70.32478980976464]
['id_169', 14.840967694067068]
['id_170', 68.8655875667886]
['id_171', 42.74199244486634]
['id_172', 24.000261542920168]
['id_173', 23.420724860321446]
['id_174', 61.672124435682356]
['id_175', 25.494202845059192]
['id_176', 19.004809786869096]
['id_177', 34.88668288189683]
['id_178', 9.40231339837975]
['id_179', 29.520011314408027]
['id_180', 14.573965885700483]
['id_181', 9.125563143203598]
['id_182', 52.81258399813187]
['id_183', 45.03953799438962]
['id_184', 17.452434679183295]
['id_185', 38.49393527971433]
['id_186', 27.03891909264382]
['id_187', 65.58170967424583]
['id_188', 7.0373063807695795]
['id_189', 52.71447713411572]
['id_190', 38.20645933704977]
['id_191', 21.16980105955784]
['id_192', 30.247556879488393]
['id_193', 2.714422989716304]
['id_194', 19.93293258764082]
['id_195', -3.413332337603944]
['id_196', 32.44599940281316]
['id_197', 10.582973029979941]
['id_198', 21.77522570725845]
['id_199', 62.465292065677886]
['id_200', 24.13294368731649]
['id_201', 26.201239647400964]
['id_202', 63.74447723440287]
['id_203', 2.83429777412905]
['id_204', 14.37924698697884]
['id_205', 9.369850731753894]
['id_206', 9.881166613595411]
['id_207', 3.4949453589721426]
['id_208', 122.6080493792178]
['id_209', 21.083513014480573]
['id_210', 17.53222059945511]
['id_211', 20.183098344597003]
['id_212', 36.39313221228185]
['id_213', 34.93515120529068]
['id_214', 18.83031266145864]
['id_215', 38.34455552272332]
['id_216', 77.91663413807038]
['id_217', 1.7953235508882215]
['id_218', 13.445827939135775]
['id_219', 36.131155590412135]
['id_220', 15.150403498166307]
['id_221', 12.941848334417926]
['id_222', 113.12524093786391]
['id_223', 15.224604677934382]
['id_224', 14.824025968612034]
['id_225', 59.267353688540446]
['id_226', 10.583695290718481]
['id_227', 20.993062563532174]
['id_228', 9.789365880830381]
['id_229', 4.77118000870597]
['id_230', 47.92780690481291]
['id_231', 12.399438394751039]
['id_232', 48.14647656264414]
['id_233', 40.46638039656415]
['id_234', 16.94059027033294]
['id_235', 41.26654448941875]
['id_236', 69.02789203372899]
['id_237', 40.34624924412241]
['id_238', 14.313743982871129]
['id_239', 15.770726634219828]

作者简介

作者:三岁
经历:自学python,现在混迹于paddle社区,希望和大家一起从基础走起,一起学习Paddle
csdn地址:blog.csdn.net/weixin_4562…
我在AI Studio上获得钻石等级,点亮7个徽章,来互关呀~ aistudio.baidu.com/aistudio/pe…

传说中的飞桨社区最菜代码人,让我们一起努力!
记住:三岁出品必是精品 (不要脸系列