COVID-Net工程源码详解(三) - inference.py解析

243 阅读6分钟

本文已参与 [新人创作礼] 活动,一起开启掘金创作之路。​

inference.py源码如下:

import numpy as np
import tensorflow as tf
import os, argparse
import cv2

from data import process_image_file

parser = argparse.ArgumentParser(description='COVID-Net Inference')
parser.add_argument('--weightspath', default='models/COVIDNet-CXR3-S', type=str, help='Path to output folder')
parser.add_argument('--metaname', default='model.meta', type=str, help='Name of ckpt meta file')
parser.add_argument('--ckptname', default='model-1014', type=str, help='Name of model ckpts')
parser.add_argument('--imagepath', default='assets/ex-covid.jpeg', type=str, help='Full path to image to be inferenced')
parser.add_argument('--in_tensorname', default='input_1:0', type=str, help='Name of input tensor to graph')
parser.add_argument('--out_tensorname', default='norm_dense_1/Softmax:0', type=str, help='Name of output tensor from graph')
parser.add_argument('--input_size', default=480, type=int, help='Size of input (ex: if 480x480, --input_size 480)')
parser.add_argument('--top_percent', default=0.08, type=float, help='Percent top crop from top of image')

args = parser.parse_args()

mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}

sess = tf.Session()
tf.get_default_graph()
saver = tf.train.import_meta_graph(os.path.join(args.weightspath, args.metaname))
saver.restore(sess, os.path.join(args.weightspath, args.ckptname))

graph = tf.get_default_graph()

image_tensor = graph.get_tensor_by_name(args.in_tensorname)
pred_tensor = graph.get_tensor_by_name(args.out_tensorname)

x = process_image_file(args.imagepath, args.top_percent, args.input_size)
x = x.astype('float32') / 255.0
pred = sess.run(pred_tensor, feed_dict={image_tensor: np.expand_dims(x, axis=0)})

print('Prediction: {}'.format(inv_mapping[pred.argmax(axis=1)[0]]))
print('Confidence')
print('Normal: {:.3f}, Pneumonia: {:.3f}, COVID-19: {:.3f}'.format(pred[0][0], pred[0][1], pred[0][2]))
print('**DISCLAIMER**')
print('Do not use this prediction for self-diagnosis. You should check with your local authorities for the latest advice on seeking medical assistance.')

初始的几句import无需多说,分别导入了numpy,tensorflow,os,argparse和cv2。

from data import process_image_file

这个语句是从data模块中导入process_image_file函数。

接下来的一段是与argparse相关的内容,重点说一下。

模块简介

argparse是Python中的一个常用模块,和sys.argv()功能类似,主要用于编写命令行接口:对于程序所需要的参数,它可以进行正确的解析。另外,argparse还可以自动的生成helpusage信息,当程序的参数无效时,它可以自动生成错误信息。

argparse是一个Python模块:命令行选项、参数和子命令解析器。

argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。

[参考文档]:argparse — Parser for command-line options, arguments and sub-commands — Python 3.11.0 documentation

使用流程

1. 创建解析器

class argparse.ArgumentParser(prog=None, usage=None, description=None, epilog=None, parents=[], formatter_class=argparse.HelpFormatter, argument_default=None, conflict_handler=’error’, add_help=True, allow_abbrev=True)

每个参数解释如下:

  • prog - The name of the program (default: sys.argv[0])
    • 程序文件名
  • usage - The string describing the program usage (default: generated from arguments added to parser)
    • 程序使用说明
  • description - Text to display before the argument help (default: none)
    • 程序目的说明
  • epilog - Text to display after the argument help (default: none)
    • 程序说明后记
  • parents - A list of ArgumentParser objects whose arguments should also be included (default: [])
    • ArgumentParser对象的父对象的参数列表
  • formatter_class - A class for customizing the help output
    • help信息的说明格式
  • prefix_chars - The set of characters that prefix optional arguments (default: ‘-‘)
    • 命令行参数的前缀
  • fromfile_prefix_chars - The set of characters that prefix files from which additional arguments should be read (default: None)
  • argument_default - The global default value for arguments (default: None)
    • 参数的全局默认值
  • conflict_handler - The strategy for resolving conflicting optionals (usually unnecessary)
    • 冲突处理
  • add_help - Add a -h/–help option to the parser (default: True)
    • 是否增加help选项
  • allow_abbrev - Allows long options to be abbreviated if the abbreviation is unambiguous. (default: True)
    • 是否使用参数的缩写

例如:

parser = argparse.ArgumentParser(description='Process some intergers.')

