本文主要对LSTM神经网络实现中国人口预测项目的数据处理部分进行讲解,主要包括数据基本处理,可视化,相关性分析,缺失值及异常值检测等。
1. 项目简介
本项目将使用PaddlePaddle框架进行机器学习实战,根据指定数据集(中国人口数据集等)使用Paddle框架搭建LSTM神经网络,包括数据预处理、模型构建、模型训练、模型预测、预测结果可视化等。
- 我们将根据中国人口数据集中的多个特征(features),例如:出生人口(万)、中国人均GPA(美元计)、中国性别比例(按照女生=100)、自然增长率(%)等8个特征字段,预测中国未来总人口(万人)这1个标签字段。属于多输入,单输出LSTM神经网路预测范畴。
- LSTM算法是一种重要的目前使用最多的时间序列算法,是一种特殊的RNN(Recurrent Neural Network,循环神经网络),能够学习长期的依赖关系。
2. 数据集介绍
本项目使用的数据集为中国人口预测数据集,包含10个字段,其中8个特征字段,1个标签字段,1个行索引字段,数据集各字段对应的数据类型如下表所示:
- 有些字段为int64类型,需要经过相关的数据处理,才可传入模型进行训练。
- 数据包含50条样本,因此应该合理确定训练数据、测试数据和验证数据
年份 | 出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) |
---|---|---|---|---|---|---|---|---|---|
int64 | int64 | int64 | int64 | float64 | float64 | float64 | float64 | float64 | int64 |
3. 实战演练
3.1 环境准备
%matplotlib inline
# 导入 paddle
import paddle
import paddle.nn.functional as F
# 导入其他模块
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
print(paddle.__version__)
import paddle.nn as nn
3.2 数据处理
3.2.1 导入数据
population = pd.read_csv("data/data140190/人口.csv")
population.head()
运行结果如下所示:
3.2.2 查看各字段类型
查看数据类型,因为训练时可能需要转换数据类型,才能传入神经网络。
population.dtypes
3.3 数据可视化
3.3.1 特征(features)折线图
绘制出各个特征与年份索引之间的折线图,进行初步观察
因为数据集中包含中文字段,想要能够在绘图中正常显示中文,需要进行如下设定:
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.
绘制折线图:
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
titles = [
"出生人口(万)",
"总人口(万人)",
"中国人均GPA(美元计)",
"中国性别比例(按照女生=100)",
"自然增长率(%)",
"城镇人口(城镇+乡村=100)",
"乡村人口",
"美元兑换人民币汇率",
"中国就业人口(万人)",
]
feature_keys = [
"出生人口(万)",
"总人口(万人)",
"中国人均GPA(美元计)",
"中国性别比例(按照女生=100)",
"自然增长率(%)",
"城镇人口(城镇+乡村=100)",
"乡村人口",
"美元兑换人民币汇率",
"中国就业人口(万人)",
]
colors = [
"blue",
"chocolate",
"green",
"red",
"purple",
"brown",
"darkblue",
"black",
"magenta",
]
date_time_key = "年份"
def show_raw_visualization(data):
time_data = data[date_time_key]
fig, axes = plt.subplots(
nrows=3, ncols=3, figsize=(15, 15), dpi=100, facecolor="w", edgecolor="k"
)
for i in range(len(feature_keys)):
key = feature_keys[i]
c = colors[i % (len(colors))]
t_data = data[key]
t_data.index = time_data
t_data.head()
ax = t_data.plot(
ax=axes[i // 3, i % 3],
color=c,
title="{}".format(titles[i], key),
rot=25,
)
ax.legend([titles[i]])
plt.tight_layout()
show_raw_visualization(population)
部分运行结果如下图所示:
- 我们想通过机器学习(搭建LSTM神经网络)的手段对总人口变量进行预测
- 因此观察总人口变量的变化趋势折线图,可以发现总人口在时间段内的变化比较有规律,所以适用LSTM神经网络进行解决
3.3.2 箱型图
查看部分数据的分布情况,下面抽取了出生人口(万)、总人口(万人)、中国人均GPA(美元计)、中国就业人口(万人)这四个字段进行箱型图展示。
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
plt.figure(figsize=(15,8),dpi=100)
plt.subplot(1,4,1)
sns.boxplot(y="出生人口(万)", data=population, saturation=0.9)
plt.subplot(1,4,2)
sns.boxplot(y="总人口(万人)", data=population, saturation=0.9)
plt.subplot(1,4,3)
sns.boxplot(y="中国人均GPA(美元计)", data=population, saturation=0.9)
plt.subplot(1,4,4)
sns.boxplot(y="中国就业人口(万人)", data=population, saturation=0.9)
plt.tight_layout()
运行结果如下图所示:
- 箱型图在数据处理中通常用来观察离群点(大于上下界异常点)
- 因为本数据都是搜集到真实数据,因此大于上下界的样本不做剔除
3.3.3 相关性分析
相关性分析在数据处理中用来查看变量两两之间的相关性
- 分析特征之间的相关性:可以考虑将两个相关性较强的特征选择一个进行保留,因为本数据集特征字段不是很多,就不考虑剔除了。
- 分析目标值与特征之间的相关性:发现目标值与特征间(包括年份)的相关性都比较强,因此年份也可以作为特征传入模型进行训练。
corr = population.corr()
# 调用热力图绘制相关性关系
plt.figure(figsize=(10,10),dpi=100)
sns.heatmap(corr, square=True, linewidths=0.1, annot=True)
运行结果如下图所示:
3.4 缺失值及重复值
3.4.1 重复值检测
查看特征中是否包含重复值,返回false说明没有重复值。无需剔除。
population.duplicated().any()
3.4.2 缺失值检测
查看是否有缺失的样本。返回True说明无缺失值,无需进行额外处理。
pd.notnull(population).all()
3.4.3 转换字段类型
我们需要将int类型字段转化为float类型的字段。
- 首先使用如下语句查看数据集各个字段类型。
- 接下来我们将int64转为float64并替换原数据字段
population['出生人口(万)'] = population['出生人口(万)'].astype('float64')
population['总人口(万人)'] = population['总人口(万人)'].astype('float64')
population['中国人均GPA(美元计)'] = population['中国人均GPA(美元计)'].astype('float64')
population['中国就业人口(万人)'] = population['中国就业人口(万人)'].astype('float64')
4. 总结
本文主要完成了项目的数据处理部分,包括基本的统计学分析,数据可视化,异常值以及缺失值处理,下一节将进行数据预处理以及神经网络的搭建,从而预测中国人口。
本文正在参加「金石计划 . 瓜分6万现金大奖」