1原有的渲染函数修改
1 添加json路径
print("model.extract(args).source_path",model.extract(args).source_path) print("model.extract(args).model_path",model.extract(args).model_path) print("args.iteration",args.iteration) #model.extract(args).source_path /home/dongdong/2project/0data/house3/100/colmap7 #model.extract(args).model_path /home/dongdong/2project/0data/house3/100/gs_out7 # Initialize system state (RNG) safe_state(args.quiet) render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test,args.track_path,args.save_path)
2添加家在函数
def loadRenderCameras(self,track_path, save_path,args): render_cameras_infos = loadRenderCamFromJson(track_path, save_path) self.render_cameras = {} self.render_cameras[1.0] = cameraList_from_camInfos_render(render_cameras_infos, 1 ,args)
json数据格式
{ "camera_angle_x": 1.1168334919260177, 视场角度x "camera_angle_y": 0.6916973331686856, 视场角度y "fl_x": 1536.6241314375784, 像素焦距fx (像素个数) = f_x焦距(米)/单个像素物理长度(米) "fl_y": 1498.621663983134, 像素焦距fy (像素个数) = f_y焦距(米)/单个像素物理长度(米) "k1": 0.0, 畸变系数 "k2": 0.0, "k3": 0.0, "k4": 0, "p1": 0.0, "p2": 0.0, "is_fisheye": false, "cx": 960.0, 中心 "cy": 540.0, 中心 "w": 1920.0, "h": 1080.0, "aabb_scale": 1, "frames": [ { "file_path": "./images/0001.jpg", "sharpness": 1, "transform_matrix": [ [ 0.9999527616170583, 0.00279275512886567, 0.009309943781165604, -4.738427091683277 ], [ -0.002729430736777636, 0.9999731031529604, -0.006807582426336303, 2.9547333944796534 ], [ -0.009328705283768106, 0.006781850000436799, 0.9999334886722716, 2.240625948778779 ], [ 0.0, 0.0, 0.0, 1.0 ] ], "cx": 960.0, "cy": 540.0, "w": 1920.0, "h": 1080.0, "k1": 0.0, "k2": 0.0, "k3": 0.0, "k4": 0, "p1": 0.0, "p2": 0.0, "camera_angle_x": 1.1168334919260177, "camera_angle_y": 0.6916973331686856, "fl_x": 1536.6241314375784, "fl_y": 1498.621663983134 }, { "file_path": "./images/0002.jpg", "sharpness": 1, "transform_matrix": [ [ 0.9999527616170583, 0.00279275512886567, 0.009309943781165604, -4.7374427825107475 ], [ -0.002729430736777636, 0.9999731031529604, -0.006807582426336303, 2.954988109418036 ], [ -0.009328705283768106, 0.006781850000436799, 0.9999334886722716, 2.240385724373548 ], [ 0.0, 0.0, 0.0, 1.0 ] ], "cx": 960.0, "cy": 540.0, "w": 1920.0, "h": 1080.0, "k1": 0.0, "k2": 0.0, "k3": 0.0, "k4": 0, "p1": 0.0, "p2": 0.0, "camera_angle_x": 1.1168334919260177, "camera_angle_y": 0.6916973331686856, "fl_x": 1536.6241314375784, "fl_y": 1498.621663983134 },
2-1加载函数
def loadRenderCamFromJson(track_path,save_path): cam_infos = [] with open(track_path) as json_file: contents = json.load(json_file) frames = contents["frames"] for idx, frame in enumerate(frames): # print(os.path.basename(frame["file_path"])) outname = os.path.join(save_path, os.path.basename(frame["file_path"])) if not os.path.splitext(outname)[1]: outname = outname + ".jpg" # outname = outname[: -4] + '.jpg' # NeRF 'transform_matrix' is a camera-to-world transform c2w = np.array(frame["transform_matrix"]) # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) # c2w[:3, 1:3] *= -1 # get the world-to-camera transform and set R, T w2c = np.linalg.inv(c2w) R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] width = frame["w"] height = frame["h"] # fx = frame["fl_x"] # fy = frame["fl_y"] fx = 1536.6241314375784 fy = 1498.621663983134 # fjx ??? why # fovx = frame["camera_angle_x"] # # fovy = frame["camera_angle_y"] # fovy = focal2fov(fov2focal(fovx, width), height) # # fovx = fovy cx = frame["cx"] cy = frame["cy"] # fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) # wyl FovY = focal2fov(fy,height) FovX = focal2fov(fx,width) image = np.zeros((1080,1920,3),dtype=np.uint8) cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, image_path=" ", image_name=outname, width=width, height=height, cx=cx, cy=cy) cam_infos.append(cam_info) return cam_infos
2-2转化函数 原先自带的
def cameraList_from_camInfos_render(cam_infos, resolution_scale, args): camera_list = [] for id, c in enumerate(cam_infos): camera_list.append(loadCamRender(args, id, c, resolution_scale)) return camera_list
def loadCamRender(args, id, cam_info, resolution_scale): # TODO : how to use Camera.image in render ???? resized_image_rgb = PILtoTorch(cam_info.image, 1) # cam_info.image图像数据 gt_image = resized_image_rgb[:3, ...] return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, cx=cam_info.cx, cy=cam_info.cy, image=gt_image, gt_alpha_mask=None, image_name=cam_info.image_name, uid=id)
3 数据给渲染函数
调用
def getRenderCameras(self,scale=1.0): return self.render_cameras[scale]
完整代码
render.py
# # Copyright (C) 2023, Inria # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # # This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # import torch from scene import Scene import os from tqdm import tqdm from os import makedirs from gaussian_renderer import render import torchvision from utils.general_utils import safe_state from argparse import ArgumentParser from arguments import ModelParams, PipelineParams, get_combined_args from gaussian_renderer import GaussianModel def render_set(model_path, name, iteration, views, gaussians, pipeline, background): # render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") # gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") # # makedirs(render_path, exist_ok=True) # makedirs(gts_path, exist_ok=True) for idx, view in enumerate(tqdm(views, desc="Rendering progress")): rendering = render(view, gaussians, pipeline, background)["render"] # gt = view.original_image[0:3, :, :] torchvision.utils.save_image(rendering, view.image_name) # torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".jpg")) def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool,track_path:str,save_path:str): with torch.no_grad(): gaussians = GaussianModel(dataset.sh_degree) #dataset 1 colmap 稀疏重建数据路径 2 3D高斯训练好的场景 scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") scene.loadRenderCameras(track_path,save_path,ModelParams) os.makedirs(save_path,exist_ok=True) render_set(dataset.model_path, " ", scene.loaded_iter, scene.getRenderCameras(), gaussians, pipeline,background) # if not skip_train: # render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) # # if not skip_test: # render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) if __name__ == "__main__": # Set up command line argument parser parser = ArgumentParser(description="Testing script parameters") model = ModelParams(parser, sentinel=True) pipeline = PipelineParams(parser) parser.add_argument("--iteration", default=-1, type=int) parser.add_argument("--skip_train", action="store_true") parser.add_argument("--skip_test", action="store_true") parser.add_argument("--quiet", action="store_true") parser.add_argument("--track_path", default="", type=str) parser.add_argument("--save_path", default="", type=str) parser.add_argument("--width", default=1920, type=int) parser.add_argument("--height", default=1080, type=int) args = get_combined_args(parser) print("1训练好的模型路径 -m " + args.model_path) ''' 2 Config file found: /home/dongdong/2project/0data/house3/100/gs_out7/cfg_args 2-1 记录了 1colmap重建路径 colmap/images colmap/spares/0/... 2 参数 ''' print("model.extract(args).source_path",model.extract(args).source_path) print("model.extract(args).model_path",model.extract(args).model_path) print("args.iteration",args.iteration) #model.extract(args).source_path /home/dongdong/2project/0data/house3/100/colmap7 #model.extract(args).model_path /home/dongdong/2project/0data/house3/100/gs_out7 # Initialize system state (RNG) safe_state(args.quiet) render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test,args.track_path,args.save_path)