如何在Tensorflow Fedared中加载时尚MNIST数据集?

2024-06-16 09:03:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在与Tensorflow federated合作一个项目。我已经设法使用TensorFlow联邦学习模拟提供的库来加载、训练和测试一些数据集

例如,我加载emnist数据集

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

它将load_data()返回的数据集作为tff.simulation.ClientData的实例。这是一个接口,允许我在客户端ID上迭代,并允许我为模拟选择数据子集

len(emnist_train.client_ids)

3383


emnist_train.element_type_structure


OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])


example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

我正在尝试用Keras加载fashion_mnist数据集,以执行一些联合操作:

fashion_train,fashion_test=tf.keras.datasets.fashion_mnist.load_data()

但是我得到了这个错误

AttributeError: 'tuple' object has no attribute 'element_spec'

因为Keras返回Numpy数组的元组,而不是像以前一样返回tff.simulation.ClientData:

def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=factory.retrieve_model(True),
        input_spec=fashion_test.element_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

iterative_process = tff.learning.build_federated_averaging_process(
    tff_model_fn, Parameters.server_adam_optimizer_fn, Parameters.client_adam_optimizer_fn)
server_state = iterative_process.initialize()

总之,

  1. 有没有办法从Keras tuple Numpy数组创建tff.simulation.ClientData的元组元素

  2. 我想到的另一个解决办法是使用 tff.simulation.HDF5ClientData并加载 手动以HDF5格式(train.h5, test.h5)获取tff.simulation.ClientData的适当文件,但我的问题是我找不到fashion_mnist HDF5文件格式的url,我指的是类似于训练和测试的url:

          fileprefix = 'fed_emnist_digitsonly'
          sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
          filename = fileprefix + '.tar.bz2'
          path = tf.keras.utils.get_file(
              filename,
              origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
              file_hash=sha256,
              hash_algorithm='sha256',
              extract=True,
              archive_format='tar',
              cache_dir=cache_dir)
    
          dir_path = os.path.dirname(path)
          train_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path, fileprefix + '_train.h5'))
          test_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path, fileprefix + '_test.h5'))
    
          return train_client_data, test_client_data
    

我的最终目标是使fashion_mnist数据集与TensorFlow联合学习一起工作


Tags: 数据pathtestclientdatamodeltfdir
1条回答
网友
1楼 · 发布于 2024-06-16 09:03:38

你在正确的轨道上。总而言之,^{}API返回的数据集是^{}对象。^{}返回的对象是tuple个numpy数组

因此,需要实现一个tff.simulation.ClientData来包装tf.keras.datasets.fashion_mnist.load_data返回的数据集。关于实现ClientData对象的一些以前的问题:

这确实需要回答一个重要的问题:时装MNIST数据应该如何划分为单个用户?数据集不包含可用于分区的功能。研究人员已经提出了一些综合划分数据的方法,例如,为每个参与者随机抽取一些标签,但这将对模型训练产生很大影响,有助于在这里投入一些思考

相关问题 更多 >