手写BP全连接神经网络

501 阅读1分钟

本次手写bp神经网络的优化过程,网络结构如图 image.png 这里做了一点改动,隐层和输出层都使用了sigmoid函数转换 模拟一个样本 FP 和 BP 过程。 代码如下

import numpy as np

_w = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65]
_b = [0.35, 0.65]
_x = [5, 10]
_y = [0.2, 0.5]
_lr = 0.5

def w(index):
    return _w[index - 1]
def x(index):
    return _x[index - 1]
def b(index):
    return _b[index - 1]
def y(index):
    return _y[index - 1]
def set_w(index, gd):
    _w[index - 1] = _w[index - 1] - _lr * gd
def set_b(index, gd):
    _b[index -1] = _b[index - 1] - _lr * gd
def sigmoid(z):
    return 1.0 / (1.0 + np.exp(-z))


def training():
    # 1. FP前向过程 计算损失
    h1 = sigmoid(w(1)*x(1) + w(2)*x(2) + b(1))
    h2 = sigmoid(w(3)*x(1) + w(4)*x(2) + b(1))
    h3 = sigmoid(w(5)*x(1) + w(6)*x(2) + b(1))
    o1 = sigmoid(h1*w(7) + h2*w(9) + h3*w(11) + b(2))
    o2 = sigmoid(h1*w(8) + h2*w(10) + h3*w(12) + b(2))
    loss = 0.5 * (y(1)-o1)**2 + 0.5 * (y(2)-o2)**2

    # 2.BP反向过程 参数更新
    # 梯度的简化写法loss/w1 = loss/h1 * h1/w1 = (loss/o1*o1/h1 + loss/o2*o2/h2) * h1/w1
    t1 = (o1-y(1)) * o1 * (1-o1)
    t2 = (o2-y(2)) * o2 * (1-o2)
    set_w(1, gd=(t1 * w(7) + t2 * w(8)) * h1 * (1 - h1) * x(1))
    set_w(2, gd=(t1 * w(7) + t2 * w(8)) * h1 * (1 - h1) * x(2))
    set_w(3, gd=(t1 * w(7) + t2 * w(8)) * h2 * (1 - h2) * x(1))
    set_w(4, gd=(t1 * w(7) + t2 * w(8)) * h2 * (1 - h2) * x(2))
    set_w(5, gd=(t1 * w(7) + t2 * w(8)) * h3 * (1 - h3) * x(1))
    set_w(6, gd=(t1 * w(7) + t2 * w(8)) * h3 * (1 - h3) * x(2))
    set_w(7, t1*h1)
    set_w(8, t2 * h1)
    set_w(9, t1 * h2)
    set_w(10, t2 * h2)
    set_w(11, t1 * h3)
    set_w(12, t2 * h3)
    #更新b
    # loss/b2 = loss/o1*o1/b2 + loss/o2*o2/b2
    set_b(1, (t1*w(7))*h1*(1-h1)+(t1*w(9))*h2*(1-h2)+(t1*w(11))*h3*(1-h3)+
          (t2*w(8))*h1*(1-h1)+(t2*w(10))*h2*(1-h2)+(t2*w(12))*h3*(1-h3))
    set_b(2, t1+t2)
    return loss, o1, o2


if __name__ == '__main__':
    print('demo')
    #training()
    for i in range(10000):
        _loss = training()
    print(_loss)
    print(_w)
    print(_b)

输出结果:

image.png