From 83b047b73f344f17595ea83d3b111c7f5c76961a Mon Sep 17 00:00:00 2001 From: kris Date: Tue, 16 Nov 2021 15:44:04 +0000 Subject: [PATCH] Whoops, fix a major bug with the iterated image fitting: we don't want to mutate our source image! Fix another bug introduced in the previous commit: convert from linear rgb before quantizing //gs RGB palette since //gs RGB values are in Rec.601 colour space. Switch to double for colour_squared_distance and related variables, not sure if it matters though. When iterating palette clustering, reject the new palettes if they would increase the total image error. This prevents accepting changes that are local improvements to one palette but which would introduce more net errors elsewhere when this palette is reused. This now seems to give monotonic improvements in image quality so no need to write out intermediate images any more. --- convert.py | 79 ++++++++++++++++++++++++++++++++---------------------- dither.pyx | 17 +++++++----- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/convert.py b/convert.py index 8fa01ee..0b931e6 100644 --- a/convert.py +++ b/convert.py @@ -24,10 +24,8 @@ class ClusterPalette: def __init__(self, image: Image): self._colours_cam = self._image_colours_cam(image) self._best_palette_distances = [1e9] * 16 - self._iterations = 0 self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32) - self._global_palette = self._fit_global_palette() def _image_colours_cam(self, image: Image): colours_rgb = np.asarray(image).reshape((-1, 3)) @@ -46,16 +44,15 @@ class ClusterPalette: return clusters.cluster_centers_ def iterate(self): - self._iterations += 1 - print("Iteration %d" % self._iterations) + self._global_palette = self._fit_global_palette() for palette_idx in range(16): p_lower = max(palette_idx - 1.5, 0) p_upper = min(palette_idx + 2.5, 16) # TODO: dynamically tune palette cuts palette_pixels = self._colours_cam[ - int(p_lower * (200 / 16)) * 320:int(p_upper * ( - 200 / 16)) * 320, :] + int(p_lower * (200 / 16)) * 320:int(p_upper * ( + 200 / 16)) * 320, :] best_wce = self._best_palette_distances[palette_idx] # TODO: tolerance @@ -63,12 +60,7 @@ class ClusterPalette: n_clusters=16, max_iter=10000, init=self._global_palette, n_init=1) clusters.fit_predict(palette_pixels) - if clusters.inertia_ < (best_wce * 0.99): - # TODO: sentinel - if best_wce < 1e9: - print("Improved palette %d (+%f%%)" % ( - palette_idx, best_wce / clusters.inertia_)) - + if clusters.inertia_ < best_wce: self._palettes_cam[palette_idx, :, :] = np.array( clusters.cluster_centers_).astype(np.float32) best_wce = clusters.inertia_ @@ -147,26 +139,51 @@ def main(): gamma=args.gamma_correct)).astype(np.float32) / 255 # TODO: flags - penalty = 10 # 1e9 - iterations = 50 + penalty = 1e9 # 0 # 1e9 + iterations = 50 # 0 pygame.init() # TODO: for some reason I need to execute this twice - the first time # the window is created and immediately destroyed - pygame.display.set_mode((640, 400)) + _ = pygame.display.set_mode((640, 400)) canvas = pygame.display.set_mode((640, 400)) canvas.fill((0, 0, 0)) pygame.display.flip() + total_image_error = 1e9 cluster_palette = ClusterPalette(rgb) + image_generation = 0 for iteration in range(iterations): - palettes_cam, palettes_rgb = cluster_palette.iterate() - for i in range(16): - screen.set_palette(i, (np.round(palettes_rgb[i, :, :] * 15)).astype( - np.uint8)) + old_best_palette_distances = cluster_palette._best_palette_distances + old_palettes_cam = cluster_palette._palettes_cam + old_palettes_rgb = cluster_palette._palettes_rgb - output_4bit, line_to_palette = dither_pyx.dither_shr( - rgb, palettes_cam, palettes_rgb, rgb_to_cam16, float(penalty)) + new_palettes_cam, new_palettes_rgb = cluster_palette.iterate() + output_4bit, line_to_palette, new_total_image_error = \ + dither_pyx.dither_shr( + rgb, new_palettes_cam, new_palettes_rgb, rgb_to_cam16, + float(penalty) + ) + + if new_total_image_error < total_image_error: + if total_image_error < 1e9: + print("Improved quality +%f%% (%f)" % ( + (1 - new_total_image_error / total_image_error) * 100, + new_total_image_error)) + total_image_error = new_total_image_error + palettes_rgb = new_palettes_rgb + else: + cluster_palette._palettes_cam = old_palettes_cam + cluster_palette._palettes_rgb = old_palettes_rgb + cluster_palette._best_palette_distances = old_best_palette_distances + continue + + image_generation += 1 + for i in range(16): + screen.set_palette(i, ( + np.round(image_py.linear_to_srgb(palettes_rgb[i, :, + :] * 255) / 255 * 15)).astype( + np.uint8)) screen.set_pixels(output_4bit) output_rgb = np.empty((200, 320, 3), dtype=np.uint8) for i in range(200): @@ -199,17 +216,15 @@ def main(): canvas.blit(surface, (0, 0)) pygame.display.flip() - # Save Double hi-res image - outfile = os.path.join(os.path.splitext(args.output)[0] + - "-%d-preview.png" % cluster_palette._iterations) - out_image.save(outfile, "PNG") - screen.pack() - # with open(args.output, "wb") as f: - # f.write(bytes(screen.aux)) - # f.write(bytes(screen.main)) - with open("%s-%s" % (args.output, cluster_palette._iterations), - "wb") as f: - f.write(bytes(screen.memory)) + # Save Double hi-res image + outfile = os.path.join(os.path.splitext(args.output)[0] + "-preview.png") + out_image.save(outfile, "PNG") + screen.pack() + # with open(args.output, "wb") as f: + # f.write(bytes(screen.aux)) + # f.write(bytes(screen.main)) + with open(args.output, "wb") as f: + f.write(bytes(screen.memory)) if __name__ == "__main__": diff --git a/dither.pyx b/dither.pyx index cea6a28..ff974d3 100644 --- a/dither.pyx +++ b/dither.pyx @@ -171,7 +171,7 @@ cdef inline float fabs(float value) nogil: @cython.boundscheck(False) @cython.wraparound(False) -cdef inline float colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil: +cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil: return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2 @@ -339,20 +339,21 @@ import colour @cython.boundscheck(False) @cython.wraparound(False) -def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty): +def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty): cdef int y, x, idx, best_colour_idx, best_palette - cdef float best_distance, distance + cdef double best_distance, distance, total_image_error cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam cdef float quant_error cdef float[:, ::1] palette_rgb cdef (unsigned char)[:, ::1] output_4bit = np.zeros((200, 320), dtype=np.uint8) - # cdef (unsigned char)[:, :, ::1] output_rgb = np.zeros((200, 320, 3), dtype=np.uint8) - + cdef float[:, :, ::1] working_image = np.copy(input_rgb) cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32) cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32) + best_palette = 15 + total_image_error = 0.0 for y in range(200): # print(y) for x in range(320): @@ -380,6 +381,8 @@ def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, fl best_colour_idx = idx best_colour_rgb = palette_rgb[best_colour_idx] output_4bit[y, x] = best_colour_idx + total_image_error += best_distance + # print(y,x,best_distance,total_image_error) for i in range(3): quant_error = working_image[y, x, i] - best_colour_rgb[i] @@ -449,7 +452,7 @@ def dither_shr(float[:, :, ::1] working_image, float[:, :, ::1] palettes_cam, fl # working_image[y + 2, x + 2, i] + quant_error * (1 / 48), # 0, 1) - return np.array(output_4bit, dtype=np.uint8), line_to_palette + return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error import collections import random @@ -511,7 +514,7 @@ def k_means_with_fixed_centroids( @cython.wraparound(False) cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int base_palette_idx, int last_palette_idx, float last_penalty) nogil: cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx - cdef float best_total_dist, total_dist, best_pixel_dist, pixel_dist + cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist cdef float[:, ::1] palette_cam cdef float[::1] pixel_cam, palette_entry