Numpy:生成四元数乘法的批处理版本

2024-06-07 15:12:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我转换了以下函数

def quaternion_multiply(quaternion0, quaternion1):
    """Return multiplication of two quaternions.

    >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
    >>> numpy.allclose(q, [-44, -14, 48, 28])
    True

    """
    x0, y0, z0, w0 = quaternion0
    x1, y1, z1, w1 = quaternion1
    return numpy.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)

到批处理版本

^{pr2}$

此函数用于处理形状为(?)的四元数1和四元数0?,4)。现在我希望函数可以处理任意数量的维,比如(?)?,?,4)。怎么做?在


Tags: 函数numpydefmultiplyw1x1y1z0
3条回答

只需将axis-=-1传递给np.split沿最后一个轴拆分,就可以得到您所追求的行为。在

由于数组有一个恼人的大小为1的尾随维度,而不是沿着一个新维度堆叠,然后将其压缩,您可以简单地将它们连接起来,同样沿着(最后一个)axis=-1

def quat_multiply(self, quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, axis=-1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, axis=-1)
    return np.concatenate(
        (x1*w0 + y1*z0 - z1*y0 + w1*x0,
         -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
         -x1*x0 - y1*y0 - z1*z0 + w1*w0),
        axis=-1)

请注意,使用此方法,不仅可以乘法任意维数的形状相同的四元数堆栈:

^{pr2}$

但是,你也可以得到很好的广播,也就是说,你可以用一个四元数乘以一堆四元数,而不必调整维数:

>>> a = np.random.rand(6, 5, 4)
>>> b = np.random.rand(4)
>>> quat_multiply(None, a, b).shape
(6, 5, 4)

或者在一条直线上尽可能少地摆弄两个堆栈之间的所有交叉积:

>>> a = np.random.rand(6, 4)
>>> b = np.random.rand(5, 4)
>>> quat_multiply(None, a[:, None], b).shape
(6, 5, 4)

你快到了!您只需对如何拆分和连接阵列稍微小心一点:

def quat_multiply(quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, axis=-1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, axis=-1)

    return np.squeeze(np.stack((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), axis=-1), axis=-2)

在这里,我们两次使用axis=-1沿最后一个轴拆分,然后沿最后一个轴连接回来。最后,我们挤出倒数第二个轴,正如您正确注意到的。为了证明它的有效性:

^{pr2}$

希望这就是你需要的!这应该适用于任意维度和任意数量的维度。在

注意:np.split似乎不适用于列表。所以你只能把数组传递给你的新函数,就像我上面所做的那样。如果您想传递列表,您可以改为调用

 np.split(np.asarray(quaternion0), 4, -1)

在你的功能内。在

而且,您的测试用例似乎是错误的。我想您已经交换了quaternion0和{}的位置:我在测试q0q1时,把它们交换回来了。在

您可以使用^{}将最后一个轴放在前面,帮助我们在不实际拆分4个数组的情况下将它们切片。我们执行所需的操作,最后将第一个轴发送回末尾,以保持输出数组形状与输入相同。因此,我们将有一个解决一般的n维n维nArray,如-

def quat_multiply_ndim(quaternion0, quaternion1):
    x0, y0, z0, w0 = np.rollaxis(quaternion0, -1, 0)
    x1, y1, z1, w1 = np.rollaxis(quaternion1, -1, 0)
    result = np.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=np.float64)
    return np.rollaxis(result,0, result.ndim)

样本运行-

^{pr2}$

相关问题 更多 >

    热门问题