Pytork问题,包括丢失和纪元数

2024-05-08 18:27:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在通过调整curiosily's tutorial中所示的代码来构建一个神经网络。我没有使用天气数据,而是输入我自己的数据(全是数字)来解决时间序列回归问题。在“查找良好参数”部分下,他们计算损耗(计算值与实际输出值之间的差值)

使用我的数据(并使用不同的优化器、节点数、层数等),Train set - lossTest set - loss值可以随着编号的年代而减少,然后loss值再次增加。accuracy总是0.0。我想了解为什么会发生这种情况,理想的loss值是多少(零?),以及如何调整模型参数以避免此问题

我在教程中基本上使用了相同的代码,使用了不同的神经网络:

class Net(nn.Module):

    def __init__(self, n_features):
        super(Net, self).__init__()
        # n_features = no. inputs
        n1 = 8 # no. nodes in layer 1
        n2 = 5
        n3 = 4
        n4 = 5
        n5 = 2
        self.fc1 = nn.Linear(n_features,n1)
        self.fc2 = nn.Linear(n1,n2)
        self.fc3 = nn.Linear(n2,n3)
        self.fc4 = nn.Linear(n3,n4)
        self.fc5 = nn.Linear(n4,n5)
        self.fc6 = nn.Linear(n5,1)

    def forward(self, x):
        #x = F.relu(self.fc1(x)) 
        x = torch.tanh(self.fc1(x)) # activation function in layer 1
        x = torch.sigmoid(self.fc2(x)) 
        x = torch.sigmoid(self.fc3(x)) 
        x = torch.sigmoid(self.fc4(x)) 
        x = torch.tanh(self.fc5(x))
        return torch.sigmoid(self.fc6(x)) 

对于培训/测试数据

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

给予

torch.Size([20, 8]) torch.Size([20])
torch.Size([6, 8]) torch.Size([6])

以下是我的一些数据:

                 Price        f1          f2         f3           f4  \
Date                                                                   
2015-03-02   90.196107  1803.892  113.146970  12.643646  2125.656231   
2015-03-09   64.135647  1800.734  107.968714   5.875968  2121.790735   
2015-03-16   79.552756  1704.983  110.304459  12.003638  2009.193045   
2015-03-23   82.191813  1607.716  107.720195   6.442494  2020.463010   
2015-03-30   69.386627  1522.380  108.315439  13.252422  1979.088367   
2016-03-07   66.651752  2084.698  113.987594  15.707330  2101.044023   
2016-03-14   65.263433  2089.886  110.828986  10.185968  2126.727206   
2016-03-21   67.420919  2152.666  111.177730   8.500986  2167.854746   
2016-03-28   41.540860  2280.450   95.394193  11.750658  2103.708359   
2017-03-06   45.244413  2383.778  110.464190  21.425014  2053.123167   
2017-03-13   54.460675  2289.858  109.539569  10.345976  1982.583561   
2017-03-20   41.063493  2185.491  106.347338  25.485176  1946.495832   
2017-03-27   49.431981  2087.931  110.003395  10.732664  2032.264678   
2018-03-05   73.660636  2204.947  108.703186   5.965236  2017.757273   
2018-03-12   65.089474  2244.313  105.978320  11.164498  2102.231834   
2018-03-19   61.284307  2240.600  106.864093   8.307786  2130.436459   
2018-03-26   57.872814  2256.034  107.546072  16.750366  2153.384082   
2019-03-04  173.318212  1826.327  113.837832  16.328690  2130.480772   
2019-03-11  199.718808  1789.397  110.402293   6.385144  2038.025531   
2019-03-18  206.258064  1809.019  109.644544   4.469384  1957.963904   
2019-03-25  186.447336  1779.967  111.211074  17.378698  1948.683384   
2020-03-02   63.820617  2586.044  113.275140   8.278228  2108.441593   
2020-03-09   52.762931  2513.891  111.669942  12.933696  2087.767817   
2020-03-16   72.150978  2467.322  109.775070  15.961352  2058.925025   
2020-03-23   75.902965  2394.069  111.015771  18.886624  2023.038540   
2020-03-30   51.715278  2298.855   95.129930  10.840378  2122.552675   

                    f5          f6  year  week  
Date                                            
2015-03-02  321349.480  232757.674  2015    10  
2015-03-09  319000.479  221875.266  2015    11  
2015-03-16  329682.915  226521.004  2015    12  
2015-03-23  323335.102  221358.104  2015    13  
2015-03-30  335423.556  222942.088  2015    14  
2016-03-07  324917.837  235534.038  2016    10  
2016-03-14  318739.973  229351.230  2016    11  
2016-03-21  311516.881  231233.470  2016    12  
2016-03-28  317998.580  198436.598  2016    13  
2017-03-06  333304.312  227996.148  2017    10  
2017-03-13  319538.063  225794.464  2017    11  
2017-03-20  343361.214  219506.514  2017    12  
2017-03-27  326703.683  227488.980  2017    13  
2018-03-05  306569.458  225853.320  2018    10  
2018-03-12  309483.605  219876.156  2018    11  
2018-03-19  316931.421  221450.730  2018    12  
2018-03-26  322248.386  224380.222  2018    13  
2019-03-04  340449.937  235389.124  2019    10  
2019-03-11  323107.510  227822.394  2019    11  
2019-03-18  322681.705  226564.046  2019    12  
2019-03-25  342102.164  229219.588  2019    13  
2020-03-02  343116.127  234588.908  2020    10  
2020-03-09  345827.356  230804.352  2020    11  
2020-03-16  341559.653  226640.770  2020    12  
2020-03-23  344563.904  229330.532  2020    13  
2020-03-30  327042.742  196731.040  2020    14  

我将数据拆分为培训/测试集:

# inputs
cols0 = [i for i in cols if i != 'Price']
X = mydata[cols0]

# output
y = mydata[['Price']]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)

2条回答

有两件事帮了大忙:

  • 需要使用例如^{}对数据进行规范化(本教程不进行规范化)

  • 使用设计用于解决回归问题的loss functiontorch.nn.L1Losstorch.nn.MSELoss(本教程考虑二元分类问题,因此使用torch.nn.BCELoss

最初的文章处理的是一个二进制分类问题,其中精度度量是有意义的(注意,预测的浮点首先转换为布尔张量:predicted = y_pred.ge(.5).view(-1)

另一方面,你的问题表明你正在处理一个回归问题,在这种情况下,准确度是没有意义的。准确预测浮点值几乎是不可能的

相关问题 更多 >

    热门问题