Tweak k-means convergence criterion to return once the total centroid position error stops decreasing.

This commit is contained in:
kris 2021-11-25 21:33:12 +00:00
parent fc35387360
commit 61b4cbb184
2 changed files with 24 additions and 29 deletions

View File

@ -266,14 +266,12 @@ class ClusterPalette:
initial_centroids[fixed_colours, :] = colour
fixed_colours += 1
palette_rgb12_iigs, palette_error = \
dither_pyx.k_means_with_fixed_centroids(
palette_rgb12_iigs = dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=fixed_colours,
samples=palette_pixels,
initial_centroids=initial_centroids,
max_iterations=1000, tolerance=0.05,
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
)
max_iterations=1000,
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs)
# If the k-means clustering returned fewer than 16 unique colours,
# fill out the remainder with the most common pixels colours that
# have not yet been used.

View File

@ -672,10 +672,9 @@ def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
@cython.boundscheck(False)
@cython.wraparound(False)
def k_means_with_fixed_centroids(
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations,
float tolerance, float [:, ::1] rgb12_iigs_to_cam16ucs):
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, float [:, ::1] rgb12_iigs_to_cam16ucs):
cdef double error, best_error, centroid_movement, total_error
cdef double error, best_error, total_error, last_total_error
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
@ -686,20 +685,17 @@ def k_means_with_fixed_centroids(
cdef float[:, ::1] centroid_cam_sample_positions_total
cdef int[::1] centroid_sample_counts
# Allow centroids to move on lattice of size 15/255 in sRGB Rec.601 space -- matches //gs palette
# map centroids to CAM when computing distances, cluster means etc
# Map new centroid back to closest lattice point
# Return CAM centroids
cdef int centroid_moved
last_total_error = 1e9
for iteration in range(max_iterations):
centroid_moved = 1
total_error = 0.0
centroid_movement = 0.0
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_sample_counts = np.zeros(16, dtype=np.int32)
# For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples
# but we do this by accumulating the (coordinate vector) total and number of associated samples.
#
# Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour
# space.
for point_idx in range(samples.shape[0]):
point_cam = samples[point_idx, :]
best_error = 1e9
@ -715,28 +711,29 @@ def k_means_with_fixed_centroids(
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error
# Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we
# can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest
# available point to an intended next centroid position will increase the total distance, or centroids may
# oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing.
if total_error >= last_total_error:
break
last_total_error = total_error
# Compute new centroid positions in CAM16UCS colour space
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroids_cam[centroid_idx - n_fixed, i] = (
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(
_convert_rgb12_iigs_to_cam(
rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]),
new_centroids_cam[centroid_idx - n_fixed, :])
# Convert all new centroids as a single matrix since _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead
# Convert all new centroids back to //gb RGB colour space (done as a single matrix since
# _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead)
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
# Update positions for non-fixed centroids
for centroid_idx in range(n_clusters - n_fixed):
for i in range(3):
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
centroid_moved = 1
if centroid_movement < tolerance:
break
if centroid_moved == 0:
break
return centroids_rgb12, total_error
return centroids_rgb12