TrajPreModel

轨迹预测模型

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

#######################################
class TrajPreModel(nn.Module):
    """self-attention model"""
    def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
        super(TrajPreModel, self).__init__()
        self.loc_size = loc_size
        self.loc_emb_size = loc_emb_size
        self.hidden_size = hidden_size
        self.heads = head_num
        self.dropout_p = dropout_p
        # embeding
        self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
        self.weight = self.emb_loc.weight
              
        #-------------model---------------
        self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
        self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
        self.is_weight_sharing = False#is_weight_sharing
        self.init_weights()
        self.dropout = nn.Dropout(p=dropout_p)

    def init_weights(self):
        ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
        hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
        b = (param.data for name, param in self.named_parameters() if 'bias' in name)
        for t in ih:
            nn.init.xavier_uniform(t)
        for t in hh:
            nn.init.orthogonal(t)
        for t in b:
            nn.init.constant_(t, 0)

    def forward(self, x):
        
        seq = x[1] # [batch_size, seq_len]
        loc_emb = self.emb_loc(seq) 
        output = self.dropout(loc_emb)
        #Self-attention
        
        output = self.attention(output,output, output)
        output = self.dropout(output)

        if not self.is_weight_sharing:
            y = self.fc(output)
        else:
            y = F.linear(output, self.weight)
        
        score = F.log_softmax(y, dim=-1) 
        return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]

posted @ 2020-05-19 22:15  li修远  阅读(116)  评论(0编辑  收藏  举报