PyTorch按索引列表分割数组
我想根据一组索引来拆分一个torch数组。
比如说,我的输入数组是 torch.arange(20)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19])
而我的索引列表是 splits = [1,2,5,10]
那么我的结果会是:
(tensor([0]),
tensor([1, 2]),
tensor([3, 4, 5, 6, 7]),
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]))
假设我的输入数组总是足够长,超过我索引列表的总和。
4 个回答
0
你可以使用PyTorch里的torch.split函数,配合列表推导式,来根据给定的索引把数组分开。
import torch
input_array = torch.arange(20)
splits = [1, 2, 5, 10]
result = [input_array.split(split) for split in splits]
print(tuple(result))
2
另一个可能的选择是先对张量进行切片,然后再用split
来分割它:
import torch
t = torch.arange(20)
splits = [1, 2, 5, 10]
out = torch.split(t[: sum(splits)], splits)
输出结果:
(tensor([0]),
tensor([1, 2]),
tensor([3, 4, 5, 6, 7]),
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]))
2
你可以使用 tensor_split
来处理累加和的 splits
(比如用 np.cumsum
),但要排除最后一部分:
import torch
import numpy as np
t = torch.arange(20)
splits = [1,2,5,10]
t.tensor_split(np.cumsum(splits).tolist())[:-1]
输出结果:
(tensor([0]),
tensor([1, 2]),
tensor([3, 4, 5, 6, 7]),
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]),
)