常用代码模板自用
导入库(用于深度学习)
import os import time from datetime import timedelta import json import yaml from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader
绘图(用于各种,以信号为例)
plt.figure(1) plt.subplot(2, 1, 1) plt.plot(t, sig) # 绘制信号波形图 plt.subplot(2, 1, 2) plt.imshow(np.abs(st_Res), origin='lower', extent=(0, len(t), 0, len(t)//2)) # 绘制频谱图、对复数格式取绝对值即可 plt.savefig("./imgs/stockwell-asdvalve09-" + time.strftime("%Y%m%d%H%M", time.localtime())) plt.show()
深度学习训练模板(torch框架):
import os import time from datetime import timedelta import json import yaml from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader class BaseTrainer(object): def __init__(self): self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" self.model = None self.optimizer = None self.loss = None def __setup_dataloader(self, is_train): if is_train: self.train_dataset = None self.train_loader = DataLoader(self.train_dataset, batch_size=64, shuffle=True, num_workers=0) # 获取测试数据 self.valid_dataset = None self.valid_loader = DataLoader(self.valid_dataset, batch_size=64, shuffle=True, num_workers=0) def __setup_model(self, is_train): self.model = None self.model.to(self.device) # optimizer & scheduler self.loss = .to(self.device) if is_train: if self.configs.train_conf.enable_amp: self.amp_scaler = torch.cuda.amp.GradScaler(init_scale=1024) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=float(self.configs.optimizer_conf.learning_rate)) def __train_epoch(self, epoch_id): pass def train(self): pass def evaluate(self): pass def __test(self, resume_model=None): pass
保存checkpoint、加载checkpoint(torch框架):
def __load_checkpoint(self, save_model_path, resume_model): last_epoch = -1 best_auc, best_pauc = 0, 0 last_model_dir = os.path.join(save_model_path, f'{self.configs.use_model}', 'last_model') if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pth')) and os.path.exists(os.path.join(last_model_dir, 'optimizer.pth'))): # 自动获取最新保存的模型 if resume_model is None: resume_model = last_model_dir assert os.path.exists(os.path.join(resume_model, 'model.pth')), "模型参数文件不存在!" assert os.path.exists(os.path.join(resume_model, 'optimizer.pth')), "优化方法参数文件不存在!" state_dict = torch.load(os.path.join(resume_model, 'model.pth')) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pth'))) # 自动混合精度参数 if self.amp_scaler is not None and os.path.exists(os.path.join(resume_model, 'scaler.pth')): self.amp_scaler.load_state_dict(torch.load(os.path.join(resume_model, 'scaler.pth'))) with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f: json_data = json.load(f) last_epoch = json_data['last_epoch'] - 1 best_auc = json_data['best_auc'] best_pauc = json_data['best_pauc'] self.logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model)) self.optimizer.step() [self.scheduler.step() for _ in range(last_epoch * len(self.train_loader))] return last_epoch, best_auc, best_pauc def __save_checkpoint(self, save_model_path, epoch_id, best_auc=0., best_pauc=0., best_model=False): if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() if best_model: model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'best_model') else: model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'epoch_{}'.format(epoch_id)) os.makedirs(model_path, exist_ok=True) torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pth')) torch.save(state_dict, os.path.join(model_path, 'model.pth')) # 自动混合精度参数 if self.amp_scaler is not None: torch.save(self.amp_scaler.state_dict(), os.path.join(model_path, 'scaler.pth')) with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f: data = {"last_epoch": epoch_id, "best_auc": best_auc, "best_pauc": best_pauc, "version": __version__} f.write(json.dumps(data)) if not best_model: last_model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'last_model') shutil.rmtree(last_model_path, ignore_errors=True) shutil.copytree(model_path, last_model_path) # 删除旧的模型 old_model_path = os.path.join(save_model_path, f'{self.configs.use_model}', 'epoch_{}'.format(epoch_id - 3)) if os.path.exists(old_model_path): shutil.rmtree(old_model_path) self.logger.info('已保存模型:{}'.format(model_path))
加载预训练模型(torch):
def __load_pretrained(self, pretrained_model): # 加载预训练模型 if pretrained_model is not None: if os.path.isdir(pretrained_model): pretrained_model = os.path.join(pretrained_model, 'model.pth') assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!" if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): model_dict = self.model.module.state_dict() else: model_dict = self.model.state_dict() model_state_dict = torch.load(pretrained_model) # 过滤不存在的参数 for name, weight in model_dict.items(): if name in model_state_dict.keys(): if list(weight.shape) != list(model_state_dict[name].shape): self.logger.warning('{} not used, shape {} unmatched with {} in model.'. format(name, list(model_state_dict[name].shape), list(weight.shape))) model_state_dict.pop(name, None) else: self.logger.warning('Lack weight: {}'.format(name)) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(model_state_dict, strict=False) else: self.model.load_state_dict(model_state_dict, strict=False) self.logger.info('成功加载预训练模型:{}'.format(pretrained_model))