动态构建通用numpy数组索引

2 投票
1 回答
843 浏览
提问于 2025-04-18 08:28

基本背景

我正在写一些代码,目的是为了简化支持向量机(SVM)的训练,特别是当数据的特征数量不一样时。同时,我还想根据用户指定的“切片”来可视化这些SVM的决策边界。如果我的数据集中有 n 个特征和 m 个样本,我会生成一个 (n+1)-维 的网格,其中第一个索引的每个切片都是一个 m x m x ... 的网格,维度为 n。然后,我可以用我的SVM来对网格中的每个数据点进行分类。

接下来我想做的是在用户指定的任意两个维度上绘制这些结果的切片。我已经有了可以在数据只有两个特征时绘制的代码,但一旦我添加第三个特征,就开始遇到索引问题了。

问题陈述

假设我有一个三维矩阵 predictions,我想在我的网格 mesh 中绘制这些预测值,特别是与 index0=0index1=1 相关的值,以及这些维度中的训练数据。我可以通过类似下面的函数调用来实现:

import matplotlib.pyplot as plt
plt.contourf(mesh[index0,:,:,0], mesh[index1,:,:,0], pred[:,:,0])
plt.scatter(samples[:,index0], samples[:,index1], c=labels)
plt.show()

我想知道的是,如何动态构建我的索引数组,这样如果 index0=0index1=1,我们就能得到上面的代码;但如果 index0=1index1=2,我们就能得到:

plt.contourf(mesh[index0,0,:,:], mesh[index1,0,:,:], pred[0,:,:])

如果 index0=0index1=2,我们又能得到:

plt.contourf(mesh[index0,:,0,:], mesh[index1,:,0,:], pred[:,0,:])

我该如何动态构建这些索引呢?对于那些我可能事先不知道特征数量的情况,有没有更好的方法呢?

更多细节

我尝试过类似这样的代码:

mesh_indices0 = [0]*len(mesh.shape)
mesh_indices0[0] = index0
mesh_indices0[index0+1] = ':'    # syntax error: I cannot add this dynamically
mesh_indices0[index1+1] = ':'    # same problem

我也试着从相反的方向入手,用 mesh_indices = [:]*len(mesh.shape),但这也是无效的语法。我还考虑过尝试这样的代码:

mesh_indices[index0+1] = np.r_[:len(samples[:, 1])]

其中 samples 是我的 m x n 观察数据。不过我觉得这看起来很笨重,所以我想应该有更好的方法。

1 个回答

1

我不太确定我是否完全理解你想做什么,但如果你想处理切片,应该使用Python中的slice对象:

mesh[index0,0,:,:]

这相当于:

mesh[index0,0,slice(0,mesh.shape[2]),slice(0,mesh.shape[3])]

另外要注意,你可以用一个列表或元组来索引切片和索引:

inds = (index0, 0, slice(0,mesh.shape[2]), slice(0,mesh.shape[3]))
mesh[inds]

把这些结合起来,你可以创建一个:等价的slice对象列表,然后用你的具体索引替换掉合适的那个。或者,反过来也可以:

mesh_indices = [0]*len(mesh.shape)
mesh_indices[0] = index0
mesh_indices[index0+1] = slice(0, mesh.shape[index0+1])
mesh_indices[index1+1] = slice(0, mesh.shape[index1+1])

撰写回答