【CasMVSNet】代码学习记录
我们在学习/科研过程中,时常要参(bai)考(piao)别人的开源代码。很多深度学习的代码是基于PyTorch的,那我们就来看一下代码的组织格式吧。
正如一个人有两条腿走路,CV领域也有模型和数据两条腿。
├── dataset # 数据集相关文件夹
├── model # 模型相关文件夹
│ ├── sub_module.py # 网络的子模块
│ └── xxnet.py # 基于子模块构建的网络
├── train.py # 模型训练脚本
├── valid.py # 模型验证脚本文件
├── test.py
├── utils.py # 其它的一些工具脚本
├── checkpoints # 训练过程中产生的模型
└── log # 日志文件
一、数据集
数据加载方面主要涉及到Dataset
和DataLoader
这两个类
1.1 Dataset
torch.utils.data.Dataset
是一个抽象类,自定义的数据集类需要继承此类,并实现两个成员方法:__getitem__()
和__len__()
其中__getitem__()
根据输入的key返回对应的训练样本,而__len__()
返回数据集中训练样本的总数。
(举个例子:每一个样本包括ref_img, src_imgs和对应的相机参数以及ref_depth.)
def __getitem__(self, index):
img_path, label = self.data[index].img_path, self.data[index].label
img = Image.open(img_path)
return img, label
第二个方法__len__
返回数据集的长度
这里所讲的dataset, 按照官方文档上来说是Map-Style datasets
。即可以通过索引idx
来访问第idx
幅图片和对应的标签。
1.2 DataLoader
构造函数
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
参数解析
- dataset:上面自定义的dataset
- batch_size:一次加载几组数据
- shuffle:乱序数据加载,通常在train时设为True; test时置为False
- num_worker:用于数据加载的线程数。默认为0表示使用主进程加载,参考👉 https://pytorch-lightning.readthedocs.io/en/latest/performance.html#num-workers
- pin_memeory:页锁定内存,一般默认设置即可。
1.3 Transform
读取图片时使用PIL
来读取,可以进行crop, resize, flip等操作,最重要的是将读取的图片对象转为Tensor
https://pytorch.apachecn.org/docs/1.0/torchvision_transforms.html
二、模型
模型中网络的子模块和网络类需要继承nn.Module
实现前向传播forward()
函数
构造函数中需要super().__init__()
, 即调用父类nn.Module
的构造函数
三、训练
dataloader中数据feed至定义的xx-net中,进行forward。forward的产物计算loss, backward