pytorch中的所有模型摘要与keras中的“model.summary()”类似

modelsummar的Python项目详细描述


模型摘要(pytorch模型摘要)

Keras style model.summary() in PyTorch, torchsummary

这是pytorch库,用于改进torchsummarytorchsummaryX的可视化工具。我受到torchsummary的启发,写下了我提到的代码。它与输入参数的数目无关!

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfrommodelsummaryimportsummaryclassNet(nn.Module):def__init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.conv2_drop=nn.Dropout2d()self.fc1=nn.Linear(320,50)self.fc2=nn.Linear(50,10)defforward(self,x):x=F.relu(F.max_pool2d(self.conv1(x),2))x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x=x.view(-1,320)x=F.relu(self.fc1(x))x=F.dropout(x,training=self.training)x=self.fc2(x)returnF.log_softmax(x,dim=1)# show input shapesummary(Net(),torch.zeros((1,1,28,28)),show_input=True)# show output shapesummary(Net(),torch.zeros((1,1,28,28)),show_input=False)
-----------------------------------------------------------------------
             Layer (type)                Input Shape         Param #
=======================================================================
                 Conv2d-1            [-1, 1, 28, 28]             260
                 Conv2d-2           [-1, 10, 12, 12]           5,020
              Dropout2d-3             [-1, 20, 8, 8]               0
                 Linear-4                  [-1, 320]          16,050
                 Linear-5                   [-1, 50]             510
=======================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
-----------------------------------------------------------------------

-----------------------------------------------------------------------
             Layer (type)               Output Shape         Param #
=======================================================================
                 Conv2d-1           [-1, 10, 24, 24]             260
                 Conv2d-2             [-1, 20, 8, 8]           5,020
              Dropout2d-3             [-1, 20, 8, 8]               0
                 Linear-4                   [-1, 50]          16,050
                 Linear-5                   [-1, 10]             510
=======================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
-----------------------------------------------------------------------

快速启动

只需使用pip modelsummary

下载

pip install modelsummaryfrom modelsummary import summary

你可以这样使用这个库。如果您看到更多细节,请参阅示例代码。

from modelsummary import summary

model = your_model_name()

# show input shape
summary(model, (input tensor you want), show_input=True)

# show output shape
summary(model, (input tensor you want), show_input=False)

# show hierarchical struct
summary(model, (input tensor you want), show_hierarchical=True)

摘要函数有这个参数选项def summary(model, *inputs, batch_size=-1, show_input=True, show_hierarchical=False)

选项

  • 模型:您的模型类
  • *输入:输入张量datas(星号)
  • 批量大小:-1与张量相同None
  • show_input:显示输入形状数据,如果此参数为false,则将显示输出形状默认值:true
  • 显示层次:显示层次数据结构,default:false

结果

使用Attention is all you need paper(2017)中的transformer模型运行示例

  1. 显示输入形状
# show input shape
summary(model, enc_inputs, dec_inputs, show_input=True)

-----------------------------------------------------------------------
             Layer (type)                Input Shape         Param #
