使用三层简单的神经网络:拟合二次函数、多项式、sin、指数函数

589 阅读1分钟

使用最简单的自定义三层网络,仅用3个节点,尝试拟合各种函数。效果还行,肯定还有更多优化的余地,仅供入门参考。

包引入

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import MinMaxScaler

函数定义


function_select = 5

def myfun(x):
    functions = {
        1: np.power(x-7,2), # 二次函数
        2: np.sin(x), # sin
        3: np.sign(x), # signum
        4: np.exp(x), # 指数
        5: np.power(x,3) - 3*np.power(x,2) + 5, # 多项式
        6: 1+np.power(x,2)/4000-np.cos(x) # 格里旺克函数
    }
    return functions.get(function_select)

构建模型

activation_function = 'tanh'

def build_model(train_data, labels, units, epochs):
    #print(train_data.shape)
    model = keras.Sequential()
    model.add(keras.layers.Dense(units, input_dim=train_data.shape[1], kernel_initializer='he_normal', activation=activation_function))
    model.add(keras.layers.Dense(1, kernel_initializer='he_normal', activation='linear'))
    # Compile model
    model.compile(optimizer='adam', 
                  loss='mse',
                  metrics=['mse'])
    
    # 训练模型
    model.fit(train_data, labels, epochs=epochs, batch_size=50, verbose=0)
    return model

训练数据

batch_size = 20
x_train = np.linspace(-10, 10, num=300).reshape(-1,1)

# 计算真实样本y
y_train = myfun(x_train)
# 正规化
x_scaler = MinMaxScaler(feature_range=(-1, 1))
y_scaler = MinMaxScaler(feature_range=(-1, 1))
x_scaled = x_scaler.fit_transform(x_train)
y_scaled = y_scaler.fit_transform(y_train)

开始训练模型

units = 3
epochs = 2000
model_best = build_model(train_data=x_scaled, labels=y_scaled, units=units, epochs=epochs)

测试

#测试集
x_eval = np.linspace(-8, 5, num=40).reshape(-1,1)
x_eval_scaled = x_scaler.transform(x_eval)

result = model_best.predict(x_eval_scaled, batch_size=50)
predictions = y_scaler.inverse_transform(result)

画图

fig = plt.figure(1, figsize=(20,10))
ax = fig.add_subplot(1, 2, 1)

plt.plot(x_eval, predictions, '.', color='red', linewidth=2.0)
plt.plot(x_eval, myfun(x_eval), '-', color='blue', linewidth=1.0)
plt.plot(x_train, myfun(x_train), '-', color='gray', linewidth=1.0)

ax = fig.add_subplot(1, 2, 2)
plt.plot(x_eval, np.abs(predictions-myfun(x_eval)), '-', label='output', color='firebrick', linewidth=2.0)
plt.show()

多项式

二次函数

sin

注意:针对sin的特点,你需要将训练数据的密集度调整好,否则可能不能很好的拟合。