Loading

MaskFormer代码理解

MaskFormer代码理解

模型主体

代码全部放在/mask_former下面。进入到这个路径,首先发现有一个mask_former_model.py,打开可以发现这就是论文提出的MaskFormer主类所在的文件。先看forward函数的前面五行:

def forward(self, batched_inputs):
  images = [x["image"].to(self.device) for x in batched_inputs]
  images = [(x - self.pixel_mean) / self.pixel_std for x in images]
  images = ImageList.from_tensors(images, self.size_divisibility)

  features = self.backbone(images.tensor)
  outputs = self.sem_seg_head(features)
  ...

前三行是对输入进行处理,之后将处理过的image送入backbone得到features,再将features通过sem_seg_head得到output。这里的backbone和sem_seg_head都是通过类里的from_config函数构建的。配置文件都在项目的configs路径下面,随便打开一个文件夹,例如ade20k-150,会发现下面有一个Base-ADE20K-150.yaml,其他的yaml文件都是继承自该文件,此外还有一个文件夹为swin,存放swin transformer作为backbone的配置文件(都继承自maskformer_R50_bs16_160k.yaml)。从这些配置文件得知,backbone可以为resnet或swin,而sem_seg_head为MaskFormerHead。

先看backbone,以swin为例,代码存放于/mask_former/backbone/swin.py。forward函数如下,返回的是提取到的特征:

def forward(self, x):
  ......
  for i in range(self.num_layers):
    layer = self.layers[i]
    x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

    if i in self.out_indices:
      norm_layer = getattr(self, f"norm{i}")
      x_out = norm_layer(x_out)

      out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
      outs["res{}".format(i + 2)] = out

      return outs

backbone处理得到的特征会送入MaskFormerHead。代码存放于/mask_former/heads/mask_former_head.py。按惯例先看forward:

def forward(self, features):
  return self.layers(features)
def layers(self, features):
  mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
  if self.transformer_in_feature == "transformer_encoder":
    assert (
      transformer_encoder_features is not None
    ), "Please use the TransformerEncoderPixelDecoder."
    predictions = self.predictor(transformer_encoder_features, mask_features)
  else:
    predictions = self.predictor(features[self.transformer_in_feature], mask_features)
    return predictions

forward调用了self.layers,layers函数干的事就是首先通过self.pixel_decoder.forward_features对features进行处理得到mask_features以及transformer_encoder_features,之后通过self.predictor计算得到predictions。

根据之前的经验,发现self.pixel_decoder是根据配置文件构建的。这里如果去找配置文件,会发现只有一部分配置文件如maskformer_panoptic_R50_bs64_554k.yaml指出PIXEL_DECODER_NAME: "TransformerEncoderPixelDecoder",原因暂时还不清楚。总之,可以确定pixel_decoder就是TransformerEncoderPixelDecoder,代码存放于/mask_former/heads/pixel_decoder.py。TransformerEncoderPixelDecoder的forward_featuers函数如下:

def forward_features(self, features):
  # Reverse feature maps into top-down order (from low to high resolution)
  for idx, f in enumerate(self.in_features[::-1]):
    x = features[f]
    lateral_conv = self.lateral_convs[idx]
    output_conv = self.output_convs[idx]
    if lateral_conv is None:
      transformer = self.input_proj(x)
      pos = self.pe_layer(x)
      transformer = self.transformer(transformer, None, pos)
      y = output_conv(transformer)
      # save intermediate feature as input to Transformer decoder
      transformer_encoder_features = transformer
    else:
      cur_fpn = lateral_conv(x)
      # Following FPN implementation, we use nearest upsampling here
      y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
      y = output_conv(y)
      return self.mask_features(y), transformer_encoder_features

注释说的很明白,这个函数的作用就是将低分辨率的特征上采样得到高分辨率的特征。self.lateral_convs和self.output_convs都是定义在父类BasePixelDecoder中,将输入特征的通道维度变为conv_dim。特别地,对于最后一个特征,其对应的lateral_conv为None,此时需要用self.transformer进行处理。self.transformer在TransformerEncoderPixelDecoder中定义,本质是TransformerEncoderOnly这个类的对象,其forward函数做的就是将投影后的输入特征通过TransformerEncoderLayer计算self-attention:

