Optimize calling _convert_cam16ucs_to_rgb12_iigs since it has

significant overhead
This commit is contained in:
kris 2021-11-18 21:50:39 +00:00
parent 3159a09c27
commit c608f6b961
2 changed files with 39 additions and 56 deletions

View File

@ -60,10 +60,8 @@ class ClusterPalette:
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
res = np.empty((16, 3), dtype=np.uint8)
for i in range(16):
res[i, :] = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order][i].astype(
res = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order].astype(
np.float32))
return res
@ -254,7 +252,6 @@ def main():
rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs)
while iterations_since_improvement < iterations:
print(iterations_since_improvement)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
cluster_palette.propose_palettes())

View File

@ -502,23 +502,27 @@ import colour
@cython.boundscheck(False)
@cython.wraparound(False)
cdef float[::1] linear_to_srgb_array(float[::1] a, float gamma=2.4):
cdef int i
cdef float[::1] res = np.empty(3, dtype=np.float32)
for i in range(3):
if a[i] <= 0.0031308:
res[i] = a[i] * 12.92
else:
res[i] = 1.055 * a[i] ** (1.0 / gamma) - 0.055
cdef float[:, ::1] linear_to_srgb_array(float[:, ::1] a, float gamma=2.4):
cdef int i, j
cdef float[:, ::1] res = np.empty_like(a, dtype=np.float32)
for i in range(res.shape[0]):
for j in range(3):
if a[i, j] <= 0.0031308:
res[i, j] = a[i, j] * 12.92
else:
res[i, j] = 1.055 * a[i, j] ** (1.0 / gamma) - 0.055
return res
@cython.boundscheck(False)
@cython.wraparound(False)
cdef (unsigned char)[::1] _convert_cam16ucs_to_rgb12_iigs(float[::1] point_cam):
cdef float[::1] rgb, rgb12_iigs
cdef int i
cdef (unsigned char)[:, ::1] _convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
cdef float[:, ::1] rgb
cdef (float)[:, ::1] rgb12_iigs
# Convert CAM16UCS input to RGB
# TODO: this dynamically constructs a path on the graph of colour conversions every time, which is
# presumably not very efficient. However, colour.convert doesn't provide a way to cache the composed conversion
# function so we'd have to build it ourselves (https://github.com/colour-science/colour/issues/905)
with colour.utilities.suppress_warnings(python_warnings=True):
rgb = colour.convert(point_cam, "CAM16UCS", "RGB").astype(np.float32)
@ -528,17 +532,14 @@ cdef (unsigned char)[::1] _convert_cam16ucs_to_rgb12_iigs(float[::1] point_cam):
# Gamma correct and convert Rec.709 R'G'B' to YCbCr
colour.RGB_to_YCbCr(
linear_to_srgb_array(rgb), K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1).astype(np.float32)
for i in range(3):
rgb12_iigs[i] *= 15
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1).astype(np.float32) * 15
return np.round(rgb12_iigs).astype(np.uint8)
def convert_cam16ucs_to_rgb12_iigs(float[::1] point_cam):
def convert_cam16ucs_to_rgb12_iigs(float[:, ::1] point_cam):
return _convert_cam16ucs_to_rgb12_iigs(point_cam)
@cython.boundscheck(False)
@cython.wraparound(False)
def k_means_with_fixed_centroids(
@ -549,9 +550,10 @@ def k_means_with_fixed_centroids(
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef (unsigned char)[:, ::1] centroids_rgb12 = initial_centroids[:, :]
cdef (unsigned char)[::1] centroid_rgb12, new_centroid_rgb12
cdef (unsigned char)[:, ::1] new_centroids_rgb12
cdef float[::1] point_cam, new_centroid_cam = np.empty(3, dtype=np.float32)
cdef float[::1] point_cam
cdef float[:, ::1] new_centroids_cam = np.empty((n_clusters - n_fixed, 3), dtype=np.float32)
cdef float[:, ::1] centroid_cam_sample_positions_total
cdef int[::1] centroid_sample_counts
@ -574,8 +576,8 @@ def k_means_with_fixed_centroids(
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
centroid_rgb12 = centroids_rgb12[centroid_idx, :]
error = colour_distance_squared(_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroid_rgb12), point_cam)
error = colour_distance_squared(
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx, :]), point_cam)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
@ -587,16 +589,21 @@ def k_means_with_fixed_centroids(
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroid_cam[i] = (
new_centroids_cam[centroid_idx - n_fixed, i] = (
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(
_convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]), new_centroid_cam)
new_centroid_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroid_cam)
for i in range(3):
if centroids_rgb12[centroid_idx, i] != new_centroid_rgb12[i]:
# print(i, centroids_rgb12[centroid_idx, i], new_centroid_rgb12[i])
centroids_rgb12[centroid_idx, i] = new_centroid_rgb12[i]
centroid_moved = 1
_convert_rgb12_iigs_to_cam(
rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]),
new_centroids_cam[centroid_idx - n_fixed, :])
# Convert all new centroids as a single matrix since _convert_cam16ucs_to_rgb12_iigs has nontrivial overhead
new_centroids_rgb12 = _convert_cam16ucs_to_rgb12_iigs(new_centroids_cam)
for centroid_idx in range(n_clusters - n_fixed):
for i in range(3):
if centroids_rgb12[centroid_idx + n_fixed, i] != new_centroids_rgb12[centroid_idx, i]:
centroids_rgb12[centroid_idx + n_fixed, i] = new_centroids_rgb12[centroid_idx, i]
centroid_moved = 1
# print(iteration, centroid_movement, total_error, centroids_rgb12)
@ -605,25 +612,4 @@ def k_means_with_fixed_centroids(
if centroid_moved == 0:
break
return centroids_rgb12, total_error
#@cython.boundscheck(False)
#@cython.wraparound(False)
#cdef float[::1] closest_quantized_point(float [:, ::1] rgb24_to_cam, float [::1] point_cam) nogil:
# cdef unsigned int rgb12, rgb24, closest_rgb24, r, g, b
# cdef double best_distance = 1e9, distance
# for rgb12 in range(2**12):
# r = rgb12 >> 8
# g = (rgb12 >> 4) & 0xf
# b = rgb12 & 0xf
# rgb24 = (r << 20) | (r << 16) | (g << 12) | (g << 8) | (b << 4) | b
# distance = colour_distance_squared(rgb24_to_cam[rgb24], point_cam)
# # print(hex(rgb24), distance)
# if distance < best_distance:
# best_distance = distance
# closest_rgb24 = rgb24
# # print(distance, rgb24, hex(rgb24))
# # print("-->", closest_rgb24, hex(closest_rgb24), best_distance)
# return rgb24_to_cam[closest_rgb24]
return centroids_rgb12, total_error