方法一:在线安装
Tansorflow_federated教程给的官方的安装方法。执行以下代码,会自动下载mnist数据集,并且在头一次执行下载完之后,以后执行不会重复下载。
import tensorflow_federated as tff
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
方法二:修改tff包函数
由于第一种方法非常耗时(运行之后就一直出于下载状态不动),因此就琢磨出了第二种方法,查看tff.simulation.datasets.emnist.load_data()函数的源代码并修改下载方式。首先观察源代码:
def load_data(only_digits=True, cache_dir=None):
if only_digits:
fileprefix = 'fed_emnist_digitsonly'
sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
else:
fileprefix = 'fed_emnist'
sha256 = 'fe1ed5a502cea3a952eb105920bff8cffb32836b5173cb18a57a32c3606f3ea0'
filename = fileprefix + '.tar.bz2'
path = tf.keras.utils.get_file(
filename,
origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
file_hash=sha256,
hash_algorithm='sha256',
extract=True,
archive_format='tar',
cache_dir=cache_dir)
dir_path = os.path.dirname(path)
train_client_data = hdf5_client_data.HDF5ClientData(
os.path.join(path, fileprefix + '_train.h5'))
test_client_data = hdf5_client_data.HDF5ClientData(
os.path.join(path, fileprefix + '_test.h5'))
return train_client_data, test_client_data
可以看出,返回的数据集接口train_client_data和test_client_data来源于path路径获取的,而path路径是从storage.googleapis.com/tff-dataset…
路径下载的fed_emnist_digitsonly.tar.bz2,然后解压成fed_emnist_digitsonly_train.h5和fed_emnist_digitsonly_test.h5返回的。
因此根据下载路径下载好这两个数据集文件(点击下载:fed_emnist_digitsonly.tar.bz2和fed_emnist.tar.bz2),解压成4个h5文件,放在一个文件夹内,再将此文件夹的路径赋值到path参数上。具体代码如下,记得path是自己指定的路径,并且要把相关内容注释掉,不然还是会线上下载。
def load_data(only_digits=True, cache_dir=None):
if only_digits:
fileprefix = 'fed_emnist_digitsonly'
sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
else:
fileprefix = 'fed_emnist'
sha256 = 'fe1ed5a502cea3a952eb105920bff8cffb32836b5173cb18a57a32c3606f3ea0'
'''
filename = fileprefix + '.tar.bz2'
path = tf.keras.utils.get_file(
filename,
origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
file_hash=sha256,
hash_algorithm='sha256',
extract=True,
archive_format='tar',
cache_dir=cache_dir)
dir_path = os.path.dirname(path)
'''
train_client_data = hdf5_client_data.HDF5ClientData(
os.path.join("C:/Users/new-f/PycharmProjects/pythonProject/data", fileprefix + '_train.h5')) # 修改成自己的路径
test_client_data = hdf5_client_data.HDF5ClientData(
os.path.join("C:/Users/new-f/PycharmProjects/pythonProject/data", fileprefix + '_test.h5'))
return train_client_data, test_client_data