=======================================================================
                Encoder-1                    [-1, 5]               0
              Embedding-2                    [-1, 5]           3,072
              Embedding-3                    [-1, 5]           3,072
           EncoderLayer-4               [-1, 5, 512]               0
     MultiHeadAttention-5               [-1, 5, 512]               0
                 Linear-6               [-1, 5, 512]         262,656
                 Linear-7               [-1, 5, 512]         262,656
                 Linear-8               [-1, 5, 512]         262,656
  PoswiseFeedForwardNet-9               [-1, 5, 512]               0
                Conv1d-10               [-1, 512, 5]       1,050,624
                Conv1d-11              [-1, 2048, 5]       1,049,088
          EncoderLayer-12               [-1, 5, 512]               0
    MultiHeadAttention-13               [-1, 5, 512]               0
                Linear-14               [-1, 5, 512]         262,656
                Linear-15               [-1, 5, 512]         262,656
                Linear-16               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-17               [-1, 5, 512]               0
                Conv1d-18               [-1, 512, 5]       1,050,624
                Conv1d-19              [-1, 2048, 5]       1,049,088
          EncoderLayer-20               [-1, 5, 512]               0
    MultiHeadAttention-21               [-1, 5, 512]               0
                Linear-22               [-1, 5, 512]         262,656
                Linear-23               [-1, 5, 512]         262,656
                Linear-24               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-25               [-1, 5, 512]               0
                Conv1d-26               [-1, 512, 5]       1,050,624
                Conv1d-27              [-1, 2048, 5]       1,049,088
          EncoderLayer-28               [-1, 5, 512]               0
    MultiHeadAttention-29               [-1, 5, 512]               0
                Linear-30               [-1, 5, 512]         262,656
                Linear-31               [-1, 5, 512]         262,656
                Linear-32               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-33               [-1, 5, 512]               0
                Conv1d-34               [-1, 512, 5]       1,050,624
                Conv1d-35              [-1, 2048, 5]       1,049,088
          EncoderLayer-36               [-1, 5, 512]               0
    MultiHeadAttention-37               [-1, 5, 512]               0
                Linear-38               [-1, 5, 512]         262,656
                Linear-39               [-1, 5, 512]         262,656
                Linear-40               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-41               [-1, 5, 512]               0
                Conv1d-42               [-1, 512, 5]       1,050,624
                Conv1d-43              [-1, 2048, 5]       1,049,088
          EncoderLayer-44               [-1, 5, 512]               0
    MultiHeadAttention-45               [-1, 5, 512]               0
                Linear-46               [-1, 5, 512]         262,656
                Linear-47               [-1, 5, 512]         262,656
                Linear-48               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-49               [-1, 5, 512]               0
                Conv1d-50               [-1, 512, 5]       1,050,624
                Conv1d-51              [-1, 2048, 5]       1,049,088
               Decoder-52                    [-1, 5]               0
             Embedding-53                    [-1, 5]           3,584
             Embedding-54                    [-1, 5]           3,072
          DecoderLayer-55               [-1, 5, 512]               0
    MultiHeadAttention-56               [-1, 5, 512]               0
                Linear-57               [-1, 5, 512]         262,656
                Linear-58               [-1, 5, 512]         262,656
                Linear-59               [-1, 5, 512]         262,656
    MultiHeadAttention-60               [-1, 5, 512]               0
                Linear-61               [-1, 5, 512]         262,656
                Linear-62               [-1, 5, 512]         262,656
                Linear-63               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-64               [-1, 5, 512]               0
                Conv1d-65               [-1, 512, 5]       1,050,624
                Conv1d-66              [-1, 2048, 5]       1,049,088
          DecoderLayer-67               [-1, 5, 512]               0
    MultiHeadAttention-68               [-1, 5, 512]               0
                Linear-69               [-1, 5, 512]         262,656
                Linear-70               [-1, 5, 512]         262,656
                Linear-71               [-1, 5, 512]         262,656
    MultiHeadAttention-72               [-1, 5, 512]               0
                Linear-73               [-1, 5, 512]         262,656
                Linear-74               [-1, 5, 512]         262,656
                Linear-75               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-76               [-1, 5, 512]               0
                Conv1d-77               [-1, 512, 5]       1,050,624
                Conv1d-78              [-1, 2048, 5]       1,049,088
          DecoderLayer-79               [-1, 5, 512]               0
    MultiHeadAttention-80               [-1, 5, 512]               0
                Linear-81               [-1, 5, 512]         262,656
                Linear-82               [-1, 5, 512]         262,656
                Linear-83               [-1, 5, 512]         262,656
    MultiHeadAttention-84               [-1, 5, 512]               0
                Linear-85               [-1, 5, 512]         262,656
                Linear-86               [-1, 5, 512]         262,656
                Linear-87               [-1, 5, 512]         262,656
 PoswiseFeedForwardNet-88               [-1, 5, 512]               0
                Conv1d-89               [-1, 512, 5]       1,050,624
                Conv1d-90              [-1, 2048, 5]       1,049,088
          DecoderLayer-91               [-1, 5, 512]               0
    MultiHeadAttention-92               [-1, 5, 512]               0
                Linear-93               [-1, 5, 512]         262,656
                Linear-94               [-1, 5, 512]         262,656
                Linear-95               [-1, 5, 512]         262,656
    MultiHeadAttention-96               [-1, 5, 512]               0
                Linear-97               [-1, 5, 512]         262,656
                Linear-98               [-1, 5, 512]         262,656
                Linear-99               [-1, 5, 512]         262,656
