在numpy数组中用元组替换整数?

2024-04-20 13:26:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个numpy阵列:

a = [[0 1 2 3 4]
     [0 1 2 3 4]
     [0 1 2 3 4]]

我有一个字典,其中包含要替换/映射的值:

^{pr2}$

所以我的结论是:

a = [[(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]
     [(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]
     [(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]]

根据this SO answer,基于字典的值映射的一种方法是:

b = np.copy( a )
for k, v in d.items(): b[ a == k ] = v

当键和值属于同一数据类型时,此方法有效。但是在我的例子中,键是int,而新值是tuple (of ints)。因此,我得到一个cannot assign 2 input values错误。在

我尝试了:

b = a.astype( ( np.int, 2 ) )

但是,我得到了ValueError: could not broadcast input array from shape (3,5) into shape (3,5,2)的合理误差。在

那么,如何将int映射到numpy数组中的元组呢?在


Tags: 方法answerinnumpyforinput字典so
2条回答

这个怎么样?在

import numpy as np

data = np.tile(np.arange(5), (3, 1))

lookup = { 0 : ( 0, 1 ),
           1 : ( 100, 101 ),
           2 : ( 200, 201 ),
           3 : ( 300, 301 ),
           4 : ( 400, 401 )}

# get keys and values, make sure they are ordered the same
keys, values = zip(*lookup.items())

# making use of the fact that the keys are non negative ints
# create a numpy friendly lookup table
out = np.empty((max(keys) + 1,), object)
out[list(keys)] = values

# now out can be used to look up the tuples using only numpy indexing
result = out[data]
print(result)

印刷品:

^{pr2}$

或者,可以考虑使用整数数组:

out = np.empty((max(keys) + 1, 2), int)
out[list(keys), :] = values

result = out[data, :]
print(result)

印刷品:

[[[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]

 [[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]

 [[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]]

您可以使用结构化数组(这类似于使用元组,但不会失去速度优势):

>>> rgb_dtype = np.dtype([('r', np.int64), ('g', np.int64)])
>>> arr = np.zeros(a.shape, dtype=rgb_dtype)
>>> for k, v in d.items():
...     arr[a==k] = v
>>> arr
array([[(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)],
       [(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)],
       [(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)]], 
      dtype=[('r', '<i8'), ('g', '<i8')])

for-循环可以用更快的操作来代替。但是,如果您的a与总大小相比包含的值很少,这应该足够快了。在

相关问题 更多 >