我做了这个神经网络,但每次我运行它,它给我不同的损失开始,它保持不变的完整循环。我想预测'yy'中每3个'xx'值中的一个值作为输入。如何显示输出?例如:我想显示一个数组,它的预测值尽可能接近'yy'中的值。你知道吗
import tensorflow as tf
xx=(
[178.72,218.38,171.1],
[211.57,215.63,173.13],
[196.25,196.69,116.91],
[121.88,132.07,85.02],
[117.04,135.44,112.54],
[118.13,124.04,97.98],
[116.73,125.88,99.04],
[118.75,125.01,110.16],
[109.69,111.72,69.07],
[76.57,96.88,67.38],
[91.69,128.43,87.57],
[117.57,146.43,117.57]
)
yy=(
[212.09],
[195.58],
[127.6],
[116.5],
[117.95],
[117.55],
[117.55],
[110.39],
[74.33],
[91.08],
[121.75],
[127.3]
)
x=tf.placeholder(tf.float32,[None,3])
y=tf.placeholder(tf.float32,[None,1])
n1=5
n2=5
classes=12
def neuralnetwork(data):
hl1={'weights':tf.Variable(tf.random_normal([3,n1])),'biases':tf.Variable(tf.random_normal([n1]))}
hl2={'weights':tf.Variable(tf.random_normal([n1,n2])),'biases':tf.Variable(tf.random_normal([n2]))}
op={'weights':tf.Variable(tf.random_normal([n2,classes])),'biases':tf.Variable(tf.random_normal([classes]))}
l1=tf.add(tf.matmul(data,hl1['weights']),hl1['biases'])
l1=tf.nn.relu(l1)
l2=tf.add(tf.matmul(l1,hl2['weights']),hl2['biases'])
l2=tf.nn.relu(l2)
output=tf.matmul(l2,op['weights'])+op['biases']
return output
def train(x):
pred=neuralnetwork(x)
# cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
sq = tf.square(pred-y)
loss=tf.reduce_mean(sq)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
#optimizer=tf.train.RMSPropOptimizer(0.01).minimize(cost)
epochs=100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(epochs):
epoch_loss=0
for i in range (int(1)):
batch_x=xx
batch_y=yy
# a=tf.shape(xx)
#print(sess.run(a))
c=sess.run(loss,feed_dict={x:batch_x, y: batch_y})
epoch_loss+=c
print("Epoch ",epoch," completed out of ",epochs, 'loss:', epoch_loss)
train(x)
我不确定你到底想完成什么,但在我看来这是一个回归问题,而不是分类问题。我想下面的代码就是你想要的。我已经清理了一点,但仍然试图保持它在一种方式,你会认出它。我个人会用另一种方式写这个。你知道吗
你犯了两个主要错误:
您正在尝试有12个输出节点,您可能需要的是一个节点,它尝试预测相应的y值。
您没有调用
train
操作,因此优化器实际上没有执行任何操作。例如,这些行:
这将简单地评估计算图的一部分,该部分是计算
pred
张量所必需的,使用整个数据集作为输入,将其输入占位符。你知道吗然而,正如你所见,你的网络只是学会预测标签的平均值,而不管输入是什么。你知道吗
相关问题 更多 >
编程相关推荐