[数字人] NeRF实现
体渲染
光沿着路径传播,如果路径上有物体,就会损失能量,我们将单位面积的光的能量记为\(I(t)\)。假设光走了\(\Delta t\)的路程,此时剩下的能量可以表示成:
其中\(\sigma(t)\)是该点物体的“占有密度”,因为是单位面积,所以\(\sigma(t)\Delta t\)表示的是这段路径下单位面积的物体的“占有量”,或者说光能量因穿过物体导致的“衰减”。因此可以用微分方程进行推导,得出\(I(t)\)的形式:
这是一个可分离的微分方程,可以直接解得
其中\(I_0\)是初始能量,这是微分方程的初值。如果我们将\(I(t)\)规定为\([0,1]\)的量,此时\(1-I(t)\)可以表示为一个概率分布函数,即表示\([0,t]\)范围内有物体的概率,此时我们可以求出概率密度函数:
特别的,我们可以取\(I_0=1\)并且控制\(\sigma(t)\)以保证他还是一个概率密度函数,此时
因此,沿着光线路径对颜色进行积分,就是这条路径上颜色的期望,可以描述为:
一般来说,我们都是从一个近端向一个远端积分,因此改积分上下界:
其中\(T(t) =e^{-\int_{t_{near}}^{t}\sigma(t)dt}\)
离散化
在\([t_{near},t_{far}]\)中采样N个点,其中\(t_{near} \lt t_1 \lt t_2 \lt \cdots \lt t_N \lt t_{far}\),则差分(区间长度)为\(\delta_i = t_{i+1} - t_{i}\)。前面说过,概率密度为\(p(t) = T(t)\sigma(t)\),因此需要计算离散概率\(p_i\),然后计算\(\sum_{i=1}^{N} p_i c_i\)就是数值积分的结果(想想我们如何求离散概率分布的期望的),其中\(p_i\)表示\(\delta_i\)这个interval里有物体的概率,因此可以先估计这个概率:
现在将结果带入:
再考虑\(T(t_i)\):
因此最终的形式为:
实现
def render(self, t, preds):
# t(b, hw, num_samples, 1)
b = preds.size(0)
sigma = preds[..., -1:] # (b, hw, num_samples, 1)
c = preds[..., :3] # (b, hw, num_samples, 3)
delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
delta = torch.cat([delta, max_delta], dim = 2)
sigma_delta = sigma * delta #(b, hw, num_samples, 1)
exp_ = torch.exp(-sigma_delta)
T = torch.cumprod(exp_,dim=2)
T = torch.roll(T, 1, 2)
T[...,0,0]=1
代码
demo.py,我按照tinyNerf重新写了下:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import os
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(36,128),
nn.ReLU(),
nn.Linear(128,128),
nn.ReLU(),
nn.Linear(128,4)
)
def forward(self, x):
y = self.net(x)
y[..., :3] = torch.sigmoid(y[..., :3]) # rgb
y[..., -1] = torch.nn.functional.relu(y[..., -1]) # sigma
return y
class Dataset(torch.utils.data.Dataset):
def __init__(self):
self.load()
def __len__(self):
return len(self.images)
def load(self):
data = np.load("tiny_nerf_data.npz")
images = data["images"]
tform_cam2world = torch.from_numpy(data["poses"])
focal_length = torch.from_numpy(data["focal"])
self.images = images
self.tform_cam2world = tform_cam2world
self.focal_length = focal_length
self.height, self.width = images.shape[1:3]
def __getitem__(self, i):
image = self.images[i]
tform_cam2world = self.tform_cam2world[i]
return torch.from_numpy(image), tform_cam2world
class TinyNeRF():
def __init__(self, net, focal, width, height, znear, zfar, L, num_samples):
self.focal = focal
self.width = width
self.height = height
self.znear = znear
self.zfar = zfar
self.num_samples = num_samples
self.net = net
self.L = L
def sample_query_points(self, tform):
b = tform.size(0)
xx, yy = torch.arange(self.width), torch.arange(self.height)
x, y = torch.meshgrid(xx, yy,indexing = "xy" ) # (h, w)
x, y = x-0.5 * self.width , 0.5 * self.height-y
directions = torch.stack([x / self.focal, y / self.focal, -torch.ones_like(x)], dim = -1) # (h,w,3)
directions = directions.view(-1,3, 1).to(tform) #(hw, 3, 1)
directions = (directions.view(1, self.height, self.width, 1, 3) * tform[:, :3, :3].view(b, 1, 1, 3, 3)).sum(dim =
directions = directions.view(b, -1, 1, 3) #(b, hw, 1, 3)
t = torch.linspace(self.znear, self.zfar, self.num_samples)
noise = torch.rand([b, self.height * self.width, self.num_samples, 1])
t = t.view(1, 1, self.num_samples, 1) + noise * (self.zfar - self.znear) / self.num_samples
t = t.to(tform)
origins = tform[:, :3, -1].view(b, 1, 1, 3) # tx ty tz
return origins + directions * t, t # (b, wh, num_samples, 3)
def position_encoding(self, x):
b = x.size(0)
x_flatten = x.view(-1,3, 1)
freq = (2**torch.arange(0, self.L)).view(1,1,-1).to(x)
y_sin = torch.sin(x_flatten * freq)
y_cos = torch.cos(x_flatten * freq) # (n, 3, L)
return torch.cat([y_sin, y_cos], dim = -1).view(b, -1, 2*3*self.L)
def render(self, t, preds):
# t(b, hw, num_samples, 1)
b = preds.size(0)
sigma = preds[..., -1:] # (b, hw, num_samples, 1)
c = preds[..., :3] # (b, hw, num_samples, 3)
delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
delta = torch.cat([delta, max_delta], dim = 2)
sigma_delta = sigma * delta #(b, hw, num_samples, 1)
exp_ = torch.exp(-sigma_delta)
T = torch.cumprod(exp_,dim=2)
T = torch.roll(T, 1, 2)
T[...,0,0]=1
return torch.sum(T*(1-exp_)*c, dim = 2).view(b, self.height, self.width, 3)
def query(self, tform):
b = tform.size(0)
pts, t = self.sample_query_points(tform)
pts_code = self.position_encoding(pts)
pred = net(pts_code)
return pred.view(b, self.height*self.width, self.num_samples, 4), t # (wh*num_samples, 4)
def train(dataloader, nerf, optimizer, device):
for epoch in range(300):
bar = tqdm(dataloader)
for img, tform in bar:
img = img.to(device)
tform = tform.to(device)
optimizer.zero_grad()
preds, t = nerf.query(tform)
img_rendered = nerf.render(t, preds)
loss = torch.nn.functional.mse_loss(img_rendered, img)
loss.backward()
optimizer.step()
bar.set_description("[train] Epoch: {} Loss: {:.6f}".format(epoch+1, loss.item()))
img_out = img_rendered.permute(0,3,1,2).contiguous()
img_ref = img.permute(0,3,1,2).contiguous()
img_out = torchvision.utils.make_grid(torch.cat([img_out, img_ref], dim = 0), nrow=2)
torchvision.utils.save_image(img_out.cpu(),f"images/{epoch}.jpg" )
if __name__ == "__main__":
if not os.path.exists("images"):
os.mkdir("images")
device = "cuda"
batch_size = 8
dataset = Dataset()
net = Net().to(device)
nerf = TinyNeRF(net, focal = dataset.focal_length, width = dataset.width,
height=dataset.height, znear = 2, zfar = 6, L = 6,
num_samples = 32)
dataloader = torch.utils.data.DataLoader(dataset, batch_size, True)
opti train(dataloader, nerf, optimizer, device) r = 5e-3)
train(dataloader, nerf, optimizer, device)
100 epoch时有个基本的样子:
其中主要在于2个函数,sample_query_points和render。
def sample_query_points(self, tform):
b = tform.size(0)
xx, yy = torch.arange(self.width), torch.arange(self.height)
x, y = torch.meshgrid(xx, yy,indexing = "xy" ) # (h, w)
x, y = x-0.5 * self.width , 0.5 * self.height-y
directions = torch.stack([x / self.focal, y / self.focal, -torch.ones_like(x)], dim = -1) # (h,w,3)
directions = directions.view(-1,3, 1).to(tform) #(hw, 3, 1)
directions = (directions.view(1, self.height, self.width, 1, 3) * tform[:, :3, :3].view(b, 1, 1, 3, 3)).sum(dim =
directions = directions.view(b, -1, 1, 3) #(b, hw, 1, 3)
t = torch.linspace(self.znear, self.zfar, self.num_samples)
noise = torch.rand([b, self.height * self.width, self.num_samples, 1])
t = t.view(1, 1, self.num_samples, 1) + noise * (self.zfar - self.znear) / self.num_samples
t = t.to(tform)
origins = tform[:, :3, -1].view(b, 1, 1, 3) # tx ty tz
return origins + directions * t, t # (b, wh, num_samples, 3)
这里meshgrid的indexing=’xy’是让生成的图像坐标跟图像是同shape的,即(h,w)
然后就是像素坐标系到相机坐标系的转换,他规定的是z轴向里(他的变换矩阵是根据这个规定算来的),所以虚像的平面应该是z=-focal。
代码里x和y先除以focal,z设置为-1是因为后面早晚要除的,x/focal * t才是点的坐标,然后先做相机系到世界系的变换,再采样t,跟先采样z再变换一样的。
然后就是均匀采样num_samples个点,得到这些点的坐标。
def render(self, t, preds):
# t(b, hw, num_samples, 1)
b = preds.size(0)
sigma = preds[..., -1:] # (b, hw, num_samples, 1)
c = preds[..., :3] # (b, hw, num_samples, 3)
delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
delta = torch.cat([delta, max_delta], dim = 2)
sigma_delta = sigma * delta #(b, hw, num_samples, 1)
exp_ = torch.exp(-sigma_delta)
T = torch.cumprod(exp_,dim=2)
T = torch.roll(T, 1, 2)
T[...,0,0]=1
这个上面推导过了,最后一行设第一个元素为1是因为对j的求和到\(i-1\),此时是\(e^0=1\)
max_delta那里也是因为差分的最后一个元素没有下一个元素了,所以我们给很大的差分会让这项算出来的权重几乎为0.