可视化Keras中图层的输出

2024-04-29 18:32:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我在Keras中有一个经过训练的模型,我想加载这个模型并用一个样本进行测试,然后我需要可视化每个层的输出。我使用了Keract进行此操作,我的代码如下,但它产生了以下错误,我不知道为什么?请帮我解决这个错误,或者一个显示图层输出的解决方案?提前谢谢

import keract
from keras.datasets import mnist,cifar10
from keras.models import load_model
from keras.layers import Input, Concatenate, GaussianNoise,Cropping2D,Activation,Dropout,BatchNormalization,MaxPool2D,AveragePooling2D,ZeroPadding2D
from keras.layers import Conv2D, AtrousConv2D
from keras.models import Model
from keras.datasets import mnist,cifar10
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
from keras.optimizers import SGD,RMSprop,Adam
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import numpy as np
import pylab as pl
import matplotlib.cm as cm
from matplotlib import pyplot
from keras import optimizers
from keras import regularizers
import scipy.io as sio
from tqdm import tqdm

from tensorflow.python.keras.layers import Lambda;
from keras.engine.topology import Layer
from keras.layers import DepthwiseConv2D
class SaltAndPepper(Layer):

    def __init__(self, ratio, **kwargs):
        super(SaltAndPepper, self).__init__(**kwargs)
        self.supports_masking = True
        self.ratio = ratio

    # the definition of the call method of custom layer
    def call(self, inputs, training=None):
        def noised():
            shp = K.shape(inputs)[1:]
            mask_select = K.random_binomial(shape=shp, p=self.ratio)
            mask_noise = K.random_binomial(shape=shp, p=0.5) # salt and pepper have the same chance
            out = inputs * (1-mask_select) + mask_noise * mask_select
            return out
    
        return noised()
        #return K.in_train_phase(noised(), inputs, training=training)
    def get_config(self):
        config = {'ratio': self.ratio}
        base_config = super(SaltAndPepper, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))



image = Input((32, 32, 1),name='input')
conv1 = Conv2D(64, (1, 1),padding='same', name='enc_conv1',dilation_rate=(2,2))(image)
bncv1=BatchNormalization(name='enc_bn1')(conv1)
# act1=Kr.layers.ReLU(name='enc_ac1')(bncv1)
act1=Activation('elu',name='ac1')(bncv1)

conv2 = Conv2D(64, (5, 5),padding='same', name='enc_conv2',dilation_rate=(2,2))(act1)
bncv2=BatchNormalization(name='enc_bn2')(conv2)
# act2=Kr.layers.ReLU(name='enc_ac2')(bncv2)
act2=Activation('elu',name='enc_ac2')(bncv2)

conv3 = Conv2D(64, (5, 5), padding='same', name='enc_conv3',dilation_rate=(2,2))(act2)
bncv3=BatchNormalization(name='enc_bch3')(conv3)
# act3=Kr.layers.ReLU(name='enc_ac3')(bncv3)
act3=Activation('elu',name='enc_ac3')(bncv3)

#DrO1=Dropout(0.25,name='Dro1')(BN)
encoded =  Conv2D(1, (5, 5), padding='same',name='encoded_I',dilation_rate=(2,2))(act3)
bncve=BatchNormalization(name='enc_bch4')(encoded)
# acte=Kr.layers.ReLU(name='enc_ac4')(bncve)
acte=Activation('elu',name='enc_ac4')(bncve)

#-----------------------passing wtm to a network-------------------------------
wtm=Input((4,4,1),name='watermark')
conv1w = Conv2D(64, (2, 2), padding='same', name='convl1w',dilation_rate=(2,2))(wtm)
bncv4=BatchNormalization(name='enc_bch5')(conv1w)
# act4=Kr.layers.ReLU(name='enc_ac5')(bncv4)
act4=Activation('elu',name='enc_ac5')(bncv4)

conv2w = Conv2D(64, (2, 2),  padding='same', name='convl2w',dilation_rate=(2,2))(act4)
bncv5=BatchNormalization(name='enc_bch6')(conv2w)
# act5=Kr.layers.ReLU(name='enc_ac6')(bncv5)
act5=Activation('elu',name='enc_ac6')(bncv5)

