【笔记】Spherical Harmonic Lighting 球谐光照再探

更新一下,使用pyshtools(或者skylib)就可以简单地完成这个工作了。


其实就是想找一个demo,从一个图片(天空盒)里整出来球谐光照,但好像没找到短小的demo

自己瞎摸索着写了一个(采样那里偷懒了,直接在cubebox上做的,这样似乎有偏,理论上应该是球上采完射回去)

自己的理解也不是很透彻,不知道对不对,效果类似这样:

image

主要涉及到了三个坐标系之间的转换,单位球上的球坐标,cubemap展开以后的uv坐标,合上之后的xyz坐标

然后就是采样,拟合(或者类似于反演?)出一个合适的SH光照函数,然后再射回来可视化一下

值得一提的是scipy的SH函数不是一般奇葩,参数是反着的,(1,2) (3,4) 要调换一下

然后写法也比较粗暴,速度很差

import torch
import cv2
import numpy as np
from scipy.special import sph_harm
from matplotlib import pyplot as plt
import math

def uv_2_xyz(face,u,v):
    # u,v: [0,1]
    uc = 2*u-1
    vc = 2*v-1
    x,y,z = 0,0,0
    if face == 'px':
        x,y,z = 1,vc,-uc
    elif face == 'nx':
        x,y,z = -1,vc,uc
    elif face == 'py':
        x,y,z = uc,1,-vc
    elif face == 'ny':
        x,y,z = uc,-1,vc
    elif face == 'pz':
        x,y,z = uc,vc,1
    elif face == 'nz':
        x,y,z = -uc,vc,-1
    return x,y,z

def convert_xyz_to_cube_uv(x, y, z):
    absX = abs(x)
    absY = abs(y)
    absZ = abs(z)
    
    isXPositive = 1 if x > 0 else 0
    isYPositive = 1 if y > 0 else 0
    isZPositive = 1 if z > 0 else 0
    
    if isXPositive and absX >= absY and absX >= absZ:
        # u (0 to 1) goes from +z to -z
        # v (0 to 1) goes from -y to +y
        maxAxis = absX
        uc = -z
        vc = y
        index = 'px'
    elif not isXPositive and absX >= absY and absX >= absZ:
        # u (0 to 1) goes from -z to +z
        # v (0 to 1) goes from -y to +y
        maxAxis = absX
        uc = z
        vc = y
        index = 'nx'
    elif isYPositive and absY >= absX and absY >= absZ:
        # u (0 to 1) goes from -x to +x
        # v (0 to 1) goes from +z to -z
        maxAxis = absY
        uc = x
        vc = -z
        index = 'py'
    elif not isYPositive and absY >= absX and absY >= absZ:
        # u (0 to 1) goes from -x to +x
        # v (0 to 1) goes from -z to +z
        maxAxis = absY
        uc = x
        vc = z
        index = 'ny'
    elif isZPositive and absZ >= absX and absZ >= absY:
        # u (0 to 1) goes from -x to +x
        # v (0 to 1) goes from -y to +y
        maxAxis = absZ
        uc = x
        vc = y
        index = 'pz'
    elif not isZPositive and absZ >= absX and absZ >= absY:
        # u (0 to 1) goes from +x to -x
        # v (0 to 1) goes from -y to +y
        maxAxis = absZ
        uc = -x
        vc = y
        index = 'nz'
    
    # Convert range from -1 to 1 to 0 to 1
    u = 0.5 * (uc / maxAxis + 1.0)
    v = 0.5 * (vc / maxAxis + 1.0)
    return index, u, v

def xyz_2_theta_phi(x,y,z):
    theta = np.arccos(z/np.sqrt(x**2+y**2+z**2))
    phi = np.arctan2(y,x)
    # theta [0,pi]
    # phi [0,2pi]
    if phi < 0:
        phi += 2*np.pi
    
    return theta, phi

def uniform_spherical_sample(n_samples):
    """Uniformly sample points on the surface of a sphere.
    Args:
        n_samples (int): number of samples to generate.
    Returns:
        Tensor: points on the surface of a sphere, of shape (n_samples, 2).
    """
    points = torch.randn(n_samples, 3)
    points = points / points.norm(dim=1, keepdim=True)
    # calculate spherical coordinates
    theta = torch.acos(points[:, 2])
    phi = torch.atan2(points[:, 1], points[:, 0])
    return torch.stack([theta, phi], dim=1) # [n_samples, 2]

