如何将原始指针转换为特定形状的pytorch张量?

0 投票
2 回答
103 浏览
提问于 2025-04-14 16:13

我从一个C++库中得到了一个原始指针,我想把它当作一个特定形状的pytorch张量来使用。因为这段代码是在一个对性能要求很高的地方执行的,所以我想确保不会进行任何堆内存分配或者复制操作。

这是我现在的代码:

def as_tensor(pointer, shape):
    return torch.from_numpy(numpy.array(numpy.ctypeslib.as_array(pointer, shape = shape)))

shape = (2, 3, 4)
x = torch.zeros(shape)

p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape)

在这之前真的有必要先转换成numpy数组吗?而且我也不太确定调用 numpy.array(...) 是否会复制 as_array() 指向的内容。

2 个回答

0

如果你想把一个来自C++库的原始指针当作PyTorch的张量,并且希望保持特定的形状,同时又不想进行额外的内存分配或数据复制,你可以直接通过指定数据指针和形状来创建一个PyTorch张量。下面是你代码的一个修改版本,目的是避免不必要的复制:

import torch
import ctypes

def as_tensor(pointer, shape):
    # Create a PyTorch tensor from a raw pointer without copying data
    tensor = torch.as_tensor(pointer, dtype=torch.float32)
    # Reshape the tensor to the specified shape
    return tensor.view(shape)

shape = (2, 3, 4)
x = torch.zeros(shape)

p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape)

在这个修改后的版本中:

我们使用 torch.as_tensor() 直接从原始指针创建PyTorch张量,这样就不需要进行任何复制操作。 通过使用 torch.as_tensor() 而不是 torch.from_numpy(),我们省去了转换成NumPy数组的中间步骤,这个步骤可能会涉及不必要的复制。 view() 方法用于在不改变底层数据的情况下,将张量重新调整为指定的形状。 使用 torch.as_tensor(),你可以高效地将原始指针解释为PyTorch张量,而不必担心额外的内存分配或数据复制。

2

你可以通过指针和形状来创建一个一维的ctypes数组对象。它支持缓冲协议,所以可以转换成一个一维的张量,最后再进行形状调整。

最后的代码展示了 xy 是共享同一块内存的。

import torch
import ctypes
from math import prod

# It additionally needs the ctypes type as torch type
def as_tensor(pointer, shape, torch_type):
    arr = (pointer._type_ * prod(shape)).from_address(
        ctypes.addressof(pointer.contents))
    
    return torch.frombuffer(arr, dtype=torch_type).view(*shape)

shape = (2, 3, 4)
x = torch.zeros(shape)

p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))

y = as_tensor(p, shape, torch.float)

print(y)  # Print created tensor

x[1,1,0] = 3.  # Modify original

print(y)  # Print again

输出:

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [3., 0., 0., 0.],
         [0., 0., 0., 0.]]])

撰写回答