def forward(self, src, mask, pos_embed):
  # flatten NxCxHxW to HWxNxC
  bs, c, h, w = src.shape
  src = src.flatten(2).permute(2, 0, 1)
  pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
  if mask is not None:
    mask = mask.flatten(1)

    memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
    return memory.permute(1, 2, 0).view(bs, c, h, w)

注意最后返回的是memory.permute(1, 2, 0).view(bs, c, h, w),这是因为计算self-attention时进行了维度的调整。总之,TransformerEncoderPixelDecoder进行forward以后返回两个结果:经过self.mask_features将通道调整为mask_dim的特征以及self.transformer输出的特征(通道为conv_dim)。

再回到MaskFormerHead的forward函数,接下来要做的就是:

  if self.transformer_in_feature == "transformer_encoder":
    assert (
      transformer_encoder_features is not None
    ), "Please use the TransformerEncoderPixelDecoder."
    predictions = self.predictor(transformer_encoder_features, mask_features)
  else:
    predictions = self.predictor(features[self.transformer_in_feature], mask_features)
    return predictions

这个分支的含义是根据self.transformer_in_feature找到对应的特征,和mask_features一起送入self.predictor。从from_config函数可以知道self.transformer_in_feature来自于cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,而transformer_predictor是TransformerPredictor这个类的对象。先看配置文件,TRANSFORMER_IN_FEATURE有的是res5,有的是transformer_encoder,那就接着往下看,找到TransformerPredictor这个类,其位于/mask_former/modeling/transformer/transformer_predictor.py,查看其forward函数:

def forward(self, x, mask_features):
  pos = self.pe_layer(x)

  src = x
  mask = None
  hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)

  if self.mask_classification:
    outputs_class = self.class_embed(hs)
    out = {"pred_logits": outputs_class[-1]}
  else:
    out = {}

    if self.aux_loss:
      # [l, bs, queries, embed]
      mask_embed = self.mask_embed(hs)
      outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
      out["pred_masks"] = outputs_seg_masks[-1]
      out["aux_outputs"] = self._set_aux_loss(
        outputs_class if self.mask_classification else None, outputs_seg_masks
      )
    else:
      # FIXME h_boxes takes the last one computed, keep this in mind
      # [bs, queries, embed]
      mask_embed = self.mask_embed(hs[-1])
      outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
      out["pred_masks"] = outputs_seg_masks
      return out

pos是x的位置编码,首先将x、mask_features以及self.query_embed计算得到hs和memory。这里的x与mask_features都是前面提到的,而self.query_embed对应的是论文中的那N个query,是可学习的参数:

self.query_embed = nn.Embedding(num_queries, hidden_dim)

而self.transformer是Transformer类的对象,包含一个完整的encode和decode的过程:

def forward(self, src, mask, query_embed, pos_embed):
	# flatten NxCxHxW to HWxNxC
  bs, c, h, w = src.shape
  src = src.flatten(2).permute(2, 0, 1)
  pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
  query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
  if mask is not None:
  	mask = mask.flatten(1)

  tgt = torch.zeros_like(query_embed)
  memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
  hs = self.decoder(
    tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
  )
  return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

得到hs和memory以后就可以计算概率了:

  if self.mask_classification:
    outputs_class = self.class_embed(hs)
    out = {"pred_logits": outputs_class[-1]}
  else:
    out = {}

这里的self.class_embed就是nn.Linear(hidden_dim, num_classes + 1)。注意这里取outputs_class[-1]作为red_logits是因为这里的hs实际上是前面decoder中一系列layer的输出拼接的结果(为了计算辅助损失,需要用到中间输出),所以pred_logits取的是最后一个layer输出计算得到的结果,对应到论文就是fig2中的N class predictions。

接下来的部分为:

if self.aux_loss:
  # [l, bs, queries, embed]
  mask_embed = self.mask_embed(hs)
  outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
  out["pred_masks"] = outputs_seg_masks[-1]
  out["aux_outputs"] = self._set_aux_loss(
    outputs_class if self.mask_classification else None, outputs_seg_masks
  )
else:
  # FIXME h_boxes takes the last one computed, keep this in mind
  # [bs, queries, embed]
  mask_embed = self.mask_embed(hs[-1])
  outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
  out["pred_masks"] = outputs_seg_masks

