[数字人] NeRF实现

体渲染

光沿着路径传播,如果路径上有物体,就会损失能量,我们将单位面积的光的能量记为\(I(t)\)。假设光走了\(\Delta t\)的路程,此时剩下的能量可以表示成:

\[I(t+\Delta t) = I(t) * (1-\sigma(t)\Delta t) \]

其中\(\sigma(t)\)是该点物体的“占有密度”,因为是单位面积,所以\(\sigma(t)\Delta t\)表示的是这段路径下单位面积的物体的“占有量”,或者说光能量因穿过物体导致的“衰减”。因此可以用微分方程进行推导,得出\(I(t)\)的形式:

\[\begin{align} \Delta I =& I(t+\Delta t) - I(t) \\ =&-I(t)\sigma(t)\Delta t \\ \frac{\Delta I}{\Delta t} =&-I(t)\sigma(t) \\ I'(t) =& -I(t)\sigma(t) \end{align} \]

这是一个可分离的微分方程,可以直接解得

\[\begin{align} I(t) = I_0e^{-\int_0^t \sigma(t)dt} \end{align} \]

其中\(I_0\)是初始能量,这是微分方程的初值。如果我们将\(I(t)\)规定为\([0,1]\)的量,此时\(1-I(t)\)可以表示为一个概率分布函数,即表示\([0,t]\)范围内有物体的概率,此时我们可以求出概率密度函数:

\[\begin{align}(1-I(t))'=f(t) = I_0 \sigma(t)e^{-\int_0^t\sigma(t)dt}\end{align} \]

特别的,我们可以取\(I_0=1\)并且控制\(\sigma(t)\)以保证他还是一个概率密度函数,此时

\[\begin{align}f(t) = \sigma(t)e^{-\int_0^t\sigma(t)dt}\end{align} \]

因此,沿着光线路径对颜色进行积分,就是这条路径上颜色的期望,可以描述为:

\[\begin{align}\int_0^Tf(t)c(t)dt = \int_0^T \sigma(t)e^{-\int_0^t\sigma(t)dt}c(t)dt\end{align} \]

一般来说,我们都是从一个近端向一个远端积分,因此改积分上下界:

\[\begin{align}c(r)=\int_{t_{near}}^{t_{far}} \sigma(t)e^{-\int_{t_{near}}^{t}\sigma(t)dt}c(t)dt = \int_{t_{near}}^{t_{far}} \sigma(t)T(t)c(t)dt\end{align} \]

其中\(T(t) =e^{-\int_{t_{near}}^{t}\sigma(t)dt}\)

离散化

\([t_{near},t_{far}]\)中采样N个点,其中\(t_{near} \lt t_1 \lt t_2 \lt \cdots \lt t_N \lt t_{far}\),则差分(区间长度)为\(\delta_i = t_{i+1} - t_{i}\)。前面说过,概率密度为\(p(t) = T(t)\sigma(t)\),因此需要计算离散概率\(p_i\),然后计算\(\sum_{i=1}^{N} p_i c_i\)就是数值积分的结果(想想我们如何求离散概率分布的期望的),其中\(p_i\)表示\(\delta_i\)这个interval里有物体的概率,因此可以先估计这个概率:

\[\begin{align}p_i=&\int_{t_i}^{t_{t+1}}T(t)\sigma(t)dt \\=& (1-T(t_{i+1})) - (1-T(t_i))\\ =& T(t_i) - T(t_{t+1})\\ =&T(t_i)(1-\frac{T(t_{i+1})}{T(t_i)}) \\ =&T(t_i)(1-e^{-\int_{t_i}^{t_{i+1}}\sigma(t)dt}) \\\approx& T(t_i)(1-e^{-\delta_i \sigma_i})&\end{align} \]

现在将结果带入:

\[\begin{align}c(r) =& \sum_{i=1}^N p_ic_i \\ =&\sum_{i=1}^N T(t_i)(1-e^{-\delta_i \sigma_i}) c_i \end{align} \]

再考虑\(T(t_i)\)

\[\begin{align}T(t_i) =& e^{-\int_{t_1}^{t_i}\sigma(t)dt} \\ \approx&e^{-\sum_{j=1}^{i-1}\sigma(t_i)\delta_i}\end{align} \]

因此最终的形式为:

\[\begin{align}c(r) =& \sum_{i=1}^N e^{-\sum_{j=1}^{i-1}\sigma_j\delta_j}(1-e^{-\delta_i \sigma_i}) c_i \end{align} \]

实现

def render(self, t, preds):
        # t(b, hw, num_samples, 1)
        b = preds.size(0)
        sigma = preds[..., -1:] # (b, hw, num_samples, 1)
        c = preds[..., :3] # (b, hw, num_samples, 3)
        delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
        max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
        delta = torch.cat([delta, max_delta], dim = 2)
        sigma_delta = sigma * delta #(b, hw, num_samples, 1)
        exp_ = torch.exp(-sigma_delta)
        T = torch.cumprod(exp_,dim=2)
        T = torch.roll(T, 1, 2)
        T[...,0,0]=1

代码

demo.py,我按照tinyNerf重新写了下:

import torch
import torchvision
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import os
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(36,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,4)
        )
    def forward(self, x):
        y = self.net(x)
        y[..., :3] = torch.sigmoid(y[..., :3]) # rgb
        y[..., -1] = torch.nn.functional.relu(y[..., -1]) # sigma
        return y
class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.load()
    def __len__(self):
        return len(self.images)
    def load(self):
        data = np.load("tiny_nerf_data.npz")
        images = data["images"]
        tform_cam2world = torch.from_numpy(data["poses"])
        focal_length = torch.from_numpy(data["focal"])
        self.images = images
        self.tform_cam2world = tform_cam2world
        self.focal_length = focal_length
        self.height, self.width = images.shape[1:3]
    def __getitem__(self, i):
        image = self.images[i]
        tform_cam2world = self.tform_cam2world[i]
        return torch.from_numpy(image), tform_cam2world
class TinyNeRF():
    def __init__(self, net, focal, width, height, znear, zfar, L, num_samples):
        self.focal = focal
        self.width = width
        self.height = height
        self.znear = znear
        self.zfar = zfar
        self.num_samples = num_samples
        self.net = net
        self.L = L
    def sample_query_points(self, tform):
        b = tform.size(0)
        xx, yy = torch.arange(self.width), torch.arange(self.height)
        x, y = torch.meshgrid(xx, yy,indexing = "xy" ) # (h, w)
        x, y = x-0.5 * self.width , 0.5 * self.height-y
        directions = torch.stack([x / self.focal, y / self.focal, -torch.ones_like(x)], dim = -1) # (h,w,3)
        directions = directions.view(-1,3, 1).to(tform) #(hw, 3, 1)
        directions = (directions.view(1, self.height, self.width, 1, 3) * tform[:, :3, :3].view(b, 1, 1, 3, 3)).sum(dim =
        directions = directions.view(b, -1, 1, 3) #(b, hw, 1, 3)
        t = torch.linspace(self.znear, self.zfar, self.num_samples)
        noise = torch.rand([b, self.height * self.width, self.num_samples, 1])
        t = t.view(1, 1, self.num_samples, 1) + noise * (self.zfar - self.znear) / self.num_samples
        t = t.to(tform)
        origins = tform[:, :3, -1].view(b, 1, 1, 3) # tx ty tz
        return origins + directions * t, t # (b, wh, num_samples, 3)
    def position_encoding(self, x):
        b = x.size(0)
        x_flatten = x.view(-1,3, 1)
        freq = (2**torch.arange(0, self.L)).view(1,1,-1).to(x)
        y_sin = torch.sin(x_flatten * freq)
        y_cos = torch.cos(x_flatten * freq) # (n, 3, L)
        return torch.cat([y_sin, y_cos], dim = -1).view(b, -1, 2*3*self.L)
    def render(self, t, preds):
        # t(b, hw, num_samples, 1)
        b = preds.size(0)
        sigma = preds[..., -1:] # (b, hw, num_samples, 1)
        c = preds[..., :3] # (b, hw, num_samples, 3)
        delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
        max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
        delta = torch.cat([delta, max_delta], dim = 2)
        sigma_delta = sigma * delta #(b, hw, num_samples, 1)
        exp_ = torch.exp(-sigma_delta)
        T = torch.cumprod(exp_,dim=2)
        T = torch.roll(T, 1, 2)
        T[...,0,0]=1
        return torch.sum(T*(1-exp_)*c, dim = 2).view(b, self.height, self.width, 3)
    def query(self, tform):
        b = tform.size(0)
        pts, t = self.sample_query_points(tform)
        pts_code = self.position_encoding(pts)
        pred = net(pts_code)
        return pred.view(b, self.height*self.width, self.num_samples, 4), t # (wh*num_samples, 4)
def train(dataloader, nerf, optimizer, device):
    for epoch in range(300):
        bar = tqdm(dataloader)
        for img, tform in bar:
            img = img.to(device)
            tform = tform.to(device)
            optimizer.zero_grad()
            preds, t = nerf.query(tform)
            img_rendered = nerf.render(t, preds)
            loss = torch.nn.functional.mse_loss(img_rendered, img)
            loss.backward()
            optimizer.step()
            bar.set_description("[train] Epoch: {} Loss: {:.6f}".format(epoch+1, loss.item()))
        img_out = img_rendered.permute(0,3,1,2).contiguous()
        img_ref = img.permute(0,3,1,2).contiguous()
        img_out = torchvision.utils.make_grid(torch.cat([img_out, img_ref], dim = 0), nrow=2)
        torchvision.utils.save_image(img_out.cpu(),f"images/{epoch}.jpg" )
if __name__ == "__main__":
    if not os.path.exists("images"):
        os.mkdir("images")
    device = "cuda"
    batch_size = 8
    dataset = Dataset()
    net = Net().to(device)
    nerf = TinyNeRF(net, focal = dataset.focal_length, width = dataset.width,
                   height=dataset.height, znear = 2, zfar = 6, L = 6,
                   num_samples = 32)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size, True)
    opti train(dataloader, nerf, optimizer, device) r = 5e-3)
    train(dataloader, nerf, optimizer, device)

