Tweak k-means convergence criterion to return once the total centroid position error stops decreasing.
This commit is contained in:
parent
fc35387360
commit
61b4cbb184
|
@ -266,14 +266,12 @@ class ClusterPalette:
|
||||||
initial_centroids[fixed_colours, :] = colour
|
initial_centroids[fixed_colours, :] = colour
|
||||||
fixed_colours += 1
|
fixed_colours += 1
|
||||||
|
|
||||||
palette_rgb12_iigs, palette_error = \
|
palette_rgb12_iigs = dither_pyx.k_means_with_fixed_centroids(
|
||||||
dither_pyx.k_means_with_fixed_centroids(
|
|
||||||
n_clusters=16, n_fixed=fixed_colours,
|
n_clusters=16, n_fixed=fixed_colours,
|
||||||
samples=palette_pixels,
|
samples=palette_pixels,
|
||||||
initial_centroids=initial_centroids,
|
initial_centroids=initial_centroids,
|
||||||
max_iterations=1000, tolerance=0.05,
|
max_iterations=1000,
|
||||||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
|
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs)
|
||||||
)
|
|
||||||
# If the k-means clustering returned fewer than 16 unique colours,
|
# If the k-means clustering returned fewer than 16 unique colours,
|
||||||
# fill out the remainder with the most common pixels colours that
|
# fill out the remainder with the most common pixels colours that
|
||||||
# have not yet been used.
|
# have not yet been used.
|
||||||
|
|
45
dither.pyx
45
dither.pyx
|
@ -672,10 +672,9 @@ def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
def k_means_with_fixed_centroids(
|
def k_means_with_fixed_centroids(
|
||||||
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations,
|
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, float [:, ::1] rgb12_iigs_to_cam16ucs):
|
||||||
float tolerance, float [:, ::1] rgb12_iigs_to_cam16ucs):
|
|
||||||
|
|
||||||
cdef double error, best_error, centroid_movement, total_error
|
cdef double error, best_error, total_error, last_total_error
|
||||||
cdef int centroid_idx, closest_centroid_idx, i, point_idx
|
cdef int centroid_idx, closest_centroid_idx, i, point_idx
|
||||||
|
|
||||||
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
|
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
|
||||||
|
@ -686,20 +685,17 @@ def k_means_with_fixed_centroids(
|
||||||
cdef float[:, ::1] centroid_cam_sample_positions_total
|
cdef float[:, ::1] centroid_cam_sample_positions_total
|
||||||
cdef int[::1] centroid_sample_counts
|
cdef int[::1] centroid_sample_counts
|
||||||
|
|
||||||
# Allow centroids to move on lattice of size 15/255 in sRGB Rec.601 space -- matches //gs palette
|
last_total_error = 1e9
|
||||||
# map centroids to CAM when computing distances, cluster means etc
|
|
||||||
# Map new centroid back to closest lattice point
|
|
||||||
|
|
||||||
# Return CAM centroids
|
|
||||||
|
|
||||||
cdef int centroid_moved
|
|
||||||
for iteration in range(max_iterations):
|
for iteration in range(max_iterations):
|
||||||
centroid_moved = 1
|
|
||||||
total_error = 0.0
|
total_error = 0.0
|
||||||
centroid_movement = 0.0
|
|
||||||
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
|
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
|
||||||
centroid_sample_counts = np.zeros(16, dtype=np.int32)
|
centroid_sample_counts = np.zeros(16, dtype=np.int32)
|
||||||
|
|
||||||
|
# For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples
|
||||||
|
# but we do this by accumulating the (coordinate vector) total and number of associated samples.
|
||||||
|
#
|
||||||
|
# Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour
|
||||||
|
# space.
|
||||||
for point_idx in range(samples.shape[0]):
|
for point_idx in range(samples.shape[0]):
|
||||||
point_cam = samples[point_idx, :]
|
point_cam = samples[point_idx, :]
|
||||||
best_error = 1e9
|
best_error = 1e9
|
||||||
|
@ -715,28 +711,29 @@ def k_means_with_fixed_centroids(
|
||||||
centroid_sample_counts[closest_centroid_idx] += 1
|
centroid_sample_counts[closest_centroid_idx] += 1
|
||||||
total_error += best_error
|
total_error += best_error
|
||||||
|
|
||||||
|
# Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we
|
||||||
|
# can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest
|
||||||
|
# available point to an intended next centroid position will increase the total distance, or centroids may
|
||||||
|
# oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing.
|
||||||
|
if total_error >= last_total_error:
|
||||||
|
break
|
||||||
|
last_total_error = total_error
|
||||||
|
|
||||||
|
# Compute new centroid positions in CAM16UCS colour space
|
||||||
for centroid_idx in range(n_fixed, n_clusters):
|
for centroid_idx in range(n_fixed, n_clusters):
|
||||||
if centroid_sample_counts[centroid_idx]:
|
if centroid_sample_counts[centroid_idx]:
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
new_centroids_cam[centroid_idx - n_fixed, i] = (
|
new_centroids_cam[centroid_idx - n_fixed, i] = (
|
||||||
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
|
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
|
||||||
centroid_movement += colour_distance_squared(
|
|
||||||
_convert_rgb12_iigs_to_cam(
|
|
||||||
rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]),
|
|
||||||
new_centroids_cam[centroid_idx - n_fixed, :])
|
|
||||||
|
|
||||||
# Convert all new centroids as a single matrix since _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead
|
# Convert all new centroids back to //gb RGB colour space (done as a single matrix since
|
||||||
|
# _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead)
|
||||||
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
|
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
|
||||||
|
|
||||||
|
# Update positions for non-fixed centroids
|
||||||
for centroid_idx in range(n_clusters - n_fixed):
|
for centroid_idx in range(n_clusters - n_fixed):
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
|
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
|
||||||
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
|
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
|
||||||
centroid_moved = 1
|
|
||||||
|
|
||||||
if centroid_movement < tolerance:
|
return centroids_rgb12
|
||||||
break
|
|
||||||
if centroid_moved == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return centroids_rgb12, total_error
|
|
Loading…
Reference in New Issue