CausalTAD解读
The code uses both VAE and confidence modeling for trajectory embedding and anomaly detection. Here's how they work together:
The Confidence Model:
- Takes trajectory sequences and embeds them using an embedding layer
- Encodes into latent space using mu and sigma encoders
- Samples from the latent distribution and decodes
- Computes a confidence score representing how "normal" each trajectory point is
- Returns NLL loss + KL divergence as the confidence measure
The VAE Architecture:
-
Encoder:
- Takes source trajectory and encodes into latent distribution (mu, sigma)
- Uses RNN for sequence encoding
-
Decoder:
- Samples from latent space
- Decodes into trajectory reconstruction
- Has two decoders:
- Main decoder reconstructs full trajectory
- SD decoder for source-destination prediction
-
Training flow:
confidence = self.confidence(src) # Get confidence scores
src = self.road_embedding(src) # Embed road segments
kl_loss, p_x, sd_p_x = self.vae(src, trg) # VAE encoding/decoding
nll_loss = self.loss_fn(p_x, label, mask) # Reconstruction loss
The combined architecture helps detect anomalies by:
- Confidence scores identify unusual patterns
- VAE reconstruction error highlights deviations from normal trajectories
- SD decoder ensures start/end points are reasonable
- Masked loss focuses on valid road segment transitions
All losses are combined to train the model to recognize normal vs anomalous trajectories.
revise
import torch
import torch.nn as nn
import math
class TransformerTrajectoryModel(nn.Module):
def __init__(self, num_nodes, d_model=128, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=512, dropout=0.1):
super().__init__()
self.d_model = d_model
# Node embeddings
self.node_embedding = nn.Embedding(num_nodes, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
# Transformer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
# Confidence module
self.confidence_encoder = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
self.confidence_mu = nn.Linear(d_model, d_model)
self.confidence_logvar = nn.Linear(d_model, d_model)
# Trajectory VAE
self.encoder_mu = nn.Linear(d_model, d_model)
self.encoder_logvar = nn.Linear(d_model, d_model)
self.decoder = nn.Linear(d_model, num_nodes)
# Output projection
self.output_projection = nn.Linear(d_model, num_nodes)
def generate_mask(self, sz, device):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf'))
return mask.to(device)
def encode(self, src, src_mask=None):
src = self.node_embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
memory = self.transformer.encoder(src, src_mask)
return memory
def decode(self, tgt, memory, tgt_mask=None, memory_mask=None):
tgt = self.node_embedding(tgt) * math.sqrt(self.d_model)
tgt = self.pos_encoder(tgt)
output = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask)
return output
def compute_confidence(self, x):
h = self.confidence_encoder(x)
mu = self.confidence_mu(h)
logvar = self.confidence_logvar(h)
std = torch.exp(0.5 * logvar)
# Sample using reparameterization trick
eps = torch.randn_like(std)
z = mu + eps * std
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return z, kl_div
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
# Encode trajectory
memory = self.encode(src, src_mask)
# Compute confidence
confidence_latent, confidence_kl = self.compute_confidence(memory)
# Encode to latent space
mu = self.encoder_mu(memory)
logvar = self.encoder_logvar(memory)
std = torch.exp(0.5 * logvar)
# Sample using reparameterization trick
eps = torch.randn_like(std)
z = mu + eps * std
# Compute VAE KL divergence
vae_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
# Decode
if tgt_mask is None:
tgt_mask = self.generate_mask(tgt.size(1), tgt.device)
output = self.decode(tgt, z, tgt_mask, memory_mask)
# Project to vocabulary space
logits = self.output_projection(output)
return {
'logits': logits,
'confidence_kl': confidence_kl,
'vae_kl': vae_kl
}
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
Key changes from the original:
- Uses Transformer encoder-decoder instead of RNN
- Maintains same VAE + confidence structure but implements with self-attention
- Adds positional encoding for sequence information
- Uses masked self-attention for autoregressive decoding
- Combines confidence and VAE latent spaces through the transformer memory
Usage remains similar but benefits from parallel processing and better long-range dependencies.
import torch
import torch.nn as nn
import math
class TransformerTrajectoryModel(nn.Module):
def init(self, num_nodes, d_model=128, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=512, dropout=0.1):
super().init()
self.d_model = d_model
# Node embeddings
self.node_embedding = nn.Embedding(num_nodes, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
# Transformer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
# Confidence module
self.confidence_encoder = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
self.confidence_mu = nn.Linear(d_model, d_model)
self.confidence_logvar = nn.Linear(d_model, d_model)
# Trajectory VAE
self.encoder_mu = nn.Linear(d_model, d_model)
self.encoder_logvar = nn.Linear(d_model, d_model)
self.decoder = nn.Linear(d_model, num_nodes)
# Output projection
self.output_projection = nn.Linear(d_model, num_nodes)
def generate_mask(self, sz, device):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf'))
return mask.to(device)
def encode(self, src, src_mask=None):
src = self.node_embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
memory = self.transformer.encoder(src, src_mask)
return memory
def decode(self, tgt, memory, tgt_mask=None, memory_mask=None):
tgt = self.node_embedding(tgt) * math.sqrt(self.d_model)
tgt = self.pos_encoder(tgt)
output = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask)
return output
def compute_confidence(self, x):
h = self.confidence_encoder(x)
mu = self.confidence_mu(h)
logvar = self.confidence_logvar(h)
std = torch.exp(0.5 * logvar)
# Sample using reparameterization trick
eps = torch.randn_like(std)
z = mu + eps * std
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return z, kl_div
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
# Encode trajectory
memory = self.encode(src, src_mask)
# Compute confidence
confidence_latent, confidence_kl = self.compute_confidence(memory)
# Encode to latent space
mu = self.encoder_mu(memory)
logvar = self.encoder_logvar(memory)
std = torch.exp(0.5 * logvar)
# Sample using reparameterization trick
eps = torch.randn_like(std)
z = mu + eps * std
# Compute VAE KL divergence
vae_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
# Decode
if tgt_mask is None:
tgt_mask = self.generate_mask(tgt.size(1), tgt.device)
output = self.decode(tgt, z, tgt_mask, memory_mask)
# Project to vocabulary space
logits = self.output_projection(output)
return {
'logits': logits,
'confidence_kl': confidence_kl,
'vae_kl': vae_kl
}
class PositionalEncoding(nn.Module):
def init(self, d_model, dropout=0.1, max_len=5000):
super().init()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)