Pytorch的Normal
pytorch的normal,从来没有用过,稍微整理下,导包:
from torch.distributions import Normal
normal.sample()
normal.rsample()
rsample()不是在定义的正太分布上采样,而是先对标准正太分布N(0,1)进行采样,然后输出:
mean+std×采样值
normal.log_prob(c)
log_prob(value)是计算value在定义的正太分布中对应的概率的对数