mirror of
https://github.com/KrisKennaway/ii-pix.git
synced 2025-03-13 20:30:18 +00:00
Tweak k-means convergence criterion to return once the total centroid position error stops decreasing.
This commit is contained in:
parent
fc35387360
commit
61b4cbb184
@ -266,14 +266,12 @@ class ClusterPalette:
|
||||
initial_centroids[fixed_colours, :] = colour
|
||||
fixed_colours += 1
|
||||
|
||||
palette_rgb12_iigs, palette_error = \
|
||||
dither_pyx.k_means_with_fixed_centroids(
|
||||
palette_rgb12_iigs = dither_pyx.k_means_with_fixed_centroids(
|
||||
n_clusters=16, n_fixed=fixed_colours,
|
||||
samples=palette_pixels,
|
||||
initial_centroids=initial_centroids,
|
||||
max_iterations=1000, tolerance=0.05,
|
||||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
|
||||
)
|
||||
max_iterations=1000,
|
||||
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs)
|
||||
# If the k-means clustering returned fewer than 16 unique colours,
|
||||
# fill out the remainder with the most common pixels colours that
|
||||
# have not yet been used.
|
||||
|
45
dither.pyx
45
dither.pyx
@ -672,10 +672,9 @@ def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def k_means_with_fixed_centroids(
|
||||
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations,
|
||||
float tolerance, float [:, ::1] rgb12_iigs_to_cam16ucs):
|
||||
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, float [:, ::1] rgb12_iigs_to_cam16ucs):
|
||||
|
||||
cdef double error, best_error, centroid_movement, total_error
|
||||
cdef double error, best_error, total_error, last_total_error
|
||||
cdef int centroid_idx, closest_centroid_idx, i, point_idx
|
||||
|
||||
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
|
||||
@ -686,20 +685,17 @@ def k_means_with_fixed_centroids(
|
||||
cdef float[:, ::1] centroid_cam_sample_positions_total
|
||||
cdef int[::1] centroid_sample_counts
|
||||
|
||||
# Allow centroids to move on lattice of size 15/255 in sRGB Rec.601 space -- matches //gs palette
|
||||
# map centroids to CAM when computing distances, cluster means etc
|
||||
# Map new centroid back to closest lattice point
|
||||
|
||||
# Return CAM centroids
|
||||
|
||||
cdef int centroid_moved
|
||||
last_total_error = 1e9
|
||||
for iteration in range(max_iterations):
|
||||
centroid_moved = 1
|
||||
total_error = 0.0
|
||||
centroid_movement = 0.0
|
||||
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
|
||||
centroid_sample_counts = np.zeros(16, dtype=np.int32)
|
||||
|
||||
# For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples
|
||||
# but we do this by accumulating the (coordinate vector) total and number of associated samples.
|
||||
#
|
||||
# Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour
|
||||
# space.
|
||||
for point_idx in range(samples.shape[0]):
|
||||
point_cam = samples[point_idx, :]
|
||||
best_error = 1e9
|
||||
@ -715,28 +711,29 @@ def k_means_with_fixed_centroids(
|
||||
centroid_sample_counts[closest_centroid_idx] += 1
|
||||
total_error += best_error
|
||||
|
||||
# Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we
|
||||
# can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest
|
||||
# available point to an intended next centroid position will increase the total distance, or centroids may
|
||||
# oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing.
|
||||
if total_error >= last_total_error:
|
||||
break
|
||||
last_total_error = total_error
|
||||
|
||||
# Compute new centroid positions in CAM16UCS colour space
|
||||
for centroid_idx in range(n_fixed, n_clusters):
|
||||
if centroid_sample_counts[centroid_idx]:
|
||||
for i in range(3):
|
||||
new_centroids_cam[centroid_idx - n_fixed, i] = (
|
||||
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
|
||||
centroid_movement += colour_distance_squared(
|
||||
_convert_rgb12_iigs_to_cam(
|
||||
rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]),
|
||||
new_centroids_cam[centroid_idx - n_fixed, :])
|
||||
|
||||
# Convert all new centroids as a single matrix since _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead
|
||||
# Convert all new centroids back to //gb RGB colour space (done as a single matrix since
|
||||
# _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead)
|
||||
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
|
||||
|
||||
# Update positions for non-fixed centroids
|
||||
for centroid_idx in range(n_clusters - n_fixed):
|
||||
for i in range(3):
|
||||
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
|
||||
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
|
||||
centroid_moved = 1
|
||||
|
||||
if centroid_movement < tolerance:
|
||||
break
|
||||
if centroid_moved == 0:
|
||||
break
|
||||
|
||||
return centroids_rgb12, total_error
|
||||
return centroids_rgb12
|
Loading…
x
Reference in New Issue
Block a user