transforms.Normalize 用于对图像进行标准化处理,其计算过程是基于输入图像的均值(mean)和标准差(std)。这个操作有助于使输入数据的分布接近标准正态分布,从而帮助神经网络更好地学习和收敛。
具体来说,transforms.Normalize 对每个通道(对于RGB图像)进行标准化。标准化的过程是将每个通道的像素值减去均值,然后除以标准差。
在训练过程中,均值和标准差的计算通常是在整个训练数据集上进行的,或者如果数据集非常庞大的时候,我们可以随机抽样来计算其均值和标准差。下面为示例代码,给定路径和需要计算的样本数量。
def InfoCalc(img_dir, calc_num):
h, w = 32, 32
imgs = np.zeros([w, h, 3, 1])
mean, std = [], []
num = 0
for root, _, files in os.walk(img_dir):
for name in files:
if name.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
img_path = os.path.join(root, name)
img = cv2.imread(img_path)
img = cv2.resize(img, (h, w))
img = img[:, :, :, np.newaxis]
imgs = np.concatenate((imgs, img), axis=3)
num += 1
if num >= calc_num:
break
imgs = imgs.astype(np.float32) / 255.
#三通道图
for i in range(3):
pixels = imgs[:, :, i, :].ravel() # 展平
mean.append(np.mean(pixels))
std.append(np.std(pixels))
mean.reverse() # BGR --> RGB
std.reverse()
return mean, std