PoswiseFeedForwardNet-100               [-1, 5, 512]               0
               Conv1d-101               [-1, 512, 5]       1,050,624
               Conv1d-102              [-1, 2048, 5]       1,049,088
         DecoderLayer-103               [-1, 5, 512]               0
   MultiHeadAttention-104               [-1, 5, 512]               0
               Linear-105               [-1, 5, 512]         262,656
               Linear-106               [-1, 5, 512]         262,656
               Linear-107               [-1, 5, 512]         262,656
   MultiHeadAttention-108               [-1, 5, 512]               0
               Linear-109               [-1, 5, 512]         262,656
               Linear-110               [-1, 5, 512]         262,656
               Linear-111               [-1, 5, 512]         262,656
PoswiseFeedForwardNet-112               [-1, 5, 512]               0
               Conv1d-113               [-1, 512, 5]       1,050,624
               Conv1d-114              [-1, 2048, 5]       1,049,088
         DecoderLayer-115               [-1, 5, 512]               0
   MultiHeadAttention-116               [-1, 5, 512]               0
               Linear-117               [-1, 5, 512]         262,656
               Linear-118               [-1, 5, 512]         262,656
               Linear-119               [-1, 5, 512]         262,656
   MultiHeadAttention-120               [-1, 5, 512]               0
               Linear-121               [-1, 5, 512]         262,656
               Linear-122               [-1, 5, 512]         262,656
               Linear-123               [-1, 5, 512]         262,656
PoswiseFeedForwardNet-124               [-1, 5, 512]               0
               Conv1d-125               [-1, 512, 5]       1,050,624
               Conv1d-126              [-1, 2048, 5]       1,049,088
               Linear-127               [-1, 5, 512]           3,584
=======================================================================
Total params: 39,396,352
Trainable params: 39,390,208
Non-trainable params: 6,144
  1. 显示输出形状
# show output shape
summary(model, enc_inputs, dec_inputs, show_input=False)

-----------------------------------------------------------------------
             Layer (type)               Output Shape         Param #
