ShapeGuard允许您以一种动态的、受einsum启发的方式非常简洁地断言张量的预期形状。
torch-shapeguard的Python项目详细描述
形卫兵
ShapeGuard允许您以一种动态的、受einsum启发的方式非常简洁地断言张量的预期形状
在ml中生成bug很容易。一个特别丰富的bug源是由于操作符的灵活性:a*b
不管a和b是向量、标量向量、向量向量等等,都可以工作。同样地,{
我发现避免错误的最好方法就是一直虔诚地检查我所有张量的形状,所以我最终花了很多时间调试并到处写诸如#(bs, n_samples, z_size)
之类的评论。在
那么为什么不通过算法检查形状呢?它很快就变丑了。在
您必须在任何地方添加assert foo.shape == (bs, n_samples, x_size)
,这实际上使您的行数和
你必须定义你所有的尺寸尺寸(bs等),这些尺寸可能会随列车/测试、批次等的不同而变化。
所以我做了个小帮手让它变得更好。我叫它ShapeGuard。导入时,它将sg
方法添加到torch.Tensor
和{ShapeGuard
类。在
您可以像使用断言一样使用sg
方法:
defforward(self,x,y):x.sg("bchw")y.sg("by")
这将验证x有4个维度,y有2个维度,并且x和y在第一个维度“b”中具有相同的大小。如果assert通过,则返回张量。这意味着您还可以在操作结果中内联使用它:
^{pr2}$如果assert失败,它将生成一条很好的错误消息。在
它的工作方式如下:第一次为一个看不见的形状调用sg时,该形状的张量大小保存在ShapeGuard.shapes
全局dict中。随后的调用会在形状dict中看到此形状,并断言该维度的张量是相同的形状。例如,如果批量大小在训练和测试之间发生变化,您可以调用ShapeGuard.reset("b")
来重置“b”形状。在
我发现通过调用ShapeGuard.reset()
来重置主nn.Module.forward
开头的所有形状效果很好。如果你想验证一个精确的尺寸,你可以传递一个整数作为形状
defforward(self,x,y):x.sg(("b",1,"h","w"))y.sg("by")
特殊形状“*”是为不应断言的形状保留的,例如,x.sg("*chw")
将断言除第一个形状之外的所有形状。在
- 项目
标签: