hrnet
train.py的221行是做train_loader。在make_dataloader中关键代码是dataset = build_dataset(cfg, is_train)然后直接data_loader = torch.utils.data.DataLoader。
而build_dataset又分为三部分,build_transformer和heatmap和offset的生成。
train_loader = make_dataloader(
cfg, is_train=True, distributed=args.distributed
)
他这里有几句下面是被make_dataloader所包含的一个函数在dataset.buld.py里
def build_dataset(cfg, is_train):
assert is_train is True, 'Please only use build_dataset for training.'
transforms = build_transforms(cfg, is_train)
heatmap_generator = HeatmapGenerator(
cfg.DATASET.OUTPUT_SIZE, cfg.DATASET.NUM_JOINTS
)
offset_generator = OffsetGenerator(
cfg.DATASET.OUTPUT_SIZE, cfg.DATASET.OUTPUT_SIZE,
cfg.DATASET.NUM_JOINTS, cfg.DATASET.OFFSET_RADIUS
)
dataset = eval(cfg.DATASET.DATASET)(
cfg,
cfg.DATASET.TRAIN,
heatmap_generator,
offset_generator,
transforms
)
return dataset
先转换再读取数据
输出的是18
class RandomAffineTransform(object):
def __init__(self,
input_size,
output_size,
max_rotation,
min_scale,
max_scale,
scale_type,
max_translate):
self.input_size = input_size
self.output_size = output_size if isinstance(output_size, list) \
else [output_size]
self.max_rotation = max_rotation
self.min_scale = min_scale
self.max_scale = max_scale
self.scale_type = scale_type
self.max_translate = max_translate
def _get_affine_matrix(self, center, scale, res, rot=0):
# Generate transformation matrix
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
scale = t[0,0]*t[1,1]
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1]/2
t_mat[1, 2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t, scale
def _affine_joints(self, joints, mat):
joints = np.array(joints)
shape = joints.shape
我输出了shape他是18
joints = joints.reshape(-1, 2)
return np.dot(np.concatenate(
(joints, joints[:, 0:1]*0+1), axis=1), mat.T).reshape(shape)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
2022-01-13 查看网络参数