本文已参与 [新人创作礼] 活动,一起开启掘金创作之路。
train_tf.py源码如下:
from __future__ import print_function
import tensorflow as tf
import os, argparse, pathlib
from eval import eval
from data import BalanceCovidDataset
parser = argparse.ArgumentParser(description='COVID-Net Training Script')
parser.add_argument('--epochs', default=10, type=int, help='Number of epochs')
parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate')
parser.add_argument('--bs', default=8, type=int, help='Batch size')
parser.add_argument('--weightspath', default='models/COVIDNet-CXR3-S', type=str, help='Path to output folder')
parser.add_argument('--metaname', default='model.meta', type=str, help='Name of ckpt meta file')
parser.add_argument('--ckptname', default='model-1014', type=str, help='Name of model ckpts')
parser.add_argument('--trainfile', default='train_COVIDx3.txt', type=str, help='Name of train file')
parser.add_argument('--testfile', default='test_COVIDx3.txt', type=str, help='Name of test file')
parser.add_argument('--name', default='COVIDNet', type=str, help='Name of folder to store training checkpoints')
parser.add_argument('--datadir', default='data', type=str, help='Path to data folder')
parser.add_argument('--covid_weight', default=4., type=float, help='Class weighting for covid')
parser.add_argument('--covid_percent', default=0.3, type=float, help='Percentage of covid samples in batch')
parser.add_argument('--input_size', default=480, type=int, help='Size of input (ex: if 480x480, --input_size 480)')
parser.add_argument('--top_percent', default=0.08, type=float, help='Percent top crop from top of image')
parser.add_argument('--in_tensorname', default='input_1:0', type=str, help='Name of input tensor to graph')
parser.add_argument('--out_tensorname', default='norm_dense_1/Softmax:0', type=str, help='Name of output tensor from graph')
parser.add_argument('--logit_tensorname', default='norm_dense_1/MatMul:0', type=str, help='Name of logit tensor for loss')
parser.add_argument('--label_tensorname', default='norm_dense_1_target:0', type=str, help='Name of label tensor for loss')
parser.add_argument('--weights_tensorname', default='norm_dense_1_sample_weights:0', type=str, help='Name of sample weights tensor for loss')
args = parser.parse_args()
# Parameters
learning_rate = args.lr
batch_size = args.bs
display_step = 1
# output path
outputPath = './output/'
runID = args.name + '-lr' + str(learning_rate)
runPath = outputPath + runID
pathlib.Path(runPath).mkdir(parents=True, exist_ok=True)
print('Output: ' + runPath)
with open(args.trainfile) as f:
trainfiles = f.readlines()
with open(args.testfile) as f:
testfiles = f.readlines()
generator = BalanceCovidDataset(data_dir=args.datadir,
csv_file=args.trainfile,
batch_size=batch_size,
input_shape=(args.input_size, args.input_size),
covid_percent=args.covid_percent,
class_weights=[1., 1., args.covid_weight],
top_percent=args.top_percent)
with tf.Session() as sess:
tf.get_default_graph()
saver = tf.train.import_meta_graph(os.path.join(args.weightspath, args.metaname))
graph = tf.get_default_graph()
image_tensor = graph.get_tensor_by_name(args.in_tensorname)
labels_tensor = graph.get_tensor_by_name(args.label_tensorname)
sample_weights = graph.get_tensor_by_name(args.weights_tensorname)
pred_tensor = graph.get_tensor_by_name(args.logit_tensorname)
# loss expects unscaled logits since it performs a softmax on logits internally for efficiency
# Define loss and optimizer
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
logits=pred_tensor, labels=labels_tensor)*sample_weights)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)
# Initialize the variables
init = tf.global_variables_initializer()
# Run the initializer
sess.run(init)
# load weights
saver.restore(sess, os.path.join(args.weightspath, args.ckptname))
#saver.restore(sess, tf.train.latest_checkpoint(args.weightspath))
# save base model
saver.save(sess, os.path.join(runPath, 'model'))
print('Saved baseline checkpoint')
print('Baseline eval:')
eval(sess, graph, testfiles, os.path.join(args.datadir,'test'),
args.in_tensorname, args.out_tensorname, args.input_size)
# Training cycle
print('Training started')
total_batch = len(generator)
progbar = tf.keras.utils.Progbar(total_batch)
for epoch in range(args.epochs):
for i in range(total_batch):
# Run optimization
batch_x, batch_y, weights = next(generator)
sess.run(train_op, feed_dict={image_tensor: batch_x,
labels_tensor: batch_y,
sample_weights: weights})
progbar.update(i+1)
if epoch % display_step == 0:
pred = sess.run(pred_tensor, feed_dict={image_tensor:batch_x})
loss = sess.run(loss_op, feed_dict={pred_tensor: pred,
labels_tensor: batch_y,
sample_weights: weights})
print("Epoch:", '%04d' % (epoch + 1), "Minibatch loss=", "{:.9f}".format(loss))
eval(sess, graph, testfiles, os.path.join(args.datadir,'test'),
args.in_tensorname, args.out_tensorname, args.input_size)
saver.save(sess, os.path.join(runPath, 'model'), global_step=epoch+1, write_meta_graph=False)
print('Saving checkpoint at epoch {}'.format(epoch + 1))
print("Optimization Finished!")
from __future__ import print_function
在开头加上from __future__ import print_function这句之后,即使在python2.X,使用print就得像python3.X那样加括号使用。python2.X中print不需要括号,而在python3.X中则需要。
import tensorflow as tf
import os, argparse, pathlib
导入tensorflow模块。
导入os,argparse,pathlib模块。
from eval import eval
from data import BalanceCovidDataset
从eval模块导入eval函数
从data模块导入BalanceCovidDataset函数
接下来的一段参数解析( arg parser)相关的代码详见 COVID-Net工程源码详解(三) - inference.py解析。
# Parameters
learning_rate = args.lr
batch_size = args.bs
display_step = 1
学习率learning_rate赋值为args.lr,默认为0.0002。
批次大小batch_size赋值为args.bs,默认为8。
display_step赋值为1。
# output path
outputPath = './output/'
runID = args.name + '-lr' + str(learning_rate)
runPath = outputPath + runID
pathlib.Path(runPath).mkdir(parents=True, exist_ok=True)
runID 为"COVIDNet-lr0.0002"。
runPath为"./output/COVIDNet-lr0.0002"。
在当前路径下创建output/COVIDNet-lr0.0002文件夹。parents=True时,会依次创建路径中间缺少的文件夹。
pathlib的mkdir接收两个参数:
- parents:如果父目录不存在,是否创建父目录。
- exist_ok:只有在目录不存在时创建目录,目录已存在时不会抛出异常。
with open(args.trainfile) as f:
trainfiles = f.readlines()
打开train_split_v3.txt文件,并读入全部内容,保存于trainfiles。此处没有使用默认参数train_COVIDx3.txt,而是使用了由命令行传入的文件名。参见 COVID-Net工程源码详解(四) - train_eval_inference.md解析。
with open(args.testfile) as f:
testfiles = f.readlines()
打开test_split_v3.txt文件,并读入全部内容,保存于testfiles。此处没有使用默认参数test_COVIDx3.txt,而是使用了由命令行传入的文件名。
generator = BalanceCovidDataset(data_dir=args.datadir,
csv_file=args.trainfile,
batch_size=batch_size,
input_shape=(args.input_size, args.input_size),
covid_percent=args.covid_percent,
class_weights=[1., 1., args.covid_weight],
top_percent=args.top_percent)
重头戏来了。使用Data库中的BalanceCovidDataset函数构建每个批次的数据集,详见 COVID-Net工程源码详解(五) - data.py解析。