From 0009ce89139a7e4f953ab282ea19f823d28725c2 Mon Sep 17 00:00:00 2001 From: kris Date: Wed, 17 Nov 2021 17:09:42 +0000 Subject: [PATCH] - allow reserving a number of colours which are to be shared across all palettes. This will be useful for Total Replay which does an animation effect when displaying the image (first set palettes, then transition in pixels) - this requires us to go back to computing k-means ourself instead of using sklearn, since it can't keep some centroids fixed - try to be more careful about //gs RGB values, which are in the Rec.601 colour space. This isn't quite right yet - the issue seems to be that since we dither in linear RGB space but quantize in the nonlinear space, small differences may lead to a +/- 1 in the 4-bit //gs RGB value, which is quite noticeable. Instead we need to be clustering and/or dithering with awareness of the quantized palette space. --- convert.py | 113 ++++++++++++++++++++++++++++++++++++++--------------- dither.pyx | 54 ++++++++++++++++++++++++- screen.py | 1 + 3 files changed, 136 insertions(+), 32 deletions(-) diff --git a/convert.py b/convert.py index e266686..8483f63 100644 --- a/convert.py +++ b/convert.py @@ -28,8 +28,9 @@ import screen as screen_py class ClusterPalette: def __init__( - self, image: Image): + self, image: Image, reserved_colours=0): self._colours_cam = self._image_colours_cam(image) + self._reserved_colours = reserved_colours self._errors = [1e9] * 16 self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32) @@ -50,7 +51,14 @@ class ClusterPalette: clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000) clusters.fit_predict(self._colours_cam) - return clusters.cluster_centers_ + + labels = clusters.labels_ + frequency_order = [ + k for k, v in sorted( + # List of (palette idx, frequency count) + list(zip(*np.unique(labels, return_counts=True))), + key=lambda kv: kv[1], reverse=True)] + return clusters.cluster_centers_[frequency_order] def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. @@ -69,15 +77,16 @@ class ClusterPalette: The current (locally) best palettes are returned and can be applied using accept_palettes(). """ + new_errors = list(self._errors) + new_palettes_cam = np.copy(self._palettes_cam) + new_palettes_rgb = np.copy(self._palettes_rgb) # Compute a new 16-colour global palette for the entire image, # used as the starting center positions for k-means clustering of the # individual palettes self._global_palette = self._fit_global_palette() - new_errors = list(self._errors) - new_palettes_cam = np.copy(self._palettes_cam) - new_palettes_rgb = np.copy(self._palettes_rgb) + dynamic_colours = 16 - self._reserved_colours # The 16 palettes are striped across consecutive (overlapping) line # ranges. The basic unit is 200/16 = 12.5 lines, but we extend the @@ -100,25 +109,53 @@ class ClusterPalette: # be a major issue in practise though, and fixing it would require # implementing our own (optimized) k-means. # TODO: tune tolerance - clusters = cluster.MiniBatchKMeans( - n_clusters=16, max_iter=10000, init=self._global_palette, - n_init=1) - clusters.fit_predict(palette_pixels) - palette_error = clusters.inertia_ - if palette_error >= self._errors[palette_idx]: - # Not a local improvement to existing palette + # clusters = cluster.MiniBatchKMeans( + # n_clusters=16, max_iter=10000, + # init=self._global_palette, + # n_init=1) + # clusters.fit_predict(palette_pixels) + # + # palette_error = clusters.inertia_ + + clusters, palette_error = dither_pyx.k_means_with_fixed_centroids( + n_clusters=16, n_fixed=self._reserved_colours, + samples=palette_pixels, initial_centroids=self._global_palette, + max_iterations=1000, tolerance=1e-4 + ) + + if (palette_error >= self._errors[palette_idx] and not + self._reserved_colours): + # Not a local improvement to the existing palette, so ignore it. + # We can't take this shortcut when we're reserving colours + # because it would break the invariant that all palettes must + # share colours. continue - palette_cam = np.array(clusters.cluster_centers_).astype(np.float32) + new_palettes_cam[palette_idx, :, :] = np.array( + # clusters.cluster_centers_).astype(np.float32) + clusters).astype(np.float32) # Suppress divide by zero warning, # https://github.com/colour-science/colour/issues/900 with colour.utilities.suppress_warnings(python_warnings=True): - # SHR colour palette only uses 4-bit RGB values - palette_rgb = (np.round(colour.convert( - palette_cam, "CAM16UCS", "RGB") * 15) / 15).astype( - np.float32) - new_palettes_cam[palette_idx, :, :] = palette_cam - new_palettes_rgb[palette_idx, :, :] = palette_rgb + palette_rgb = colour.convert( + new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB") + palette_rgb_rec601 = np.clip(image_py.srgb_to_linear( + colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(palette_rgb * 255) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1) + # palette_rgb = np.clip( + # image_py.srgb_to_linear( + # colour.YCbCr_to_RGB( + # colour.RGB_to_YCbCr( + # image_py.linear_to_srgb( + # palette_rgb[:, :] * 255) / 255, + # K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + # K=colour.WEIGHTS_YCBCR[ + # 'ITU-R BT.601']) * 255) / 255, + # 0, 1) + new_palettes_rgb[palette_idx, :, :] = palette_rgb # palette_rgb_rec601 new_errors[palette_idx] = palette_error return new_palettes_cam, new_palettes_rgb, new_errors @@ -192,7 +229,7 @@ def main(): # TODO: flags penalty = 1e9 - iterations = 50 + iterations = 10 # 50 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -205,8 +242,8 @@ def main(): total_image_error = 1e9 iterations_since_improvement = 0 - palette_iigs = np.empty((16, 16, 3), dtype=np.uint8) - cluster_palette = ClusterPalette(rgb) + palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8) + cluster_palette = ClusterPalette(rgb, reserved_colours=1) while iterations_since_improvement < iterations: new_palettes_cam, new_palettes_rgb, new_palette_errors = ( @@ -237,11 +274,16 @@ def main(): palettes_rgb = new_palettes_rgb # Recompute 4-bit //gs RGB palettes + palette_rgb_rec601 = np.clip( + colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(palettes_rgb * 255) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1) + + palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8) for i in range(16): - palette_iigs[i, :, :] = ( - np.round(image_py.linear_to_srgb( - palettes_rgb[i, :, :] * 255) / 255 * 15)).astype(np.uint8) - screen.set_palette(i, palette_iigs[i, :, :]) + screen.set_palette(i, palettes_iigs[i, :, :]) # Recompute current screen RGB image screen.set_pixels(output_4bit) @@ -249,9 +291,18 @@ def main(): for i in range(200): screen.line_palette[i] = line_to_palette[i] output_rgb[i, :, :] = ( - palettes_rgb[line_to_palette[i]][ - output_4bit[i, :]] * 255).astype(np.uint8) - output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8) + palettes_rgb[line_to_palette[i]][output_4bit[i, :]] * 255 + ).astype( + # np.round(palettes_rgb[line_to_palette[i]][ + # output_4bit[i, :]] * 15) / 15 * 255).astype( + np.uint8) + output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(output_rgb) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255 + + output_srgb = (image_py.linear_to_srgb(output_rgb)).astype(np.uint8) # dither = dither_pattern.PATTERNS[args.dither]() # bitmap = dither_pyx.dither_image( @@ -275,8 +326,8 @@ def main(): np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes canvas.blit(surface, (0, 0)) pygame.display.flip() - - unique_colours = np.unique(palette_iigs.reshape(-1, 3), axis=0).shape[0] + print((palettes_rgb * 255).astype(np.uint8)) + unique_colours = np.unique(palettes_iigs.reshape(-1, 3), axis=0).shape[0] print("%d unique colours" % unique_colours) # Save Double hi-res image diff --git a/dither.pyx b/dither.pyx index e16e499..d1e53d8 100644 --- a/dither.pyx +++ b/dither.pyx @@ -341,7 +341,7 @@ def dither_image( def dither_shr( float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty): - cdef int y, x, idx, best_colour_idx, best_palette + cdef int y, x, idx, best_colour_idx, best_palette, i cdef double best_distance, distance, total_image_error cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam cdef float quant_error @@ -357,6 +357,8 @@ def dither_shr( total_image_error = 0.0 for y in range(200): for x in range(320): + #for i in range(3): + # working_image[y, x, i] = np.round(working_image[y, x, i] * 15) / 15 colour_cam = convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) line_cam[x, :] = colour_cam @@ -489,3 +491,53 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette best_palette_idx = palette_idx return best_palette_idx +@cython.boundscheck(False) +@cython.wraparound(False) +def k_means_with_fixed_centroids( + int n_clusters, int n_fixed, float[:, ::1] samples, float[:, ::1] initial_centroids, int max_iterations, float tolerance): + + cdef double error, best_error, centroid_movement, total_error + cdef int centroid_idx, closest_centroid_idx, i, point_idx + + cdef float[:, ::1] centroids = initial_centroids[:, :] + cdef float[::1] centroid, point, new_centroid = np.empty(3, dtype=np.float32) + + cdef float[:, ::1] centroid_sample_positions_total + cdef int[::1] centroid_sample_counts + + for iteration in range(max_iterations): + total_error = 0.0 + centroid_movement = 0.0 + centroid_sample_positions_total = np.zeros((16, 3), dtype=np.float32) + centroid_sample_counts = np.zeros(16, dtype=np.int32) + + for point_idx in range(samples.shape[0]): + point = samples[point_idx, :] + best_error = 1e9 + closest_centroid_idx = 0 + for centroid_idx in range(n_clusters): + centroid = centroids[centroid_idx, :] + error = colour_distance_squared(centroid, point) + if error < best_error: + best_error = error + closest_centroid_idx = centroid_idx + for i in range(3): + centroid_sample_positions_total[closest_centroid_idx, i] += point[i] + centroid_sample_counts[closest_centroid_idx] += 1 + total_error += best_error + + for centroid_idx in range(n_fixed, n_clusters): + if centroid_sample_counts[centroid_idx]: + for i in range(3): + new_centroid[i] = ( + centroid_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx]) + centroid_movement += colour_distance_squared(centroids[centroid_idx], new_centroid) + + centroids[centroid_idx, :] = new_centroid + + # print(iteration, total_error, centroids) + + if centroid_movement < tolerance: + break + + return centroids, total_error diff --git a/screen.py b/screen.py index dd0b554..8a51d45 100644 --- a/screen.py +++ b/screen.py @@ -55,6 +55,7 @@ class SHR320Screen: for palette_idx, palette in self.palettes.items(): for rgb_idx, rgb in enumerate(palette): r, g, b = rgb + assert r <= 15 and g <= 15 and b <= 15 # print(r, g, b) rgb_low = (g << 4) | b rgb_hi = r