张量流神经网络损耗不减

2024-05-15 10:09:21 发布

您现在位置:Python中文网/ 问答频道 /正文

无论我训练网络多长时间,我的普通网络模型似乎都没有减少损失。inputs是由15个(1,14)数组组成的numpy堆栈,因此它的形状为(15,14)。你知道吗

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

with tf.name_scope("bnn"):
    model = tf.keras.Sequential([
        tfp.layers.DenseFlipout(64, activation=tf.nn.relu),
        tfp.layers.DenseFlipout(64, activation=tf.nn.relu),
        tfp.layers.DenseFlipout(11, activation=tf.nn.softmax)
    ])

logits = model(inputs)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits_v2(
    labels=labels, logits=logits)
#kl = sum(model.losses)
loss = neg_log_likelihood #+ kl
train_op_bnn = tf.train.AdamOptimizer().minimize(loss)

init_op = tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(init_op)
    for i in range(10):   
        sess.run(train_op_bnn)
        print(sess.run(loss))

即使我将训练部分循环100次,损失也保持不变。我觉得这只是产生随机数。你知道吗

[ 15.69408512  15.44436646  13.1471653   10.95459461  11.92738056
  12.26817703  10.54849815  15.23202133  10.96777344  10.42760086
  11.41384125  16.70359612  14.71702576  12.59389114  12.59498119]
[ 15.69178391  15.45760155  13.13955212  10.97087193  11.9185276
  12.26686096  10.55150986  15.24072647  10.98205566  10.42508125
  11.40711594  16.70509338  14.71866608  12.59212685  12.58044815]
[ 15.70432568  15.43920803  13.14484024  10.96325684  11.90746498
  12.27936172  10.54476738  15.23231792  10.98124218  10.4410696
  11.41601944  16.70531845  14.71773529  12.58877563  12.58486748]
[ 15.69456196  15.4549036   13.13622952  10.9618206   11.92374229
  12.27278805  10.55258274  15.23033237  10.98199749  10.45040035
  11.40854454  16.69827271  14.71369648  12.58154106  12.58543587]
[ 15.70057106  15.44137669  13.15152454  10.97329521  11.91176605
  12.27191162  10.55643845  15.22959518  10.96763611  10.43885517
  11.40656662  16.70225334  14.71477509  12.58106422  12.57350349]
[ 15.70051384  15.44955826  13.12762356  10.97265244  11.92464542
  12.26436138  10.54278946  15.2416935   10.95931625  10.44235325
  11.39641094  16.70422935  14.71526909  12.58607388  12.5754776 ]
[ 15.70247078  15.44031525  13.13246441  10.96818161  11.90959644
  12.27048016  10.55867577  15.23018265  10.96870041  10.4413271
  11.40160179  16.70223618  14.71558762  12.58408928  12.56538963]
[ 15.69963455  15.43683147  13.12852192  10.97309399  11.92388725
  12.27491188  10.5465889   15.22896194  10.96969795  10.43502808
  11.40288258  16.70007324  14.7202301   12.58245087  12.57666397]
[ 15.70012856  15.43531322  13.13196182  10.9636631   11.92444801
  12.27731323  10.55225563  15.2232151   10.9690609   10.43749809
  11.4017868   16.69387817  14.71770382  12.57458782  12.56506252]
[ 15.70418262  15.43191147  13.13453293  10.95469475  11.91213608
  12.2595768   10.55391121  15.23048401  10.95438766  10.43799973
  11.40246582  16.69694519  14.72452354  12.58216476  12.553545  ]

Tags: importmodellayerstfaswithnnactivation