使用argparser的第一步是创建一个ArgumentParser对象,ArgumentParser对象包含将命令行解析成Python数据类型所需的全部信息。ArgumentParser对象保存了所有必要的信息,用以将命令行参数解析为相应的python数据类型。

2. 添加参数

调用add_argument()向ArgumentParser对象添加命令行参数信息,这些信息告诉ArgumentParser对象如何处理命令行参数。可以通过调用parse_agrs()来使用这些命令行参数。

ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

每个参数解释如下:

  • name or flags - Either a name or a list of option strings, e.g. foo or -f, –foo.
    • 参数名或者参数标识
    • -的为可选参数(optional parameter)
    • 不带-的为必选参数(positional parametrer)。
  • action - The basic type of action to be taken when this argument is encountered at the command line.
    • 参数的处理方法
  • nargs - The number of command-line arguments that should be consumed.
    • 参数的数量
  • const - A constant value required by some action and nargs selections.
    • 参数的常量值
  • default - The value produced if the argument is absent from the command line.
    • 参数的默认值
  • type - The type to which the command-line argument should be converted.
    • 参数的数据类型
  • choices - A container of the allowable values for the argument.
    • 参数的取值范围
  • required - Whether or not the command-line option may be omitted (optionals only).
    • 参数是否可以忽略不写 ,仅对可选参数有效
  • help - A brief description of what the argument does.
    • 参数的说明信息
  • metavar - A name for the argument in usage messages.
    • 参数在说明信息usage中的名称
  • dest - The name of the attribute to be added to the object returned by
    parse_args().
    • 对象的属性名

例如:

parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')

parser.add_argument('--sum', dest='accumulate', action='store_const', const=sum, default=max, help='sum the integers (default: find the max)')

3. 解析参数

ArgumentParser通过parse_args()方法来解析ArgumentParser对象中保存的命令行参数:将命令行参数解析成相应的数据类型并采取相应的动作,它返回一个Namespace对象。

>>>parser.parse_args(['--sum', '7', '-1', '42' ])

Namespace(accumulate=, integers=[7, -1, 42])

在实际的python脚本中,parse_args()一般并不使用参数,它的参数由sys.argv决定。

举例1

下列代码是python程序,它可以接受多个整数,并返回它们的和或者最大值。

import argparse

parser = argparse.ArgumentParser(description='Process some integers.')

parser.add_argument('integers', metavar='N', type=int, nargs='+',
                    help='an integer for the accumulator')
parser.add_argument('--sum', dest='accumulate', action='store_const',
                    const=sum, default=max,
                    help='sum the integers (default: find the max)')

args = parser.parse_args()
print(args.accumulate(args.integers))

假定上述代码保存为prog.py,它可以在命令行执行并自动提供有用的help信息:

$ python prog.py -h
usage: prog.py [-h] [--sum] N [N ...]

Process some integers.

positional arguments:
 N           an integer for the accumulator

optional arguments:
 -h, --help  show this help message and exit
 --sum       sum the integers (default: find the max)

该程序可以接受相应的参数并给出相应的输出:

$ python prog.py 1 2 3 4
4
$ python prog.py 1 2 3 4 --sum
10

如果传入无效的参数,它可以自动生成error信息:

$ python prog.py a b c
usage: prog.py [-h] [--sum] N [N ...]
prog.py: error: argument N: invalid int value: 'a'

举例2

import argparse
  
parser = argparse.ArgumentParser()

parser.add_argument('--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('--seed', type=int, default=72, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
  
args = parser.parse_args()
  
print(args.sparse)
print(args.seed)
print(args.epochs)

执行以上程序,打印内容如下:

False

72

10000

Process finished with exit code 0

介绍完了参数解析相关内容,接下来是2个字典:

mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}

很好理解,3分类。正常对应0;普通肺炎对应1;新冠肺炎对应2。反字典对应关系也一致。

接下来的2句无需过多解释,

sess = tf.Session():创建会话;

tf.get_default_graph():获取当前默认的计算图。

saver = tf.train.import_meta_graph(os.path.join(args.weightspath, args.metaname))

用来加载meta文件中的图,以及图上定义的结点参数包括权重偏置项等需要训练的参数,也包括训练过程生成的中间参数。所有参数都是通过graph调用接口get_tensor_by_name(name="训练时的参数名称")来获取。

saver.restore(sess, os.path.join(args.weightspath, args.ckptname))

通过restore函数来给图里的变量赋值。

image_tensor = graph.get_tensor_by_name(args.in_tensorname)
pred_tensor = graph.get_tensor_by_name(args.out_tensorname)