<p>由于您计划使用iterable数据集,因此不需要随机访问(<code>IterableDataset</code>不支持随机采样)。在这种情况下,为什么不将所有内容写入二进制文件并对其进行迭代呢?我发现在实践中,这通常比其他解决方案快得多。这应该比保存为文本文件快得多,因为可以避免将文本转换为数字的开销</p>
<p>示例实现可能如下所示。首先,我们可以按如下方式构建一个二进制文件(包含作为占位符的随机数据)</p>
<pre class="lang-py prettyprint-override"><code>import numpy as np
from tqdm import tqdm
filename = 'data.bin'
num_samples = 3600000
rows, cols = 30, 32
dtype = np.float32
# format: <num_samples> <rows> <cols> <sample0> <sample1>...
with open(filename, 'wb') as fout:
# write a header that contains the total number of samples and the rows and columns per sample
fout.write(np.array((num_samples, rows, cols), dtype=np.int32).tobytes())
for i in tqdm(range(num_samples)):
# random placeholder
sample = np.random.randn(rows, cols).astype(dtype)
# write data to file
fout.write(sample.tobytes())
</code></pre>
<p>然后我们可以定义一个<code>IterableDataset</code>,如下所示</p>
<pre class="lang-py prettyprint-override"><code>import numpy as np
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm
def binary_reader(filename, start=None, end=None, dtype=np.float32):
itemsize = np.dtype(dtype).itemsize
with open(filename, 'rb') as fin:
num_samples, rows, cols = np.frombuffer(fin.read(3 * np.dtype(np.int32).itemsize), dtype=np.int32)
start = start if start is not None else 0
end = end if end is not None else num_samples
blocksize = itemsize * rows * cols
start_offset = start * blocksize
fin.seek(start_offset, 1)
for _ in range(start, end):
yield np.frombuffer(fin.read(blocksize), dtype=dtype).reshape(rows, cols).copy()
class BinaryIterableDataset(IterableDataset):
def __init__(self, filename, start=None, end=None, dtype=np.float32):
super().__init__()
self.filename = filename
self.start = start
self.end = end
self.dtype = dtype
def __iter__(self):
return binary_reader(self.filename, self.start, self.end, self.dtype)
</code></pre>
<p>通过在我的系统(使用SSD存储)上对该数据集进行快速测试,我发现我能够在大约10秒内迭代360万个样本</p>
<pre class="lang-py prettyprint-override"><code>dataset = BinaryIterableDataset('data.bin')
for sample in tqdm(dataset):
pass
</code></pre>
<pre><code>3600000it [00:09, 374026.17it/s]
</code></pre>
<p>使用带有<code>batch_size=256</code>的<code>DataLoader</code>迭代整个数据集大约需要20秒(转换为张量和创建批处理会有一些开销)。对于这个数据集,我发现当使用并行加载时,将数据传输到共享内存和从共享内存传输数据的开销实际上比仅使用0个worker要慢得多。因此,我建议使用<code>num_workers=0</code>。与任何iterable数据集一样,您需要添加额外的逻辑来支持num_workers>;1,虽然我不确定在这种情况下是否值得</p>
<pre class="lang-py prettyprint-override"><code>loader = DataLoader(dataset, batch_size=256, num_workers=0)
for batch in tqdm(loader):
# batch is a tensor of shape (256, 30, 32)
pass
</code></pre>
<pre><code>14063it [00:19, 710.49it/s]
</code></pre>
<p>请注意<code>data.bin</code>文件不能跨使用不同字节顺序的系统移植。尽管可以进行修改以支持这一点</p>