【Tensorflow】mnist数据集安装方法

880 阅读2分钟

方法一:在线安装

Tansorflow_federated教程给的官方的安装方法。执行以下代码,会自动下载mnist数据集,并且在头一次执行下载完之后,以后执行不会重复下载。

import tensorflow_federated as tff
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

方法二:修改tff包函数

由于第一种方法非常耗时(运行之后就一直出于下载状态不动),因此就琢磨出了第二种方法,查看tff.simulation.datasets.emnist.load_data()函数的源代码并修改下载方式。首先观察源代码:

def load_data(only_digits=True, cache_dir=None):
  if only_digits:
    fileprefix = 'fed_emnist_digitsonly'
    sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
  else:
    fileprefix = 'fed_emnist'
    sha256 = 'fe1ed5a502cea3a952eb105920bff8cffb32836b5173cb18a57a32c3606f3ea0'
 
  filename = fileprefix + '.tar.bz2'
  path = tf.keras.utils.get_file(
      filename,
      origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
      file_hash=sha256,
      hash_algorithm='sha256',
      extract=True,
      archive_format='tar',
      cache_dir=cache_dir)
  dir_path = os.path.dirname(path)
  
  train_client_data = hdf5_client_data.HDF5ClientData(
      os.path.join(path, fileprefix + '_train.h5'))
  test_client_data = hdf5_client_data.HDF5ClientData(
      os.path.join(path, fileprefix + '_test.h5'))

  return train_client_data, test_client_data

可以看出,返回的数据集接口train_client_datatest_client_data来源于path路径获取的,而path路径是从storage.googleapis.com/tff-dataset… 路径下载的fed_emnist_digitsonly.tar.bz2,然后解压成fed_emnist_digitsonly_train.h5和fed_emnist_digitsonly_test.h5返回的。 因此根据下载路径下载好这两个数据集文件(点击下载:fed_emnist_digitsonly.tar.bz2fed_emnist.tar.bz2),解压成4个h5文件,放在一个文件夹内,再将此文件夹的路径赋值到path参数上。具体代码如下,记得path是自己指定的路径,并且要把相关内容注释掉,不然还是会线上下载。

def load_data(only_digits=True, cache_dir=None):
  if only_digits:
    fileprefix = 'fed_emnist_digitsonly'
    sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
  else:
    fileprefix = 'fed_emnist'
    sha256 = 'fe1ed5a502cea3a952eb105920bff8cffb32836b5173cb18a57a32c3606f3ea0'
  '''
  filename = fileprefix + '.tar.bz2'
  path = tf.keras.utils.get_file(
      filename,
      origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
      file_hash=sha256,
      hash_algorithm='sha256',
      extract=True,
      archive_format='tar',
      cache_dir=cache_dir)
  
  dir_path = os.path.dirname(path)
  '''
  train_client_data = hdf5_client_data.HDF5ClientData(
      os.path.join("C:/Users/new-f/PycharmProjects/pythonProject/data", fileprefix + '_train.h5'))  # 修改成自己的路径
  test_client_data = hdf5_client_data.HDF5ClientData(
      os.path.join("C:/Users/new-f/PycharmProjects/pythonProject/data", fileprefix + '_test.h5'))

  return train_client_data, test_client_data

方法三:自己定义函数读取数据

参考:blog.csdn.net/qq_37337494…