diff --git a/convert.py b/convert.py index e266686..8483f63 100644 --- a/convert.py +++ b/convert.py @@ -28,8 +28,9 @@ import screen as screen_py class ClusterPalette: def __init__( - self, image: Image): + self, image: Image, reserved_colours=0): self._colours_cam = self._image_colours_cam(image) + self._reserved_colours = reserved_colours self._errors = [1e9] * 16 self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32) @@ -50,7 +51,14 @@ class ClusterPalette: clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000) clusters.fit_predict(self._colours_cam) - return clusters.cluster_centers_ + + labels = clusters.labels_ + frequency_order = [ + k for k, v in sorted( + # List of (palette idx, frequency count) + list(zip(*np.unique(labels, return_counts=True))), + key=lambda kv: kv[1], reverse=True)] + return clusters.cluster_centers_[frequency_order] def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: """Attempt to find new palettes that locally improve image quality. @@ -69,15 +77,16 @@ class ClusterPalette: The current (locally) best palettes are returned and can be applied using accept_palettes(). """ + new_errors = list(self._errors) + new_palettes_cam = np.copy(self._palettes_cam) + new_palettes_rgb = np.copy(self._palettes_rgb) # Compute a new 16-colour global palette for the entire image, # used as the starting center positions for k-means clustering of the # individual palettes self._global_palette = self._fit_global_palette() - new_errors = list(self._errors) - new_palettes_cam = np.copy(self._palettes_cam) - new_palettes_rgb = np.copy(self._palettes_rgb) + dynamic_colours = 16 - self._reserved_colours # The 16 palettes are striped across consecutive (overlapping) line # ranges. The basic unit is 200/16 = 12.5 lines, but we extend the @@ -100,25 +109,53 @@ class ClusterPalette: # be a major issue in practise though, and fixing it would require # implementing our own (optimized) k-means. # TODO: tune tolerance - clusters = cluster.MiniBatchKMeans( - n_clusters=16, max_iter=10000, init=self._global_palette, - n_init=1) - clusters.fit_predict(palette_pixels) - palette_error = clusters.inertia_ - if palette_error >= self._errors[palette_idx]: - # Not a local improvement to existing palette + # clusters = cluster.MiniBatchKMeans( + # n_clusters=16, max_iter=10000, + # init=self._global_palette, + # n_init=1) + # clusters.fit_predict(palette_pixels) + # + # palette_error = clusters.inertia_ + + clusters, palette_error = dither_pyx.k_means_with_fixed_centroids( + n_clusters=16, n_fixed=self._reserved_colours, + samples=palette_pixels, initial_centroids=self._global_palette, + max_iterations=1000, tolerance=1e-4 + ) + + if (palette_error >= self._errors[palette_idx] and not + self._reserved_colours): + # Not a local improvement to the existing palette, so ignore it. + # We can't take this shortcut when we're reserving colours + # because it would break the invariant that all palettes must + # share colours. continue - palette_cam = np.array(clusters.cluster_centers_).astype(np.float32) + new_palettes_cam[palette_idx, :, :] = np.array( + # clusters.cluster_centers_).astype(np.float32) + clusters).astype(np.float32) # Suppress divide by zero warning, # https://github.com/colour-science/colour/issues/900 with colour.utilities.suppress_warnings(python_warnings=True): - # SHR colour palette only uses 4-bit RGB values - palette_rgb = (np.round(colour.convert( - palette_cam, "CAM16UCS", "RGB") * 15) / 15).astype( - np.float32) - new_palettes_cam[palette_idx, :, :] = palette_cam - new_palettes_rgb[palette_idx, :, :] = palette_rgb + palette_rgb = colour.convert( + new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB") + palette_rgb_rec601 = np.clip(image_py.srgb_to_linear( + colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(palette_rgb * 255) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1) + # palette_rgb = np.clip( + # image_py.srgb_to_linear( + # colour.YCbCr_to_RGB( + # colour.RGB_to_YCbCr( + # image_py.linear_to_srgb( + # palette_rgb[:, :] * 255) / 255, + # K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + # K=colour.WEIGHTS_YCBCR[ + # 'ITU-R BT.601']) * 255) / 255, + # 0, 1) + new_palettes_rgb[palette_idx, :, :] = palette_rgb # palette_rgb_rec601 new_errors[palette_idx] = palette_error return new_palettes_cam, new_palettes_rgb, new_errors @@ -192,7 +229,7 @@ def main(): # TODO: flags penalty = 1e9 - iterations = 50 + iterations = 10 # 50 pygame.init() # TODO: for some reason I need to execute this twice - the first time @@ -205,8 +242,8 @@ def main(): total_image_error = 1e9 iterations_since_improvement = 0 - palette_iigs = np.empty((16, 16, 3), dtype=np.uint8) - cluster_palette = ClusterPalette(rgb) + palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8) + cluster_palette = ClusterPalette(rgb, reserved_colours=1) while iterations_since_improvement < iterations: new_palettes_cam, new_palettes_rgb, new_palette_errors = ( @@ -237,11 +274,16 @@ def main(): palettes_rgb = new_palettes_rgb # Recompute 4-bit //gs RGB palettes + palette_rgb_rec601 = np.clip( + colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(palettes_rgb * 255) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1) + + palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8) for i in range(16): - palette_iigs[i, :, :] = ( - np.round(image_py.linear_to_srgb( - palettes_rgb[i, :, :] * 255) / 255 * 15)).astype(np.uint8) - screen.set_palette(i, palette_iigs[i, :, :]) + screen.set_palette(i, palettes_iigs[i, :, :]) # Recompute current screen RGB image screen.set_pixels(output_4bit) @@ -249,9 +291,18 @@ def main(): for i in range(200): screen.line_palette[i] = line_to_palette[i] output_rgb[i, :, :] = ( - palettes_rgb[line_to_palette[i]][ - output_4bit[i, :]] * 255).astype(np.uint8) - output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8) + palettes_rgb[line_to_palette[i]][output_4bit[i, :]] * 255 + ).astype( + # np.round(palettes_rgb[line_to_palette[i]][ + # output_4bit[i, :]] * 15) / 15 * 255).astype( + np.uint8) + output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB( + colour.RGB_to_YCbCr( + image_py.linear_to_srgb(output_rgb) / 255, + K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), + K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255 + + output_srgb = (image_py.linear_to_srgb(output_rgb)).astype(np.uint8) # dither = dither_pattern.PATTERNS[args.dither]() # bitmap = dither_pyx.dither_image( @@ -275,8 +326,8 @@ def main(): np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes canvas.blit(surface, (0, 0)) pygame.display.flip() - - unique_colours = np.unique(palette_iigs.reshape(-1, 3), axis=0).shape[0] + print((palettes_rgb * 255).astype(np.uint8)) + unique_colours = np.unique(palettes_iigs.reshape(-1, 3), axis=0).shape[0] print("%d unique colours" % unique_colours) # Save Double hi-res image diff --git a/dither.pyx b/dither.pyx index e16e499..d1e53d8 100644 --- a/dither.pyx +++ b/dither.pyx @@ -341,7 +341,7 @@ def dither_image( def dither_shr( float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty): - cdef int y, x, idx, best_colour_idx, best_palette + cdef int y, x, idx, best_colour_idx, best_palette, i cdef double best_distance, distance, total_image_error cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam cdef float quant_error @@ -357,6 +357,8 @@ def dither_shr( total_image_error = 0.0 for y in range(200): for x in range(320): + #for i in range(3): + # working_image[y, x, i] = np.round(working_image[y, x, i] * 15) / 15 colour_cam = convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) line_cam[x, :] = colour_cam @@ -489,3 +491,53 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette best_palette_idx = palette_idx return best_palette_idx +@cython.boundscheck(False) +@cython.wraparound(False) +def k_means_with_fixed_centroids( + int n_clusters, int n_fixed, float[:, ::1] samples, float[:, ::1] initial_centroids, int max_iterations, float tolerance): + + cdef double error, best_error, centroid_movement, total_error + cdef int centroid_idx, closest_centroid_idx, i, point_idx + + cdef float[:, ::1] centroids = initial_centroids[:, :] + cdef float[::1] centroid, point, new_centroid = np.empty(3, dtype=np.float32) + + cdef float[:, ::1] centroid_sample_positions_total + cdef int[::1] centroid_sample_counts + + for iteration in range(max_iterations): + total_error = 0.0 + centroid_movement = 0.0 + centroid_sample_positions_total = np.zeros((16, 3), dtype=np.float32) + centroid_sample_counts = np.zeros(16, dtype=np.int32) + + for point_idx in range(samples.shape[0]): + point = samples[point_idx, :] + best_error = 1e9 + closest_centroid_idx = 0 + for centroid_idx in range(n_clusters): + centroid = centroids[centroid_idx, :] + error = colour_distance_squared(centroid, point) + if error < best_error: + best_error = error + closest_centroid_idx = centroid_idx + for i in range(3): + centroid_sample_positions_total[closest_centroid_idx, i] += point[i] + centroid_sample_counts[closest_centroid_idx] += 1 + total_error += best_error + + for centroid_idx in range(n_fixed, n_clusters): + if centroid_sample_counts[centroid_idx]: + for i in range(3): + new_centroid[i] = ( + centroid_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx]) + centroid_movement += colour_distance_squared(centroids[centroid_idx], new_centroid) + + centroids[centroid_idx, :] = new_centroid + + # print(iteration, total_error, centroids) + + if centroid_movement < tolerance: + break + + return centroids, total_error diff --git a/screen.py b/screen.py index dd0b554..8a51d45 100644 --- a/screen.py +++ b/screen.py @@ -55,6 +55,7 @@ class SHR320Screen: for palette_idx, palette in self.palettes.items(): for rgb_idx, rgb in enumerate(palette): r, g, b = rgb + assert r <= 15 and g <= 15 and b <= 15 # print(r, g, b) rgb_low = (g << 4) | b rgb_hi = r