用Cython编写PyTorch类

2024-04-16 22:12:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图找到一个PyTorchnn.Module类的例子,它是用Cython为speed编写的,但没有找到任何东西。假设我用Python编写了下面的类,那么最好的Cython翻译是什么

class Actor(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, hidden_size)
        self.l4 = nn.Linear(hidden_size, action_size)
        self.log_std = nn.Parameter(-0.5 * torch.ones(action_size, dtype=torch.float32))

    def forward(self, x):
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = torch.relu(self.l3(x))
        mu = self.l4(x)
        return mu

    def dist(self, mu):
        pi = Normal(mu, torch.exp(self.log_std))
        return pi

    def log_prob(self, pi, action):
        return pi.log_prob(action).sum(axis=-1)

Tags: selflogsizereturndefpiactionnn