CLIP

156 阅读1分钟
import numpy as np
import torch
from torch.nn.functional import normalize
from torch.nn import CrossEntropyLoss

# 因为对角线上的元素为正样本
labels = torch.from_numpy(np.arange(32))
print(labels.shape)

I_f = torch.randn(32, 512).numpy()  # image_encoder(I) [n, d_i]
T_f = torch.randn(32, 512).numpy()  # text_encoder(T) [n, d_i]

W_i = torch.randn(512, 32).numpy()
W_t = torch.randn(512, 32).numpy()

#  w_i和w_t 学习如何将单模态特征投射多到模态特征,然后做l2归一化
I_e = normalize(torch.from_numpy(np.dot(I_f, W_i)), p=2, dim=1)
T_e = normalize(torch.from_numpy(np.dot(T_f, W_t)), p=2, dim=1)

T = 10  # 当目标携带的信息量太少(比如在某些类别上的概率非常小, 只有1e-6), 而我们又想放大这些信息, 就可以尝试引入较大的温度参数T, 从而蒸馏出这些较小值所携带的信息.
logits = torch.from_numpy(np.dot(I_e, T_e.T) * np.exp(T))
print(logits)

loss = CrossEntropyLoss()
loss_i = loss(logits, labels)  # 默认是dim=1
print(loss_i)
loss_t = loss(logits.T, labels)
print(loss_t)

total_loss = (loss_t + loss_i) / 2
print(total_loss)
# print(loss_i)