python 从深度图计算法线

def get_world_points(depth, intrinsics, extrinsics):
    '''
    Args:
        depthmap: H*W
        intrinsics: 3*3 or 4*4
        extrinsics: 4*4, world to camera
    Return:
        points: N*3, in world space 
    '''
    if intrinsics.shape[0] ==4:
        intrinsics = intrinsics[:3,:3]
        
    height, width = depth.shape

    x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))

    # valid_points = np.ma.masked_greater(depth, 0.0).mask
    # x, y, depth = x[valid_points], y[valid_points], depth[valid_points]

    x = x.reshape((1, height*width))
    y = y.reshape((1, height*width))
    depth = depth.reshape((1, height*width))

    xyz_ref = np.matmul(np.linalg.inv(intrinsics),
                        np.vstack((x, y, np.ones_like(x))) * depth)
    xyz_world = np.matmul(np.linalg.inv(extrinsics),
                            np.vstack((xyz_ref, np.ones_like(x))))[:3]
    xyz_world = xyz_world.transpose((1, 0))

    return xyz_world
def get_camera_origins(poses_homo):
    '''
    Args:
        poses_homo: world to camera poses
    '''
    if not isinstance(poses_homo, np.ndarray):
        poses_homo = np.array(poses_homo)
    cam_centers = []
    poses_homo = np.array(poses_homo)
    num_cams = poses_homo.shape[0]
    for i in range(num_cams):
        rot = poses_homo[i, :3,:3]
        trans = poses_homo[i, :3,3]
        trans = trans.reshape(3,1)
        cam_center = - np.linalg.inv(rot) @ trans
        cam_centers.append(cam_center)
    cam_centers = np.array(cam_centers).squeeze(axis=-1)
    return cam_centers 
def calculate_normalmap_from_depthmap(depthmap, intrin, extrin, num_nearest_neighbors=100):
    '''
    Args:
        depthmap: H*W. depth in image plane
        extrin: word to cam
    Return:
        normalmap: H*W*3
    '''
    pts = get_world_points(depthmap, intrin, extrin)
    cam_center = get_camera_origins([extrin])[0]
    H, W = depthmap.shape

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=num_nearest_neighbors))
    normals = np.array(pcd.normals)
    
    # check normal direction: if ray dir and normal angle is smaller than 90, reverse normal
    ray_dir = pts-cam_center.reshape(1,3)
    normal_dir_not_correct = (ray_dir*normals).sum(axis=-1) > 0
    logging.info(f'Normals with wrong direction: {normal_dir_not_correct.sum()}')
    normals[normal_dir_not_correct] = -normals[normal_dir_not_correct]

    return pts, normals.reshape(H,W,3)
posted @ 2022-12-14 17:25  小小灰迪  阅读(611)  评论(0编辑  收藏  举报