"""
Feature extractor class for Audio Spectrogram Transformer.
"""
# 导入必要的库
from typing import List, Optional, Union
import numpy as np # 导入 NumPy 库
# 导入音频处理相关的函数和类
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, is_speech_available, is_torch_available, logging
# 如果 TorchAudio 可用,则导入相应的模块
if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi
# 如果 Torch 可用,则导入 Torch 库
if is_torch_available():
import torch
# 获取日志记录器
logger = logging.get_logger(__name__)
# 定义 Audio Spectrogram Transformer (AST) 特征提取器类
class ASTFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Audio Spectrogram Transformer (AST) feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
num_mel_bins (`int`, *optional*, defaults to 128):
Number of Mel-frequency bins.
max_length (`int`, *optional*, defaults to 1024):
Maximum length to which to pad/truncate the extracted features.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the log-Mel features using `mean` and `std`.
mean (`float`, *optional*, defaults to -4.2677393):
The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default.
std (`float`, *optional*, defaults to 4.5689974):
The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation
by default.
return_attention_mask (`bool`, *optional*, defaults to `False`):
Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
"""
model_input_names = ["input_values", "attention_mask"] # 定义模型输入的名称列表,包括输入值和注意力掩码
def __init__( # 初始化方法,用于设置模型参数和属性
self,
feature_size=1, # 特征大小,默认为1
sampling_rate=16000, # 采样率,默认为16000
num_mel_bins=128, # 梅尔频谱的梅尔频道数,默认为128
max_length=1024, # 最大长度,默认为1024
padding_value=0.0, # 填充值,默认为0.0
do_normalize=True, # 是否进行归一化,默认为True
mean=-4.2677393, # 均值,默认为-4.2677393
std=4.5689974, # 标准差,默认为4.5689974
return_attention_mask=False, # 是否返回注意力掩码,默认为False
**kwargs, # 其他关键字参数
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins # 设置梅尔频道数
self.max_length = max_length # 设置最大长度
self.do_normalize = do_normalize # 设置是否归一化
self.mean = mean # 设置均值
self.std = std # 设置标准差
self.return_attention_mask = return_attention_mask # 设置是否返回注意力掩码
if not is_speech_available(): # 如果语音处理不可用
mel_filters = mel_filter_bank( # 生成梅尔滤波器组
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) # 对梅尔滤波器进行填充以适应需求
self.window = window_function(400, "hann", periodic=False) # 创建窗函数对象
def _extract_fbank_features( # 提取梅尔滤波器组特征的方法
self,
waveform: np.ndarray, # 输入波形数据,numpy数组类型
max_length: int, # 最大长度,整数类型
) -> np.ndarray:
"""
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
and hence the waveform should not be normalized before feature extraction.
"""
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available(): # 如果语音处理可用
waveform = torch.from_numpy(waveform).unsqueeze(0) # 将波形数据转换为PyTorch张量
fbank = ta_kaldi.fbank( # 使用TorchAudio的Kaldi库提取梅尔滤波器组特征
waveform,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
)
else:
waveform = np.squeeze(waveform) # 去除波形数据中的单维度
fbank = spectrogram( # 使用自定义的频谱图方法提取梅尔滤波器组特征
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
fbank = torch.from_numpy(fbank) # 将特征数据转换为PyTorch张量
n_frames = fbank.shape[0] # 获取特征张量的帧数
difference = max_length - n_frames # 计算需要填充或截断的帧数差异
# pad or truncate, depending on difference
if difference > 0: # 如果差异大于0,进行填充操作
pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference)) # 创建填充模块对象
fbank = pad_module(fbank) # 对特征张量进行填充
elif difference < 0: # 如果差异小于0,进行截断操作
fbank = fbank[0:max_length, :] # 截取指定长度的特征数据
fbank = fbank.numpy() # 将PyTorch张量转换为numpy数组
return fbank # 返回梅尔滤波器组特征数组
# 根据给定的均值和标准差对输入值进行标准化处理
def normalize(self, input_values: np.ndarray) -> np.ndarray:
return (input_values - (self.mean)) / (self.std * 2)
# 实现对象的可调用接口,用于处理原始语音数据
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
# 设置文件编码为 UTF-8
# 版权声明,这段代码版权属于 MIT 和 HuggingFace Inc. 团队,保留所有权利
#
# 根据 Apache License, Version 2.0 许可,除非符合许可证要求,否则不得使用此文件
# 您可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,本软件是基于“原样”提供的,没有任何明示或暗示的担保或条件
# 有关详细信息,请参阅许可证
""" PyTorch Audio Spectrogram Transformer (AST) model."""
# 引入数学库
import math
# 引入类型提示
from typing import Dict, List, Optional, Set, Tuple, Union
# 引入 PyTorch 库
import torch
# 引入 PyTorch 的检查点工具
import torch.utils.checkpoint
# 引入 PyTorch 中的神经网络模块
from torch import nn
# 引入 PyTorch 中的损失函数
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
# 引入激活函数映射
from ...activations import ACT2FN
# 引入模型输出
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
# 引入预训练模型工具
from ...modeling_utils import PreTrainedModel
# 引入 PyTorch 实用工具
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
# 引入日志工具
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
# 引入 AST 配置
from .configuration_audio_spectrogram_transformer import ASTConfig
# 获取日志记录器
logger = logging.get_logger(__name__)
# 模型配置文档字符串
_CONFIG_FOR_DOC = "ASTConfig"
# 检查点文档字符串
_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
# 预期输出形状文档字符串
_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
# 音频分类检查点文档字符串
_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
# 音频分类预期输出文档字符串
_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
# 音频分类预期损失文档字符串
_SEQ_CLASS_EXPECTED_LOSS = 0.17
# 音频频谱变换预训练模型存档列表
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"MIT/ast-finetuned-audioset-10-10-0.4593",
# 查看所有音频频谱变换模型,请访问 https://huggingface.co/models?filter=ast
]
# ASTEmbeddings 类定义,继承自 nn.Module
class ASTEmbeddings(nn.Module):
"""
构建 CLS 标记、位置和补丁嵌入。
"""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 定义 CLS 标记参数
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
# 定义蒸馏标记参数
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
# 初始化补丁嵌入
self.patch_embeddings = ASTPatchEmbeddings(config)
# 获取频率和时间输出维度形状
frequency_out_dimension, time_out_dimension = self.get_shape(config)
# 计算补丁数
num_patches = frequency_out_dimension * time_out_dimension
# 定义位置嵌入参数
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
# 定义 Dropout 层
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 保存配置
self.config = config
# 定义一个方法用于获取输出的形状,基于给定的配置参数
def get_shape(self, config):
# 根据 Karpathy 在 cs231n 博客中的方法计算频率输出的维度
# https://cs231n.github.io/convolutional-networks/#conv
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
# 根据 Karpathy 在 cs231n 博客中的方法计算时间输出的维度
# https://cs231n.github.io/convolutional-networks/#conv
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
# 返回计算得到的频率和时间输出维度
return frequency_out_dimension, time_out_dimension
# 定义一个前向传播方法,输入是一个 torch.Tensor,输出也是一个 torch.Tensor
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
# 获取输入数据的批量大小
batch_size = input_values.shape[0]
# 将输入数据通过 patch_embeddings 方法进行嵌入
embeddings = self.patch_embeddings(input_values)
# 使用 self.cls_token 扩展成 batch_size 行,-1 列的张量
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 使用 self.distillation_token 扩展成 batch_size 行,-1 列的张量
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
# 将 cls_tokens、distillation_tokens 和 embeddings 沿着第一维度拼接起来
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
# 将位置嵌入加到 embeddings 上
embeddings = embeddings + self.position_embeddings
# 对 embeddings 进行 dropout 操作
embeddings = self.dropout(embeddings)
# 返回处理后的 embeddings 张量作为前向传播的输出
return embeddings
# ASTSelfAttention 类的构造函数,初始化自注意力机制模块
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 检查隐藏大小是否可以被注意力头数整除,若不能且没有嵌入大小,则引发值错误
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
# 设置注意力头数和每个注意力头的大小
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
# 创建用于查询、键和值的线性变换层
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
# Dropout 层,用于注意力概率的随机失活
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
# 将输入张量转换为分数矩阵形式的函数
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
# 定义新的形状以便于注意力分数计算,并进行维度置换
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
# ASTSelfAttention 类的前向传播函数,实现自注意力机制
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# 定义函数返回类型为一个元组,包含一个 torch.Tensor 类型的上下文层和一个 torch.Tensor 类型的注意力概率
-> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# 使用 self.query 对隐藏状态进行查询,生成混合查询层
mixed_query_layer = self.query(hidden_states)
# 使用 self.key 对隐藏状态进行键的转换,并为计算注意力分数准备转置
key_layer = self.transpose_for_scores(self.key(hidden_states))
# 使用 self.value 对隐藏状态进行值的转换,并为计算上下文层准备转置
value_layer = self.transpose_for_scores(self.value(hidden_states))
# 对混合查询层进行转置,为计算注意力分数准备
query_layer = self.transpose_for_scores(mixed_query_layer)
# 计算原始的注意力分数,通过 query_layer 和 key_layer 的点积得到
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# 根据注意力头的大小对注意力分数进行缩放
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 将注意力分数归一化为概率分布
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# 对注意力概率进行 dropout 操作,以防止过拟合
attention_probs = self.dropout(attention_probs)
# 如果给定了头部掩码,将注意力概率与头部掩码相乘,实现掩码操作
if head_mask is not None:
attention_probs = attention_probs * head_mask
# 计算上下文层,将注意力概率与 value_layer 相乘得到加权和
context_layer = torch.matmul(attention_probs, value_layer)
# 对上下文层进行维度置换和连续化操作,以便后续的形状变换
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# 根据新的上下文层形状,进行视图变换,以匹配所有注意力头的输出维度
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
# 根据是否需要输出注意力权重,选择性地返回上下文层和注意力概率
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
# 返回函数的输出结果
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
class ASTSelfOutput(nn.Module):
"""
The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 定义一个全连接层,输入和输出大小都是 config.hidden_size
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 定义一个 Dropout 层,使用的 dropout 概率是 config.hidden_dropout_prob
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 将输入的 hidden_states 应用全连接层 self.dense
hidden_states = self.dense(hidden_states)
# 对应用全连接层后的 hidden_states 应用 dropout
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
class ASTAttention(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 定义一个 ASTSelfAttention 层
self.attention = ASTSelfAttention(config)
# 定义一个 ASTSelfOutput 层
self.output = ASTSelfOutput(config)
# 初始化一个空的集合,用于存储被剪枝的注意力头
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
# 调用外部函数 find_pruneable_heads_and_indices,找到可剪枝的头部和索引
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# 剪枝线性层
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# 更新超参数并存储被剪枝的头部
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# 调用 self.attention 的 forward 方法,得到 attention 的输出
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
# 将 attention 的输出和输入的 hidden_states 应用到 self.output 层
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # 如果需要输出 attentions,则添加到 outputs 中
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 定义一个全连接层,输入大小是 config.hidden_size,输出大小是 config.intermediate_size
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 根据 config.hidden_act 的类型选择相应的激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
# 定义一个前向传播方法,接受一个名为hidden_states的张量作为输入,并返回一个张量作为输出
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 将输入张量通过全连接层dense进行线性变换
hidden_states = self.dense(hidden_states)
# 将经过全连接层后的张量输入到激活函数intermediate_act_fn中进行非线性变换
hidden_states = self.intermediate_act_fn(hidden_states)
# 返回经过线性变换和非线性变换后的张量作为输出
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
class ASTOutput(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 创建一个全连接层,输入维度为 config.intermediate_size,输出维度为 config.hidden_size
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
# 创建一个 Dropout 层,用于在训练过程中随机置零输入张量的元素,以防止过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 将输入张量通过全连接层映射到新的张量 hidden_states
hidden_states = self.dense(hidden_states)
# 对 hidden_states 应用 Dropout 操作
hidden_states = self.dropout(hidden_states)
# 将 dropout 后的 hidden_states 与输入张量 input_tensor 相加,实现残差连接
hidden_states = hidden_states + input_tensor
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
# 设置每个前馈分块的大小
self.chunk_size_feed_forward = config.chunk_size_feed_forward
# 序列长度的维度
self.seq_len_dim = 1
# 创建自注意力机制对象
self.attention = ASTAttention(config)
# 创建中间层对象
self.intermediate = ASTIntermediate(config)
# 创建输出层对象
self.output = ASTOutput(config)
# 创建前层归一化对象,在隐藏大小维度上应用 LayerNorm
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 创建后层归一化对象,在隐藏大小维度上应用 LayerNorm
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# 应用前层归一化后,传入自注意力层进行计算
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
# 获取自注意力层的输出张量
attention_output = self_attention_outputs[0]
# 如果输出注意力权重,则将其添加到输出元组中
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# 实现第一个残差连接
hidden_states = attention_output + hidden_states
# 在中间层的输出上应用后层归一化
layer_output = self.layernorm_after(hidden_states)
# 在中间层上应用中间层对象进行进一步处理
layer_output = self.intermediate(layer_output)
# 在输出层对象上执行第二个残差连接
layer_output = self.output(layer_output, hidden_states)
# 将层输出添加到输出元组中
outputs = (layer_output,) + outputs
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
class ASTEncoder(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.config = config
# 使用 ASTLayer 对象的列表创建层的序列
self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
# 设置梯度检查点为 False
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
# 如果不需要输出隐藏状态,则初始化空元组;否则设为None
all_hidden_states = () if output_hidden_states else None
# 如果不需要输出注意力权重,则初始化空元组;否则设为None
all_self_attentions = () if output_attentions else None
# 遍历每一个层次模块
for i, layer_module in enumerate(self.layer):
# 如果需要输出隐藏状态,则将当前隐藏状态加入到all_hidden_states中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 获取当前层的头部掩码,如果head_mask存在的话
layer_head_mask = head_mask[i] if head_mask is not None else None
# 如果开启梯度检查点且处于训练阶段,则使用梯度检查点函数进行前向传播
if self.gradient_checkpointing and self.training:
# 调用梯度检查点函数,以节省内存开销
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
# 否则直接调用当前层的前向传播函数
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
# 更新当前隐藏状态为当前层的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则将当前层的注意力权重加入到all_self_attentions中
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
# 如果需要输出隐藏状态,则将最终的隐藏状态加入到all_hidden_states中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 如果不需要返回字典格式的结果,则返回一个元组,其中包含需要返回的非None的值
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
# 否则返回一个BaseModelOutput对象,包含最终的隐藏状态、所有隐藏状态、所有注意力权重
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ASTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 设置配置类为 ASTConfig
config_class = ASTConfig
# 设置基础模型前缀为 "audio_spectrogram_transformer"
base_model_prefix = "audio_spectrogram_transformer"
# 设置主输入名称为 "input_values"
main_input_name = "input_values"
# 启用梯度检查点支持
supports_gradient_checkpointing = True
# 从 transformers 库中的 DeiTPreTrainedModel 类中复制的方法
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 将输入升级为 `fp32`,然后转回到所需的 `dtype`,以避免 `trunc_normal_cpu` 在 `half` 模式下未实现的问题
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
# 如果存在偏置项,将其数据初始化为零
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
# 如果是 LayerNorm 模块,将偏置项初始化为零,权重初始化为 1.0
module.bias.data.zero_()
module.weight.data.fill_(1.0)
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ASTConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
This class can be used as a regular PyTorch Module. Please refer to the PyTorch documentation for general usage and
behavior details.
Parameters:
config (:class:`~transformers.ASTConfig`):
The configuration class holding all parameters of this model. Initializing with a configuration file only
initializes the model configuration; it does not load the weights. For loading weights, use the
:meth:`~transformers.PreTrainedModel.from_pretrained` method.
"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
输入参数 `input_values`:
- 代表形状为 `(batch_size, max_length, num_mel_bins)` 的 `torch.FloatTensor`。
- 包含从原始音频波形提取的梅尔特征的浮点值。可以通过将 `.flac` 或 `.wav` 音频文件加载到 `List[float]` 或 `numpy.ndarray` 数组中获得原始音频波形。
- 要将数组准备成 `input_features`,应使用 [`AutoFeatureExtractor`] 提取梅尔特征,进行填充并转换为 `torch.FloatTensor` 类型的张量。参见 [`~ASTFeatureExtractor.__call__`]。
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
输入参数 `head_mask`(可选):
- 形状为 `(num_heads,)` 或 `(num_layers, num_heads)` 的 `torch.FloatTensor`。
- 用于屏蔽自注意力模块中选定头部的掩码。掩码值在 `[0, 1]` 范围内选择:
- 1 表示头部 **未被屏蔽**,
- 0 表示头部 **被屏蔽**。
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
输入参数 `output_attentions`(可选):
- 布尔值,指示是否返回所有注意力层的注意力张量。
- 查看返回的张量中的 `attentions` 以获取更多详细信息。
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
输入参数 `output_hidden_states`(可选):
- 布尔值,指示是否返回所有层的隐藏状态。
- 查看返回的张量中的 `hidden_states` 以获取更多详细信息。
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
输入参数 `return_dict`(可选):
- 布尔值,指示是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
This class defines a transformer model for AST (Audio Spectrogram Transformer) without a specific head for output.
@add_start_docstrings(
"The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
)
class ASTModel(ASTPreTrainedModel):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.config = config
# Initialize AST embeddings and encoder based on provided configuration
self.embeddings = ASTEmbeddings(config)
self.encoder = ASTEncoder(config)
# Apply layer normalization across the hidden size dimension
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ASTPatchEmbeddings:
# Retrieve the patch embeddings used for input
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel.
"""
for layer, heads in heads_to_prune.items():
# Prune specified attention heads in each encoder layer
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Perform forward pass of the AST model.
Args:
input_values (Optional[torch.Tensor]): Input tensor to the model.
head_mask (Optional[torch.Tensor]): Mask to nullify selected heads of the model.
output_attentions (Optional[bool]): Whether to output attentions weights.
output_hidden_states (Optional[bool]): Whether to output hidden states.
return_dict (Optional[bool]): Whether to return a dictionary.
Returns:
BaseModelOutputWithPooling: Output with pooled representation.
"""
) -> Union[Tuple, BaseModelOutputWithPooling]:
# 如果未显式指定output_attentions,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果未显式指定output_hidden_states,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果未显式指定return_dict,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_values is None:
# 如果输入值为None,则抛出数值错误
raise ValueError("You have to specify input_values")
# 准备头部掩码(如果需要)
# 头部掩码中的1.0表示保留该头部
# attention_probs的形状为bsz x n_heads x N x N
# 输入的head_mask形状为[num_heads]或[num_hidden_layers x num_heads]
# head_mask被转换为形状[num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 将输入值传递给嵌入层进行嵌入
embedding_output = self.embeddings(input_values)
# 将嵌入输出传递给编码器
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 取编码器输出的第一个元素作为序列输出
sequence_output = encoder_outputs[0]
# 序列输出经过LayerNorm层处理
sequence_output = self.layernorm(sequence_output)
# 计算池化输出,取序列输出的第一个和第二个位置的平均值
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
if not return_dict:
# 如果不要求返回字典,则返回元组形式的输出
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 如果要求返回字典形式的输出,则创建BaseModelOutputWithPooling对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ASTMLPHead(nn.Module):
def __init__(self, config: ASTConfig):
super().__init__()
# 初始化一个 LayerNorm 层,用于标准化隐藏状态
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 如果标签数量大于0,则使用全连接层作为分类器;否则使用恒等映射
self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
def forward(self, hidden_state):
# 对隐藏状态进行 LayerNorm 标准化
hidden_state = self.layernorm(hidden_state)
# 应用全连接层或恒等映射,得到分类结果
hidden_state = self.dense(hidden_state)
return hidden_state
@add_start_docstrings(
"""
Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
output) e.g. for datasets like AudioSet, Speech Commands v2.
""",
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
)
class ASTForAudioClassification(ASTPreTrainedModel):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
# 初始化 ASTModel 类,该类用于处理音频谱图的转换
self.audio_spectrogram_transformer = ASTModel(config)
# 分类器头部
self.classifier = ASTMLPHead(config)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 确定是否返回字典形式的输出,如果未指定则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用音频频谱变换器处理输入数据,获取输出
outputs = self.audio_spectrogram_transformer(
input_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从处理后的输出中获取汇聚输出(pooled_output)
pooled_output = outputs[1]
# 使用分类器对汇聚输出进行分类得到 logits
logits = self.classifier(pooled_output)
# 初始化损失为 None
loss = None
# 如果提供了标签,则计算损失
if labels is not None:
# 如果问题类型未定义,则根据标签数据类型和标签数量确定问题类型
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
# 根据问题类型选择相应的损失函数进行计算
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# 如果不需要返回字典形式的输出,则返回 logits 和可能的隐藏状态
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果需要返回字典形式的输出,则返回 SequenceClassifierOutput 对象
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 导入类型检查模块
from typing import TYPE_CHECKING
# 导入自定义异常和模块延迟加载工具
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
# 定义模块的导入结构字典,包含配置、特征提取和模型相关内容
_import_structure = {
"configuration_audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
],
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
}
# 检查是否存在torch库,若不存在则引发自定义的依赖不可用异常
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
# 如果torch可用,则添加模型相关的导入结构到_import_structure字典中
_import_structure["modeling_audio_spectrogram_transformer"] = [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ASTForAudioClassification",
"ASTModel",
"ASTPreTrainedModel",
]
# 如果是类型检查模式,则从各自的模块导入特定的符号
if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
)
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
# 同样地,检查是否存在torch库,若不存在则引发自定义的依赖不可用异常
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
# 如果torch可用,则从模型模块中导入特定的符号
from .modeling_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTForAudioClassification,
ASTModel,
ASTPreTrainedModel,
)
# 如果不是类型检查模式,则将当前模块设为一个LazyModule,用于延迟加载相关依赖
else:
import sys
# 设置当前模块的sys.modules,使其变为一个延迟加载的模块
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\auto\auto_factory.py
# 设置编码为 UTF-8,确保可以正确处理中文和其他特殊字符
# Copyright 2021 The HuggingFace Inc. team.
# 根据 Apache License, Version 2.0 授权许可,进行版权声明和许可信息的设置
# 导入必要的模块和函数
import copy # 导入 copy 模块,用于对象的深拷贝操作
import importlib # 导入 importlib 模块,用于动态导入模块和类
import json # 导入 json 模块,用于 JSON 数据的序列化和反序列化
import os # 导入 os 模块,用于操作系统相关功能的访问
import warnings # 导入 warnings 模块,用于警告的处理
from collections import OrderedDict # 从 collections 模块导入 OrderedDict 类,用于有序字典的创建
# 从其他模块中导入必要的函数和类
from ...configuration_utils import PretrainedConfig # 导入 PretrainedConfig 类,用于预训练模型配置管理
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code # 导入动态模块工具函数
from ...utils import (
CONFIG_NAME, # 从 utils 模块导入 CONFIG_NAME 常量,用于配置文件名
cached_file, # 导入 cached_file 函数,用于缓存文件处理
copy_func, # 导入 copy_func 函数,用于函数的复制
extract_commit_hash, # 导入 extract_commit_hash 函数,用于提取提交哈希值
find_adapter_config_file, # 导入 find_adapter_config_file 函数,用于查找适配器配置文件
is_peft_available, # 导入 is_peft_available 函数,用于检查 PEFT 是否可用
logging, # 导入 logging 模块,用于日志记录
requires_backends, # 导入 requires_backends 装饰器,用于声明后端依赖
)
# 从当前模块的子模块中导入必要的类和函数
from .configuration_auto import (
AutoConfig, # 导入 AutoConfig 类,用于自动配置模型
model_type_to_module_name, # 导入 model_type_to_module_name 函数,用于模型类型到模块名的映射
replace_list_option_in_docstrings, # 导入 replace_list_option_in_docstrings 函数,用于替换文档字符串中的列表选项
)
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义类的文档字符串,描述了一个通用模型类的用途和创建方法
CLASS_DOCSTRING = """
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
# 多行字符串,包含用于从配置文件实例化模型类的文档字符串
FROM_CONFIG_DOCSTRING = """
Instantiates one of the model classes of the library from a configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
Args:
config ([`PretrainedConfig`]):
The model class to instantiate is selected based on the configuration class:
List options
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
Examples:
```
>>> from transformers import AutoConfig, BaseAutoModelClass
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
>>> model = BaseAutoModelClass.from_config(config)
```
"""
# 空行
# 空行
"""
BaseAutoModelClass is a base class for auto models. It provides functionality to select and instantiate a model class based on a configuration.
"""
def _get_model_class(config, model_mapping):
# 获取与给定配置相对应的模型类
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
# 创建模型名称到模型类的映射字典
name_to_model = {model.__name__: model for model in supported_models}
# 从配置中获取架构信息
architectures = getattr(config, "architectures", [])
# 遍历架构信息,尝试匹配模型类
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
# 如果配置中未设置架构或未匹配到支持的模型类,则返回元组的第一个元素作为默认模型类
return supported_models[0]
# 空行
class _BaseAutoModelClass:
# BaseAutoModelClass 是自动模型的基类。
# 类变量,用于存储模型映射信息
_model_mapping = None
def __init__(self, *args, **kwargs):
# 抛出环境错误,提示应使用 from_pretrained 或 from_config 方法实例化模型类
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)
@classmethod
# 从配置中创建一个模型实例的类方法
def from_config(cls, config, **kwargs):
# 从 kwargs 中弹出 trust_remote_code 参数,若无则设为 None
trust_remote_code = kwargs.pop("trust_remote_code", None)
# 检查 config 是否具有 auto_map 属性,并且 cls.__name__ 是否在 auto_map 中
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
# 检查 config 的类型是否在 cls._model_mapping 字典的键中
has_local_code = type(config) in cls._model_mapping.keys()
# 解析 trust_remote_code 参数,确定是否信任远程代码
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code
)
# 如果存在远程代码并且信任远程代码
if has_remote_code and trust_remote_code:
# 从 config.auto_map 中获取类引用
class_ref = config.auto_map[cls.__name__]
# 若 class_ref 包含 "--" 分隔符,则分割出 repo_id 和 class_ref
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
else:
# 否则 repo_id 设为 config.name_or_path
repo_id = config.name_or_path
# 通过动态模块获取类对象 model_class
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
# 如果 config._name_or_path 是目录,则将 model_class 注册为自动类
if os.path.isdir(config._name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
# 否则使用 cls.register 方法注册 model_class
cls.register(config.__class__, model_class, exist_ok=True)
# 从 kwargs 中弹出 code_revision 参数,但不使用其值
_ = kwargs.pop("code_revision", None)
# 调用 model_class 的 _from_config 方法,返回结果
return model_class._from_config(config, **kwargs)
# 如果 config 的类型在 cls._model_mapping 字典的键中
elif type(config) in cls._model_mapping.keys():
# 从 _model_mapping 中获取对应的 model_class 类对象
model_class = _get_model_class(config, cls._model_mapping)
# 调用 model_class 的 _from_config 方法,返回结果
return model_class._from_config(config, **kwargs)
# 如果以上条件都不满足,则抛出 ValueError 异常
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
@classmethod
# 用于注册新模型类的类方法
def register(cls, config_class, model_class, exist_ok=False):
"""
Register a new model for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
model_class ([`PreTrainedModel`]):
The model to register.
"""
# 如果 model_class 具有 config_class 属性且不等于 config_class 参数,则引发 ValueError 异常
if hasattr(model_class, "config_class") and model_class.config_class != config_class:
raise ValueError(
"The model class you are passing has a `config_class` attribute that is not consistent with the "
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
"one of those so they match!"
)
# 调用 _model_mapping 的 register 方法注册 config_class 和 model_class
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
class _BaseAutoBackboneClass(_BaseAutoModelClass):
# Base class for auto backbone models.
_model_mapping = None
@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Ensure required backends are available
requires_backends(cls, ["vision", "timm"])
# Import TimmBackboneConfig from specific module
from ...models.timm_backbone import TimmBackboneConfig
# Set default configuration or use provided `config`
config = kwargs.pop("config", TimmBackboneConfig())
# Check for disallowed arguments
if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")
if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
# Set configuration parameters based on kwargs or defaults
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
# Create TimmBackboneConfig object with specified parameters
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
# Call superclass method `from_config` with the constructed config
return super().from_config(config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Determine if timm backbone should be used
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
# If `use_timm_backbone` is True, invoke `_load_timm_backbone_from_pretrained`
if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Otherwise, call superclass method `from_pretrained`
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def insert_head_doc(docstring, head_doc=""):
# Replace part of the docstring based on presence of `head_doc`
if len(head_doc) > 0:
return docstring.replace(
"one of the model classes of the library ",
f"one of the model classes of the library (with a {head_doc} head) ",
)
else:
return docstring.replace(
"one of the model classes of the library ", "one of the base model classes of the library "
)
def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
# Create a new class with updated documentation based on `head_doc`
model_mapping = cls._model_mapping
name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
# Replace `BaseAutoModelClass` with current class name in class docstring
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
# Copy `from_config` method from superclass `_BaseAutoModelClass`
from_config = copy_func(_BaseAutoModelClass.from_config)
# Update docstring of `from_config` method based on `head_doc`
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
# 将 from_config_docstring 中的 "checkpoint_placeholder" 替换为示例中的 checkpoint_for_example 变量值
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
# 将 from_config 的文档字符串设为替换后的 from_config_docstring
from_config.__doc__ = from_config_docstring
# 使用 model_mapping._model_mapping 中的信息替换 from_config 中的列表选项
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
# 将 from_config 方法设为类方法
cls.from_config = classmethod(from_config)
# 根据模型名称选择合适的 from_pretrained_docstring
if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
elif name.startswith("Flax"):
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
else:
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
# 复制 _BaseAutoModelClass.from_pretrained 方法为 from_pretrained
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
# 插入 head_doc 到 from_pretrained_docstring 的头部
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
# 替换 from_pretrained_docstring 中的类名和占位符
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
# 从 checkpoint_for_example 的路径中提取快捷名称
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
# 将 from_pretrained 方法的文档字符串设为替换后的 from_pretrained_docstring
from_pretrained.__doc__ = from_pretrained_docstring
# 使用 model_mapping._model_mapping 中的信息替换 from_pretrained 中的列表选项
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
# 将 from_pretrained 方法设为类方法
cls.from_pretrained = classmethod(from_pretrained)
# 返回修改后的类对象
return cls
# 定义函数 `get_values`,接收一个映射 `model_mapping`,返回所有值的列表
def get_values(model_mapping):
# 初始化一个空列表 `result` 用于存放结果
result = []
# 遍历 `model_mapping` 中的所有值
for model in model_mapping.values():
# 如果值是列表或元组,则将其扁平化后加入 `result`
if isinstance(model, (list, tuple)):
result += list(model)
else:
# 否则直接将值添加到 `result` 中
result.append(model)
# 返回处理后的结果列表
return result
# 定义函数 `getattribute_from_module`,根据模块和属性获取属性值
def getattribute_from_module(module, attr):
# 如果属性为 None,则返回 None
if attr is None:
return None
# 如果属性是元组,则递归获取每个元素的属性值并返回元组
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
# 如果模块具有指定属性,则返回该属性的值
if hasattr(module, attr):
return getattr(module, attr)
# 如果以上条件都不满足,则尝试从 `transformers` 模块中导入相应的模块
transformers_module = importlib.import_module("transformers")
if module != transformers_module:
try:
# 尝试在 `transformers` 模块中查找属性的值
return getattribute_from_module(transformers_module, attr)
except ValueError:
# 如果无法找到属性,则抛出 ValueError 异常
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
else:
# 如果模块是 `transformers` 且仍然找不到属性,则抛出 ValueError 异常
raise ValueError(f"Could not find {attr} in {transformers_module}!")
# 定义类 `_LazyAutoMapping`,继承自 `OrderedDict`
class _LazyAutoMapping(OrderedDict):
"""
一个映射配置到对象(例如模型或分词器),在访问时加载键和值的类。
Args:
- config_mapping: 模型类型到配置类的映射
- model_mapping: 模型类型到模型(或分词器)类的映射
"""
# 初始化方法,接收配置映射和模型映射作为参数
def __init__(self, config_mapping, model_mapping):
# 初始化 `_config_mapping` 和 `_reverse_config_mapping` 属性
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
# 初始化 `_model_mapping` 属性,并将 `_model_mapping` 的 `_model_mapping` 属性设置为当前对象自身
self._model_mapping = model_mapping
self._model_mapping._model_mapping = self
# 初始化 `_extra_content` 和 `_modules` 属性
self._extra_content = {}
self._modules = {}
# 返回映射的长度
def __len__(self):
# 计算 `_config_mapping` 和 `_model_mapping` 公共键的数量,并加上 `_extra_content` 的长度
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
return len(common_keys) + len(self._extra_content)
# 根据键获取值的方法
def __getitem__(self, key):
# 如果键在 `_extra_content` 中,则返回其对应的值
if key in self._extra_content:
return self._extra_content[key]
# 根据键获取模型类型
model_type = self._reverse_config_mapping[key.__name__]
# 如果模型类型在 `_model_mapping` 中,则获取相应模型的属性值
if model_type in self._model_mapping:
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
# 如果一个配置关联了多个模型类型,则尝试获取每个模型类型对应的属性值
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
# 如果未找到匹配的键,则抛出 KeyError 异常
raise KeyError(key)
# 根据模型类型和属性名从模块中加载属性值的私有方法
def _load_attr_from_module(self, model_type, attr):
# 获取模型类型对应的模块名称
module_name = model_type_to_module_name(model_type)
# 如果模块名称不在 `_modules` 中,则导入该模块
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
# 调用 `getattribute_from_module` 函数获取模块中属性的值并返回
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
# 从配置映射中加载属性,形成映射键列表
mapping_keys = [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
# 返回映射键列表加上额外内容的键列表
return mapping_keys + list(self._extra_content.keys())
def get(self, key, default):
try:
# 调用 __getitem__ 方法获取键对应的值
return self.__getitem__(key)
except KeyError:
# 如果键不存在,则返回默认值
return default
def __bool__(self):
# 返回映射键的布尔值
return bool(self.keys())
def values(self):
# 从模型映射中加载属性,形成映射值列表
mapping_values = [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
# 返回映射值列表加上额外内容的值列表
return mapping_values + list(self._extra_content.values())
def items(self):
# 从模型映射和配置映射中加载属性,形成映射项列表
mapping_items = [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
# 返回映射项列表加上额外内容的项列表
return mapping_items + list(self._extra_content.items())
def __iter__(self):
# 返回迭代器,迭代映射键
return iter(self.keys())
def __contains__(self, item):
# 检查额外内容中是否包含指定项
if item in self._extra_content:
return True
# 检查项是否具有 "__name__" 属性且其名称不在反向配置映射中
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
# 获取项的模型类型
model_type = self._reverse_config_mapping[item.__name__]
# 检查模型类型是否在模型映射中
return model_type in self._model_mapping
def register(self, key, value, exist_ok=False):
"""
Register a new model in this mapping.
"""
# 如果键具有 "__name__" 属性且其名称在反向配置映射中
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
# 获取键对应的模型类型
model_type = self._reverse_config_mapping[key.__name__]
# 如果模型类型在模型映射中且不允许覆盖,则引发值错误异常
if model_type in self._model_mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers model.")
# 向额外内容映射中注册新的键值对
self._extra_content[key] = value
.\models\auto\configuration_auto.py
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Config class."""
# 导入标准库和第三方模块
import importlib
import os
import re
import warnings
from collections import OrderedDict
from typing import List, Union
# 导入自定义模块
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, logging
# 获取 logger 对象
logger = logging.get_logger(__name__)
# 定义用于配置映射、模型映射和归档映射的有序字典
CONFIG_MAPPING_NAMES = OrderedDict(
[]
)
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[]
)
MODEL_NAMES_MAPPING = OrderedDict(
[]
)
# 被废弃的模型类型列表,需要将 "-" 转换为 "_"
DEPRECATED_MODELS = [
"bort",
"mctct",
"mmbt",
"open_llama",
"retribert",
"tapex",
"trajectory_transformer",
"transfo_xl",
"van",
]
# 特殊模型类型到模块名的映射
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
[
("openai-gpt", "openai"),
("data2vec-audio", "data2vec"),
("data2vec-text", "data2vec"),
("data2vec-vision", "data2vec"),
("donut-swin", "donut"),
("kosmos-2", "kosmos2"),
("maskformer-swin", "maskformer"),
("xclip", "x_clip"),
("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
]
)
def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
# 特殊模型类型的特殊处理
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
# 将 "-" 转换为 "_",处理被废弃的模型类型
key = key.replace("-", "_")
if key in DEPRECATED_MODELS:
key = f"deprecated.{key}"
return key
def config_class_to_model_type(config):
"""Converts a config class name to the corresponding model type"""
# 在 CONFIG_MAPPING_NAMES 中查找与 config 类相匹配的键
for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config:
return key
# 如果在 CONFIG_MAPPING_NAMES 中找不到,则在额外内容中查找
for key, cls in CONFIG_MAPPING._extra_content.items():
if cls.__name__ == config:
return key
return None
class _LazyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping
self._extra_content = {}
self._modules = {}
# 这里可以添加 LazyConfigMapping 类的其他方法,但未提供在这个片段中
# 定义魔术方法 __getitem__(),实现索引操作
def __getitem__(self, key):
# 如果键存在于额外内容中,则返回额外内容中对应的值
if key in self._extra_content:
return self._extra_content[key]
# 如果键不存在于映射中,则引发 KeyError 异常
if key not in self._mapping:
raise KeyError(key)
# 获取键在映射中对应的值
value = self._mapping[key]
# 根据键获取模型类型对应的模块名
module_name = model_type_to_module_name(key)
# 如果模块名不在已加载的模块集合中,则动态导入对应模块
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
# 如果模块中存在对应的属性,则返回该属性值
if hasattr(self._modules[module_name], value):
return getattr(self._modules[module_name], value)
# 某些映射可能指向另一个模型类型的配置对象,此时尝试获取顶层对象
transformers_module = importlib.import_module("transformers")
return getattr(transformers_module, value)
# 返回映射的键列表和额外内容的键列表的合并
def keys(self):
return list(self._mapping.keys()) + list(self._extra_content.keys())
# 返回映射的值列表和额外内容的值列表的合并
def values(self):
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
# 返回映射的键值对列表和额外内容的键值对列表的合并
def items(self):
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
# 返回一个迭代器,迭代器包含映射的键和额外内容的键
def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
# 检查给定的项是否存在于映射或额外内容中
def __contains__(self, item):
return item in self._mapping or item in self._extra_content
# 将新的配置注册到映射中
def register(self, key, value, exist_ok=False):
"""
Register a new configuration in this mapping.
"""
# 如果键已经存在于映射中且不允许覆盖,则引发 ValueError 异常
if key in self._mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
# 否则将键值对添加到额外内容中
self._extra_content[key] = value
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
# 创建一个懒加载的配置映射对象,根据给定的配置映射名称列表 CONFIG_MAPPING_NAMES
class _LazyLoadAllMappings(OrderedDict):
"""
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
etc.)
Args:
mapping: The mapping to load.
"""
def __init__(self, mapping):
self._mapping = mapping
self._initialized = False # 初始化标志位,表示映射是否已经初始化
self._data = {} # 存储加载后的映射数据的字典
def _initialize(self):
if self._initialized: # 如果已经初始化过,则直接返回
return
warnings.warn(
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
FutureWarning,
)
# 遍历配置映射,加载模块并更新数据字典
for model_type, map_name in self._mapping.items():
module_name = model_type_to_module_name(model_type) # 获取模块名称
module = importlib.import_module(f".{module_name}", "transformers.models") # 动态导入模块
mapping = getattr(module, map_name) # 获取模块中的映射
self._data.update(mapping) # 更新数据字典
self._initialized = True # 设置初始化标志为 True,表示已完成初始化
def __getitem__(self, key):
self._initialize() # 确保初始化完成
return self._data[key] # 返回指定键的值
def keys(self):
self._initialize() # 确保初始化完成
return self._data.keys() # 返回所有键的视图
def values(self):
self._initialize() # 确保初始化完成
return self._data.values() # 返回所有值的视图
def items(self):
self._initialize() # 确保初始化完成
return self._data.keys() # 返回所有键-值对的视图
def __iter__(self):
self._initialize() # 确保初始化完成
return iter(self._data) # 返回迭代器,用于迭代所有键
def __contains__(self, item):
self._initialize() # 确保初始化完成
return item in self._data # 检查指定项是否在数据字典中
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)
# 创建一个懒加载的预训练配置存档映射对象,根据给定的配置映射名称列表 CONFIG_ARCHIVE_MAP_MAPPING_NAMES
def _get_class_name(model_class: Union[str, List[str]]):
if isinstance(model_class, (list, tuple)):
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
return f"[`{model_class}`]"
# 返回格式化的模型类名称字符串,接受字符串或字符串列表参数
def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types:
if config_to_class is None:
model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
else:
model_type_to_name = {
model_type: _get_class_name(model_class)
for model_type, model_class in config_to_class.items()
if model_type in MODEL_NAMES_MAPPING
}
lines = [
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
# 构建模型选项列表,包括模型类型名称和关联的模型类名称
else:
# 创建一个字典,将配置映射到类名
config_to_name = {
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
for config, clas in config_to_class.items()
if config in CONFIG_MAPPING_NAMES
}
# 创建另一个字典,将配置映射到模型名称
config_to_model_name = {
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
}
# 生成包含配置信息的行列表
lines = [
# 每行格式为:"- [`配置名`] configuration class: 类名 (模型名称 model)"
f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
# 返回以换行符连接的行字符串
return "\n".join(lines)
# 定义一个装饰器函数,用于替换函数的文档字符串中的特定部分,以生成新的文档字符串
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
# 实际的装饰器函数,接受被装饰的函数 fn 作为参数
def docstring_decorator(fn):
# 获取函数 fn 的文档字符串
docstrings = fn.__doc__
# 将文档字符串按行分割成列表
lines = docstrings.split("\n")
i = 0
# 查找以指定格式开始的行,以定位到“List options”部分
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
i += 1
# 如果找到了符合格式的行
if i < len(lines):
# 提取缩进信息,用于替换“List options”部分
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
# 如果 use_model_types 为真,追加额外的缩进
if use_model_types:
indent = f"{indent} "
# 替换文档字符串中的“List options”部分为具体内容
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
# 更新修改后的文档字符串
docstrings = "\n".join(lines)
else:
# 如果未找到符合格式的行,则抛出异常
raise ValueError(
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
f" docstring is:\n{docstrings}"
)
# 将更新后的文档字符串赋回给函数的 __doc__ 属性
fn.__doc__ = docstrings
# 返回经装饰后的函数
return fn
# 返回装饰器函数本身
return docstring_decorator
# 定义一个配置类 AutoConfig
class AutoConfig:
r"""
This is a generic configuration class that will be instantiated as one of the configuration classes of the library
when created with the [`~AutoConfig.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
# 禁止直接实例化该类,抛出环境错误异常
def __init__(self):
raise EnvironmentError(
"AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
# 类方法,根据 model_type 返回相应的配置类实例
@classmethod
def for_model(cls, model_type: str, *args, **kwargs):
# 如果 model_type 在 CONFIG_MAPPING 中注册过,则返回相应的配置类实例
if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type]
return config_class(*args, **kwargs)
# 如果 model_type 未注册,则抛出值错误异常
raise ValueError(
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
)
# 静态方法,用于注册新的配置
@classmethod
@replace_list_option_in_docstrings() # 应用装饰器,替换文档字符串中的“List options”部分
def register(model_type, config, exist_ok=False):
"""
Register a new configuration for this class.
Args:
model_type (`str`): The model type like "bert" or "gpt".
config ([`PretrainedConfig`]): The config to register.
"""
# 如果 config 是 PretrainedConfig 的子类且其 model_type 不与传入的 model_type 一致,则抛出值错误异常
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
raise ValueError(
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
"match!"
)
# 调用 CONFIG_MAPPING 的 register 方法注册新的配置
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
# 设置脚本的编码格式为 UTF-8
# 版权声明,使用 Apache License Version 2.0 授权许可
# 只有遵循许可证的条款,才能使用该文件
# 可以从 http://www.apache.org/licenses/LICENSE-2.0 获取许可证的副本
# 除非适用法律要求或书面同意,否则不得使用此文件
# 此软件根据 "原样" 分发,不提供任何形式的明示或暗示担保或条件
# 有关详细信息,请参阅许可证
""" AutoFeatureExtractor class."""
# 导入必要的模块
import importlib # 动态导入模块的功能
import json # 处理 JSON 数据的模块
import os # 提供与操作系统相关的功能
import warnings # 控制警告信息的输出
from collections import OrderedDict # 提供有序字典的数据结构
from typing import Dict, Optional, Union # 导入类型提示所需的类型
# 导入 transformers 库中的其他模块和函数
from ...configuration_utils import PretrainedConfig # 预训练配置类
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code # 动态模块相关工具函数
from ...feature_extraction_utils import FeatureExtractionMixin # 特征提取混合类
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging # 提供各种实用功能的工具函数
from .auto_factory import _LazyAutoMapping # 自动工厂类的延迟映射
from .configuration_auto import (
CONFIG_MAPPING_NAMES, # 配置映射名称列表
AutoConfig, # 自动配置类
model_type_to_module_name, # 模型类型到模块名称的映射函数
replace_list_option_in_docstrings, # 在文档字符串中替换列表选项的函数
)
# 获取日志记录器对象
logger = logging.get_logger(__name__)
# 特征提取器映射名称的有序字典定义
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
# 此处应该有一些特征提取器的映射条目,但代码片段中省略了具体内容
]
)
# 基于配置映射名称和特征提取器映射名称创建特征提取器映射对象
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
def feature_extractor_class_from_name(class_name: str):
"""
根据特征提取器类名获取对应的特征提取器类对象。
Args:
class_name (str): 特征提取器类的名称。
Returns:
type or None: 如果找到匹配的特征提取器类,则返回该类对象;否则返回 None。
"""
# 遍历特征提取器映射名称字典中的模块名称和特征提取器类列表
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
# 如果 class_name 在当前模块的特征提取器类列表中
if class_name in extractors:
# 将模型类型转换为相应的模块名称
module_name = model_type_to_module_name(module_name)
# 在 transformers.models 下动态导入对应的模块
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
# 返回特征提取器类对象
return getattr(module, class_name)
except AttributeError:
continue
# 在额外内容中查找特征提取器对象
for _, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
# 如果特征提取器对象的 __name__ 属性等于 class_name
if getattr(extractor, "__name__", None) == class_name:
return extractor
# 如果在当前模块中找不到特征提取器类,可能是由于依赖项丢失,此时返回适当的 dummy 类以获得适当的错误消息
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
# 如果找不到匹配的特征提取器类,则返回 None
return None
def get_feature_extractor_config(
pretrained_model_name_or_path: Union[str, os.PathLike], # 预训练模型名称或路径
cache_dir: Optional[Union[str, os.PathLike]] = None, # 缓存目录,可选
force_download: bool = False, # 是否强制下载
resume_download: bool = False, # 是否恢复下载
proxies: Optional[Dict[str, str]] = None, # 代理设置
token: Optional[Union[bool, str]] = None, # 访问令牌,可选
revision: Optional[str] = None, # 仓库的版本号,可选
local_files_only: bool = False, # 仅使用本地文件
**kwargs, # 其他关键字参数
):
"""
从预训练模型加载特征提取器的配置信息。
Args:
pretrained_model_name_or_path (Union[str, os.PathLike]): 预训练模型的名称或路径。
cache_dir (Optional[Union[str, os.PathLike]], optional): 缓存目录路径,可选参数。默认为 None。
force_download (bool, optional): 是否强制下载,默认为 False。
resume_download (bool, optional): 是否恢复下载,默认为 False。
proxies (Optional[Dict[str, str]], optional): 代理设置,可选参数。默认为 None。
token (Optional[Union[bool, str]], optional): 访问令牌,可选参数。默认为 None。
revision (Optional[str], optional): 仓库的版本号,可选参数。默认为 None。
local_files_only (bool, optional): 是否仅使用本地文件,默认为 False。
**kwargs: 其他关键字参数。
Returns:
None
"""
pass # 函数体未实现,仅有文档字符串提示函数用途
# 从参数中获取 `use_auth_token`,如果存在则弹出并赋值给 `use_auth_token` 变量,否则设置为 `None`
use_auth_token = kwargs.pop("use_auth_token", None)
# 如果 use_auth_token 参数不为 None,则发出警告,说明该参数在将来版本中会被移除
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了 token 参数,则抛出数值错误,提示只能设置 `token` 参数
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将 token 参数设置为 use_auth_token 的值
token = use_auth_token
# 获取预训练模型名称或路径对应的特征提取器配置文件路径
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
# 如果未找到特征提取器配置文件,则记录日志并返回空字典
if resolved_config_file is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}
# 使用 UTF-8 编码打开特征提取器配置文件,并加载为 JSON 格式返回
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
# 抛出环境错误,阻止直接通过 __init__() 实例化该类
raise EnvironmentError(
"AutoFeatureExtractor is designed to be instantiated "
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
@staticmethod
def register(config_class, feature_extractor_class, exist_ok=False):
"""
Register a new feature extractor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
"""
# 使用 FEATURE_EXTRACTOR_MAPPING 的 register 方法注册新的特征提取器类
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
.\models\auto\image_processing_auto.py
# 设置编码格式为 UTF-8
# 版权声明,指明代码版权归 HuggingFace Inc. 团队所有
# 使用 Apache License, Version 2.0 许可协议,详见链接
# 除非法律另有规定或书面同意,否则不得使用本文件
# 详细信息请查看许可协议:http://www.apache.org/licenses/LICENSE-2.0
# 引入 warnings 库,用于发出警告信息
import warnings
# collections 模块中的 OrderedDict 类,用于创建有序字典
from collections import OrderedDict
# typing 模块,用于类型提示
from typing import Dict, Optional, Union
# 从相应模块中导入函数和类
# configuration_utils 模块中的 PretrainedConfig 类
from ...configuration_utils import PretrainedConfig
# dynamic_module_utils 中的函数,用于从动态模块中获取类
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
# image_processing_utils 中的 ImageProcessingMixin 类
from ...image_processing_utils import ImageProcessingMixin
# utils 中的各种实用函数和常量
from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging
# 从当前包中导入 auto_factory 模块的 _LazyAutoMapping 类
from .auto_factory import _LazyAutoMapping
# 从当前包中导入 configuration_auto 模块中的若干变量和函数
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
# 获取 logger 对象
logger = logging.get_logger(__name__)
# 定义 IMAGE_PROCESSOR_MAPPING_NAMES 为有序字典
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
# 这里原本应该有具体的映射关系,由开发者补充完整
# 类似 {'module_name': ['extractor1', 'extractor2']}
# 用于存储映射关系
)
# 使用 _LazyAutoMapping 类创建 IMAGE_PROCESSOR_MAPPING 对象
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
# 根据类名从 IMAGE_PROCESSOR_MAPPING_NAMES 中获取对应的处理器类
def image_processor_class_from_name(class_name: str):
for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
# 遍历映射字典,查找匹配的类名
if class_name in extractors:
# 将模块名转换为模块的实际名称
module_name = model_type_to_module_name(module_name)
# 动态导入相应模块
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
# 返回模块中对应的类对象
return getattr(module, class_name)
except AttributeError:
continue
# 如果在 IMAGE_PROCESSOR_MAPPING_NAMES 中未找到对应类名,则遍历额外内容
for _, extractor in IMAGE_PROCESSOR_MAPPING._extra_content.items():
# 检查额外内容中是否包含与类名匹配的对象
if getattr(extractor, "__name__", None) == class_name:
return extractor
# 若以上方法均未找到匹配的类名,则从主模块中导入,返回对应的类对象或 None
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
# 加载预训练模型的图像处理器配置信息
def get_image_processor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
从预训练模型的图像处理器配置中加载图像处理器配置信息。
"""
# 函数体内容尚未给出,需由开发者补充完整
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the image processor configuration from local files.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the image processor.
Examples:
```
# Download configuration from huggingface.co and cache.
image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
# This model does not have a image processor config so the result will be an empty dict.
image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
# Save a pretrained image processor locally and you can reload its config
from transformers import AutoTokenizer
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
image_processor.save_pretrained("image-processor-test")
image_processor_config = get_image_processor_config("image-processor-test")
```
"""
use_auth_token = kwargs.pop("use_auth_token", None)
# 如果 use_auth_token 参数不为 None,则发出警告,提醒该参数将在 Transformers v5 版本中被移除
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了 token 参数和 use_auth_token 参数,则抛出数值错误
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将 token 参数设置为 use_auth_token 参数的值
token = use_auth_token
# 从指定的预训练模型名或路径中获取配置文件路径
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
IMAGE_PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
# 如果未能定位到图像处理器配置文件,则记录信息并返回空字典
if resolved_config_file is None:
logger.info(
"Could not locate the image processor configuration file, will try to use the model config instead."
)
return {}
# 打开配置文件并以 UTF-8 编码读取其中的内容,解析为 JSON 格式返回
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoImageProcessor:
r"""
This is a generic image processor class that will be instantiated as one of the image processor classes of the
library when created with the [`AutoImageProcessor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
# 抛出环境错误,阻止直接实例化该类
raise EnvironmentError(
"AutoImageProcessor is designed to be instantiated "
"using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
@staticmethod
def register(config_class, image_processor_class, exist_ok=False):
"""
Register a new image processor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
"""
# 调用全局注册函数,将给定的配置类和图像处理器类注册到映射表中
IMAGE_PROCESSOR_MAPPING.register(config_class, image_processor_class, exist_ok=exist_ok)
.\models\auto\modeling_auto.py
# 设置文件编码为 UTF-8
# 版权声明和许可信息
#
# 根据 Apache 许可证版本 2.0 授权使用此文件
# 除非符合许可证的条件,否则不得使用此文件
# 可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件基于"原样"分发,无任何担保或条件
# 请查阅许可证了解具体的法律条文和允许条件
""" Auto Model class."""
# 导入警告模块
import warnings
# 导入有序字典模块
from collections import OrderedDict
# 导入日志记录工具
from ...utils import logging
# 从 auto_factory 模块导入相关类和函数
from .auto_factory import (
_BaseAutoBackboneClass,
_BaseAutoModelClass,
_LazyAutoMapping,
auto_class_update,
)
# 导入自动生成的配置映射
from .configuration_auto import CONFIG_MAPPING_NAMES
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义模型映射名称的有序字典
MODEL_MAPPING_NAMES = OrderedDict(
# 这里是一个空的有序字典,用于存储模型映射名称
)
# 定义用于预训练的模型映射名称的有序字典
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
# 这里是一个空的有序字典,用于存储预训练模型映射名称
)
# 定义带语言模型头部的模型映射名称的有序字典
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
# 这里是一个空的有序字典,用于存储带语言模型头部的模型映射名称
)
# 定义用于因果语言模型的模型映射名称的有序字典
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
# 这里是一个空的有序字典,用于存储因果语言模型的模型映射名称
)
# 定义用于图像任务的模型映射名称的有序字典
MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
# 这里是一个空的有序字典,用于存储图像任务的模型映射名称
)
# 创建一个元组列表,每个元组包含模型的名称和相应的模型类名
[
# 模型名 "beit" 对应的模型类名 "BeitModel"
("beit", "BeitModel"),
# 模型名 "bit" 对应的模型类名 "BitModel"
("bit", "BitModel"),
# 模型名 "conditional_detr" 对应的模型类名 "ConditionalDetrModel"
("conditional_detr", "ConditionalDetrModel"),
# 模型名 "convnext" 对应的模型类名 "ConvNextModel"
("convnext", "ConvNextModel"),
# 模型名 "convnextv2" 对应的模型类名 "ConvNextV2Model"
("convnextv2", "ConvNextV2Model"),
# 模型名 "data2vec-vision" 对应的模型类名 "Data2VecVisionModel"
("data2vec-vision", "Data2VecVisionModel"),
# 模型名 "deformable_detr" 对应的模型类名 "DeformableDetrModel"
("deformable_detr", "DeformableDetrModel"),
# 模型名 "deit" 对应的模型类名 "DeiTModel"
("deit", "DeiTModel"),
# 模型名 "deta" 对应的模型类名 "DetaModel"
("deta", "DetaModel"),
# 模型名 "detr" 对应的模型类名 "DetrModel"
("detr", "DetrModel"),
# 模型名 "dinat" 对应的模型类名 "DinatModel"
("dinat", "DinatModel"),
# 模型名 "dinov2" 对应的模型类名 "Dinov2Model"
("dinov2", "Dinov2Model"),
# 模型名 "dpt" 对应的模型类名 "DPTModel"
("dpt", "DPTModel"),
# 模型名 "efficientformer" 对应的模型类名 "EfficientFormerModel"
("efficientformer", "EfficientFormerModel"),
# 模型名 "efficientnet" 对应的模型类名 "EfficientNetModel"
("efficientnet", "EfficientNetModel"),
# 模型名 "focalnet" 对应的模型类名 "FocalNetModel"
("focalnet", "FocalNetModel"),
# 模型名 "glpn" 对应的模型类名 "GLPNModel"
("glpn", "GLPNModel"),
# 模型名 "imagegpt" 对应的模型类名 "ImageGPTModel"
("imagegpt", "ImageGPTModel"),
# 模型名 "levit" 对应的模型类名 "LevitModel"
("levit", "LevitModel"),
# 模型名 "mobilenet_v1" 对应的模型类名 "MobileNetV1Model"
("mobilenet_v1", "MobileNetV1Model"),
# 模型名 "mobilenet_v2" 对应的模型类名 "MobileNetV2Model"
("mobilenet_v2", "MobileNetV2Model"),
# 模型名 "mobilevit" 对应的模型类名 "MobileViTModel"
("mobilevit", "MobileViTModel"),
# 模型名 "mobilevitv2" 对应的模型类名 "MobileViTV2Model"
("mobilevitv2", "MobileViTV2Model"),
# 模型名 "nat" 对应的模型类名 "NatModel"
("nat", "NatModel"),
# 模型名 "poolformer" 对应的模型类名 "PoolFormerModel"
("poolformer", "PoolFormerModel"),
# 模型名 "pvt" 对应的模型类名 "PvtModel"
("pvt", "PvtModel"),
# 模型名 "regnet" 对应的模型类名 "RegNetModel"
("regnet", "RegNetModel"),
# 模型名 "resnet" 对应的模型类名 "ResNetModel"
("resnet", "ResNetModel"),
# 模型名 "segformer" 对应的模型类名 "SegformerModel"
("segformer", "SegformerModel"),
# 模型名 "siglip_vision_model" 对应的模型类名 "SiglipVisionModel"
("siglip_vision_model", "SiglipVisionModel"),
# 模型名 "swiftformer" 对应的模型类名 "SwiftFormerModel"
("swiftformer", "SwiftFormerModel"),
# 模型名 "swin" 对应的模型类名 "SwinModel"
("swin", "SwinModel"),
# 模型名 "swin2sr" 对应的模型类名 "Swin2SRModel"
("swin2sr", "Swin2SRModel"),
# 模型名 "swinv2" 对应的模型类名 "Swinv2Model"
("swinv2", "Swinv2Model"),
# 模型名 "table-transformer" 对应的模型类名 "TableTransformerModel"
("table-transformer", "TableTransformerModel"),
# 模型名 "timesformer" 对应的模型类名 "TimesformerModel"
("timesformer", "TimesformerModel"),
# 模型名 "timm_backbone" 对应的模型类名 "TimmBackbone"
("timm_backbone", "TimmBackbone"),
# 模型名 "van" 对应的模型类名 "VanModel"
("van", "VanModel"),
# 模型名 "videomae" 对应的模型类名 "VideoMAEModel"
("videomae", "VideoMAEModel"),
# 模型名 "vit" 对应的模型类名 "ViTModel"
("vit", "ViTModel"),
# 模型名 "vit_hybrid" 对应的模型类名 "ViTHybridModel"
("vit_hybrid", "ViTHybridModel"),
# 模型名 "vit_mae" 对应的模型类名 "ViTMAEModel"
("vit_mae", "ViTMAEModel"),
# 模型名 "vit_msn" 对应的模型类名 "ViTMSNModel"
("vit_msn", "ViTMSNModel"),
# 模型名 "vitdet" 对应的模型类名 "VitDetModel"
("vitdet", "VitDetModel"),
# 模型名 "vivit" 对应的模型类名 "VivitModel"
("vivit", "VivitModel"),
# 模型名 "yolos" 对应的模型类名 "YolosModel"
("yolos", "YolosModel"),
]
# 定义一个有序字典,映射不同模型到相应的类名称,用于掩模图像建模模型
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "DeiTForMaskedImageModeling"), # 将 "deit" 映射到 "DeiTForMaskedImageModeling"
("focalnet", "FocalNetForMaskedImageModeling"), # 将 "focalnet" 映射到 "FocalNetForMaskedImageModeling"
("swin", "SwinForMaskedImageModeling"), # 将 "swin" 映射到 "SwinForMaskedImageModeling"
("swinv2", "Swinv2ForMaskedImageModeling"), # 将 "swinv2" 映射到 "Swinv2ForMaskedImageModeling"
("vit", "ViTForMaskedImageModeling"), # 将 "vit" 映射到 "ViTForMaskedImageModeling"
]
)
# 定义一个有序字典,映射不同模型到因果图像建模模型的类名称
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("imagegpt", "ImageGPTForCausalImageModeling"), # 将 "imagegpt" 映射到 "ImageGPTForCausalImageModeling"
]
)
# 定义一个有序字典,映射不同模型到图像分类模型的类名称
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 定义了多个模型名称与对应的类名的映射关系
("beit", "BeitForImageClassification"), # BEiT 模型的图像分类器
("bit", "BitForImageClassification"), # BiT 模型的图像分类器
("clip", "CLIPForImageClassification"), # CLIP 模型的图像分类器
("convnext", "ConvNextForImageClassification"), # ConvNext 模型的图像分类器
("convnextv2", "ConvNextV2ForImageClassification"), # ConvNextV2 模型的图像分类器
("cvt", "CvtForImageClassification"), # CvT 模型的图像分类器
("data2vec-vision", "Data2VecVisionForImageClassification"), # Data2VecVision 模型的图像分类器
(
"deit",
("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"), # DeiT 模型的图像分类器及其带教师的版本
),
("dinat", "DinatForImageClassification"), # DINO 模型的图像分类器
("dinov2", "Dinov2ForImageClassification"), # DINOv2 模型的图像分类器
(
"efficientformer",
(
"EfficientFormerForImageClassification", # EfficientFormer 模型的图像分类器
"EfficientFormerForImageClassificationWithTeacher", # EfficientFormer 模型的图像分类器带教师版本
),
),
("efficientnet", "EfficientNetForImageClassification"), # EfficientNet 模型的图像分类器
("focalnet", "FocalNetForImageClassification"), # FocalNet 模型的图像分类器
("imagegpt", "ImageGPTForImageClassification"), # ImageGPT 模型的图像分类器
(
"levit",
("LevitForImageClassification", "LevitForImageClassificationWithTeacher"), # LeViT 模型的图像分类器及其带教师的版本
),
("mobilenet_v1", "MobileNetV1ForImageClassification"), # MobileNetV1 模型的图像分类器
("mobilenet_v2", "MobileNetV2ForImageClassification"), # MobileNetV2 模型的图像分类器
("mobilevit", "MobileViTForImageClassification"), # MobileViT 模型的图像分类器
("mobilevitv2", "MobileViTV2ForImageClassification"), # MobileViTV2 模型的图像分类器
("nat", "NatForImageClassification"), # NAT 模型的图像分类器
(
"perceiver",
(
"PerceiverForImageClassificationLearned", # Perceiver 模型的图像分类器(学习)
"PerceiverForImageClassificationFourier", # Perceiver 模型的图像分类器(Fourier变换)
"PerceiverForImageClassificationConvProcessing", # Perceiver 模型的图像分类器(卷积处理)
),
),
("poolformer", "PoolFormerForImageClassification"), # PoolFormer 模型的图像分类器
("pvt", "PvtForImageClassification"), # PVT 模型的图像分类器
("pvt_v2", "PvtV2ForImageClassification"), # PvtV2 模型的图像分类器
("regnet", "RegNetForImageClassification"), # RegNet 模型的图像分类器
("resnet", "ResNetForImageClassification"), # ResNet 模型的图像分类器
("segformer", "SegformerForImageClassification"), # Segformer 模型的图像分类器
("siglip", "SiglipForImageClassification"), # Siglip 模型的图像分类器
("swiftformer", "SwiftFormerForImageClassification"), # SwiftFormer 模型的图像分类器
("swin", "SwinForImageClassification"), # Swin 模型的图像分类器
("swinv2", "Swinv2ForImageClassification"), # SwinV2 模型的图像分类器
("van", "VanForImageClassification"), # ViT 模型的图像分类器
("vit", "ViTForImageClassification"), # ViT 模型的图像分类器
("vit_hybrid", "ViTHybridForImageClassification"), # ViT 混合模型的图像分类器
("vit_msn", "ViTMSNForImageClassification"), # ViT-MSN 模型的图像分类器
]
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于图像分割模型
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# 不要在这里添加新的模型,此类将来会被弃用。
# 图像分割模型的映射
("detr", "DetrForSegmentation"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于语义分割模型
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# 语义分割模型的映射
("beit", "BeitForSemanticSegmentation"),
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
("dpt", "DPTForSemanticSegmentation"),
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
("mobilevit", "MobileViTForSemanticSegmentation"),
("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
("upernet", "UperNetForSemanticSegmentation"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于实例分割模型
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# 实例分割模型的映射
# MaskFormerForInstanceSegmentation 在 v5 中可以从这个映射中移除
("maskformer", "MaskFormerForInstanceSegmentation"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于通用分割模型
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# 通用分割模型的映射
("detr", "DetrForSegmentation"),
("mask2former", "Mask2FormerForUniversalSegmentation"),
("maskformer", "MaskFormerForInstanceSegmentation"),
("oneformer", "OneFormerForUniversalSegmentation"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于视频分类模型
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("timesformer", "TimesformerForVideoClassification"),
("videomae", "VideoMAEForVideoClassification"),
("vivit", "VivitForVideoClassification"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于视觉到序列模型
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("git", "GitForCausalLM"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
# 定义一个有序字典,映射不同的模型名称到对应的类名,用于掩码语言建模模型
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# 模型名称与对应的 PyTorch 模型类名的映射关系列表
("albert", "AlbertForMaskedLM"), # Albert 模型用于 Masked LM
("bart", "BartForConditionalGeneration"), # Bart 模型用于条件生成
("bert", "BertForMaskedLM"), # Bert 模型用于 Masked LM
("big_bird", "BigBirdForMaskedLM"), # BigBird 模型用于 Masked LM
("camembert", "CamembertForMaskedLM"), # Camembert 模型用于 Masked LM
("convbert", "ConvBertForMaskedLM"), # ConvBert 模型用于 Masked LM
("data2vec-text", "Data2VecTextForMaskedLM"), # Data2Vec-Text 模型用于 Masked LM
("deberta", "DebertaForMaskedLM"), # Deberta 模型用于 Masked LM
("deberta-v2", "DebertaV2ForMaskedLM"), # Deberta-v2 模型用于 Masked LM
("distilbert", "DistilBertForMaskedLM"), # DistilBert 模型用于 Masked LM
("electra", "ElectraForMaskedLM"), # Electra 模型用于 Masked LM
("ernie", "ErnieForMaskedLM"), # Ernie 模型用于 Masked LM
("esm", "EsmForMaskedLM"), # ESM 模型用于 Masked LM
("flaubert", "FlaubertWithLMHeadModel"), # Flaubert 模型用于 Masked LM
("fnet", "FNetForMaskedLM"), # FNet 模型用于 Masked LM
("funnel", "FunnelForMaskedLM"), # Funnel 模型用于 Masked LM
("ibert", "IBertForMaskedLM"), # IBert 模型用于 Masked LM
("layoutlm", "LayoutLMForMaskedLM"), # LayoutLM 模型用于 Masked LM
("longformer", "LongformerForMaskedLM"), # Longformer 模型用于 Masked LM
("luke", "LukeForMaskedLM"), # Luke 模型用于 Masked LM
("mbart", "MBartForConditionalGeneration"), # MBart 模型用于条件生成
("mega", "MegaForMaskedLM"), # Mega 模型用于 Masked LM
("megatron-bert", "MegatronBertForMaskedLM"), # Megatron-Bert 模型用于 Masked LM
("mobilebert", "MobileBertForMaskedLM"), # MobileBert 模型用于 Masked LM
("mpnet", "MPNetForMaskedLM"), # MPNet 模型用于 Masked LM
("mra", "MraForMaskedLM"), # Mra 模型用于 Masked LM
("mvp", "MvpForConditionalGeneration"), # Mvp 模型用于条件生成
("nezha", "NezhaForMaskedLM"), # Nezha 模型用于 Masked LM
("nystromformer", "NystromformerForMaskedLM"), # Nystromformer 模型用于 Masked LM
("perceiver", "PerceiverForMaskedLM"), # Perceiver 模型用于 Masked LM
("qdqbert", "QDQBertForMaskedLM"), # QDQBert 模型用于 Masked LM
("reformer", "ReformerForMaskedLM"), # Reformer 模型用于 Masked LM
("rembert", "RemBertForMaskedLM"), # RemBert 模型用于 Masked LM
("roberta", "RobertaForMaskedLM"), # Roberta 模型用于 Masked LM
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), # Roberta with PreLayerNorm 模型用于 Masked LM
("roc_bert", "RoCBertForMaskedLM"), # RoCBert 模型用于 Masked LM
("roformer", "RoFormerForMaskedLM"), # RoFormer 模型用于 Masked LM
("squeezebert", "SqueezeBertForMaskedLM"), # SqueezeBert 模型用于 Masked LM
("tapas", "TapasForMaskedLM"), # Tapas 模型用于 Masked LM
("wav2vec2", "Wav2Vec2ForMaskedLM"), # Wav2Vec2 模型用于 Masked LM
("xlm", "XLMWithLMHeadModel"), # XLM 模型用于 Masked LM
("xlm-roberta", "XLMRobertaForMaskedLM"), # XLM-RoBERTa 模型用于 Masked LM
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), # XLM-RoBERTa-XL 模型用于 Masked LM
("xmod", "XmodForMaskedLM"), # Xmod 模型用于 Masked LM
("yoso", "YosoForMaskedLM"), # Yoso 模型用于 Masked LM
]
# 定义用于对象检测模型的名称映射字典,使用有序字典确保顺序性
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# 对象检测模型映射
("conditional_detr", "ConditionalDetrForObjectDetection"),
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
)
# 定义用于零样本对象检测模型的名称映射字典,使用有序字典确保顺序性
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# 零样本对象检测模型映射
("owlv2", "Owlv2ForObjectDetection"),
("owlvit", "OwlViTForObjectDetection"),
]
)
# 定义深度估计模型的名称映射字典,使用有序字典确保顺序性
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
[
# 深度估计模型映射
("depth_anything", "DepthAnythingForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
]
)
# 定义序列到序列因果语言模型的名称映射字典,使用有序字典确保顺序性
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# 序列到序列因果语言模型映射
("bart", "BartForConditionalGeneration"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("mvp", "MvpForConditionalGeneration"),
("nllb-moe", "NllbMoeForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("pegasus_x", "PegasusXForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForTextToText"),
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("umt5", "UMT5ForConditionalGeneration"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
]
)
# 定义语音序列到序列模型的名称映射字典,使用有序字典确保顺序性
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
# 语音序列到序列模型映射
("pop2piano", "Pop2PianoForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForSpeechToText"),
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"),
("whisper", "WhisperForConditionalGeneration"),
]
)
# 定义用于序列分类模型的名称映射字典
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
# 定义用于问答模型的名称映射字典
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
]
)
# 定义用于表格问答模型的名称映射字典
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# 表格问答模型映射
("tapas", "TapasForQuestionAnswering"),
]
)
# 定义用于视觉问答模型的名称映射字典
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("blip", "BlipForQuestionAnswering"),
("blip-2", "Blip2ForConditionalGeneration"),
("vilt", "ViltForQuestionAnswering"),
]
)
# 定义用于文档问答模型的名称映射字典
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("layoutlm", "LayoutLMForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
]
)
# 定义用于标记分类模型的名称映射字典
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
# 定义用于多项选择模型的名称映射字典
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# 多项选择模型映射
("albert", "AlbertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("data2vec-text", "Data2VecTextForMultipleChoice"),
("deberta-v2", "DebertaV2ForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("ernie", "ErnieForMultipleChoice"),
("ernie_m", "ErnieMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("fnet", "FNetForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("mega", "MegaForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("mra", "MraForMultipleChoice"),
("nezha", "NezhaForMultipleChoice"),
("nystromformer", "NystromformerForMultipleChoice"),
("qdqbert", "QDQBertForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
("roc_bert", "RoCBertForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("xmod", "XmodForMultipleChoice"),
("yoso", "YosoForMultipleChoice"),
]
)
# 定义用于下一句预测模型的名称映射字典
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
# 留空,等待后续添加
)
# 定义一个包含模型名称和类名的元组列表,每个元组包含模型的简称和完整类名
[
("bert", "BertForNextSentencePrediction"), # Bert 模型的简称及其完整类名
("ernie", "ErnieForNextSentencePrediction"), # Ernie 模型的简称及其完整类名
("fnet", "FNetForNextSentencePrediction"), # FNet 模型的简称及其完整类名
("megatron-bert", "MegatronBertForNextSentencePrediction"), # Megatron-Bert 模型的简称及其完整类名
("mobilebert", "MobileBertForNextSentencePrediction"), # MobileBERT 模型的简称及其完整类名
("nezha", "NezhaForNextSentencePrediction"), # Nezha 模型的简称及其完整类名
("qdqbert", "QDQBertForNextSentencePrediction"), # QDQBert 模型的简称及其完整类名
]
# 定义一个有序字典,用于映射音频分类模型名称到对应的类名
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 音频分类模型映射
("audio-spectrogram-transformer", "ASTForAudioClassification"),
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
("sew", "SEWForSequenceClassification"),
("sew-d", "SEWDForSequenceClassification"),
("unispeech", "UniSpeechForSequenceClassification"),
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"),
("whisper", "WhisperForAudioClassification"),
]
)
# 定义一个有序字典,用于映射连接主义时间分类(CTC)模型名称到对应的类名
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# 连接主义时间分类(CTC)模型映射
("data2vec-audio", "Data2VecAudioForCTC"),
("hubert", "HubertForCTC"),
("mctct", "MCTCTForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
("unispeech", "UniSpeechForCTC"),
("unispeech-sat", "UniSpeechSatForCTC"),
("wav2vec2", "Wav2Vec2ForCTC"),
("wav2vec2-bert", "Wav2Vec2BertForCTC"),
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
("wavlm", "WavLMForCTC"),
]
)
# 定义一个有序字典,用于映射音频帧分类模型名称到对应的类名
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 音频帧分类模型映射
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
]
)
# 定义一个有序字典,用于映射音频 X-向量模型名称到对应的类名
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
[
# 音频 X-向量模型映射
("data2vec-audio", "Data2VecAudioForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wav2vec2", "Wav2Vec2ForXVector"),
("wav2vec2-bert", "Wav2Vec2BertForXVector"),
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
("wavlm", "WavLMForXVector"),
]
)
# 定义一个有序字典,用于映射文本到频谱图模型名称到对应的类名
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
[
# 文本到频谱图模型映射
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("speecht5", "SpeechT5ForTextToSpeech"),
]
)
# 定义一个有序字典,用于映射文本到波形图模型名称到对应的类名
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
[
# 定义了多个元组,每个元组表示一个模型名称和对应的类名
("bark", "BarkModel"), # 模型名 "bark" 对应的类名 "BarkModel"
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"), # 模型名 "fastspeech2_conformer" 对应的类名 "FastSpeech2ConformerWithHifiGan"
("musicgen", "MusicgenForConditionalGeneration"), # 模型名 "musicgen" 对应的类名 "MusicgenForConditionalGeneration"
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"), # 模型名 "musicgen_melody" 对应的类名 "MusicgenMelodyForConditionalGeneration"
("seamless_m4t", "SeamlessM4TForTextToSpeech"), # 模型名 "seamless_m4t" 对应的类名 "SeamlessM4TForTextToSpeech"
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"), # 模型名 "seamless_m4t_v2" 对应的类名 "SeamlessM4Tv2ForTextToSpeech"
("vits", "VitsModel"), # 模型名 "vits" 对应的类名 "VitsModel"
]
# 用于零样本图像分类模型映射的有序字典
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 零样本图像分类模型映射
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
("blip", "BlipModel"),
("chinese_clip", "ChineseCLIPModel"),
("clip", "CLIPModel"),
("clipseg", "CLIPSegModel"),
("siglip", "SiglipModel"),
]
)
# 用于骨干网络映射的有序字典
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[
# 骨干网络映射
("beit", "BeitBackbone"),
("bit", "BitBackbone"),
("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"),
("dinov2", "Dinov2Backbone"),
("focalnet", "FocalNetBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("pvt_v2", "PvtV2Backbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("timm_backbone", "TimmBackbone"),
("vitdet", "VitDetBackbone"),
]
)
# 用于遮罩生成模型映射的有序字典
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "SamModel"),
]
)
# 用于关键点检测模型映射的有序字典
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("superpoint", "SuperPointForKeypointDetection"),
]
)
# 用于文本编码模型映射的有序字典
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertModel"),
("bert", "BertModel"),
("big_bird", "BigBirdModel"),
("data2vec-text", "Data2VecTextModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("distilbert", "DistilBertModel"),
("electra", "ElectraModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("longformer", "LongformerModel"),
("mobilebert", "MobileBertModel"),
("mt5", "MT5EncoderModel"),
("nystromformer", "NystromformerModel"),
("reformer", "ReformerModel"),
("rembert", "RemBertModel"),
("roberta", "RobertaModel"),
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("squeezebert", "SqueezeBertModel"),
("t5", "T5EncoderModel"),
("umt5", "UMT5EncoderModel"),
("xlm", "XLMModel"),
("xlm-roberta", "XLMRobertaModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
]
)
# 用于时间序列分类模型映射的有序字典
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
("patchtst", "PatchTSTForClassification"),
]
)
# 用于时间序列回归模型映射的有序字典
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
[
("patchtsmixer", "PatchTSMixerForRegression"),
("patchtst", "PatchTSTForRegression"),
]
)
# 用于图像到图像映射的有序字典
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
[
("swin2sr", "Swin2SRForImageSuperResolution"),
]
)
# 使用懒加载自动映射生成的模型映射
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
# 创建用于预训练模型映射的惰性自动映射对象
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
# 创建带有语言模型头的模型映射的惰性自动映射对象
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
# 创建用于因果语言模型的模型映射的惰性自动映射对象
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
# 创建用于因果图像建模的模型映射的惰性自动映射对象
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES)
# 创建用于图像分类的模型映射的惰性自动映射对象
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES)
# 创建用于零样本图像分类的模型映射的惰性自动映射对象
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES)
# 创建用于图像分割的模型映射的惰性自动映射对象
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES)
# 创建用于语义分割的模型映射的惰性自动映射对象
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
# 创建用于实例分割的模型映射的惰性自动映射对象
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
# 创建用于通用分割的模型映射的惰性自动映射对象
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
# 创建用于视频分类的模型映射的惰性自动映射对象
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
# 创建用于视觉到序列的模型映射的惰性自动映射对象
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
# 创建用于视觉问答的模型映射的惰性自动映射对象
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
# 创建用于文档问答的模型映射的惰性自动映射对象
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
# 创建用于掩蔽语言模型的模型映射的惰性自动映射对象
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
# 创建用于图像处理的模型映射的惰性自动映射对象
MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
# 创建用于掩蔽图像建模的模型映射的惰性自动映射对象
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES)
# 创建用于目标检测的模型映射的惰性自动映射对象
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
# 创建用于零样本目标检测的模型映射的惰性自动映射对象
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)
# 创建用于深度估计的模型映射的惰性自动映射对象
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
# 创建用于序列到序列因果语言模型的模型映射的惰性自动映射对象
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
# 创建用于序列分类的模型映射的惰性自动映射对象
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
# 创建用于问答的模型映射的惰性自动映射对象
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
# 创建用于表格问答的模型映射的惰性自动映射对象
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
# 导入变量 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_CTC_MAPPING_NAMES
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_BACKBONE_MAPPING_NAMES
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建模型到配置映射,基于 CONFIG_MAPPING_NAMES 和 MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
# 将 MODEL_WITH_LM_HEAD_MAPPING 赋值给 _model_mapping 变量
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
# 更新 _AutoModelWithLMHead 类,自动设置头部文档为 "language modeling"
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
# 定义 AutoModelForCausalLM 类,使用 MODEL_FOR_CAUSAL_LM_MAPPING 映射
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
# 更新 AutoModelForCausalLM 类,自动设置头部文档为 "causal language modeling"
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
# 定义 AutoModelForMaskedLM 类,使用 MODEL_FOR_MASKED_LM_MAPPING 映射
class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
# 更新 AutoModelForMaskedLM 类,自动设置头部文档为 "masked language modeling"
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
# 定义 AutoModelForSeq2SeqLM 类,使用 MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING 映射
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
# 更新 AutoModelForSeq2SeqLM 类,自动设置头部文档为 "sequence-to-sequence language modeling"
# 同时设置示例的检查点为 "google-t5/t5-base"
AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
# 定义 AutoModelForSequenceClassification 类,使用 MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING 映射
class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
# 更新 AutoModelForSequenceClassification 类,自动设置头部文档为 "sequence classification"
AutoModelForSequenceClassification = auto_class_update(AutoModelForSequenceClassification, head_doc="sequence classification")
# 定义 AutoModelForQuestionAnswering 类,使用 MODEL_FOR_QUESTION_ANSWERING_MAPPING 映射
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
# 更新 AutoModelForQuestionAnswering 类,自动设置头部文档为 "question answering"
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
# 定义 AutoModelForTableQuestionAnswering 类,使用 MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING 映射
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
# 更新 AutoModelForTableQuestionAnswering 类,自动设置头部文档为 "table question answering"
# 同时设置示例的检查点为 "google/tapas-base-finetuned-wtq"
AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
# 定义 AutoModelForVisualQuestionAnswering 类,使用 MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING 映射
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
# 更新 AutoModelForVisualQuestionAnswering 类,自动设置头部文档为 "visual question answering"
# 同时设置示例的检查点为 "dandelin/vilt-b32-finetuned-vqa"
AutoModelForVisualQuestionAnswering = auto_class_update(
AutoModelForVisualQuestionAnswering,
head_doc="visual question answering",
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)
# 定义 AutoModelForDocumentQuestionAnswering 类,使用 MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING 映射
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
# 更新 AutoModelForDocumentQuestionAnswering 类,自动设置头部文档为 "document question answering"
# 同时设置示例的检查点为 'impira/layoutlm-document-qa", revision="52e01b3'
AutoModelForDocumentQuestionAnswering = auto_class_update(
AutoModelForDocumentQuestionAnswering,
head_doc="document question answering",
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)
# 定义 AutoModelForTokenClassification 类,使用 MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING 映射
class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
# 更新 AutoModelForTokenClassification 类,自动设置头部文档为 "token classification"
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
# 定义 AutoModelForMultipleChoice 类,使用 MODEL_FOR_MULTIPLE_CHOICE_MAPPING 映射
class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
# 更新 AutoModelForMultipleChoice 类,自动设置头部文档为 "multiple choice"
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
# 定义 AutoModelForNextSentencePrediction 类,使用 MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING 映射
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
# 更新 AutoModelForNextSentencePrediction 类,未完成的部分,可能有其它设置或定义。
# 导入 AutoModelForNextSentencePrediction 类,并为其指定 head_doc 参数为 "next sentence prediction"
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
class AutoModelForImageClassification(_BaseAutoModelClass):
# 自动化生成的图像分类模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
# 自动化生成的零样本图像分类模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification = auto_class_update(
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class AutoModelForImageSegmentation(_BaseAutoModelClass):
# 自动化生成的图像分割模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
# 自动化生成的语义分割模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
AutoModelForSemanticSegmentation = auto_class_update(
AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
# 自动化生成的通用图像分割模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
AutoModelForUniversalSegmentation = auto_class_update(
AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
)
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
# 自动化生成的实例分割模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
AutoModelForInstanceSegmentation = auto_class_update(
AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)
class AutoModelForObjectDetection(_BaseAutoModelClass):
# 自动化生成的物体检测模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
# 自动化生成的零样本物体检测模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
AutoModelForZeroShotObjectDetection = auto_class_update(
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)
class AutoModelForDepthEstimation(_BaseAutoModelClass):
# 自动化生成的深度估计模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
class AutoModelForVideoClassification(_BaseAutoModelClass):
# 自动化生成的视频分类模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
class AutoModelForVision2Seq(_BaseAutoModelClass):
# 自动化生成的视觉到文本模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
class AutoModelForAudioClassification(_BaseAutoModelClass):
# 自动化生成的音频分类模型类,使用预定义的模型映射
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
# 将 MODEL_FOR_CTC_MAPPING 赋值给 _model_mapping
_model_mapping = MODEL_FOR_CTC_MAPPING
# 使用 auto_class_update 函数更新 AutoModelForCTC 类,添加头部文档说明
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
# 定义 AutoModelForSpeechSeq2Seq 类,映射到 MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
# 使用 auto_class_update 函数更新 AutoModelForSpeechSeq2Seq 类,添加头部文档说明
AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
# 定义 AutoModelForAudioFrameClassification 类,映射到 MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
# 使用 auto_class_update 函数更新 AutoModelForAudioFrameClassification 类,添加头部文档说明
AutoModelForAudioFrameClassification = auto_class_update(
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)
# 定义 AutoModelForAudioXVector 类,映射到 MODEL_FOR_AUDIO_XVECTOR_MAPPING
class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
# 定义 AutoModelForTextToSpectrogram 类,映射到 MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
# 定义 AutoModelForTextToWaveform 类,映射到 MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
class AutoModelForTextToWaveform(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
# 定义 AutoBackbone 类,映射到 MODEL_FOR_BACKBONE_MAPPING
class AutoBackbone(_BaseAutoBackboneClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
# 使用 auto_class_update 函数更新 AutoModelForAudioXVector 类,添加头部文档说明
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
# 定义 AutoModelForMaskedImageModeling 类,映射到 MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
# 使用 auto_class_update 函数更新 AutoModelForMaskedImageModeling 类,添加头部文档说明
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
# 定义 AutoModelWithLMHead 类,继承自 _AutoModelWithLMHead
class AutoModelWithLMHead(_AutoModelWithLMHead):
# 从给定配置创建对象的类方法,发出未来版本移除警告
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
# 从预训练模型创建对象的类方法,发出未来版本移除警告
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
.\models\auto\modeling_flax_auto.py
# 导入必要的模块和函数
from collections import OrderedDict
# 导入日志记录器
from ...utils import logging
# 导入自动模型工厂相关类和函数
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
# 导入自动配置映射名称
from .configuration_auto import CONFIG_MAPPING_NAMES
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义模型名称到类的映射字典,用OrderedDict确保顺序
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# 基础模型映射
("albert", "FlaxAlbertModel"),
("bart", "FlaxBartModel"),
("beit", "FlaxBeitModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gemma", "FlaxGemmaModel"),
("gpt-sw3", "FlaxGPT2Model"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
("regnet", "FlaxRegNetModel"),
("resnet", "FlaxResNetModel"),
("roberta", "FlaxRobertaModel"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
("roformer", "FlaxRoFormerModel"),
("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
)
# 定义用于预训练任务的模型名称到类的映射字典,初始化为空OrderedDict
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# 预训练模型到 Flax 模型类的映射关系列表
# ("albert", "FlaxAlbertForPreTraining") 表示将 "albert" 映射到 Flax 中的 FlaxAlbertForPreTraining 类
("albert", "FlaxAlbertForPreTraining"),
# ("bart", "FlaxBartForConditionalGeneration") 表示将 "bart" 映射到 Flax 中的 FlaxBartForConditionalGeneration 类
("bart", "FlaxBartForConditionalGeneration"),
# ("bert", "FlaxBertForPreTraining") 表示将 "bert" 映射到 Flax 中的 FlaxBertForPreTraining 类
("bert", "FlaxBertForPreTraining"),
# ("big_bird", "FlaxBigBirdForPreTraining") 表示将 "big_bird" 映射到 Flax 中的 FlaxBigBirdForPreTraining 类
("big_bird", "FlaxBigBirdForPreTraining"),
# ("electra", "FlaxElectraForPreTraining") 表示将 "electra" 映射到 Flax 中的 FlaxElectraForPreTraining 类
("electra", "FlaxElectraForPreTraining"),
# ("longt5", "FlaxLongT5ForConditionalGeneration") 表示将 "longt5" 映射到 Flax 中的 FlaxLongT5ForConditionalGeneration 类
("longt5", "FlaxLongT5ForConditionalGeneration"),
# ("mbart", "FlaxMBartForConditionalGeneration") 表示将 "mbart" 映射到 Flax 中的 FlaxMBartForConditionalGeneration 类
("mbart", "FlaxMBartForConditionalGeneration"),
# ("mt5", "FlaxMT5ForConditionalGeneration") 表示将 "mt5" 映射到 Flax 中的 FlaxMT5ForConditionalGeneration 类
("mt5", "FlaxMT5ForConditionalGeneration"),
# ("roberta", "FlaxRobertaForMaskedLM") 表示将 "roberta" 映射到 Flax 中的 FlaxRobertaForMaskedLM 类
("roberta", "FlaxRobertaForMaskedLM"),
# ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM") 表示将 "roberta-prelayernorm" 映射到 Flax 中的 FlaxRobertaPreLayerNormForMaskedLM 类
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
# ("roformer", "FlaxRoFormerForMaskedLM") 表示将 "roformer" 映射到 Flax 中的 FlaxRoFormerForMaskedLM 类
("roformer", "FlaxRoFormerForMaskedLM"),
# ("t5", "FlaxT5ForConditionalGeneration") 表示将 "t5" 映射到 Flax 中的 FlaxT5ForConditionalGeneration 类
("t5", "FlaxT5ForConditionalGeneration"),
# ("wav2vec2", "FlaxWav2Vec2ForPreTraining") 表示将 "wav2vec2" 映射到 Flax 中的 FlaxWav2Vec2ForPreTraining 类
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
# ("whisper", "FlaxWhisperForConditionalGeneration") 表示将 "whisper" 映射到 Flax 中的 FlaxWhisperForConditionalGeneration 类
("whisper", "FlaxWhisperForConditionalGeneration"),
# ("xlm-roberta", "FlaxXLMRobertaForMaskedLM") 表示将 "xlm-roberta" 映射到 Flax 中的 FlaxXLMRobertaForMaskedLM 类
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
# 带有模型名称到对应 Flax 模型类的映射字典,用于 Masked LM 模型
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# 模型为 Masked LM 时的映射
("albert", "FlaxAlbertForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("distilbert", "FlaxDistilBertForMaskedLM"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
# 带有模型名称到对应 Flax 模型类的映射字典,用于 Seq2Seq Causal LM 模型
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# 模型为 Seq2Seq Causal LM 时的映射
("bart", "FlaxBartForConditionalGeneration"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("longt5", "FlaxLongT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
]
)
# 带有模型名称到对应 Flax 模型类的映射字典,用于图像分类模型
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 图像分类模型的映射
("beit", "FlaxBeitForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"),
]
)
# 带有模型名称到对应 Flax 模型类的映射字典,用于 Vision 2 Seq 模型
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
]
)
# 带有模型名称到对应 Flax 模型类的映射字典,用于 Causal LM 模型
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# 模型为 Causal LM 时的映射
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gemma", "FlaxGemmaForCausalLM"),
("gpt-sw3", "FlaxGPT2LMHeadModel"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
]
)
# 带有模型名称到对应 Flax 模型类的映射字典,用于序列分类模型
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 定义了一系列元组,每个元组包含两个字符串:
# 第一个字符串是模型的名称,第二个字符串是用于该模型的序列分类任务的类名
("albert", "FlaxAlbertForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("distilbert", "FlaxDistilBertForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
]
# 定义用于问题回答的模型名称映射字典
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# 将 "albert" 映射到 FlaxAlbertForQuestionAnswering
("albert", "FlaxAlbertForQuestionAnswering"),
# 将 "bart" 映射到 FlaxBartForQuestionAnswering
("bart", "FlaxBartForQuestionAnswering"),
# 将 "bert" 映射到 FlaxBertForQuestionAnswering
("bert", "FlaxBertForQuestionAnswering"),
# 将 "big_bird" 映射到 FlaxBigBirdForQuestionAnswering
("big_bird", "FlaxBigBirdForQuestionAnswering"),
# 将 "distilbert" 映射到 FlaxDistilBertForQuestionAnswering
("distilbert", "FlaxDistilBertForQuestionAnswering"),
# 将 "electra" 映射到 FlaxElectraForQuestionAnswering
("electra", "FlaxElectraForQuestionAnswering"),
# 将 "mbart" 映射到 FlaxMBartForQuestionAnswering
("mbart", "FlaxMBartForQuestionAnswering"),
# 将 "roberta" 映射到 FlaxRobertaForQuestionAnswering
("roberta", "FlaxRobertaForQuestionAnswering"),
# 将 "roberta-prelayernorm" 映射到 FlaxRobertaPreLayerNormForQuestionAnswering
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
# 将 "roformer" 映射到 FlaxRoFormerForQuestionAnswering
("roformer", "FlaxRoFormerForQuestionAnswering"),
# 将 "xlm-roberta" 映射到 FlaxXLMRobertaForQuestionAnswering
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
]
)
# 定义用于标记分类的模型名称映射字典
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 将 "albert" 映射到 FlaxAlbertForTokenClassification
("albert", "FlaxAlbertForTokenClassification"),
# 将 "bert" 映射到 FlaxBertForTokenClassification
("bert", "FlaxBertForTokenClassification"),
# 将 "big_bird" 映射到 FlaxBigBirdForTokenClassification
("big_bird", "FlaxBigBirdForTokenClassification"),
# 将 "distilbert" 映射到 FlaxDistilBertForTokenClassification
("distilbert", "FlaxDistilBertForTokenClassification"),
# 将 "electra" 映射到 FlaxElectraForTokenClassification
("electra", "FlaxElectraForTokenClassification"),
# 将 "roberta" 映射到 FlaxRobertaForTokenClassification
("roberta", "FlaxRobertaForTokenClassification"),
# 将 "roberta-prelayernorm" 映射到 FlaxRobertaPreLayerNormForTokenClassification
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
# 将 "roformer" 映射到 FlaxRoFormerForTokenClassification
("roformer", "FlaxRoFormerForTokenClassification"),
# 将 "xlm-roberta" 映射到 FlaxXLMRobertaForTokenClassification
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
]
)
# 定义用于多项选择的模型名称映射字典
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# 将 "albert" 映射到 FlaxAlbertForMultipleChoice
("albert", "FlaxAlbertForMultipleChoice"),
# 将 "bert" 映射到 FlaxBertForMultipleChoice
("bert", "FlaxBertForMultipleChoice"),
# 将 "big_bird" 映射到 FlaxBigBirdForMultipleChoice
("big_bird", "FlaxBigBirdForMultipleChoice"),
# 将 "distilbert" 映射到 FlaxDistilBertForMultipleChoice
("distilbert", "FlaxDistilBertForMultipleChoice"),
# 将 "electra" 映射到 FlaxElectraForMultipleChoice
("electra", "FlaxElectraForMultipleChoice"),
# 将 "roberta" 映射到 FlaxRobertaForMultipleChoice
("roberta", "FlaxRobertaForMultipleChoice"),
# 将 "roberta-prelayernorm" 映射到 FlaxRobertaPreLayerNormForMultipleChoice
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
# 将 "roformer" 映射到 FlaxRoFormerForMultipleChoice
("roformer", "FlaxRoFormerForMultipleChoice"),
# 将 "xlm-roberta" 映射到 FlaxXLMRobertaForMultipleChoice
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
]
)
# 定义用于下一个句子预测的模型名称映射字典
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
# 将 "bert" 映射到 FlaxBertForNextSentencePrediction
("bert", "FlaxBertForNextSentencePrediction"),
]
)
# 定义用于语音序列到序列的模型名称映射字典
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
# 将 "speech-encoder-decoder" 映射到 FlaxSpeechEncoderDecoderModel
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
# 将 "whisper" 映射到 FlaxWhisperForConditionalGeneration
("whisper", "FlaxWhisperForConditionalGeneration"),
]
)
# 定义用于音频分类的模型名称映射字典
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 将 "whisper" 映射到 FlaxWhisperForAudioClassification
("whisper", "FlaxWhisperForAudioClassification"),
]
)
# 定义 Flax 模型映射对象,通过 LazyAutoMapping 进行自动映射
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
# 定义用于预训练的 Flax 模型映射对象
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
# 定义用于遮盖语言模型的 Flax 模型映射对象
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
# 定义用于序列到序列因果语言模型的 Flax 模型映射对象
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
# 导入两个变量:CONFIG_MAPPING_NAMES 和 FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING 映射
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING 映射
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_CAUSAL_LM_MAPPING 映射
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING 映射
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING 映射
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING 映射
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING 映射
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING 映射
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING 映射
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING 映射
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
# 定义 FlaxAutoModel 类,并将 _model_mapping 设置为 FLAX_MODEL_MAPPING
class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModel
FlaxAutoModel = auto_class_update(FlaxAutoModel)
# 定义 FlaxAutoModelForPreTraining 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_PRETRAINING_MAPPING
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForPreTraining,并设置头部文档为 "pretraining"
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
# 定义 FlaxAutoModelForCausalLM 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForCausalLM,并设置头部文档为 "causal language modeling"
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
# 定义 FlaxAutoModelForMaskedLM 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_MASKED_LM_MAPPING
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForMaskedLM,并设置头部文档为 "masked language modeling"
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
# 定义 FlaxAutoModelForSeq2SeqLM 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForSeq2SeqLM,并设置头部文档为 "sequence-to-sequence language modeling",以及示例检查点为 "google-t5/t5-base"
FlaxAutoModelForSeq2SeqLM = auto_class_update(
FlaxAutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
# 定义 FlaxAutoModelForSequenceClassification 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForSequenceClassification,并设置头部文档为 "sequence classification"
FlaxAutoModelForSequenceClassification = auto_class_update(
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
)
# 定义 FlaxAutoModelForQuestionAnswering 类,并将 _model_mapping 设置为 FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
# 使用 auto_class_update 函数更新 FlaxAutoModelForQuestionAnswering,并设置头部文档为 "question answering"
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
# 定义用于标记分类任务的自动化模型类
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
# 指定模型映射到标记分类任务的类别
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
# 更新标记分类任务模型类,添加头部文档说明为"token classification"
FlaxAutoModelForTokenClassification = auto_class_update(
FlaxAutoModelForTokenClassification, head_doc="token classification"
)
# 定义用于多项选择任务的自动化模型类
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
# 指定模型映射到多项选择任务的类别
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
# 更新多项选择任务模型类,添加头部文档说明为"multiple choice"
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
# 定义用于下一句预测任务的自动化模型类
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
# 指定模型映射到下一句预测任务的类别
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
# 更新下一句预测任务模型类,添加头部文档说明为"next sentence prediction"
FlaxAutoModelForNextSentencePrediction = auto_class_update(
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
# 定义用于图像分类任务的自动化模型类
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
# 指定模型映射到图像分类任务的类别
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
# 更新图像分类任务模型类,添加头部文档说明为"image classification"
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
)
# 定义用于视觉到文本建模任务的自动化模型类
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
# 指定模型映射到视觉到文本建模任务的类别
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
# 更新视觉到文本建模任务模型类,添加头部文档说明为"vision-to-text modeling"
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
# 定义用于语音序列到序列建模任务的自动化模型类
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
# 指定模型映射到语音序列到序列建模任务的类别
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
# 更新语音序列到序列建模任务模型类,添加头部文档说明为"sequence-to-sequence speech-to-text modeling"
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
.\models\auto\modeling_tf_auto.py
# 指定编码格式为UTF-8
# 版权声明和许可证信息,告知代码的版权和许可使用条款
# 版权归The HuggingFace Inc.团队所有,许可类型为Apache License, Version 2.0
# 除非符合许可证规定,否则不得使用此文件
# 引入警告模块,用于处理警告信息
import warnings
# 引入有序字典模块,用于保存模型映射关系的有序字典
from collections import OrderedDict
# 从当前包中的utils模块中导入logging工具
from ...utils import logging
# 从当前包中的.auto_factory模块导入自动模型工厂的基类和映射类相关函数
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
# 从当前包中的.configuration_auto模块导入配置映射名称
from .configuration_auto import CONFIG_MAPPING_NAMES
# 获取或创建当前模块的日志记录器对象
logger = logging.get_logger(__name__)
# 定义TensorFlow模型映射名称的有序字典,用于存储模型名称和类别的映射关系
TF_MODEL_MAPPING_NAMES = OrderedDict(
# 定义一个列表,包含模型名称到对应 TensorFlow 模型类的映射关系
[
# "albert" 模型对应的 TensorFlow 模型类为 "TFAlbertModel"
("albert", "TFAlbertModel"),
# "bart" 模型对应的 TensorFlow 模型类为 "TFBartModel"
("bart", "TFBartModel"),
# "bert" 模型对应的 TensorFlow 模型类为 "TFBertModel"
("bert", "TFBertModel"),
# "blenderbot" 模型对应的 TensorFlow 模型类为 "TFBlenderbotModel"
("blenderbot", "TFBlenderbotModel"),
# "blenderbot-small" 模型对应的 TensorFlow 模型类为 "TFBlenderbotSmallModel"
("blenderbot-small", "TFBlenderbotSmallModel"),
# "blip" 模型对应的 TensorFlow 模型类为 "TFBlipModel"
("blip", "TFBlipModel"),
# "camembert" 模型对应的 TensorFlow 模型类为 "TFCamembertModel"
("camembert", "TFCamembertModel"),
# "clip" 模型对应的 TensorFlow 模型类为 "TFCLIPModel"
("clip", "TFCLIPModel"),
# "convbert" 模型对应的 TensorFlow 模型类为 "TFConvBertModel"
("convbert", "TFConvBertModel"),
# "convnext" 模型对应的 TensorFlow 模型类为 "TFConvNextModel"
("convnext", "TFConvNextModel"),
# "convnextv2" 模型对应的 TensorFlow 模型类为 "TFConvNextV2Model"
("convnextv2", "TFConvNextV2Model"),
# "ctrl" 模型对应的 TensorFlow 模型类为 "TFCTRLModel"
("ctrl", "TFCTRLModel"),
# "cvt" 模型对应的 TensorFlow 模型类为 "TFCvtModel"
("cvt", "TFCvtModel"),
# "data2vec-vision" 模型对应的 TensorFlow 模型类为 "TFData2VecVisionModel"
("data2vec-vision", "TFData2VecVisionModel"),
# "deberta" 模型对应的 TensorFlow 模型类为 "TFDebertaModel"
("deberta", "TFDebertaModel"),
# "deberta-v2" 模型对应的 TensorFlow 模型类为 "TFDebertaV2Model"
("deberta-v2", "TFDebertaV2Model"),
# "deit" 模型对应的 TensorFlow 模型类为 "TFDeiTModel"
("deit", "TFDeiTModel"),
# "distilbert" 模型对应的 TensorFlow 模型类为 "TFDistilBertModel"
("distilbert", "TFDistilBertModel"),
# "dpr" 模型对应的 TensorFlow 模型类为 "TFDPRQuestionEncoder"
("dpr", "TFDPRQuestionEncoder"),
# "efficientformer" 模型对应的 TensorFlow 模型类为 "TFEfficientFormerModel"
("efficientformer", "TFEfficientFormerModel"),
# "electra" 模型对应的 TensorFlow 模型类为 "TFElectraModel"
("electra", "TFElectraModel"),
# "esm" 模型对应的 TensorFlow 模型类为 "TFEsmModel"
("esm", "TFEsmModel"),
# "flaubert" 模型对应的 TensorFlow 模型类为 "TFFlaubertModel"
("flaubert", "TFFlaubertModel"),
# "funnel" 模型对应的 TensorFlow 模型类为 ("TFFunnelModel", "TFFunnelBaseModel")
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
# "gpt-sw3" 模型对应的 TensorFlow 模型类为 "TFGPT2Model"
("gpt-sw3", "TFGPT2Model"),
# "gpt2" 模型对应的 TensorFlow 模型类为 "TFGPT2Model"
("gpt2", "TFGPT2Model"),
# "gptj" 模型对应的 TensorFlow 模型类为 "TFGPTJModel"
("gptj", "TFGPTJModel"),
# "groupvit" 模型对应的 TensorFlow 模型类为 "TFGroupViTModel"
("groupvit", "TFGroupViTModel"),
# "hubert" 模型对应的 TensorFlow 模型类为 "TFHubertModel"
("hubert", "TFHubertModel"),
# "layoutlm" 模型对应的 TensorFlow 模型类为 "TFLayoutLMModel"
("layoutlm", "TFLayoutLMModel"),
# "layoutlmv3" 模型对应的 TensorFlow 模型类为 "TFLayoutLMv3Model"
("layoutlmv3", "TFLayoutLMv3Model"),
# "led" 模型对应的 TensorFlow 模型类为 "TFLEDModel"
("led", "TFLEDModel"),
# "longformer" 模型对应的 TensorFlow 模型类为 "TFLongformerModel"
("longformer", "TFLongformerModel"),
# "lxmert" 模型对应的 TensorFlow 模型类为 "TFLxmertModel"
("lxmert", "TFLxmertModel"),
# "marian" 模型对应的 TensorFlow 模型类为 "TFMarianModel"
("marian", "TFMarianModel"),
# "mbart" 模型对应的 TensorFlow 模型类为 "TFMBartModel"
("mbart", "TFMBartModel"),
# "mobilebert" 模型对应的 TensorFlow 模型类为 "TFMobileBertModel"
("mobilebert", "TFMobileBertModel"),
# "mobilevit" 模型对应的 TensorFlow 模型类为 "TFMobileViTModel"
("mobilevit", "TFMobileViTModel"),
# "mpnet" 模型对应的 TensorFlow 模型类为 "TFMPNetModel"
("mpnet", "TFMPNetModel"),
# "mt5" 模型对应的 TensorFlow 模型类为 "TFMT5Model"
("mt5", "TFMT5Model"),
# "openai-gpt" 模型对应的 TensorFlow 模型类为 "TFOpenAIGPTModel"
("openai-gpt", "TFOpenAIGPTModel"),
# "opt" 模型对应的 TensorFlow 模型类为 "TFOPTModel"
("opt", "TFOPTModel"),
# "pegasus" 模型对应的 TensorFlow 模型类为 "TFPegasusModel"
("pegasus", "TFPegasusModel"),
# "regnet" 模型对应的 TensorFlow 模型类为 "TFRegNetModel"
("regnet", "TFRegNetModel"),
# "rembert" 模型对应的 TensorFlow 模型类为 "TFRemBertModel"
("rembert", "TFRemBertModel"),
# "resnet" 模型对应的 TensorFlow 模型类为 "TFResNetModel"
("resnet", "TFResNetModel"),
# "roberta" 模型对应的 TensorFlow 模型类为 "TFRobertaModel"
("roberta", "TFRobertaModel"),
# "roberta-prelayernorm" 模型对应的 TensorFlow 模型类为 "TFRobertaPreLayerNormModel"
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
# "roformer" 模型对应的 TensorFlow 模型类为 "TFRoFormerModel"
("roformer", "TFRoFormerModel"),
# "sam" 模型对应的 TensorFlow 模型类为 "TFSamModel"
("sam", "TFSamModel"),
# "segformer" 模型对应的 TensorFlow 模型类为 "TFSegformerModel"
("segformer", "TFSegformerModel"),
# "speech_to_text" 模型对应的 TensorFlow 模型类为 "TFSpeech2TextModel"
("speech_to_text", "TFSpeech2TextModel"),
# "swin" 模型对应的 TensorFlow 模型类为 "TFSwinModel"
("swin", "TFSwinModel"),
# "t5" 模型对应的 TensorFlow 模型类
# 定义一个有序字典,映射模型名称到TensorFlow模型类名,用于预训练模型的映射
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# 各种预训练模型的映射关系
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForPreTraining"),
("camembert", "TFCamembertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForPreTraining"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForPreTraining"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("vit_mae", "TFViTMAEForPreTraining"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
# 定义另一个有序字典,映射模型名称到TensorFlow模型类名,用于带有语言模型头部的模型的映射
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# 各种带有语言模型头部的模型的映射关系
("albert", "TFAlbertForMaskedLM"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("longformer", "TFLongformerForMaskedLM"),
("marian", "TFMarianMTModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("whisper", "TFWhisperForConditionalGeneration"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
# 定义另一个有序字典,映射模型名称到TensorFlow模型类名,用于因果语言模型的映射
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# 定义一个列表,包含了多个元组,每个元组代表了一个模型及其对应的类名
# 第一个元素是模型的缩写或名称,第二个元素是该模型对应的 TensorFlow 类名
# 模型 "bert" 对应的类是 "TFBertLMHeadModel"
("bert", "TFBertLMHeadModel"),
# 模型 "camembert" 对应的类是 "TFCamembertForCausalLM"
("camembert", "TFCamembertForCausalLM"),
# 模型 "ctrl" 对应的类是 "TFCTRLLMHeadModel"
("ctrl", "TFCTRLLMHeadModel"),
# 模型 "gpt-sw3" 对应的类是 "TFGPT2LMHeadModel"
("gpt-sw3", "TFGPT2LMHeadModel"),
# 模型 "gpt2" 对应的类是 "TFGPT2LMHeadModel"
("gpt2", "TFGPT2LMHeadModel"),
# 模型 "gptj" 对应的类是 "TFGPTJForCausalLM"
("gptj", "TFGPTJForCausalLM"),
# 模型 "openai-gpt" 对应的类是 "TFOpenAIGPTLMHeadModel"
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
# 模型 "opt" 对应的类是 "TFOPTForCausalLM"
("opt", "TFOPTForCausalLM"),
# 模型 "rembert" 对应的类是 "TFRemBertForCausalLM"
("rembert", "TFRemBertForCausalLM"),
# 模型 "roberta" 对应的类是 "TFRobertaForCausalLM"
("roberta", "TFRobertaForCausalLM"),
# 模型 "roberta-prelayernorm" 对应的类是 "TFRobertaPreLayerNormForCausalLM"
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
# 模型 "roformer" 对应的类是 "TFRoFormerForCausalLM"
("roformer", "TFRoFormerForCausalLM"),
# 模型 "transfo-xl" 对应的类是 "TFTransfoXLLMHeadModel"
("transfo-xl", "TFTransfoXLLMHeadModel"),
# 模型 "xglm" 对应的类是 "TFXGLMForCausalLM"
("xglm", "TFXGLMForCausalLM"),
# 模型 "xlm" 对应的类是 "TFXLMWithLMHeadModel"
("xlm", "TFXLMWithLMHeadModel"),
# 模型 "xlm-roberta" 对应的类是 "TFXLMRobertaForCausalLM"
("xlm-roberta", "TFXLMRobertaForCausalLM"),
# 模型 "xlnet" 对应的类是 "TFXLNetLMHeadModel"
("xlnet", "TFXLNetLMHeadModel"),
]
# 模型到类的映射,用于模型在 TensorFlow 中的命名
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "TFDeiTForMaskedImageModeling"), # DEIT模型对应的命名为TFDeiTForMaskedImageModeling
("swin", "TFSwinForMaskedImageModeling"), # Swin模型对应的命名为TFSwinForMaskedImageModeling
]
)
# 图像分类模型到类的映射
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 图像分类模型
("convnext", "TFConvNextForImageClassification"), # ConvNext模型对应的命名为TFConvNextForImageClassification
("convnextv2", "TFConvNextV2ForImageClassification"), # ConvNextV2模型对应的命名为TFConvNextV2ForImageClassification
("cvt", "TFCvtForImageClassification"), # CVT模型对应的命名为TFCvtForImageClassification
("data2vec-vision", "TFData2VecVisionForImageClassification"), # Data2Vec-Vision模型对应的命名为TFData2VecVisionForImageClassification
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), # DEIT模型对应的命名为TFDeiTForImageClassification和TFDeiTForImageClassificationWithTeacher
(
"efficientformer",
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"), # EfficientFormer模型对应的命名为TFEfficientFormerForImageClassification和TFEfficientFormerForImageClassificationWithTeacher
),
("mobilevit", "TFMobileViTForImageClassification"), # MobileViT模型对应的命名为TFMobileViTForImageClassification
("regnet", "TFRegNetForImageClassification"), # RegNet模型对应的命名为TFRegNetForImageClassification
("resnet", "TFResNetForImageClassification"), # ResNet模型对应的命名为TFResNetForImageClassification
("segformer", "TFSegformerForImageClassification"), # Segformer模型对应的命名为TFSegformerForImageClassification
("swin", "TFSwinForImageClassification"), # Swin模型对应的命名为TFSwinForImageClassification
("vit", "TFViTForImageClassification"), # ViT模型对应的命名为TFViTForImageClassification
]
)
# 零样本图像分类模型到类的映射
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 零样本图像分类模型映射
("blip", "TFBlipModel"), # BLIP模型对应的命名为TFBlipModel
("clip", "TFCLIPModel"), # CLIP模型对应的命名为TFCLIPModel
]
)
# 语义分割模型到类的映射
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# 语义分割模型映射
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), # Data2Vec-Vision模型对应的命名为TFData2VecVisionForSemanticSegmentation
("mobilevit", "TFMobileViTForSemanticSegmentation"), # MobileViT模型对应的命名为TFMobileViTForSemanticSegmentation
("segformer", "TFSegformerForSemanticSegmentation"), # Segformer模型对应的命名为TFSegformerForSemanticSegmentation
]
)
# 视觉到序列模型到类的映射
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "TFBlipForConditionalGeneration"), # BLIP模型对应的命名为TFBlipForConditionalGeneration
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), # Vision-Encoder-Decoder模型对应的命名为TFVisionEncoderDecoderModel
]
)
# Masked LM模型到类的映射
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Masked LM模型映射
("albert", "TFAlbertForMaskedLM"), # ALBERT模型对应的命名为TFAlbertForMaskedLM
("bert", "TFBertForMaskedLM"), # BERT模型对应的命名为TFBertForMaskedLM
("camembert", "TFCamembertForMaskedLM"), # Camembert模型对应的命名为TFCamembertForMaskedLM
("convbert", "TFConvBertForMaskedLM"), # ConvBERT模型对应的命名为TFConvBertForMaskedLM
("deberta", "TFDebertaForMaskedLM"), # DeBERTa模型对应的命名为TFDebertaForMaskedLM
("deberta-v2", "TFDebertaV2ForMaskedLM"), # DeBERTa-v2模型对应的命名为TFDebertaV2ForMaskedLM
("distilbert", "TFDistilBertForMaskedLM"), # DistilBERT模型对应的命名为TFDistilBertForMaskedLM
("electra", "TFElectraForMaskedLM"), # Electra模型对应的命名为TFElectraForMaskedLM
("esm", "TFEsmForMaskedLM"), # ESM模型对应的命名为TFEsmForMaskedLM
("flaubert", "TFFlaubertWithLMHeadModel"), # FlauBERT模型对应的命名为TFFlaubertWithLMHeadModel
("funnel", "TFFunnelForMaskedLM"), # Funnel模型对应的命名为TFFunnelForMaskedLM
("layoutlm", "TFLayoutLMForMaskedLM"), # LayoutLM模型对应的命名为TFLayoutLMForMaskedLM
("longformer", "TFLongformerForMaskedLM"), # Longformer模型对应的命名为TFLongformerForMaskedLM
("mobilebert", "TFMobileBertForMaskedLM"), # MobileBERT模型对应的命名为TFMobileBertForMaskedLM
("mpnet", "TFMPNetForMaskedLM"), # MPNet模型对应的命名为TFMPNetForMaskedLM
("rembert", "TFRemBertForMaskedLM"), # RemBERT模型对应的命名为TFRemBertForMaskedLM
("roberta", "TFRobertaForMaskedLM"), # RoBERTa模型对应的命名为TFRobertaForMaskedLM
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), # RoBERTa-prelayernorm模型对应的命名为TFRobertaPreLayerNormForMaskedLM
("roformer", "TFRoFormerForMaskedLM"), # RoFormer模型对应的命名为TFRoFormerForMaskedLM
("tapas", "TFTapasForMaskedLM"), # TAPAS模型对应的命名为TFTapasForMaskedLM
("xlm", "TFXLMWithLMHeadModel"), # XLM模型对应的命名为TFXLMWithLMHeadModel
("xlm-roberta", "TFXLMRobertaForMaskedLM"), # XLM-RoBERTa模型对应的命名为TFXLMRobertaForMaskedLM
]
)
# 创建一个有序字典,用于将模型名称映射到对应的 TensorFlow 序列到序列因果语言建模模型类名
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("bart", "TFBartForConditionalGeneration"), # BART模型的条件生成器
("blenderbot", "TFBlenderbotForConditionalGeneration"), # Blenderbot模型的条件生成器
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), # 小型Blenderbot模型的条件生成器
("encoder-decoder", "TFEncoderDecoderModel"), # 编码-解码模型
("led", "TFLEDForConditionalGeneration"), # LED模型的条件生成器
("marian", "TFMarianMTModel"), # Marian机器翻译模型
("mbart", "TFMBartForConditionalGeneration"), # mBART模型的条件生成器
("mt5", "TFMT5ForConditionalGeneration"), # MT5模型的条件生成器
("pegasus", "TFPegasusForConditionalGeneration"), # Pegasus模型的条件生成器
("t5", "TFT5ForConditionalGeneration"), # T5模型的条件生成器
]
)
# 创建一个有序字典,用于将模型名称映射到对应的 TensorFlow 语音序列到序列模型类名
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech_to_text", "TFSpeech2TextForConditionalGeneration"), # 语音转文本模型的条件生成器
("whisper", "TFWhisperForConditionalGeneration"), # Whisper模型的条件生成器
]
)
# 创建一个有序字典,用于将模型名称映射到对应的 TensorFlow 序列分类模型类名
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("albert", "TFAlbertForSequenceClassification"), # Albert模型的序列分类器
("bart", "TFBartForSequenceClassification"), # BART模型的序列分类器
("bert", "TFBertForSequenceClassification"), # BERT模型的序列分类器
("camembert", "TFCamembertForSequenceClassification"), # CamemBERT模型的序列分类器
("convbert", "TFConvBertForSequenceClassification"), # ConvBERT模型的序列分类器
("ctrl", "TFCTRLForSequenceClassification"), # CTRL模型的序列分类器
("deberta", "TFDebertaForSequenceClassification"), # DeBERTa模型的序列分类器
("deberta-v2", "TFDebertaV2ForSequenceClassification"), # DeBERTa-v2模型的序列分类器
("distilbert", "TFDistilBertForSequenceClassification"), # DistilBERT模型的序列分类器
("electra", "TFElectraForSequenceClassification"), # Electra模型的序列分类器
("esm", "TFEsmForSequenceClassification"), # ESM模型的序列分类器
("flaubert", "TFFlaubertForSequenceClassification"), # FlauBERT模型的序列分类器
("funnel", "TFFunnelForSequenceClassification"), # Funnel模型的序列分类器
("gpt-sw3", "TFGPT2ForSequenceClassification"), # GPT-SW3模型的序列分类器
("gpt2", "TFGPT2ForSequenceClassification"), # GPT-2模型的序列分类器
("gptj", "TFGPTJForSequenceClassification"), # GPT-J模型的序列分类器
("layoutlm", "TFLayoutLMForSequenceClassification"), # LayoutLM模型的序列分类器
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"), # LayoutLMv3模型的序列分类器
("longformer", "TFLongformerForSequenceClassification"), # Longformer模型的序列分类器
("mobilebert", "TFMobileBertForSequenceClassification"), # MobileBERT模型的序列分类器
("mpnet", "TFMPNetForSequenceClassification"), # MPNet模型的序列分类器
("openai-gpt", "TFOpenAIGPTForSequenceClassification"), # OpenAI-GPT模型的序列分类器
("rembert", "TFRemBertForSequenceClassification"), # RemBERT模型的序列分类器
("roberta", "TFRobertaForSequenceClassification"), # RoBERTa模型的序列分类器
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"), # RoBERTa-prelayernorm模型的序列分类器
("roformer", "TFRoFormerForSequenceClassification"), # RoFormer模型的序列分类器
("tapas", "TFTapasForSequenceClassification"), # TAPAS模型的序列分类器
("transfo-xl", "TFTransfoXLForSequenceClassification"), # TransfoXL模型的序列分类器
("xlm", "TFXLMForSequenceClassification"), # XLM模型的序列分类器
("xlm-roberta", "TFXLMRobertaForSequenceClassification"), # XLM-RoBERTa模型的序列分类器
("xlnet", "TFXLNetForSequenceClassification"), # XLNet模型的序列分类器
]
)
# 创建一个有序字典,用于将模型名称映射到对应的 TensorFlow 问答模型类名
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
# 定义了一个模型到类的映射关系列表,用于问答任务
[
# 使用 ALBERT 模型进行问答的类
("albert", "TFAlbertForQuestionAnswering"),
# 使用 BERT 模型进行问答的类
("bert", "TFBertForQuestionAnswering"),
# 使用 CamemBERT 模型进行问答的类
("camembert", "TFCamembertForQuestionAnswering"),
# 使用 ConvBERT 模型进行问答的类
("convbert", "TFConvBertForQuestionAnswering"),
# 使用 DeBERTa 模型进行问答的类
("deberta", "TFDebertaForQuestionAnswering"),
# 使用 DeBERTa-v2 模型进行问答的类
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
# 使用 DistilBERT 模型进行问答的类
("distilbert", "TFDistilBertForQuestionAnswering"),
# 使用 Electra 模型进行问答的类
("electra", "TFElectraForQuestionAnswering"),
# 使用 FlauBERT 模型进行问答的类
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
# 使用 Funnel 模型进行问答的类
("funnel", "TFFunnelForQuestionAnswering"),
# 使用 GPT-J 模型进行问答的类
("gptj", "TFGPTJForQuestionAnswering"),
# 使用 LayoutLMv3 模型进行问答的类
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
# 使用 Longformer 模型进行问答的类
("longformer", "TFLongformerForQuestionAnswering"),
# 使用 MobileBERT 模型进行问答的类
("mobilebert", "TFMobileBertForQuestionAnswering"),
# 使用 MPNet 模型进行问答的类
("mpnet", "TFMPNetForQuestionAnswering"),
# 使用 RemBERT 模型进行问答的类
("rembert", "TFRemBertForQuestionAnswering"),
# 使用 RoBERTa 模型进行问答的类
("roberta", "TFRobertaForQuestionAnswering"),
# 使用 RoBERTa-prelayernorm 模型进行问答的类
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
# 使用 RoFormer 模型进行问答的类
("roformer", "TFRoFormerForQuestionAnswering"),
# 使用 XLM 模型进行问答的类
("xlm", "TFXLMForQuestionAnsweringSimple"),
# 使用 XLM-RoBERTa 模型进行问答的类
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
# 使用 XLNet 模型进行问答的类
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
# 导入 OrderedDict 类型,用于创建有序字典,记录模型名称到 TensorFlow 类的映射关系
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
# 导入 OrderedDict 类型,用于创建有序字典,记录模型名称到 TensorFlow 类的映射关系
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("layoutlm", "TFLayoutLMForQuestionAnswering"),
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
]
)
# 导入 OrderedDict 类型,用于创建有序字典,记录模型名称到 TensorFlow 类的映射关系
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# 用于表格问答的模型映射
("tapas", "TFTapasForQuestionAnswering"),
]
)
# 导入 OrderedDict 类型,用于创建有序字典,记录模型名称到 TensorFlow 类的映射关系
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# 用于标记分类的模型映射
("albert", "TFAlbertForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("esm", "TFEsmForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
]
)
# 导入 OrderedDict 类型,用于创建有序字典,记录模型名称到 TensorFlow 类的映射关系
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
# 此处是多选题的模型映射
[
# 模型名称和对应的TensorFlow模型类名,用于多选题任务
("albert", "TFAlbertForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
("rembert", "TFRemBertForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
]
# 创建一个有序字典,用于将模型名称映射到相应的 TensorFlow 下一句预测模型类名
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
# 创建一个有序字典,用于将模型名称映射到相应的 TensorFlow 掩码生成模型类名
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
# 创建一个有序字典,用于将模型名称映射到相应的 TensorFlow 文本编码模型类名
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertModel"),
("bert", "TFBertModel"),
("convbert", "TFConvBertModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("distilbert", "TFDistilBertModel"),
("electra", "TFElectraModel"),
("flaubert", "TFFlaubertModel"),
("longformer", "TFLongformerModel"),
("mobilebert", "TFMobileBertModel"),
("mt5", "TFMT5EncoderModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("t5", "TFT5EncoderModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
]
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_MAPPING_NAMES
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
# 创建 LazyAutoMapping 对象,将 CONFIG_MAPPING_NAMES 映射到 TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
# 导入模块中的特定变量,CONFIG_MAPPING_NAMES 和 TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_MASK_GENERATION_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
# 使用 _LazyAutoMapping 类创建 TF_MODEL_FOR_TEXT_ENCODING_MAPPING 对象,映射配置名称到 TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_MASK_GENERATION_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_TEXT_ENCODING_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
class TFAutoModel(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
TFAutoModelForAudioClassification = auto_class_update(
TFAutoModelForAudioClassification, head_doc="audio classification"
)
class TFAutoModelForPreTraining(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_PRETRAINING_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
# Private on purpose, the public class will add the deprecation warnings.
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_WITH_LM_HEAD_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
class TFAutoModelForCausalLM(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_CAUSAL_LM_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
TFAutoModelForMaskedImageModeling = auto_class_update(
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
)
class TFAutoModelForImageClassification(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForImageClassification = auto_class_update(
TFAutoModelForImageClassification, head_doc="image classification"
)
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
# 设置类属性 _model_mapping 为 TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,用于自动模型选择
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification = auto_class_update(
TFAutoModelForZeroShotImageClassification,
head_doc="zero-shot image classification"
)
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
# 导入 TensorFlow 自动模型用于零样本图像分类,指定头部文档为“zero-shot image classification”
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于语义分割任务
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# 更新 TFAutoModelForSemanticSegmentation 类,添加头部文档描述为“semantic segmentation”
TFAutoModelForSemanticSegmentation = auto_class_update(
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于视觉到文本任务
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
# 更新 TFAutoModelForVision2Seq 类,添加头部文档描述为“vision-to-text modeling”
TFAutoModelForVision2Seq = auto_class_update(
TFAutoModelForVision2Seq, head_doc="vision-to-text modeling"
)
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于掩码语言建模任务
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
# 更新 TFAutoModelForMaskedLM 类,添加头部文档描述为“masked language modeling”
TFAutoModelForMaskedLM = auto_class_update(
TFAutoModelForMaskedLM, head_doc="masked language modeling"
)
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于序列到序列因果语言建模任务
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
# 更新 TFAutoModelForSeq2SeqLM 类,添加头部文档描述为“sequence-to-sequence language modeling”,
# 并指定一个示例的检查点名称为“google-t5/t5-base”
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于序列分类任务
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
# 更新 TFAutoModelForSequenceClassification 类,添加头部文档描述为“sequence classification”
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
)
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于问答任务
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
# 更新 TFAutoModelForQuestionAnswering 类,添加头部文档描述为“question answering”
TFAutoModelForQuestionAnswering = auto_class_update(
TFAutoModelForQuestionAnswering, head_doc="question answering"
)
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于文档问答任务
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
# 更新 TFAutoModelForDocumentQuestionAnswering 类,添加头部文档描述为“document question answering”,
# 并指定一个示例的检查点名称和修订版本号
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
TFAutoModelForDocumentQuestionAnswering,
head_doc="document question answering",
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于表格问答任务
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
# 更新 TFAutoModelForTableQuestionAnswering 类,添加头部文档描述为“table question answering”,
# 并指定一个示例的检查点名称为“google/tapas-base-finetuned-wtq”
TFAutoModelForTableQuestionAnswering = auto_class_update(
TFAutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于标记分类任务
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
# 更新 TFAutoModelForTokenClassification 类,添加头部文档描述为“token classification”
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
)
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于多项选择任务
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
# 更新 TFAutoModelForMultipleChoice 类,添加头部文档描述为“multiple choice”
TFAutoModelForMultipleChoice = auto_class_update(
TFAutoModelForMultipleChoice, head_doc="multiple choice"
)
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
# 定义自动化创建的 TensorFlow 模型类,用于下一句预测任务
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
# 更新 TFAutoModelForNextSentencePrediction 类,添加头部文档描述为“next sentence prediction”
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
# 定义了一个名为 TFAutoModelForSpeechSeq2Seq 的类,继承自 _BaseAutoModelClass
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
# 设置了一个类变量 _model_mapping,用于映射语音序列到序列模型
# 对 TFAutoModelForSpeechSeq2Seq 类进行更新,添加了头部文档信息,说明其为序列到序列语音转文本建模
TFAutoModelForSpeechSeq2Seq = auto_class_update(
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
@classmethod
def from_config(cls, config):
# 发出警告,提醒该类即将被弃用,建议使用特定的子类代替
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
# 调用父类方法,从给定的配置中创建对象
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# 发出警告,提醒该类即将被弃用,建议使用特定的子类代替
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
# 调用父类方法,从预训练模型名或路径创建对象
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
.\models\auto\processing_auto.py
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" AutoProcessor class."""
# 导入必要的库和模块
import importlib # 导入动态导入模块的功能
import inspect # 导入用于检查对象的属性和方法的模块
import json # 导入处理 JSON 格式数据的模块
import os # 导入与操作系统交互的功能
import warnings # 导入警告处理模块
from collections import OrderedDict # 导入有序字典类型
# 导入其他本地库和模块
from ...configuration_utils import PretrainedConfig # 导入预训练配置类
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code # 导入动态模块加载函数
from ...feature_extraction_utils import FeatureExtractionMixin # 导入特征提取混合类
from ...image_processing_utils import ImageProcessingMixin # 导入图像处理混合类
from ...processing_utils import ProcessorMixin # 导入处理混合类
from ...tokenization_utils import TOKENIZER_CONFIG_FILE # 导入分词配置文件
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging # 导入工具函数和日志记录
# 导入本地自动化处理工厂和配置模块
from .auto_factory import _LazyAutoMapping # 导入惰性自动映射类
from .configuration_auto import (
CONFIG_MAPPING_NAMES, # 导入配置映射名称
AutoConfig, # 导入自动配置类
model_type_to_module_name, # 导入模型类型到模块名称的映射函数
replace_list_option_in_docstrings, # 导入替换文档字符串中列表选项的函数
)
from .feature_extraction_auto import AutoFeatureExtractor # 导入自动特征提取器类
from .image_processing_auto import AutoImageProcessor # 导入自动图像处理器类
from .tokenization_auto import AutoTokenizer # 导入自动分词器类
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 定义处理器映射名称,使用有序字典存储
PROCESSOR_MAPPING_NAMES = OrderedDict(
# 定义一个包含处理器名称和对应处理器类的元组列表
[
# ('align', 'AlignProcessor'):处理器名称为 'align',对应处理器类为 'AlignProcessor'
("align", "AlignProcessor"),
# ('altclip', 'AltCLIPProcessor'):处理器名称为 'altclip',对应处理器类为 'AltCLIPProcessor'
("altclip", "AltCLIPProcessor"),
# ('bark', 'BarkProcessor'):处理器名称为 'bark',对应处理器类为 'BarkProcessor'
("bark", "BarkProcessor"),
# ('blip', 'BlipProcessor'):处理器名称为 'blip',对应处理器类为 'BlipProcessor'
("blip", "BlipProcessor"),
# ('blip-2', 'Blip2Processor'):处理器名称为 'blip-2',对应处理器类为 'Blip2Processor'
("blip-2", "Blip2Processor"),
# ('bridgetower', 'BridgeTowerProcessor'):处理器名称为 'bridgetower',对应处理器类为 'BridgeTowerProcessor'
("bridgetower", "BridgeTowerProcessor"),
# ('chinese_clip', 'ChineseCLIPProcessor'):处理器名称为 'chinese_clip',对应处理器类为 'ChineseCLIPProcessor'
("chinese_clip", "ChineseCLIPProcessor"),
# ('clap', 'ClapProcessor'):处理器名称为 'clap',对应处理器类为 'ClapProcessor'
("clap", "ClapProcessor"),
# ('clip', 'CLIPProcessor'):处理器名称为 'clip',对应处理器类为 'CLIPProcessor'
("clip", "CLIPProcessor"),
# ('clipseg', 'CLIPSegProcessor'):处理器名称为 'clipseg',对应处理器类为 'CLIPSegProcessor'
("clipseg", "CLIPSegProcessor"),
# ('clvp', 'ClvpProcessor'):处理器名称为 'clvp',对应处理器类为 'ClvpProcessor'
("clvp", "ClvpProcessor"),
# ('flava', 'FlavaProcessor'):处理器名称为 'flava',对应处理器类为 'FlavaProcessor'
("flava", "FlavaProcessor"),
# ('fuyu', 'FuyuProcessor'):处理器名称为 'fuyu',对应处理器类为 'FuyuProcessor'
("fuyu", "FuyuProcessor"),
# ('git', 'GitProcessor'):处理器名称为 'git',对应处理器类为 'GitProcessor'
("git", "GitProcessor"),
# ('groupvit', 'CLIPProcessor'):处理器名称为 'groupvit',对应处理器类为 'CLIPProcessor'
("groupvit", "CLIPProcessor"),
# ('hubert', 'Wav2Vec2Processor'):处理器名称为 'hubert',对应处理器类为 'Wav2Vec2Processor'
("hubert", "Wav2Vec2Processor"),
# ('idefics', 'IdeficsProcessor'):处理器名称为 'idefics',对应处理器类为 'IdeficsProcessor'
("idefics", "IdeficsProcessor"),
# ('instructblip', 'InstructBlipProcessor'):处理器名称为 'instructblip',对应处理器类为 'InstructBlipProcessor'
("instructblip", "InstructBlipProcessor"),
# ('kosmos-2', 'Kosmos2Processor'):处理器名称为 'kosmos-2',对应处理器类为 'Kosmos2Processor'
("kosmos-2", "Kosmos2Processor"),
# ('layoutlmv2', 'LayoutLMv2Processor'):处理器名称为 'layoutlmv2',对应处理器类为 'LayoutLMv2Processor'
("layoutlmv2", "LayoutLMv2Processor"),
# ('layoutlmv3', 'LayoutLMv3Processor'):处理器名称为 'layoutlmv3',对应处理器类为 'LayoutLMv3Processor'
("layoutlmv3", "LayoutLMv3Processor"),
# ('llava', 'LlavaProcessor'):处理器名称为 'llava',对应处理器类为 'LlavaProcessor'
("llava", "LlavaProcessor"),
# ('llava_next', 'LlavaNextProcessor'):处理器名称为 'llava_next',对应处理器类为 'LlavaNextProcessor'
("llava_next", "LlavaNextProcessor"),
# ('markuplm', 'MarkupLMProcessor'):处理器名称为 'markuplm',对应处理器类为 'MarkupLMProcessor'
("markuplm", "MarkupLMProcessor"),
# ('mctct', 'MCTCTProcessor'):处理器名称为 'mctct',对应处理器类为 'MCTCTProcessor'
("mctct", "MCTCTProcessor"),
# ('mgp-str', 'MgpstrProcessor'):处理器名称为 'mgp-str',对应处理器类为 'MgpstrProcessor'
("mgp-str", "MgpstrProcessor"),
# ('oneformer', 'OneFormerProcessor'):处理器名称为 'oneformer',对应处理器类为 'OneFormerProcessor'
("oneformer", "OneFormerProcessor"),
# ('owlv2', 'Owlv2Processor'):处理器名称为 'owlv2',对应处理器类为 'Owlv2Processor'
("owlv2", "Owlv2Processor"),
# ('owlvit', 'OwlViTProcessor'):处理器名称为 'owlvit',对应处理器类为 'OwlViTProcessor'
("owlvit", "OwlViTProcessor"),
# ('pix2struct', 'Pix2StructProcessor'):处理器名称为 'pix2struct',对应处理器类为 'Pix2StructProcessor'
("pix2struct", "Pix2StructProcessor"),
# ('pop2piano', 'Pop2PianoProcessor'):处理器名称为 'pop2piano',对应处理器类为 'Pop2PianoProcessor'
("pop2piano", "Pop2PianoProcessor"),
# ('sam', 'SamProcessor'):处理器名称为 'sam',对应处理器类为 'SamProcessor'
("sam", "SamProcessor"),
# ('seamless_m4t', 'SeamlessM4TProcessor'):处理器名称为 'seamless_m4t',对应处理器类为 'SeamlessM4TProcessor'
("seamless_m4t", "SeamlessM4TProcessor"),
# ('sew', 'Wav2Vec2Processor'):处理器名称为 'sew',对应处理器类为 'Wav2Vec2Processor'
("sew", "Wav2Vec2Processor"),
# ('sew-d', 'Wav2Vec2Processor'):处理器名称为 'sew-d',对应处理器类为 'Wav2Vec2Processor'
("sew-d", "Wav2Vec2Processor"),
# ('siglip', 'SiglipProcessor'):处理器名称为 'siglip',对应处理器类为 'SiglipProcessor'
("siglip", "SiglipProcessor"),
# ('speech_to_text', 'Speech2TextProcessor'):处理器名称为 'speech_to_text',对应处理器类为 'Speech2TextProcessor'
("speech_to_text", "Speech2TextProcessor"),
# ('speech_to_text_2', 'Speech2Text2Processor'):处理器名称为 'speech_to_text_2',对应处理器类为 'Speech2Text2Processor'
("speech_to_text_2", "Speech2Text2Processor"),
# ('speecht5', 'SpeechT5Processor'):处理器名称为 'speecht5',对应处理器类为 'SpeechT5Processor'
("speecht5", "SpeechT5Processor"),
# ('trocr', 'TrOCRProcessor'):处理器名称为 'trocr',对应处理器类为 'TrOCRProcessor'
("trocr", "TrOCRProcessor"),
# ('tvlt', 'TvltProcessor'):处理器名称为 'tvlt',对应处理器类为 'TvltProcessor'
("tvlt", "TvltProcessor"),
# ('tvp', 'TvpProcessor'):处理器名称为 'tvp',对应处理器类为 'TvpProcessor'
("tvp", "TvpProcessor"),
# ('unispeech', 'Wav2Vec2Processor'):处理器名称为 'unispeech',对应处理器类为 'Wav2Vec2Processor'
("unispeech", "Wav2Vec2Processor"),
# ('un
# 这里导入了_LazyAutoMapping和replace_list_option_in_docstrings函数,以及CONFIG_MAPPING_NAMES和PROCESSOR_MAPPING_NAMES变量。
PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
# 根据给定的类名查找并返回对应的处理器类。
def processor_class_from_name(class_name: str):
# 遍历PROCESSOR_MAPPING_NAMES中的模块名和处理器列表
for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
# 如果class_name在当前处理器列表中
if class_name in processors:
# 将模块名转换为对应的模块路径
module_name = model_type_to_module_name(module_name)
# 动态导入transformers.models下的特定模块
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
# 返回模块中的class_name类对象
return getattr(module, class_name)
except AttributeError:
# 如果属性错误,则继续下一个模块的尝试
continue
# 如果在PROCESSOR_MAPPING的额外内容中找到class_name对应的处理器,则返回该处理器
for processor in PROCESSOR_MAPPING._extra_content.values():
if getattr(processor, "__name__", None) == class_name:
return processor
# 如果以上都找不到,则尝试从transformers主模块中导入class_name类
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
# 如果还是找不到,则返回None
return None
class AutoProcessor:
r"""
This is a generic processor class that will be instantiated as one of the processor classes of the library when
created with the [`AutoProcessor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
# AutoProcessor的构造函数,抛出环境错误,指导使用`AutoProcessor.from_pretrained(pretrained_model_name_or_path)`方法来实例化
def __init__(self):
raise EnvironmentError(
"AutoProcessor is designed to be instantiated "
"using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
# 装饰器,用于在文档字符串中替换列表选项,使用PROCESSOR_MAPPING_NAMES参数
@replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
# 静态方法,注册新的处理器类到PROCESSOR_MAPPING中
def register(config_class, processor_class, exist_ok=False):
"""
Register a new processor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
processor_class ([`FeatureExtractorMixin`]): The processor to register.
"""
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
.\models\auto\tokenization_auto.py
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Tokenizer class."""
import importlib # 导入用于动态导入模块的标准库
import json # 导入处理 JSON 格式数据的标准库
import os # 导入与操作系统交互的标准库
import warnings # 导入警告处理相关的标准库
from collections import OrderedDict # 导入有序字典的标准库
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union # 导入类型提示相关的标准库
from ...configuration_utils import PretrainedConfig # 导入预训练配置类
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code # 导入动态模块相关工具函数
from ...tokenization_utils import PreTrainedTokenizer # 导入预训练分词器基类
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE # 导入分词器配置文件常量
from ...utils import ( # 导入一些工具函数
cached_file,
extract_commit_hash,
is_g2p_en_available,
is_sentencepiece_available,
is_tokenizers_available,
logging,
)
from ..encoder_decoder import EncoderDecoderConfig # 导入编码器解码器配置类
from .auto_factory import _LazyAutoMapping # 导入自动工厂相关类
from .configuration_auto import ( # 导入自动配置相关的模块
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
if is_tokenizers_available():
from ...tokenization_utils_fast import PreTrainedTokenizerFast # 如果有安装 tokenizers,导入快速分词器
else:
PreTrainedTokenizerFast = None # 否则将快速分词器设为 None
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
if TYPE_CHECKING:
# 定义一个有序字典,用于存储分词器名称及其对应的模块名和类名元组
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
# 如果不是类型检查模式,则 TOKENIZER_MAPPING_NAMES 初始化为空
TOKENIZER_MAPPING_NAMES = OrderedDict()
# 使用 _LazyAutoMapping 类初始化 TOKENIZER_MAPPING
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
# 通过 CONFIG_MAPPING_NAMES 创建反向映射字典,用于从映射名称到类型的转换
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
def tokenizer_class_from_name(class_name: str):
# 根据类名返回相应的分词器类对象
if class_name == "PreTrainedTokenizerFast":
return PreTrainedTokenizerFast
# 遍历 TOKENIZER_MAPPING_NAMES,查找与 class_name 匹配的分词器类
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)
# 动态导入 transformers.models 下的指定模块
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name) # 返回指定模块下的类对象
except AttributeError:
continue
# 如果在 TOKENIZER_MAPPING 中找不到对应的类,尝试从 _extra_content 中查找
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:
if getattr(tokenizer, "__name__", None) == class_name:
return tokenizer
# 如果以上方法都无法找到指定类,则返回 None
# 这段代码用于处理未能找到类的情况,可能是由于依赖项缺失导致的
# 在这种情况下,该类应该在主要的模块中
# 导入 importlib 模块,并使用它来导入名为 "transformers" 的模块
main_module = importlib.import_module("transformers")
# 检查在导入的模块中是否存在名为 class_name 的属性
if hasattr(main_module, class_name):
# 如果存在,则返回该属性对应的对象或函数
return getattr(main_module, class_name)
# 如果不存在名为 class_name 的属性,则返回 None
return None
# 加载预训练模型的分词器配置信息
def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the tokenizer.
Examples:
```
# 从huggingface.co下载配置文件并进行缓存
```
# 获取指定预训练模型的分词器配置信息
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
# 由于这个模型没有分词器配置,所以结果将会是一个空字典。
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
# 导入transformers库中的AutoTokenizer类,用于自动获取预训练模型的分词器
from transformers import AutoTokenizer
# 从预训练模型路径中加载分词器
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
# 将分词器保存到本地目录"tokenizer-test"中
tokenizer.save_pretrained("tokenizer-test")
# 获取保存的分词器配置信息
tokenizer_config = get_tokenizer_config("tokenizer-test")
```
# 处理`use_auth_token`参数的兼容性警告和错误处理逻辑
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
# 发出将在Transformers v5中移除`use_auth_token`参数的警告
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了`token`和`use_auth_token`参数,则抛出错误
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将`use_auth_token`参数的值赋给`token`变量
token = use_auth_token
# 获取kwargs中的_commit_hash参数值
commit_hash = kwargs.get("_commit_hash", None)
# 解析和缓存预训练模型的tokenizer配置文件
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
# 如果未能定位tokenizer配置文件,则记录日志并返回空字典
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
# 提取配置文件的提交哈希值
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
# 打开配置文件并加载其内容到result字典中
with open(resolved_config_file, encoding="utf-8") as reader:
result = json.load(reader)
# 将提取的提交哈希值存入result字典中的"_commit_hash"键
result["_commit_hash"] = commit_hash
return result
class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
created with the [`AutoTokenizer.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
# 抛出环境错误,阻止直接实例化该类
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
"""
Register a new tokenizer in this mapping.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
The slow tokenizer to register.
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
The fast tokenizer to register.
"""
# 检查是否提供了慢速或快速的分词器类,否则抛出值错误
if slow_tokenizer_class is None and fast_tokenizer_class is None:
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class`.")
# 如果在`slow_tokenizer_class`中传入了快速分词器类,则抛出值错误
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
# 如果在`fast_tokenizer_class`中传入了慢速分词器类,则抛出值错误
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
# 如果同时提供了慢速和快速分词器类,并且快速分词器类有一个与传入的慢速分词器类不一致的`slow_tokenizer_class`属性,则抛出值错误
if (
slow_tokenizer_class is not None
and fast_tokenizer_class is not None
and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
):
raise ValueError(
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
"consistent with the slow tokenizer class you passed (fast tokenizer has "
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
"so they match!"
)
# 如果已经在TOKENIZER_MAPPING._extra_content中注册了config_class,则尝试使用现有的慢速和快速分词器类
if config_class in TOKENIZER_MAPPING._extra_content:
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
if slow_tokenizer_class is None:
slow_tokenizer_class = existing_slow
if fast_tokenizer_class is None:
fast_tokenizer_class = existing_fast
# 在TOKENIZER_MAPPING中注册config_class与其对应的慢速和快速分词器类的映射
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)