熵编码实现
1 def compress(self, x): 2 y = self.g_a(x) 3 y_strings = self.entropy_bottleneck.compress(y) 4 return {"strings": [y_strings], "shape": y.size()[-2:]} 5 6 def decompress(self, strings, shape): 7 assert isinstance(strings, list) and len(strings) == 1 8 y_hat = self.entropy_bottleneck.decompress(strings[0], shape) 9 x_hat = self.g_s(y_hat).clamp_(0, 1) 10 return {"x_hat": x_hat}
class EntropyBottleneck(EntropyModel): def compress(self, x): indexes = self._build_indexes(x.size()) medians = self._get_medians().detach() spatial_dims = len(x.size()) - 2 medians = self._extend_ndims(medians, spatial_dims) medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) return super().compress(x, indexes, medians) def decompress(self, strings, size): output_size = (len(strings), self._quantized_cdf.size(0), *size) indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) medians = self._extend_ndims(self._get_medians().detach(), len(size)) medians = medians.expand(len(strings), *([-1] * (len(size) + 1))) return super().decompress(strings, indexes, medians.dtype, medians)
class _EntropyCoder:
底层代码中常用的是非对称数系编码和区间编码
然后使用index进行编码/解码
将概率质量函数转换为量化的累积分布函数,并定义了一个占位符方法,鼓励在子类中提供具体实现。
class EntropyModel(nn.Module):
1 class GaussianConditional(EntropyModel): #高斯条件模型 2 r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, 3 S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 4 hyperprior" <https://arxiv.org/abs/1802.01436>`_. 5 6 This is a re-implementation of the Gaussian conditional layer in 7 *tensorflow/compression*. See the `tensorflow documentation 8 <https://github.com/tensorflow/compression/blob/v1.3/docs/api_docs/python/tfc/GaussianConditional.md>`__ 9 for more information. 10 """ 11 12 def __init__( 13 self, 14 scale_table: Optional[Union[List, Tuple]], #标准差的table列表或元组 15 *args: Any, 16 scale_bound: float = 0.11, 17 tail_mass: float = 1e-9, 18 **kwargs: Any, 19 ): 20 super().__init__(*args, **kwargs) 21 22 if not isinstance(scale_table, (type(None), list, tuple)): 23 raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') 24 25 if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: 26 raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') 27 28 if scale_table and ( 29 scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) 30 ): 31 raise ValueError(f'Invalid scale_table "({scale_table})"') 32 33 self.tail_mass = float(tail_mass) 34 if scale_bound is None and scale_table: 35 scale_bound = self.scale_table[0] 36 if scale_bound <= 0: 37 raise ValueError("Invalid parameters") 38 self.lower_bound_scale = LowerBound(scale_bound) 39 40 self.register_buffer( 41 "scale_table", 42 self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), 43 ) 44 45 self.register_buffer( 46 "scale_bound", 47 torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, 48 ) 49 50 @staticmethod 51 def _prepare_scale_table(scale_table): 52 return torch.Tensor(tuple(float(s) for s in scale_table)) 53 54 def _standardized_cumulative(self, inputs: Tensor) -> Tensor: 55 half = float(0.5) 56 const = float(-(2**-0.5)) 57 # Using the complementary error function maximizes numerical precision. 58 return half * torch.erfc(const * inputs) 59 60 @staticmethod 61 def _standardized_quantile(quantile): 62 return scipy.stats.norm.ppf(quantile) 63 64 def update_scale_table(self, scale_table, force=False): 65 # Check if we need to update the gaussian conditional parameters, the 66 # offsets are only computed and stored when the conditonal model is 67 # updated. 68 if self._offset.numel() > 0 and not force: 69 return False 70 device = self.scale_table.device 71 self.scale_table = self._prepare_scale_table(scale_table).to(device) 72 self.update() 73 return True 74 75 def update(self): 76 multiplier = -self._standardized_quantile(self.tail_mass / 2) 77 pmf_center = torch.ceil(self.scale_table * multiplier).int() 78 pmf_length = 2 * pmf_center + 1 79 max_length = torch.max(pmf_length).item() 80 81 device = pmf_center.device 82 samples = torch.abs( 83 torch.arange(max_length, device=device).int() - pmf_center[:, None] 84 ) 85 samples_scale = self.scale_table.unsqueeze(1) 86 samples = samples.float() 87 samples_scale = samples_scale.float() 88 upper = self._standardized_cumulative((0.5 - samples) / samples_scale) 89 lower = self._standardized_cumulative((-0.5 - samples) / samples_scale) 90 pmf = upper - lower 91 92 tail_mass = 2 * lower[:, :1] 93 94 quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 95 quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 96 self._quantized_cdf = quantized_cdf 97 self._offset = -pmf_center 98 self._cdf_length = pmf_length + 2 99 100 def _likelihood( 101 self, inputs: Tensor, scales: Tensor, means: Optional[Tensor] = None 102 ) -> Tensor: 103 half = float(0.5) 104 105 if means is not None: 106 values = inputs - means 107 else: 108 values = inputs 109 110 scales = self.lower_bound_scale(scales) 111 112 values = torch.abs(values) 113 upper = self._standardized_cumulative((half - values) / scales) 114 lower = self._standardized_cumulative((-half - values) / scales) 115 likelihood = upper - lower 116 117 return likelihood 118 119 def forward( 120 self, 121 inputs: Tensor, 122 scales: Tensor, 123 means: Optional[Tensor] = None, 124 training: Optional[bool] = None, 125 ) -> Tuple[Tensor, Tensor]: 126 if training is None: 127 training = self.training 128 outputs = self.quantize(inputs, "noise" if training else "dequantize", means) 129 likelihood = self._likelihood(outputs, scales, means) 130 if self.use_likelihood_bound: 131 likelihood = self.likelihood_lower_bound(likelihood) 132 return outputs, likelihood 133 134 def build_indexes(self, scales: Tensor) -> Tensor: 135 scales = self.lower_bound_scale(scales) 136 indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 137 for s in self.scale_table[:-1]: 138 indexes -= (scales <= s).int() 139 return indexes