本文已参与「新人创作礼」活动,一起开启掘金创作之路。
一、可视化工具网址
二、数据打包成lmdb格式(lmdb.py)
#!/usr/bin/env python
# encoding: utf-8
'''
@author: lljyoyo1995
@license: (C) Copyright 2014-2019,Fish Group.
@contact: 1821381898@qq.com
@software: PyCharm
@file: lmdb.py
@time: 19-8-21 下午3:42
@desc:
'''
import lmdb
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
def write():
"""
生成lmdb格式文件
:return:
"""
lmdb_file = 'lmdb_data' # 存放lmdb文件的文件夹
batch_size = 256
lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12)) # 1e1, 10^12
lmdb_txn = lmdb_env.begin(write = True)
for x in range(batch_size):
data = np.ones((3, 64, 64), np.uint8)
label = x
datum = caffe.io.array_to_datum(data, label)
keystr = '{:0>8d}'.format(x)
lmdb_txn.put(keystr, datum.SerializeToString()) # 序列化
lmdb_txn.commit()
def read():
"""
读取lmdb文件
:return:
"""
lmdb_env = lmdb.open('lmdb_data')
lmdb_txt = lmdb_env.begin()
datum = caffe_pb2.Datum()
for key, value in lmdb_txt.cursor():
datum.ParseFromString(value) # 反序列化
label = datum.label
data = caffe.io.datum_to_array(datum)
print(label)
print(data)
if __name__ == '__main__':
write()
read()
运行lmdb.py,生成data.mdb文件
三、构建网络(create_net.py)
#!/usr/bin/env python
# encoding: utf-8
'''
@author: lljyoyo1995
@license: (C) Copyright 2014-2019,Fish Group.
@contact: 1821381898@qq.com
@software: PyCharm
@file: create_net.py
@time: 19-8-21 下午6:03
@desc:
'''
import caffe
def create_net():
# 网络规范
net = caffe.NetSpec()
# 第一层Data层
net.data, net.label = caffe.layers.Data(source='data.lmdb',
backend=caffe.params.Data.LMDB,
batch_size=32,
ntop=2,
transform_param=dict(crop_size=40, mirror=True)
)
# 第二层Convolution层
net.conv1 = caffe.layers.Convolution(net.data,
num_output=20,
kernel_size=5,
weight_filler={'type':'xavier'},
bias_filler={'type':'xavier'}
)
# 第三层ReLU激活层
net.relu1 = caffe.layers.ReLU(net.conv1,
in_place=True)
# 第四层Pooling池化层
net.pool1 = caffe.layers.Pooling(net.relu1,
pool=caffe.params.Pooling.MAX,
kernel_size=3, stride=2)
net.conv2 = caffe.layers.Convolution(net.pool1,
num_output=32,
kernel_size=3,
pad=1,
weight_filler={'type': 'xavier'},
bias_filler={'type': 'xavier'}
)
net.relu2 = caffe.layers.ReLU(net.conv2,
in_place=True)
net.pool2 = caffe.layers.Pooling(net.relu2,
pool=caffe.params.Pooling.MAX,
kernel_size=3, stride=2)
# 全连接层
net.fc3 = caffe.layers.InnerProduct(net.pool2,
num_output=1024,
weight_filler=dict(type='xavier')
)
net.relu3 = caffe.layers.ReLU(net.fc3,
in_place=True)
# 创建一个dropout层
net.drop = caffe.layers.Dropout(net.relu3,
dropout_param=dict(dropout_ratio=0.5)
)
net.fc4 = caffe.layers.InnerProduct(net.drop,
num_output=10,
weight_filler=dict(type='xavier')
)
# 创建一个softmax层
net.loss = caffe.layers.softmaxWithLoss(net.fc4,
net.label)
# 写入prototxt文件
with open('net//tt.prototxt', 'w') as f:
f.write(str(net.to_proto()))
if __name__ == '__main__':
create_net()
运行create_net.py,生成 tt.prototxt网络结构文件
四、caffe网络结构可视化
将生成的tt.prototxt文件中的代码,拷贝到 Netscope 中的指定位置,按住 【Shift + Enter】查看网络结构