使用半监督学习(SSL)用未标记的数据训练分类器模型的初学者教程

150 阅读8分钟

传统上,训练计算机视觉模型(如分类器)需要标记的数据。训练数据中的每个例子都需要是一对:一幅图像和一个描述图像的人类生成的标签。

最近,新的SSL技术为Imagenet等经典挑战提供了计算机视觉中最准确的模型。半监督学习(SSL)让一个模型从有标签和无标签的数据中学习。无标签的数据仅由图像组成,没有任何标签。

SSL很好,因为通常未标记的数据比标记的数据多得多,特别是当你把模型部署到生产中时。而且,SSL减少了贴标签的时间、成本和精力。

但是,一个模型如何从没有标签的图像中学习?关键的见解是,图像本身有信息。SSL的神奇之处在于,它可以通过对基于结构相似的图像进行自动聚类,从没有标签的数据中提取信息,而这种聚类为模型提供了额外的信息来学习。

本教程使用了Google Colab中包含的几个常用Python库,包括matplotlib、numpy和TensorFlow。如果你需要安装它们,你通常可以在Jupyter笔记本中运行!pip install --upgrade pip; pip install matplotlib numpy tensorflow ,或者从命令行中运行pip install --upgrade pip; pip install matplotlib numpy tensorflow (没有感叹号)。

如果你使用谷歌Colab,确保将运行时类型改为GPU。

在本教程中,让我们在CIFAR-10数据集上训练一个分类器。这是一个经典的自然图像的研究数据集。让我们加载它并看一看。我们将看到CIFAR-10中的一些类别:青蛙、船、汽车、卡车、鹿、马、鸟、猫、狗和飞机。

import matplotlib.pyplot as plt

def plot_images(images):
  """Simple utility to render images."""
  # Visualize the data.
  _, axarr = plt.subplots(5, 5, figsize=(15,15))

  for row in range(5):
    for col in range(5):
      image = images[row*5 + col]
      axarr[row, col].imshow(image)
      
import tensorflow as tf

NUM_CLASSES = 10
# Load the data using the Keras Datasets API. 
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

plot_images(x_test)

创建模型

一般来说,你会想使用一个现成的模型架构。这可以节省你摆弄模型架构设计的精力。模型大小的一般规则是选择一个足够大的模型来处理你的数据,但又不至于大到在推理过程中很慢。对于像CIFAR-10这样非常小的数据集,我们将使用一个非常小的模型。对于图像尺寸较大的数据集,Efficient Net系列是一个不错的选择。

def get_model():    
    return tf.keras.applications.MobileNet(input_shape=(32,32,3), 
                                           weights=None, 
                                           classes=NUM_CLASSES, 
                                           classifier_activation=None)

model = get_model()

准备好数据

现在,让我们准备一下数据,将标签(从0到9的整数,代表10类物体)转换成单热向量,如[1,0,0,0,0,0,0,0]和[0,0,0,0,0,1]。我们还将把图像像素更新为模型架构所期望的范围,即[-1,1]的范围。

def normalize_data(x_train, y_train, x_test, y_test):
  """Utility to normalize the data into standard formats."""

  # Update the pixel range to [-1,1], which is expected by the model architecture.
  x_train = x = tf.keras.applications.mobilenet.preprocess_input(x_train)
  x_test = x = tf.keras.applications.mobilenet.preprocess_input(x_test)

  # Convert to one-hot labels. 
  y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
  y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

  return x_train, y_train, x_test, y_test
  
x_train, y_train, x_test, y_test = \
  normalize_data(x_train, y_train, x_test, y_test)

这个数据集包括50,000个例子。让我们把其中的5000个作为有标签的图像,把20000个作为无标签的图像。

import numpy as np

def prepare_data(x_train, y_train, num_labeled_examples, num_unlabeled_examples):
    """Returns labeled and unlabeled datasets."""
    num_examples = x_train.size

    assert num_labeled_examples + num_unlabeled_examples <= num_examples

    # Generate some random indices. 
    dataset_size = len(x_train)
    indices = np.array(range(dataset_size))
    generator = np.random.default_rng(seed=0)
    generator.shuffle(indices)

    # Split the indices into two sets: one for labeled, one for unlabeled. 
    labeled_train_indices = indices[:num_labeled_examples]
    unlabeled_train_indices = indices[num_labeled_examples : num_labeled_examples + num_unlabeled_examples]

    x_labeled_train = x_train[labeled_train_indices]
    y_labeled_train = y_train[labeled_train_indices]

    x_unlabeled_train = x_train[unlabeled_train_indices]
    # Since this is unlabeled, we won't need a y_labeled_data. 

    return x_labeled_train, y_labeled_train, x_unlabeled_train

