mnist 神经网络 受限

165 阅读2分钟



import jax
import jax.numpy as np
import numpy as onp

#import random
import itertools

from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784',data_home='mnist_784')

X, y = mnist["data"], mnist["target"].astype(np.float32)

X = X / 255

digits = 10
examples = y.shape[0]

y = y.reshape(1, examples)

Y_new = np.eye(digits)[y.astype('int32')]
Y_new = Y_new.T.reshape(digits, examples)

m = 60000
m_test = X.shape[0] - m

X_train, X_test = X[:m].T, X[m:].T
Y_train, Y_test = Y_new[:,:m], Y_new[:,m:]

onp.random.seed(138)

shuffle_index = onp.random.permutation(m)
X_train, Y_train = X_train[:, shuffle_index], Y_train[:, shuffle_index]





def sigmoid(x): return 1/(1+np.exp(-x))

def net(params,x):
  w1,b1,w2,b2=params
  z1=np.dot(w1,x)+b1
  a1 = sigmoid(z1)
  z2=np.dot(w2,a1)+b2
  a2 = np.exp(z2) / np.sum(np.exp(z2), axis=0)
  out = a2
  print(x.shape,out.shape)
  return out

def loss(params,x,y):
  out = net(params,x)
  L_sum = np.sum(np.multiply(y, np.log(out)))
  m = y.shape[1]
  L = -(1/m) * L_sum
  return L

def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

n_x = X_train.shape[0]
n_h = 64
learning_rate = 1

W1 = onp.random.randn(n_h, n_x)
b1 = np.zeros((n_h, 1))
W2 = onp.random.randn(digits, n_h)
b2 = np.zeros((digits, 1))

def initial_params():
    return [
        W1,  # w1
        b1,  # b1
        W2,  # w2
        b2,  #b2
    ]

loss_grad=jax.grad(loss)

inputs = X_train

params = initial_params()

def topNSetToZero(topN,W):
  # topN=2
  img_size=W.shape[1]
  for i in range(img_size):
    w1k=W[:,i]
    betweenNIndexList=w1k.argsort()[topN:len(w1k)-topN]
    w1k [ betweenNIndexList ] =0.0

root_dir="/content/drive/My Drive/mnist_using_jax_model_csv/"

import numpy as numpy_real
from datetime import datetime
import os
def save_model_csv(dt_str):
#   dt_str=datetime.now().strftime( '%Y%m%d%H%M%S' )
  dt_dir="%s/%s/csv/"%(root_dir,dt_str)
  os.makedirs(dt_dir)
  numpy_real.savetxt("%s/W1.csv"%dt_dir, W1, delimiter=",")
  numpy_real.savetxt("%s/b1.csv"%dt_dir, b1, delimiter=",")
  numpy_real.savetxt("%s/W2.csv"%dt_dir, W2, delimiter=",")
  numpy_real.savetxt("%s/b2.csv"%dt_dir, b2, delimiter=",")
  print(dt_dir)

def save_w1k(w1k,file_full_path):
    print("w1k.shape",w1k.shape)
    from PIL import Image
    w1kImg=numpy_real.expand_dims(w1k, axis=2)
    w1kImg=w1kImg.repeat([3],axis=2)
    w1kImg[w1k==0.0] = [0, 0, 0]
    w1kImg[w1k>0.0] = [255, 0, 0]
    w1kImg[w1k<0.0] = [0, 255, 0]
    w1k_rgb = numpy_real.array(w1kImg.reshape(28,28,3), dtype=numpy_real.uint8) 
    img = Image.fromarray(w1k_rgb, 'RGB')
    img.save(file_full_path)
  # img.show()

def save_W(dt_str):
  # dt_str=datetime.now().strftime( '%Y%m%d%H%M%S' )
  dt_dir="%s/%s/W1_28_28/"%(root_dir,dt_str)
  dt_img_dir="%s/%s/W1_28_28_img/"%(root_dir,dt_str)
  os.makedirs(dt_dir)
  os.makedirs(dt_img_dir)
  for i in range(W1.shape[0]):
    w1k=W1[i,:].reshape(28,28)
    # print(w1k.shape)
    numpy_real.savetxt("%s/W1_%s.csv"%(dt_dir,i), w1k, delimiter=",")
    save_w1k(w1k,"%s/W1_%s.png"%(dt_img_dir,i))

