Java数组到NumPy数组(Py4J)的快速转换

2024-05-15 17:38:48 发布

您现在位置:Python中文网/ 问答频道 /正文

这里有一些很好的例子来说明如何将NumPy数组转换成Java数组,但是反之亦然——如何将Java对象中的数据转换回NumPy数组。我有这样一个Python脚本:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy是Numpy数组,int_array是Java的整数数组-int[ ][ ]类型的数组。在Python脚本中初始化为:

^{pr2}$

尽管它是这样工作的,但它不是最快的方法,而且工作速度也很慢-对于~1000x1000阵列,转换需要5分钟以上。在

有没有办法在合理的时间内完成这个任务?在

如果我尝试:

    test=np.array(int_array)

我得到:

    ValueError: invalid __array_struct__

Tags: numpy脚本sizemynp数组javaarray
2条回答

我遇到了一个类似的问题,并找到了一个比我测试的情况快220倍的解决方案:将1628x120的短整数数组从Java传输到Numpy,运行时间从11秒减少到0.05秒。多亏了this related StackOverflow question,我开始研究py4j byte arrays,结果发现py4j有效地将Java字节数组转换成Python字节对象,反之亦然(通过值传递,而不是通过引用传递)。这是一种相当迂回的做事方式,但不太难。在

因此,如果要传输一个整数数组intArray,其维数为iMaxxjMax(为了这个例子的缘故,我假设这些都作为实例变量存储在对象中),可以先编写一个Java函数将其转换为byte[],如下所示:

public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

然后,您可以编写Python 3代码来接收字节数组并将其转换为正确形状的numpy数组:

^{pr2}$

我也遇到过类似的问题,只是试图绘制从Java端通过py4j获得的谱向量(Java数组)。 在这里,从Java数组到Python列表的转换是通过list()函数实现的。这可能会给我们一些线索,比如如何使用它来填充NumPy数组。。。在

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

这是否比您使用的嵌套for循环要快,我不能说,但它也能做到:

^{pr2}$

相关问题 更多 >