Tensorflow Probability中Categorical

简介

TensorFlow Probability 是 TensorFlow 中用于概率推理和统计分析的库。

安装

安装最新版本的 TensorFlow Probability:

 pip install --upgrade tensorflow-probability

安装指定版本的 TensorFlow Probability:

 pip install tensorflow-probability==版本号

有关 TensorFlow 和 TensorFlow Probability 之间的版本对应关系,请参阅 TFP 版本说明

使用

这里仅介绍我常用的一个根据概率分布采样的功能,其余功能参考官方文档

Categorical类

用途:创建一个用于表示不同种类概率分布的对象。

import tensorflow_probability as tfp

dist = tfp.distributions.Categorical(
    						logits=None,  # 传入的是logits的分布(未经过sotfmax)
  							probs=None,  # 传入的是概率分布
  							dtype=tf.int32,  # 种类的数据类型
  							validate_args=False,
    						allow_nan_stats=True, 
  							name='Categorical'
)

类的属性和方法

  • dist.probs:得到传入的probs
  • dist.logits:得到传入的logits
  • dist.prob(value):返回某个种类的概率
  • dist.log_prob(value):返回某个种类的概率的log
  • dist.sample(sample_shape=(), seed=None, name='sample', **kwargs):按probs的分布采样种类

举例

>>> import tensorflow as tf
>>> import tensorflow_probability as tfp
>>> dist = tfp.distributions.Categorical(probs=[0.1, 0.2, 0.7], dtype='float32')

>>> print(dist.probs)
tf.Tensor([0.1 0.2 0.7], shape=(3,), dtype=float32)

>>> print(dist.logits)
None

>>> dist.sample()
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>

>>> dist.log_prob(0)  # 计算种类0对应的prob的log,即log(0.1)
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>

>>> tf.math.log(0.1)  # 结果和上面一样
<tf.Tensor: shape=(), dtype=float32, numpy=-2.3025851>
>>> dist2 = tfp.distributions.Categorical(logits=[0.1, 0.2, 0.3], dtype='float32')
# 传入logits在执行prob()时,会自动对其作sotfmax操作
>>> dist2.prob(0)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3006096>

>>> dist2.prob(1)
<tf.Tensor: shape=(), dtype=float32, numpy=0.33222497>

>>> dist2.prob(2)
<tf.Tensor: shape=(), dtype=float32, numpy=0.3671654>

>>> tf.nn.softmax([0.1, 0.2, 0.3])  # 结果和上面一样
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.3006096, 0.332225 , 0.3671654], dtype=float32)>

参考

https://tensorflow.google.cn/probability/api_docs/python/tfp/distributions/Categorical?skip_cache=true#attributes

posted @ 2021-01-11 15:12  火锅先生  阅读(1898)  评论(0编辑  收藏  举报