在Python中提取列表中的N维数组元素
我遇到了一个问题。我在一个类里面保存了大量的数据。这些数据大多数是和时间有关的,在最复杂的情况下,变量是三维数组。
因为列表非常灵活(不需要明确声明),我想用它们来封装我的N维数组,这样就可以用列表来保存时间相关的信息。
这里有一个典型的例子,展示了我在历史类里面的列表中,t=0、t=2和t=3时的元素(一个简单的float64矩阵):
history.params[0]
array([[ 1. , 2. , 1. ],
[ 1. , 2. , 1. ],
[ 0.04877093, 0.53176887, 0.26210472],
[ 2.76110434, 1.3569246 , 3.118208 ]])
history.params[2]
array([[ 1.00000825, 1.99998362, 1.00012835],
[ 0.62113657, 0.47057772, 5.23074169],
[ 0.04877193, 0.53076887, 0.26210472],
[ 0.02762192, 4.99387138, 2.6654337 ]])
history.params[3]
array([[ 1.00005473, 1.99923187, 1.00008009],
[ 0.62713657, 0.47157772, 5.23074169],
[ 0.04677193, 0.53476887, 0.25210472],
[ 0.02462192, 4.89387138, 2.6654337 ]])
现在,我该怎么做才能读取/提取我的矩阵中,给定坐标(x,y)在所有时间索引t下的所有元素呢?
我尝试这样做:
history.params[:][0][0]
结果是:
array([ 1., 2., 1.])
实际上,无论冒号放在哪里,我总是得到相同的值,这些值对应于我矩阵的第一行:
"history.params[0][:][0]" returns "array([ 1., 2., 1.])" in the shell
"history.params[0][0][:]" returns "array([ 1., 2., 1.])"
为什么Python在这里无法区分矩阵的元素和列表的元素呢?有什么好的解决办法吗?
当然,我可以写一些循环,创建一个新变量来重新组织我的数据,但这样有点浪费精力。我相信一定有更优雅的解决方案。
另外,我打算在某个时候使用'Cython'来优化我的代码,所以如果你有针对Cython存储这些变量的优化方案,我也很乐意听听。
2 个回答
为什么Python在这里无法区分矩阵的元素和列表的元素?最好的解决办法是什么?
因为 history.params[0]
是一个列表的列表,所以 history.params[0][0]
是一个列表,而 history.params[0][0][:]
是 list[:]
,这实际上是内层列表的一个副本。同样,history.params[0][:]
是这个列表的列表的副本,因此 history.params[0][:][0]
就是 (copy of the list of lists)[0]
,这又是你内层列表的第一行,但在这个二维数组的副本中。
如果你想把你的列表“扁平化”,也就是说把一个二维数组存储成一个列表,那么二维数组中位置 (n,m)
的元素就会变成扁平化数组中的元素 (n*M + m)
。比如你提到的4x3的数组,位置 (0,0) 就是扁平化列表的第0个元素,位置 (2,1) 就是第2*3+1 = 7个元素,以此类推。
这个方法也可以扩展到三维数组:对于大小为 KxNxM
的数组,索引 (k,n,m)
对应的元素是 (k*N*M + n*N + m)
;对于更高维度的数组也是类似的。
我建议你使用numpy.array
数组,而不是嵌套列表。
import numpy as np
# Create some data which would be equal to your "params"
a = np.array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19]]])
print(a[0])
# [[1 2 3]
# [4 5 6]
# [7 8 9]]
print(a[:, 0, 0])
# [1, 11]
print(a[:, 0:2, 0])
# [[1, 4]
# [11, 14]]
此外,numpy还可以和Cython结合使用,具体可以参考这里。