conv3w = Conv2D(64, (2, 2), padding='same', name='convl3w',dilation_rate=(2,2))(act5)
bncv6=BatchNormalization(name='enc_bch7')(conv3w)
# act6=Kr.layers.ReLU(name='enc_ac7')(bncv6)
act6=Activation('elu',name='enc_ac7')(bncv6)

encodedw =  Conv2D(1, (2, 2), padding='same',name='encoded_w')(act6)
bncvw=BatchNormalization(name='enc_bch8')(encodedw)
# actw=Kr.layers.ReLU(name='enc_ac8')(bncvw)
actw=Activation('elu',name='enc_ac8')(bncvw)

#-----------------------adding w---------------------------------------
wtmN=Kr.layers.Lambda(K.tile,arguments={'n':(1,8,8,1)},name='lambda')(actw)
encoded_merged = Concatenate(axis=3,name='concat')([acte, wtmN])
#-----------------------decoder------------------------------------------------
#------------------------------------------------------------------------------
#deconv_input=Input((28,28,1),name='inputTodeconv')
#encoded_merged = Input((28, 28, 2))
deconv1 = Conv2D(64, (1, 1), padding='same', name='convl1d',dilation_rate=(2,2))(encoded_merged)
bncv7=BatchNormalization(name='dec_bch9')(deconv1)
# act7=Kr.layers.ReLU(name='dec_ac9')(bncv7)
act7=Activation('elu',name='dec_ac9')(bncv7)

deconv2 = Conv2D(64, (5, 5),padding='same', name='convl2d',dilation_rate=(2,2))(act7)
bncv8=BatchNormalization(name='dec_bch10')(deconv2)
# act8=Kr.layers.ReLU(name='dec_ac10')(bncv8)
act8=Activation('elu',name='dec_ac10')(bncv8)

deconv3 = Conv2D(64, (5, 5), padding='same', name='convl3d',dilation_rate=(2,2))(act8)
bncv9=BatchNormalization(name='dec_bch11')(deconv3)
# act9=Kr.layers.ReLU(name='dec_ac11')(bncv9)
act9=Activation('elu',name='dec_ac11')(bncv9)

deconv4 = Conv2D(64, (5, 5), padding='same', name='convl4d',dilation_rate=(2,2))(act9)
bncv10=BatchNormalization(name='dec_bch12')(deconv4)
# act10=Kr.layers.ReLU(name='dec_ac12')(bncv10)
act10=Activation('elu',name='dec_ac12')(bncv10)
decoded = Conv2D(1, (5, 5), padding='same', name='decoder_output',dilation_rate=(2,2))(act10) 
bncv15=BatchNormalization(name='dec_bch17')(decoded)
act15=Activation('elu',name='imageprim')(bncv15)
#-----------------salt-pepper --------------------------------------------
decoded_noise=SaltAndPepper(0.5,name='SandP')(bncv15)#16
#----------------------w extraction------------------------------------
#convw1 = Conv2D(64, (5,5), name='conl1w')(decoded_noise)#28
convw1 = Conv2D(64, (1,1), name='conl1w')(decoded_noise)#28
bncv16=BatchNormalization(name='dec_bch18')(convw1)
# act16=Kr.layers.ReLU(name='dec_ac18')(bncv16)
act16=Activation('elu',name='dec_ac18')(bncv16)

convw2 = Conv2D(64, (5,5), name='conl2w')(act16)#24
bncv17=BatchNormalization(name='dec_bch19')(convw2)
# act17=Kr.layers.ReLU(name='dec_ac19')(bncv17)
act17=Activation('elu',name='dec_ac19')(bncv17)

#Avw1=AveragePooling2D(pool_size=(2,2))(convw2)
convw3 = Conv2D(64, (5,5),name='conl3w')(act17)#20
bncv18=BatchNormalization(name='dec_bch20')(convw3)
# act18=Kr.layers.ReLU(name='dec_ac20')(bncv18)
act18=Activation('elu',name='dec_ac20')(bncv18)

convw4 = Conv2D(64, (5,5), activation='relu' ,name='conl4w')(act18)#16
bncv19=BatchNormalization(name='dec_bch21')(convw4)
# act19=Kr.layers.ReLU(name='dec_ac21')(bncv19)
act19=Activation('elu',name='dec_ac21')(bncv19)

