Tensorflow Probability中Categorical
简介
TensorFlow Probability 是 TensorFlow 中用于概率推理和统计分析的库。
安装
安装最新版本的 TensorFlow Probability:
pip install --upgrade tensorflow-probability
安装指定版本的 TensorFlow Probability:
pip install tensorflow-probability==版本号
有关 TensorFlow 和 TensorFlow Probability 之间的版本对应关系,请参阅 TFP 版本说明。
使用
这里仅介绍我常用的一个根据概率分布采样的功能,其余功能参考官方文档。
Categorical类
用途:创建一个用于表示不同种类概率分布的对象。
import tensorflow_probability as tfp
dist = tfp.distributions.Categorical(
logits=None, # 传入的是logits的分布(未经过sotfmax)
probs=None, # 传入的是概率分布
dtype=tf.int32, # 种类的数据类型
validate_args=False,
allow_nan_stats=True,
name='Categorical'
)
类的属性和方法
dist.probs:
得到传入的probsdist.logits:
得到传入的logitsdist.prob(value):
返回某个种类的概率dist.log_prob(value):
返回某个种类的概率的logdist.sample(sample_shape=(), seed=None, name='sample', **kwargs):
按probs的分布采样种类
举例
>>> import tensorflow as tf
>>> import tensorflow_probability as tfp
>>> dist = tfp.distributions.Categorical(probs=[0.1, 0.2, 0.7], dtype='float32')
>>> print(dist.probs)
tf.Tensor([0.1 0.2 0.7], shape=(3,), dtype=float32)
>>> print(dist.logits)
None
>>> dist.sample()
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>
>>> dist.log_prob(0) # 计算种类0对应的prob的log,即log(0.1)
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>
>>> tf.math.log(0.1) # 结果和上面一样
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>
>>> dist2 = tfp.distributions.Categorical(logits=[0.1, 0.2, 0.3], dtype='float32')
# 传入logits在执行prob()时,会自动对其作sotfmax操作
>>> dist2.prob(0)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3006096>
>>> dist2.prob(1)
<tf.Tensor: shape=(), dtype=float32, numpy=0.33222497>
>>> dist2.prob(2)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3671654>
>>> tf.nn.softmax([0.1, 0.2, 0.3]) # 结果和上面一样
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.3006096, 0.332225 , 0.3671654], dtype=float32)>