<p>我找到了一些解决方法,可以使用模型子分类API进行绘图。由于显而易见的原因,<strong>子分类</strong>API不支持类似<code>model.summary()</code>的<strong>顺序或功能性</strong>API和使用<code>plot_model</code>的良好可视化。在这里,我将演示这两种方法</p>
<pre class="lang-py prettyprint-override"><code>class my_model(Model):
def __init__(self, dim):
super(my_model, self).__init__()
self.Base = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
self.GAP = L.GlobalAveragePooling2D()
self.BAT = L.BatchNormalization()
self.DROP = L.Dropout(rate=0.1)
self.DENS = L.Dense(256, activation='relu', name = 'dense_A')
self.OUT = L.Dense(1, activation='sigmoid')
def call(self, inputs):
x = self.Base(inputs)
g = self.GAP(x)
b = self.BAT(g)
d = self.DROP(b)
d = self.DENS(d)
return self.OUT(d)
# AFAIK: The most convenient method to print model.summary()
# similar to the sequential or functional API like.
def build_graph(self):
x = Input(shape=(dim))
return Model(inputs=[x], outputs=self.call(x))
dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()
</code></pre>
<p>它将产生如下成果:</p>
<pre><code>Layer (type) Output Shape Param #
=================================================================
input_67 (InputLayer) [(None, 124, 124, 3)] 0
_________________________________________________________________
vgg16 (Functional) (None, 3, 3, 512) 14714688
_________________________________________________________________
global_average_pooling2d_32 (None, 512) 0
_________________________________________________________________
batch_normalization_7 (Batch (None, 512) 2048
_________________________________________________________________
dropout_5 (Dropout) (None, 512) 0
_________________________________________________________________
dense_A (Dense) (None, 256) 402192
_________________________________________________________________
dense_7 (Dense) (None, 1) 785
=================================================================
Total params: 14,848,321
Trainable params: 14,847,297
Non-trainable params: 1,024
</code></pre>
<p>现在,通过使用<code>build_graph</code>函数,我们可以简单地绘制整个体系结构</p>
<pre class="lang-py prettyprint-override"><code># Just showing all possible argument for newcomer.
tf.keras.utils.plot_model(
model.build_graph(), # here is the trick (for now)
to_file='model.png', dpi=96, # saving
show_shapes=True, show_layer_names=True, # show shapes and layer name
expand_nested=False # will show nested block
)
</code></pre>
<p>它将产生如下结果:-)</p>
<p><img src="https://user-images.githubusercontent.com/17668390/93187371-8e545000-f761-11ea-8d70-74dc2fe7c644.png" alt="a"/></p>