From c78f731cd75ad131f6c16b636922f2319acbd462 Mon Sep 17 00:00:00 2001 From: kris Date: Tue, 23 Nov 2021 14:55:45 +0000 Subject: [PATCH] Refactor --- convert.py | 81 +++++++++++++++++++++++++----------------------------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/convert.py b/convert.py index 3dacec0..c91c180 100644 --- a/convert.py +++ b/convert.py @@ -96,36 +96,8 @@ class ClusterPalette: self._image_rgb, new_palettes_cam, new_palettes_linear_rgb, self._rgb24_to_cam16ucs, float(penalty)) - # print(total_image_error, new_total_image_error, - # self._palette_splits) - - # TODO: extract this into a function - palettes_used = [False] * 16 - for palette in new_line_to_palette: - palettes_used[palette] = True - for palette_idx, palette_used in enumerate(palettes_used): - if palette_used: - continue - print("Reassigning palette %d" % palette_idx) - max_width = 0 - split_palette_idx = -1 - idx = 0 - for lower, upper in last_good_splits: - width = upper - lower - if width > max_width: - split_palette_idx = idx - idx += 1 - - lower, upper = last_good_splits[split_palette_idx] - if upper - lower > 20: - mid = (lower + upper) // 2 - self._palette_splits[split_palette_idx] = ( - lower, mid - 1) - self._palette_splits[palette_idx] = (mid, upper) - else: - lower = np.random.randint(0, 199) - upper = np.random.randint(lower, 200) - self._palette_splits[palette_idx] = (lower, upper) + self._reassign_unused_palettes(new_line_to_palette, + last_good_splits) if new_total_image_error >= total_image_error: iterations_since_improvement += 1 @@ -135,8 +107,11 @@ class ClusterPalette: iterations_since_improvement = 0 last_good_splits = self._palette_splits total_image_error = new_total_image_error - self._accept_palettes( - new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors) + + self._palettes_cam = new_palettes_cam + self._palettes_rgb = new_palettes_rgb12_iigs + self._errors = new_palette_errors + self._palettes_accepted = True yield (new_total_image_error, new_output_4bit, new_line_to_palette, new_palettes_rgb12_iigs, new_palettes_linear_rgb) @@ -236,6 +211,34 @@ class ClusterPalette: self._apply_palette_delta(palette_to_mutate, palette_lower_delta, palette_upper_delta) + def _reassign_unused_palettes(self, new_line_to_palette, last_good_splits): + palettes_used = [False] * 16 + for palette in new_line_to_palette: + palettes_used[palette] = True + for palette_idx, palette_used in enumerate(palettes_used): + if palette_used: + continue + print("Reassigning palette %d" % palette_idx) + max_width = 0 + split_palette_idx = -1 + idx = 0 + for lower, upper in last_good_splits: + width = upper - lower + if width > max_width: + split_palette_idx = idx + idx += 1 + + lower, upper = last_good_splits[split_palette_idx] + if upper - lower > 20: + mid = (lower + upper) // 2 + self._palette_splits[split_palette_idx] = ( + lower, mid - 1) + self._palette_splits[palette_idx] = (mid, upper) + else: + lower = np.random.randint(0, 199) + upper = np.random.randint(lower, 200) + self._palette_splits[palette_idx] = (lower, upper) + def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. @@ -262,8 +265,6 @@ class ClusterPalette: # individual palettes self._fit_global_palette() - dynamic_colours = 16 - self._reserved_colours - self._mutate_palette_splits() for palette_idx in range(16): palette_lower, palette_upper = self._palette_splits[palette_idx] @@ -299,14 +300,6 @@ class ClusterPalette: self._palettes_accepted = False return new_palettes_cam, new_palettes_rgb12_iigs, new_errors - def _accept_palettes( - self, new_palettes_cam: np.ndarray, - new_palettes_rgb: np.ndarray, new_errors: List[float]): - self._palettes_cam = np.copy(new_palettes_cam) - self._palettes_rgb = np.copy(new_palettes_rgb) - self._errors = list(new_errors) - self._palettes_accepted = True - def main(): parser = argparse.ArgumentParser() @@ -370,7 +363,7 @@ def main(): # TODO: flags penalty = 1 # 1e18 # TODO: is this needed any more? - iterations = 20# 0 + iterations = 200 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -389,7 +382,7 @@ def main(): for (new_total_image_error, output_4bit, line_to_palette, palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate( - penalty, iterations): + penalty, iterations): if total_image_error is not None: print("Improved quality +%f%% (%f)" % (