Pytorch 梯度反转层及测试
参考文献: 梯度反转
import torch.nn as nn
from torch.autograd.function import Function
class Grl_func(Function):
def __init__(self):
super(Grl_func, self).__init__()
@staticmethod
def forward(ctx, x, lambda_):
ctx.save_for_backward(lambda_)
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
lambda_, = ctx.saved_tensors
grad_input = grad_output.clone()
return -lambda_ * grad_input, None
class GRL(nn.Module):
def __init__(self, lambda_=0.):
super(GRL, self).__init__()
self.lambda_ = torch.tensor(lambda_)
def set_lambda(self, lambda_):
self.lambda_ = torch.tensor(lambda_)
def forward(self, x):
return Grl_func.apply(x, self.lambda_)
# 首先建立一个全连接的子module,继承nn.Module
class Linear1(nn.Module):
def __init__(self):
super(Linear1, self).__init__() # 调用nn.Module构造函数
# 使用nn.Parameter来构造需要学习的参数
self.w = nn.Parameter(torch.tensor([[1., 2., 3.], [1., 1., 1.]]))
self.b = nn.Parameter(torch.tensor([1., 1., 1.]))
# 在forward中实现向前传播过程
def forward(self, x):
x = x.matmul(self.w) # 使用Tensor.matmul实现矩阵相乘
y = x + self.b.expand_as(x) # 使用Tensor.expand_as()来保证矩阵形状一致
return y
# 首先建立一个全连接的子module,继承nn.Module
class Linear2(nn.Module):
def __init__(self):
super(Linear2, self).__init__() # 调用nn.Module构造函数
# 使用nn.Parameter来构造需要学习的参数
self.w = nn.Parameter(torch.tensor([[1., 2., 3.], [1., 1., 1.], [1., 1., 1.]]))
self.b = nn.Parameter(torch.tensor([1., 1., 1.]))
# 在forward中实现向前传播过程
def forward(self, x):
x = x.matmul(self.w) # 使用Tensor.matmul实现矩阵相乘
y = x + self.b.expand_as(x) # 使用Tensor.expand_as()来保证矩阵形状一致
return y
# 实例化一个网络,并赋值全连接中的维数,最终输出二维代表了二分类
perception1 = Linear1()
perception2 = Linear2()
grl = GRL()
grl.set_lambda(1.0)
# 随机生成数据,注意这里的4代表了样本数为4,每个样本有两维
data = torch.tensor([[2., 1.], [1., 1.]])
output = perception1(data)
# output = grl(output) # 是有效的
output = perception2(output)
output = grl(output) # 是有效的
print(f'output:\n {output}\n')
output.sum().backward()
print(f'perception1.w.grad:\n {perception1.w.grad}\n')
print(f'perception1.b.grad:\n {perception1.b.grad}\n')
print(f'perception2.w.grad:\n {perception2.w.grad}\n')
print(f'perception2.b.grad:\n {perception2.b.grad}\n')