十进制非线性分类

2024-04-19 20:56:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我是机器学习和张量流的新手,想用数据做一个简单的二维分类,不能线性分离。在

Current Result 在左侧,您可以看到模型的训练数据。 右侧显示训练模型预测的结果。在

到目前为止,我过度拟合我的模型,所以所有可能的输入都被输入到模型中。 我的预期结果将是一个非常高的准确性,因为模型已经'知道'每个答案。 不幸的是,我使用的深度神经网络只能用线性除法器进行分离,这与我的数据不符。在

我就是这样训练我的模特的:

def testDNN(data):
  """ 
  * data is a list of tuples (x, y, b), 
  * where (x, y) is the input vector and b is the expected output
  """
  # Build neural network
  net = tflearn.input_data(shape=[None, 2])

  net = tflearn.fully_connected(net, 100)
  net = tflearn.fully_connected(net, 100)
  net = tflearn.fully_connected(net, 100)


  net = tflearn.fully_connected(net, 2, activation='softmax')
  net = tflearn.regression(net)

  # Define model
  model = tflearn.DNN(net)

  # check if we already have a trained model
  # Start training (apply gradient descent algorithm)
  model.fit(
    [(x,y) for (x,y,b) in data], 
    [([1, 0] if b else [0, 1]) for (x,y,b) in data], 
    n_epoch=2, show_metric=True)

  return lambda x,y: model.predict([[x, y]])[0][0]

大部分都是从tflearn的示例中获取的,所以我不太清楚每一行都做了什么。在


Tags: the数据in模型forinputdatanet
1条回答
网友
1楼 · 发布于 2024-04-19 20:56:07

你的网络需要一个非线性激活函数。激活函数是神经网络拟合非线性函数的方法。Tflearn默认使用线性激活,您可以将其更改为“sigmoid”,并查看结果是否有所改善。在

相关问题 更多 >