Split out common utility functions into a shared module

This commit is contained in:
kris 2021-11-26 12:26:46 +00:00
parent 0dc2c0a7a0
commit ae89682dab
5 changed files with 82 additions and 78 deletions

5
common.pxd Normal file
View File

@ -0,0 +1,5 @@
cdef float clip(float a, float min_value, float max_value) nogil
cdef float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil
cdef double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil

18
common.pyx Normal file
View File

@ -0,0 +1,18 @@
# cython: infer_types=True
# cython: profile=False
# cython: boundscheck=False
# cython: wraparound=False
cdef float clip(float a, float min_value, float max_value) nogil:
return min(max(a, min_value), max_value)
cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
cdef unsigned int rgb_24bit = (<unsigned int>(r*255) << 16) + (<unsigned int>(g*255) << 8) + <unsigned int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2

View File

@ -7,6 +7,8 @@ cimport cython
import numpy as np
from libc.stdlib cimport malloc, free
cimport common
# TODO: use a cdef class
# C representation of dither_pattern.DitherPattern data, for efficient access.
@ -19,10 +21,6 @@ cdef struct Dither:
int y_origin
cdef float clip(float a, float min_value, float max_value) nogil:
return min(max(a, min_value), max_value)
# Compute left-hand bounding box for dithering at horizontal position x.
cdef int dither_bounds_xl(Dither *dither, int x) nogil:
cdef int el = max(dither.x_origin - x, 0)
@ -140,10 +138,10 @@ cdef int dither_lookahead(Dither* dither, float[:, :, ::1] palette_cam16, float[
quant_error[j] = lah_image_rgb[i * lah_shape2 + j] - palette_rgb[next_pixels, phase, j]
apply_one_line(dither, xl, xr, i, lah_image_rgb, lah_shape2, quant_error)
lah_cam16ucs = convert_rgb_to_cam16ucs(
lah_cam16ucs = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, lah_image_rgb[i*lah_shape2], lah_image_rgb[i*lah_shape2+1],
lah_image_rgb[i*lah_shape2+2])
total_error += colour_distance_squared(lah_cam16ucs, palette_cam16[next_pixels, phase])
total_error += common.colour_distance_squared(lah_cam16ucs, palette_cam16[next_pixels, phase])
if total_error >= best_error:
# No need to continue
@ -157,19 +155,6 @@ cdef int dither_lookahead(Dither* dither, float[:, :, ::1] palette_cam16, float[
return best
cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
cdef unsigned int rgb_24bit = (<unsigned int>(r*255) << 16) + (<unsigned int>(g*255) << 8) + <unsigned int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
cdef inline float fabs(float value) nogil:
return -value if value < 0 else value
cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2
# Perform error diffusion to a single image row.
#
# Args:
@ -190,7 +175,7 @@ cdef void apply_one_line(Dither* dither, int xl, int xr, int x, float[] image, i
for i in range(xl, xr):
error_fraction = dither.pattern[i - x + dither.x_origin]
for j in range(3):
image[i * image_shape1 + j] = clip(image[i * image_shape1 + j] + error_fraction * quant_error[j], 0, 1)
image[i * image_shape1 + j] = common.clip(image[i * image_shape1 + j] + error_fraction * quant_error[j], 0, 1)
# Perform error diffusion across multiple image rows.
@ -218,7 +203,7 @@ cdef void apply(Dither* dither, int x_res, int y_res, int x, int y, float[:,:,::
for j in range(xl, xr):
error_fraction = dither.pattern[(i - y) * dither.x_shape + j - x + dither.x_origin]
for k in range(3):
image[i,j,k] = clip(image[i,j,k] + error_fraction * quant_error[k], 0, 1)
image[i,j,k] = common.clip(image[i,j,k] + error_fraction * quant_error[k], 0, 1)
cdef image_nbit_to_bitmap(

View File

@ -4,20 +4,10 @@
# cython: wraparound=False
cimport cython
import colour
import numpy as np
# TODO: move these into a common module
cdef float clip(float a, float min_value, float max_value) nogil:
return min(max(a, min_value), max_value)
cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
cdef unsigned int rgb_24bit = (<unsigned int>(r*255) << 16) + (<unsigned int>(g*255) << 8) + <unsigned int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2
cimport common
def dither_shr_perfect(
@ -40,17 +30,17 @@ def dither_shr_perfect(
total_image_error = 0.0
for y in range(200):
for x in range(320):
line_cam[x, :] = convert_rgb_to_cam16ucs(
line_cam[x, :] = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for x in range(320):
pixel_cam = convert_rgb_to_cam16ucs(
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = -1
for idx in range(palette_size):
distance = colour_distance_squared(pixel_cam, full_palette_cam[idx, :])
distance = common.colour_distance_squared(pixel_cam, full_palette_cam[idx, :])
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
@ -66,16 +56,16 @@ def dither_shr_perfect(
# 0 * 7
# 3 5 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
@ -83,47 +73,47 @@ def dither_shr_perfect(
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = clip(
working_image[y, x + 2, i] = common.clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = clip(
working_image[y + 1, x - 2, i] = common.clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = clip(
working_image[y + 1, x + 2, i] = common.clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = clip(
working_image[y + 2, x - 2, i] = common.clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = clip(
working_image[y + 2, x - 1, i] = common.clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = clip(
working_image[y + 2, x, i] = common.clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = clip(
working_image[y + 2, x + 1, i] = common.clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = clip(
working_image[y + 2, x + 2, i] = common.clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
@ -154,7 +144,7 @@ def dither_shr(
total_image_error = 0.0
for y in range(200):
for x in range(320):
line_cam[x, :] = convert_rgb_to_cam16ucs(
line_cam[x, :] = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette)
@ -166,13 +156,13 @@ def dither_shr(
line_to_palette[y] = best_palette
for x in range(320):
pixel_cam = convert_rgb_to_cam16ucs(
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = -1
for idx in range(16):
distance = colour_distance_squared(pixel_cam, palette_cam[idx, :])
distance = common.colour_distance_squared(pixel_cam, palette_cam[idx, :])
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
@ -189,16 +179,16 @@ def dither_shr(
# 0 * 7
# 3 5 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
@ -206,51 +196,54 @@ def dither_shr(
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] = common.clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = clip(
working_image[y, x + 2, i] = common.clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = clip(
working_image[y + 1, x - 2, i] = common.clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] = common.clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] = common.clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] = common.clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = clip(
working_image[y + 1, x + 2, i] = common.clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = clip(
working_image[y + 2, x - 2, i] = common.clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = clip(
working_image[y + 2, x - 1, i] = common.clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = clip(
working_image[y + 2, x, i] = common.clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = clip(
working_image[y + 2, x + 1, i] = common.clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = clip(
working_image[y + 2, x + 2, i] = common.clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error, np.array(palette_line_errors, dtype=np.float64)
return (
np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error,
np.array(palette_line_errors, dtype=np.float64)
)
cdef struct PaletteSelection:
@ -258,7 +251,8 @@ cdef struct PaletteSelection:
double total_error
cdef PaletteSelection best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx) nogil:
cdef PaletteSelection best_palette_for_line(
float [:, ::1] line_cam, float[:, :, ::1] palettes_cam, int last_palette_idx) nogil:
cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx
cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef float[:, ::1] palette_cam
@ -274,7 +268,7 @@ cdef PaletteSelection best_palette_for_line(float [:, ::1] line_cam, float[:, :,
pixel_cam = line_cam[pixel_idx]
best_pixel_dist = 1e9
for palette_entry_idx in range(16):
pixel_dist = colour_distance_squared(pixel_cam, palette_cam[palette_entry_idx, :])
pixel_dist = common.colour_distance_squared(pixel_cam, palette_cam[palette_entry_idx, :])
if pixel_dist < best_pixel_dist:
best_pixel_dist = pixel_dist
total_dist += best_pixel_dist
@ -293,10 +287,10 @@ cdef float[::1] _convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs
return rgb12_iigs_to_cam16ucs[rgb12]
# Wrapper around _convert_rgb12_iigs_to_cam to allow calling from python while retaining fast path for cython calls.
def convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) -> float[::1]:
return _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, point_rgb12)
import colour
@cython.cdivision(True)
cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4):
@ -310,6 +304,7 @@ cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4):
res[i, j] = 1.055 * a[i, j] ** (1.0 / gamma) - 0.055
return res
cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
cdef float[:, ::1] rgb
cdef (float)[:, ::1] rgb12_iigs
@ -329,7 +324,8 @@ cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1).astype(np.float32) * 15
return np.round(rgb12_iigs).astype(np.uint8)
# Wrapper around _convert_cam16ucs_to_rgb12_iigs to allow calling from python while retaining fast path for cython
# calls.
def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
return _convert_cam16ucs_to_rgb12_iigs(point_cam)
@ -366,7 +362,7 @@ def k_means_with_fixed_centroids(
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
error = colour_distance_squared(
error = common.colour_distance_squared(
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]), point_cam)
if error < best_error:
best_error = error

View File

@ -6,7 +6,7 @@ Cython.Compiler.Options.annotate = True
setup(
ext_modules=cythonize(
["dither_dhr.pyx", "dither_shr.pyx"],
["common.pyx", "dither_dhr.pyx", "dither_shr.pyx"],
annotate=True,
compiler_directives={'language_level': "3"}
)