def spherical_harmonics_basis(theta, phi, n_orders):
    """Compute spherical harmonics basis.
    Args:
        theta (Tensor): polar angle, in range [0, pi].
        phi (Tensor): azimuthal angle, in range [0, 2*pi].
        n_orders (int): number of orders.
    Returns:
        Tensor: spherical harmonics basis, of shape (n_samples, n_orders**2).
    """
    # ! attention: scipy sph_harm swap (theta, phi) & (order, degree)
    # compute spherical harmonics basis
    basis = []
    for order in range(n_orders):
        for degree in range(-order, order + 1):
            Ylm = sph_harm(abs(degree), order, phi, theta)
            if degree < 0:
                Ylm = np.sqrt(2) * (-1)**degree * Ylm.imag
            elif degree > 0:
                Ylm = np.sqrt(2) * (-1)**degree * Ylm.real
            basis.append(Ylm.real)
    basis = torch.stack(basis, dim=1) # [n_samples, n_orders**2]
    return basis

cubemap = {}
for name in ['px','nx','py','ny','pz','nz']:
    cubemap[name] = cv2.imread(f'data/cubemap/{name}.png')
    cubemap[name] = cv2.cvtColor(cubemap[name], cv2.COLOR_BGR2RGB)
    cubemap[name] = cubemap[name].astype(np.float32)/255.0
cubemap_h, cubemap_w = cubemap['px'].shape[:2]

n_samples = 100000
n_orders = 4

sample_theta_phi = []
sample_color = []
for i in range(n_samples):
    sample_face = np.random.choice(['px','nx','py','ny','pz','nz'])
    sample_u = np.random.uniform(0,1)
    sample_v = np.random.uniform(0,1)
    x,y,z = uv_2_xyz(sample_face,sample_u,sample_v)
    theta, phi = xyz_2_theta_phi(x,y,z)
    sample_theta_phi.append([theta,phi])
    color = cubemap[sample_face][int(sample_v*cubemap_h),int(sample_u*cubemap_w)]
    sample_color.append(color)

sample_theta_phi = torch.tensor(np.array(sample_theta_phi)) # [n_samples, 2]
sample_color = torch.tensor(np.array(sample_color)) # [n_samples, 3], RGB,float [0,1]

Y = spherical_harmonics_basis(sample_theta_phi[:,0],sample_theta_phi[:,1],n_orders) # [n_samples, n_orders**2]

coeffs = Y.unsqueeze(-1) * sample_color.unsqueeze(1) # [n_samples, n_orders**2, 3]
coeffs = torch.mean(coeffs,dim=0) # [n_orders**2, 3]
coeffs *= 4*np.pi # [n_orders**2, 3]

# print(coeffs)

new_cubemap = {}
for face in ['px','nx','py','ny','pz','nz']:
    plt.suptitle(face)
    plt.subplot(1,2,1)
    plt.axis('off')
    plt.imshow(cubemap[face])
    new_cubemap[face] = np.zeros((cubemap_h,cubemap_w,3))
    theta_list, phi_list = [],[]
    for i in range(cubemap_h):
        for j in range(cubemap_w):
            x,y,z = uv_2_xyz(face,j/cubemap_w,i/cubemap_h)
            theta, phi = xyz_2_theta_phi(x,y,z)
            theta_list.append(theta)
            phi_list.append(phi)
    theta = torch.tensor(theta_list) # [n_samples]
    phi = torch.tensor(phi_list) # [n_samples]
    Y = spherical_harmonics_basis(theta,phi,n_orders) # [n_samples, n_orders**2]
    tmp = torch.matmul(Y,coeffs).numpy() # [n_samples, 3]
    new_cubemap[face] = tmp.reshape(cubemap_h,cubemap_w,3)

    print(new_cubemap[face].min(),new_cubemap[face].max())
    new_cubemap[face] = (new_cubemap[face]*255).astype(np.uint8)
    plt.subplot(1,2,2)
    plt.axis('off')
    plt.imshow(new_cubemap[face])
    new_cubemap[face] = cv2.cvtColor(new_cubemap[face], cv2.COLOR_RGB2BGR)
    plt.show()
    cv2.imwrite(f'data/cubemap/{face}_new.png',new_cubemap[face])


posted @ 2023-02-27 21:39  GhostCai  阅读(151)  评论(0编辑  收藏  举报