深度学习框架Keras的深入理解:Keras标准工作流程、回调函数使用、自定义评估循环

564 阅读14分钟

公众号:尤而小屋
作者:Peter
编辑:Peter

Python深度学习-深入理解Keras:Keras标准工作流程、回调函数使用、自定义训练循环和评估循环。

本文对Keras的部分做深入了解,主要包含:

  • Keras标准工作流程
  • 如何使用Keras的回调函数
  • 如何自定义编写训练循环和评估循环

Keras标准工作流程

标准的工作流程:

  • compile:编译
  • fit:训练
  • evaluate:评估
  • predict:预测

定义模型

In [1]:

#  keras标准工作流程

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:

下面是函数式API的写法:

def get_mnist_model():
    """
    函数式API的流程
    """
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

定义数据

使用内置的mnist数据集:

In [3]:

(images, labels), (test_images, test_labels) = mnist.load_data()

# 像素尺度缩放
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255

# 验证集和训练集
train_images, valid_images = images[10000:], images[:10000]
train_labels, valid_labels = labels[10000:], labels[:10000]

模型编译、训练、评估、预测

In [4]:

model = get_mnist_model()
model.compile(optimizer="rmsprop", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"]
             )  

model.fit(train_images, train_labels, 
          epochs=3, 
          validation_data=(valid_images, valid_labels))  # 训练;使用验证集来监控模型性能

test_metrics = model.evaluate(test_images, test_labels)  # 在测试集上评估模型

predictions = model.predict(test_images)  # 模型预测

test_metrics
Epoch 1/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.2946 - accuracy: 0.9135 - val_loss: 0.1455 - val_accuracy: 0.9576
Epoch 2/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1615 - accuracy: 0.9540 - val_loss: 0.1177 - val_accuracy: 0.9651
Epoch 3/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1313 - accuracy: 0.9638 - val_loss: 0.1129 - val_accuracy: 0.9693
313/313 [==============================] - 0s 905us/step - loss: 0.0983 - accuracy: 0.9708
313/313 [==============================] - 0s 802us/step

Out[4]:

[0.0983111560344696, 0.97079998254776]

上面就是一个最为简单的从准备数据到预测评估的过程

自定义指标

上面的是内置的方法来标准化过程,用户可以自定义指标。

  • 常用的分类和回归的指标都在keras.metrics模块中。Keras指标是keras.metrics.Metric类的子类。
  • 与层一样,指标具有一个存储在TensorFlow变量中的内部状态。但是其无法进行反向传播来更新,需要手动编写更新逻辑,这个逻辑通过update_state来实现。

In [5]:

# 自定义指标均方根误差RMSE(内置中没有这个指标)

# 本段代码可以直接套用

import tensorflow as tf

