多标记回归keras神经网络中的未知问题

2024-03-28 23:57:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我对神经网络和keras是个新手,在使用我的真实数据之前,我会努力确保事情正常进行。你知道吗

这是一个有1000个样本,三个输入和三个输出的神经网络

X.csv包含:(索引重复三次)

1,1,1

2,2,2

直到10001000

Y.csv包含三个标签:(索引、索引*5、索引/5)

1,5,0.2英寸

2,10,0.4英寸

至10005000200

random.seed(42)
X = np.genfromtxt(r'C:\Users\boss\Desktop\X.csv' , delimiter=',')
y = np.genfromtxt(r'C:\Users\boss\Desktop\Y.csv' , delimiter=',')
y1,y2,y3 = y[:, 0:1],y[:, 1:2],y[:, 2:]
X_train, X_test, y1_train, y1_test, y2_train, y2_test, y3_train, y3_test = train_test_split(X, y1,y2,y3, test_size =0.3, random_state = 0)
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

inp = Input((3,)) 
x = Dense(10, activation='relu')(inp)
x = Dense(10, activation='relu')(x)
x = Dense(10, activation='relu')(x)
out1 = Dense(1,  activation='linear')(x)
out2 = Dense(1,  activation='linear')(x)
out3 = Dense(1,  activation='linear')(x)

model = Model(inputs=inp, outputs=[out1,out2,out3])
model.compile(optimizer = "adam", loss = 'mse')
model.fit(x=X_train, y=[y1_train,y2_train,y3_train], batch_size=100, epochs=10, verbose=1, validation_split=0.3,  shuffle=True)            

#plot predicted data vs real data
y_pred = model.predict(X_test)
plt.plot(y1_test, color = 'red', label = 'Real data')
plt.plot(y_pred[0], color = 'blue', label = 'Predicted data')
plt.title('y1')
plt.legend()
plt.show()

plt.plot(y2_test, color = 'red', label = 'Real data')
plt.plot(y_pred[1], color = 'blue', label = 'Predicted data')
plt.title('y2')
plt.legend()
plt.show()

plt.plot(y3_test, color = 'red', label = 'Real data')
plt.plot(y_pred[2], color = 'blue', label = 'Predicted data')
plt.title('y3')
plt.legend()
plt.show()

不幸的是,损失和验证损失都是巨大的(百万) 另一个问题是,尽管使用了随机种子,结果每次都不一样


Tags: csvtestdatamodelplottrainpltactivation
1条回答
网友
1楼 · 发布于 2024-03-28 23:57:59

造成高损失的一个可能原因是,只有10个年头的年头很少有好结果。试试100、1000等,看看效果如何改善。你知道吗

对于可复制的随机数生成,您还需要为Numpy和TensorFlow指定种子(如果您使用的是TensorFlow后端,这是默认的)。以下是this article给出的示例:

from numpy.random import seed
seed(1)
from tensorflow import set_random_seed
set_random_seed(2)

相关问题 更多 >