我正在做一个pytorch项目,我的数据保存在zarr
对zarr
的随机访问成本很高,但由于zarr
使用了分块缓存,迭代速度非常快。为了利用这一事实,我将IterableDataset
与多个worker一起使用:
class Data(IterableDataset):
def __init__(self, path, start=None, end=None):
super(Data, self).__init__()
store = zarr.DirectoryStore(path)
self.array = zarr.open(store, mode='r')
if start is None:
start = 0
if end is None:
end = self.array.shape[0]
assert end > start
self.start = start
self.end = end
def __iter__(self):
return islice(self.array, self.start, self.end)
问题是self.array
具有10e9
行的顺序,对于连续的工作者,随着self.start
和self.end
自然变大,创建像itertools.islice(array, start, end)
这样的生成器需要花费大量的时间,因为islice
仍然必须迭代不需要的元素,直到它到达start
。一旦每个工人都创建了一个发电机,这就像一个符咒,但要达到这个目标需要很长时间
有没有更好的方法来创建这样的生成器?或者也许有一种更聪明的方法在pytorch
中使用zarr
我在zarr中做了一个小动作,看起来这将很容易从zarr内部启用。我已经打开了一个问题here,同时我制作了一个fork of zarr来实现函数
array.islice(start, end)
dataset
__iter__
方法如下所示:相关问题 更多 >
编程相关推荐