Loading

pytorch model()[] 模型对象类型

model = Model() model(input) 直接调用Model类中的forward(input)函数,因其实现了__call__

举个例子

 1 import math, random
 2 import numpy as np
 3 
 4 import torch
 5 import torch.nn as nn
 6 import torch.optim as optim
 7 import torch.autograd as autograd 
 8 import torch.nn.functional as F
 9 USE_CUDA = torch.cuda.is_available()
10 Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
11 
12 class Encoder(nn.Module):
13     def __init__(self, din=32, hidden_dim=128):
14         super(Encoder, self).__init__()
15         self.fc = nn.Linear(din, hidden_dim)
16 
17     def forward(self, x):
18         embedding = F.relu(self.fc(x))
19         return embedding
20 
21 class AttModel(nn.Module):
22     def __init__(self, n_node, din, hidden_dim, dout):
23         super(AttModel, self).__init__()
24         self.fcv = nn.Linear(din, hidden_dim)
25         self.fck = nn.Linear(din, hidden_dim)
26         self.fcq = nn.Linear(din, hidden_dim)
27         self.fcout = nn.Linear(hidden_dim, dout)
28 
29     def forward(self, x, mask):
30         v = F.relu(self.fcv(x))
31         q = F.relu(self.fcq(x))
32         k = F.relu(self.fck(x)).permute(0,2,1)
33         att = F.softmax(torch.mul(torch.bmm(q,k), mask) - 9e15*(1 - mask),dim=2)
34 
35         out = torch.bmm(att,v)
36         #out = torch.add(out,v)
37         out = F.relu(self.fcout(out))
38         return out
39 
40 class Q_Net(nn.Module):
41     def __init__(self, hidden_dim, dout):
42         super(Q_Net, self).__init__()
43         self.fc = nn.Linear(hidden_dim, dout)
44 
45     def forward(self, x):
46         q = self.fc(x)
47         return q
View Code

 

 1 class DGN(nn.Module):
 2     def __init__(self,n_agent,num_inputs,hidden_dim,num_actions):
 3         super(DGN, self).__init__()
 4         
 5         self.encoder = Encoder(num_inputs,hidden_dim)
 6         self.att_1 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
 7         self.att_2 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
 8         self.q_net = Q_Net(hidden_dim,num_actions)
 9         
10     def forward(self, x, mask):
11         h1 = self.encoder(x)
12         h2 = self.att_1(h1, mask)
13         h3 = self.att_2(h2, mask)
14         q = self.q_net(h3)
15         return q 

 

在监视窗口查看

 model是Tensor类型

故model(input)[0]是取第一个batch

posted @ 2021-12-02 09:51  ArkiWang  阅读(651)  评论(0编辑  收藏  举报