我能设定神经网络的预测值吗?

2024-04-27 00:02:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我的问题是理论性的而不是实践性的,但我也可以展示一些代码。我有一个网络,它从一个域uvw中的随机值映射到一个xyz域。我希望uvw的某个值转到xyz中我已经知道的其他某个值,因为这是我想要获得的函数的概念,我希望网络过度拟合。你知道吗

我的问题分为两个问题:

  1. 我可以将我想要的预测值设置到网络中,这样它就不必计算那些预测值了吗?你知道吗
  2. 这会影响其他数值的预测吗?你知道吗

这是我的代码,我想展示它,这样我们就可以讨论一些符号。你知道吗

 # The model is a simple fully connected network mapping a 3D parameter point to 3D
    phi = common.MLP(in_dim=3, out_dim=3).to(args.device)

    # Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
    emd_loss_fun = SinkhornLoss(eps=args.sinkhorn_eps, max_iters=args.max_sinkhorn_iters,
                                stop_thresh=1e-3, return_transport_matrix=True) # TODO add r-1 function to the weights  

    mse_loss_fun = torch.nn.MSELoss() 


    # Adam optimizer at first
    optimizer = torch.optim.Rprop(phi.parameters(), lr=args.learning_rate)

    fit_start_time = time.time()

    for epoch in range(args.num_epochs):
        optimizer.zero_grad()

        # Do the forward pass of the neural net, evaluating the function at the parametric points
        y = phi(t)

        # Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
        # NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
        # since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
        with torch.no_grad():
            _, M = emd_loss_fun(phi(t[num:]).unsqueeze(0), x[num:].unsqueeze(0))
            _, Q = emd_loss_fun(phi(t[0:num]).unsqueeze(0), x[0:num].unsqueeze(0)) 
            P[0,num:,num:] = M[0]
            P[0,0:num,0:num] = Q[0]
            #print(y[Q.squeeze().max(0)[1], :])





        # Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
        # between the permuted points
        loss =  mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
        # loss = mse_loss_fun(P.squeeze() @ y,  x)  # Use the transport matrix directly

        # Take an optimizer step
        loss.backward()
        optimizer.step()

        print("Epoch %d, loss = %f" % (epoch, loss.item()))

Tags: andoftheto网络isargsnum