diff --git a/convert.py b/convert.py index e532a92..a90c7e6 100644 --- a/convert.py +++ b/convert.py @@ -102,32 +102,43 @@ class ClusterPalette: int(np.round(palette_upper)))) return palette_ranges + def _dither_image(self, palettes_cam, penalty): + # Suppress divide by zero warning, + # https://github.com/colour-science/colour/issues/900 + with colour.utilities.suppress_warnings(python_warnings=True): + palettes_linear_rgb = colour.convert( + palettes_cam, "CAM16UCS", "RGB").astype(np.float32) + + output_4bit, line_to_palette, total_image_error = \ + dither_pyx.dither_shr( + self._image_rgb, palettes_cam, palettes_linear_rgb, + self._rgb24_to_cam16ucs, float(penalty)) + + return (output_4bit, line_to_palette, palettes_linear_rgb, + total_image_error) + def iterate(self, penalty: float, max_iterations: int): iterations_since_improvement = 0 total_image_error = 1e9 last_good_splits = self._palette_splits + (output_4bit, line_to_palette, palettes_linear_rgb, + new_total_image_error) = self._dither_image( + self._palettes_cam, penalty) while iterations_since_improvement < max_iterations: + self._reassign_unused_palettes(line_to_palette, + last_good_splits) + # print("Iterations %d" % iterations_since_improvement) new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( self._propose_palettes()) - # Suppress divide by zero warning, - # https://github.com/colour-science/colour/issues/900 - with colour.utilities.suppress_warnings(python_warnings=True): - new_palettes_linear_rgb = colour.convert( - new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32) - # Recompute image with proposed palettes and check whether it has # lower total image error than our previous best. - new_output_4bit, new_line_to_palette, new_total_image_error = \ - dither_pyx.dither_shr( - self._image_rgb, new_palettes_cam, new_palettes_linear_rgb, - self._rgb24_to_cam16ucs, float(penalty)) - - self._reassign_unused_palettes(new_line_to_palette, - last_good_splits) + (output_4bit, line_to_palette, palettes_linear_rgb, + new_total_image_error) = self._dither_image( + new_palettes_cam, penalty) if new_total_image_error >= total_image_error: iterations_since_improvement += 1 @@ -143,8 +154,8 @@ class ClusterPalette: self._errors = new_palette_errors self._palettes_accepted = True - yield (new_total_image_error, new_output_4bit, new_line_to_palette, - new_palettes_rgb12_iigs, new_palettes_linear_rgb) + yield (new_total_image_error, output_4bit, line_to_palette, + new_palettes_rgb12_iigs, palettes_linear_rgb) def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality.