TensorFlow Federated:如何调整联邦数据集中的非IID性?

2024-05-16 02:29:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在TensorFlow Federated(TFF)中测试一些算法。在这方面,我想在具有不同“级别”的数据异构性(即非IID性)的相同联邦数据集上测试和比较它们

因此,我想知道是否有任何方法可以自动或半自动地控制和调整特定联邦数据集中的非IID“级别”,例如通过TFF API或传统TF API(可能在数据集UTIL内部)

更实际:例如,TFF提供的EMNIST联邦数据集有3383个客户端,每个客户端都有自己的手写字符。然而,这些本地数据集在本地示例的数量和表示的类(所有类或多或少都在本地表示)方面似乎相当平衡。 如果我想要一个联邦数据集(例如,从TFF的EMNIST数据集开始),即:

  • 从语法上讲是非IID的,例如,客户端只持有N个类中的一个类(总是指分类任务)。这就是{}{a1}的目的吗。如果是这样,我应该如何从联邦数据集(如TFF已经提供的数据集)使用它
  • 本地示例数量不平衡(例如,一个客户有10个示例,另一个客户有100个示例)
  • 两种可能性

我应该如何在TFF框架内准备具有这些特征的联邦数据集

我应该手工做所有的事情吗?或者你们中的一些人对自动化这个过程有什么建议吗

另外一个问题:在Hsu等人的论文"Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification"中,他们利用Dirichlet分布来合成不完全相同的客户群体,并使用浓度参数来控制客户之间的一致性。这似乎是一种调整方法,以产生具有不同异质性级别的数据集。任何关于如何在TFF框架内实现这一策略(或类似策略)的建议,或者仅仅是在TensorFlow(Python)中考虑一个简单的数据集(如EMNIST),都将非常有用

多谢各位


Tags: 数据方法框架api客户端示例数量客户
1条回答
网友
1楼 · 发布于 2024-05-16 02:29:11

对于联合学习模拟,在实验驱动程序中用Python设置客户机数据集以实现所需的分布是非常合理的。在某些高层,TFF处理建模数据位置(“类型系统中的放置”)和计算逻辑。重新混合/生成模拟数据集并不是该库的核心,尽管您已经发现了一些有用的库。在python中直接通过操作tf.data.Dataset然后将客户机数据集“推”到TFF计算中来实现这一点似乎很简单

非IID标签

是的,^{}就是为了这个目的

它采用tf.data.Dataset并基本上过滤掉所有与label_keydesired_label值不匹配的示例(假设数据集生成类似dict的结构)

对于EMNIST而言,要创建一个包含所有数据集的数据集(无论用户为何),可以通过以下方式实现:

train_data, _ = tff.simulation.datasets.emnist.load_data()
ones = tff.simulation.datasets.build_single_label_dataset(
  train_data.create_tf_dataset_from_all_clients(),
  label_key='label', desired_label=1)
print(ones.element_spec)
>>> OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
print(next(iter(ones))['label'])
>>> tf.Tensor(1, shape=(), dtype=int32)

数据不平衡

使用^{}^{}的组合可用于创建数据不平衡

train_data, _ = tff.simulation.datasets.emnist.load_data()
datasets = [train_data.create_tf_dataset_for_client(id) for id in train_data.client_ids[:2]]
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [93, 109]
datasets[0] = datasets[0].repeat(5)
datasets[1] = datasets[1].take(5)
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [465, 5]

相关问题 更多 >