Numpy中使用元组作为某些轴的多维索引

2 投票
1 回答
37 浏览
提问于 2025-04-14 18:29

我有一个三维的numpy数组,实际上是一个矩阵的集合。我想通过一种方法把对角线的值设为零。当我打印这个元组和din时,它们看起来完全一样,但返回的数组视图却不同。

m = np.random.normal(0, 0.2, (10, 4, 4))
din = np.diag_indices(m.shape[1], ndim = 2)

m[:, np.array([0,1,2,3]), np.array([0,1,2,3])]) # It returns an array of diagonals as expected
m[:, tuple(din)] # It returns the array

我这里漏掉了什么呢?

1 个回答

3

正如评论中所说的,你需要把索引拆开。

从 Python 3.11 开始,你可以使用:

m[:, *din]

输出结果:

array([[ 8.61622699e-02, -1.46919069e-01, -9.37771599e-02,
         1.94698315e-03],
       [ 1.60933774e-01, -2.77077615e-02, -1.74135776e-01,
        -1.72223723e-01],
       [-1.54804225e-01,  1.08146714e-01,  2.51844877e-01,
        -2.91622737e-02],
       [ 1.22213756e-02,  1.59703456e-02, -1.41757563e-01,
        -5.02470362e-02],
       [ 1.49296012e-01, -9.60208199e-03, -4.82484338e-01,
         1.58012139e-02],
       [-3.09847219e-01, -1.13959996e-01, -6.71019475e-01,
         3.17810448e-01],
       [ 2.04860543e-04, -2.16311908e-01,  1.39098046e-01,
        -1.40102017e-01],
       [-5.82402679e-02,  2.55831587e-01, -3.74597159e-01,
         1.23205316e-01],
       [-1.23942861e-01,  1.40365188e-02, -2.16884333e-02,
        -2.08800511e-02],
       [ 1.02934324e-01, -1.81953630e-01,  2.35600757e-01,
        -2.29315601e-01]])

不过,这种写法在旧版本的 Python 中不支持,如果你用的是旧版本,可以构建一个单一的元组:

m[tuple((slice(None), *din))]

# or
m[(slice(None), *din)]

撰写回答