def imgI_w1k_merge():
  imgI_w1k_merged_ls=[]
  for i in range(10):
    imgI_w1k_ls=W1[W2[i,:]!=0,:]
    imgI_w1k_merged=[[] for _ in range(imgI_w1k_ls.shape[1])]
    for t in range(imgI_w1k_ls.shape[0]):
      for q in range(imgI_w1k_ls.shape[1]):#q 784
        if imgI_w1k_ls[t,q] != 0:
          imgI_w1k_merged[q] .append ( imgI_w1k_ls[t,q] )
    imgI_w1k_merged_ls.append( imgI_w1k_merged )
  return imgI_w1k_merged_ls

def imgI_w1k_alignment_sort(imgI_w1k_merged_ls):
  for i in range(10):
    for t in range(784):
      if len(  imgI_w1k_merged_ls[i][t] ) ==0 :
        imgI_w1k_merged_ls[i][t] = [0,0]
      elif len(  imgI_w1k_merged_ls[i][t] ) ==1 :
        imgI_w1k_merged_ls[i][t].append(0)
      imgI_w1k_merged_ls[i][t] = numpy_real.sort(imgI_w1k_merged_ls[i][t])
  return imgI_w1k_merged_ls

def imgI_w1k_split_then_save(imgI_w1k_merged_ls,dt_str):
  dt_merge_W2_img_neg_dir="%s/%s/W1_28_28_merge_W2_img_neg/"%(root_dir,dt_str)
  dt_merge_W2_img_positive_dir="%s/%s/W1_28_28_merge_W2_img_positive/"%(root_dir,dt_str)
  os.makedirs(dt_merge_W2_img_neg_dir,exist_ok=True)
  os.makedirs(dt_merge_W2_img_positive_dir,exist_ok=True)
  imgI_w1k_merged_ls_nparr=numpy_real.asarray(imgI_w1k_merged_ls)
  print(imgI_w1k_merged_ls_nparr.shape)
  neg,positive=numpy_real.split(imgI_w1k_merged_ls_nparr,2,axis=2)
  print("before neg.shape,positive.shape",neg.shape,positive.shape)
  neg=neg.reshape( (10, 784) )
  positive=positive.reshape( (10, 784) )
  print("neg.shape,positive.shape",neg.shape,positive.shape)
  for i in range(10):
    save_w1k(neg[i].reshape((28,28))  ,'%s/%s.png'%(dt_merge_W2_img_neg_dir,i))
    save_w1k(positive[i].reshape((28,28)) ,'%s/%s.png'%(dt_merge_W2_img_positive_dir,i) )
  print("imgI_w1k_split_then_save end")
  return (neg,positive)

from google.colab import drive
# drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)

# !chmod  755 /content/drive
# !mkdir /content/drive/My\ Drive/mnist_using_jax_model_csv

for n in itertools.count():
  grads= loss_grad(params,X_train,Y_train)
  params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
  topNSetToZero(1,W1)
  topNSetToZero(1,W2)
  print(loss(params,X_train,Y_train))
  if   n > 100: 
    break

dt_str=datetime.now().strftime( '%Y%m%d%H%M%S' )
save_model_csv(dt_str)    
save_W(dt_str)



imgI_w1k_merged_ls=imgI_w1k_merge()
imgI_w1k_merged_ls=imgI_w1k_alignment_sort(imgI_w1k_merged_ls)

neg,positive=imgI_w1k_split_then_save(imgI_w1k_merged_ls,dt_str)

# m[1].shape

# X_test.shape,X_test[:,0].shape

Y_test_predict=net(params,X_test)
Y_test_predict_idx=np.argmax(Y_test_predict,axis=0)
Y_test_predict_idx.shape

Y_test_idx=np.argmax(Y_test, axis=0)
Y_test_idx.shape

test_correct_count = np.sum(Y_test_predict_idx==Y_test_idx) 
test_accuracy=test_correct_count/m_test
print("test_accuracy:",test_accuracy)