Numpy中使用元组作为某些轴的多维索引
我有一个三维的numpy数组,实际上是一个矩阵的集合。我想通过一种方法把对角线的值设为零。当我打印这个元组和din
时,它们看起来完全一样,但返回的数组视图却不同。
m = np.random.normal(0, 0.2, (10, 4, 4))
din = np.diag_indices(m.shape[1], ndim = 2)
m[:, np.array([0,1,2,3]), np.array([0,1,2,3])]) # It returns an array of diagonals as expected
m[:, tuple(din)] # It returns the array
我这里漏掉了什么呢?
1 个回答
3
正如评论中所说的,你需要把索引拆开。
从 Python 3.11 开始,你可以使用:
m[:, *din]
输出结果:
array([[ 8.61622699e-02, -1.46919069e-01, -9.37771599e-02,
1.94698315e-03],
[ 1.60933774e-01, -2.77077615e-02, -1.74135776e-01,
-1.72223723e-01],
[-1.54804225e-01, 1.08146714e-01, 2.51844877e-01,
-2.91622737e-02],
[ 1.22213756e-02, 1.59703456e-02, -1.41757563e-01,
-5.02470362e-02],
[ 1.49296012e-01, -9.60208199e-03, -4.82484338e-01,
1.58012139e-02],
[-3.09847219e-01, -1.13959996e-01, -6.71019475e-01,
3.17810448e-01],
[ 2.04860543e-04, -2.16311908e-01, 1.39098046e-01,
-1.40102017e-01],
[-5.82402679e-02, 2.55831587e-01, -3.74597159e-01,
1.23205316e-01],
[-1.23942861e-01, 1.40365188e-02, -2.16884333e-02,
-2.08800511e-02],
[ 1.02934324e-01, -1.81953630e-01, 2.35600757e-01,
-2.29315601e-01]])
不过,这种写法在旧版本的 Python 中不支持,如果你用的是旧版本,可以构建一个单一的元组:
m[tuple((slice(None), *din))]
# or
m[(slice(None), *din)]