=======================================================================
              Embedding-1               [-1, 5, 512]           3,072
              Embedding-2               [-1, 5, 512]           3,072
                 Linear-3               [-1, 5, 512]         262,656
                 Linear-4               [-1, 5, 512]         262,656
                 Linear-5               [-1, 5, 512]         262,656
     MultiHeadAttention-6              [-1, 8, 5, 5]               0
                 Conv1d-7              [-1, 2048, 5]       1,050,624
                 Conv1d-8               [-1, 512, 5]       1,049,088
  PoswiseFeedForwardNet-9               [-1, 5, 512]               0
          EncoderLayer-10              [-1, 8, 5, 5]               0
                Linear-11               [-1, 5, 512]         262,656
                Linear-12               [-1, 5, 512]         262,656
                Linear-13               [-1, 5, 512]         262,656
    MultiHeadAttention-14              [-1, 8, 5, 5]               0
                Conv1d-15              [-1, 2048, 5]       1,050,624
                Conv1d-16               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-17               [-1, 5, 512]               0
          EncoderLayer-18              [-1, 8, 5, 5]               0
                Linear-19               [-1, 5, 512]         262,656
                Linear-20               [-1, 5, 512]         262,656
                Linear-21               [-1, 5, 512]         262,656
    MultiHeadAttention-22              [-1, 8, 5, 5]               0
                Conv1d-23              [-1, 2048, 5]       1,050,624
                Conv1d-24               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-25               [-1, 5, 512]               0
          EncoderLayer-26              [-1, 8, 5, 5]               0
                Linear-27               [-1, 5, 512]         262,656
                Linear-28               [-1, 5, 512]         262,656
                Linear-29               [-1, 5, 512]         262,656
    MultiHeadAttention-30              [-1, 8, 5, 5]               0
                Conv1d-31              [-1, 2048, 5]       1,050,624
                Conv1d-32               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-33               [-1, 5, 512]               0
          EncoderLayer-34              [-1, 8, 5, 5]               0
                Linear-35               [-1, 5, 512]         262,656
                Linear-36               [-1, 5, 512]         262,656
                Linear-37               [-1, 5, 512]         262,656
    MultiHeadAttention-38              [-1, 8, 5, 5]               0
                Conv1d-39              [-1, 2048, 5]       1,050,624
                Conv1d-40               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-41               [-1, 5, 512]               0
          EncoderLayer-42              [-1, 8, 5, 5]               0
                Linear-43               [-1, 5, 512]         262,656
                Linear-44               [-1, 5, 512]         262,656
                Linear-45               [-1, 5, 512]         262,656
    MultiHeadAttention-46              [-1, 8, 5, 5]               0
                Conv1d-47              [-1, 2048, 5]       1,050,624
                Conv1d-48               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-49               [-1, 5, 512]               0
          EncoderLayer-50              [-1, 8, 5, 5]               0
               Encoder-51              [-1, 8, 5, 5]               0
             Embedding-52               [-1, 5, 512]           3,584
             Embedding-53               [-1, 5, 512]           3,072
                Linear-54               [-1, 5, 512]         262,656
                Linear-55               [-1, 5, 512]         262,656
                Linear-56               [-1, 5, 512]         262,656
    MultiHeadAttention-57              [-1, 8, 5, 5]               0
                Linear-58               [-1, 5, 512]         262,656
                Linear-59               [-1, 5, 512]         262,656
                Linear-60               [-1, 5, 512]         262,656
    MultiHeadAttention-61              [-1, 8, 5, 5]               0
                Conv1d-62              [-1, 2048, 5]       1,050,624
                Conv1d-63               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-64               [-1, 5, 512]               0
          DecoderLayer-65              [-1, 8, 5, 5]               0
                Linear-66               [-1, 5, 512]         262,656
                Linear-67               [-1, 5, 512]         262,656
                Linear-68               [-1, 5, 512]         262,656
    MultiHeadAttention-69              [-1, 8, 5, 5]               0
                Linear-70               [-1, 5, 512]         262,656
                Linear-71               [-1, 5, 512]         262,656
                Linear-72               [-1, 5, 512]         262,656
    MultiHeadAttention-73              [-1, 8, 5, 5]               0
                Conv1d-74              [-1, 2048, 5]       1,050,624
                Conv1d-75               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-76               [-1, 5, 512]               0
          DecoderLayer-77              [-1, 8, 5, 5]               0
                Linear-78               [-1, 5, 512]         262,656
                Linear-79               [-1, 5, 512]         262,656
                Linear-80               [-1, 5, 512]         262,656
    MultiHeadAttention-81              [-1, 8, 5, 5]               0
                Linear-82               [-1, 5, 512]         262,656
                Linear-83               [-1, 5, 512]         262,656
                Linear-84               [-1, 5, 512]         262,656
    MultiHeadAttention-85              [-1, 8, 5, 5]               0
                Conv1d-86              [-1, 2048, 5]       1,050,624
                Conv1d-87               [-1, 512, 5]       1,049,088
 PoswiseFeedForwardNet-88               [-1, 5, 512]               0
          DecoderLayer-89              [-1, 8, 5, 5]               0
                Linear-90               [-1, 5, 512]         262,656
                Linear-91               [-1, 5, 512]         262,656
                Linear-92               [-1, 5, 512]         262,656
    MultiHeadAttention-93              [-1, 8, 5, 5]               0
                Linear-94               [-1, 5, 512]         262,656
                Linear-95               [-1, 5, 512]         262,656
                Linear-96               [-1, 5, 512]         262,656
    MultiHeadAttention-97              [-1, 8, 5, 5]               0
                Conv1d-98              [-1, 2048, 5]       1,050,624
                Conv1d-99               [-1, 512, 5]       1,049,088
