【机器学习】LSTM神经网络实现中国人口预测(1)

1,750 阅读6分钟

本文主要对LSTM神经网络实现中国人口预测项目的数据处理部分进行讲解,主要包括数据基本处理,可视化,相关性分析,缺失值及异常值检测等。

1. 项目简介

本项目将使用PaddlePaddle框架进行机器学习实战,根据指定数据集(中国人口数据集等)使用Paddle框架搭建LSTM神经网络,包括数据预处理、模型构建、模型训练、模型预测、预测结果可视化等。

  • 我们将根据中国人口数据集中的多个特征(features),例如:出生人口(万)、中国人均GPA(美元计)、中国性别比例(按照女生=100)、自然增长率(%)等8个特征字段,预测中国未来总人口(万人)这1个标签字段。属于多输入,单输出LSTM神经网路预测范畴。
  • LSTM算法是一种重要的目前使用最多的时间序列算法,是一种特殊的RNN(Recurrent Neural Network,循环神经网络),能够学习长期的依赖关系。

2. 数据集介绍

本项目使用的数据集为中国人口预测数据集,包含10个字段,其中8个特征字段,1个标签字段,1个行索引字段,数据集各字段对应的数据类型如下表所示:

  • 有些字段为int64类型,需要经过相关的数据处理,才可传入模型进行训练。
  • 数据包含50条样本,因此应该合理确定训练数据、测试数据和验证数据
年份出生人口(万)总人口(万人)中国人均GPA(美元计)中国性别比例(按照女生=100)自然增长率(%)城镇人口(城镇+乡村=100)乡村人口美元兑换人民币汇率中国就业人口(万人)
int64int64int64int64float64float64float64float64float64int64

3. 实战演练

3.1 环境准备

%matplotlib inline

# 导入 paddle
import paddle
import paddle.nn.functional as F

# 导入其他模块
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
print(paddle.__version__)
import paddle.nn as nn

3.2 数据处理

3.2.1 导入数据

population = pd.read_csv("data/data140190/人口.csv")
population.head()

运行结果如下所示:

image.png

3.2.2 查看各字段类型

查看数据类型,因为训练时可能需要转换数据类型,才能传入神经网络。

population.dtypes

image.png

3.3 数据可视化

3.3.1 特征(features)折线图

绘制出各个特征与年份索引之间的折线图,进行初步观察

因为数据集中包含中文字段,想要能够在绘图中正常显示中文,需要进行如下设定:

from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.

绘制折线图:


from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())




titles = [
    "出生人口(万)",
    "总人口(万人)",
    "中国人均GPA(美元计)",
    "中国性别比例(按照女生=100)",
    "自然增长率(%)",
    "城镇人口(城镇+乡村=100)",
    "乡村人口",
    "美元兑换人民币汇率",
    "中国就业人口(万人)",
]

feature_keys = [
    "出生人口(万)",
    "总人口(万人)",
    "中国人均GPA(美元计)",
    "中国性别比例(按照女生=100)",
    "自然增长率(%)",
    "城镇人口(城镇+乡村=100)",
    "乡村人口",
    "美元兑换人民币汇率",
    "中国就业人口(万人)",
]

colors = [
    "blue",
    "chocolate",
    "green",
    "red",
    "purple",
    "brown",
    "darkblue",
    "black",
    "magenta",
]

date_time_key = "年份"


def show_raw_visualization(data):
    time_data = data[date_time_key]
    fig, axes = plt.subplots(
        nrows=3, ncols=3, figsize=(15, 15), dpi=100, facecolor="w", edgecolor="k"
    )
    for i in range(len(feature_keys)):
        key = feature_keys[i]
        c = colors[i % (len(colors))]
        t_data = data[key]
        t_data.index = time_data
        t_data.head()
        
        ax = t_data.plot(
            ax=axes[i // 3, i % 3],
            color=c,
            title="{}".format(titles[i], key),
            rot=25,
        )
        ax.legend([titles[i]])
    plt.tight_layout()


show_raw_visualization(population)

部分运行结果如下图所示:

  • 我们想通过机器学习(搭建LSTM神经网络)的手段对总人口变量进行预测
  • 因此观察总人口变量的变化趋势折线图,可以发现总人口在时间段内的变化比较有规律,所以适用LSTM神经网络进行解决

image.png

3.3.2 箱型图

查看部分数据的分布情况,下面抽取了出生人口(万)、总人口(万人)、中国人均GPA(美元计)、中国就业人口(万人)这四个字段进行箱型图展示。


from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
plt.figure(figsize=(15,8),dpi=100)
plt.subplot(1,4,1)
sns.boxplot(y="出生人口(万)", data=population, saturation=0.9)
plt.subplot(1,4,2)
sns.boxplot(y="总人口(万人)", data=population, saturation=0.9)
plt.subplot(1,4,3)
sns.boxplot(y="中国人均GPA(美元计)", data=population, saturation=0.9)
plt.subplot(1,4,4)
sns.boxplot(y="中国就业人口(万人)", data=population, saturation=0.9)
plt.tight_layout()

运行结果如下图所示:

  • 箱型图在数据处理中通常用来观察离群点(大于上下界异常点)
  • 因为本数据都是搜集到真实数据,因此大于上下界的样本不做剔除

image.png

3.3.3 相关性分析

相关性分析在数据处理中用来查看变量两两之间的相关性

  • 分析特征之间的相关性:可以考虑将两个相关性较强的特征选择一个进行保留,因为本数据集特征字段不是很多,就不考虑剔除了。
  • 分析目标值与特征之间的相关性:发现目标值与特征间(包括年份)的相关性都比较强,因此年份也可以作为特征传入模型进行训练。
corr = population.corr()
# 调用热力图绘制相关性关系
plt.figure(figsize=(10,10),dpi=100)
sns.heatmap(corr, square=True, linewidths=0.1, annot=True)

运行结果如下图所示:

image.png

3.4 缺失值及重复值

3.4.1 重复值检测

查看特征中是否包含重复值,返回false说明没有重复值。无需剔除。

population.duplicated().any()

3.4.2 缺失值检测

查看是否有缺失的样本。返回True说明无缺失值,无需进行额外处理。

pd.notnull(population).all()

3.4.3 转换字段类型

我们需要将int类型字段转化为float类型的字段。

  • 首先使用如下语句查看数据集各个字段类型。
  • 接下来我们将int64转为float64并替换原数据字段
population['出生人口(万)'] = population['出生人口(万)'].astype('float64')
population['总人口(万人)'] = population['总人口(万人)'].astype('float64')
population['中国人均GPA(美元计)'] = population['中国人均GPA(美元计)'].astype('float64')
population['中国就业人口(万人)'] = population['中国就业人口(万人)'].astype('float64')

4. 总结

本文主要完成了项目的数据处理部分,包括基本的统计学分析,数据可视化,异常值以及缺失值处理,下一节将进行数据预处理以及神经网络的搭建,从而预测中国人口。

本文正在参加「金石计划 . 瓜分6万现金大奖」