Optimize dither_dhr.dither_image performance by about 2x (#12)

- avoid passing around a float[::1] memoryview across function barriers, this seems to require reference counting which has a large overhead
- inline some functions
- C division
- float instead of double
This commit is contained in:
KrisKennaway 2023-02-25 21:21:43 +00:00 committed by GitHub
parent 7a4e27e0da
commit f8fbd768a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 33 deletions

View File

@ -1,5 +1,10 @@
cdef float clip(float a, float min_value, float max_value) nogil
cdef float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil
# This is used to avoid passing around float[::1] memoryviews in the critical path. These seem to
# require reference counting which has a large performance overhead.
cdef packed struct float3:
float[3] data
cdef double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil
cdef float3 convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil
cdef float colour_distance_squared(float[3] colour1, float[3] colour2) nogil

View File

@ -4,20 +4,28 @@
# cython: wraparound=False
cdef float clip(float a, float min_value, float max_value) nogil:
cdef inline float clip(float a, float min_value, float max_value) nogil:
"""Clip a value between min_value and max_value inclusive."""
return min(max(a, min_value), max_value)
cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
cdef inline float3 convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
"""Converts floating point (r,g,b) valueto 3-tuple in CAM16UCS colour space, via 24-bit RGB lookup matrix."""
cdef unsigned int rgb_24bit = (<unsigned int>(r*255) << 16) + (<unsigned int>(g*255) << 8) + <unsigned int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
cdef float3 res
cdef int i
for i in range(3):
res.data[i] = rgb_to_cam16ucs[rgb_24bit][i]
return res
cdef inline double colour_distance_squared(float[::1] colour1, float[::1] colour2) nogil:
cdef inline float colour_distance_squared(float[3] colour1, float[3] colour2) nogil:
"""Computes Euclidean squared distance between two floating-point colour 3-tuples."""
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2
return (
(colour1[0] - colour2[0]) * (colour1[0] - colour2[0]) +
(colour1[1] - colour2[1]) * (colour1[1] - colour2[1]) +
(colour1[2] - colour2[2]) * (colour1[2] - colour2[2])
)

View File

@ -24,21 +24,21 @@ cdef struct Dither:
# Compute left-hand bounding box for dithering at horizontal position x.
cdef int dither_bounds_xl(Dither *dither, int x) nogil:
cdef inline int dither_bounds_xl(Dither *dither, int x) nogil:
cdef int el = max(dither.x_origin - x, 0)
cdef int xl = x - dither.x_origin + el
return xl
#Compute right-hand bounding box for dithering at horizontal position x.
cdef int dither_bounds_xr(Dither *dither, int x_res, int x) nogil:
cdef inline int dither_bounds_xr(Dither *dither, int x_res, int x) nogil:
cdef int er = min(dither.x_shape, x_res - x)
cdef int xr = x - dither.x_origin + er
return xr
# Compute upper bounding box for dithering at vertical position y.
cdef int dither_bounds_yt(Dither *dither, int y) nogil:
cdef inline int dither_bounds_yt(Dither *dither, int y) nogil:
cdef int et = max(dither.y_origin - y, 0)
cdef int yt = y - dither.y_origin + et
@ -46,7 +46,7 @@ cdef int dither_bounds_yt(Dither *dither, int y) nogil:
# Compute lower bounding box for dithering at vertical position y.
cdef int dither_bounds_yb(Dither *dither, int y_res, int y) nogil:
cdef inline int dither_bounds_yb(Dither *dither, int y_res, int y) nogil:
cdef int eb = min(dither.y_shape, y_res - y)
cdef int yb = y - dither.y_origin + eb
return yb
@ -128,6 +128,7 @@ cdef struct Context:
#
# Returns: index from 0 .. 2**lookahead into options_nbit representing best available choice for position (x,y)
#
@cython.cdivision(True)
cdef int dither_lookahead(Dither* dither, unsigned char palette_depth, float[:, :, ::1] palette_cam16,
float[:, :, ::1] palette_rgb, float[:, :, ::1] image_rgb, int x, int y, unsigned char last_pixels,
int x_res, float[:,::1] rgb_to_cam16ucs, Context context) nogil:
@ -138,13 +139,15 @@ cdef int dither_lookahead(Dither* dither, unsigned char palette_depth, float[:,
cdef float total_error
cdef unsigned char current_pixels
cdef int phase
cdef float[::1] lah_cam16ucs
cdef common.float3 lah_cam16ucs
cdef float[3] cam
# Don't bother dithering past the lookahead horizon or edge of screen.
cdef int xxr = min(x + context.pixel_lookahead, x_res)
cdef int lah_shape1 = xxr - x
cdef int lah_shape2 = 3
# TODO: try again with memoryview - does it actually have overhead here?
cdef float *lah_image_rgb = <float *> malloc(lah_shape1 * lah_shape2 * sizeof(float))
# For each 2**lookahead possibilities for the on/off state of the next lookahead pixels, apply error diffusion
@ -184,10 +187,13 @@ cdef int dither_lookahead(Dither* dither, unsigned char palette_depth, float[:,
quant_error[j] = lah_image_rgb[i * lah_shape2 + j] - palette_rgb[current_pixels, phase, j]
apply_one_line(dither, xl, xr, i, lah_image_rgb, lah_shape2, quant_error)
# Accumulate error distance from pixel colour to target colour in CAM16UCS colour space
lah_cam16ucs = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, lah_image_rgb[i*lah_shape2], lah_image_rgb[i*lah_shape2+1],
lah_image_rgb[i*lah_shape2+2])
total_error += common.colour_distance_squared(lah_cam16ucs, palette_cam16[current_pixels, phase])
for j in range(3):
cam[j] = palette_cam16[current_pixels, phase, j]
total_error += common.colour_distance_squared(lah_cam16ucs.data, cam)
if total_error >= best_error:
# No need to continue
@ -212,7 +218,7 @@ cdef int dither_lookahead(Dither* dither, unsigned char palette_depth, float[:,
# image_shape1: horizontal dimension of image
# quant_error: RGB quantization error to be diffused
#
cdef void apply_one_line(Dither* dither, int xl, int xr, int x, float[] image, int image_shape1,
cdef inline void apply_one_line(Dither* dither, int xl, int xr, int x, float[] image, int image_shape1,
float[] quant_error) nogil:
cdef int i, j
@ -274,8 +280,9 @@ cdef image_nbit_to_bitmap(
#
# Returns: tuple of n-bit output image array and RGB output image array
#
@cython.cdivision(True)
def dither_image(
screen, float[:, :, ::1] image_rgb, dither, int lookahead, unsigned char verbose, float[:,::1] rgb_to_cam16ucs):
screen, float[:, :, ::1] image_rgb, dither, int lookahead, unsigned char verbose, float[:, ::1] rgb_to_cam16ucs):
cdef int y, x
cdef unsigned char i, j, pixels_nbit, phase
cdef float[3] quant_error

View File

@ -13,9 +13,9 @@ cimport common
def dither_shr_perfect(
float[:, :, ::1] input_rgb, float[:, ::1] full_palette_cam, float[:, ::1] full_palette_rgb,
float[:,::1] rgb_to_cam16ucs):
cdef int y, x, idx, best_colour_idx, i
cdef int y, x, idx, best_colour_idx, i, j
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam
cdef float[::1] best_colour_rgb
cdef float quant_error
cdef float[:, ::1] palette_rgb, palette_cam
@ -27,11 +27,15 @@ def dither_shr_perfect(
cdef float decay = 0.5
cdef int floyd_steinberg = 1
cdef common.float3 cam, pixel_cam
total_image_error = 0.0
for y in range(200):
for x in range(320):
line_cam[x, :] = common.convert_rgb_to_cam16ucs(
cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for j in range(3):
line_cam[x, j] = cam.data[j]
for x in range(320):
pixel_cam = common.convert_rgb_to_cam16ucs(
@ -40,7 +44,9 @@ def dither_shr_perfect(
best_distance = 1e9
best_colour_idx = -1
for idx in range(palette_size):
distance = common.colour_distance_squared(pixel_cam, full_palette_cam[idx, :])
for j in range(3):
cam.data[j] = full_palette_cam[idx,j]
distance = common.colour_distance_squared(pixel_cam.data, cam.data)
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
@ -123,9 +129,9 @@ def dither_shr_perfect(
def dither_shr(
float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb,
float[:,::1] rgb_to_cam16ucs):
cdef int y, x, idx, best_colour_idx, best_palette, i
cdef int y, x, idx, best_colour_idx, best_palette, i, j
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam
cdef float[::1] best_colour_rgb
cdef float quant_error
cdef float[:, ::1] palette_rgb, palette_cam
@ -140,12 +146,16 @@ def dither_shr(
cdef float decay = 0.5
cdef int floyd_steinberg = 1
cdef common.float3 pixel_cam, cam
best_palette = -1
total_image_error = 0.0
for y in range(200):
for x in range(320):
line_cam[x, :] = common.convert_rgb_to_cam16ucs(
pixel_cam = common.convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for j in range(3):
line_cam[x, j] = pixel_cam.data[j]
palette_line = best_palette_for_line(line_cam, palettes_cam, best_palette)
best_palette = palette_line.palette_idx
@ -162,7 +172,9 @@ def dither_shr(
best_distance = 1e9
best_colour_idx = -1
for idx in range(16):
distance = common.colour_distance_squared(pixel_cam, palette_cam[idx, :])
for j in range(3):
cam.data[j] = palette_cam[idx, j]
distance = common.colour_distance_squared(pixel_cam.data, cam.data)
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
@ -256,7 +268,8 @@ cdef PaletteSelection best_palette_for_line(
cdef int palette_idx, best_palette_idx, palette_entry_idx, pixel_idx
cdef double best_total_dist, total_dist, best_pixel_dist, pixel_dist
cdef float[:, ::1] palette_cam
cdef float[::1] pixel_cam
cdef common.float3 pixel_cam, cam
cdef int j
best_total_dist = 1e9
best_palette_idx = -1
@ -265,10 +278,13 @@ cdef PaletteSelection best_palette_for_line(
palette_cam = palettes_cam[palette_idx, :, :]
total_dist = 0
for pixel_idx in range(line_size):
pixel_cam = line_cam[pixel_idx]
for j in range(3):
pixel_cam.data[j] = line_cam[pixel_idx, j]
best_pixel_dist = 1e9
for palette_entry_idx in range(16):
pixel_dist = common.colour_distance_squared(pixel_cam, palette_cam[palette_entry_idx, :])
for j in range(3):
cam.data[j] = palette_cam[palette_entry_idx, j]
pixel_dist = common.colour_distance_squared(pixel_cam.data, cam.data)
if pixel_dist < best_pixel_dist:
best_pixel_dist = pixel_dist
total_dist += best_pixel_dist
@ -282,14 +298,24 @@ cdef PaletteSelection best_palette_for_line(
return res
cdef float[::1] _convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) nogil:
cdef common.float3 _convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) nogil:
cdef int rgb12 = (point_rgb12[0] << 8) | (point_rgb12[1] << 4) | point_rgb12[2]
return rgb12_iigs_to_cam16ucs[rgb12]
cdef int i
cdef common.float3 res
for i in range(3):
res.data[i] = rgb12_iigs_to_cam16ucs[rgb12, i]
return res
# Wrapper around _convert_rgb12_iigs_to_cam to allow calling from python while retaining fast path for cython calls.
def convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) -> float[::1]:
return _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, point_rgb12)
cdef common.float3 cam = _convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, point_rgb12)
cdef int i
cdef float[::1] res = np.empty((3), dtype=np.float32)
for i in range(3):
res[i] = cam.data[i]
return res
@cython.cdivision(True)
@ -305,6 +331,7 @@ cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4):
return res
# TODO: optimize
cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
cdef float[:, ::1] rgb
cdef (float)[:, ::1] rgb12_iigs
@ -343,7 +370,7 @@ def k_means_with_fixed_centroids(
cdef (unsigned char)[:, ::1] centroids_rgb12 = np.copy(initial_centroids)
cdef (unsigned char)[:, ::1] new_centroids_rgb12
cdef float[::1] point_cam
cdef common.float3 point_cam
cdef float[:, ::1] new_centroids_cam = np.empty((n_clusters - n_fixed, 3), dtype=np.float32)
cdef float[:, ::1] centroid_cam_sample_positions_total
cdef int[::1] centroid_sample_counts
@ -360,17 +387,19 @@ def k_means_with_fixed_centroids(
# Centroid positions are tracked in 4-bit //gs RGB colour space with distances measured in CAM16UCS colour
# space.
for point_idx in range(samples.shape[0]):
point_cam = samples[point_idx, :]
for j in range(3):
point_cam.data[j] = samples[point_idx, j]
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
error = common.colour_distance_squared(
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]), point_cam)
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]).data,
point_cam.data)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
for i in range(3):
centroid_cam_sample_positions_total[closest_centroid_idx, i] += point_cam[i]
centroid_cam_sample_positions_total[closest_centroid_idx, i] += point_cam.data[i]
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error