diff --git a/dither.pyx b/dither.pyx index 31277df..109611e 100644 --- a/dither.pyx +++ b/dither.pyx @@ -464,7 +464,7 @@ def k_means_with_fixed_centroids( cdef int i, iteration, centroid_idx, num_fixed_centroids, num_random_centroids, best_centroid_idx cdef float[::1] point, centroid, new_centroid, old_centroid cdef float[:, ::1] centroids - cdef float best_dist, centroid_movement + cdef float best_dist, centroid_movement, dist centroids = np.zeros((n_clusters, 3), dtype=np.float32) if fixed_centroids is not None: @@ -473,9 +473,9 @@ def k_means_with_fixed_centroids( num_random_centroids = n_clusters - num_fixed_centroids # TODO: kmeans++ initialization + cdef int rand_idx = random.randint(0, data.shape[0]) for i in range(num_random_centroids): - centroids[num_fixed_centroids + i, :] = data[ - random.randint(0, data.shape[0]), :] + centroids[num_fixed_centroids + i, :] = data[rand_idx, :] cdef int[::1] centroid_weights = np.zeros(n_clusters, dtype=np.int32) for iteration in range(iterations): @@ -484,7 +484,8 @@ def k_means_with_fixed_centroids( for point in data: best_dist = 1e9 best_centroid_idx = 0 - for centroid_idx, centroid in enumerate(centroids): + for centroid_idx in range(n_clusters): + centroid = centroids[centroid_idx, :] dist = colour_distance(centroid, point) if dist < best_dist: best_dist = dist @@ -492,12 +493,11 @@ def k_means_with_fixed_centroids( closest_points[best_centroid_idx].append(point) centroid_movement = 0 - for centroid_idx, points in closest_points.items(): centroid_weights[centroid_idx] = len(points) if centroid_idx < num_fixed_centroids: continue - new_centroid = np.mean(np.array(points), axis=0) + new_centroid = np.median(np.array(points), axis=0) old_centroid = centroids[centroid_idx] centroid_movement += colour_distance(old_centroid, new_centroid) centroids[centroid_idx, :] = new_centroid