NUM_LABELED = 5000
NUM_UNLABELED = 20000

x_labeled_train, y_labeled_train, x_unlabeled_train = \
    prepare_data(x_train, 
                 y_train, 
                 num_labeled_examples=NUM_LABELED, 
                 num_unlabeled_examples=NUM_UNLABELED)

del x_train, y_train

基线训练

为了测量SSL带来的性能改进,我们首先测量模型在没有SSL的标准训练循环中的性能。

让我们建立一个带有一些基本数据增强的标准训练循环。数据增强是正则化的一种,它可以对抗过度拟合,使你的模型能够更好地概括它从未见过的数据。

下面的超参数值(学习率、历时、批次大小等)是普通默认值和手动调整值的组合。

其结果是一个模型的准确率约为45%。(记住要读出验证精度,而不是训练精度)。我们的下一个任务是弄清楚我们是否可以使用SSL来提高模型的准确度。

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)

# Setup Keras augmentation. 
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    horizontal_flip=True)

datagen.fit(x_labeled_train)

batch_size = 64
epochs = 30
model.fit(
    x = datagen.flow(x_labeled_train, y_labeled_train, batch_size=batch_size),
    shuffle=True,
    validation_data=(x_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
)
baseline_metrics = model.evaluate(x=x_test, y=y_test, return_dict=True)
print('')
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")


输出:

Epoch 1/30
79/79 [==============================] - 4s 23ms/step - loss: 2.4214 - categorical_accuracy: 0.1578 - val_loss: 2.3047 - val_categorical_accuracy: 0.1000
Epoch 2/30
79/79 [==============================] - 1s 16ms/step - loss: 2.0831 - categorical_accuracy: 0.2196 - val_loss: 2.3063 - val_categorical_accuracy: 0.1000
Epoch 3/30
79/79 [==============================] - 1s 16ms/step - loss: 1.9363 - categorical_accuracy: 0.2852 - val_loss: 2.3323 - val_categorical_accuracy: 0.1000
Epoch 4/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8324 - categorical_accuracy: 0.3174 - val_loss: 2.3496 - val_categorical_accuracy: 0.1000
Epoch 5/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8155 - categorical_accuracy: 0.3438 - val_loss: 2.3339 - val_categorical_accuracy: 0.1000
Epoch 6/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6477 - categorical_accuracy: 0.3886 - val_loss: 2.3606 - val_categorical_accuracy: 0.1000
Epoch 7/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6120 - categorical_accuracy: 0.4100 - val_loss: 2.3585 - val_categorical_accuracy: 0.1000
Epoch 8/30
79/79 [==============================] - 1s 16ms/step - loss: 1.5884 - categorical_accuracy: 0.4220 - val_loss: 2.1796 - val_categorical_accuracy: 0.2519
Epoch 9/30
79/79 [==============================] - 1s 18ms/step - loss: 1.5477 - categorical_accuracy: 0.4310 - val_loss: 1.8913 - val_categorical_accuracy: 0.3145
Epoch 10/30
79/79 [==============================] - 1s 15ms/step - loss: 1.4328 - categorical_accuracy: 0.4746 - val_loss: 1.7082 - val_categorical_accuracy: 0.3696
Epoch 11/30
79/79 [==============================] - 1s 16ms/step - loss: 1.4328 - categorical_accuracy: 0.4796 - val_loss: 1.7679 - val_categorical_accuracy: 0.3811
Epoch 12/30
79/79 [==============================] - 2s 20ms/step - loss: 1.3962 - categorical_accuracy: 0.5020 - val_loss: 1.8994 - val_categorical_accuracy: 0.3690
Epoch 13/30
79/79 [==============================] - 1s 16ms/step - loss: 1.3271 - categorical_accuracy: 0.5156 - val_loss: 2.0416 - val_categorical_accuracy: 0.3688
Epoch 14/30
79/79 [==============================] - 1s 17ms/step - loss: 1.2711 - categorical_accuracy: 0.5374 - val_loss: 1.9231 - val_categorical_accuracy: 0.3848
Epoch 15/30
79/79 [==============================] - 1s 15ms/step - loss: 1.2312 - categorical_accuracy: 0.5624 - val_loss: 1.9006 - val_categorical_accuracy: 0.3961
Epoch 16/30
79/79 [==============================] - 1s 19ms/step - loss: 1.2048 - categorical_accuracy: 0.5720 - val_loss: 2.0102 - val_categorical_accuracy: 0.4102
Epoch 17/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1365 - categorical_accuracy: 0.6000 - val_loss: 2.1400 - val_categorical_accuracy: 0.3672
Epoch 18/30
79/79 [==============================] - 1s 18ms/step - loss: 1.1992 - categorical_accuracy: 0.5840 - val_loss: 2.1206 - val_categorical_accuracy: 0.3933
Epoch 19/30
79/79 [==============================] - 2s 25ms/step - loss: 1.1438 - categorical_accuracy: 0.6012 - val_loss: 2.4035 - val_categorical_accuracy: 0.4014
Epoch 20/30
79/79 [==============================] - 2s 24ms/step - loss: 1.1211 - categorical_accuracy: 0.6018 - val_loss: 2.0224 - val_categorical_accuracy: 0.4010
Epoch 21/30
79/79 [==============================] - 2s 21ms/step - loss: 1.0425 - categorical_accuracy: 0.6358 - val_loss: 2.2100 - val_categorical_accuracy: 0.3911
Epoch 22/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1177 - categorical_accuracy: 0.6116 - val_loss: 1.9892 - val_categorical_accuracy: 0.4285
Epoch 23/30
79/79 [==============================] - 1s 19ms/step - loss: 1.0236 - categorical_accuracy: 0.6412 - val_loss: 2.1216 - val_categorical_accuracy: 0.4211
Epoch 24/30
79/79 [==============================] - 1s 18ms/step - loss: 0.9487 - categorical_accuracy: 0.6714 - val_loss: 2.0135 - val_categorical_accuracy: 0.4307
Epoch 25/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1877 - categorical_accuracy: 0.5876 - val_loss: 2.3732 - val_categorical_accuracy: 0.3923
Epoch 26/30
79/79 [==============================] - 2s 20ms/step - loss: 1.0639 - categorical_accuracy: 0.6288 - val_loss: 1.9291 - val_categorical_accuracy: 0.4291
Epoch 27/30
79/79 [==============================] - 2s 19ms/step - loss: 0.9243 - categorical_accuracy: 0.6882 - val_loss: 1.8552 - val_categorical_accuracy: 0.4343
Epoch 28/30
79/79 [==============================] - 1s 15ms/step - loss: 0.9784 - categorical_accuracy: 0.6656 - val_loss: 2.0175 - val_categorical_accuracy: 0.4386
Epoch 29/30
79/79 [==============================] - 1s 17ms/step - loss: 0.9316 - categorical_accuracy: 0.6800 - val_loss: 1.9916 - val_categorical_accuracy: 0.4305
Epoch 30/30
79/79 [==============================] - 1s 17ms/step - loss: 0.8816 - categorical_accuracy: 0.7054 - val_loss: 2.0281 - val_categorical_accuracy: 0.4366
313/313 [==============================] - 1s 3ms/step - loss: 2.0280 - categorical_accuracy: 0.4366

Baseline model accuracy: 0.436599999666214

使用SSL进行训练

现在,让我们看看是否可以通过在训练数据中添加未标记的数据来提高模型的准确性。我们将使用Masterful,一个为计算机视觉模型(如我们的分类器)实现SSL的平台。

让我们来安装Masterful。在谷歌Colab中,我们可以从笔记本单元中进行管道安装。我们也可以通过命令行来安装它。更多细节,请看Masterful安装指南

!pip install --upgrade pip
!pip install masterful

import masterful

masterful = masterful.register()

输出:

Loaded Masterful version 0.4.1. This software is distributed free of
charge for personal projects and evaluation purposes.
See http://www.masterfulai.com/personal-and-evaluation-agreement for details.
Sign up in the next 45 days at https://www.masterfulai.com/get-it-now
to continue using Masterful.

设置Masterful

现在,我们来设置Masterful的一些配置参数。

# Start fresh with a new model
tf.keras.backend.clear_session()
model = get_model()

# Tell Masterful that your model is performing a classification task
# with 10 labels and that the image pixel range is 
# [-1,1]. Also, the model outputs logits rather than a softmax activation.
model_params = masterful.architecture.learn_architecture_params(
    model=model,
    task=masterful.enums.Task.CLASSIFICATION,
    input_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    prediction_logits=True,
)

# Tell Masterful that your labeled training data is using one-hot labels. 
labeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_labeled_train, y_labeled_train),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)

unlabeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_unlabeled_train,),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=None,
)

# Tell Masterful that your test/validation data is using one-hot labels. 
test_data_params = masterful.data.learn_data_params(
    dataset=(x_test, y_test),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)

# Let Masterful meta-learn ideal optimization hyperparameters like
# batch size, learning rate, optimizer, learning rate schedule, and epochs.
# This will speed up training. 
optimization_params = masterful.optimization.learn_optimization_params(
    model,
    model_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
)

# Let Masterful meta-learn ideal regularization hyperparameters. Regularization
# is an important ingredient of SSL. Meta-learning can
# take a while so we'll use a precached set of parameters.
# regularization_params = \
#   masterful.regularization.learn_regularization_params(model, 
#                                                        model_params, 
#                                                        optimization_params, 
#                                                        (x_labeled_train, y_labeled_train),
#                                                        labeled_training_data_params)

regularization_params = masterful.regularization.parameters.CIFAR10_SMALL

# Let Masterful meta-learn ideal SSL hyperparameters. 
ssl_params = masterful.ssl.learn_ssl_params(
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

输出:

MASTERFUL: Learning optimal batch size.
MASTERFUL: Learning optimal initial learning rate for batch size 256.

现在,我们已经准备好使用SSL技术进行训练了!我们将调用masterful.training.train,它是进入Masterful训练引擎的入口。

training_report = masterful.training.train(
    model,
    model_params,
    optimization_params,
    regularization_params,
    ssl_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    (x_test, y_test),
    test_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

输出:

MASTERFUL: Training model with semi-supervised learning enabled.
MASTERFUL: Performing basic dataset analysis.
MASTERFUL: Training model with:
MASTERFUL: 	5000 labeled examples.
MASTERFUL: 	10000 validation examples.
MASTERFUL: 	0 synthetic examples.
MASTERFUL: 	20000 unlabeled examples.
MASTERFUL: Training model with learned parameters partridge-boiled-cap in two phases.
MASTERFUL: The first phase is supervised training with the learned parameters.
MASTERFUL: The second phase is semi-supervised training to boost performance.
MASTERFUL: Warming up model for supervised training.
MASTERFUL: 	Warming up batch norm statistics (this could take a few minutes).
MASTERFUL: 	Warming up training for 500 steps.
100%|██████████| 500/500 [00:47<00:00, 10.59steps/s]
MASTERFUL: 	Validating batch norm statistics after warmup for stability (this could take a few minutes).
MASTERFUL: Starting Phase 1: Supervised training until the validation loss stabilizes...
Supervised Training: 100%|██████████| 6300/6300 [02:33<00:00, 41.13steps/s]
MASTERFUL: Starting Phase 2: Semi-supervised training until the validation loss stabilizes...
MASTERFUL: Warming up model for semi-supervised training.
MASTERFUL: 	Warming up batch norm statistics (this could take a few minutes).
MASTERFUL: 	Warming up training for 500 steps.
100%|██████████| 500/500 [00:23<00:00, 20.85steps/s]
MASTERFUL: 	Validating batch norm statistics after warmup for stability (this could take a few minutes).
Semi-Supervised Training: 100%|██████████| 11868/11868 [08:06<00:00, 24.39steps/s]

分析结果

你传入masterful.training.train的模型现在已经被训练并更新到位,所以你能够像其他训练好的Keras模型一样对它进行评估。

masterful_metrics = model.evaluate(
    x_test, y_test, return_dict=True, verbose=0
)
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")
print(f"Masterful model accuracy: {masterful_metrics['categorical_accuracy']}")

输出:

Baseline model accuracy: 0.436599999666214
Masterful model accuracy: 0.558899998664856

结果的可视化

正如你所看到的,你将准确率从0.45左右提高到0.56。当然,更严格的研究会试图消除基线训练和通过Masterful平台使用SSL训练之间的其他差异,以及重复运行几次并产生误差条和P值。现在,让我们确保将其绘制成图表,以帮助解释我们的结果。

import matplotlib.cm as cm
from matplotlib.colors import Normalize
 
data = (baseline_metrics['categorical_accuracy'], masterful_metrics['categorical_accuracy'])
fig, ax = plt.subplots(1, 1)
   
ax.bar(range(2), data, color=('gray', 'red'))

plt.xlabel("Training Method")
plt.ylabel("Accuracy")

plt.xticks((0,1), ("baseline", "SSL with Masterful"))

plt.show()

结论

祝贺你!我们刚刚在一个简单的教程中成功地采用了SSL这种最先进的训练方法来提高你的模型精度。在这一过程中,你避免了标注的成本和精力。

SSL不仅适用于分类--各种不同类型的方法几乎适用于任何计算机视觉任务。要想深入了解这个主题,并看到SSL在物体检测方面的作用,请查看这里的其他教程。