#Avw2=AveragePooling2D(pool_size=(2,2))(convw4)
convw5 = Conv2D(64, (5,5), name='conl5w')(act19)#12
bncv20=BatchNormalization(name='dec_bch22')(convw5)
# act20=Kr.layers.ReLU(name='dec_ac22')(bncv20)
act20=Activation('elu',name='dec_ac22')(bncv20)

convw6 = Conv2D(64, (5,5), name='conl6w')(act20)#8
bncv21=BatchNormalization(name='dec_bch23')(convw6)
# act21=Kr.layers.ReLU(name='dec_ac23')(bncv21)
act21=Activation('elu',name='dec_ac23')(bncv21)

convw7 = Conv2D(64, (5,5), name='conl7w')(act21)#4
bncv22=BatchNormalization(name='dec_bch24')(convw7)
# act22=Kr.layers.ReLU(name='dec_ac24')(bncv22)
act22=Activation('elu',name='dec_ac24')(bncv22)

convw8 = Conv2D(64, (5,5), name='conl8w')(act22)#4
bncv23=BatchNormalization(name='dec_bch25')(convw8)
# act23=Kr.layers.ReLU(name='dec_ac25')(bncv23)
act23=Activation('elu',name='dec_ac25')(bncv23)

pred_w = Conv2D(1, (1, 1),padding='same', name='reconstructed_W',dilation_rate=(2,2))(act23)
bncv24=BatchNormalization(name='dec_bch26')(pred_w)
act24=Activation('sigmoid', name='wprim')(bncv24)  
w_extraction=Model(inputs=[image,wtm],outputs=[act15,act24])

w_extraction.summary()
W = np.random.randint(low=0, high=2, size=(1, 4, 4,1)).astype(np.float32)

#wt_expand[:,0:4,0:4]=w_test
img_rows=32
img_cols=32
(x_train_cifar, y_train_cifar), (x_test_cifar, y_test_cifar) = cifar10.load_data()
x_train_cifar = x_train_cifar.reshape(x_train_cifar.shape[0], img_rows, img_cols, 3)
x_train_cifar = x_train_cifar[:,:,:,1]
x_train_cifar = x_train_cifar.reshape(x_train_cifar.shape[0], img_rows, img_cols, 1)
x_test_cifar = x_test_cifar.reshape(x_test_cifar.shape[0], img_rows, img_cols, 3)
x_test_cifar = x_test_cifar[:,:,:,1]
x_test_cifar = x_test_cifar.reshape(x_test_cifar.shape[0], img_rows, img_cols,1)

x_train_cifar = x_train_cifar.astype('float32')
x_train_cifar = (x_train_cifar)/255.0
w_extraction.load_weights('E:/my_weights/test_withAttack_saltAndpepper_11022020.h5')

activations = keract.get_activations(w_extraction, [x_train_cifar[8000:8001],W])
keract.display_activations(activations)

错误是:

activations = keract.get_activations(w_extraction, [x_train_cifar[8000:8001],W])
Traceback (most recent call last):

  File "<ipython-input-15-b5e1693c12be>", line 1, in <module>
    activations = keract.get_activations(w_extraction, [x_train_cifar[8000:8001],W])

  File "D:\software\Anaconda3\envs\py37\lib\site-packages\keract\keract.py", line 282, in get_activations
    nodes = _get_nodes(model, output_format, layer_names=layer_names, nested=nested)

  File "D:\software\Anaconda3\envs\py37\lib\site-packages\keract\keract.py", line 193, in _get_nodes
    assert is_model_or_layer, 'Not a model or layer!'

AssertionError: Not a model or layer!

Tags: namefromimportlayerstrainactivationdeckeras
1条回答
网友
1楼 · 发布于 2024-04-29 18:32:43

你可能会发现我不久前写的一个模块很有用 https://github.com/brianmanderson/Visualizing_Model/blob/master/Visualize_Model.py

from Visualizing_Model.Visualize_Model import ModelVisualizationClass
model = some_keras_model
visualizer = ModelVisualizationClass(model=model, save_images=True,
                                     out_path=r'some_path_to_image_folder')
x = some_image_to_predict_on
visualizer.print_all_layers()  # Prints the names of all your model layers
visualizer.predict_on_tensor(x)  # Predicts on an input image, needed for writing activation images
visualizer.plot_activation('layer_name')  # Opens a matplotlib figure and shows activations in grid

相关问题 更多 >