API介绍
tf.random.truncated_normal 是一个TensorFlow函数,用于生成截断正态分布(truncated normal distribution)的随机数。
截断正态分布是指在一个给定区间内的正态分布,超出这个区间的值将被排除。这个函数可以生成在指定区间内满足正态分布的随机数。
tf.random.truncated_normal 的函数签名如下:
tf.random.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
参数说明:
shape:生成随机数的形状(shape)。可以是一个整数,表示一维张量,或者一个整数元组,表示多维张量。mean:正态分布的均值,默认为0.0。stddev:正态分布的标准差,默认为1.0。dtype:生成随机数的数据类型,默认为tf.float32。seed:随机数生成器的种子值,用于保证可重复性。如果指定了种子值,每次使用相同的种子生成的随机数序列将相同。name:操作的名称,可选。
这个函数返回一个与指定形状相符的张量,其中的元素是从截断正态分布中随机生成的。
截断区间
tf.random.truncated_normal 生成的截断正态分布的区间是均值±2倍标准差。这是正态分布的常用截断区间,通常认为在这个区间外的数据是异常值或不合理的值。
具体来说,对于给定的均值 mean 和标准差 stddev,tf.random.truncated_normal 生成的随机数满足 mean - 2 * stddev < x < mean + 2 * stddev。如果生成的随机数超出这个区间,将会重新生成,直到满足条件为止。
这种截断方式可以确保生成的随机数在均值附近,而不会出现过于极端的值。