如何从Pytorch张量中去掉每一个填充了零的列?

2024-03-28 08:40:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个pytorch张量A,如下所示:

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)

我想:

  • 标识其所有条目均等于0的所有列
  • 只删除所有条目都等于0的列

我可以想象这样做:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)

标识其元素只有0的列 但是我不知道如何从原来的张量中删除这些填充了0的列。你知道吗

如何根据索引号从pytorch张量中删除带零的列?你知道吗

谢谢你


Tags: inforsizeif条目rangetorchpytorch
1条回答
网友
1楼 · 发布于 2024-03-28 08:40:29

索引要保留的列比索引要删除的列更有意义。你知道吗

valid_cols = []
for col_idx in range(A.size(1)):
    if not torch.all(A[:, col_idx] == 0):
        valid_cols.append(col_idx)
A = A[:, valid_cols]

或者更神秘一点

valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
A = A[:, valid_cols]

相关问题 更多 >