PyTorch按索引列表分割数组

0 投票
4 回答
57 浏览
提问于 2025-04-14 18:29

我想根据一组索引来拆分一个torch数组。

比如说,我的输入数组是 torch.arange(20)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

而我的索引列表是 splits = [1,2,5,10]

那么我的结果会是:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))

假设我的输入数组总是足够长,超过我索引列表的总和。

4 个回答

0

你可以使用PyTorch里的torch.split函数,配合列表推导式,来根据给定的索引把数组分开。

import torch

input_array = torch.arange(20)
splits = [1, 2, 5, 10]

result = [input_array.split(split) for split in splits]
print(tuple(result))
2

另一个可能的选择是先对张量进行切片,然后再用split来分割它:

import torch

t = torch.arange(20)

splits = [1, 2, 5, 10]

out = torch.split(t[: sum(splits)], splits)

输出结果:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))
2

你可以使用 tensor_split 来处理累加和的 splits(比如用 np.cumsum),但要排除最后一部分:

import torch
import numpy as np

t = torch.arange(20)
splits = [1,2,5,10]

t.tensor_split(np.cumsum(splits).tolist())[:-1]

输出结果:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]),
)

撰写回答