From d645cc5964eea720ea85abaee02de781100a86ca Mon Sep 17 00:00:00 2001 From: kris Date: Wed, 24 Nov 2021 15:21:50 +0000 Subject: [PATCH] Tidy --- convert.py | 157 ++++++++++++++++------------------------------------- 1 file changed, 48 insertions(+), 109 deletions(-) diff --git a/convert.py b/convert.py index 3921656..813fc8a 100644 --- a/convert.py +++ b/convert.py @@ -32,65 +32,40 @@ class ClusterPalette: def __init__( self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs, reserved_colours=0): + + # Source image in 24-bit linear RGB colour space self._image_rgb = image + + # Source image in CAM16UCS colour space self._colours_cam = self._image_colours_cam(image) - self._errors = [1e9] * 16 - - # We fit a 16-colour palette against the entire image which is used - # as starting values for fitting the 16 SHR palettes. This helps to - # provide better global consistency of colours across the palettes, - # e.g. for large blocks of colour. Otherwise these can take a while - # to converge. - self._global_palette = np.empty((16, 3), dtype=np.uint8) - # How many image colours to fix identically across all 16 SHR # palettes. These are taken to be the most prevalent colours from # _global_palette. self._reserved_colours = reserved_colours - # 16 SHR palettes each of 16 colours, in CAM16UCS format + # We fit a 16-colour palette against the entire image which is used + # as starting values for fitting the reserved colours in the 16 SHR + # palettes. + self._global_palette = np.empty((16, 3), dtype=np.uint8) + + # 16 SHR palettes each of 16 colours, in CAM16UCS colour space self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) - # 16 SHR palettes each of 16 colours, in //gs 4-bit RGB format + # 16 SHR palettes each of 16 colours, in //gs 4-bit RGB colour space self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8) + # defaultdict(list) mapping palette index to lines using this palette + self._palette_lines = self._init_palette_lines() + # Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS # colour space self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs + # Conversion matrix from 24-bit linear RGB colour space to CAM16UCS + # colour space self._rgb24_to_cam16ucs = rgb24_to_cam16ucs - self._palette_lines = defaultdict(list) - - # List of line ranges used to train the 16 SHR palettes - # [(lower_0, upper_0), ...] - # self._palette_splits = self._equal_palette_splits() - self._init_palette_lines() - - # Whether the previous iteration of proposed palettes was accepted - self._palettes_accepted = False - - # Which palette index's line ranges did we mutate in previous iteration - self._palette_mutate_idx = 0 - - # Delta applied to palette split in previous iteration - self._palette_mutate_delta = (0, 0) - - def _init_palette_lines(self): - palette_splits = self._equal_palette_splits() - for i, lh in enumerate(palette_splits): - l, h = lh - self._palette_lines[i].extend(list(range(l, h))) - - # lines = list(range(200)) - # random.shuffle(lines) - # idx = 0 - # while lines: - # self._palette_lines[idx].append(lines.pop()) - # idx += 1 - - def _image_colours_cam(self, image: Image): colours_rgb = np.asarray(image) # .reshape((-1, 3)) with colour.utilities.suppress_warnings(colour_usage_warnings=True): @@ -98,6 +73,23 @@ class ClusterPalette: "CAM16UCS").astype(np.float32) return colours_cam + def _init_palette_lines(self, init_random = False): + palette_lines = defaultdict(list) + + if init_random: + lines = list(range(200)) + random.shuffle(lines) + idx = 0 + while lines: + palette_lines[idx].append(lines.pop()) + idx += 1 + else: + palette_splits = self._equal_palette_splits() + for i, lh in enumerate(palette_splits): + l, h = lh + palette_lines[i].extend(list(range(l, h))) + return palette_lines + def _equal_palette_splits(self, palette_height=35): # The 16 palettes are striped across consecutive (overlapping) line # ranges. Since nearby lines tend to have similar colours, this has @@ -133,6 +125,8 @@ class ClusterPalette: self._image_rgb, palettes_cam, palettes_linear_rgb, self._rgb24_to_cam16ucs, float(penalty)) + # Update map of palettes to image lines for which the palette was the + # best match palette_lines = defaultdict(list) for line, palette in enumerate(line_to_palette): palette_lines[palette].append(line) @@ -146,23 +140,21 @@ class ClusterPalette: def iterate(self, penalty: float, max_inner_iterations: int, max_outer_iterations: int): total_image_error = 1e9 - # last_good_splits = self._palette_splits outer_iterations_since_improvement = 0 while outer_iterations_since_improvement < max_outer_iterations: print("New iteration") inner_iterations_since_improvement = 0 - # self._palette_splits = self._equal_palette_splits() - self._init_palette_lines() + self._palette_lines = self._init_palette_lines() self._fit_global_palette() while inner_iterations_since_improvement < max_inner_iterations: # print("Iterations %d" % inner_iterations_since_improvement) - new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( + new_palettes_cam, new_palettes_rgb12_iigs = ( self._propose_palettes()) - # Recompute image with proposed palettes and check whether it has - # lower total image error than our previous best. + # Recompute image with proposed palettes and check whether it + # has lower total image error than our previous best. (output_4bit, line_to_palette, palettes_linear_rgb, new_total_image_error) = self._dither_image( new_palettes_cam, penalty) @@ -171,27 +163,24 @@ class ClusterPalette: # within a palette self._reassign_unused_palettes(line_to_palette) - # print(total_image_error, new_total_image_error) if new_total_image_error >= total_image_error: inner_iterations_since_improvement += 1 continue - # We found a globally better set of palettes + # We found a globally better set of palettes, so restart the + # clocks inner_iterations_since_improvement = 0 outer_iterations_since_improvement = -1 - # last_good_splits = self._palette_splits total_image_error = new_total_image_error self._palettes_cam = new_palettes_cam self._palettes_rgb = new_palettes_rgb12_iigs - self._errors = new_palette_errors - self._palettes_accepted = True yield (new_total_image_error, output_4bit, line_to_palette, new_palettes_rgb12_iigs, palettes_linear_rgb) outer_iterations_since_improvement += 1 - def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: + def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray]: """Attempt to find new palettes that locally improve image quality. Re-fit a set of 16 palettes from (overlapping) line ranges of the @@ -208,7 +197,6 @@ class ClusterPalette: The current (locally) best palettes are returned and can be applied using accept_palettes(). """ - new_errors = list(self._errors) new_palettes_cam = np.empty_like(self._palettes_cam) new_palettes_rgb12_iigs = np.empty_like(self._palettes_rgb) @@ -217,14 +205,15 @@ class ClusterPalette: # individual palettes self._fit_global_palette() - # self._mutate_palette_splits() for palette_idx in range(16): - # print(palette_idx, self._palette_lines[palette_idx]) - # palette_lower, palette_upper = self._palette_splits[palette_idx] palette_pixels = ( self._colours_cam[ self._palette_lines[palette_idx], :, :].reshape(-1, 3)) + # Fix reserved colours from the global palette and pick unique + # random colours from the sample points for the remaining initial + # centroids. This tends to increase the number of colours in the + # resulting image, and improves quality. initial_centroids = self._global_palette pixels_rgb_iigs = dither_pyx.convert_cam16ucs_to_rgb12_iigs( palette_pixels) @@ -236,10 +225,8 @@ class ClusterPalette: 0]) new_colour = pixels_rgb_iigs[choice, :] if tuple(new_colour) in seen_colours: - # print("Skipping") continue seen_colours.add(tuple(new_colour)) - # print(i, choice) initial_centroids[i, :] = new_colour palettes_rgb12_iigs, palette_error = \ @@ -251,13 +238,6 @@ class ClusterPalette: rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs ) - # if (palette_error >= self._errors[palette_idx] and not - # self._reserved_colours): - # # Not a local improvement to the existing palette, so ignore it. - # # We can't take this shortcut when we're reserving colours - # # because it would break the invariant that all palettes must - # # share colours. - # continue for i in range(16): new_palettes_cam[palette_idx, i, :] = ( np.array(dither_pyx.convert_rgb12_iigs_to_cam( @@ -265,10 +245,9 @@ class ClusterPalette: i]), dtype=np.float32)) new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs - new_errors[palette_idx] = palette_error self._palettes_accepted = False - return new_palettes_cam, new_palettes_rgb12_iigs, new_errors + return new_palettes_cam, new_palettes_rgb12_iigs def _fit_global_palette(self): """Compute a 16-colour palette for the entire image to use as @@ -299,7 +278,6 @@ class ClusterPalette: best_palette_lines = [v for k, v in sorted(list(zip( self._palette_line_errors, range(200))))] - # print(self._palette_lines) for palette_idx, palette_used in enumerate(palettes_used): if palette_used: continue @@ -309,44 +287,6 @@ class ClusterPalette: worst_line = best_palette_lines.pop() self._palette_lines[palette_idx] = [worst_line] - # print("Picked line %d with error %f" % (worst_line, - # self._palette_line_errors[worst_line])) - - - # - # worst_average_palette_error = 0 - # split_palette_idx = -1 - # idx = 0 - # for idx, lines in self._palette_lines.items(): - # if len(lines) < 10: - # continue - # average_palette_error = np.sum(self._palette_line_errors[ - # lines]) / len(lines) - # print(idx, average_palette_error) - # if average_palette_error > worst_average_palette_error: - # worst_average_palette_error = average_palette_error - # split_palette_idx = idx - # - # print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error)) - # # TODO: split off lines with largest error - # - # palette_line_errors = self._palette_line_errors[ - # self._palette_lines[split_palette_idx]] - # - # print(sorted( - # list(zip(palette_line_errors, self._palette_lines[ - # split_palette_idx])), reverse=True)) - # best_palette_lines = [v for k, v in sorted( - # list(zip(palette_line_errors, self._palette_lines[ - # split_palette_idx])))] - # num_max_lines = len(self._palette_lines[split_palette_idx]) - # - # self._palette_lines[split_palette_idx] = best_palette_lines[ - # :num_max_lines // 2] - # # Move worst half to new palette - # self._palette_lines[palette_idx] = best_palette_lines[ - # num_max_lines // 2:] - def main(): parser = argparse.ArgumentParser() @@ -410,7 +350,7 @@ def main(): # TODO: flags penalty = 1 # 1e18 # TODO: is this needed any more? - inner_iterations = 10 # 20 + inner_iterations = 10 outer_iterations = 20 pygame.init() @@ -476,7 +416,6 @@ def main(): canvas.blit(surface, (0, 0)) pygame.display.flip() - # print((palettes_rgb * 255).astype(np.uint8)) unique_colours = np.unique( palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0] print("%d unique colours" % unique_colours)