class RootMeanSquareError(keras.metrics.Metric):  # 继承父类Metric
    
    def __init__(self, name="rmse", **kwargs):  # 构造函数
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")  # 访问add_weight方法
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros", dtype="int32")
     
    # 基于update_state来实现
    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        基于update_state方法实现状态更新逻辑  
        y_true:真实值  
        y_pred:预测值
        """
        
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])  # 编码
        mse = tf.reduce_sum(tf.square(y_true - y_pred))  # mse求解
        self.mse_sum.assign_add(mse)  # 对mse的累计求和
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)  # 总样本数    
        
    def result(self):
        """
        使用result方法返回指标的当前值
        """
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))
    
    # 重置指标状态    
    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0) 

模型训练(基于自定义指标)

In [6]:

model = get_mnist_model()

model.compile(optimizer="rmsprop", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquareError()])  # 列表中传入自定义指标

model.fit(train_images, train_labels, 
          epochs=3, 
          validation_data=(valid_images, valid_labels))

test_metrics = model.evaluate(test_images, test_labels)
Epoch 1/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.2945 - accuracy: 0.9108 - rmse: 7.1743 - val_loss: 0.1480 - val_accuracy: 0.9554 - val_rmse: 7.3592
Epoch 2/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1609 - accuracy: 0.9543 - rmse: 7.3517 - val_loss: 0.1205 - val_accuracy: 0.9642 - val_rmse: 7.3981
Epoch 3/3
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1313 - accuracy: 0.9635 - rmse: 7.3858 - val_loss: 0.1024 - val_accuracy: 0.9723 - val_rmse: 7.4194
313/313 [==============================] - 0s 1ms/step - loss: 0.0931 - accuracy: 0.9727 - rmse: 7.4312

使用回调函数

Keras中的回调函数是一个对象(实现了特定方法的类实例),在调用fit函数时被传入模型,并在训练过程中的不同时间点被模型调用。

简介

回调函数可以访问模型状态或者性能的所有数据,还可以采取下面的功能:

  • 中断训练
  • 保存模型
  • 加载权重
  • 改变模型状态等

常用的回调函数的功能:

  1. 模型检查点model checkpointing:在训练过程中的不同时间点保存模型的当前状态
  2. 早停early stopping:如果验证损失不在改变,则提前终止
  3. 在训练过程中,动态调节某些参数:比如学习率等
  4. 在训练过程中,记录训练指标和验证指标,或者将模型学到的表示可视化
keras.callbacks.ModelCheckpoint
keras.callbacks.EarlyStoppping
keras.callbacks.LearningRateScheduler
keras.callbacks.ReduceLROnPlateau
keras.callbacks.CSVLogger

使用回调函数

以早停EarlyStopping & 模型检查点ModelCheckpoint为例,介绍如何使用回调函数。

早停可以让模型在验证损失不在改变的时候提前终止,通过EarlyStopping回调函数来实现。 通常和ModelCheckpoint回调函数使用,该函数在训练过程中不断保存模型。使得在某个点停止后保存的仍然是最佳模型。

In [7]:

callback_list = [
    # 早停
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",  # 监控模型的验证精度
        patience=2  # 如果精度在两轮内不变,则中断训练
    ),
    # 模型检查点
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path.keras",  # 模型文件保存路径
        monitor="val_loss",  # 两个参数的含义:当val_loss改善时,才会覆盖模型文件,这样便会一致保存最佳模型
        save_best_only=True
    )
    
]

In [8]:

基于早停法实现模型训练:callbacks参数数据

model = get_mnist_model()

model.compile(optimizer="rmsprop",  # 优化器
              loss="sparse_categorical_crossentropy",  # 损失
              metrics=["accuracy"]  # 监控指标-精度
             )

model.fit(train_images, 
          train_labels, 
          epochs=20, 
          callbacks=callback_list,  # 使用callbacks参数
          validation_data=(valid_images, valid_labels))
Epoch 1/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.2959 - accuracy: 0.9114 - val_loss: 0.1450 - val_accuracy: 0.9584
Epoch 2/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1599 - accuracy: 0.9539 - val_loss: 0.1172 - val_accuracy: 0.9672
Epoch 3/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1306 - accuracy: 0.9631 - val_loss: 0.1044 - val_accuracy: 0.9710
Epoch 4/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1122 - accuracy: 0.9681 - val_loss: 0.1018 - val_accuracy: 0.9732
Epoch 5/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1044 - accuracy: 0.9717 - val_loss: 0.0965 - val_accuracy: 0.9763
Epoch 6/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0946 - accuracy: 0.9740 - val_loss: 0.0990 - val_accuracy: 0.9756
Epoch 7/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0866 - accuracy: 0.9762 - val_loss: 0.0985 - val_accuracy: 0.9768
Epoch 8/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0848 - accuracy: 0.9770 - val_loss: 0.0873 - val_accuracy: 0.9788
Epoch 9/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0779 - accuracy: 0.9795 - val_loss: 0.0900 - val_accuracy: 0.9788
Epoch 10/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0722 - accuracy: 0.9801 - val_loss: 0.0972 - val_accuracy: 0.9785

Out[8]:

<keras.callbacks.History at 0x1b156d38eb0>

可以看到指定训练20轮,但是实际上在未达到20轮训练就已经停止了。

模型保存和加载

In [9]:

model.save("my_checkpoint_path")  # 保存

In [10]:

model = keras.models.load_model("checkpoint_path.keras")  # 加载模型检查点处的模型

自定义回调函数

如果我们想在训练中采取特定的行动,但是这些行动没有包含在内置回调函数中,可以自己编写回调函数。

回调函数实现的方式是将keras.callbacks.Callback类子类化。然后实现下列方法,在训练过程中的不同时间点被调用。

on_epoch_begin(epoch,logs)  # 每轮开始时
on_epoch_end(epoch,logs)  # 每轮结束时
on_batch_begin(batch,logs)  # 在处理每个批次前
on_batch_end(batch,logs)  # 在处理每个批次后
on_train_begin(logs)   # 在训练开始前
on_train_end(logs)  # 在训练开始后

在调用这些方法的时候,都会用到参数logs,这个参数是个字典,它包含前一个批量、前一个轮次或前一个训练的信息,比如验证指标或者训练指标等。

In [11]:

# 通过Callback类子类化来创建自定义回调函数

# 在训练过程中保存每个批量损失值组成的列表,在每轮结束时保存这些损失值组成的图

from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):  # 继承父类
    """
    每个方法都有logs参数,它是字典
    """
    def on_train_begin(self, logs):
        self.per_batch_losses = []
        
    def on_batch_end(self, batch, logs):  
        self.per_batch_losses.append(logs.get("loss"))  # 获取每个batch下的损失
        
    def on_epoch_end(self, epoch, logs):
        plt.clf()  # 图形清零
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")  # 保存图像
        self.per_batch_losses = []

测试自定义的回调函数:

In [12]:

model = get_mnist_model()

model.compile(optimizer="rmsprop",  # 优化器
              loss="sparse_categorical_crossentropy",  # 损失
              metrics=["accuracy"]  # 监控指标-精度
             )

model.fit(train_images, 
          train_labels, 
          epochs=10, 
          callbacks=[LossHistory()],  # callbacks参数;必须是列表的形式
          validation_data=(valid_images, valid_labels))
Epoch 1/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.2941 - accuracy: 0.9125 - val_loss: 0.1541 - val_accuracy: 0.9563
Epoch 2/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1587 - accuracy: 0.9541 - val_loss: 0.1188 - val_accuracy: 0.9671
Epoch 3/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1313 - accuracy: 0.9627 - val_loss: 0.0983 - val_accuracy: 0.9731
Epoch 4/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1120 - accuracy: 0.9691 - val_loss: 0.1007 - val_accuracy: 0.9733
Epoch 5/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1025 - accuracy: 0.9714 - val_loss: 0.0961 - val_accuracy: 0.9764
Epoch 6/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0947 - accuracy: 0.9744 - val_loss: 0.1068 - val_accuracy: 0.9752
Epoch 7/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0913 - accuracy: 0.9758 - val_loss: 0.0883 - val_accuracy: 0.9784
Epoch 8/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0805 - accuracy: 0.9780 - val_loss: 0.0968 - val_accuracy: 0.9784
Epoch 9/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0781 - accuracy: 0.9784 - val_loss: 0.0963 - val_accuracy: 0.9804
Epoch 10/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0740 - accuracy: 0.9805 - val_loss: 0.0966 - val_accuracy: 0.9787

Out[12]:

<keras.callbacks.History at 0x1b1592ae640>

基于回调函数利用TensorBoard进行监控和可视化

TensorBoard是一个基于浏览器的应用程序,可以在本地运行,它在训练过程中可以监控模型的最佳方式,它可以实现下面的内容:

  • 在训练过程中以可视化的方式监控指标
  • 将模型架构可视化
  • 将激活函数和梯度的直方图可视化
  • 以三维形式研究嵌入

如果想将TensorBoard与Keras模型的fit方法联用,可以用keras.callbacks.TensorBoard回调函数

基于TensorBoard的回调函数

In [13]:

# 让回调函数写入日志的位置

model = get_mnist_model()

model.compile(optimizer="rmsprop",  
              loss="sparse_categorical_crossentropy",  
              metrics=["accuracy"])  

# 不用事先本地手动创建;代码运行会自动创建logs文件夹    
tensorboard = keras.callbacks.TensorBoard(log_dir="./logs",)  

model.fit(train_images, 
          train_labels, 
          epochs=10, 
          callbacks=[tensorboard],  # callbacks参数是列表
          validation_data=(valid_images, valid_labels))
Epoch 1/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.2924 - accuracy: 0.9130 - val_loss: 0.1504 - val_accuracy: 0.9568
Epoch 2/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1622 - accuracy: 0.9527 - val_loss: 0.1195 - val_accuracy: 0.9657
Epoch 3/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1294 - accuracy: 0.9629 - val_loss: 0.1138 - val_accuracy: 0.9697
Epoch 4/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.1158 - accuracy: 0.9682 - val_loss: 0.0992 - val_accuracy: 0.9735
Epoch 5/10
1563/1563 [==============================] - 7s 5ms/step - loss: 0.1040 - accuracy: 0.9713 - val_loss: 0.1016 - val_accuracy: 0.9744
Epoch 6/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0957 - accuracy: 0.9743 - val_loss: 0.0865 - val_accuracy: 0.9776
Epoch 7/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0881 - accuracy: 0.9762 - val_loss: 0.0966 - val_accuracy: 0.9780
Epoch 8/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0858 - accuracy: 0.9776 - val_loss: 0.0899 - val_accuracy: 0.9794
Epoch 9/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0771 - accuracy: 0.9793 - val_loss: 0.0908 - val_accuracy: 0.9790
Epoch 10/10
1563/1563 [==============================] - 6s 4ms/step - loss: 0.0744 - accuracy: 0.9799 - val_loss: 0.0921 - val_accuracy: 0.9791

Out[13]:

<keras.callbacks.History at 0x1b15e8780a0>

启动TensorBoard显示界面

第一步先安装TensorBoard,如果没有安装

pip install TensorBoard

1、在命令窗口中启动语句:

# 启动界面
tensorboard --logdir=tensorboard_path

最终的结果:

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.12.3 at http://localhost:6006/ (Press CTRL+C to quit)

可以直接进入本地的6006端口。

2、如果想在jupyter中直接使用tensorboard:依次执行下面的两条命令

%load_ext tensorboard
%tensorboard --logdir logs  # logs表示文件目录地址

In [14]:

%load_ext tensorboard

In [15]:

# 启动命令
%tensorboard --logdir logs/  # 在当前的cell中直接使用

==补充图==

编写自定义的训练循环和评估循环

fit()工作流程在易用性和灵活性之间实现了很好的平衡。然而,有时即使自定义指标、损失函数和回调函数,也无法满足一切需求。

内置的fit流程只针对监督学习supervised learning。其他的机器学习任务,比如生成式学习generative learning、自监督学习self-supervised learning和强化学习reinforcement learning,则无法满足。

这个时候需要编写自定义的训练逻辑。本节从头开始实现fit()方法。

训练和推断

低阶训练循环示例中:

  • 步骤1:前向传播是通过predictions=model(inputs)完成
  • 步骤2:检索梯度带计算的梯度是通过gradients=tape.gradient(loss,model.weights)完成的

某些Keras层中,在训练过程和推断过程中具有不同的行为。这些层的call方法中有一个名为training的参数。

比如:调用dropout(inputs, training=True)将会舍弃一些单元,而调用dropout(inputs, training=False)则不会舍弃。

在函数式模型和序贯模型的call方法中,也有training这个参数,前向传播变成:predictions=model(inputs, training=True)

检索模型权重的梯度时,使用:tape.gradients(loss, model.trainable_weights)。层和模型具有以下两种权重:

  • 可训练权重trainable weight:通过反向传播对这些权重进行更新,将损失最小化。Dense层的核和偏置就是可训练权重。
  • 不可训练权重non-trainable weight:在前向传播中,这些权重所在的层对它们进行更新。在Keras的所有内置层中,唯一不可训练的权重层是BatchNormalization,实现特征的规范化。

指标的低阶用法

在低阶训练循环中,可能会用到Keras指标。指标API的实现:目标值和预测值组成的批量调用update_state(y_true, y_pred),然后使用result方法查询当前指标值。

In [16]:

metric = keras.metrics.SparseCategoricalAccuracy()
targets = [0,1,2]
predictions = [[1,0,0],[0,1,0],[0,0,1]]

metric.update_state(targets, predictions)
current_result = metric.result()

print(f"result: {current_result:.2f}")
result: 1.00

跟踪某个标量值,比如模型损失的均值,使用keras.metrics.Mean()指标来实现:

In [17]:

values = [0,1,2,3,4]
mean_tracker = keras.metrics.Mean()

for value in values:
    mean_tracker.update_state(value)

print(f"Mean of values: {mean_tracker.result():.2f}")
Mean of values: 2.00

如果想重置当前结果,可以使用metric.reset_state()

完整的训练循环和评估循环

将前向传播、反向传播和指标跟踪组合成一个类似fit的训练步骤函数:

训练循环train_step

In [18]:

model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()  # 损失函数
optimizer = keras.optimizers.RMSprop()  # 优化器
metrics = [keras.metrics.SparseCategoricalAccuracy()]  # 需要监控的指标列表
loss_tracking_metric = keras.metrics.Mean()  # 准备Mean指标跟踪损失均值

def train_step(inputs, targets):
    with tf.GradientTape() as tape:  # 前向传播  training=True
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    
    # 反向传播 model_trainable_weights    
    gradients = tape.gradient(loss, model.trainable_weights)  
    optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 
    
    # 跟踪指标
    logs = {}  # 字典形式
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs[metric.name] = metric.result()
    
    # 跟踪损失均值    
    loss_tracking_metric.update_state(loss)
    logs["loss"] = loss_tracking_metric.result()
    
    return logs  # 返回指标和损失值

在每轮开始时和进行评估之前,我们需要重置指标的状态:

In [19]:

def reset_metrics():
    for metric in metrics:
        metric.reset_state()   
    loss_tracking_metric.reset_state()

编写完整的训练本身,tf.data.Dataset对象将Numpy数据转成一个迭代器,以大小为32的批量来迭代数据:

In [20]:

# 逐步编写训练循环:循环本身
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3

for epoch in range(epochs):
    reset_metrics()    # 指标重置
    for inputs_batch, targets_batch in training_dataset:
        logs = train_step(inputs_batch, targets_batch)   
        
    print(f"Result at the end of epoch {epoch}")
    
    for k, v in logs.items():
        print(f"...{k}: {v:.4f}")  # 保留4位小数
Result at the end of epoch 0
...sparse_categorical_accuracy: 0.9144
...loss: 0.2913
Result at the end of epoch 1
...sparse_categorical_accuracy: 0.9542
...loss: 0.1589
Result at the end of epoch 2
...sparse_categorical_accuracy: 0.9635
...loss: 0.1300

评估循环test_step

In [21]:

import time  # 运行计时

In [22]:

def test_step(inputs, targets):
    """
    test_step是train_step()逻辑的子集;省略了处理更新权重的代码(即所有设计GradientTape的代码)
    """
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)
    
    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()
        
    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    
    return logs

start = time.time()

val_dataset = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels))
val_dataset = val_dataset.batch(32)

reset_metrics()

for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
    
for k, v in logs.items():
    print(f"...{k}:{v:.4f}")
    
end = time.time()
print("未使用@tf.function的运行时间: ",end - start)
...val_sparse_categorical_accuracy:0.9668
...val_loss:0.1210
未使用@tf.function的运行时间:  1.4751169681549072

利用tf.function加速运算

自定义循环的运行速度比内置的fit核evaluate要慢很多;默认情况下,TensorFlow代码是逐行急切执行的。急切执行让调试代码变得容易,但是性能上远非最佳。

高效做法:将TensorFlow代码编译成计算图,对该计算图进行全局优化,这是逐行解释代码无法实现的。

只要一行代码:@tf.function

In [23]:

@tf.function   # 一行代码!!!!!!
def test_step(inputs, targets):
    """
    test_step是train_step()逻辑的子集;省略了处理更新权重的代码(即所有设计GradientTape的代码)
    """
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)
    
    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()
        
    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    
    return logs

start = time.time()

val_dataset = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels))
val_dataset = val_dataset.batch(32)

reset_metrics()

for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
    
for k, v in logs.items():
    print(f"...{k}:{v:.4f}")
    
end = time.time()
print("使用@tf.function的运行时间: ",end - start)
...val_sparse_categorical_accuracy:0.9668
...val_loss:0.1210
使用@tf.function的运行时间:  0.41889119148254395

对比两次运行的结果,添加了@tf.function之后,运行时间缩短了3倍多。

注意:调试代码时,最好使用急切执行,不要使用@tf.function装饰器。一旦代码能够成功运行之后,便可使用其进行加速。

在fit中使用自定义训练循环

自定义训练步骤

自定义训练循环的特点:

  1. 拥有很强的灵活性
  2. 需要编写大量的代码
  3. 无法利用fit提供的诸多方便性,比如回调函数或者对分布式训练的支持等

如果想自定义训练算法,但是仍想使用keras内置训练逻辑的强大功能,折中方法:编写自定义的训练步骤函数,让Keras完成其他工作

通过覆盖Model类的train_step()方法来实现

In [24]:

loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")  # 跟踪训练和评估过程的损失均值

class CustmoModel(keras.Model):
    def train_step(self, data):  # 覆盖train_step方法
        inputs, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)  # 也可用model(inputs, training=True) 模型就是类本身
            loss = loss_fn(targets, predictions)
            
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
        loss_tracker.update_state(loss)  # 更新损失跟踪器指标(均值)
        
        return {"loss": loss_tracker.result()}  # 返回当前损失跟踪器的损失均值
    
    @property
    def metrics(self):
        return [loss_tracker]

实例化模型并进行训练:

In [25]:

inputs = keras.Input(shape=(28*28))
features = layers.Dense(512,activation="relu")(inputs)
features = layers.Dropout(0.5)(features)

outputs = layers.Dense(10, activation="softmax")(features)
model = CustmoModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop())
model.fit(train_images, train_labels, epochs=3)
Epoch 1/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.2967
Epoch 2/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.1600
Epoch 3/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.1310

Out[25]:

<keras.callbacks.History at 0x1b15e884c40>

指标处理

在调用compile之后,可用访问的内容:

  • self.compiled_loss:传入compile()的损失函数
  • self.compiled_metrics:传入的指标列表的包装器,它允许调用self.compiled_metrics.update_state()来一次性更新所有指标。
  • self.metrics:传入compile()的指标列表。它还包括一个跟踪损失的指标,类似于用loss_tracking_metric手动实现的例子

In [26]:

class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = self.compiled_loss(targets, predictionsctions) # 传入compile()的损失函数
            
        gradients = tape.gradient(loss, self.trainabel_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
        self.compiled_metrics.update_state(targets, predictions)
        
        return {m.name:m.result() for m in self.metrics}

In [27]:

# 测试代码

inputs = keras.Input(shape=(28*28))
features = layers.Dense(512,activation="relu")(inputs)
features = layers.Dropout(0.5)(features)

outputs = layers.Dense(10, activation="softmax")(features)
model = CustmoModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop(),
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseTopKCategoricalAccuracy()]
             )

model.fit(train_images, train_labels, epochs=3)
Epoch 1/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.2996
Epoch 2/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.1618
Epoch 3/3
1563/1563 [==============================] - 5s 3ms/step - loss: 0.1303

Out[27]:

<keras.callbacks.History at 0x1b1610c7c10>