使用字典映射替换形状为(m, n, 2)的3D numpy数组中的子子数组

0 投票
1 回答
55 浏览
提问于 2025-04-12 03:29

我有一个叫做 input_array 的数组,它的形状是 (m, n, 2),还有一个字典 mapping(格式是 Dict[Tuple[int, int], float])。我想通过根据 mapping 字典来替换 input_array 中最里面的数组,来创建一个新的 (m, n) 形状的 output_array

举个例子:

input_array = np.array([[[1, 2], [3, 7]], [[1, 2], [4, 5]]])
mapping = {(1, 2): 0.7, (3, 7): 0.8, (4, 5): 0.9, (2, 4): 0.3}
---
output_array = np.array([[0.7, 0.8], [0.7, 0.9]])

尝试:

import time
import numpy as np

def map_values(input_array, mapping):
    output_array = np.empty(input_array.shape[:2], dtype=float)

    tic = time.time()
    for key in np.unique(input_array.reshape(-1, 2), axis=0):
        value = mapping[tuple(key)]
        
        mask = np.all(input_array == key, axis=-1)
        indices = np.where(mask)
        output_array[indices[0], indices[1]] = value
    toc = time.time()
    print(f'mapping loop took {toc-tic:.4f} seconds')

    return output_array

有没有什么方法可以让循环和替换数组的过程更快一些?

1 个回答

1

我会用一个简单的循环来处理:

out = (np.array([mapping[l] for l in
                 map(tuple, input_array.reshape(-1, 2))])
         .reshape(input_array.shape[:-1])
      )

或者,如果你预计会有很多重复的值,你可以先把这些值变成唯一的,这样可以减少映射的步骤。不过,如果数据量很大,使用 np.unique 可能不会快太多:

tmp, idx = np.unique(input_array.reshape(-1, 2),
                     return_inverse=True, axis=0)

out = (np.array([mapping[tuple(l)] for l in tmp])[idx]
         .reshape(input_array.shape[:-1])
      )

输出结果:

array([[0.7, 0.8],
       [0.7, 0.9]])
时间记录

在 (2, 2, 2) 的情况下:

# original
86 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# this answer: simple loop
7.08 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

# this answer: only mapping unique values
60.7 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

在 (200, 200, 2) 的情况下:

# original
33.7 ms ± 4.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# this answer: simple loop
45.7 ms ± 3.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# this answer: only mapping unique values
30 ms ± 854 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于较大的 m(或 n),比如 (2000, 2, 2):

# original
2.78 ms ± 366 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# this answer: simple loop
4.54 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# this answer: only mapping unique values
2.3 ms ± 6.19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

对于更大的 mn(例如 (2000, 2000, 2)):

# original
4.5 s ± 67.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# this answer: simple loop
4.29 s ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# this answer: simple loop
4.31 s ± 108 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

正如 @MatBailie 提到的,简单的循环可以通过使用 np.fromiter 进一步优化:

N = input_array.shape[-1]
out = (np.fromiter((mapping[l] for l in
                   map(tuple, input_array.reshape(-1, N))),
                   float, count=input_array.size//N)
         .reshape(input_array.shape[:-1])
      )

撰写回答