self.aux_loss为True则计算辅助损失,这里先不去管,只看else这个分支。self.mask_embed是MLP(hidden_dim, hidden_dim, mask_dim, 3),这是为了获得embedding space大小为mask_dim的N个mask embedding,这样就可以与之前算的per-pixel embeddings(也就是代码中的mask_features)计算得到N个mask了。

损失函数

有了out["pred_masks"]和out["pred_logits"]就可以进行损失函数计算了。回到MaskFormer这个类的forward函数,在训练时,计算得到outputs以后进入如下分支:

if self.training:
  # mask classification target
  if "instances" in batched_inputs[0]:
    gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
    targets = self.prepare_targets(gt_instances, images)
  else:
    targets = None

    # bipartite matching-based loss
  losses = self.criterion(outputs, targets)

  for k in list(losses.keys()):
    if k in self.criterion.weight_dict:
      losses[k] *= self.criterion.weight_dict[k]
    else:
      # remove this loss if not specified in `weight_dict`
      losses.pop(k)

  return losses

criterion是通过SetCriterion构建的,其位于/mask_former/modeling/criterion.py。查看其forward函数:

def forward(self, outputs, targets):
  """This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
  outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}

  # Retrieve the matching between the outputs of the last layer and the targets
  indices = self.matcher(outputs_without_aux, targets)

  # Compute the average number of target boxes accross all nodes, for normalization purposes
  num_masks = sum(len(t["labels"]) for t in targets)
  num_masks = torch.as_tensor(
    [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
  )
  if is_dist_avail_and_initialized():
    torch.distributed.all_reduce(num_masks)
    num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()

    # Compute all the requested losses
    losses = {}
    for loss in self.losses:
      losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))

      # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  if "aux_outputs" in outputs:
    for i, aux_outputs in enumerate(outputs["aux_outputs"]):
      indices = self.matcher(aux_outputs, targets)
      for loss in self.losses:
        l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
        l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
        losses.update(l_dict)

  return losses

outputs_without_aux就是将之前的输出去掉了和辅助损失计算有关的项。然后通过self.matcher将目标与最后一层的输出进行匹配。由MaskFormer的from_config函数可知,self.matcher是HungarianMatcher类的对象,其位于mask_former/modeling/matcher.py。查看其forward函数,发现执行的是memory_efficient_forward:

def memory_efficient_forward(self, outputs, targets):
  """More memory-friendly matching"""
  bs, num_queries = outputs["pred_logits"].shape[:2]

  # Work out the mask padding size
  masks = [v["masks"] for v in targets]
  h_max = max([m.shape[1] for m in masks])
  w_max = max([m.shape[2] for m in masks])

  indices = []

  # Iterate through batch size
  for b in range(bs):

    out_prob = outputs["pred_logits"][b].softmax(-1)  # [num_queries, num_classes]
    out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]

    tgt_ids = targets[b]["labels"]
    # gt masks are already padded when preparing target
    tgt_mask = targets[b]["masks"].to(out_mask)

    # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    # but approximate it in 1 - proba[target class].
    # The 1 is a constant that doesn't change the matching, it can be ommitted.
    cost_class = -out_prob[:, tgt_ids]

    # Downsample gt masks to save memory
    tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")

    # Flatten spatial dimension
    out_mask = out_mask.flatten(1)  # [batch_size * num_queries, H*W]
    tgt_mask = tgt_mask[:, 0].flatten(1)  # [num_total_targets, H*W]

    # Compute the focal loss between masks
    cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)

    # Compute the dice loss betwen masks
    cost_dice = batch_dice_loss(out_mask, tgt_mask)

    # Final cost matrix
    C = (
      self.cost_mask * cost_mask
      + self.cost_class * cost_class
      + self.cost_dice * cost_dice
    )
    C = C.reshape(num_queries, -1).cpu()

    indices.append(linear_sum_assignment(C))
    return [
      (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
      for i, j in indices
    ]

out_prob和out_mask是当前这个batch的N个query对应每个类的概率以及N个query对应的mask。

TODO 剩余训练与推理部分待补充

posted @ 2023-02-06 22:59  脂环  阅读(1173)  评论(0编辑  收藏  举报