- allow reserving a number of colours which are to be shared across

all palettes.  This will be useful for Total Replay which does an
  animation effect when displaying the image (first set palettes, then
  transition in pixels)

- this requires us to go back to computing k-means ourself instead of
  using sklearn, since it can't keep some centroids fixed

- try to be more careful about //gs RGB values, which are in the
  Rec.601 colour space.  This isn't quite right yet - the issue seems
  to be that since we dither in linear RGB space but quantize in the
  nonlinear space, small differences may lead to a +/- 1 in the 4-bit
  //gs RGB value, which is quite noticeable.  Instead we need to be
  clustering and/or dithering with awareness of the quantized palette
  space.
This commit is contained in:
kris 2021-11-17 17:09:42 +00:00
parent f2f07ddc04
commit 0009ce8913
3 changed files with 136 additions and 32 deletions

View File

@ -28,8 +28,9 @@ import screen as screen_py
class ClusterPalette: class ClusterPalette:
def __init__( def __init__(
self, image: Image): self, image: Image, reserved_colours=0):
self._colours_cam = self._image_colours_cam(image) self._colours_cam = self._image_colours_cam(image)
self._reserved_colours = reserved_colours
self._errors = [1e9] * 16 self._errors = [1e9] * 16
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32) self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32) self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
@ -50,7 +51,14 @@ class ClusterPalette:
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000) clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
clusters.fit_predict(self._colours_cam) clusters.fit_predict(self._colours_cam)
return clusters.cluster_centers_
labels = clusters.labels_
frequency_order = [
k for k, v in sorted(
# List of (palette idx, frequency count)
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
return clusters.cluster_centers_[frequency_order]
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]: def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality. """Attempt to find new palettes that locally improve image quality.
@ -69,15 +77,16 @@ class ClusterPalette:
The current (locally) best palettes are returned and can be applied The current (locally) best palettes are returned and can be applied
using accept_palettes(). using accept_palettes().
""" """
new_errors = list(self._errors)
new_palettes_cam = np.copy(self._palettes_cam)
new_palettes_rgb = np.copy(self._palettes_rgb)
# Compute a new 16-colour global palette for the entire image, # Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the # used as the starting center positions for k-means clustering of the
# individual palettes # individual palettes
self._global_palette = self._fit_global_palette() self._global_palette = self._fit_global_palette()
new_errors = list(self._errors) dynamic_colours = 16 - self._reserved_colours
new_palettes_cam = np.copy(self._palettes_cam)
new_palettes_rgb = np.copy(self._palettes_rgb)
# The 16 palettes are striped across consecutive (overlapping) line # The 16 palettes are striped across consecutive (overlapping) line
# ranges. The basic unit is 200/16 = 12.5 lines, but we extend the # ranges. The basic unit is 200/16 = 12.5 lines, but we extend the
@ -100,25 +109,53 @@ class ClusterPalette:
# be a major issue in practise though, and fixing it would require # be a major issue in practise though, and fixing it would require
# implementing our own (optimized) k-means. # implementing our own (optimized) k-means.
# TODO: tune tolerance # TODO: tune tolerance
clusters = cluster.MiniBatchKMeans( # clusters = cluster.MiniBatchKMeans(
n_clusters=16, max_iter=10000, init=self._global_palette, # n_clusters=16, max_iter=10000,
n_init=1) # init=self._global_palette,
clusters.fit_predict(palette_pixels) # n_init=1)
palette_error = clusters.inertia_ # clusters.fit_predict(palette_pixels)
if palette_error >= self._errors[palette_idx]: #
# Not a local improvement to existing palette # palette_error = clusters.inertia_
clusters, palette_error = dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._reserved_colours,
samples=palette_pixels, initial_centroids=self._global_palette,
max_iterations=1000, tolerance=1e-4
)
if (palette_error >= self._errors[palette_idx] and not
self._reserved_colours):
# Not a local improvement to the existing palette, so ignore it.
# We can't take this shortcut when we're reserving colours
# because it would break the invariant that all palettes must
# share colours.
continue continue
palette_cam = np.array(clusters.cluster_centers_).astype(np.float32) new_palettes_cam[palette_idx, :, :] = np.array(
# clusters.cluster_centers_).astype(np.float32)
clusters).astype(np.float32)
# Suppress divide by zero warning, # Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900 # https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True): with colour.utilities.suppress_warnings(python_warnings=True):
# SHR colour palette only uses 4-bit RGB values palette_rgb = colour.convert(
palette_rgb = (np.round(colour.convert( new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB")
palette_cam, "CAM16UCS", "RGB") * 15) / 15).astype( palette_rgb_rec601 = np.clip(image_py.srgb_to_linear(
np.float32) colour.YCbCr_to_RGB(
new_palettes_cam[palette_idx, :, :] = palette_cam colour.RGB_to_YCbCr(
new_palettes_rgb[palette_idx, :, :] = palette_rgb image_py.linear_to_srgb(palette_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1)
# palette_rgb = np.clip(
# image_py.srgb_to_linear(
# colour.YCbCr_to_RGB(
# colour.RGB_to_YCbCr(
# image_py.linear_to_srgb(
# palette_rgb[:, :] * 255) / 255,
# K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
# K=colour.WEIGHTS_YCBCR[
# 'ITU-R BT.601']) * 255) / 255,
# 0, 1)
new_palettes_rgb[palette_idx, :, :] = palette_rgb # palette_rgb_rec601
new_errors[palette_idx] = palette_error new_errors[palette_idx] = palette_error
return new_palettes_cam, new_palettes_rgb, new_errors return new_palettes_cam, new_palettes_rgb, new_errors
@ -192,7 +229,7 @@ def main():
# TODO: flags # TODO: flags
penalty = 1e9 penalty = 1e9
iterations = 50 iterations = 10 # 50
pygame.init() pygame.init()
# TODO: for some reason I need to execute this twice - the first time # TODO: for some reason I need to execute this twice - the first time
@ -205,8 +242,8 @@ def main():
total_image_error = 1e9 total_image_error = 1e9
iterations_since_improvement = 0 iterations_since_improvement = 0
palette_iigs = np.empty((16, 16, 3), dtype=np.uint8) palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8)
cluster_palette = ClusterPalette(rgb) cluster_palette = ClusterPalette(rgb, reserved_colours=1)
while iterations_since_improvement < iterations: while iterations_since_improvement < iterations:
new_palettes_cam, new_palettes_rgb, new_palette_errors = ( new_palettes_cam, new_palettes_rgb, new_palette_errors = (
@ -237,11 +274,16 @@ def main():
palettes_rgb = new_palettes_rgb palettes_rgb = new_palettes_rgb
# Recompute 4-bit //gs RGB palettes # Recompute 4-bit //gs RGB palettes
palette_rgb_rec601 = np.clip(
colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(palettes_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1)
palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8)
for i in range(16): for i in range(16):
palette_iigs[i, :, :] = ( screen.set_palette(i, palettes_iigs[i, :, :])
np.round(image_py.linear_to_srgb(
palettes_rgb[i, :, :] * 255) / 255 * 15)).astype(np.uint8)
screen.set_palette(i, palette_iigs[i, :, :])
# Recompute current screen RGB image # Recompute current screen RGB image
screen.set_pixels(output_4bit) screen.set_pixels(output_4bit)
@ -249,9 +291,18 @@ def main():
for i in range(200): for i in range(200):
screen.line_palette[i] = line_to_palette[i] screen.line_palette[i] = line_to_palette[i]
output_rgb[i, :, :] = ( output_rgb[i, :, :] = (
palettes_rgb[line_to_palette[i]][ palettes_rgb[line_to_palette[i]][output_4bit[i, :]] * 255
output_4bit[i, :]] * 255).astype(np.uint8) ).astype(
output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8) # np.round(palettes_rgb[line_to_palette[i]][
# output_4bit[i, :]] * 15) / 15 * 255).astype(
np.uint8)
output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(output_rgb) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255
output_srgb = (image_py.linear_to_srgb(output_rgb)).astype(np.uint8)
# dither = dither_pattern.PATTERNS[args.dither]() # dither = dither_pattern.PATTERNS[args.dither]()
# bitmap = dither_pyx.dither_image( # bitmap = dither_pyx.dither_image(
@ -275,8 +326,8 @@ def main():
np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes
canvas.blit(surface, (0, 0)) canvas.blit(surface, (0, 0))
pygame.display.flip() pygame.display.flip()
print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(palette_iigs.reshape(-1, 3), axis=0).shape[0] unique_colours = np.unique(palettes_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours) print("%d unique colours" % unique_colours)
# Save Double hi-res image # Save Double hi-res image

View File

@ -341,7 +341,7 @@ def dither_image(
def dither_shr( def dither_shr(
float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb, float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb,
float[:,::1] rgb_to_cam16ucs, float penalty): float[:,::1] rgb_to_cam16ucs, float penalty):
cdef int y, x, idx, best_colour_idx, best_palette cdef int y, x, idx, best_colour_idx, best_palette, i
cdef double best_distance, distance, total_image_error cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam
cdef float quant_error cdef float quant_error
@ -357,6 +357,8 @@ def dither_shr(
total_image_error = 0.0 total_image_error = 0.0
for y in range(200): for y in range(200):
for x in range(320): for x in range(320):
#for i in range(3):
# working_image[y, x, i] = np.round(working_image[y, x, i] * 15) / 15
colour_cam = convert_rgb_to_cam16ucs( colour_cam = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2]) rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
line_cam[x, :] = colour_cam line_cam[x, :] = colour_cam
@ -489,3 +491,53 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette
best_palette_idx = palette_idx best_palette_idx = palette_idx
return best_palette_idx return best_palette_idx
@cython.boundscheck(False)
@cython.wraparound(False)
def k_means_with_fixed_centroids(
int n_clusters, int n_fixed, float[:, ::1] samples, float[:, ::1] initial_centroids, int max_iterations, float tolerance):
cdef double error, best_error, centroid_movement, total_error
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef float[:, ::1] centroids = initial_centroids[:, :]
cdef float[::1] centroid, point, new_centroid = np.empty(3, dtype=np.float32)
cdef float[:, ::1] centroid_sample_positions_total
cdef int[::1] centroid_sample_counts
for iteration in range(max_iterations):
total_error = 0.0
centroid_movement = 0.0
centroid_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_sample_counts = np.zeros(16, dtype=np.int32)
for point_idx in range(samples.shape[0]):
point = samples[point_idx, :]
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
centroid = centroids[centroid_idx, :]
error = colour_distance_squared(centroid, point)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
for i in range(3):
centroid_sample_positions_total[closest_centroid_idx, i] += point[i]
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroid[i] = (
centroid_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(centroids[centroid_idx], new_centroid)
centroids[centroid_idx, :] = new_centroid
# print(iteration, total_error, centroids)
if centroid_movement < tolerance:
break
return centroids, total_error

View File

@ -55,6 +55,7 @@ class SHR320Screen:
for palette_idx, palette in self.palettes.items(): for palette_idx, palette in self.palettes.items():
for rgb_idx, rgb in enumerate(palette): for rgb_idx, rgb in enumerate(palette):
r, g, b = rgb r, g, b = rgb
assert r <= 15 and g <= 15 and b <= 15
# print(r, g, b) # print(r, g, b)
rgb_low = (g << 4) | b rgb_low = (g << 4) | b
rgb_hi = r rgb_hi = r