PoswiseFeedForwardNet-100               [-1, 5, 512]               0
         DecoderLayer-101              [-1, 8, 5, 5]               0
               Linear-102               [-1, 5, 512]         262,656
               Linear-103               [-1, 5, 512]         262,656
               Linear-104               [-1, 5, 512]         262,656
   MultiHeadAttention-105              [-1, 8, 5, 5]               0
               Linear-106               [-1, 5, 512]         262,656
               Linear-107               [-1, 5, 512]         262,656
               Linear-108               [-1, 5, 512]         262,656
   MultiHeadAttention-109              [-1, 8, 5, 5]               0
               Conv1d-110              [-1, 2048, 5]       1,050,624
               Conv1d-111               [-1, 512, 5]       1,049,088
PoswiseFeedForwardNet-112               [-1, 5, 512]               0
         DecoderLayer-113              [-1, 8, 5, 5]               0
               Linear-114               [-1, 5, 512]         262,656
               Linear-115               [-1, 5, 512]         262,656
               Linear-116               [-1, 5, 512]         262,656
   MultiHeadAttention-117              [-1, 8, 5, 5]               0
               Linear-118               [-1, 5, 512]         262,656
               Linear-119               [-1, 5, 512]         262,656
               Linear-120               [-1, 5, 512]         262,656
   MultiHeadAttention-121              [-1, 8, 5, 5]               0
               Conv1d-122              [-1, 2048, 5]       1,050,624
               Conv1d-123               [-1, 512, 5]       1,049,088
PoswiseFeedForwardNet-124               [-1, 5, 512]               0
         DecoderLayer-125              [-1, 8, 5, 5]               0
              Decoder-126              [-1, 8, 5, 5]               0
               Linear-127                 [-1, 5, 7]           3,584
=======================================================================
Total params: 39,396,352
Trainable params: 39,390,208
Non-trainable params: 6,144
-----------------------------------------------------------------------
  1. 显示分级摘要
Transformer(
  (encoder): Encoder(
    (src_emb): Embedding(6, 512), 3,072 params
    (pos_emb): Embedding(6, 512), 3,072 params
    (layers): ModuleList(
      (0): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (1): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (2): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (3): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (4): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (5): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
    ), 17,326,080 params
  ), 17,332,224 params
  (decoder): Decoder(
    (tgt_emb): Embedding(7, 512), 3,584 params
    (pos_emb): Embedding(6, 512), 3,072 params
    (layers): ModuleList(
      (0): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (1): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (2): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (3): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (4): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (5): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
    ), 22,053,888 params
  ), 22,060,544 params
  (projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params
), 39,396,352 params

参考

code_reference={'https://github.com/pytorch/pytorch/issues/2001','https://gist.github.com/HTLife/b6640af9d6e7d765411f8aa9aa94b837','https://github.com/sksq96/pytorch-summary','Inspired by https://github.com/sksq96/pytorch-summary'}

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java Cassandra复制因子大于节点数   java J2EE JTA事务回滚不适用于OSE Glassfish 4.0(Build 89)   java spring安全预认证用户登录   org的java类文件。反应流。从RxJava编译示例时未找到Publisher?   java在使用dataFormat作为POJO通过Camel调用Web服务时无法设置SOAP标头   Javafx类的java静态实例   java如何防止一个部件在关闭时覆盖另一个部件的位置   sql server无法从我的java代码连接到数据库   java在JList(Swing)中显示带有的ArrayList   从Java中的CXF服务获取WSAddressing数据   使用资产文件夹进行java简单json解析(本地)   java LDAPException未绑定的无效凭据   JavaJSFspring部署到weblogic   JAVA中字符数组中的特定元素排列?   如果脚本位于不同的目录中,则ant不会使用exec标记运行Javashell脚本