diff --git a/convert.py b/convert.py index 23101f5..3dacec0 100644 --- a/convert.py +++ b/convert.py @@ -28,7 +28,9 @@ import screen as screen_py class ClusterPalette: def __init__( - self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0): + self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs, + reserved_colours=0): + self._image_rgb = image self._colours_cam = self._image_colours_cam(image) self._errors = [1e9] * 16 @@ -55,6 +57,8 @@ class ClusterPalette: # colour space self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs + self._rgb24_to_cam16ucs = rgb24_to_cam16ucs + # List of line ranges used to train the 16 SHR palettes # [(lower_0, upper_0), ...] self._palette_splits = self._palette_splits() @@ -68,6 +72,75 @@ class ClusterPalette: # Delta applied to palette split in previous iteration self._palette_mutate_delta = (0, 0) + def iterate(self, penalty: float, max_iterations: int): + iterations_since_improvement = 0 + total_image_error = 1e9 + + last_good_splits = self._palette_splits + + while iterations_since_improvement < max_iterations: + # print("Iterations %d" % iterations_since_improvement) + new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( + self._propose_palettes()) + + # Suppress divide by zero warning, + # https://github.com/colour-science/colour/issues/900 + with colour.utilities.suppress_warnings(python_warnings=True): + new_palettes_linear_rgb = colour.convert( + new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32) + + # Recompute image with proposed palettes and check whether it has + # lower total image error than our previous best. + new_output_4bit, new_line_to_palette, new_total_image_error = \ + dither_pyx.dither_shr( + self._image_rgb, new_palettes_cam, new_palettes_linear_rgb, + self._rgb24_to_cam16ucs, float(penalty)) + + # print(total_image_error, new_total_image_error, + # self._palette_splits) + + # TODO: extract this into a function + palettes_used = [False] * 16 + for palette in new_line_to_palette: + palettes_used[palette] = True + for palette_idx, palette_used in enumerate(palettes_used): + if palette_used: + continue + print("Reassigning palette %d" % palette_idx) + max_width = 0 + split_palette_idx = -1 + idx = 0 + for lower, upper in last_good_splits: + width = upper - lower + if width > max_width: + split_palette_idx = idx + idx += 1 + + lower, upper = last_good_splits[split_palette_idx] + if upper - lower > 20: + mid = (lower + upper) // 2 + self._palette_splits[split_palette_idx] = ( + lower, mid - 1) + self._palette_splits[palette_idx] = (mid, upper) + else: + lower = np.random.randint(0, 199) + upper = np.random.randint(lower, 200) + self._palette_splits[palette_idx] = (lower, upper) + + if new_total_image_error >= total_image_error: + iterations_since_improvement += 1 + continue + + # We found a globally better set of palettes + iterations_since_improvement = 0 + last_good_splits = self._palette_splits + total_image_error = new_total_image_error + self._accept_palettes( + new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors) + + yield (new_total_image_error, new_output_4bit, new_line_to_palette, + new_palettes_rgb12_iigs, new_palettes_linear_rgb) + def _image_colours_cam(self, image: Image): colours_rgb = np.asarray(image).reshape((-1, 3)) with colour.utilities.suppress_warnings(colour_usage_warnings=True): @@ -163,7 +236,7 @@ class ClusterPalette: self._apply_palette_delta(palette_to_mutate, palette_lower_delta, palette_upper_delta) - def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: + def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. Re-fit a set of 16 palettes from (overlapping) line ranges of the @@ -226,7 +299,7 @@ class ClusterPalette: self._palettes_accepted = False return new_palettes_cam, new_palettes_rgb12_iigs, new_errors - def accept_palettes( + def _accept_palettes( self, new_palettes_cam: np.ndarray, new_palettes_rgb: np.ndarray, new_errors: List[float]): self._palettes_cam = np.copy(new_palettes_cam) @@ -297,7 +370,7 @@ def main(): # TODO: flags penalty = 1 # 1e18 # TODO: is this needed any more? - iterations = 200 + iterations = 20# 0 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -307,83 +380,22 @@ def main(): canvas.fill((0, 0, 0)) pygame.display.flip() - total_image_error = 1e9 - iterations_since_improvement = 0 - + total_image_error = None # TODO: reserved_colours should be a flag cluster_palette = ClusterPalette( - rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs) - last_good_splits = cluster_palette._palette_splits + rgb, reserved_colours=1, + rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs, + rgb24_to_cam16ucs=rgb24_to_cam16ucs) - while iterations_since_improvement < iterations: - # print("Iterations %d" % iterations_since_improvement) - new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( - cluster_palette.propose_palettes()) + for (new_total_image_error, output_4bit, line_to_palette, + palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate( + penalty, iterations): - # Suppress divide by zero warning, - # https://github.com/colour-science/colour/issues/900 - with colour.utilities.suppress_warnings(python_warnings=True): - new_palettes_linear_rgb = colour.convert( - new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32) - - # Recompute image with proposed palettes and check whether it has - # lower total image error than our previous best. - new_output_4bit, new_line_to_palette, new_total_image_error = \ - dither_pyx.dither_shr( - rgb, new_palettes_cam, new_palettes_linear_rgb, - rgb24_to_cam16ucs, float(penalty)) - - # print(total_image_error, new_total_image_error, - # cluster_palette._palette_splits) - - # TODO: move this into ClusterPalettes - palettes_used = [False] * 16 - for palette in new_line_to_palette: - palettes_used[palette] = True - for palette_idx, palette_used in enumerate(palettes_used): - if palette_used: - continue - print("Reassigning palette %d" % palette_idx) - max_width = 0 - split_palette_idx = -1 - idx = 0 - for lower, upper in last_good_splits: - width = upper - lower - if width > max_width: - split_palette_idx = idx - idx += 1 - - lower, upper = last_good_splits[split_palette_idx] - if upper - lower > 20: - mid = (lower + upper) // 2 - cluster_palette._palette_splits[split_palette_idx] = ( - lower, mid - 1) - cluster_palette._palette_splits[palette_idx] = (mid, upper) - else: - lower = np.random.randint(0, 199) - upper = np.random.randint(lower, 200) - cluster_palette._palette_splits[palette_idx] = (lower, upper) - - if new_total_image_error >= total_image_error: - iterations_since_improvement += 1 - continue - - # We found a globally better set of palettes - iterations_since_improvement = 0 - cluster_palette.accept_palettes( - new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors) - last_good_splits = cluster_palette._palette_splits - - if total_image_error < 1e9: + if total_image_error is not None: print("Improved quality +%f%% (%f)" % ( (1 - new_total_image_error / total_image_error) * 100, new_total_image_error)) - # print(cluster_palette._palette_splits) - output_4bit = new_output_4bit - line_to_palette = new_line_to_palette total_image_error = new_total_image_error - palettes_rgb12_iigs = new_palettes_rgb12_iigs - palettes_linear_rgb = new_palettes_linear_rgb for i in range(16): screen.set_palette(i, palettes_rgb12_iigs[i, :, :])