如何在python中对这个(numpy)操作进行矢量化?

2024-04-19 08:29:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个向量的形状(batch, dim),我试图从另一个减去。目前,我正在使用一个简单的循环,根据第二个向量(即label),从1中减去向量(即error)中的特定项:

per_ts_loss=0
for i, idx in enumerate(np.argmax(label, axis=1)):
    error[i, idx] -=1
    per_ts_loss += error[i, idx]

如何将其矢量化?你知道吗

例如,错误和标签可以如下所示:

error :
array([[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
       [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]])
label:
    array([[0, 0, 0, 1, 0 ],
           [0, 1, 0, 0, 0]])

对于本例,运行下面的代码会产生以下结果:

for i, idx in enumerate(np.argmax(label,axis=1)):
    error[i,idx] -=1
    ls_loss += error[i,idx]

结果:

error: 
 [[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
 [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]]
label: 
 [[ 0.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  0.]]

error(indexes 3 and 1 are changed): 
[[ 0.5488135   0.71518937  0.60276338 -0.45511682  0.4236548 ]
 [ 0.64589411 -0.56241279  0.891773    0.96366276  0.38344152]]
per_ts_loss: 
 -1.01752960574

下面是代码本身:https://ideone.com/e1k8ra

我被困在如何使用np.argmax的结果上,因为结果是一个新的索引向量,不能简单地像这样使用:

 error[:, np.argmax(label, axis=1)] -=1

所以我被困在这里了!你知道吗


Tags: 代码infornperrorarray向量label
2条回答

也许是这样:

import numpy as np


error = np.array([[0.32783139, 0.29204386, 0.0572163 , 0.96162543, 0.8343454 ],
       [0.67308787, 0.27715222, 0.11738748, 0.091061  , 0.51806117]])

label= np.array([[0, 0, 0, 1, 0 ],
           [0, 1, 0, 0, 0]])



def f(error, label):
    per_ts_loss=0
    t=np.zeros(error.shape)
    argma=np.argmax(label, axis=1)
    t[[i for i in range(error.shape[0])],argma]=-1
    print(t)
    error+=t
    per_ts_loss += error[[i for i in range(error.shape[0])],argma]


f(error, label)

输出:

[[ 0.  0.  0. -1.  0.]
 [ 0. -1.  0.  0.  0.]]

替换:

error[:, np.argmax(label, axis=1)] -=1

使用:

error[np.arange(error.shape[0]), np.argmax(label, axis=1)] -=1

当然了

loss = error[np.arange(error.shape[0]), np.argmax(label, axis=1)].sum()

在您的示例中,您正在更改和求和error[0,3]error[1,1],简而言之error[[0,1],[3,1]]。你知道吗

相关问题 更多 >