From 61b4cbb18499cadc94bc1e872ce287eb9a3d3718 Mon Sep 17 00:00:00 2001 From: kris Date: Thu, 25 Nov 2021 21:33:12 +0000 Subject: [PATCH] Tweak k-means convergence criterion to return once the total centroid position error stops decreasing. --- convert.py | 8 +++----- dither.pyx | 45 +++++++++++++++++++++------------------------ 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/convert.py b/convert.py index 86298ea..68a7d4d 100644 --- a/convert.py +++ b/convert.py @@ -266,14 +266,12 @@ class ClusterPalette: initial_centroids[fixed_colours, :] = colour fixed_colours += 1 - palette_rgb12_iigs, palette_error = \ - dither_pyx.k_means_with_fixed_centroids( + palette_rgb12_iigs = dither_pyx.k_means_with_fixed_centroids( n_clusters=16, n_fixed=fixed_colours, samples=palette_pixels, initial_centroids=initial_centroids, - max_iterations=1000, tolerance=0.05, - rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs - ) + max_iterations=1000, + rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs) # If the k-means clustering returned fewer than 16 unique colours, # fill out the remainder with the most common pixels colours that # have not yet been used. diff --git a/dither.pyx b/dither.pyx index d7a63df..60152b2 100644 --- a/dither.pyx +++ b/dither.pyx @@ -672,10 +672,9 @@ def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam): @cython.boundscheck(False) @cython.wraparound(False) def k_means_with_fixed_centroids( - int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, - float tolerance, float [:, ::1] rgb12_iigs_to_cam16ucs): + int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, float [:, ::1] rgb12_iigs_to_cam16ucs): - cdef double error, best_error, centroid_movement, total_error + cdef double error, best_error, total_error, last_total_error cdef int centroid_idx, closest_centroid_idx, i, point_idx cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids) @@ -686,20 +685,17 @@ def k_means_with_fixed_centroids( cdef float[:, ::1] centroid_cam_sample_positions_total cdef int[::1] centroid_sample_counts - # Allow centroids to move on lattice of size 15/255 in sRGB Rec.601 space -- matches //gs palette - # map centroids to CAM when computing distances, cluster means etc - # Map new centroid back to closest lattice point - - # Return CAM centroids - - cdef int centroid_moved + last_total_error = 1e9 for iteration in range(max_iterations): - centroid_moved = 1 total_error = 0.0 - centroid_movement = 0.0 centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32) centroid_sample_counts = np.zeros(16, dtype=np.int32) + # For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples + # but we do this by accumulating the (coordinate vector) total and number of associated samples. + # + # Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour + # space. for point_idx in range(samples.shape[0]): point_cam = samples[point_idx, :] best_error = 1e9 @@ -715,28 +711,29 @@ def k_means_with_fixed_centroids( centroid_sample_counts[closest_centroid_idx] += 1 total_error += best_error + # Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we + # can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest + # available point to an intended next centroid position will increase the total distance, or centroids may + # oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing. + if total_error >= last_total_error: + break + last_total_error = total_error + + # Compute new centroid positions in CAM16UCS colour space for centroid_idx in range(n_fixed, n_clusters): if centroid_sample_counts[centroid_idx]: for i in range(3): new_centroids_cam[centroid_idx - n_fixed, i] = ( centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx]) - centroid_movement += colour_distance_squared( - _convert_rgb12_iigs_to_cam( - rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]), - new_centroids_cam[centroid_idx - n_fixed, :]) - # Convert all new centroids as a single matrix since _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead + # Convert all new centroids back to //gb RGB colour space (done as a single matrix since + # _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead) new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam) + # Update positions for non-fixed centroids for centroid_idx in range(n_clusters - n_fixed): for i in range(3): if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]: centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i] - centroid_moved = 1 - if centroid_movement < tolerance: - break - if centroid_moved == 0: - break - - return centroids_rgb12, total_error \ No newline at end of file + return centroids_rgb12 \ No newline at end of file