import numpy as np
import torch
from torch.nn.functional import normalize
from torch.nn import CrossEntropyLoss
labels = torch.from_numpy(np.arange(32))
print(labels.shape)
I_f = torch.randn(32, 512).numpy()
T_f = torch.randn(32, 512).numpy()
W_i = torch.randn(512, 32).numpy()
W_t = torch.randn(512, 32).numpy()
I_e = normalize(torch.from_numpy(np.dot(I_f, W_i)), p=2, dim=1)
T_e = normalize(torch.from_numpy(np.dot(T_f, W_t)), p=2, dim=1)
T = 10
logits = torch.from_numpy(np.dot(I_e, T_e.T) * np.exp(T))
print(logits)
loss = CrossEntropyLoss()
loss_i = loss(logits, labels)
print(loss_i)
loss_t = loss(logits.T, labels)
print(loss_t)
total_loss = (loss_t + loss_i) / 2
print(total_loss)