我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现
先复制一点知乎上的内容
按照上面的流程图,一个ViT block可以分为以下几个步骤
(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题
(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768
(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768
(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768
一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 Z0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类
vit 的 numpy 实现代码,可以直接看懂各个部分的细节实现 ,和bert有一些不一样,除了embedding层不一样之外,还有模型结构有有些不同,主要是layer_normalization放在了attention层和feed_forword层之前,bert都是放在之后
import numpy as np import os from PIL import Image # 加载保存的模型数据 model_data = np.load('vit_model_params.npz') for i in model_data: # print(i) print(i,model_data[i].shape) patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"] patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"] position_embeddings = model_data["vit.embeddings.position_embeddings"] cls_token_embeddings = model_data["vit.embeddings.cls_token"] def patch_embedding(images): # 卷积核大小 kernel_size = 16 return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size) def position_embedding(): return position_embeddings def model_input(images): patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1)) patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1) # position_ids = np.array(range(patch_embedded.shape[1])) # 位置id # 位置嵌入矩阵,形状为 (max_position, embedding_size) position_embedded = position_embedding() embedding_output = patch_embedded + position_embedded return embedding_output def softmax(x, axis=None): # e_x = np.exp(x).astype(np.float32) # e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32) return e_x / sum_ex def conv2d(images,weight,bias,stride=1,padding=0): # 卷积操作 N, C, H, W = images.shape F, _, HH, WW = weight.shape # 计算卷积后的输出尺寸 H_out = (H - HH + 2 * padding) // stride + 1 W_out = (W - WW + 2 * padding) // stride + 1 # 初始化卷积层输出 out = np.zeros((N, F, H_out, W_out)) # 执行卷积运算 for i in range(H_out): for j in range(W_out): # 提取当前卷积窗口 window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW] # 执行卷积运算 out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias # 输出结果 # print("卷积层输出尺寸:", out.shape) return out def scaled_dot_product_attention(Q, K, V, mask=None): d_k = Q.shape[-1] scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k) if mask is not None: scores = np.where(mask, scores, np.full_like(scores, -np.inf)) attention_weights = softmax(scores, axis=-1) # print(attention_weights) # print(np.sum(attention_weights,axis=-1)) output = np.matmul(attention_weights, V) return output, attention_weights def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O): q = np.matmul(input, W_Q.T)+B_Q k = np.matmul(input, W_K.T)+B_K v = np.matmul(input, W_V.T)+B_V # 分割输入为多个头 q = np.split(q, num_heads, axis=-1) k = np.split(k, num_heads, axis=-1) v = np.split(v, num_heads, axis=-1) outputs = [] for q_,k_,v_ in zip(q,k,v): output, attention_weights = scaled_dot_product_attention(q_, k_, v_) outputs.append(output) outputs = np.concatenate(outputs, axis=-1) outputs = np.matmul(outputs, W_O.T)+B_O return outputs def layer_normalization(x, weight, bias, eps=1e-12): mean = np.mean(x, axis=-1, keepdims=True) variance = np.var(x, axis=-1, keepdims=True) std = np.sqrt(variance + eps) normalized_x = (x - mean) / std output = weight * normalized_x + bias return output def feed_forward_layer(inputs, weight, bias, activation='relu'): linear_output = np.matmul(inputs,weight) + bias if activation == 'relu': activated_output = np.maximum(0, linear_output) # ReLU激活函数 elif activation == 'gelu': activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3)))) # GELU激活函数 elif activation == "tanh" : activated_output = np.tanh(linear_output) else: activated_output = linear_output # 无激活函数 return activated_output def residual_connection(inputs, residual): # 残差连接 residual_output = inputs + residual return residual_output def vit(input,num_heads=12): for i in range(12): # 调用多头自注意力函数 W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)] B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)] W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)] B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)] W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)] B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)] W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)] B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)] intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)] intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)] dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)] dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)] LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)] LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)] LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)] LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)] output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias) output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O) output1 = residual_connection(input,output) #这里和模型输出一致 output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias) #一致 output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu') output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='') output2 = residual_connection(output1,output) input = output2 bert_pooler_dense_weight = model_data['vit.layernorm.weight'] bert_pooler_dense_bias = model_data['vit.layernorm.bias'] output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias ) #一致 classifier_weight = model_data['classifier.weight'] classifier_bias = model_data['classifier.bias'] output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" ) #一致 output = np.argmax(output,axis=-1) return output folder_path = './cifar10' # 替换为图片所在的文件夹路径 def infer_images_in_folder(folder_path): for file_name in os.listdir(folder_path): file_path = os.path.join(folder_path, file_name) if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')): image = Image.open(file_path) image = image.resize((224, 224)) label = file_name.split(".")[0].split("_")[1] image = np.array(image)/255.0 image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image,axis=0) print("file_path:",file_path,"img size:",image.shape,"label:",label) input = model_input(image) predicted_class = vit(input) print('Predicted class:', predicted_class) if __name__ == "__main__": infer_images_in_folder(folder_path)
结果:
file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1 Predicted class: [1] file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2 Predicted class: [2] file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1 Predicted class: [1] file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8 Predicted class: [8] file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9 Predicted class: [9] file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7 Predicted class: [7] file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5
模型参数:
vit.embeddings.cls_token (1, 1, 768) vit.embeddings.position_embeddings (1, 197, 768) vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16) vit.embeddings.patch_embeddings.projection.bias (768,) vit.encoder.layer.0.attention.attention.query.weight (768, 768) vit.encoder.layer.0.attention.attention.query.bias (768,) vit.encoder.layer.0.attention.attention.key.weight (768, 768) vit.encoder.layer.0.attention.attention.key.bias (768,) vit.encoder.layer.0.attention.attention.value.weight (768, 768) vit.encoder.layer.0.attention.attention.value.bias (768,) vit.encoder.layer.0.attention.output.dense.weight (768, 768) vit.encoder.layer.0.attention.output.dense.bias (768,) vit.encoder.layer.0.intermediate.dense.weight (3072, 768) vit.encoder.layer.0.intermediate.dense.bias (3072,) vit.encoder.layer.0.output.dense.weight (768, 3072) vit.encoder.layer.0.output.dense.bias (768,) vit.encoder.layer.0.layernorm_before.weight (768,) vit.encoder.layer.0.layernorm_before.bias (768,) vit.encoder.layer.0.layernorm_after.weight (768,) vit.encoder.layer.0.layernorm_after.bias (768,) vit.encoder.layer.1.attention.attention.query.weight (768, 768) vit.encoder.layer.1.attention.attention.query.bias (768,) vit.encoder.layer.1.attention.attention.key.weight (768, 768) vit.encoder.layer.1.attention.attention.key.bias (768,) vit.encoder.layer.1.attention.attention.value.weight (768, 768) vit.encoder.layer.1.attention.attention.value.bias (768,) vit.encoder.layer.1.attention.output.dense.weight (768, 768) vit.encoder.layer.1.attention.output.dense.bias (768,) vit.encoder.layer.1.intermediate.dense.weight (3072, 768) vit.encoder.layer.1.intermediate.dense.bias (3072,) vit.encoder.layer.1.output.dense.weight (768, 3072) vit.encoder.layer.1.output.dense.bias (768,) vit.encoder.layer.1.layernorm_before.weight (768,) vit.encoder.layer.1.layernorm_before.bias (768,) vit.encoder.layer.1.layernorm_after.weight (768,) vit.encoder.layer.1.layernorm_after.bias (768,) vit.encoder.layer.2.attention.attention.query.weight (768, 768) vit.encoder.layer.2.attention.attention.query.bias (768,) vit.encoder.layer.2.attention.attention.key.weight (768, 768) vit.encoder.layer.2.attention.attention.key.bias (768,) vit.encoder.layer.2.attention.attention.value.weight (768, 768) vit.encoder.layer.2.attention.attention.value.bias (768,) vit.encoder.layer.2.attention.output.dense.weight (768, 768) vit.encoder.layer.2.attention.output.dense.bias (768,) vit.encoder.layer.2.intermediate.dense.weight (3072, 768) vit.encoder.layer.2.intermediate.dense.bias (3072,) vit.encoder.layer.2.output.dense.weight (768, 3072) vit.encoder.layer.2.output.dense.bias (768,) vit.encoder.layer.2.layernorm_before.weight (768,) vit.encoder.layer.2.layernorm_before.bias (768,) vit.encoder.layer.2.layernorm_after.weight (768,) vit.encoder.layer.2.layernorm_after.bias (768,) vit.encoder.layer.3.attention.attention.query.weight (768, 768) vit.encoder.layer.3.attention.attention.query.bias (768,) vit.encoder.layer.3.attention.attention.key.weight (768, 768) vit.encoder.layer.3.attention.attention.key.bias (768,) vit.encoder.layer.3.attention.attention.value.weight (768, 768) vit.encoder.layer.3.attention.attention.value.bias (768,) vit.encoder.layer.3.attention.output.dense.weight (768, 768) vit.encoder.layer.3.attention.output.dense.bias (768,) vit.encoder.layer.3.intermediate.dense.weight (3072, 768) vit.encoder.layer.3.intermediate.dense.bias (3072,) vit.encoder.layer.3.output.dense.weight (768, 3072) vit.encoder.layer.3.output.dense.bias (768,) vit.encoder.layer.3.layernorm_before.weight (768,) vit.encoder.layer.3.layernorm_before.bias (768,) vit.encoder.layer.3.layernorm_after.weight (768,) vit.encoder.layer.3.layernorm_after.bias (768,) vit.encoder.layer.4.attention.attention.query.weight (768, 768) vit.encoder.layer.4.attention.attention.query.bias (768,) vit.encoder.layer.4.attention.attention.key.weight (768, 768) vit.encoder.layer.4.attention.attention.key.bias (768,) vit.encoder.layer.4.attention.attention.value.weight (768, 768) vit.encoder.layer.4.attention.attention.value.bias (768,) vit.encoder.layer.4.attention.output.dense.weight (768, 768) vit.encoder.layer.4.attention.output.dense.bias (768,) vit.encoder.layer.4.intermediate.dense.weight (3072, 768) vit.encoder.layer.4.intermediate.dense.bias (3072,) vit.encoder.layer.4.output.dense.weight (768, 3072) vit.encoder.layer.4.output.dense.bias (768,) vit.encoder.layer.4.layernorm_before.weight (768,) vit.encoder.layer.4.layernorm_before.bias (768,) vit.encoder.layer.4.layernorm_after.weight (768,) vit.encoder.layer.4.layernorm_after.bias (768,) vit.encoder.layer.5.attention.attention.query.weight (768, 768) vit.encoder.layer.5.attention.attention.query.bias (768,) vit.encoder.layer.5.attention.attention.key.weight (768, 768) vit.encoder.layer.5.attention.attention.key.bias (768,) vit.encoder.layer.5.attention.attention.value.weight (768, 768) vit.encoder.layer.5.attention.attention.value.bias (768,) vit.encoder.layer.5.attention.output.dense.weight (768, 768) vit.encoder.layer.5.attention.output.dense.bias (768,) vit.encoder.layer.5.intermediate.dense.weight (3072, 768) vit.encoder.layer.5.intermediate.dense.bias (3072,) vit.encoder.layer.5.output.dense.weight (768, 3072) vit.encoder.layer.5.output.dense.bias (768,) vit.encoder.layer.5.layernorm_before.weight (768,) vit.encoder.layer.5.layernorm_before.bias (768,) vit.encoder.layer.5.layernorm_after.weight (768,) vit.encoder.layer.5.layernorm_after.bias (768,) vit.encoder.layer.6.attention.attention.query.weight (768, 768) vit.encoder.layer.6.attention.attention.query.bias (768,) vit.encoder.layer.6.attention.attention.key.weight (768, 768) vit.encoder.layer.6.attention.attention.key.bias (768,) vit.encoder.layer.6.attention.attention.value.weight (768, 768) vit.encoder.layer.6.attention.attention.value.bias (768,) vit.encoder.layer.6.attention.output.dense.weight (768, 768) vit.encoder.layer.6.attention.output.dense.bias (768,) vit.encoder.layer.6.intermediate.dense.weight (3072, 768) vit.encoder.layer.6.intermediate.dense.bias (3072,) vit.encoder.layer.6.output.dense.weight (768, 3072) vit.encoder.layer.6.output.dense.bias (768,) vit.encoder.layer.6.layernorm_before.weight (768,) vit.encoder.layer.6.layernorm_before.bias (768,) vit.encoder.layer.6.layernorm_after.weight (768,) vit.encoder.layer.6.layernorm_after.bias (768,) vit.encoder.layer.7.attention.attention.query.weight (768, 768) vit.encoder.layer.7.attention.attention.query.bias (768,) vit.encoder.layer.7.attention.attention.key.weight (768, 768) vit.encoder.layer.7.attention.attention.key.bias (768,) vit.encoder.layer.7.attention.attention.value.weight (768, 768) vit.encoder.layer.7.attention.attention.value.bias (768,) vit.encoder.layer.7.attention.output.dense.weight (768, 768) vit.encoder.layer.7.attention.output.dense.bias (768,) vit.encoder.layer.7.intermediate.dense.weight (3072, 768) vit.encoder.layer.7.intermediate.dense.bias (3072,) vit.encoder.layer.7.output.dense.weight (768, 3072) vit.encoder.layer.7.output.dense.bias (768,) vit.encoder.layer.7.layernorm_before.weight (768,) vit.encoder.layer.7.layernorm_before.bias (768,) vit.encoder.layer.7.layernorm_after.weight (768,) vit.encoder.layer.7.layernorm_after.bias (768,) vit.encoder.layer.8.attention.attention.query.weight (768, 768) vit.encoder.layer.8.attention.attention.query.bias (768,) vit.encoder.layer.8.attention.attention.key.weight (768, 768) vit.encoder.layer.8.attention.attention.key.bias (768,) vit.encoder.layer.8.attention.attention.value.weight (768, 768) vit.encoder.layer.8.attention.attention.value.bias (768,) vit.encoder.layer.8.attention.output.dense.weight (768, 768) vit.encoder.layer.8.attention.output.dense.bias (768,) vit.encoder.layer.8.intermediate.dense.weight (3072, 768) vit.encoder.layer.8.intermediate.dense.bias (3072,) vit.encoder.layer.8.output.dense.weight (768, 3072) vit.encoder.layer.8.output.dense.bias (768,) vit.encoder.layer.8.layernorm_before.weight (768,) vit.encoder.layer.8.layernorm_before.bias (768,) vit.encoder.layer.8.layernorm_after.weight (768,) vit.encoder.layer.8.layernorm_after.bias (768,) vit.encoder.layer.9.attention.attention.query.weight (768, 768) vit.encoder.layer.9.attention.attention.query.bias (768,) vit.encoder.layer.9.attention.attention.key.weight (768, 768) vit.encoder.layer.9.attention.attention.key.bias (768,) vit.encoder.layer.9.attention.attention.value.weight (768, 768) vit.encoder.layer.9.attention.attention.value.bias (768,) vit.encoder.layer.9.attention.output.dense.weight (768, 768) vit.encoder.layer.9.attention.output.dense.bias (768,) vit.encoder.layer.9.intermediate.dense.weight (3072, 768) vit.encoder.layer.9.intermediate.dense.bias (3072,) vit.encoder.layer.9.output.dense.weight (768, 3072) vit.encoder.layer.9.output.dense.bias (768,) vit.encoder.layer.9.layernorm_before.weight (768,) vit.encoder.layer.9.layernorm_before.bias (768,) vit.encoder.layer.9.layernorm_after.weight (768,) vit.encoder.layer.9.layernorm_after.bias (768,) vit.encoder.layer.10.attention.attention.query.weight (768, 768) vit.encoder.layer.10.attention.attention.query.bias (768,) vit.encoder.layer.10.attention.attention.key.weight (768, 768) vit.encoder.layer.10.attention.attention.key.bias (768,) vit.encoder.layer.10.attention.attention.value.weight (768, 768) vit.encoder.layer.10.attention.attention.value.bias (768,) vit.encoder.layer.10.attention.output.dense.weight (768, 768) vit.encoder.layer.10.attention.output.dense.bias (768,) vit.encoder.layer.10.intermediate.dense.weight (3072, 768) vit.encoder.layer.10.intermediate.dense.bias (3072,) vit.encoder.layer.10.output.dense.weight (768, 3072) vit.encoder.layer.10.output.dense.bias (768,) vit.encoder.layer.10.layernorm_before.weight (768,) vit.encoder.layer.10.layernorm_before.bias (768,) vit.encoder.layer.10.layernorm_after.weight (768,) vit.encoder.layer.10.layernorm_after.bias (768,) vit.encoder.layer.11.attention.attention.query.weight (768, 768) vit.encoder.layer.11.attention.attention.query.bias (768,) vit.encoder.layer.11.attention.attention.key.weight (768, 768) vit.encoder.layer.11.attention.attention.key.bias (768,) vit.encoder.layer.11.attention.attention.value.weight (768, 768) vit.encoder.layer.11.attention.attention.value.bias (768,) vit.encoder.layer.11.attention.output.dense.weight (768, 768) vit.encoder.layer.11.attention.output.dense.bias (768,) vit.encoder.layer.11.intermediate.dense.weight (3072, 768) vit.encoder.layer.11.intermediate.dense.bias (3072,) vit.encoder.layer.11.output.dense.weight (768, 3072) vit.encoder.layer.11.output.dense.bias (768,) vit.encoder.layer.11.layernorm_before.weight (768,) vit.encoder.layer.11.layernorm_before.bias (768,) vit.encoder.layer.11.layernorm_after.weight (768,) vit.encoder.layer.11.layernorm_after.bias (768,) vit.layernorm.weight (768,) vit.layernorm.bias (768,) classifier.weight (10, 768) classifier.bias (10,)
hungging face模型训练代码 对cifar10训练,保存模型参数为numpy格式,方便numpy的模型加载:
import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from transformers import ViTModel, ViTForImageClassification from tqdm import tqdm import numpy as np # 设置随机种子 torch.manual_seed(42) # 定义超参数 batch_size = 64 num_epochs = 1 learning_rate = 1e-4 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # 加载CIFAR-10数据集 train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform) test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型 vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头 num_classes = 10 # vit_model.config.classifier = 'mlp' # vit_model.config.num_labels = num_classes vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # parameters = list(vit_model.parameters()) # for x in parameters[:-1]: # x.requires_grad = False # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型 for epoch in range(num_epochs): print("epoch:",epoch) vit_model.train() train_loss = 0.0 train_correct = 0 bar = tqdm(train_loader,total=len(train_loader)) for images, labels in bar: images = images.to(device) labels = labels.to(device) # 前向传播 outputs = vit_model(images) loss = criterion(outputs.logits, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(outputs.logits, 1) train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率 train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估 vit_model.eval() test_loss = 0.0 test_correct = 0 with torch.no_grad(): bar = tqdm(test_loader,total=len(test_loader)) for images, labels in bar: images = images.to(device) labels = labels.to(device) outputs = vit_model(images) loss = criterion(outputs.logits, labels) test_loss += loss.item() _, predicted = torch.max(outputs.logits, 1) test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率 test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率 print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%') torch.save(vit_model.state_dict(), 'vit_model_parameters.pth') # 打印BERT模型的权重维度 for name, param in vit_model.named_parameters(): print(name, param.data.shape) # # # 保存模型参数为NumPy格式 model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()} np.savez('vit_model_params.npz', **model_params) # model_params
Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?