tf.random.categorical 是一个TensorFlow函数,用于从分类分布中随机抽取样本。
函数的签名如下:
tf.random.categorical(logits, num_samples=1, dtype=tf.int64)
参数说明:
logits:一个2-D张量,表示分类分布的概率。每一行代表一个样本,每一列代表一个类别。logits的形状应该为[num_samples, num_classes]。num_samples:要从分布中抽取的样本数量。默认为1。dtype:返回的张量的数据类型。默认为tf.int64。
函数的返回值是一个形状为(num_samples,)的张量,表示从分类分布中随机抽取的样本。
这个函数通常用于多分类问题中,例如机器学习和深度学习的分类任务。它可以根据给定的概率分布随机生成样本标签。
示例
import tensorflow as tf
# 定义分类分布的概率
logits = tf.constant([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]])
# 从分类分布中随机抽取样本
samples = tf.random.categorical(logits, num_samples=10)
# 打印结果
print(samples)
tf.Tensor(
[[2 2 1 0 2 2 1 2 2 0]
[2 1 0 0 0 2 0 2 0 2]], shape=(2, 10), dtype=int64)