CONTINUAL LOCAL TRAINING FOR BETTER INITIALIZATION OF FEDERATED MODELS 论文解读+代码分析

957 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。 论文地址点这里

一. 介绍

本篇文章的介绍和之前的持续学习、联邦学习类似,就不赘述。 作者提出的为局部持续训练的联邦学习来缓解权值的发散,并将不同的局部客户端的知识不断整合到全局模型中,保证了全局模型具有更好的泛化能力。全局模型参数的重要性权重在中央服务器上的一个小的代理数据集上进行评估,然后用于约束本地训练。通过这种方式,联邦模型能够在保持其原始性能的同时学习客户端的知识。

二. 方法

FedCL通过参数正则化持续训练处理全球模型中发生变化的重要参数。FedCL通过在服务器估算代理数据集的重要性权重,然后将其分配给客户端。

2.1 估算代理数据的重要性矩阵

根据EWC的思想,我们通过在服务端上的代理数据集(xk,yk)Dproxy(x_k,y_k)\in \mathcal{D}_{proxy},来计算参数的重要程度:

wk=L(θg(xk),yk)θg(1)w^k=\parallel\frac{\partial\mathcal{L}(\theta_g(x_k),y_k)}{\partial\theta_g}\parallel \tag 1

对所有的代理数据集进行一次计算后,求平均即可:

Ω=1Dproxy(xk,yk)Dproxywk(2)\Omega = \frac{1}{|\mathcal{D}_{proxy}|}\sum_{(x_k,y_k)\in \mathcal{D}_{proxy}}w_k \tag 2

2.2 在客户端上进行持续训练

在服务端上进行计算后的,我们在客户端上进行学习:

Lk(θk)=Llocal(θk)+λi,jΩi,j(θk,i,jθg,i,j)2(3)\mathcal{L}_k(\theta_k)=\mathcal{L}_{local}(\theta_k)+\lambda\sum_{i,j}\Omega_{i,j}(\theta_{k,i,j}-\theta_{g,i,j})^2 \tag 3

三. 代码介绍

本文是基于正则化的持续学习,唯一不同的是在服务端需要部署一部分数据和模型,来评估当前的权重因子。

def train(self):
    for rnd in range(self.rounds):
    	## 选择客户端
        np.random.shuffle(self.nets_pool)
        pool = mp.Pool(self.num_per_rnd)
        self.q = mp.Manager().Queue()
        dict_new = self.global_agent.model.state_dict()
        ## 在对应的轮数进行因子评估
        if self.estimate_weights_in_center and rnd % self.interval == 0:
            w_d = self.global_agent.estimate_weights(self.policy)
        else:
            w_d = None
        ## 加载训练
        for net in self.nets_pool[:self.num_per_rnd]:
            net.model.load_state_dict(dict_new)
            net.set_lr(self.global_agent.lr)
            pool.apply_async(train_local_mp, (net, self.local_epochs, rnd, self.q, self.policy, w_d))
        pool.close()
        pool.join()
        self.update_global(rnd)

评估方式和EWC等方式一致,如下:

def ewc(self, train_loader=None):
    if train_loader is None:
        train_loader = self.train_loader
    tmp_weights = dict()
    for k, p in self.model.named_parameters():
        tmp_weights[k] = torch.zeros_like(p)
    self.model.eval()
    num_examples = 0
    for image_batch, label_batch in BackgroundGenerator(train_loader):
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        num_examples += image_batch.size(0)
        # compute output
        output, _ = self.model(image_batch)
        loss = self.criterion(output, label_batch)
        # compute gradient
        self.optimizer.zero_grad()
        loss.backward()
        for k, p in self.model.named_parameters():
            tmp_weights[k].add_(p.grad.detach() ** 2)
    for k, v in tmp_weights.items():
        tmp_weights[k] = torch.sum(v).div(num_examples)
    return tmp_weights