如何在Python中将Alexnet的渐变存储为numpy数组(在每次迭代中)?

2024-05-16 23:06:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我想将模型的最终梯度向量存储为numpy数组。使用Tensorflow有没有简单直观的方法

我想为每次迭代存储Alexnet的梯度向量(在numpy数组中),直到收敛


Tags: 方法模型numpytensorflow数组向量直观梯度
2条回答

下面是一个类似Alexnet架构的模型,它捕捉了每个时代的梯度

# (1) Importing dependency
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
import numpy as np
np.random.seed(1000)

# (2) Get Data
import tflearn.datasets.oxflower17 as oxflower17
x, y = oxflower17.load_data(one_hot=True)

# (3) Create a sequential model
model = Sequential()

# 1st Convolutional Layer
model.add(Conv2D(filters=96, input_shape=(224,224,3), kernel_size=(11,11), strides=(4,4), padding='valid'))
model.add(Activation('relu'))
# Pooling 
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation before passing it to the next layer
model.add(BatchNormalization())

# 2nd Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(11,11), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())

# 3rd Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())

# 4th Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())

# 5th Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())

# Passing it to a dense layer
model.add(Flatten())
# 1st Dense Layer
model.add(Dense(4096, input_shape=(224*224*3,)))
model.add(Activation('relu'))
# Add Dropout to prevent overfitting
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# 2nd Dense Layer
model.add(Dense(4096))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# 3rd Dense Layer
model.add(Dense(1000))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# Output Layer
model.add(Dense(17))
model.add(Activation('softmax'))

model.summary()

# (4) Compile 
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# (5) Define Gradient Function
def get_gradient_func(model):
    grads = K.gradients(model.total_loss, model.trainable_weights)
    inputs = model.model._feed_inputs + model.model._feed_targets + model.model._feed_sample_weights
    func = K.function(inputs, grads)
    return func

# (6) Train the model such that gradients are captured for every epoch
epoch_gradient = []
for epoch in range(1,5):
    model.fit(x, y, batch_size=64, epochs= epoch, initial_epoch = (epoch-1), verbose=1, validation_split=0.2, shuffle=True)
    get_gradient = get_gradient_func(model)
    grads = get_gradient([x, y, np.ones(len(y))])
    epoch_gradient.append(grads)

# (7) Convert to a 2 dimensiaonal array of (epoch, gradients) type
gradient = np.asarray(epoch_gradient)
print("Total number of epochs run:", epoch)
print("Gradient Array has the shape:",gradient.shape)

输出:gradient是一个二维数组,它为每个历元捕获梯度,并根据网络层保留梯度结构

Model: "sequential_34"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_115 (Conv2D)          (None, 54, 54, 96)        34944     
_________________________________________________________________
activation_213 (Activation)  (None, 54, 54, 96)        0         
_________________________________________________________________
max_pooling2d_83 (MaxPooling (None, 27, 27, 96)        0         
_________________________________________________________________
batch_normalization_180 (Bat (None, 27, 27, 96)        384       
_________________________________________________________________
conv2d_116 (Conv2D)          (None, 17, 17, 256)       2973952   
_________________________________________________________________
activation_214 (Activation)  (None, 17, 17, 256)       0         
_________________________________________________________________
max_pooling2d_84 (MaxPooling (None, 8, 8, 256)         0         
_________________________________________________________________
batch_normalization_181 (Bat (None, 8, 8, 256)         1024      
_________________________________________________________________
conv2d_117 (Conv2D)          (None, 6, 6, 384)         885120    
_________________________________________________________________
activation_215 (Activation)  (None, 6, 6, 384)         0         
_________________________________________________________________
batch_normalization_182 (Bat (None, 6, 6, 384)         1536      
_________________________________________________________________
conv2d_118 (Conv2D)          (None, 4, 4, 384)         1327488   
_________________________________________________________________
activation_216 (Activation)  (None, 4, 4, 384)         0         
_________________________________________________________________
batch_normalization_183 (Bat (None, 4, 4, 384)         1536      
_________________________________________________________________
conv2d_119 (Conv2D)          (None, 2, 2, 256)         884992    
_________________________________________________________________
activation_217 (Activation)  (None, 2, 2, 256)         0         
_________________________________________________________________
max_pooling2d_85 (MaxPooling (None, 1, 1, 256)         0         
_________________________________________________________________
batch_normalization_184 (Bat (None, 1, 1, 256)         1024      
_________________________________________________________________
flatten_34 (Flatten)         (None, 256)               0         
_________________________________________________________________
dense_99 (Dense)             (None, 4096)              1052672   
_________________________________________________________________
activation_218 (Activation)  (None, 4096)              0         
_________________________________________________________________
dropout_66 (Dropout)         (None, 4096)              0         
_________________________________________________________________
batch_normalization_185 (Bat (None, 4096)              16384     
_________________________________________________________________
dense_100 (Dense)            (None, 4096)              16781312  
_________________________________________________________________
activation_219 (Activation)  (None, 4096)              0         
_________________________________________________________________
dropout_67 (Dropout)         (None, 4096)              0         
_________________________________________________________________
batch_normalization_186 (Bat (None, 4096)              16384     
_________________________________________________________________
dense_101 (Dense)            (None, 1000)              4097000   
_________________________________________________________________
activation_220 (Activation)  (None, 1000)              0         
_________________________________________________________________
dropout_68 (Dropout)         (None, 1000)              0         
_________________________________________________________________
batch_normalization_187 (Bat (None, 1000)              4000      
_________________________________________________________________
dense_102 (Dense)            (None, 17)                17017     
_________________________________________________________________
activation_221 (Activation)  (None, 17)                0         
=================================================================
Total params: 28,096,769
Trainable params: 28,075,633
Non-trainable params: 21,136
_________________________________________________________________
Train on 1088 samples, validate on 272 samples
Epoch 1/1
1088/1088 [==============================] - 22s 20ms/step - loss: 3.1251 - acc: 0.2178 - val_loss: 13.0005 - val_acc: 0.1140
Train on 1088 samples, validate on 272 samples
Epoch 2/2
 128/1088 [==>...........................] - ETA: 1s - loss: 2.3913 - acc: 0.2656/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 2.2318 - acc: 0.3465 - val_loss: 9.6171 - val_acc: 0.1912
Train on 1088 samples, validate on 272 samples
Epoch 3/3
  64/1088 [>.............................] - ETA: 1s - loss: 1.5143 - acc: 0.5000/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 1.8109 - acc: 0.4320 - val_loss: 4.3375 - val_acc: 0.3162
Train on 1088 samples, validate on 272 samples
Epoch 4/4
  64/1088 [>.............................] - ETA: 1s - loss: 1.7827 - acc: 0.4688/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 1.5861 - acc: 0.4871 - val_loss: 3.4091 - val_acc: 0.3787
Total number of epochs run: 4
Gradient Array has the shape: (4, 34)
/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '

我们可以按照下面的代码进行操作-

import tensorflow as tf
import numpy as np

print(tf.__version__)

#Define the input tensor
x = tf.constant([3.0,6.0,9.0])

#Define the Gradient Function
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x)

#Output Gradient Tensor
print("Output Gradient Tensor:",dy_dx)

#Convert to array
a = np.asarray(dy_dx)
print("Gradient array:",a)
print("Array shape:",a.shape)
print("Output type:",type(a))

该代码的输出为-

2.1.0
Output Gradient Tensor: tf.Tensor([ 6. 12. 18.], shape=(3,), dtype=float32)
Gradient array: [ 6. 12. 18.]
Array shape: (3,)
Output type: <class 'numpy.ndarray'>

相关问题 更多 >