使用最简单的自定义三层网络,仅用3个节点,尝试拟合各种函数。效果还行,肯定还有更多优化的余地,仅供入门参考。
包引入
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import MinMaxScaler
函数定义
function_select = 5
def myfun(x):
functions = {
1: np.power(x-7,2), # 二次函数
2: np.sin(x), # sin
3: np.sign(x), # signum
4: np.exp(x), # 指数
5: np.power(x,3) - 3*np.power(x,2) + 5, # 多项式
6: 1+np.power(x,2)/4000-np.cos(x) # 格里旺克函数
}
return functions.get(function_select)
构建模型
activation_function = 'tanh'
def build_model(train_data, labels, units, epochs):
#print(train_data.shape)
model = keras.Sequential()
model.add(keras.layers.Dense(units, input_dim=train_data.shape[1], kernel_initializer='he_normal', activation=activation_function))
model.add(keras.layers.Dense(1, kernel_initializer='he_normal', activation='linear'))
# Compile model
model.compile(optimizer='adam',
loss='mse',
metrics=['mse'])
# 训练模型
model.fit(train_data, labels, epochs=epochs, batch_size=50, verbose=0)
return model
训练数据
batch_size = 20
x_train = np.linspace(-10, 10, num=300).reshape(-1,1)
# 计算真实样本y
y_train = myfun(x_train)
# 正规化
x_scaler = MinMaxScaler(feature_range=(-1, 1))
y_scaler = MinMaxScaler(feature_range=(-1, 1))
x_scaled = x_scaler.fit_transform(x_train)
y_scaled = y_scaler.fit_transform(y_train)
开始训练模型
units = 3
epochs = 2000
model_best = build_model(train_data=x_scaled, labels=y_scaled, units=units, epochs=epochs)
测试
#测试集
x_eval = np.linspace(-8, 5, num=40).reshape(-1,1)
x_eval_scaled = x_scaler.transform(x_eval)
result = model_best.predict(x_eval_scaled, batch_size=50)
predictions = y_scaler.inverse_transform(result)
画图
fig = plt.figure(1, figsize=(20,10))
ax = fig.add_subplot(1, 2, 1)
plt.plot(x_eval, predictions, '.', color='red', linewidth=2.0)
plt.plot(x_eval, myfun(x_eval), '-', color='blue', linewidth=1.0)
plt.plot(x_train, myfun(x_train), '-', color='gray', linewidth=1.0)
ax = fig.add_subplot(1, 2, 2)
plt.plot(x_eval, np.abs(predictions-myfun(x_eval)), '-', label='output', color='firebrick', linewidth=2.0)
plt.show()
多项式
二次函数
sin
注意:针对sin的特点,你需要将训练数据的密集度调整好,否则可能不能很好的拟合。