Tensorflow打印修复形状问题?

2024-03-29 01:16:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我遇到了一个非常奇怪的问题,希望找一个更熟悉的人。我正在尝试一个基本的LSTM,用以下代码进行一些二进制分类:

class FakeData(object):
    def __init__(self, n):
        self.x = np.random.randint(4, size=(n, 90, 4))
        blah = np.random.randint(2, size=(n))
        self.y = np.zeros((n,2))
        self.y[:,0] = blah
        self.y[:,1] = 1 - blah
        self.mask = np.arange(n)
        self.cnt = 0
        self.n = n

    def getdata(self, n):
        if self.cnt + n > self.n:
            np.randoom.shuffle(self.mask)
            self.cnt = 0
        mask = self.mask[self.cnt : self.cnt + n]
        return self.x[mask], self.y[mask]

n_data = 10000
batch_size = 10
fd = FakeData(n_data)
n_units = 200
n_classes = 2

x = tf.placeholder(tf.float32, shape=[None, 90, 4])
y_ = tf.placeholder(tf.float32, shape=[None, n_classes])
dropout = tf.placeholder(tf.float32)

w_out = tf.Variable(tf.truncated_normal([n_units, n_classes]))
b_out = tf.Variable(tf.truncated_normal([n_classes]))

lstm = tf.contrib.rnn.LSTMCell(n_units)
cell = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=1.0 - dropout)

new_x = tf.unstack(x, 90, 1)
new_x = tf.Print(new_x, [tf.shape(new_x)], message='newx is: ')
output, state = tf.nn.dynamic_rnn(cell, new_x, dtype=tf.float32)
output = tf.Print(output, [tf.shape(output)], message='output is: ')

logits = tf.matmul(output[-1], w_out) + b_out
logits = tf.Print(logits, [tf.shape(logits)], message='logits is: ')
preds = tf.nn.softmax(logits)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, 
    labels=y_))
training = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

correct = tf.equal(tf.argmax(preds, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

#
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        batch_x, batch_y = fd.getdata(batch_size)
        sess.run([training], feed_dict={x: batch_x, y_: batch_y, dropout: 0})
        if i % 100 == 0:
            print "Accuracy {}".format(accuracy.eval(feed_dict={x: batch_x, 
                y_: batch_y, dropout: 0}))

我遇到的具体问题是,出于某种原因,当我在没有tf.打印线,我得到某种奇怪的形状转换错误
ValueError: Dimension must be 2 but is 3 for 'transpose' (op: 'Transpose') with shapes: [?,4], [3]. 在线
output, state = tf.nn.dynamic_rnn(cell, new_x, dtype=tf.float32)

但是,当我把tf.打印行,它正确地记录形状并能够运行整个会话。我错过什么了吗?你知道吗

为清晰起见,形状应为:
输入:n x 90 x 4
新\u x:90 x n x 4
输出:90 x n x 200
登录:n x 2


Tags: selfnewoutputsizetfnpbatchmask
1条回答
网友
1楼 · 发布于 2024-03-29 01:16:29

在这里添加答案,以防将来有人遇到此问题。你知道吗

原来,很多老的RNN例子都使用unstack。然而,这就把它变成了一个张量列表,而动态的张量不能作为输入。印刷品正在将它从一个二维张量列表转换为一个三维张量,以便它能够正确地处理它。解决方案是以其他方式对数据进行维度转换,如:

new_x = tf.transpose(x, perm=(1, 0, 2))(谢谢rvinas)

相关问题 更多 >