TensorFlow 进度条

597 阅读3分钟
import tensorflow as tf  
from time import sleep

tf.keras.utils.Progbar

在tensorflow里,提供了tf.keras.utils.Progbar来实现进度条的功能

tf.keras.utils.Progbar(  
target,  
width=30,  
verbose=1,  
interval=0.05,  
stateful_metrics=None,  
unit_name='step'  
)  
参数作用
target预期总步数(必填)
width进度条宽度
verbose详细模式,0(静默)、1(详细)、2(半详细)
stateful_metrics此列表的参数,最终输出不会取平均值
interval最小视觉进度更新间隔(以秒为单位)
unit_name步骤计数的名称(通常为“step”或“sample”)

target

target为必填项,一般填样本的数量

progbar = tf.keras.utils.Progbar(10)  

大多数情况下,填写target参数就够了

width

width控制进度条的宽度
tf.keras.utils.Progbar(10, width=30)
显示:10/10 [==============================] - 0s 0s/step
tf.keras.utils.Progbar(10, width=30)
显示:10/10 [========================================] - 0s 0s/step
width=40相较于width=30,进度条的宽度显得更宽了

verbose

verbose参数决定进度条输出的详细程度

  • verbose=0时,进度条将被禁用,不会显示在输出中
  • verbose=1时(默认值),进度条将显示在输出中,并且会根据设置的时间间隔(interval)来更新进度条的状态
  • verbose=2时,进度条同样会显示在输出中,但会在每个步骤结束后都进行刷新,而不是根据时间间隔来更新

verbose=1时的输出
10/10 [==============================] - 0s 0s/step
verbose=1时的输出
10/10 - 0s - 0s/epoch - 0s/step

stateful_metrics

stateful_metrics参数决定最后一次的输出是否取平均值
下面俩段代码,不同的是上面的stateful_metrics参数为['loss'],下面的stateful_metrics参数是None

project_1:

progbar = tf.keras.utils.Progbar(10, stateful_metrics=['loss'])
record = 0
for i in range(10):
    loss = tf.random.uniform([1, ])
    record += loss
    progbar.update(i + 1, values=[('loss', loss)])
print('loss平均值', (record/10) .numpy())

project_2:

progbar = tf.keras.utils.Progbar(10, stateful_metrics=None)
record = 0
for i in range(10):
    loss = tf.random.uniform([1, ])
    record += loss
    progbar.update(i + 1, values=[('loss', loss)])
print('loss平均值', (record/10) .numpy())

project_1输出:
10/10 [==============================] - 0s 0s/step - loss: 0.8224
loss平均值 tf.Tensor(5.3405304, shape=(), dtype=float32)
project_2输出:
10/10 [==============================] - 0s 0s/step - loss: 0.4970
loss平均值 tf.Tensor(4.969612, shape=(), dtype=float32)
project_1的输出0.82245.3405304明显都对不上号,而project_2的输出0.49704.969612大差不差

interval

interval参数指定进度条多少秒更新一次

unit_name

unit_name参数定义计数步骤的名称
unit_name=tf
输出如下(注意后面)
10/10 [==============================] - 0s 0s/tf

Methods

add

values参数的用法就不多说了,详见上面stateful_metrics参数

add(  
n, values=None  
)  

n参数定义每一更新时,进度条添加的宽度

progbar = tf.keras.utils.Progbar(10)  
for i in range(10):  
    progbar.add(1)  
    sleep(0.1)  

每次更新,进度条都走1个单位宽度,循环10次,刚好达到progbar定义的上限10

progbar = tf.keras.utils.Progbar(10)  
for i in range(10):  
    progbar.add(1.5)  
    sleep(0.1)  

很明显1.5*10=15>10,所以进度条就会@..$=&[}(不知道咋描述),看输出吧

10/10 [===============================] - 1s 61ms/step  
12/10 [====================================] - 1s 63ms/step  
13/10 [========================================] - 1s 64ms/step  
15/10 [=============================================] - 1s 65ms/step  

update

update(  
current, values=None  
)  

update则根据current,将进度条的进度更新到current的位置

progbar = tf.keras.utils.Progbar(10)  
for i in range(10):  
    progbar.update(i+1)  
    sleep(0.1)  

这串代码,进度条的进度则会在01之间反复横跳

progbar = tf.keras.utils.Progbar(10)  
n = 1  
for i in range(10):  
    if n == 1:  
        n = 0  
    else:  
        n = 1  
    progbar.update(n)  
    sleep(0.3)   

Run

准备训练数据

EPOCH = 2  
BATCH_SIZE = 256  
trainSet = tf.random.uniform([6000, 28, 28, 1])  
trainSetSize = trainSet.shape[0]  
trainSet = tf.data.Dataset.from_tensor_slices(trainSet).batch(BATCH_SIZE)  

add方式

for epoch in range(EPOCH):  
    print(f'\nEpoch {epoch + 1}/{EPOCH}')  
    progbar = tf.keras.utils.Progbar(tf.math.ceil(trainSetSize / BATCH_SIZE), stateful_metrics=['loss'])  
    for data in trainSet:  
        # loss = train(data)  
        loss = tf.random.uniform([1, ])  
        progbar.add(1, values=[('loss', loss)])  
        sleep(0.01)  

update方式

for epoch in range(EPOCH):  
    print(f'\nEpoch {epoch + 1}/{EPOCH}')  
    progbar = tf.keras.utils.Progbar(tf.math.ceil(trainSetSize / BATCH_SIZE), stateful_metrics=['loss'])  
    for num, data in enumerate(trainSet):  
        # loss = train(data)  
        loss = tf.random.uniform([1, ])  
        progbar.update(num, values=[('loss', loss)])  
        sleep(0.01)