如何使用pytorch构建具有多个输出(和多个类别)的神经网络?

0 投票
1 回答
29 浏览
提问于 2025-04-13 00:18

我正在处理一个多输出(也就是输出目标超过1个)和多分类(也就是类别超过1个)的问题。我认为这也可以称为多任务问题。

举个例子,我的训练特征数据的形状是(4, 6),也就是说有4行(样本)和6列(特征);而我的训练目标数据的形状是(4, 3),也就是说有4行(样本)和3列(目标)。对于每个目标,我有三种不同的类别:-1、0和1。

我为这个问题定义了一个示例模型架构(和数据),如下所示:

import pandas as pd
from torch import nn 
from logging import log
import torch
feature_data = {
    'A': [1, 2, 3, 4],
    'B': [5, 6, 7, 8],
    'C': [9, 10, 11, 12],
    'D': [13, 14, 15, 16],
    'E': [17, 18, 19, 20],
    'F': [21, 22, 23, 24]
}

target_data = {
    'Col1': [1, -1, 0, 1],
    'Col2': [-1, 0, 1, -1],
    'Col3': [-1, 0, 1, 1]
}

# Create the DataFrame
train_feature_data = pd.DataFrame(feature_data) 
train_target_data = pd.DataFrame(target_data)
device = "cuda" if torch.cuda.is_available() else "cpu"

# create the model
class MyModel(nn.Module):
  def __init__(self, inputs=6, l1=12, outputs=3):
      super().__init__()
      self.sequence = nn.Sequential(
        nn.Linear(inputs, l1),
        nn.Linear(l1, outputs),
        nn.Softmax(dim=1)
    )
      
  def forward(self, x):
      x = self.sequence(x)
      return x
    
x_train = torch.tensor(train_feature_data.to_numpy()).type(torch.float)
model = MyModel(inputs = 6, l1 = 12, outputs = 3).to(device)
model(x_train.to(device=device))

当我把训练数据传入模型时(也就是当我调用model(x_train.to(device=device))),我得到的结果是一个形状为(4, 3)的数组。

根据这个资源 资源,我原本期待得到的结果形状是(4, 3, 3),其中第一个轴(也就是4)表示我的特征和目标数据中的样本数量,第二个轴(也就是中间的3)代表每个样本的logits(在这里因为我使用了softmax函数,所以这将是预测的概率),而这个值是3,因为我有三种类别;第三个轴(或者说形状中的最右边的3)表示我在训练目标数据中有多少个输出/列。

请问有没有人能给我一些指导,告诉我我在这里做错了什么(如果我的方法是错的),以及该如何修正。谢谢。

1 个回答

1

你的模型将输入的形状从 (4, 6) 转换到 (4, 12),这是在第一层线性层中完成的,然后在第二层中转换到 (4, 3)

如果你想要输出的形状是 (4, 3, 3),那么你需要让第二层的输出是 (4, 3*3),然后再进行形状调整。

n_problems = 3
classes_per_problem = 3

model = nn.Linear(6, n_problems*classes_per_problem)

x = torch.randn(4, 6)
x1 = model(x)
bs, _ = x1.shape
x1 = x1.reshape(bs, classes_per_problem, n_problems)

y = torch.randint(high=classes_per_problem, size=(bs, n_problems))
loss_function = nn.CrossEntropyLoss()

loss = loss_function(x1, y)

撰写回答