我使用detach().clone().cpu().numpy()仍然抛出TypeError: 无法将cuda:0设备类型张量转换为numpy

0 投票
1 回答
25 浏览
提问于 2025-04-12 18:24

这个错误发生在第7行的函数里。

def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().clone().cpu().numpy()
    print(type(h))
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()

错误信息:

<class 'numpy.ndarray'>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[17], line 21
     19 loss, h = train(data)
     20 if epoch % 10 == 0:
---> 21     visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
     22     time.sleep(0.3)

Cell In[16], line 16
     14 h = h.detach().clone().cpu().numpy()
     15 print(type(h))
---> 16 plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
     17 if epoch is not None and loss is not None:
     18     plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)

File c:\Users\polyu\Documents\RA\hkjc_dm\hkjc_dm\model\src\venvModel4\lib\site-packages\matplotlib\pyplot.py:3684, in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, plotnonfinite, data, **kwargs)
   3665 @_copy_docstring_and_deprecators(Axes.scatter)
   3666 def scatter(
   3667     x: float | ArrayLike,
   (...)
   3682     **kwargs,
   3683 ) -> PathCollection:
-> 3684     __ret = gca().scatter(
   3685         x,
   3686         y,
...
   1030     return self.numpy()
   1031 else:
-> 1032     return self.numpy().astype(dtype, copy=False)

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

h已经是ndarray(就是一种数据结构),为什么还会给我一个转换成cuda张量的错误呢?顺便说一下,h的形状是[batch_size, 2]。

1 个回答

1

我猜这个错误可能不是由 h 引起的,可能是颜色的问题!你可以检查一下 data.y 是否是一个 GPU 张量,如果是的话,你可以像处理 h 一样,对它使用 detach/cpu/numpy 这些方法。

撰写回答