常用代码模板自用

导入库(用于深度学习)

import os
import time
from datetime import timedelta
import json
import yaml
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

 

绘图(用于各种,以信号为例)

plt.figure(1)
plt.subplot(2, 1, 1)
plt.plot(t, sig)  # 绘制信号波形图
plt.subplot(2, 1, 2)
plt.imshow(np.abs(st_Res), origin='lower', extent=(0, len(t), 0, len(t)//2))  # 绘制频谱图、对复数格式取绝对值即可
plt.savefig("./imgs/stockwell-asdvalve09-" + time.strftime("%Y%m%d%H%M", time.localtime()))
plt.show()

深度学习训练模板(torch框架):

import os
import time
from datetime import timedelta
import json
import yaml
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class BaseTrainer(object):
    def __init__(self):
        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
        self.model = None
        self.optimizer = None
        self.loss = None

    def __setup_dataloader(self, is_train):
        if is_train:
            self.train_dataset = None
            self.train_loader = DataLoader(self.train_dataset,
                                           batch_size=64,
                                           shuffle=True, num_workers=0)
        # 获取测试数据
        self.valid_dataset = None
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=64,
                                       shuffle=True, num_workers=0)

    def __setup_model(self, is_train):
        self.model = None
        self.model.to(self.device)
        # optimizer & scheduler
        self.loss = .to(self.device)
        if is_train:
            if self.configs.train_conf.enable_amp:
                self.amp_scaler = torch.cuda.amp.GradScaler(init_scale=1024)
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=float(self.configs.optimizer_conf.learning_rate))

    def __train_epoch(self, epoch_id):
        pass

    def train(self):
        pass

    def evaluate(self):
        pass

    def __test(self, resume_model=None):
        pass

保存checkpoint、加载checkpoint(torch框架):

    def __load_checkpoint(self, save_model_path, resume_model):
        last_epoch = -1
        best_auc, best_pauc = 0, 0
        last_model_dir = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'last_model')
        if resume_model is not None or (os.path.exists(os.path.join(last_model_dir, 'model.pth'))
                                        and os.path.exists(os.path.join(last_model_dir, 'optimizer.pth'))):
            # 自动获取最新保存的模型
            if resume_model is None:
                resume_model = last_model_dir
            assert os.path.exists(os.path.join(resume_model, 'model.pth')), "模型参数文件不存在!"
            assert os.path.exists(os.path.join(resume_model, 'optimizer.pth')), "优化方法参数文件不存在!"
            state_dict = torch.load(os.path.join(resume_model, 'model.pth'))
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                self.model.module.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(state_dict)
            self.optimizer.load_state_dict(torch.load(os.path.join(resume_model, 'optimizer.pth')))
            # 自动混合精度参数
            if self.amp_scaler is not None and os.path.exists(os.path.join(resume_model, 'scaler.pth')):
                self.amp_scaler.load_state_dict(torch.load(os.path.join(resume_model, 'scaler.pth')))
            with open(os.path.join(resume_model, 'model.state'), 'r', encoding='utf-8') as f:
                json_data = json.load(f)
                last_epoch = json_data['last_epoch'] - 1
                best_auc = json_data['best_auc']
                best_pauc = json_data['best_pauc']
            self.logger.info('成功恢复模型参数和优化方法参数:{}'.format(resume_model))
            self.optimizer.step()
            [self.scheduler.step() for _ in range(last_epoch * len(self.train_loader))]
        return last_epoch, best_auc, best_pauc

    def __save_checkpoint(self, save_model_path, epoch_id, best_auc=0., best_pauc=0., best_model=False):
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict = self.model.module.state_dict()
        else:
            state_dict = self.model.state_dict()
        if best_model:
            model_path = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'best_model')
        else:
            model_path = os.path.join(save_model_path,
                                      f'{self.configs.use_model}',
                                      'epoch_{}'.format(epoch_id))
        os.makedirs(model_path, exist_ok=True)
        torch.save(self.optimizer.state_dict(), os.path.join(model_path, 'optimizer.pth'))
        torch.save(state_dict, os.path.join(model_path, 'model.pth'))
        # 自动混合精度参数
        if self.amp_scaler is not None:
            torch.save(self.amp_scaler.state_dict(), os.path.join(model_path, 'scaler.pth'))
        with open(os.path.join(model_path, 'model.state'), 'w', encoding='utf-8') as f:
            data = {"last_epoch": epoch_id, "best_auc": best_auc, "best_pauc": best_pauc,
                    "version": __version__}
            f.write(json.dumps(data))
        if not best_model:
            last_model_path = os.path.join(save_model_path,
                                           f'{self.configs.use_model}',
                                           'last_model')
            shutil.rmtree(last_model_path, ignore_errors=True)
            shutil.copytree(model_path, last_model_path)
            # 删除旧的模型
            old_model_path = os.path.join(save_model_path,
                                          f'{self.configs.use_model}',
                                          'epoch_{}'.format(epoch_id - 3))
            if os.path.exists(old_model_path):
                shutil.rmtree(old_model_path)
        self.logger.info('已保存模型:{}'.format(model_path))

加载预训练模型(torch):

    def __load_pretrained(self, pretrained_model):
        # 加载预训练模型
        if pretrained_model is not None:
            if os.path.isdir(pretrained_model):
                pretrained_model = os.path.join(pretrained_model, 'model.pth')
            assert os.path.exists(pretrained_model), f"{pretrained_model} 模型不存在!"
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                model_dict = self.model.module.state_dict()
            else:
                model_dict = self.model.state_dict()
            model_state_dict = torch.load(pretrained_model)
            # 过滤不存在的参数
            for name, weight in model_dict.items():
                if name in model_state_dict.keys():
                    if list(weight.shape) != list(model_state_dict[name].shape):
                        self.logger.warning('{} not used, shape {} unmatched with {} in model.'.
                                       format(name, list(model_state_dict[name].shape), list(weight.shape)))
                        model_state_dict.pop(name, None)
                else:
                    self.logger.warning('Lack weight: {}'.format(name))
            if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
                self.model.module.load_state_dict(model_state_dict, strict=False)
            else:
                self.model.load_state_dict(model_state_dict, strict=False)
            self.logger.info('成功加载预训练模型:{}'.format(pretrained_model))

 

posted @ 2023-12-29 15:07  倦鸟已归时  阅读(40)  评论(0编辑  收藏  举报