熵编码实现

 1     def compress(self, x):
 2         y = self.g_a(x)
 3         y_strings = self.entropy_bottleneck.compress(y)
 4         return {"strings": [y_strings], "shape": y.size()[-2:]}
 5 
 6     def decompress(self, strings, shape):
 7         assert isinstance(strings, list) and len(strings) == 1
 8         y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
 9         x_hat = self.g_s(y_hat).clamp_(0, 1)
10         return {"x_hat": x_hat}
class EntropyBottleneck(EntropyModel):

   def compress(self, x):
        indexes = self._build_indexes(x.size())
        medians = self._get_medians().detach()
        spatial_dims = len(x.size()) - 2
        medians = self._extend_ndims(medians, spatial_dims)
        medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
        return super().compress(x, indexes, medians)

    def decompress(self, strings, size):
        output_size = (len(strings), self._quantized_cdf.size(0), *size)
        indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
        medians = self._extend_ndims(self._get_medians().detach(), len(size))
        medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
        return super().decompress(strings, indexes, medians.dtype, medians)

 

class _EntropyCoder:

底层代码中常用的是非对称数系编码区间编码
然后使用index进行编码/解码
将概率质量函数转换为量化的累积分布函数,并定义了一个占位符方法,鼓励在子类中提供具体实现。

class EntropyModel(nn.Module):

 

  1 class GaussianConditional(EntropyModel):     #高斯条件模型
  2     r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh,
  3     S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
  4     hyperprior" <https://arxiv.org/abs/1802.01436>`_.
  5 
  6     This is a re-implementation of the Gaussian conditional layer in
  7     *tensorflow/compression*. See the `tensorflow documentation
  8     <https://github.com/tensorflow/compression/blob/v1.3/docs/api_docs/python/tfc/GaussianConditional.md>`__
  9     for more information.
 10     """
 11 
 12     def __init__(
 13         self,
 14         scale_table: Optional[Union[List, Tuple]],     #标准差的table列表或元组
 15         *args: Any,
 16         scale_bound: float = 0.11,
 17         tail_mass: float = 1e-9,
 18         **kwargs: Any,
 19     ):
 20         super().__init__(*args, **kwargs)
 21 
 22         if not isinstance(scale_table, (type(None), list, tuple)):
 23             raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"')
 24 
 25         if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1:
 26             raise ValueError(f'Invalid scale_table length "{len(scale_table)}"')
 27 
 28         if scale_table and (
 29             scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table)
 30         ):
 31             raise ValueError(f'Invalid scale_table "({scale_table})"')
 32 
 33         self.tail_mass = float(tail_mass)
 34         if scale_bound is None and scale_table:
 35             scale_bound = self.scale_table[0]
 36         if scale_bound <= 0:
 37             raise ValueError("Invalid parameters")
 38         self.lower_bound_scale = LowerBound(scale_bound)
 39 
 40         self.register_buffer(
 41             "scale_table",
 42             self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(),
 43         )
 44 
 45         self.register_buffer(
 46             "scale_bound",
 47             torch.Tensor([float(scale_bound)]) if scale_bound is not None else None,
 48         )
 49 
 50     @staticmethod
 51     def _prepare_scale_table(scale_table):
 52         return torch.Tensor(tuple(float(s) for s in scale_table))
 53 
 54     def _standardized_cumulative(self, inputs: Tensor) -> Tensor:
 55         half = float(0.5)
 56         const = float(-(2**-0.5))
 57         # Using the complementary error function maximizes numerical precision.
 58         return half * torch.erfc(const * inputs)
 59 
 60     @staticmethod
 61     def _standardized_quantile(quantile):
 62         return scipy.stats.norm.ppf(quantile)
 63 
 64     def update_scale_table(self, scale_table, force=False):
 65         # Check if we need to update the gaussian conditional parameters, the
 66         # offsets are only computed and stored when the conditonal model is
 67         # updated.
 68         if self._offset.numel() > 0 and not force:
 69             return False
 70         device = self.scale_table.device
 71         self.scale_table = self._prepare_scale_table(scale_table).to(device)
 72         self.update()
 73         return True
 74 
 75     def update(self):
 76         multiplier = -self._standardized_quantile(self.tail_mass / 2)
 77         pmf_center = torch.ceil(self.scale_table * multiplier).int()
 78         pmf_length = 2 * pmf_center + 1
 79         max_length = torch.max(pmf_length).item()
 80 
 81         device = pmf_center.device
 82         samples = torch.abs(
 83             torch.arange(max_length, device=device).int() - pmf_center[:, None]
 84         )
 85         samples_scale = self.scale_table.unsqueeze(1)
 86         samples = samples.float()
 87         samples_scale = samples_scale.float()
 88         upper = self._standardized_cumulative((0.5 - samples) / samples_scale)
 89         lower = self._standardized_cumulative((-0.5 - samples) / samples_scale)
 90         pmf = upper - lower
 91 
 92         tail_mass = 2 * lower[:, :1]
 93 
 94         quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
 95         quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
 96         self._quantized_cdf = quantized_cdf
 97         self._offset = -pmf_center
 98         self._cdf_length = pmf_length + 2
 99 
100     def _likelihood(
101         self, inputs: Tensor, scales: Tensor, means: Optional[Tensor] = None
102     ) -> Tensor:
103         half = float(0.5)
104 
105         if means is not None:
106             values = inputs - means
107         else:
108             values = inputs
109 
110         scales = self.lower_bound_scale(scales)
111 
112         values = torch.abs(values)
113         upper = self._standardized_cumulative((half - values) / scales)
114         lower = self._standardized_cumulative((-half - values) / scales)
115         likelihood = upper - lower
116 
117         return likelihood
118 
119     def forward(
120         self,
121         inputs: Tensor,
122         scales: Tensor,
123         means: Optional[Tensor] = None,
124         training: Optional[bool] = None,
125     ) -> Tuple[Tensor, Tensor]:
126         if training is None:
127             training = self.training
128         outputs = self.quantize(inputs, "noise" if training else "dequantize", means)
129         likelihood = self._likelihood(outputs, scales, means)
130         if self.use_likelihood_bound:
131             likelihood = self.likelihood_lower_bound(likelihood)
132         return outputs, likelihood
133 
134     def build_indexes(self, scales: Tensor) -> Tensor:
135         scales = self.lower_bound_scale(scales)
136         indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int()
137         for s in self.scale_table[:-1]:
138             indexes -= (scales <= s).int()
139         return indexes

 

posted @ 2023-12-01 20:22  浪矢-CL  阅读(124)  评论(0编辑  收藏  举报