Abstract
现代移动设备含有大量的可以用来训练模型的数据,(通过使用这些数据来训练模型)可以极大的改善用户的体验。例如,语言模型可以增强语言识别和文本输入,图像模型可以自动选择好的图片。然而,这些丰富的数据通常对隐私很敏感、或者数据巨大,或两者兼有,这可能会导致数据无法导入中心节点,无法用传统的集中式机器学习途径去训练。我们提出了让训练的数据分散在移动设备上,通过聚合本地的计算来进行更新,去训练一个共享的模型。我们把这种去中心化的方法叫做联邦学习。
基于可迭代的模型聚合,我们提出了一个实用的深度网络联邦学习模型,并且还进行了广泛的实验评估,考虑了五个不同的模型结构和四个数据集。这些实验证明了这种方法对于无偏且非独立同分布(典型特征)的数据是鲁棒的,通信的代价是最重要的限制,我们展示了,和同步的随机梯度下降相比,我们的通信代价减少了10-100倍。
Introduction
手机和平板电脑成为越来越多人的首要计算设备。因为经常被携带着,这些设备上拥有强大的传感器(包括相机、麦克风、GPS),加上他们经常被携带,意味着它们可以获得大量数据,其中有很多是隐私数据。在这些数据上训练的模型承诺通过支持更智能的应用程序来大大提高可用性。但是这些数据的敏感性意味着把它们集中起来训练模型是有风险且责任重大的。
我们研究了一种学习技术,允许用户从共享的模型(从大量数据中训练的)中受益,而不需要将数据集中存储。我们将我们的方法称为联邦学习,因为学习任务是由中央服务器协调的参与设备(我们称之为客户端 client)的松散联盟解决的。每个客户端都有一个从不上传到服务器的本地训练数据集。相反,每个客户端都会计算对服务器维护的当前全局模型的更新,并且仅传递此更新。这是对2012年白宫关于消费者数据隐私的报告提出的集中收集或数据最小化原则的直接应用。由于这些更新专门用于改进当前模型,因此在应用这些更新后没有理由存储它们。
这种方法的一个主要优点是将模型训练与直接访问原始训练数据的需求解耦。显然,仍然需要对服务器协调处理给予信任。对于训练目标的应用可以具体到每个拥有数据的设备上,联合学习可以通过将攻击面限制为仅设备,而不是设备和云来显着降低隐私和安全风险。
我们的主要贡献:
- 将移动设备的分散数据进行训练作为一个重要的研究方向;
- 选择可以应用于此设置的简单实用的算法;
- 对所提出方法的广泛实证评估。
具体的,我们提出了 FedAvg 算法,它将每个客户端上的随机梯度下降 (SGD) 与执行模型平均的服务器相结合。在这个算法上我们执行了大量的实验,证明了其对于非平衡和非独立同分布数据的鲁棒性,并且可以将训练分散数据上的深度网络所需的通信轮次减少几个数量级。
联邦学习 联邦学习的理想问题具有如下几个特点:
- 与在数据中心通常可用的代理数据上进行训练相比,对来自移动设备的真实世界数据的训练具有明显的优势。
- 这些数据对隐私敏感或数据量大(与模型的大小相比较),因此最好不要纯粹出于模型训练的目的将其记录到数据中心(为集中收集原则服务)
- 对于监督学习,数据上的标签可以从用户交互中自然推断出来。
许多支持移动设备智能行为的模型都符合上述标准。举两个例子,考虑图像识别,预测哪一张图片在未来是最可能被多次观看或分享的;和语言模型,可用于通过改进解码、下一个单词预测甚至预测整个回复来改进触摸屏键盘上的语音识别和文本输入。这两项任务的潜在训练数据(用户拍摄的所有照片以及他们在移动键盘上键入的所有内容,包括密码,URL,消息等)都可能对隐私敏感。
绘制这些示例的分布也可能与易于获得的代理数据集有很大不同:聊天和短信中的语言使用通常与标准语言语料库有很大不同,比如维基百科和其他网页文件;用户自己拍摄的照片也和典型的相册不同。
最后,这些数据的标签都是直接可用的:输入的文本是自标注的,用于学习语言模型,照片标签可以通过用户与其照片应用程序的交互来定义(哪些照片被删除,共享或查看)。
这两个任务都非常适合学习神经网络。对于图像分类的前馈神经网络,特别是卷积神经网络,都有很好的表现。对于语言模型的循环神经网络RNN,特别是LSTM也有很好的表现。
隐私 相比较于其余的把数据放在一起训练的模型,联邦学习有着不同的隐私优势。即使是用匿名化的数据,也可能通过与其他数据的联合使用而造成风险。相比之下,为联邦学习传输的信息是改进特定模型所需的最小更新(当然,隐私优势的强度取决于更新的内容)。更新本身可以(也应该)是短暂的。相对于原始数据,其绝不会包含更多的信息,而且一般包含的更少。此外,聚合算法不需要更新的来源,因此可以在不通过混合网络(如Tor [7])或通过受信任的第三方识别元数据的情况下传输更新。在本文末尾,我们简要讨论了将联合学习与安全的多方计算和差分隐私相结合的可能性。
联邦优化 我们将联邦学习中隐含的优化问题称为联邦优化,与分布式优化的联系(和对比)。和典型的分布式优化相比,联邦优化有如下几个关键性质:
- 非独立同分布。在客户端上的训练数据是基于特定用户的移动设备,因此,任何用户的局部数据都不会代表总体的分布。
- 不平衡性。同样,一些用户会比其他用户使用的服务或者app可能比另一些用户更多,所以会产生更多的本地训练数据。
- 广泛分布。我们预计参与优化的客户端数量将远远大于每个客户端的平均示例数。
- 通信限制。移动设备经常处于离线状态,或者连接速度慢或成本高昂。
在我们的工作中,我们的着重点在于非独立同分布和非平衡的优化,以及通信限制的关键性质。
部署的联合优化系统还必须解决无数实际问题:客户机的数据会随着增加和删除而改变;以复杂方式与本地数据分布相关的客户端可用性(例如,说美式英语的人接通电话的时间可能与说英式英语的人不同);以及从不响应或发送攻击性更新的客户端。
这些问题超出了当前工作的范围。相反,我们使用适合实验的可控环境,但仍需要解决客户端可用性以及不平衡和非IID数据的关键问题。
我们假设一个在多轮通信中进行的同步更新方案。
有一组固定的 K 个客户端,每个客户端都有一个固定的本地数据集。在每轮开始时,随机选择客户端的比例C,服务器将当前全局算法状态发送到每个客户端(例如,当前模型参数)。我们只选择一小部分客户来提高效率,因为我们的实验表明,在超过某个点的情况下增加更多客户端会减少回报。然后,每个选定的客户端根据全局状态及其本地数据集执行本地计算,并向服务器发送更新。然后,服务器将这些更新应用于其全局状态,并重复该过程。
虽然我们专注于非凸神经网络目标,但我们提出的算法适用于以下形式的任何有限和目标:
对于机器学习的问题,典型的取,即预测的损失和模型的参数有关。假设有个客户端,客户端的数据是分段的,是客户端上的数据点的索引集,。因此,我们可以把(1)式重写为:
如果分布 是通过随机在客户端上均匀分布训练示例而形成的,那么就会有,其中期望值超过分配给固定客户端的示例集。这是在独立同分布下的分布优化算法;我们参考的这个设置在非独立同分布下并不适用(即能是的一个较差的估计)。
在数据集中的优化里,通信的代价是非常小的,计算的代价占据了大部分,最近的大部分重点是使用GPU来降低这些成本。相反,联邦优化中,通讯占据了主导地位——我们通常会受到1 MB / s或更小的上传带宽的限制。此外,客户通常只有在充电、接通电源和不计流量的 Wi-Fi 连接时才会自愿参与优化。而且,我们预计每一个客户端在一天只会参与少量的更新轮次。另一方面,由于任何单个设备上的数据集与总数据集相比都很小,而现代智能手机拥有相对较快的处理器(包括gpu),,而且现代的智能手机有非常快速的处理器(包括GPU),对于大部分设备,相对于通信代价,计算变得非常容易。因此,我们的目标是利用额外的计算,为了减少训练所需的通信的轮数。
有两种可以增加计算量的方法:
- 提高并行度,我们在每个通信轮次之间使用更多独立工作的客户端。
- 增加了每个客户端的计算量,每个客户端在每轮通信之间执行更复杂的计算,而不是执行像梯度计算这样的简单计算。
我们研究了这两种方法,但是我们实现的加速主要是由于在每个客户端上添加了更多的计算,使用了客户端的最低并行度级别。
相关工作
McDonald等人已经研究了通过对本地训练模型进行迭代平均来进行分布式训练感知机(2010);
Povey等人研究了用于语音识别DNN(2015);
张等研究了一种具有 “软”平均的异步方法(2015)。
这些工作并没有考虑到数据是非平衡和非独立同分布的,这些性质恰好是联邦学习的设定。我们将这种算法风格适应联邦学习设定,并执行适当的实证评估,它提出的问题与数据中心设置中的相关问题不同,并且需要不同的方法。
Neverova等人使用与我们类似的动机[29]还讨论了在设备上保留敏感用户数据的优势。
Shokri和Shmatikov[35]的工作在几个方面相关:他们专注于训练深度网络,强调隐私的重要性,并通过在每轮通信期间仅共享参数的子集来解决通信成本问题;然而,他们也不考虑不平衡和非IID数据,经验评估有限。
在凸设置中,分布式优化和估计也有大量的研究,一些算法特别注重通信效率。
除了假设凸性之外,这项现有工作通常还要求客户端数量远小于每个客户端的示例数量,数据以独立同分布的方式分布在客户端之间,并且每个节点都有一个独特的数量的数据点。所有的这些设定都和联邦学习设定相反。
SGD的异步分布式形式也被应用于训练神经网络,比如Dean等人,但是这种方法需要在联邦学习中进行大量的更新。
分布式共识算法(例如,[41])放宽了IID假设,但仍然不适合在很多客户端上进行通信约束优化。
我们考虑的(参数化)算法系列的一个点是简单的一次性平均,其中每个客户端求解的模型最小化(可能正则化)其本地数据上的损失,并将这些模型平均以生成最终的全局模型。这种方法已经在具有独立同分布数据的凸的情况下进行了广泛的研究,并且已知在最坏的情况下,生成的全局模型并不比在单个客户端上训练模型更好。
The Federated Averaging Algorithm
最近深度学习的众多成功应用几乎完全依赖于随机梯度下降(SGD)的变体进行优化。事实上,许多改进可以理解为调整模型的结构(以及损耗函数),使其更容易通过简单的基于梯度的模型进行优化。因此,我们的联邦学习构建算法也从随机梯度优化开始。
SGD可以简单地应用于联邦优化问题,其中每轮通信都进行一次批量梯度计算(比如在随机选择的客户端上)。这种方法在计算上是有效的,但需要大量的训练轮才能产生良好的模型(例如,即使使用高级方法,如批量标准化,Ioffe和Szegedy [21]训练的MNIST在60大小的小批次上进行50000步)。我们在CIFAR-10实验中考虑了这个基线。
在联邦设定中,挂钟时间的成本很少,可以吸引更多客户端,因此对于我们的基线,我们使用大批量同步SGD。Chen等人[8]的实验表明,这种方法在数据中心环境中是最先进的,它优于异步方法。为了在联合设置中应用此方法,我们在每轮中选择一个比例为的客户端,并计算这些客户端持有的所有数据的损失梯度。
因此, 控制全局批大小, 对应于全批处理(非随机)梯度下降。我们将此基线算法称为 FederatedSGD(或 FedSGD)。
FedSGD的典型实现是 ,固定一个学习率 ,并且对于每一个客户端 计算 (是在当前的模型 下本地数据上的平均梯度),服务端聚合这些梯度,并且应用其来更新 , 其中 。一个等价的更新是,对于任意 , ^ ^,
和。也就是说,每个客户端在本地使用其本地数据对当前模型进行一步梯度下降,然后服务器对结果模型进行加权平均。
以这种方式编写算法后,我们可以通过在平均步骤之前迭代本地更新,来为每个客户端添加更多计算。我们把这种方法叫做 FederatedAveraging (or FedAvg).
计算量由三个关键的参数来控制:
- C,每一轮中进行计算的客户端比例(C是小数)
- E,每轮训练中每个客户端在本地数据集上进行训练更新的次数
- B,用于客户端更新的本地小批量的大小
记 为完整的本地数据集被视为单个小批量处理。则当 , 时,FedAvg等价于FedSGD。对于有个本地实例的客户端,每轮训练中本地更新的次数为 ; 完整的伪代码在算法1中给定:
对于一般的非凸目标函数,在参数空间中平均模型可能导致一个比较坏的结果。
图1:通过使用对50个均匀间隔的值平均两个模型的参数和生成的模型的完整 MNIST 训练集的损失。模型 和 在不同的小数据集上使用 SGD 进行训练。对于左边的曲线, 和 使用不同的随机种子初始化,对于右边的曲线,使用了共同种子。注意轴的刻度。水平线给出了 或 实现的最佳损失(它们非常接近,对应于 和 处的垂直线)。通过共享初始化,平均模型可以显着减少整个训练集的损失(比任一父模型的损失都要好得多)。
按照Goodfellow等人(2015年)的方法 ,当我们平均两个从不同初始条件训练的MNIST数字识别模型时,我们恰好看到了这种不良结果(图1,左)。对于此图,父模型 和 分别在来自 MNIST 训练集中的 600 个示例的非重叠 IID 样本上进行训练。训练是通过 SGD 进行的,固定学习率为 0.1,用于对大小为 50 的小批量进行 240 次更新(或 E = 20 次遍历大小为 600 的小数据集)。这大约是模型开始过度拟合其本地数据集的训练量。
最近的工作表明,在实践中,参数充分的神经网络效果良好,更不容易出现不好的局部极小值(Dauphin等,2014; Goodfellow等,2015; Choromanska等,2015)。
但是当我们开始从相同的随机初始化条件开始,然后将其独立的用在不同的数据子集上训练,我们发现简单的参数平均也有很好的表现。(图1,右):两个模型的平均 在MNIST数据集中有很好的损失下降,这个下降比在单独的小数据集上独立训练要好。虽然图 1 从随机初始化开始,但请注意共享起始模型 用于每一轮 FedAvg,因此同样的直觉适用。
Experimental Results
图像处理和语言模型的良好模型都可以极大的提高移动设备的可用性。对于每一个模型,我们可以选择一个代理的数据集,以便于我们可以更好的研究FedAvg的超参数。虽然每次训练的规模都比较小,我们为这些实验训练了2000个单独模型。然后,我们介绍了CIFAR-10图像分类的分类基准。最后,未来证明FedAvg在真实问题和自然数据划分中的有效性,我们评估了大型的语言模型任务。
我们最初的训练包含了两个数据集上的三个模型。
MNIST数据集识别:
- 多层感知机:2个隐藏层,每个隐藏层有200个节点,激活函数用Relu,总共 199,210 个参数,我们称之为 MNIST 2NN
- CNN:5x5的卷积层(第一层有32个通道,第二层64个,每一层有2x2的最大池化。一个512个节点的全连接层,总共1663370个参数。
为了学习联邦优化,我们也需要探索数据怎样分布在客户端。我们学习了两种在客户端划分MNIAST数据集的方法:
- IID:数据随机打乱,划分到100个客户端,每个客户端有600张图片
- 非IID:按照数字标签(0-9)对数据进行排序,然后划分为200个大小为300的”片段“,然后给100个clients分配2个”片段“。(每人600条数据,且最多有两个数字标签的数据)
这是数据的病态非 IID 分区,因为大多数客户端只有两位数的示例。因此,通过这种方法就可以探索我们的算法在极度非独立同分布数据下的表现(每一种划分都是平衡的)。
探索用户批量与通信轮数(精度提升快慢或损失函数收敛速度)的关系
(2NN达到目标精度97%(E=1),CNN达到目标精度99%(E=5))
(基线:C=0,表示单个用户参与优化)表格1:注意 C = 0.0 对应于每轮一个客户;由于我们为 MNIST 数据使用 100 个客户端,因此这些行对应于 1、10、20、50 和 100 个客户端。每个表条目都给出了实现 2NN 测试集准确度 97% 和 CNN 99% 所需的通信轮数,以及相对于 C = 0 基线的加速。五次大批量运行没有在允许的时间内达到目标精度。
结果表明:
当 时(用户所有数据一次性地迭代优化),随着通信轮数增加,精度提升对参与用户数量不敏感;
当时(小批量),随着通信轮数增加,精度提升对参与用户数量十分敏感(特别是Non-IID数据);
当 时(每轮通信10个用户参与优化),可以达到计算效率和收敛速度的平衡(min(C*通信轮数),即最小化用户总计算量)。
莎士比亚全集 数据集
对于语言模型,我们在莎士比亚全集数据集上构建了模型。用至少两行为每个剧中的每个角色构造一个客户数据集。这种划分数据集的方法产生了1146个客户端。对于每个客户端,我们划分了80%的行作为训练,20%的行作为测试。
训练集中有3,564,579个字符,在测试集中具有870,014个字符。这个数据是相当不平衡的,很多角色只有几行数据,一些角色则有很多的数据。然后,观察到测试集不是随机样本,而是每个剧本按时间顺序将行分为训练集与测试集。此外,使用相同的训练/测试拆分,还构造了平衡的IID版本的数据集,也有1146个客户端。
在这些数据上,我们训练了一个堆叠字符级 LSTM 语言模型,该模型在读取一行中的每个字符后,预测下一个字符。该模型将一系列字符作为输入,并将每个字符嵌入到学习的 8 维空间中。然后通过 2 个 LSTM 层处理嵌入的字符,每个层有 256 个节点。最后,第二个 LSTM 层的输出被发送到每个字符一个节点的 softmax 输出层。完整模型有 866,578 个参数,我们使用 80 个字符的展开长度进行训练。
表 2:FedAvg 与 FedSGD 达到目标精度所需的通信轮数(第一行, 和 )。 列给出了,即每轮的预期更新次数。
SGD对于超参数的调整很敏感。报告在这里的结果是基于广泛的网格搜索。我们检查以确保最佳学习率在我们的网格中间,并且最佳学习率之间没有显着差异。除非另有说明,我们绘制了为每个 x 轴值单独选择的最佳执行率的指标。我们发现最佳学习率不会随着其他参数的变化而变化太大。
Increasing parallelism 提高并行性(fix E,analyze C & B)
我们首先实验了客户端的比例C,这控制了多个客户端并行计算的数量。表格1展示了改变C对于每一个MNIST模型的影响。通信的次数影响了测试集的准确率。
为了计算这一点,我们为参数设置组合构建一个学习曲线,如上所述优化 ,然后通过采用在之前所有回合中实现的测试集精度的最佳值来使每个曲线单调改进。然后,我们使用形成曲线的离散点之间的线性插值来计算曲线与目标精度相交的轮数。
最好参考图2来理解这一点,其中灰线显示了目标。
图2
当 时(MNIST中每个客户端的600个数据每一轮作为一个批量),增加客户端的比例只有很小的优势。
当 时,当 时有显著的提升,特别是对于非独立同分布的情况。
基于这些结果,对于大多数实验,我们将 固定下来,这对于计算有效和收敛速度上面有很好的平衡。
和 与 的通信次数比较,表格1中显示了显着的加速(之后再探究)
MNIST CNN 的测试集精度与通信轮数。灰色线显示表 2 中使用的目标精度。附录 A 中的图 7 给出了 2NN 的绘图。
Increasing computation per client 每个客户端提高计算量(固定 C=0.1),对于每轮通信次数,每个客户端需要提高计算量,减小 ,或者提高 , 或者都做。
图 2 表明,每轮添加更多本地 SGD 更新可以显著降低通信成本,表格2则量化了加速的效果。每一轮中,每个客户端期望的更新次数是,其中期望是随机客户端 。通过统计,我们对表格2进行了排序。我们可以看见,通过改变 和 都可以增加 的值。
只要 足够大,可以充分利用客户端硬件上可用的并行性,降低它的计算时间基本上没有成本,因此在实践中,这应该是第一个调整的参数。
对于独立同分布的MNIST数据,在每一个客户端使用更多的计算量减少了达到所需的准确率的通信次数。(CNN提高了35倍,2NN提高了46倍)对于非独立同分布的数据,加速的效果更小,基本上是2.8倍到3.7倍之间。
令人印象深刻的是,当我们简单地平均在完全不同的数字对上训练的模型的参数时,平均提供了一切优势(与实际上发散相比)。因此,我们认为这是这种方法稳健性的有力证据。
莎士比亚的不平衡和非IID分布(按剧中的角色)更能代表我们对现实世界应用所期望的数据分布类型。令人鼓舞的是,对于这个问题,在非IID和不平衡数据上学习实际上要容易得多(平衡IID数据加速95×而非平衡IID数据加速13×);我们推测这主要是由于某些角色具有相对较大的本地数据集,这使得增加的本地训练特别有价值。
对于所有三个模型类,FedAvg 收敛到比基线 FedSGD 收敛到更高的测试集精度。即使线超出绘制的范围,这种趋势也会继续。例如,对于CNN,, FedSGD模型在1200轮后最终达到99.22%的准确率(并且在6000轮后没有进一步提高),而, FedAvg模型在300轮后达到99.44%的准确率。我们推测,除了降低通信成本外,模型平均还产生了类似于Dropout所获得的正则化收益[36]。
我们主要关注的是泛化性能,但 FedAvg 在优化训练损失方面也很有效,甚至超出了测试集精度稳定点的范围。我们观察到所有三个模型类的相似行为,并在附录A的图6中给出了MNIST CNN的图。
Can we over-optimize on the client datasets? 当前模型参数仅影响通过初始化在每个ClientUpdate中执行的优化。因此,正如,至少对于凸问题,初始条件最终应该是无关紧要的,并且无论初始化如何,都将达到全局最小值。即使对于非凸问题,人们也可以推测,只要初始化位于同一盆地中,该算法就会收敛到相同的局部最小值。也就是说,我们预计,虽然一轮平均可能会产生一个合理的模型,但额外的沟通(和平均)回合不会产生进一步的改进。
图 3 显示了在初始训练期间比较大的对莎士比亚 LSTM 问题的影响。事实上,对于非常大量的局部训练epoch,FedAvg可以趋于稳定或发散。
图3:对于固定学习率的莎士比亚LSTM,在平均步骤之间固定和的许多局部epoch(大)的训练效果。
这一结果表明,对于某些模型,特别是在收敛的后期阶段,衰减每轮局部计算量(移动到更小的或更大的)可能是有用的,以同样的方式衰减学习率也是有用的。附录A中的图8给出了MNIST CNN的类似实验。有趣的是,对于这个模型,我们看到大值的收敛速度没有显著退化。然而,对于下面描述的大规模语言建模任务,我们看到E = 1比E = 5的性能略好(参见附录A中的图10)。
图4:CIFAR10实验的测试精度与通信。FedSGD每轮使用0.9934的学习率衰减;FedAvg使用,每轮学习率衰减。
CIFAR experiments CIFAR实验
我们还对CIFAR-10数据集[24]进一步验证FedAvg。数据集由 10 类 32x32 图像组成,其中 3 个RGB 通道。有 50,000 个训练示例,并且有10,000 个测试示例,我们将其划分为 100 个客户端,每个客户端包含 500 个训练和 100 个测试。由于这些数据没有自然的用户分区,我们考虑了平衡和IID设置。
模型架构取自TensorFlow教程[38],它由两个卷积层组成,然后是两个全连接层,然后是线性变换层以生成 logits,总共约 10的6次方个参数。请注意,最先进的方法已经实现了CIFAR的96.5%[19]的测试精度;然而,我们使用的标准模型足以满足我们的需求,因为我们的目标是评估我们的优化方法,而不是在这项任务上达到最佳的准确性。图像作为训练输入管道的一部分进行预处理,该管道包括将图像裁剪为24x24,随机左右翻转以及调整对比度,亮度和美白。
对于这些实验,我们考虑了一个额外的基线,即整个训练集的标准SGD训练(无用户分区),使用100大小的小批量。在 197,500 次小容量更新后,我们实现了 86% 的测试准确率(每个小容量更新都需要在联合设置中进行一轮通信)。
FedAvg在仅2,000轮通信后就实现了85%的类似测试精度。对于所有算法,除了初始学习速率之外,我们还调整了学习速率衰减参数。表 3 给出了基线 SGD、FedSGD 和 FedAvg 达到三个不同精度目标的通信轮次数,图 4 给出了 FedAvg 与 FedSGD 的学习速率曲线。
表3:在CIFAR10上达到目标测试集精度时,相对于基准SGD的轮数和加速。SGD使用了100个小批量。FedSGD和FedAvg使用, FedAvg使用。
图4:CIFAR10实验的测试精度与通信。FedSGD每轮使用0.9934的学习率衰减;FedAvg使用,每轮学习率衰减。
通过对SGD和FedAvg进行尺寸为的小批量的实验,我们还可以将精度视为此类小批量梯度计算的函数。我们预计SGD在这里做得更好,因为在每次小分量计算之后都会采取一个连续的步骤。但是,如附录中的图 9 所示,对于 和 的适度值,FedAvg 对每个小分量计算的进度相似。此外,我们看到标准 SGD 和 FedAvg(每轮只有一个客户端())都显示出显著的准确性振荡,而对更多客户端求平均值可以消除这种情况。
Large-scale LSTM experiments 大规模 LSTM 实验
我们在大规模下一个单词预测任务上运行了实验,以证明我们的方法在现实世界中的有效性。我们的训练数据集包含来自大型社交网络的 1000 万个公开帖子。我们按作者对帖子进行了分组,总共有超过500,000个客户。此数据集是用户移动设备上将存在的文本输入数据类型的实际代理。我们将每个客户端数据集限制为最多 5000 个单词,并在来自不同(非训练型)作者的 1e5 个帖子的测试集上报告准确性(在 10000 个可能性中,预测概率最高的数据比例为正确的下一个单词)。我们的模型是一个256节点的LSTM,词汇量为10,000个单词。每个单词的输入和输出嵌入的维度为192,并与模型共同训练;总共有4,950,544个参数。我们使用了10个单词的展开。
这些实验需要大量的计算资源,因此我们没有彻底探索超参数:所有运行都在每轮200个客户端上训练;FedAvg 使用 和 。我们探讨了FedAvg和基准FedSGD的各种学习率。图5显示了最佳学习速率的单调学习曲线。的FedSGD需要820通信次数才能达到10.5%的准确率,而η = 9.0的FedAvg仅在35个通信次数中就达到了10.5%的准确率(比FedSGD少23×)。
我们观察到FedAvg的测试精度差异较低,参见附录A中的图10。此图还包括 的结果,其表现略低于 。
图5:大规模语言模型词LSTM的单调学习曲线。
Conclusions and Future Work
我们的实验表明,联邦学习是可行的,因为FedAvg使用相对较少的通信次数训练高质量的模型,正如各种模型架构的结果所证明的那样:多层感知器,两个不同的卷积NN,一个双层字符LSTM和一个大规模的单词级LSTM。
虽然联邦学习提供了许多实用的隐私优势,但通过差分隐私提供更强的保证,安全的多方计算,或者它们的组合是未来工作的一个有趣的方向。请注意,这两类技术最自然地适用于FedAvg等同步算法。