This commit is contained in:
kris 2021-11-16 21:07:13 +00:00
parent 613a36909c
commit bb70eea7b0
1 changed files with 4 additions and 61 deletions

View File

@ -335,11 +335,12 @@ def dither_image(
free(cdither.pattern)
return image_nbit_to_bitmap(image_nbit, xres, yres, palette_depth)
import colour
@cython.boundscheck(False)
@cython.wraparound(False)
def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs, float penalty):
def dither_shr(
float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb,
float[:,::1] rgb_to_cam16ucs, float penalty):
cdef int y, x, idx, best_colour_idx, best_palette
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam
@ -355,14 +356,12 @@ def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[
best_palette = 15
total_image_error = 0.0
for y in range(200):
# print(y)
for x in range(320):
colour_cam = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
line_cam[x, :] = colour_cam
best_palette = best_palette_for_line(line_cam, palettes_cam, <int>(y * 16 / 200), best_palette, penalty)
# print("-->", best_palette)
palette_rgb = palettes_rgb[best_palette, :, :]
line_to_palette[y] = best_palette
@ -382,7 +381,6 @@ def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[
best_colour_rgb = palette_rgb[best_colour_idx]
output_4bit[y, x] = best_colour_idx
total_image_error += best_distance
# print(y,x,best_distance,total_image_error)
for i in range(3):
quant_error = working_image[y, x, i] - best_colour_rgb[i]
@ -395,6 +393,7 @@ def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
# TODO: parametrize the 0.5x decay factor
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + quant_error * (3 / 32), 0, 1)
@ -454,62 +453,6 @@ def dither_shr(float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error
import collections
import random
@cython.boundscheck(False)
@cython.wraparound(False)
def k_means_with_fixed_centroids(
int n_clusters, float[:, ::1] data, float[:, ::1] fixed_centroids = None,
int iterations = 10000, float tolerance = 1e-3):
cdef int i, iteration, centroid_idx, num_fixed_centroids, num_random_centroids, best_centroid_idx
cdef float[::1] point, centroid, new_centroid, old_centroid
cdef float[:, ::1] centroids
cdef float best_dist, centroid_movement, dist
centroids = np.zeros((n_clusters, 3), dtype=np.float32)
if fixed_centroids is not None:
centroids[:fixed_centroids.shape[0], :] = fixed_centroids
num_fixed_centroids = fixed_centroids.shape[0] if fixed_centroids is not None else 0
num_random_centroids = n_clusters - num_fixed_centroids
# TODO: kmeans++ initialization
cdef int rand_idx = random.randint(0, data.shape[0])
for i in range(num_random_centroids):
centroids[num_fixed_centroids + i, :] = data[rand_idx, :]
cdef int[::1] centroid_weights = np.zeros(n_clusters, dtype=np.int32)
for iteration in range(iterations):
# print("centroids ", centroids)
closest_points = collections.defaultdict(list)
for point in data:
best_dist = 1e9
best_centroid_idx = 0
for centroid_idx in range(n_clusters):
centroid = centroids[centroid_idx, :]
dist = colour_distance(centroid, point)
if dist < best_dist:
best_dist = dist
best_centroid_idx = centroid_idx
closest_points[best_centroid_idx].append(point)
centroid_movement = 0
for centroid_idx, points in closest_points.items():
centroid_weights[centroid_idx] = len(points)
if centroid_idx < num_fixed_centroids:
continue
new_centroid = np.median(np.array(points), axis=0)
old_centroid = centroids[centroid_idx]
centroid_movement += colour_distance(old_centroid, new_centroid)
centroids[centroid_idx, :] = new_centroid
# print("iteration %d: movement %f" % (iteration, centroid_movement))
if centroid_movement < tolerance:
break
weighted_centroids = list(zip(centroid_weights, [tuple(c) for c in centroids]))
print(weighted_centroids)
return np.array([c for w, c in sorted(weighted_centroids, reverse=True)], dtype=np.float32)
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int base_palette_idx, int last_palette_idx, float last_penalty) nogil: