如何将原始指针转换为特定形状的pytorch张量?
我从一个C++库中得到了一个原始指针,我想把它当作一个特定形状的pytorch张量来使用。因为这段代码是在一个对性能要求很高的地方执行的,所以我想确保不会进行任何堆内存分配或者复制操作。
这是我现在的代码:
def as_tensor(pointer, shape):
return torch.from_numpy(numpy.array(numpy.ctypeslib.as_array(pointer, shape = shape)))
shape = (2, 3, 4)
x = torch.zeros(shape)
p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape)
在这之前真的有必要先转换成numpy数组吗?而且我也不太确定调用 numpy.array(...)
是否会复制 as_array()
指向的内容。
2 个回答
0
如果你想把一个来自C++库的原始指针当作PyTorch的张量,并且希望保持特定的形状,同时又不想进行额外的内存分配或数据复制,你可以直接通过指定数据指针和形状来创建一个PyTorch张量。下面是你代码的一个修改版本,目的是避免不必要的复制:
import torch
import ctypes
def as_tensor(pointer, shape):
# Create a PyTorch tensor from a raw pointer without copying data
tensor = torch.as_tensor(pointer, dtype=torch.float32)
# Reshape the tensor to the specified shape
return tensor.view(shape)
shape = (2, 3, 4)
x = torch.zeros(shape)
p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape)
在这个修改后的版本中:
我们使用 torch.as_tensor()
直接从原始指针创建PyTorch张量,这样就不需要进行任何复制操作。
通过使用 torch.as_tensor()
而不是 torch.from_numpy()
,我们省去了转换成NumPy数组的中间步骤,这个步骤可能会涉及不必要的复制。
view()
方法用于在不改变底层数据的情况下,将张量重新调整为指定的形状。
使用 torch.as_tensor()
,你可以高效地将原始指针解释为PyTorch张量,而不必担心额外的内存分配或数据复制。
2
你可以通过指针和形状来创建一个一维的ctypes数组对象。它支持缓冲协议,所以可以转换成一个一维的张量,最后再进行形状调整。
最后的代码展示了 x
和 y
是共享同一块内存的。
import torch
import ctypes
from math import prod
# It additionally needs the ctypes type as torch type
def as_tensor(pointer, shape, torch_type):
arr = (pointer._type_ * prod(shape)).from_address(
ctypes.addressof(pointer.contents))
return torch.frombuffer(arr, dtype=torch_type).view(*shape)
shape = (2, 3, 4)
x = torch.zeros(shape)
p = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
y = as_tensor(p, shape, torch.float)
print(y) # Print created tensor
x[1,1,0] = 3. # Modify original
print(y) # Print again
输出:
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[3., 0., 0., 0.],
[0., 0., 0., 0.]]])