在PyTorch中,grad_fn属性究竟存储了什么?它是如何使用的?

2024-05-23 15:27:29 发布

您现在位置:Python中文网/ 问答频道 /正文

在PyTorch中,Tensor类具有grad_fn属性。这引用了用于获取张量的操作:例如,如果a = b + 2a.grad_fn将是AddBackward0。但“参考”到底意味着什么

使用inspect.getmro(type(a.grad_fn))检查AddBackward0将声明AddBackward0的唯一基类是object。此外,这个类的源代码(事实上,在grad_fn中可能遇到的任何其他类)在source code中找不到

所有这些让我想到以下问题:

  1. grad_fn中究竟存储了什么?在反向传播期间如何调用它
  2. 为什么存储在grad_fn中的对象没有某种常见的超类,为什么GitHub上没有这些对象的源代码

Tags: 对象声明source属性object源代码typepytorch
1条回答
网友
1楼 · 发布于 2024-05-23 15:27:29

grad_fn是一个函数“句柄”,用于访问适用的渐变函数。给定点处的梯度是用于在反向传播期间调整权重的系数

“句柄”是对象描述符的通用术语,旨在为对象提供适当的访问权限。例如,当您打开一个文件时,open返回一个文件句柄。实例化类时,__init__函数将返回所创建实例的句柄。句柄包含对相关项的数据和函数的引用(通常是内存地址)

它显示为泛型object类,因为它来自另一种语言的底层实现,因此它不会准确地映射到Python function类型。PyTorch处理跨语言调用和返回。此移交是预编译(共享对象)运行时系统的一部分

这足以澄清你所看到的吗

相关问题 更多 >