transforms.Normalize的均值和标准差计算过程

231 阅读1分钟

transforms.Normalize 用于对图像进行标准化处理,其计算过程是基于输入图像的均值(mean)和标准差(std)。这个操作有助于使输入数据的分布接近标准正态分布,从而帮助神经网络更好地学习和收敛。

具体来说,transforms.Normalize 对每个通道(对于RGB图像)进行标准化。标准化的过程是将每个通道的像素值减去均值,然后除以标准差。

input[channel]=input[channel]mean[channel]std[channel]input[channel] = \frac{input[channel]−mean[channel]}{std[channel]}

在训练过程中,均值和标准差的计算通常是在整个训练数据集上进行的,或者如果数据集非常庞大的时候,我们可以随机抽样来计算其均值和标准差。下面为示例代码,给定路径和需要计算的样本数量。

def InfoCalc(img_dir, calc_num):
    h, w = 32, 32
    imgs = np.zeros([w, h, 3, 1])
    mean, std = [], []
    num = 0
    for root, _, files in os.walk(img_dir):
        for name in files:
            if name.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                img_path = os.path.join(root, name)
                img = cv2.imread(img_path)
                img = cv2.resize(img, (h, w))
                img = img[:, :, :, np.newaxis]
                imgs = np.concatenate((imgs, img), axis=3)
                num += 1
                if num >= calc_num:
                    break

    imgs = imgs.astype(np.float32) / 255.

    #三通道图
    for i in range(3):
        pixels = imgs[:, :, i, :].ravel()  # 展平
        mean.append(np.mean(pixels))
        std.append(np.std(pixels))

    mean.reverse()  # BGR --> RGB
    std.reverse()

    return mean, std