Loading

transformer 中的bert是如何初始化的

动机:在看BertForMaskedLM 的实现代码时,发现在class init的时候有一个self.post_init() 函数,希望看一下它内部调用的哪个函数,以及如果我们自己定义了一些新的模型参数或者embedding怎么进行初始化?

在代码里有两个init_weights 函数,分别是post_init调用的,另一个我们可以用于初始化我们自己的参数:
1.def init_weights(self):

 def init_weights(self):
        """
        If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
        initialization logic in `_init_weights`.
        """
        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

        if _init_weights:
            # Initialize weights
            self.apply(self._initialize_weights)

            # Tie weights should be skipped when not initializing all weights
            # since from_pretrained(...) calls tie weights anyways
            self.tie_weights()

Q: self.apply(self._initialize_weights) 是什么意思呢?
A: 在 Hugging Face 的 Transformers 库中,self.apply(self._initialize_weights) 这行代码的作用是在模型对象创建时,将 _initialize_weights 方法应用到模型的所有参数中,从而初始化所有的权重。

具体来说,self.apply() 是 PyTorch 中 nn.Module 的一个方法,用于将一个函数应用到模型的所有模块中。这个函数可以是任意的自定义函数,它会遍历模型的所有模块,并对每个模块应用这个函数。在这个过程中,模块本身作为参数传递给函数,因此函数可以对模块进行任意的操作。

在这里,我们将 self._initialize_weights() 方法作为参数传递给 self.apply() 方法,这样就可以将这个方法应用到模型的所有模块中。由于 _initialize_weights() 方法会遍历模型的所有模块,并对每个模块的权重进行初始化,因此这个操作可以确保模型的所有权重都被正确地初始化了。

以下是 self.apply() 方法的代码实现,可以看到,这个方法会递归遍历模型的所有子模块,并对每个子模块调用 fn() 函数,最终还会对模型本身调用 fn() 函数。在我们的例子中,fn() 函数就是 _initialize_weights() 方法。

def apply(self, fn):
    """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
    Typical use includes initializing the parameters of a model (see also :meth:`torch.nn.Module.apply`).
    """
    for module in self.children():
        module.apply(fn)
    fn(self)

Q: _initialize_weights() 方法是怎样定义的呢
A: 在 Hugging Face 的 Transformers 库中,PreTrainedModel 类的 _initialize_weights() 方法是在 modeling_utils.py 文件中定义的。这个方法用于初始化模型的权重,并且在模型的构造函数中被调用。

具体来说,_initialize_weights() 方法会遍历模型的所有模块,对每个模块的权重进行初始化。如果模块包含可训练的权重,则会使用 PyTorch 的默认的权重初始化方法来初始化这些权重。如果模块是一个 nn.Linear 层,则会使用正态分布的方法来初始化权重,同时使用小的偏置值来避免死亡 ReLU 问题。

以下是 _initialize_weights() 方法的代码实现,注意到这个方法会根据模块的类型不同采取不同的初始化方法,以确保每个权重都能够得到合适的初始化。

def _initialize_weights(self, module):
    """Initialize the weights."""
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # 使用默认的初始化方法来初始化可训练的权重
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    if isinstance(module, nn.Linear) and module.bias is not None:
        # 避免死亡 ReLU 问题,使用小的偏置值来初始化
        module.bias.data.zero_()
    elif isinstance(module, nn.LayerNorm):
        # 初始化层归一化的权重为 1,偏置为 0
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

2.def _init_weights(self, module):

 def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

这个函数和刚刚的_initialize_weights作用一致,_initialize_weights也可以写成如下样子:

def _initialize_weights(self, module):
        """
        Initialize the weights if they are not already initialized.
        """
        if getattr(module, "_is_hf_initialized", False):
            return
        self._init_weights(module)
        module._is_hf_initialized = True

Q: 我们如果有新定义的参数如何进行初始化?
A:如果你初始化了新的embedding,你可以如下方式初始化

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)

其他:
我们一般会用 from_pretrained 加载预训练好的模型参数,在这种情况下,大概的模型加载流程如下:

  1. 找到正确的基础模型类进行初始化
  2. 使用伪随机初始化来初始化该类(通过使用_init_weights您提到的函数)
  3. 找到具有预训练权重的文件
  4. 在适用的情况下使用预先训练的权重覆盖我们刚刚创建的模型的权重,在初始化参数时,如果模型结构与预训练模型不同,那么只有与预训练模型相同的部分才会被初始化。
posted @ 2023-03-15 20:34  戴墨镜的长颈鹿  阅读(1758)  评论(0编辑  收藏  举报