
import jax
import jax.numpy as np
import numpy as onp
import itertools
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784',data_home='mnist_784')
X, y = mnist["data"], mnist["target"].astype(np.float32)
X = X / 255
digits = 10
examples = y.shape[0]
y = y.reshape(1, examples)
Y_new = np.eye(digits)[y.astype('int32')]
Y_new = Y_new.T.reshape(digits, examples)
m = 60000
m_test = X.shape[0] - m
X_train, X_test = X[:m].T, X[m:].T
Y_train, Y_test = Y_new[:,:m], Y_new[:,m:]
onp.random.seed(138)
shuffle_index = onp.random.permutation(m)
X_train, Y_train = X_train[:, shuffle_index], Y_train[:, shuffle_index]
def sigmoid(x): return 1/(1+np.exp(-x))
def net(params,x):
w1,b1,w2,b2=params
z1=np.dot(w1,x)+b1
a1 = sigmoid(z1)
z2=np.dot(w2,a1)+b2
a2 = np.exp(z2) / np.sum(np.exp(z2), axis=0)
out = a2
print(x.shape,out.shape)
return out
def loss(params,x,y):
out = net(params,x)
L_sum = np.sum(np.multiply(y, np.log(out)))
m = y.shape[1]
L = -(1/m) * L_sum
return L
def test_all_inputs(inputs, params):
predictions = [int(net(params, inp) > 0.5) for inp in inputs]
for inp, out in zip(inputs, predictions):
print(inp, '->', out)
return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
n_x = X_train.shape[0]
n_h = 64
learning_rate = 1
W1 = onp.random.randn(n_h, n_x)
b1 = np.zeros((n_h, 1))
W2 = onp.random.randn(digits, n_h)
b2 = np.zeros((digits, 1))
def initial_params():
return [
W1,
b1,
W2,
b2,
]
loss_grad=jax.grad(loss)
inputs = X_train
params = initial_params()
def topNSetToZero(topN,W):
img_size=W.shape[1]
for i in range(img_size):
w1k=W[:,i]
betweenNIndexList=w1k.argsort()[topN:len(w1k)-topN]
w1k [ betweenNIndexList ] =0.0
root_dir="/content/drive/My Drive/mnist_using_jax_model_csv/"
import numpy as numpy_real
from datetime import datetime
import os
def save_model_csv(dt_str):
dt_dir="%s/%s/csv/"%(root_dir,dt_str)
os.makedirs(dt_dir)
numpy_real.savetxt("%s/W1.csv"%dt_dir, W1, delimiter=",")
numpy_real.savetxt("%s/b1.csv"%dt_dir, b1, delimiter=",")
numpy_real.savetxt("%s/W2.csv"%dt_dir, W2, delimiter=",")
numpy_real.savetxt("%s/b2.csv"%dt_dir, b2, delimiter=",")
print(dt_dir)
def save_w1k(w1k,file_full_path):
print("w1k.shape",w1k.shape)
from PIL import Image
w1kImg=numpy_real.expand_dims(w1k, axis=2)
w1kImg=w1kImg.repeat([3],axis=2)
w1kImg[w1k==0.0] = [0, 0, 0]
w1kImg[w1k>0.0] = [255, 0, 0]
w1kImg[w1k<0.0] = [0, 255, 0]
w1k_rgb = numpy_real.array(w1kImg.reshape(28,28,3), dtype=numpy_real.uint8)
img = Image.fromarray(w1k_rgb, 'RGB')
img.save(file_full_path)
def save_W(dt_str):
dt_dir="%s/%s/W1_28_28/"%(root_dir,dt_str)
dt_img_dir="%s/%s/W1_28_28_img/"%(root_dir,dt_str)
os.makedirs(dt_dir)
os.makedirs(dt_img_dir)
for i in range(W1.shape[0]):
w1k=W1[i,:].reshape(28,28)
numpy_real.savetxt("%s/W1_%s.csv"%(dt_dir,i), w1k, delimiter=",")
save_w1k(w1k,"%s/W1_%s.png"%(dt_img_dir,i))
def imgI_w1k_merge():
imgI_w1k_merged_ls=[]
for i in range(10):
imgI_w1k_ls=W1[W2[i,:]!=0,:]
imgI_w1k_merged=[[] for _ in range(imgI_w1k_ls.shape[1])]
for t in range(imgI_w1k_ls.shape[0]):
for q in range(imgI_w1k_ls.shape[1]):
if imgI_w1k_ls[t,q] != 0:
imgI_w1k_merged[q] .append ( imgI_w1k_ls[t,q] )
imgI_w1k_merged_ls.append( imgI_w1k_merged )
return imgI_w1k_merged_ls
def imgI_w1k_alignment_sort(imgI_w1k_merged_ls):
for i in range(10):
for t in range(784):
if len( imgI_w1k_merged_ls[i][t] ) ==0 :
imgI_w1k_merged_ls[i][t] = [0,0]
elif len( imgI_w1k_merged_ls[i][t] ) ==1 :
imgI_w1k_merged_ls[i][t].append(0)
imgI_w1k_merged_ls[i][t] = numpy_real.sort(imgI_w1k_merged_ls[i][t])
return imgI_w1k_merged_ls
def imgI_w1k_split_then_save(imgI_w1k_merged_ls,dt_str):
dt_merge_W2_img_neg_dir="%s/%s/W1_28_28_merge_W2_img_neg/"%(root_dir,dt_str)
dt_merge_W2_img_positive_dir="%s/%s/W1_28_28_merge_W2_img_positive/"%(root_dir,dt_str)
os.makedirs(dt_merge_W2_img_neg_dir,exist_ok=True)
os.makedirs(dt_merge_W2_img_positive_dir,exist_ok=True)
imgI_w1k_merged_ls_nparr=numpy_real.asarray(imgI_w1k_merged_ls)
print(imgI_w1k_merged_ls_nparr.shape)
neg,positive=numpy_real.split(imgI_w1k_merged_ls_nparr,2,axis=2)
print("before neg.shape,positive.shape",neg.shape,positive.shape)
neg=neg.reshape( (10, 784) )
positive=positive.reshape( (10, 784) )
print("neg.shape,positive.shape",neg.shape,positive.shape)
for i in range(10):
save_w1k(neg[i].reshape((28,28)) ,'%s/%s.png'%(dt_merge_W2_img_neg_dir,i))
save_w1k(positive[i].reshape((28,28)) ,'%s/%s.png'%(dt_merge_W2_img_positive_dir,i) )
print("imgI_w1k_split_then_save end")
return (neg,positive)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
for n in itertools.count():
grads= loss_grad(params,X_train,Y_train)
params = [param - learning_rate * grad
for param, grad in zip(params, grads)]
topNSetToZero(1,W1)
topNSetToZero(1,W2)
print(loss(params,X_train,Y_train))
if n > 100:
break
dt_str=datetime.now().strftime( '%Y%m%d%H%M%S' )
save_model_csv(dt_str)
save_W(dt_str)
imgI_w1k_merged_ls=imgI_w1k_merge()
imgI_w1k_merged_ls=imgI_w1k_alignment_sort(imgI_w1k_merged_ls)
neg,positive=imgI_w1k_split_then_save(imgI_w1k_merged_ls,dt_str)
Y_test_predict=net(params,X_test)
Y_test_predict_idx=np.argmax(Y_test_predict,axis=0)
Y_test_predict_idx.shape
Y_test_idx=np.argmax(Y_test, axis=0)
Y_test_idx.shape
test_correct_count = np.sum(Y_test_predict_idx==Y_test_idx)
test_accuracy=test_correct_count/m_test
print("test_accuracy:",test_accuracy)