什么是Keras/TensorFlow损失函数中的'from_logits=True'?

386 阅读3分钟

像Keras这样的深度学习框架降低了大众的入门门槛,并将DL模型的开发民主化给没有经验的民间人士,他们可以依靠合理的默认值和简化的API来承担重任,并产生体面的结果。

在使用Keras损失函数进行分类时,较新的深度学习从业者之间出现了一个常见的困惑,例如CategoricalCrossentropySparseCategoricalCrossentropy

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Or
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)

from_logits 标志指的是什么?

答案相当简单,但需要看一下我们试图用损失函数来分级的网络的输出。

Logits和SoftMax概率

长话短说:

概率是归一化的--也就是说,其范围在[0..1] 。对数没有被规范化,其范围在[-inf...+inf] 之间。

取决于你的网络的输出层。

output = keras.layers.Dense(n, activation='softmax')(x)
# Or
output = keras.layers.Dense(n)(x)

Dense 层的输出将返回:

  • 概率:输出通过SoftMax函数,该函数将输出归一化为一组超过n 的概率,所有这些概率加起来就是1
  • 对数:n 激活。

这种误解可能来自于允许你向一个层添加激活的简短语法,似乎是一个单层,尽管它只是简写。

output = keras.layers.Dense(n, activation='softmax')(x)
# Equivalent to
dense = keras.layers.Dense(n)(x)
output = keras.layers.Activation('softmax')(dense)

你的损失函数必须被告知它是否应该期望一个归一化分布(通过SoftMax函数的输出)或对数。因此,from_logits 标志!

什么时候from_logits=True

如果你的网络将输出概率规范化,你的损失函数应该将from_logits 设置为False ,因为它不接受logits。这也是所有接受该标志的损失类的默认值,因为大多数人在他们的输出层中添加了一个activation='softmax'

model = keras.Sequential([
    keras.layers.Input(shape=(10, 1)),
    # Outputs normalized probability - from_logits=False
    keras.layers.Dense(10, activation='softmax') 
])

input_data = tf.random.uniform(shape=[1, 1])
output = model(input_data)
print(output)

这就导致了:

tf.Tensor(
[[[0.12467965 0.10423233 0.10054766 0.09162105 0.09144577 0.07093797
   0.12523937 0.11292477 0.06583504 0.11253635]]], shape=(1, 1, 10), dtype=float32)

**False由于这个网络的结果是一个归一化的分布--当将输出与目标输出进行比较,并通过分类损失函数进行分级时(针对适当的任务)-- **你应该将from_logits ,**或者让默认值保持不变。

另一方面,如果你的网络不在输出上应用SoftMax:

model = keras.Sequential([
    keras.layers.Input(shape=(10, 1)),
    # Outputs logits - from_logits=True
    keras.layers.Dense(10)
])

input_data = tf.random.uniform(shape=[1, 1])
output = model(input_data)
print(output)

这就导致了:

tf.Tensor(
[[[-0.06081138  0.04154852  0.00153442  0.0705068  -0.01139916
    0.08506121  0.1211026  -0.10112958 -0.03410497  0.08653068]]], shape=(1, 1, 10), dtype=float32)

你需要将from_logits 设置 True,以便损失函数能够正确处理输出。

什么时候在输出上使用SoftMax?

大多数从业者在输出上应用SoftMax,以提供一个归一化的概率分布,因为在许多情况下,这是你将使用网络的目的--特别是在简化的教育材料中。然而,在某些情况下,你不想在输出上应用该函数,而是在应用SoftMax或其他函数之前以不同的方式处理它。

一个明显的例子来自于NLP模型,其中一个真正的大词汇的概率可以存在于输出张量中。将SoftMax应用于所有的词汇并贪婪地获取argmax ,通常不会产生很好的结果。

然而,如果你观察对数,提取Top-K(其中K可以是任何数字,但通常介于[0...10] ),然后才对词汇中可能存在的前k个标记应用SoftMax,就会使分布发生明显变化,通常会产生更真实的结果。

这就是所谓的Top-K抽样,虽然它不是理想的策略,但通常明显优于贪婪抽样。

结语

在这个简短的指南中,我们看了一下Keras损失类的from_logits 论点,这常常会引起新的从业者的疑问。

这种困惑可能来自于允许在层本身的定义中,在其他层之上添加激活层的简短语法。我们最后看了一下,什么时候该参数应该被设置为TrueFalse ,什么时候输出应该被保留为logits或通过激活函数(如SoftMax)传递。