我正在努力使用Tensorflow的剪枝库,并且没有找到很多有用的例子,所以我正在寻找帮助来修剪一个在MNIST数据集上训练过的简单模型。如果有人能帮我解决这个问题,或者提供一个如何使用MNIST上的库的例子,我将不胜感激。在
我的代码的前半部分相当标准,除了我的模型有两个300单位宽的隐藏层,使用layers.masked_fully_connected
进行修剪。在
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
然后我试图定义必要的修剪操作,但是我得到了一个错误。在
^{pr2}$此行出错:
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [?,10] [[Node: Placeholder_1 = Placeholderdtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:GPU:0"]] [[Node: global_step/_57 = _Recv_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_71_global_step", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
我想它需要一种不同的运行方式来代替列车运行,但我没有发现任何有效的调整。在
同样,如果你有一个不同的工作例子,修剪一个基于MNIST的模型,我会认为这是一个答案。在
我能得到的最简单的修剪库示例,我想把它发布在这里,以防它能帮助其他一些在文档方面有困难的noobie。在
Roman Nikishin要求的代码可以保存模型,这只是我最初回答的一个小小的扩展。在
相关问题 更多 >
编程相关推荐