100 epoch时有个基本的样子:

其中主要在于2个函数,sample_query_points和render。

def sample_query_points(self, tform):
        b = tform.size(0)
        xx, yy = torch.arange(self.width), torch.arange(self.height)
        x, y = torch.meshgrid(xx, yy,indexing = "xy" ) # (h, w)
        x, y = x-0.5 * self.width , 0.5 * self.height-y
        directions = torch.stack([x / self.focal, y / self.focal, -torch.ones_like(x)], dim = -1) # (h,w,3)
        directions = directions.view(-1,3, 1).to(tform) #(hw, 3, 1)
        directions = (directions.view(1, self.height, self.width, 1, 3) * tform[:, :3, :3].view(b, 1, 1, 3, 3)).sum(dim =
        directions = directions.view(b, -1, 1, 3) #(b, hw, 1, 3)
        t = torch.linspace(self.znear, self.zfar, self.num_samples)
        noise = torch.rand([b, self.height * self.width, self.num_samples, 1])
        t = t.view(1, 1, self.num_samples, 1) + noise * (self.zfar - self.znear) / self.num_samples
        t = t.to(tform)
        origins = tform[:, :3, -1].view(b, 1, 1, 3) # tx ty tz
        return origins + directions * t, t # (b, wh, num_samples, 3)

这里meshgrid的indexing=’xy’是让生成的图像坐标跟图像是同shape的,即(h,w)

然后就是像素坐标系到相机坐标系的转换,他规定的是z轴向里(他的变换矩阵是根据这个规定算来的),所以虚像的平面应该是z=-focal。

代码里x和y先除以focal,z设置为-1是因为后面早晚要除的,x/focal * t才是点的坐标,然后先做相机系到世界系的变换,再采样t,跟先采样z再变换一样的。

然后就是均匀采样num_samples个点,得到这些点的坐标。

def render(self, t, preds):
        # t(b, hw, num_samples, 1)
        b = preds.size(0)
        sigma = preds[..., -1:] # (b, hw, num_samples, 1)
        c = preds[..., :3] # (b, hw, num_samples, 3)
        delta = t[..., 1:, :] - t[..., :-1, :] # (b, hw, num_samples, 1)
        max_delta = torch.ones(b, delta.size(1), 1, 1).to(delta) * 1e10
        delta = torch.cat([delta, max_delta], dim = 2)
        sigma_delta = sigma * delta #(b, hw, num_samples, 1)
        exp_ = torch.exp(-sigma_delta)
        T = torch.cumprod(exp_,dim=2)
        T = torch.roll(T, 1, 2)
        T[...,0,0]=1

这个上面推导过了,最后一行设第一个元素为1是因为对j的求和到\(i-1\),此时是\(e^0=1\)

max_delta那里也是因为差分的最后一个元素没有下一个元素了,所以我们给很大的差分会让这项算出来的权重几乎为0.

posted @ 2024-01-16 01:10  aoru45  阅读(167)  评论(0编辑  收藏  举报