Working version! Quantize the k-means centroids in 12-bit //gs RGB

space but continue to use CAM16-UCS for distances and updating
centroid positions, before mapping back to the nearest legal 12-bit
RGB position.

Needs some more work to deal with the fact that now that there are
discrete distances (but no fixed minimum) between allowed centroid
positions, the previous notion of convergence doesn't apply.  Actually
the centroids can oscillate between positions.

There is room for optimization but this is already reasonably
performant, and the image quality is much higher \o/
This commit is contained in:
kris 2021-11-17 22:49:06 +00:00
parent 0009ce8913
commit ed2082344a
3 changed files with 218 additions and 86 deletions

View File

@ -28,13 +28,14 @@ import screen as screen_py
class ClusterPalette:
def __init__(
self, image: Image, reserved_colours=0):
self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0):
self._colours_cam = self._image_colours_cam(image)
self._reserved_colours = reserved_colours
self._errors = [1e9] * 16
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.uint8)
self._global_palette = np.empty((16, 16, 3), dtype=np.float32)
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image).reshape((-1, 3))
@ -58,7 +59,14 @@ class ClusterPalette:
# List of (palette idx, frequency count)
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
return clusters.cluster_centers_[frequency_order]
res = np.empty((16, 3), dtype=np.uint8)
for i in range(16):
res[i, :] = dither_pyx.convert_cam16ucs_to_rgb12_iigs(
clusters.cluster_centers_[frequency_order][i].astype(
np.float32))
print(res)
return res
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
@ -79,7 +87,7 @@ class ClusterPalette:
"""
new_errors = list(self._errors)
new_palettes_cam = np.copy(self._palettes_cam)
new_palettes_rgb = np.copy(self._palettes_rgb)
new_palettes_rgb12_iigs = np.copy(self._palettes_rgb)
# Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the
@ -117,11 +125,14 @@ class ClusterPalette:
#
# palette_error = clusters.inertia_
clusters, palette_error = dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._reserved_colours,
samples=palette_pixels, initial_centroids=self._global_palette,
max_iterations=1000, tolerance=1e-4
)
palettes_rgb12_iigs, palette_error = \
dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._reserved_colours,
samples=palette_pixels,
initial_centroids=self._global_palette,
max_iterations=1000, tolerance=0.05,
rgb12_iigs_to_cam16ucs=self._rgb12_iigs_to_cam16ucs
)
if (palette_error >= self._errors[palette_idx] and not
self._reserved_colours):
@ -130,21 +141,23 @@ class ClusterPalette:
# because it would break the invariant that all palettes must
# share colours.
continue
for i in range(16):
new_palettes_cam[palette_idx, i, :] = (
np.array(dither_pyx.convert_rgb12_iigs_to_cam(
self._rgb12_iigs_to_cam16ucs, palettes_rgb12_iigs[
i]), dtype=np.float32))
new_palettes_cam[palette_idx, :, :] = np.array(
# clusters.cluster_centers_).astype(np.float32)
clusters).astype(np.float32)
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
palette_rgb = colour.convert(
new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB")
palette_rgb_rec601 = np.clip(image_py.srgb_to_linear(
colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(palette_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1)
# with colour.utilities.suppress_warnings(python_warnings=True):
# palette_rgb = colour.convert(
# new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB")
# palette_rgb_rec601 = np.clip(image_py.srgb_to_linear(
# colour.YCbCr_to_RGB(
# colour.RGB_to_YCbCr(
# image_py.linear_to_srgb(palette_rgb * 255) / 255,
# K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
# K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1)
# palette_rgb = np.clip(
# image_py.srgb_to_linear(
# colour.YCbCr_to_RGB(
@ -155,10 +168,10 @@ class ClusterPalette:
# K=colour.WEIGHTS_YCBCR[
# 'ITU-R BT.601']) * 255) / 255,
# 0, 1)
new_palettes_rgb[palette_idx, :, :] = palette_rgb # palette_rgb_rec601
new_palettes_rgb12_iigs[palette_idx, :, :] = palettes_rgb12_iigs
new_errors[palette_idx] = palette_error
return new_palettes_cam, new_palettes_rgb, new_errors
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
def accept_palettes(
self, new_palettes_cam: np.ndarray,
@ -216,7 +229,8 @@ def main():
# Conversion matrix from RGB to CAM16UCS colour values. Indexed by
# 24-bit RGB value
rgb_to_cam16 = np.load("data/rgb_to_cam16ucs.npy")
rgb24_to_cam16ucs = np.load("data/rgb24_to_cam16ucs.npy")
rgb12_iigs_to_cam16ucs = np.load("data/rgb12_iigs_to_cam16ucs.npy")
# Open and resize source image
image = image_py.open(args.input)
@ -242,19 +256,26 @@ def main():
total_image_error = 1e9
iterations_since_improvement = 0
palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8)
cluster_palette = ClusterPalette(rgb, reserved_colours=1)
# palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8)
cluster_palette = ClusterPalette(
rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs)
while iterations_since_improvement < iterations:
new_palettes_cam, new_palettes_rgb, new_palette_errors = (
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
cluster_palette.propose_palettes())
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
new_palettes_linear_rgb = colour.convert(
new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
# Recompute image with proposed palettes and check whether it has
# lower total image error than our previous best.
new_output_4bit, new_line_to_palette, new_total_image_error = \
dither_pyx.dither_shr(
rgb, new_palettes_cam, new_palettes_rgb, rgb_to_cam16,
float(penalty))
rgb, new_palettes_cam, new_palettes_linear_rgb,
rgb24_to_cam16ucs, float(penalty))
if new_total_image_error >= total_image_error:
iterations_since_improvement += 1
continue
@ -262,7 +283,7 @@ def main():
# We found a globally better set of palettes
iterations_since_improvement = 0
cluster_palette.accept_palettes(
new_palettes_cam, new_palettes_rgb, new_palette_errors)
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
if total_image_error < 1e9:
print("Improved quality +%f%% (%f)" % (
@ -271,19 +292,19 @@ def main():
output_4bit = new_output_4bit
line_to_palette = new_line_to_palette
total_image_error = new_total_image_error
palettes_rgb = new_palettes_rgb
# Recompute 4-bit //gs RGB palettes
palette_rgb_rec601 = np.clip(
colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(palettes_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1)
palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8)
palettes_rgb12_iigs = new_palettes_rgb12_iigs
palettes_linear_rgb = new_palettes_linear_rgb
# # Recompute 4-bit //gs RGB palettes
# palette_rgb_rec601 = np.clip(
# colour.YCbCr_to_RGB(
# colour.RGB_to_YCbCr(
# image_py.linear_to_srgb(palettes_rgb12_iigs * 255) / 255,
# K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
# K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1)
#
# palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8)
for i in range(16):
screen.set_palette(i, palettes_iigs[i, :, :])
screen.set_palette(i, palettes_rgb12_iigs[i, :, :])
# Recompute current screen RGB image
screen.set_pixels(output_4bit)
@ -291,22 +312,22 @@ def main():
for i in range(200):
screen.line_palette[i] = line_to_palette[i]
output_rgb[i, :, :] = (
palettes_rgb[line_to_palette[i]][output_4bit[i, :]] * 255
palettes_linear_rgb[line_to_palette[i]][output_4bit[i, :]] * 255
).astype(
# np.round(palettes_rgb[line_to_palette[i]][
# output_4bit[i, :]] * 15) / 15 * 255).astype(
np.uint8)
output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(output_rgb) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255
# output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB(
# colour.RGB_to_YCbCr(
# image_py.linear_to_srgb(output_rgb) / 255,
# K=colour.WEIGHTS_YCBCR['ITU-R BT.601']),
# K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255
output_srgb = (image_py.linear_to_srgb(output_rgb)).astype(np.uint8)
# dither = dither_pattern.PATTERNS[args.dither]()
# bitmap = dither_pyx.dither_image(
# screen, rgb, dither, args.lookahead, args.verbose, rgb_to_cam16)
# screen, rgb, dither, args.lookahead, args.verbose, rgb24_to_cam16ucs)
# Show output image by rendering in target palette
# output_palette_name = args.show_palette or args.palette
@ -326,8 +347,8 @@ def main():
np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes
canvas.blit(surface, (0, 0))
pygame.display.flip()
print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(palettes_iigs.reshape(-1, 3), axis=0).shape[0]
# print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(palettes_rgb12_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)
# Save Double hi-res image

View File

@ -157,10 +157,10 @@ cdef int dither_lookahead(Dither* dither, float[:, :, ::1] palette_cam16, float[
return best
@cython.boundscheck(False)
@cython.boundscheck(True)
@cython.wraparound(False)
cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, float r, float g, float b) nogil:
cdef int rgb_24bit = (<int>(r*255) << 16) + (<int>(g*255) << 8) + <int>(b*255)
cdef unsigned int rgb_24bit = (<unsigned int>(r*255) << 16) + (<unsigned int>(g*255) << 8) + <unsigned int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
@cython.boundscheck(False)
@ -357,8 +357,6 @@ def dither_shr(
total_image_error = 0.0
for y in range(200):
for x in range(320):
#for i in range(3):
# working_image[y, x, i] = np.round(working_image[y, x, i] * 15) / 15
colour_cam = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
line_cam[x, :] = colour_cam
@ -491,53 +489,133 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette
best_palette_idx = palette_idx
return best_palette_idx
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.boundscheck(True)
@cython.wraparound(True)
def convert_rgb12_iigs_to_cam(float [:, ::1] rgb12_iigs_to_cam16ucs, (unsigned char)[::1] point_rgb12) -> float[::1]:
cdef int rgb12 = (point_rgb12[0] << 8) | (point_rgb12[1] << 4) | point_rgb12[2]
return rgb12_iigs_to_cam16ucs[rgb12]
import colour
cdef float[::1] linear_to_srgb_array(float[::1] a, float gamma=2.4):
cdef int i
cdef float[::1] res = np.empty(3, dtype=np.float32)
for i in range(3):
if a[i] <= 0.0031308:
res[i] = a[i] * 12.92
else:
res[i] = 1.055 * a[i] ** (1.0 / gamma) - 0.055
return res
@cython.boundscheck(True)
@cython.wraparound(True)
def convert_cam16ucs_to_rgb12_iigs(float[::1] point_cam) -> int[::1]: # XXX return type
cdef float[::1] rgb, rgb12_iigs
cdef int i
# Convert CAM16UCS input to RGB
with colour.utilities.suppress_warnings(python_warnings=True):
rgb = colour.convert(point_cam, "CAM16UCS", "RGB").astype(np.float32)
rgb12_iigs = np.clip(
# Convert to Rec.601 R'G'B'
colour.YCbCr_to_RGB(
# Gamma correct and convert Rec.709 R'G'B' to YCbCr
colour.RGB_to_YCbCr(
linear_to_srgb_array(rgb), K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1).astype(np.float32)
for i in range(3):
rgb12_iigs[i] *= 15
return np.round(rgb12_iigs).astype(np.uint8)
@cython.boundscheck(True)
@cython.wraparound(True)
def k_means_with_fixed_centroids(
int n_clusters, int n_fixed, float[:, ::1] samples, float[:, ::1] initial_centroids, int max_iterations, float tolerance):
int n_clusters, int n_fixed, float[:, ::1] samples, (unsigned char)[:, ::1] initial_centroids, int max_iterations,
float tolerance, float [:, ::1] rgb12_iigs_to_cam16ucs):
cdef double error, best_error, centroid_movement, total_error
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef float[:, ::1] centroids = initial_centroids[:, :]
cdef float[::1] centroid, point, new_centroid = np.empty(3, dtype=np.float32)
cdef (unsigned char)[:, ::1] centroids_rgb12 = initial_centroids[:, :]
cdef (unsigned char)[::1] centroid_rgb12, new_centroid_rgb12
cdef float[:, ::1] centroid_sample_positions_total
cdef float[::1] point_cam, new_centroid_cam = np.empty(3, dtype=np.float32)
cdef float[:, ::1] centroid_cam_sample_positions_total
cdef int[::1] centroid_sample_counts
# Allow centroids to move on lattice of size 15/255 in sRGB Rec.601 space -- matches //gs palette
# map centroids to CAM when computing distances, cluster means etc
# Map new centroid back to closest lattice point
# Return CAM centroids
cdef int centroid_moved
for iteration in range(max_iterations):
centroid_moved = 1
total_error = 0.0
centroid_movement = 0.0
centroid_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_cam_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_sample_counts = np.zeros(16, dtype=np.int32)
for point_idx in range(samples.shape[0]):
point = samples[point_idx, :]
point_cam = samples[point_idx, :]
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
centroid = centroids[centroid_idx, :]
error = colour_distance_squared(centroid, point)
centroid_rgb12 = centroids_rgb12[centroid_idx, :]
error = colour_distance_squared(convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroid_rgb12), point_cam)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
for i in range(3):
centroid_sample_positions_total[closest_centroid_idx, i] += point[i]
centroid_cam_sample_positions_total[closest_centroid_idx, i] += point_cam[i]
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroid[i] = (
centroid_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(centroids[centroid_idx], new_centroid)
new_centroid_cam[i] = (
centroid_cam_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(
convert_rgb12_iigs_to_cam(rgb12_iigs_to_cam16ucs, centroids_rgb12[centroid_idx]), new_centroid_cam)
new_centroid_rgb12 = convert_cam16ucs_to_rgb12_iigs(new_centroid_cam)
for i in range(3):
if centroids_rgb12[centroid_idx, i] != new_centroid_rgb12[i]:
print(i, centroids_rgb12[centroid_idx, i], new_centroid_rgb12[i])
centroids_rgb12[centroid_idx, i] = new_centroid_rgb12[i]
centroid_moved = 1
centroids[centroid_idx, :] = new_centroid
# print(iteration, total_error, centroids)
print(iteration, centroid_movement, total_error, centroids_rgb12)
if centroid_movement < tolerance:
break
if centroid_moved == 0:
break
return centroids_rgb12, total_error
#@cython.boundscheck(False)
#@cython.wraparound(False)
#cdef float[::1] closest_quantized_point(float [:, ::1] rgb24_to_cam, float [::1] point_cam) nogil:
# cdef unsigned int rgb12, rgb24, closest_rgb24, r, g, b
# cdef double best_distance = 1e9, distance
# for rgb12 in range(2**12):
# r = rgb12 >> 8
# g = (rgb12 >> 4) & 0xf
# b = rgb12 & 0xf
# rgb24 = (r << 20) | (r << 16) | (g << 12) | (g << 8) | (b << 4) | b
# distance = colour_distance_squared(rgb24_to_cam[rgb24], point_cam)
# # print(hex(rgb24), distance)
# if distance < best_distance:
# best_distance = distance
# closest_rgb24 = rgb24
# # print(distance, rgb24, hex(rgb24))
# # print("-->", closest_rgb24, hex(closest_rgb24), best_distance)
# return rgb24_to_cam[closest_rgb24]
return centroids, total_error

View File

@ -10,24 +10,57 @@ import colour
import numpy as np
def srgb_to_linear_array(a: np.ndarray, gamma=2.4) -> np.ndarray:
return np.where(a <= 0.04045, a / 12.92, ((a + 0.055) / 1.055) ** gamma)
def main():
print("Precomputing conversion matrix from RGB to CAM16UCS colour space")
# print("Precomputing conversion matrix from 24-bit RGB to CAM16UCS colour "
# "space")
# # Compute matrix of all 24-bit RGB values, normalized to 0..1 range
# bits24 = np.arange(2 ** 24, dtype=np.uint32).reshape(-1, 1)
# all_rgb24 = np.concatenate(
# [bits24 >> 16 & 0xff, bits24 >> 8 & 0xff, bits24 & 0xff],
# axis=1).astype(np.float32) / 255
# del bits24
#
# with colour.utilities.suppress_warnings(colour_usage_warnings=True):
# # Compute matrix of corresponding CAM16UCS colour values, indexed
# # by 24-bit RGB value
# rgb24_to_cam16ucs = colour.convert(all_rgb24, "RGB", "CAM16UCS").astype(
# np.float32)
# del all_rgb24
# np.save("data/rgb24_to_cam16ucs.npy", rgb24_to_cam16ucs)
# del rgb24_to_cam16ucs
# Compute matrix of all 24-bit RGB values, normalized to 0..1 range
bits24 = np.arange(2 ** 24, dtype=np.uint32).reshape(-1, 1)
all_rgb = np.concatenate(
[bits24 >> 16 & 0xff, bits24 >> 8 & 0xff, bits24 & 0xff],
axis=1).astype(np.float32) / 255
del bits24
print("Precomputing conversion matrix from 12-bit //gs RGB to CAM16UCS "
"colour space")
# Compute matrix of all 12-bit RGB values, normalized to 0..1 range
bits12 = np.arange(2 ** 12, dtype=np.uint32).reshape(-1, 1)
r = bits12 >> 8
g = (bits12 >> 4) & 0xf
b = bits12 & 0xf
all_rgb12 = np.concatenate(
[(r << 4) | r, (g << 4) | g, (b << 4) | b], axis=1).astype(
np.float32) / 255
del bits12, r, g, b
# //gs RGB values use gamma-corrected Rec.601 RGB colour space. We need to
# convert to Rec.709 RGB as preparation for converting to CAM16UCS. We
# do this via the YCbCr intermediate color model.
rgb12_iigs = np.clip(srgb_to_linear_array(
np.clip(colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
all_rgb12, K=colour.WEIGHTS_YCBCR[
'ITU-R BT.601']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1)), 0, 1)
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
# Compute matrix of corresponding CAM16UCS colour values, indexed
# by 24-bit RGB value
all_cam16 = colour.convert(all_rgb, "RGB", "CAM16UCS").astype(
np.float32)
del all_rgb
np.save("data/rgb_to_cam16ucs.npy", all_cam16)
# by 12-bit //gs RGB value
rgb12_iigs_to_cam16ucs = colour.convert(
rgb12_iigs, "RGB", "CAM16UCS").astype(np.float32)
del rgb12_iigs
np.save("data/rgb12_iigs_to_cam16ucs.npy", rgb12_iigs_to_cam16ucs)
del rgb12_iigs_to_cam16ucs
if __name__ == "__main__":
main()