- allow reserving a number of colours which are to be shared across

all palettes.  This will be useful for Total Replay which does an
  animation effect when displaying the image (first set palettes, then
  transition in pixels)

- this requires us to go back to computing k-means ourself instead of
  using sklearn, since it can't keep some centroids fixed

- try to be more careful about //gs RGB values, which are in the
  Rec.601 colour space.  This isn't quite right yet - the issue seems
  to be that since we dither in linear RGB space but quantize in the
  nonlinear space, small differences may lead to a +/- 1 in the 4-bit
  //gs RGB value, which is quite noticeable.  Instead we need to be
  clustering and/or dithering with awareness of the quantized palette
  space.
This commit is contained in:
kris 2021-11-17 17:09:42 +00:00
parent f2f07ddc04
commit 0009ce8913
3 changed files with 136 additions and 32 deletions

View File

@ -28,8 +28,9 @@ import screen as screen_py
class ClusterPalette:
def __init__(
self, image: Image):
self, image: Image, reserved_colours=0):
self._colours_cam = self._image_colours_cam(image)
self._reserved_colours = reserved_colours
self._errors = [1e9] * 16
self._palettes_cam = np.empty((16, 16, 3), dtype=np.float32)
self._palettes_rgb = np.empty((16, 16, 3), dtype=np.float32)
@ -50,7 +51,14 @@ class ClusterPalette:
clusters = cluster.MiniBatchKMeans(n_clusters=16, max_iter=10000)
clusters.fit_predict(self._colours_cam)
return clusters.cluster_centers_
labels = clusters.labels_
frequency_order = [
k for k, v in sorted(
# List of (palette idx, frequency count)
list(zip(*np.unique(labels, return_counts=True))),
key=lambda kv: kv[1], reverse=True)]
return clusters.cluster_centers_[frequency_order]
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
"""Attempt to find new palettes that locally improve image quality.
@ -69,15 +77,16 @@ class ClusterPalette:
The current (locally) best palettes are returned and can be applied
using accept_palettes().
"""
new_errors = list(self._errors)
new_palettes_cam = np.copy(self._palettes_cam)
new_palettes_rgb = np.copy(self._palettes_rgb)
# Compute a new 16-colour global palette for the entire image,
# used as the starting center positions for k-means clustering of the
# individual palettes
self._global_palette = self._fit_global_palette()
new_errors = list(self._errors)
new_palettes_cam = np.copy(self._palettes_cam)
new_palettes_rgb = np.copy(self._palettes_rgb)
dynamic_colours = 16 - self._reserved_colours
# The 16 palettes are striped across consecutive (overlapping) line
# ranges. The basic unit is 200/16 = 12.5 lines, but we extend the
@ -100,25 +109,53 @@ class ClusterPalette:
# be a major issue in practise though, and fixing it would require
# implementing our own (optimized) k-means.
# TODO: tune tolerance
clusters = cluster.MiniBatchKMeans(
n_clusters=16, max_iter=10000, init=self._global_palette,
n_init=1)
clusters.fit_predict(palette_pixels)
palette_error = clusters.inertia_
if palette_error >= self._errors[palette_idx]:
# Not a local improvement to existing palette
# clusters = cluster.MiniBatchKMeans(
# n_clusters=16, max_iter=10000,
# init=self._global_palette,
# n_init=1)
# clusters.fit_predict(palette_pixels)
#
# palette_error = clusters.inertia_
clusters, palette_error = dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._reserved_colours,
samples=palette_pixels, initial_centroids=self._global_palette,
max_iterations=1000, tolerance=1e-4
)
if (palette_error >= self._errors[palette_idx] and not
self._reserved_colours):
# Not a local improvement to the existing palette, so ignore it.
# We can't take this shortcut when we're reserving colours
# because it would break the invariant that all palettes must
# share colours.
continue
palette_cam = np.array(clusters.cluster_centers_).astype(np.float32)
new_palettes_cam[palette_idx, :, :] = np.array(
# clusters.cluster_centers_).astype(np.float32)
clusters).astype(np.float32)
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
# SHR colour palette only uses 4-bit RGB values
palette_rgb = (np.round(colour.convert(
palette_cam, "CAM16UCS", "RGB") * 15) / 15).astype(
np.float32)
new_palettes_cam[palette_idx, :, :] = palette_cam
new_palettes_rgb[palette_idx, :, :] = palette_rgb
palette_rgb = colour.convert(
new_palettes_cam[palette_idx, :, :], "CAM16UCS", "RGB")
palette_rgb_rec601 = np.clip(image_py.srgb_to_linear(
colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(palette_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']) * 255) / 255, 0, 1)
# palette_rgb = np.clip(
# image_py.srgb_to_linear(
# colour.YCbCr_to_RGB(
# colour.RGB_to_YCbCr(
# image_py.linear_to_srgb(
# palette_rgb[:, :] * 255) / 255,
# K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
# K=colour.WEIGHTS_YCBCR[
# 'ITU-R BT.601']) * 255) / 255,
# 0, 1)
new_palettes_rgb[palette_idx, :, :] = palette_rgb # palette_rgb_rec601
new_errors[palette_idx] = palette_error
return new_palettes_cam, new_palettes_rgb, new_errors
@ -192,7 +229,7 @@ def main():
# TODO: flags
penalty = 1e9
iterations = 50
iterations = 10 # 50
pygame.init()
# TODO: for some reason I need to execute this twice - the first time
@ -205,8 +242,8 @@ def main():
total_image_error = 1e9
iterations_since_improvement = 0
palette_iigs = np.empty((16, 16, 3), dtype=np.uint8)
cluster_palette = ClusterPalette(rgb)
palettes_iigs = np.empty((16, 16, 3), dtype=np.uint8)
cluster_palette = ClusterPalette(rgb, reserved_colours=1)
while iterations_since_improvement < iterations:
new_palettes_cam, new_palettes_rgb, new_palette_errors = (
@ -237,11 +274,16 @@ def main():
palettes_rgb = new_palettes_rgb
# Recompute 4-bit //gs RGB palettes
palette_rgb_rec601 = np.clip(
colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(palettes_rgb * 255) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']), 0, 1)
palettes_iigs = np.round(palette_rgb_rec601 * 15).astype(np.uint8)
for i in range(16):
palette_iigs[i, :, :] = (
np.round(image_py.linear_to_srgb(
palettes_rgb[i, :, :] * 255) / 255 * 15)).astype(np.uint8)
screen.set_palette(i, palette_iigs[i, :, :])
screen.set_palette(i, palettes_iigs[i, :, :])
# Recompute current screen RGB image
screen.set_pixels(output_4bit)
@ -249,9 +291,18 @@ def main():
for i in range(200):
screen.line_palette[i] = line_to_palette[i]
output_rgb[i, :, :] = (
palettes_rgb[line_to_palette[i]][
output_4bit[i, :]] * 255).astype(np.uint8)
output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8)
palettes_rgb[line_to_palette[i]][output_4bit[i, :]] * 255
).astype(
# np.round(palettes_rgb[line_to_palette[i]][
# output_4bit[i, :]] * 15) / 15 * 255).astype(
np.uint8)
output_srgb_rec709 = np.clip(colour.YCbCr_to_RGB(
colour.RGB_to_YCbCr(
image_py.linear_to_srgb(output_rgb) / 255,
K=colour.WEIGHTS_YCBCR['ITU-R BT.601']),
K=colour.WEIGHTS_YCBCR['ITU-R BT.709']), 0, 1) * 255
output_srgb = (image_py.linear_to_srgb(output_rgb)).astype(np.uint8)
# dither = dither_pattern.PATTERNS[args.dither]()
# bitmap = dither_pyx.dither_image(
@ -275,8 +326,8 @@ def main():
np.asarray(out_image).transpose((1, 0, 2))) # flip y/x axes
canvas.blit(surface, (0, 0))
pygame.display.flip()
unique_colours = np.unique(palette_iigs.reshape(-1, 3), axis=0).shape[0]
print((palettes_rgb * 255).astype(np.uint8))
unique_colours = np.unique(palettes_iigs.reshape(-1, 3), axis=0).shape[0]
print("%d unique colours" % unique_colours)
# Save Double hi-res image

View File

@ -341,7 +341,7 @@ def dither_image(
def dither_shr(
float[:, :, ::1] input_rgb, float[:, :, ::1] palettes_cam, float[:, :, ::1] palettes_rgb,
float[:,::1] rgb_to_cam16ucs, float penalty):
cdef int y, x, idx, best_colour_idx, best_palette
cdef int y, x, idx, best_colour_idx, best_palette, i
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam
cdef float quant_error
@ -357,6 +357,8 @@ def dither_shr(
total_image_error = 0.0
for y in range(200):
for x in range(320):
#for i in range(3):
# working_image[y, x, i] = np.round(working_image[y, x, i] * 15) / 15
colour_cam = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
line_cam[x, :] = colour_cam
@ -489,3 +491,53 @@ cdef int best_palette_for_line(float [:, ::1] line_cam, float[:, :, ::1] palette
best_palette_idx = palette_idx
return best_palette_idx
@cython.boundscheck(False)
@cython.wraparound(False)
def k_means_with_fixed_centroids(
int n_clusters, int n_fixed, float[:, ::1] samples, float[:, ::1] initial_centroids, int max_iterations, float tolerance):
cdef double error, best_error, centroid_movement, total_error
cdef int centroid_idx, closest_centroid_idx, i, point_idx
cdef float[:, ::1] centroids = initial_centroids[:, :]
cdef float[::1] centroid, point, new_centroid = np.empty(3, dtype=np.float32)
cdef float[:, ::1] centroid_sample_positions_total
cdef int[::1] centroid_sample_counts
for iteration in range(max_iterations):
total_error = 0.0
centroid_movement = 0.0
centroid_sample_positions_total = np.zeros((16, 3), dtype=np.float32)
centroid_sample_counts = np.zeros(16, dtype=np.int32)
for point_idx in range(samples.shape[0]):
point = samples[point_idx, :]
best_error = 1e9
closest_centroid_idx = 0
for centroid_idx in range(n_clusters):
centroid = centroids[centroid_idx, :]
error = colour_distance_squared(centroid, point)
if error < best_error:
best_error = error
closest_centroid_idx = centroid_idx
for i in range(3):
centroid_sample_positions_total[closest_centroid_idx, i] += point[i]
centroid_sample_counts[closest_centroid_idx] += 1
total_error += best_error
for centroid_idx in range(n_fixed, n_clusters):
if centroid_sample_counts[centroid_idx]:
for i in range(3):
new_centroid[i] = (
centroid_sample_positions_total[centroid_idx, i] / centroid_sample_counts[centroid_idx])
centroid_movement += colour_distance_squared(centroids[centroid_idx], new_centroid)
centroids[centroid_idx, :] = new_centroid
# print(iteration, total_error, centroids)
if centroid_movement < tolerance:
break
return centroids, total_error

View File

@ -55,6 +55,7 @@ class SHR320Screen:
for palette_idx, palette in self.palettes.items():
for rgb_idx, rgb in enumerate(palette):
r, g, b = rgb
assert r <= 15 and g <= 15 and b <= 15
# print(r, g, b)
rgb_low = (g << 4) | b
rgb_hi = r