# cython: infer_types=True # cython: profile=False # cython: boundscheck=False # cython: wraparound=False cimport cython import colour import numpy as np cimport common def dither_shr_perfect( float[:, :, ::1] input_rgb, float[:, ::1] full_palette_cam, float[:, ::1] full_palette_rgb, float[:,::1] rgb_to_cam16ucs): cdef int y, x, idx, best_colour_idx, i, j cdef double best_distance, distance, total_image_error cdef float[::1] best_colour_rgb cdef float quant_error cdef float[:, ::1] palette_rgb, palette_cam cdef float[:, :, ::1] working_image = np.copy(input_rgb) cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32) cdef int palette_size = full_palette_rgb.shape[0] cdef float decay = 0.5 cdef int floyd_steinberg = 1 cdef common.float3 cam, pixel_cam total_image_error = 0.0 for y in range(200): for x in range(320): cam = common.convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) for j in range(3): line_cam[x, j] = cam.data[j] for x in range(320): pixel_cam = common.convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2]) best_distance = 1e9 best_colour_idx = -1 for idx in range(palette_size): for j in range(3): cam.data[j] = full_palette_cam[idx,j] distance = common.colour_distance_squared(pixel_cam.data, cam.data) if distance < best_distance: best_distance = distance best_colour_idx = idx best_colour_rgb = full_palette_rgb[best_colour_idx, :] total_image_error += best_distance for i in range(3): quant_error = working_image[y, x, i] - best_colour_rgb[i] working_image[y, x, i] = best_colour_rgb[i] if floyd_steinberg: # Floyd-Steinberg dither # 0 * 7 # 3 5 1 if x < 319: working_image[y, x + 1, i] = common.clip( working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1) if y < 199: if x > 0: working_image[y + 1, x - 1, i] = common.clip( working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1) working_image[y + 1, x, i] = common.clip( working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1) if x < 319: working_image[y + 1, x + 1, i] = common.clip( working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1) else: # Jarvis # 0 0 X 7 5 # 3 5 7 5 3 # 1 3 5 3 1 if x < 319: working_image[y, x + 1, i] = common.clip( working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1) if x < 318: working_image[y, x + 2, i] = common.clip( working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1) if y < 199: if x > 1: working_image[y + 1, x - 2, i] = common.clip( working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0, 1) if x > 0: working_image[y + 1, x - 1, i] = common.clip( working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0, 1) working_image[y + 1, x, i] = common.clip( working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1) if x < 319: working_image[y + 1, x + 1, i] = common.clip( working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48), 0, 1) if x < 318: working_image[y + 1, x + 2, i] = common.clip( working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48), 0, 1) if y < 198: if x > 1: working_image[y + 2, x - 2, i] = common.clip( working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0, 1) if x > 0: working_image[y + 2, x - 1, i] = common.clip( working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0, 1) working_image[y + 2, x, i] = common.clip( working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1) if x < 319: working_image[y + 2, x + 1, i] = common.clip( working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48), 0, 1) if x < 318: working_image[y + 2, x + 2, i] = common.clip( working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48), 0, 1) return total_image_error, working_image def dither_shr( float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:,::1] rgb_to_cam16ucs): cdef int y, x, idx, best_colour_idx, best_palette, i, j cdef double best_distance, distance, total_image_error cdef float[::1] best_colour_rgb cdef float quant_error cdef float[:, ::1] palette_rgb, palette_cam cdef (unsigned char)[:, ::1] output_4bit = np.zeros((200, 320), dtype=np.uint8) cdef float[:, :, ::1] working_image = np.copy(input_rgb) cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32) cdef int[::1] line_to_palette = np.zeros(200, dtype=np.int32) cdef double[::1] palette_line_errors = np.zeros(200, dtype=np.float64) cdef PaletteSelection palette_line cdef float decay = 0.5 cdef int floyd_steinberg = 1 cdef common.float3 pixel_cam, cam best_palette = -1 total_image_error = 0.0 for y in range(200): for x in range(320): pixel_cam = common.convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) for j in range(3): line_cam[x, j] = pixel_cam.data[j] palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette) best_palette = palette_line.palette_idx palette_line_errors[y] = palette_line.total_error palette_rgb = palettes_rgb[best_palette, :, :] palette_cam = palettes_cam[best_palette, :, :] line_to_palette[y] = best_palette for x in range(320): pixel_cam = common.convert_rgb_to_cam16ucs( rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2]) best_distance = 1e9 best_colour_idx = -1 for idx in range(16): for j in range(3): cam.data[j] = palette_cam[idx, j] distance = common.colour_distance_squared(pixel_cam.data, cam.data) if distance < best_distance: best_distance = distance best_colour_idx = idx best_colour_rgb = palette_rgb[best_colour_idx] output_4bit[y, x] = best_colour_idx total_image_error += best_distance for i in range(3): quant_error = working_image[y, x, i] - best_colour_rgb[i] working_image[y, x, i] = best_colour_rgb[i] if floyd_steinberg: # Floyd-Steinberg dither # 0 * 7 # 3 5 1 if x < 319: working_image[y, x + 1, i] = common.clip( working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1) if y < 199: if x > 0: working_image[y + 1, x - 1, i] = common.clip( working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1) working_image[y + 1, x, i] = common.clip( working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1) if x < 319: working_image[y + 1, x + 1, i] = common.clip( working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1) else: # Jarvis # 0 0 X 7 5 # 3 5 7 5 3 # 1 3 5 3 1 if x < 319: working_image[y, x + 1, i] = common.clip( working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1) if x < 318: working_image[y, x + 2, i] = common.clip( working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1) if y < 199: if x > 1: working_image[y + 1, x - 2, i] = common.clip( working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0, 1) if x > 0: working_image[y + 1, x - 1, i] = common.clip( working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0, 1) working_image[y + 1, x, i] = common.clip( working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1) if x < 319: working_image[y + 1, x + 1, i] = common.clip( working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48), 0, 1) if x < 318: working_image[y + 1, x + 2, i] = common.clip( working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48), 0, 1) if y < 198: if x > 1: working_image[y + 2, x - 2, i] = common.clip( working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0, 1) if x > 0: working_image[y + 2, x - 1, i] = common.clip( working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0, 1) working_image[y + 2, x, i] = common.clip( working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1) if x < 319: working_image[y + 2, x + 1, i] = common.clip( working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48), 0, 1) if x < 318: working_image[y + 2, x + 2, i] = common.clip( working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48), 0, 1) return ( np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error, np.array(palette_line_errors, dtype=np.float64) ) cdef struct PaletteSelection: int palette_idx double total_error cdef PaletteSelection best_palette_for_line( float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx) nogil: cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist cdef float[:, ::1] palette_cam cdef common.float3 pixel_cam, cam cdef int j best_total_dist = 1e9 best_palette_idx = -1 cdef int line_size = line_cam.shape[0] for palette_idx in range(16): palette_cam = palettes_cam[palette_idx, :, :] total_dist = 0 for pixel_idx in range(line_size): for j in range(3): pixel_cam.data[j] = line_cam[pixel_idx, j] best_pixel_dist = 1e9 for palette_entry_idx in range(16): for j in range(3): cam.data[j] = palette_cam[palette_entry_idx, j] pixel_dist = common.colour_distance_squared(pixel_cam.data, cam.data) if pixel_dist < best_pixel_dist: best_pixel_dist = pixel_dist total_dist += best_pixel_dist if total_dist < best_total_dist: best_total_dist = total_dist best_palette_idx = palette_idx cdef PaletteSelection res res.palette_idx = best_palette_idx res.total_error = best_total_dist return res cdef common.float3 _convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) nogil: cdef int rgb12 = (point_rgb12[0] << 8) | (point_rgb12[1] << 4) | point_rgb12[2] cdef int i cdef common.float3 res for i in range(3): res.data[i] = rgb12_iigs_to_cam16ucs[rgb12, i] return res # Wrapper around _convert_rgb12_iigs_to_cam to allow calling from python while retaining fast path for cython calls. def convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) -> float[::1]: cdef common.float3 cam = _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, point_rgb12) cdef int i cdef float[::1] res = np.empty((3), dtype=np.float32) for i in range(3): res[i] = cam.data[i] return res @cython.cdivision(True) cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4): cdef int i, j cdef float[:, ::1] res = np.empty_like(a, dtype=np.float32) for i in range(res.shape[0]): for j in range(3): if a[i, j] <= 0.0031308: res[i, j] = a[i, j] * 12.92 else: res[i, j] = 1.055 * a[i, j] ** (1.0 / gamma) - 0.055 return res # TODO: optimize cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam): cdef float[:, ::1] rgb cdef (float)[:, ::1] rgb12_iigs # Convert CAM16UCS input to RGB. Even though this dynamically constructs a path on the graph of colour conversions # every time, in practise this seems to have a negligible overhead compared to the actual conversion functions. with colour.utilities.suppress_warnings(python_warnings=True): rgb = colour.convert(point_cam, "CAM16UCS", "RGB").astype(np.float32) # TODO: precompute this conversion matrix since it's static. This accounts for about 10% of the CPU time here. rgb12_iigs = np.ascontiguousarray( np.clip( # Convert to Rec.601 R'G'B' colour.YCbCr_to_RGB( # Gamma correct and convert Rec.709 R'G'B' to YCbCr colour.RGB_to_YCbCr( linear_to_srgb_array(rgb), K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1) ).astype(np.float32) * 15 return np.round(rgb12_iigs).astype(np.uint8) # Wrapper around _convert_cam16ucs_to_rgb12_iigs to allow calling from python while retaining fast path for cython # calls. def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam): return _convert_cam16ucs_to_rgb12_iigs(point_cam) @cython.cdivision(True) def k_means_with_fixed_centroids( int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations, float [:, ::1] rgb12_iigs_to_cam16ucs): cdef double error, best_error, total_error, last_total_error cdef int centroid_idx, closest_centroid_idx, i, point_idx cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids) cdef (unsigned char)[:, ::1] new_centroids_rgb12 cdef common.float3 point_cam cdef float[:, ::1] new_centroids_cam = np.empty((n_clusters - n_fixed, 3), dtype=np.float32) cdef float[:, ::1] centroid_cam_sample_positions_total cdef int[::1] centroid_sample_counts last_total_error = 1e9 for iteration in range(max_iterations): total_error = 0.0 centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32) centroid_sample_counts = np.zeros(16, dtype=np.int32) # For each sample, associate it to the closest centroid. We want to compute the mean of all associated samples # but we do this by accumulating the (coordinate vector) total and number of associated samples. # # Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour # space. for point_idx in range(samples.shape[0]): for j in range(3): point_cam.data[j] = samples[point_idx, j] best_error = 1e9 closest_centroid_idx = 0 for centroid_idx in range(n_clusters): error = common.colour_distance_squared( _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]).data, point_cam.data) if error < best_error: best_error = error closest_centroid_idx = centroid_idx for i in range(3): centroid_cam_sample_positions_total[closest_centroid_idx, i] += point_cam.data[i] centroid_sample_counts[closest_centroid_idx] += 1 total_error += best_error # Since the allowed centroid positions are discrete (and not uniformly spaced in CAM16UCS colour space), we # can't rely on measuring total centroid movement as a termination condition. e.g. sometimes the nearest # available point to an intended next centroid position will increase the total distance, or centroids may # oscillate between two neighbouring positions. Instead, we terminate when the total error stops decreasing. if total_error >= last_total_error: break last_total_error = total_error # Compute new centroid positions in CAM16UCS colour space for centroid_idx in range(n_fixed, n_clusters): if centroid_sample_counts[centroid_idx]: for i in range(3): new_centroids_cam[centroid_idx - n_fixed, i] = ( centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx]) # Convert all new centroids back to //gb RGB colour space (done as a single matrix since # _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead) new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam) # Update positions for non-fixed centroids for centroid_idx in range(n_clusters - n_fixed): for i in range(3): if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]: centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i] return centroids_rgb12