mirror of
https://github.com/KrisKennaway/ii-pix.git
synced 2024-06-08 06:29:31 +00:00
k-means should be using median with L1 norm, otherwise it may not converge
Also optimize a tiny bit
This commit is contained in:
parent
5cab854269
commit
52af982159
12
dither.pyx
12
dither.pyx
|
@ -464,7 +464,7 @@ def k_means_with_fixed_centroids(
|
|||
cdef int i, iteration, centroid_idx, num_fixed_centroids, num_random_centroids, best_centroid_idx
|
||||
cdef float[::1] point, centroid, new_centroid, old_centroid
|
||||
cdef float[:, ::1] centroids
|
||||
cdef float best_dist, centroid_movement
|
||||
cdef float best_dist, centroid_movement, dist
|
||||
|
||||
centroids = np.zeros((n_clusters, 3), dtype=np.float32)
|
||||
if fixed_centroids is not None:
|
||||
|
@ -473,9 +473,9 @@ def k_means_with_fixed_centroids(
|
|||
num_random_centroids = n_clusters - num_fixed_centroids
|
||||
|
||||
# TODO: kmeans++ initialization
|
||||
cdef int rand_idx = random.randint(0, data.shape[0])
|
||||
for i in range(num_random_centroids):
|
||||
centroids[num_fixed_centroids + i, :] = data[
|
||||
random.randint(0, data.shape[0]), :]
|
||||
centroids[num_fixed_centroids + i, :] = data[rand_idx, :]
|
||||
|
||||
cdef int[::1] centroid_weights = np.zeros(n_clusters, dtype=np.int32)
|
||||
for iteration in range(iterations):
|
||||
|
@ -484,7 +484,8 @@ def k_means_with_fixed_centroids(
|
|||
for point in data:
|
||||
best_dist = 1e9
|
||||
best_centroid_idx = 0
|
||||
for centroid_idx, centroid in enumerate(centroids):
|
||||
for centroid_idx in range(n_clusters):
|
||||
centroid = centroids[centroid_idx, :]
|
||||
dist = colour_distance(centroid, point)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
|
@ -492,12 +493,11 @@ def k_means_with_fixed_centroids(
|
|||
closest_points[best_centroid_idx].append(point)
|
||||
|
||||
centroid_movement = 0
|
||||
|
||||
for centroid_idx, points in closest_points.items():
|
||||
centroid_weights[centroid_idx] = len(points)
|
||||
if centroid_idx < num_fixed_centroids:
|
||||
continue
|
||||
new_centroid = np.mean(np.array(points), axis=0)
|
||||
new_centroid = np.median(np.array(points), axis=0)
|
||||
old_centroid = centroids[centroid_idx]
|
||||
centroid_movement += colour_distance(old_centroid, new_centroid)
|
||||
centroids[centroid_idx, :] = new_centroid
|
||||
|
|
Loading…
Reference in New Issue
Block a user