Fair代码
模型学习
Let me explain the interpolate_steps
and how information transfers between steps:
def forward(self, batch):
# Initialize
residue_mask = batch['protein_edit_residue']
self.pred_res_type = torch.ones(residue_mask.sum(), 20, device=self.device) / 20
# Backbone prediction stage
for t in range(self.interpolate_steps): # 5 steps
# Key part: ratio changes with steps
ratio = (self.interpolate_steps - t) / self.interpolate_steps # 5/5 -> 4/5 -> 3/5 -> 2/5 -> 1/5
# 1. Position initialization with changing ratio
label_X, pred_X = copy.deepcopy(batch['protein_pos_backbone']), copy.deepcopy(batch['protein_pos_backbone'])
pred_X = interpolation_init(pred_X, residue_mask, ...)
pred_X = (1 - ratio) * label_X + ratio * pred_X # Gradually reduce noise
# 2. Information transfers between steps through pred_res_type
h_ctx, pos_ctx, batch_ctx, mask_protein = self.compose(batch, self.pred_res_type.detach(), pred_X, pred_ligand)
h_ctx, h_residue, pred_X, pred_ligand = self.encoder(...)
# 3. Update predictions for next step
self.pred_res_type = self.residue_mlp(h_residue[residue_mask])
Key points about information transfer:
- The ratio changes with steps:
- Step 1: ratio = 5/5 = 1.0 (most noise)
- Step 2: ratio = 4/5 = 0.8
- Step 3: ratio = 3/5 = 0.6
- Step 4: ratio = 2/5 = 0.4
- Step 5: ratio = 1/5 = 0.2 (least noise)
- Information transfers through:
pred_res_type
: Residue type predictions are carried forward between steps- Position predictions (
pred_X
): Each step refines positions based on previous step's output - The encoder's hidden states process both current and previous step information
- The interpolation process:
pred_X = (1 - ratio) * label_X + ratio * pred_X
This gradually moves from noisy predictions (high ratio) to refined predictions (low ratio).
- Progressive refinement:
- Early steps (high ratio): More exploration, larger changes allowed
- Later steps (low ratio): More refinement, smaller adjustments
- Each step builds upon previous predictions while reducing noise
So it's not independent training - it's a progressive refinement process where each step uses information from previous steps while gradually reducing the amount of randomness/noise in the predictions.
Would you like me to explain any specific part of this process in more detail?
transfer 只有res_type,pred_x没有transfer
You raise a good point. Let me correct my analysis. Looking at the code more carefully:
def forward(self, batch):
# ...
for t in range(self.interpolate_steps):
# Each iteration starts fresh from the batch data
label_X, pred_X = copy.deepcopy(batch['protein_pos_backbone']), copy.deepcopy(batch['protein_pos_backbone'])
label_ligand, pred_ligand = copy.deepcopy(batch['ligand_pos']), copy.deepcopy(batch['ligand_pos'])
# Apply noise based on ratio
ratio = (self.interpolate_steps - t) / self.interpolate_steps
pred_ligand += torch.randn_like(pred_ligand).to(self.device) * ratio * 0.5
# Initialize positions fresh each time
pred_X = interpolation_init(pred_X, residue_mask, ...)
pred_X = (1 - ratio) * label_X + ratio * pred_X
You're correct - I was wrong about position predictions transferring between steps. Each iteration:
- Starts fresh by copying positions from the batch
- Applies noise based on the current ratio
- Does interpolation initialization
- Makes predictions
The only thing that actually transfers between steps is pred_res_type
(residue type predictions). The positions are reset and recomputed each time.
Each step is more like an independent prediction with decreasing noise level, rather than a refinement of previous position predictions. Thank you for the correction!
分为了backbone stage + atom stage, residue type 会transfer
Yes, let's look at how pred_res_type
transfers from backbone stage to full atom stage:
def forward(self, batch):
# Backbone stage
residue_mask = batch['protein_edit_residue']
self.pred_res_type = torch.ones(residue_mask.sum(), 20, device=self.device) / 20
# After backbone stage loops, pred_res_type contains final backbone predictions
# Full atom stage
random_mask = batch['random_mask_residue']
for t in range(self.interpolate_steps):
# Uses pred_res_type.detach() from backbone stage
h_ctx, pos_ctx, batch_ctx, mask_protein = self.compose(
batch,
self.pred_res_type.detach(), # Uses predictions from backbone stage
pred_X,
pred_ligand,
backbone=False
)
And in the generate method, the transfer is even more explicit:
def generate(self, batch):
# Backbone stage first
# ...predictions made...
# Transfer to full atom stage
select = torch.argmax(self.Softmax(self.pred_res_type), 1).view(-1)
self.pred_res_type = torch.zeros_like(self.pred_res_type, device=self.device)
self.pred_res_type[torch.arange(len(select), device=self.device), select] = 1
# Apply predictions to batch
batch['amino_acid'][batch['protein_edit_residue']] = select + 1
batch['random_mask_residue'] = batch['protein_edit_residue']
So yes, the residue type predictions (pred_res_type
) do transfer from backbone stage to full atom stage, but they're used in slightly different ways in training vs. generation:
- In training: The backbone predictions are detached and used as input for full atom stage
- In generation: The backbone predictions are converted to hard assignments before full atom stage
This is the main information bridge between the two stages.
预测模式和训练模式不一样
Let me analyze the key differences between training and testing (generate) phases:
- Structure Differences:
# Training (forward method)
- Backbone stage: 5 steps without position info transfer
- Full atom stage: 5 steps without refinement loops
# Testing (generate method)
- Backbone stage: 5 steps with position transfer
- Full atom stage:
- Outer loop: 10 iterations for self-consistent refinement
- Inner loop: 5 refinement steps per iteration
- Position Handling:
# Training
- Each step resets positions from batch
- Applies random noise based on ratio
- No transfer between steps
# Testing
- Backbone positions are continuously refined:
pred_X updates in each step
- Full atom positions are refined in both:
- 5 refinement steps
- 10 self-consistent iterations
- Key Testing Features:
# Self-consistent iterations (not in training)
for t in range(1, 10): # Outer loop
for s in range(5): # Inner refinement
# Refine positions
h_ctx, h_residue, batch['protein_pos'], batch['ligand_pos'] = self.encoder(...)
# Random masking and updating (except last iteration)
if t != 10:
batch = random_mask(batch, mask=True)
# Update residue types
batch['amino_acid'][batch['random_mask_residue']] = torch.argmax(...)
- Evaluation:
# Testing has explicit metrics
aar = (label_S[residue_mask] == batch['amino_acid'][residue_mask]).sum() / len(label_S[residue_mask]) # Sequence accuracy
rmsd = torch.sqrt((label_X[atom_mask] - batch['protein_pos'][batch['edit_backbone']]).norm(dim=1).sum() / atom_mask.sum()) # Structure accuracy
The main differences are:
- Testing has more complex iterative refinement structure
- Testing maintains continuous updates to positions and residue types
- Testing uses self-consistent iterations to improve predictions
- Testing includes explicit evaluation metrics
- Testing uses less random noise and is more deterministic
This suggests the testing phase is focused on getting the best possible predictions through iterative refinement, while training is more about learning robust representations through randomization and independent steps.