对一批数组使用torch.where()

0 投票
1 回答
29 浏览
提问于 2025-04-12 20:00

我正在使用pytorch,想要对一批数组应用一个简单的torch.where(array > 0)操作,但不想用循环,想知道怎么用torch的函数来实现这个代码。

def batch_node_indices(states_batch):
    batch_indices = []
    for state in states_batch:
        node_indices = torch.where(state > 0)[0].detach().cpu().numpy()
        batch_indices.append(node_indices)
    return batch_indices

我尝试了不同的torch函数,但没有成功。我希望这个方法能返回一批数组,每个数组里包含状态数组中大于0的索引。

1 个回答

0

你说的“索引批次”具体指的是什么呢?

torch.where(condition) 这样的函数有个问题,就是每个批次里的项目,符合 condition=True 的元素数量都不一样。这就意味着你不能对整个批次同时使用 where,因为每个项目的输出大小都不同。

默认情况下,where 会输出一个元组,里面有多个张量,每个张量对应一个维度,显示所有符合 condition=True 的索引元组。为了处理这些不规则的大小问题,输出的索引会被压平。你可以根据需要使用这些输出的索引来获取每个批次的结果。

x = torch.randn(16, 32, 64)
indices = torch.where(x>0)
print(indices)
> (tensor([ 0,  0,  0,  ..., 15, 15, 15]),
>  tensor([ 0,  0,  0,  ..., 31, 31, 31]),
>  tensor([ 4,  5,  7,  ..., 61, 62, 63]))

index_tensor = torch.stack(indices)
# for example, select outputs from the first item in the batch
index_tensor[:, index_tensor[0] == 0]

你还可以在 torch.where 中使用额外的参数,这样可以返回一个和输入形状相同的张量。例如:

x = torch.randn(16, 32, 64)
x1 = torch.where(x>0, 1, 0) # fills 1 where x>0, 0 elsewhere

# shape is retained 
x.shape == x1.shape
> True

x2 = torch.where(x>0, x, float('-inf')) # returns `x` with -inf where x<0

撰写回答