本文已参与 [新人创作礼] 活动,一起开启掘金创作之路。
inference.py源码如下:
import numpy as np
import tensorflow as tf
import os, argparse
import cv2
from data import process_image_file
parser = argparse.ArgumentParser(description='COVID-Net Inference')
parser.add_argument('--weightspath', default='models/COVIDNet-CXR3-S', type=str, help='Path to output folder')
parser.add_argument('--metaname', default='model.meta', type=str, help='Name of ckpt meta file')
parser.add_argument('--ckptname', default='model-1014', type=str, help='Name of model ckpts')
parser.add_argument('--imagepath', default='assets/ex-covid.jpeg', type=str, help='Full path to image to be inferenced')
parser.add_argument('--in_tensorname', default='input_1:0', type=str, help='Name of input tensor to graph')
parser.add_argument('--out_tensorname', default='norm_dense_1/Softmax:0', type=str, help='Name of output tensor from graph')
parser.add_argument('--input_size', default=480, type=int, help='Size of input (ex: if 480x480, --input_size 480)')
parser.add_argument('--top_percent', default=0.08, type=float, help='Percent top crop from top of image')
args = parser.parse_args()
mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}
sess = tf.Session()
tf.get_default_graph()
saver = tf.train.import_meta_graph(os.path.join(args.weightspath, args.metaname))
saver.restore(sess, os.path.join(args.weightspath, args.ckptname))
graph = tf.get_default_graph()
image_tensor = graph.get_tensor_by_name(args.in_tensorname)
pred_tensor = graph.get_tensor_by_name(args.out_tensorname)
x = process_image_file(args.imagepath, args.top_percent, args.input_size)
x = x.astype('float32') / 255.0
pred = sess.run(pred_tensor, feed_dict={image_tensor: np.expand_dims(x, axis=0)})
print('Prediction: {}'.format(inv_mapping[pred.argmax(axis=1)[0]]))
print('Confidence')
print('Normal: {:.3f}, Pneumonia: {:.3f}, COVID-19: {:.3f}'.format(pred[0][0], pred[0][1], pred[0][2]))
print('**DISCLAIMER**')
print('Do not use this prediction for self-diagnosis. You should check with your local authorities for the latest advice on seeking medical assistance.')
初始的几句import无需多说,分别导入了numpy,tensorflow,os,argparse和cv2。
from data import process_image_file
这个语句是从data模块中导入process_image_file函数。
接下来的一段是与argparse相关的内容,重点说一下。
模块简介
argparse是Python中的一个常用模块,和sys.argv()功能类似,主要用于编写命令行接口:对于程序所需要的参数,它可以进行正确的解析。另外,argparse还可以自动的生成help和 usage信息,当程序的参数无效时,它可以自动生成错误信息。
argparse是一个Python模块:命令行选项、参数和子命令解析器。
argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
[参考文档]:argparse — Parser for command-line options, arguments and sub-commands — Python 3.11.0 documentation
使用流程
1. 创建解析器
class argparse.ArgumentParser(prog=None, usage=None, description=None, epilog=None, parents=[], formatter_class=argparse.HelpFormatter, argument_default=None, conflict_handler=’error’, add_help=True, allow_abbrev=True)
每个参数解释如下:
- prog - The name of the program (default: sys.argv[0])
- 程序文件名
- usage - The string describing the program usage (default: generated from arguments added to parser)
- 程序使用说明
- description - Text to display before the argument help (default: none)
- 程序目的说明
- epilog - Text to display after the argument help (default: none)
- 程序说明后记
- parents - A list of ArgumentParser objects whose arguments should also be included (default: [])
- ArgumentParser对象的父对象的参数列表
- formatter_class - A class for customizing the help output
help信息的说明格式
- prefix_chars - The set of characters that prefix optional arguments (default: ‘-‘)
- 命令行参数的前缀
- fromfile_prefix_chars - The set of characters that prefix files from which additional arguments should be read (default: None)
- argument_default - The global default value for arguments (default: None)
- 参数的全局默认值
- conflict_handler - The strategy for resolving conflicting optionals (usually unnecessary)
- 冲突处理
- add_help - Add a -h/–help option to the parser (default: True)
- 是否增加
help选项
- 是否增加
- allow_abbrev - Allows long options to be abbreviated if the abbreviation is unambiguous. (default: True)
- 是否使用参数的缩写
例如:
parser = argparse.ArgumentParser(description='Process some intergers.')
使用argparser的第一步是创建一个ArgumentParser对象,ArgumentParser对象包含将命令行解析成Python数据类型所需的全部信息。ArgumentParser对象保存了所有必要的信息,用以将命令行参数解析为相应的python数据类型。
2. 添加参数
调用add_argument()向ArgumentParser对象添加命令行参数信息,这些信息告诉ArgumentParser对象如何处理命令行参数。可以通过调用parse_agrs()来使用这些命令行参数。
ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])
每个参数解释如下:
- name or flags - Either a name or a list of option strings, e.g. foo or -f, –foo.
- 参数名或者参数标识
- 带
-的为可选参数(optional parameter) - 不带
-的为必选参数(positional parametrer)。
- action - The basic type of action to be taken when this argument is encountered at the command line.
- 参数的处理方法
- nargs - The number of command-line arguments that should be consumed.
- 参数的数量
- const - A constant value required by some action and nargs selections.
- 参数的常量值
- default - The value produced if the argument is absent from the command line.
- 参数的默认值
- type - The type to which the command-line argument should be converted.
- 参数的数据类型
- choices - A container of the allowable values for the argument.
- 参数的取值范围
- required - Whether or not the command-line option may be omitted (optionals only).
- 参数是否可以忽略不写 ,仅对可选参数有效
- help - A brief description of what the argument does.
- 参数的说明信息
- metavar - A name for the argument in usage messages.
- 参数在说明信息
usage中的名称
- 参数在说明信息
- dest - The name of the attribute to be added to the object returned by
parse_args().- 对象的属性名
例如:
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')
parser.add_argument('--sum', dest='accumulate', action='store_const', const=sum, default=max, help='sum the integers (default: find the max)')
3. 解析参数
ArgumentParser通过parse_args()方法来解析ArgumentParser对象中保存的命令行参数:将命令行参数解析成相应的数据类型并采取相应的动作,它返回一个Namespace对象。
>>>parser.parse_args(['--sum', '7', '-1', '42' ])
Namespace(accumulate=, integers=[7, -1, 42])
在实际的python脚本中,parse_args()一般并不使用参数,它的参数由sys.argv决定。
举例1
下列代码是python程序,它可以接受多个整数,并返回它们的和或者最大值。
import argparse
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('integers', metavar='N', type=int, nargs='+',
help='an integer for the accumulator')
parser.add_argument('--sum', dest='accumulate', action='store_const',
const=sum, default=max,
help='sum the integers (default: find the max)')
args = parser.parse_args()
print(args.accumulate(args.integers))
假定上述代码保存为prog.py,它可以在命令行执行并自动提供有用的help信息:
$ python prog.py -h
usage: prog.py [-h] [--sum] N [N ...]
Process some integers.
positional arguments:
N an integer for the accumulator
optional arguments:
-h, --help show this help message and exit
--sum sum the integers (default: find the max)
该程序可以接受相应的参数并给出相应的输出:
$ python prog.py 1 2 3 4
4
$ python prog.py 1 2 3 4 --sum
10
如果传入无效的参数,它可以自动生成error信息:
$ python prog.py a b c
usage: prog.py [-h] [--sum] N [N ...]
prog.py: error: argument N: invalid int value: 'a'
举例2
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('--seed', type=int, default=72, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
args = parser.parse_args()
print(args.sparse)
print(args.seed)
print(args.epochs)
执行以上程序,打印内容如下:
False
72
10000
Process finished with exit code 0
介绍完了参数解析相关内容,接下来是2个字典:
mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}
很好理解,3分类。正常对应0;普通肺炎对应1;新冠肺炎对应2。反字典对应关系也一致。
接下来的2句无需过多解释,
sess = tf.Session():创建会话;
tf.get_default_graph():获取当前默认的计算图。
saver = tf.train.import_meta_graph(os.path.join(args.weightspath, args.metaname))
用来加载meta文件中的图,以及图上定义的结点参数包括权重偏置项等需要训练的参数,也包括训练过程生成的中间参数。所有参数都是通过graph调用接口get_tensor_by_name(name="训练时的参数名称")来获取。
saver.restore(sess, os.path.join(args.weightspath, args.ckptname))
通过restore函数来给图里的变量赋值。
image_tensor = graph.get_tensor_by_name(args.in_tensorname)
pred_tensor = graph.get_tensor_by_name(args.out_tensorname)