手撕注意力机制
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class MultiheadAttention(nn.Module):
def __init__(self, d_model: int, heads: int = 8):
super().__init__()
self.d_model = d_model
self.heads = heads
self.d_k = self.d_model // self.heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None
):
bsz = q.shape[0]
q = self.q_proj(q).view(bsz, -1, self.heads, self.d_k).transpose(1, 2)
k = self.q_proj(k).view(bsz, -1, self.heads, self.d_k).transpose(1, 2)
v = self.q_proj(v).view(bsz, -1, self.heads, self.d_k).transpose(1, 2)
attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
scores = F.softmax(attn, dim = -1)
out = torch.matmul(scores, v)
concat = out.transpose(1, 2).contiguous().view(bsz, -1, self.d_model)
output = self.o_proj(concat)
return output
if __name__ == "__main__":
bsz = 4
seq_len = 1024
d_ff = 512
x = torch.randn(bsz, seq_len, d_ff)
attention = MultiheadAttention(d_model=512, heads=8)
mask = torch.tril(torch.ones(seq_len, seq_len))
out = attention(x, x, x, mask)
print(out.shape)
手撕 softmax
def softmax(x: torch.Tensor):
exp_x = torch.exp(x - x.max(dim=-1, keepdim=True).values)
return exp_x / exp_x.sum(dim=-1, keepdim=True)
手撕 cross_entropy
def cross_entropy(x: torch.Tensor, y: torch.Tensor):
return -torch.sum(y * torch.log(x), dim=1).mean()
使用 SGD 实现平方根
import numpy as np
def sqrt_sgd(x, lr, epochs, tolerance):
for i in range(epochs):
loss = (x ** 2 - a) ** 2
grad = 4 * x * (x ** 2 - a)
x -= lr * grad
if i % 100 == 0:
print(f"epoch {i}: x = {x}, loss = {loss}")
if loss < tolerance:
break
return x
if __name__ == "__main__":
a = 25
lr = 0.01
epochs = 1000
tolerance = 1e-7
x = np.random.rand()
sqrt_value = sqrt_sgd(x, lr, epochs, tolerance)
print(f"The square root of {a} is approximately {sqrt_value}")
神经网络实现 k 分类
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
batch_size = 64
input_size = 784
hidden_size = 512
num_classes = 10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(std=(0.1307,), mean=(0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
class MyNet(nn.Module):
def __init__(self, input_size=784, hidden_size=512, num_classes=10) -> None:
super(MyNet, self).__init__()
self.l1 = torch.nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = MyNet(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):
runing_loss = 0
for batch_idx, (inputs, target) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
runing_loss += loss.item()
if batch_idx % 300 == 299:
print("[%5d, %5d] loss: %3.f" % (epoch + 1, batch_idx + 1, runing_loss / 300))
runing_loss == 0.0
def evaluate():
correct = 0
total = 0
with torch.no_grad():
for (images, labels) in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Accuracy on test set: %d%%" % (100 * correct / total))
if __name__ == "__main__":
num_epochs = 10
for epoch in range(num_epochs):
train(epoch=epoch)
evaluate()