diff --git a/convert.py b/convert.py index d964515..dd59e9c 100644 --- a/convert.py +++ b/convert.py @@ -30,13 +30,44 @@ class ClusterPalette: def __init__( self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0): self._colours_cam = self._image_colours_cam(image) - self._reserved_colours = reserved_colours + self._errors = [1e9] * 16 + + # We fit a 16-colour palette against the entire image which is used + # as starting values for fitting the 16 SHR palettes. This helps to + # provide better global consistency of colours across the palettes, + # e.g. for large blocks of colour. Otherwise these can take a while + # to converge. + self._global_palette = np.empty((16, 3), dtype=np.uint8) + + # How many image colours to fix identically across all 16 SHR + # palettes. These are taken to be the most prevalent colours from + # _global_palette. + self._reserved_colours = reserved_colours + + # 16 SHR palettes each of 16 colours, in CAM16UCS format self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) + + # 16 SHR palettes each of 16 colours, in //gs 4-bit RGB format self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8) - self._global_palette = np.empty((16, 16, 3), dtype=np.float32) + + # Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS + # colour space self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs + # List of line ranges used to train the 16 SHR palettes + # [(lower_0, upper_0), ...] + self._palette_splits = self._palette_splits() + + # Whether the previous iteration of proposed palettes was accepted + self._palettes_accepted = False + + # Which palette index's line ranges did we mutate in previous iteration + self._palette_mutate_idx = 0 + + # Delta applied to palette split in previous iteration + self._palette_mutate_delta = (0, 0) + def _image_colours_cam(self, image: Image): colours_rgb = np.asarray(image).reshape((-1, 3)) with colour.utilities.suppress_warnings(colour_usage_warnings=True): @@ -60,10 +91,9 @@ class ClusterPalette: list(zip(*np.unique(labels, return_counts=True))), key=lambda kv: kv[1], reverse=True)] - res = dither_pyx.convert_cam16ucs_to_rgb12_iigs( + return dither_pyx.convert_cam16ucs_to_rgb12_iigs( clusters.cluster_centers_[frequency_order].astype( np.float32)) - return res def _palette_splits(self, palette_height=35): # The 16 palettes are striped across consecutive (overlapping) line @@ -74,7 +104,7 @@ class ClusterPalette: # has height H and overlaps the previous one by L lines, then the # boundaries are at lines: # (0, H), (H-L, 2H-L), (2H-2L, 3H-2L), ..., (15H-15L, 16H - 15L) - # i.e. 16H - 15L = 200, sofor a given palette height H we need to + # i.e. 16H - 15L = 200, so for a given palette height H we need to # overlap by: # L = (16H - 200)/15 @@ -86,9 +116,51 @@ class ClusterPalette: palette_upper = palette_lower + palette_height palette_ranges.append((int(np.round(palette_lower)), int(np.round(palette_upper)))) - # print(palette_ranges) return palette_ranges + def _apply_palette_delta( + self, palette_to_mutate, palette_lower_delta, palette_upper_delta): + old_lower, old_upper = self._palette_splits[palette_to_mutate] + new_lower = old_lower + palette_lower_delta + new_upper = old_upper + palette_upper_delta + + new_lower = np.clip(new_lower, 0, np.clip(new_upper, 1, 200) - 1) + new_upper = np.clip(new_upper, new_lower + 1, 200) + assert new_lower >= 0, new_upper-1 + + self._palette_splits[palette_to_mutate] = (new_lower, new_upper) + self._palette_mutate_idx = palette_to_mutate + self._palette_mutate_delta = (palette_lower_delta, palette_upper_delta) + + def _mutate_palette_splits(self): + if self._palettes_accepted: + # Last time was good, keep going + self._apply_palette_delta(self._palette_mutate_idx, + self._palette_mutate_delta[0], + self._palette_mutate_delta[1]) + else: + # undo last mutation + self._apply_palette_delta(self._palette_mutate_idx, + -self._palette_mutate_delta[0], + -self._palette_mutate_delta[1]) + + # Pick a palette endpoint to move up or down + palette_to_mutate = np.random.randint(0, 16) + while True: + if palette_to_mutate > 0: + palette_lower_delta = np.random.randint(-20, 21) + else: + palette_lower_delta = 0 + if palette_to_mutate < 15: + palette_upper_delta = np.random.randint(-20, 21) + else: + palette_upper_delta = 0 + if palette_lower_delta != 0 or palette_upper_delta != 0: + break + + self._apply_palette_delta(palette_to_mutate, palette_lower_delta, + palette_upper_delta) + def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. @@ -117,9 +189,9 @@ class ClusterPalette: dynamic_colours = 16 - self._reserved_colours - palette_splits = self._palette_splits() + self._mutate_palette_splits() for palette_idx in range(16): - palette_lower, palette_upper = palette_splits[palette_idx] + palette_lower, palette_upper = self._palette_splits[palette_idx] # TODO: dynamically tune palette cuts palette_pixels = self._colours_cam[ palette_lower * 320:palette_upper * 320, :] @@ -149,6 +221,7 @@ class ClusterPalette: new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs new_errors[palette_idx] = palette_error + self._palettes_accepted = False return new_palettes_cam, new_palettes_rgb12_iigs, new_errors def accept_palettes( @@ -157,6 +230,7 @@ class ClusterPalette: self._palettes_cam = np.copy(new_palettes_cam) self._palettes_rgb = np.copy(new_palettes_rgb) self._errors = list(new_errors) + self._palettes_accepted = True def main(): @@ -220,8 +294,8 @@ def main(): gamma=args.gamma_correct)).astype(np.float32) / 255 # TODO: flags - penalty = 1e9 - iterations = 50 + penalty = 1 # 1e18 # TODO: is this needed any more? + iterations = 200 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -234,10 +308,13 @@ def main(): total_image_error = 1e9 iterations_since_improvement = 0 + # TODO: reserved_colours should be a flag cluster_palette = ClusterPalette( rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs) + last_good_splits = cluster_palette._palette_splits while iterations_since_improvement < iterations: + # print("Iterations %d" % iterations_since_improvement) new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( cluster_palette.propose_palettes()) @@ -253,6 +330,37 @@ def main(): dither_pyx.dither_shr( rgb, new_palettes_cam, new_palettes_linear_rgb, rgb24_to_cam16ucs, float(penalty)) + + # print(total_image_error, new_total_image_error, + # cluster_palette._palette_splits) + + # TODO: move this into ClusterPalettes + palettes_used = [False] * 16 + for palette in new_line_to_palette: + palettes_used[palette] = True + for palette_idx, palette in enumerate(palettes_used): + if palette: + continue + print("Reassigning palette %d" % palette_idx) + max_width = 0 + split_palette_idx = -1 + idx = 0 + for lower, upper in last_good_splits: + width = upper - lower + if width > max_width: + split_palette_idx = idx + idx += 1 + + lower, upper = last_good_splits[split_palette_idx] + if upper - lower > 20: + mid = (lower + upper) // 2 + cluster_palette._palette_splits[split_palette_idx] = (lower, mid) + cluster_palette._palette_splits[palette_idx] = (mid, upper) + else: + lower = np.random.randint(0, 199) + upper = np.random.randint(lower, 200) + cluster_palette._palette_splits[palette_idx] = (lower, upper) + if new_total_image_error >= total_image_error: iterations_since_improvement += 1 continue @@ -261,11 +369,13 @@ def main(): iterations_since_improvement = 0 cluster_palette.accept_palettes( new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors) + last_good_splits = cluster_palette._palette_splits if total_image_error < 1e9: print("Improved quality +%f%% (%f)" % ( (1 - new_total_image_error / total_image_error) * 100, new_total_image_error)) + # print(cluster_palette._palette_splits) output_4bit = new_output_4bit line_to_palette = new_line_to_palette total_image_error = new_total_image_error