PyTorch的有用附加层。
torchmore的Python项目详细描述
火炬
torchmore
库是一个包含层和实用程序的小库
用于为图像识别、OCR和其他应用程序编写Pythorch模型。在
弹性
flex
库执行简单的大小推断。它通过将各个层包装在一个包装器中来实现,该包装器仅在维度数据可用时实例化该层。包装器可以在以后移除,模型变成一个只有完全标准模块的模型。看起来像这样:
from torch import nn
from torchmore import layers, flex
noutput = 10
model = nn.Sequential(
layers.Input("BDHW"),
flex.Conv2d(100),
flex.BatchNorm(),
nn.ReLU(),
flex.Conv2d(100),
flex.BatchNorm(),
nn.ReLU(),
layers.Reshape([1, [2, 3, 4]]),
flex.Full(100),
flex.BatchNorm(),
nn.ReLU(),
flex.Full(noutput)
)
flex.shape_inference(model, (1, 1, 28, 28))
flex
库现在为以下层提供包装:
Linear
Conv1d
,Conv2d
,Conv3d
ConvTranspose1d
,ConvTranspose2d
,ConvTranspose3d
LSTM
,BDL_LSTM
,BDHW_LSTM
BatchNorm1d
,BatchNorm2d
,BatchNorm3d
BatchNorm
您可以直接使用Flex
。以下两层相同:
也就是说,您可以轻松地将任何层转换为Flex
层,即使它不在库中。在
层
在图层。输入在
Input
层是一个方便的小层,它重新排列输入维度、检查大小范围和值范围,并自动将数据传输到运行模型的当前设备。在
例如,考虑以下Input
层:
layers.Input("BHWD", "BDHW", range=(0, 1), sizes=[None, 1, None, None]),
上面写着:
- 重新排序的bhd是“为了得到”输入
- 输入值必须在$[0,1]$
- 输入张量必须为$D=1$
- 输入张量被传输到与模型权重相同的设备上
.order
属性
注意,如果输入张量有一个.order
属性,那么它将用于将输入维度重新排序为所需的维度。此模型允许接受多个输入。考虑
model = nn.Sequential(
layers.Input("BHWD", "BDHW", range=(0, 1), sizes=[None, 1, None, None]),
...
)
a = torch.rand((1, 100, 150, 1))
b = a.permute(0, 3, 1, 2)
b.order = "BDHW"
assert model(a) == model(b)
在重新排列图层在
Reorder
层像Tensor.permute
那样对轴重新排序,但这样做的方式可以更好地记录正在发生的事情。考虑以下代码片段:
layers.Reorder("BDL", "LBD"),
flex.LSTM(100, bidirectional=True),
layers.Reorder("LBD", "BDL"),
flex.Conv1d(noutput, 1),
layers.Reorder("BDL", "BLD")
字母本身是任意的,但常见的选择是“BDLHW”。这可能比一系列排列更清楚。在
在层次感。有趣在
对于基于模块的网络,添加功能非常方便。Fun
层允许这样做,如:
layers.Fun("lambda x: x.permute(2, 0, 1)")
注意,由于函数被指定为字符串,因此可以对其进行pickle处理。在
LSTM层
layers.LSTM
:一个简单的LSTM层,它只是对状态输出进行dicardlayers.BDL_LSTM
:一个LSTM变量,它是Conv1d
层的替换layers.BDHW_LSTM
:MDLSTM变体,它是Conv2d
层的直接替换layers.BDHW_LSTM_to_BDH
:按行排列的LSTM,将维度减少1
其他层
这些可能偶尔有用:
layers.Info(info="", every=1000000)
:打印有关激活的信息layers.CheckSizes(...)
:检查通过的张量的大小layers.CheckRange(...)
:检查值的范围layers.Permute(...)
:轴排列(如x.permute)layers.Reshape(...)
:张量整形,可选择组合轴layers.View(...)
:相当于x.viewlayers.Parallel
:并行运行两个模块并将结果堆叠layers.SimplePooling2d
:已打包的最大池/未冷却layers.AcrossPooling2d
:带卷积的最大池化/非制冷
- 项目
标签: