Numpy多维数组的arctan2

0 投票
1 回答
1509 浏览
提问于 2025-04-18 15:01

我正在整理一些代码,这段代码原本是用来处理单个的 float 值的,所以它在使用一维(最终可能是二维) numpy.arrays 作为输入时工作得很好。

简化到一个最小的例子,这个函数看起来是这样的(虽然这个例子并没有做什么有用的事情,但如果去掉 do_mathdo_some_more_math,它会产生正好描述的行为):

def do_complicated_math(r, g, b):
    rgb = numpy.array([r, g, b])

    # Math! No change in array shape. To run example just comment out.
    rgb = do_math(rgb)

    m_2 = numpy.array([[rgb[0], 0, 0], [0, rgb[1], 0], [0, 0, rgb[2]]])

    # Get additional matrices needed for transformation.
    # These are actually predefined 3x3 float arrays
    m_1 = numpy.ones((3, 3))
    m_3 = numpy.ones((3, 3))

    # Transform the rgb array
    rgb_transformed = m_1.dot(m_2).dot(m_3).dot(rgb)

    # More math! No change in array shape. To run example just comment out.
    rgb_transformed = do_some_more_math(rgb_transformed)

    # Almost done just one more thing...
    return numpy.arctan2(rgb_transformed, rgb_transformed)

# Works fine
do_complicated_math(1, 1, 1)

# Fails
x = numpy.ones(6)
do_complicated_math(x, x, x)

这个函数在 rgb 是单独的数字时工作得很好,然而,如果把它们作为 numpy.array 提供(例如,为了同时转换多个 rgb 值),那么 numpy.arctan2 就会抛出以下异常:

Traceback (most recent call last):
  (...) line 32, in do_complicated_math
    numpy.arctan2(rgb_transformed, rgb_transformed)
AttributeError: 'numpy.ndarray' object has no attribute 'arctan2'

我还没有找到任何明确的答案来解释这是什么意思。arctan2 在与多维数组一起使用时似乎工作得很好,如下所示:

numpy.arctan2(numpy.ones((3,4,5)), numpy.ones((3,4,5)))

所以我猜问题可能出在 m_2 的创建方式,或者 m_1m_2m_3rgb 的乘法是如何传播的,但我似乎无法弄清楚到底是哪里出错了。

1 个回答

3

问题在于,当你把 rgb_transformed 传给 arctan2 时,它不再是标准的 numpy 数组,而变成了一个对象数组:

print rgb_transformed
"""[[array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])]
 [array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])]
 [array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])
  array([ 9.,  9.,  9.,  9.,  9.,  9.])]]"""
print rgb_transformed.shape
#(3, 6)
print rgb_transformed.dtype
#object

所以这个问题比我想的简单:

这一行:

m_2 = numpy.array([[rgb[0], 0, 0], [0, rgb[1], 0], [0, 0, rgb[2]]])
print m_2
#array([[array([ 1.,  1.,  1.,  1.,  1.,  1.]), 0, 0],
#       [0, array([ 1.,  1.,  1.,  1.,  1.,  1.]), 0],
#       [0, 0, array([ 1.,  1.,  1.,  1.,  1.,  1.])]], dtype=object)

在这里创建了对象数组,这种情况会影响到后面的代码。

补充说明

要解决这个问题,你可能需要稍微改变一下数组的广播方式。基本上就是要调整外层维度,以反映变化的 rgb 值。免责声明:我没有好的方法来验证这个结果是否适合你的问题,所以请谨慎对待输出结果。

import numpy as np

def do_complicated_math(r, g, b):
    rgb = np.array([r, g, b])

    # create a transposed version of the m_2 array
    m_2 = np.zeros((r.size,3,3))
    for ii,ar in enumerate(rgb):
        m_2[:,ii][:,ii][:] = ar
    m_1 = np.ones((3, 3))
    m_3 = np.ones((3, 3))

    rgb_transformed = m_1.dot(m_2).dot(m_3).dot(rgb)

    print rgb_transformed
    return np.arctan2(rgb_transformed, rgb_transformed)

x = np.ones(6)
do_complicated_math(x, x, x)                                                                                                                        

r = np.array([0.2,0.3,0.1])
g = np.array([1.0,1.0,0.2])
b = np.array([0.3,0.3,0.3])
do_complicated_math(r, g, b)

这只适用于数组作为输入,但处理单个值作为输入应该也很简单。

撰写回答