From 3b8767782bf29342f74bd40c9088aafc19127d29 Mon Sep 17 00:00:00 2001 From: kris Date: Wed, 24 Nov 2021 11:47:39 +0000 Subject: [PATCH] Each run seems to converge fairly quickly but there is a lot of variation across runs. Run in a loop and keep the running best. --- convert.py | 184 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 105 insertions(+), 79 deletions(-) diff --git a/convert.py b/convert.py index d96b6f2..3cf2645 100644 --- a/convert.py +++ b/convert.py @@ -61,9 +61,12 @@ class ClusterPalette: self._rgb24_to_cam16ucs = rgb24_to_cam16ucs + self._palette_lines = defaultdict(list) + # List of line ranges used to train the 16 SHR palettes # [(lower_0, upper_0), ...] self._palette_splits = self._equal_palette_splits() + self._init_palette_lines() # Whether the previous iteration of proposed palettes was accepted self._palettes_accepted = False @@ -74,7 +77,7 @@ class ClusterPalette: # Delta applied to palette split in previous iteration self._palette_mutate_delta = (0, 0) - self._palette_lines = defaultdict(list) + def _init_palette_lines(self): for i, lh in enumerate(self._palette_splits): l, h = lh self._palette_lines[i].extend(list(range(l, h))) @@ -132,41 +135,46 @@ class ClusterPalette: total_image_error) def iterate(self, penalty: float, max_iterations: int): - iterations_since_improvement = 0 total_image_error = 1e9 last_good_splits = self._palette_splits - self._fit_global_palette() - while iterations_since_improvement < max_iterations: - # print("Iterations %d" % iterations_since_improvement) - new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( - self._propose_palettes()) - - # Recompute image with proposed palettes and check whether it has - # lower total image error than our previous best. - (output_4bit, line_to_palette, palettes_linear_rgb, - new_total_image_error) = self._dither_image( - new_palettes_cam, penalty) - - self._reassign_unused_palettes(line_to_palette, - last_good_splits) - - print(total_image_error, new_total_image_error) - if new_total_image_error >= total_image_error: - iterations_since_improvement += 1 - continue - - # We found a globally better set of palettes + while True: + print("New iteration") iterations_since_improvement = 0 - last_good_splits = self._palette_splits - total_image_error = new_total_image_error + self._palette_splits = self._equal_palette_splits() + self._init_palette_lines() - self._palettes_cam = new_palettes_cam - self._palettes_rgb = new_palettes_rgb12_iigs - self._errors = new_palette_errors - self._palettes_accepted = True + self._fit_global_palette() + while iterations_since_improvement < max_iterations: + # print("Iterations %d" % iterations_since_improvement) + new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = ( + self._propose_palettes()) - yield (new_total_image_error, output_4bit, line_to_palette, - new_palettes_rgb12_iigs, palettes_linear_rgb) + # Recompute image with proposed palettes and check whether it has + # lower total image error than our previous best. + (output_4bit, line_to_palette, palettes_linear_rgb, + new_total_image_error) = self._dither_image( + new_palettes_cam, penalty) + + self._reassign_unused_palettes(line_to_palette, + last_good_splits) + + # print(total_image_error, new_total_image_error) + if new_total_image_error >= total_image_error: + iterations_since_improvement += 1 + continue + + # We found a globally better set of palettes + iterations_since_improvement = 0 + last_good_splits = self._palette_splits + total_image_error = new_total_image_error + + self._palettes_cam = new_palettes_cam + self._palettes_rgb = new_palettes_rgb12_iigs + self._errors = new_palette_errors + self._palettes_accepted = True + + yield (new_total_image_error, output_4bit, line_to_palette, + new_palettes_rgb12_iigs, palettes_linear_rgb) def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. @@ -192,7 +200,7 @@ class ClusterPalette: # Compute a new 16-colour global palette for the entire image, # used as the starting center positions for k-means clustering of the # individual palettes - # self._fit_global_palette() + self._fit_global_palette() self._mutate_palette_splits() for palette_idx in range(16): @@ -211,13 +219,13 @@ class ClusterPalette: rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs ) - if (palette_error >= self._errors[palette_idx] and not - self._reserved_colours): - # Not a local improvement to the existing palette, so ignore it. - # We can't take this shortcut when we're reserving colours - # because it would break the invariant that all palettes must - # share colours. - continue + # if (palette_error >= self._errors[palette_idx] and not + # self._reserved_colours): + # # Not a local improvement to the existing palette, so ignore it. + # # We can't take this shortcut when we're reserving colours + # # because it would break the invariant that all palettes must + # # share colours. + # continue for i in range(16): new_palettes_cam[palette_idx, i, :] = ( np.array(dither_pyx.convert_rgb12_iigs_to_cam( @@ -299,43 +307,56 @@ class ClusterPalette: palettes_used = [False] * 16 for palette in new_line_to_palette: palettes_used[palette] = True + best_palette_lines = [v for k, v in sorted(list(zip( + self._palette_line_errors, range(200))))] + + # print(self._palette_lines) for palette_idx, palette_used in enumerate(palettes_used): if palette_used: continue - print("Reassigning palette %d" % palette_idx) + # print("Reassigning palette %d" % palette_idx) - worst_average_palette_error = 0 - split_palette_idx = -1 - idx = 0 - for idx, lines in self._palette_lines.items(): - if len(lines) < 10: - continue - average_palette_error = np.sum(self._palette_line_errors[ - lines]) / len(lines) - print(idx, average_palette_error) - if average_palette_error > worst_average_palette_error: - worst_average_palette_error = average_palette_error - split_palette_idx = idx + # TODO: also remove from old entry + worst_line = best_palette_lines.pop() + self._palette_lines[palette_idx] = [worst_line] - print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error)) - # TODO: split off lines with largest error + # print("Picked line %d with error %f" % (worst_line, + # self._palette_line_errors[worst_line])) - palette_line_errors = self._palette_line_errors[ - self._palette_lines[split_palette_idx]] - print(sorted( - list(zip(palette_line_errors, self._palette_lines[ - split_palette_idx])), reverse=True)) - best_palette_lines = [v for k, v in sorted( - list(zip(palette_line_errors, self._palette_lines[ - split_palette_idx])))] - num_max_lines = len(self._palette_lines[split_palette_idx]) - - self._palette_lines[split_palette_idx] = best_palette_lines[ - :num_max_lines // 2] - # Move worst half to new palette - self._palette_lines[palette_idx] = best_palette_lines[ - num_max_lines // 2:] + # + # worst_average_palette_error = 0 + # split_palette_idx = -1 + # idx = 0 + # for idx, lines in self._palette_lines.items(): + # if len(lines) < 10: + # continue + # average_palette_error = np.sum(self._palette_line_errors[ + # lines]) / len(lines) + # print(idx, average_palette_error) + # if average_palette_error > worst_average_palette_error: + # worst_average_palette_error = average_palette_error + # split_palette_idx = idx + # + # print("Picked %d with avg error %f" % (split_palette_idx, worst_average_palette_error)) + # # TODO: split off lines with largest error + # + # palette_line_errors = self._palette_line_errors[ + # self._palette_lines[split_palette_idx]] + # + # print(sorted( + # list(zip(palette_line_errors, self._palette_lines[ + # split_palette_idx])), reverse=True)) + # best_palette_lines = [v for k, v in sorted( + # list(zip(palette_line_errors, self._palette_lines[ + # split_palette_idx])))] + # num_max_lines = len(self._palette_lines[split_palette_idx]) + # + # self._palette_lines[split_palette_idx] = best_palette_lines[ + # :num_max_lines // 2] + # # Move worst half to new palette + # self._palette_lines[palette_idx] = best_palette_lines[ + # num_max_lines // 2:] def main(): @@ -400,7 +421,7 @@ def main(): # TODO: flags penalty = 1 # 1e18 # TODO: is this needed any more? - iterations = 200 + iterations = 10 # 20 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -417,6 +438,7 @@ def main(): rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs=rgb24_to_cam16ucs) + seq = 0 for (new_total_image_error, output_4bit, line_to_palette, palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate( penalty, iterations): @@ -463,20 +485,24 @@ def main(): np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes canvas.blit(surface, (0, 0)) pygame.display.flip() + + seq += 1 + # Save Double hi-res image + outfile = os.path.join( + os.path.splitext(args.output)[0] + "-%d-preview.png" % seq) + out_image.save(outfile, "PNG") + screen.pack() + # with open(args.output, "wb") as f: + # f.write(bytes(screen.aux)) + # f.write(bytes(screen.main)) + with open(args.output, "wb") as f: + f.write(bytes(screen.memory)) + # print((palettes_rgb * 255).astype(np.uint8)) unique_colours = np.unique( palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0] print("%d unique colours" % unique_colours) - # Save Double hi-res image - outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png") - out_image.save(outfile, "PNG") - screen.pack() - # with open(args.output, "wb") as f: - # f.write(bytes(screen.aux)) - # f.write(bytes(screen.main)) - with open(args.output, "wb") as f: - f.write(bytes(screen.memory)) if __name__ == "__main__":