对一批数组使用torch.where()
我正在使用pytorch,想要对一批数组应用一个简单的torch.where(array > 0)操作,但不想用循环,想知道怎么用torch的函数来实现这个代码。
def batch_node_indices(states_batch):
batch_indices = []
for state in states_batch:
node_indices = torch.where(state > 0)[0].detach().cpu().numpy()
batch_indices.append(node_indices)
return batch_indices
我尝试了不同的torch函数,但没有成功。我希望这个方法能返回一批数组,每个数组里包含状态数组中大于0的索引。
1 个回答
0
你说的“索引批次”具体指的是什么呢?
像 torch.where(condition)
这样的函数有个问题,就是每个批次里的项目,符合 condition=True
的元素数量都不一样。这就意味着你不能对整个批次同时使用 where
,因为每个项目的输出大小都不同。
默认情况下,where
会输出一个元组,里面有多个张量,每个张量对应一个维度,显示所有符合 condition=True
的索引元组。为了处理这些不规则的大小问题,输出的索引会被压平。你可以根据需要使用这些输出的索引来获取每个批次的结果。
x = torch.randn(16, 32, 64)
indices = torch.where(x>0)
print(indices)
> (tensor([ 0, 0, 0, ..., 15, 15, 15]),
> tensor([ 0, 0, 0, ..., 31, 31, 31]),
> tensor([ 4, 5, 7, ..., 61, 62, 63]))
index_tensor = torch.stack(indices)
# for example, select outputs from the first item in the batch
index_tensor[:, index_tensor[0] == 0]
你还可以在 torch.where
中使用额外的参数,这样可以返回一个和输入形状相同的张量。例如:
x = torch.randn(16, 32, 64)
x1 = torch.where(x>0, 1, 0) # fills 1 where x>0, 0 elsewhere
# shape is retained
x.shape == x1.shape
> True
x2 = torch.where(x>0, x, float('-inf')) # returns `x` with -inf where x<0