Improvements to image quality:

- Preprocess the source image by dithering with the full 12-bit //gs
  colour palette, ignoring SHR palette restrictions (i.e. each pixel
  chosen independently from 4096 colours)

- Using this as the ground truth allows much better handling of
  e.g. solid colours, which were being dithered inconsistently with
  the previous approach

- Also when fitting an SHR palette, fix any colours that comprise more
  than 10% of source pixels.  This also encourages more uniformity in
  regions of solid colour.
This commit is contained in:
kris 2021-11-25 11:46:42 +00:00
parent 870c008827
commit ad50ed103d
2 changed files with 237 additions and 76 deletions

View File

@ -30,14 +30,27 @@ import screen as screen_py
class ClusterPalette:
def __init__(
self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
self, image: np.ndarray, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
fixed_colours=0):
# Source image in 24-bit linear RGB colour space
self._image_rgb = image
# Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS
# colour space
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
# Source image in CAM16UCS colour space
self._colours_cam = self._image_colours_cam(image)
# Conversion matrix from 24-bit linear RGB colour space to CAM16UCS
# colour space
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
# Preprocessed source image in 24-bit linear RGB colour space. We
# first dither the source image using the full 12-bit //gs RGB colour
# palette, ignoring SHR palette limitations (i.e. 4096 independent
# colours for each pixel). This gives much better results for e.g.
# solid blocks of colour, which would be dithered inconsistently if
# targeting the source image directly.
self._image_rgb = self._perfect_dither(image)
# Preprocessed source image in CAM16UCS colour space
self._colours_cam = self._image_colours_cam(self._image_rgb)
# How many image colours to fix identically across all 16 SHR
# palettes. These are taken to be the most prevalent colours from
@ -58,14 +71,6 @@ class ClusterPalette:
# defaultdict(list) mapping palette index to lines using this palette
self._palette_lines = self._init_palette_lines()
# Conversion matrix from 12-bit //gs RGB colour space to CAM16UCS
# colour space
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
# Conversion matrix from 24-bit linear RGB colour space to CAM16UCS
# colour space
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
def _image_colours_cam(self, image: Image):
colours_rgb = np.asarray(image) # .reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -113,6 +118,23 @@ class ClusterPalette:
int(np.round(palette_upper))))
return palette_ranges
def _perfect_dither(self, source_image: np.ndarray):
"""Dither a "perfect" image using the full 12-bit //gs RGB colour
palette, ignoring restrictions."""
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
with colour.utilities.suppress_warnings(python_warnings=True):
full_palette_linear_rgb = colour.convert(
self._rgb12_iigs_to_cam16ucs, "CAM16UCS", "RGB").astype(
np.float32)
total_image_error, image_rgb = dither_pyx.dither_shr_perfect(
source_image, self._rgb12_iigs_to_cam16ucs, full_palette_linear_rgb,
self._rgb24_to_cam16ucs)
# print("Perfect image error:", total_image_error)
return image_rgb
def _dither_image(self, palettes_cam, penalty):
# Suppress divide by zero warning,
# https://github.com/colour-science/colour/issues/900
@ -227,9 +249,29 @@ class ClusterPalette:
seen_colours.add(tuple(new_colour))
initial_centroids[i, :] = new_colour
# If there are any single colours in our source //gs RGB
# pixels that represent more than fixed_colour_fraction_threshold
# of the pixels, then fix these colours for the palette instead of
# clustering them. This reduces artifacting on blocks of
# colour.
fixed_colour_fraction_threshold = 0.1
fixed_colours = self._fixed_colours
for colour, freq in sorted(list(zip(
*np.unique(dither_pyx.convert_cam16ucs_to_rgb12_iigs(
palette_pixels), return_counts=True, axis=0))),
key=lambda kv: kv[1], reverse=True):
if freq < (palette_pixels.shape[0] *
fixed_colour_fraction_threshold):
break
# print(colour, freq)
if tuple(colour) not in seen_colours:
seen_colours.add(tuple(colour))
initial_centroids[fixed_colours, :] = colour
fixed_colours += 1
palettes_rgb12_iigs, palette_error = \
dither_pyx.k_means_with_fixed_centroids(
n_clusters=16, n_fixed=self._fixed_colours,
n_clusters=16, n_fixed=fixed_colours,
samples=palette_pixels,
initial_centroids=initial_centroids,
max_iterations=1000, tolerance=0.05,
@ -260,6 +302,7 @@ class ClusterPalette:
palette_freq = {idx: 0 for idx in range(16)}
for idx, freq in zip(*np.unique(clusters.labels_, return_counts=True)):
palette_freq[idx] = freq
frequency_order = [
k for k, v in sorted(
list(palette_freq.items()), key=lambda kv: kv[1], reverse=True)]

View File

@ -336,6 +336,121 @@ def dither_image(
return image_nbit_to_bitmap(image_nbit, xres, yres, palette_depth)
@cython.boundscheck(False)
@cython.wraparound(False)
def dither_shr_perfect(
float[:, :, ::1] input_rgb, float[:, ::1] full_palette_cam, float[:, ::1] full_palette_rgb,
float[:,::1] rgb_to_cam16ucs):
cdef int y, x, idx, best_colour_idx, i
cdef double best_distance, distance, total_image_error
cdef float[::1] best_colour_rgb, pixel_cam
cdef float quant_error
cdef float[:, ::1] palette_rgb, palette_cam
cdef float[:, :, ::1] working_image = np.copy(input_rgb)
cdef float[:, ::1] line_cam = np.zeros((320, 3), dtype=np.float32)
cdef int palette_size = full_palette_rgb.shape[0]
cdef float decay = 0.5
cdef float min_quant_error = 0.0 # 0.02
cdef int floyd_steinberg = 1
total_image_error = 0.0
for y in range(200):
for x in range(320):
line_cam[x, :] = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y,x,0], working_image[y,x,1], working_image[y,x,2])
for x in range(320):
pixel_cam = convert_rgb_to_cam16ucs(
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = -1
for idx in range(palette_size):
distance = colour_distance_squared(pixel_cam, full_palette_cam[idx, :])
if distance < best_distance:
best_distance = distance
best_colour_idx = idx
best_colour_rgb = full_palette_rgb[best_colour_idx]
total_image_error += best_distance
for i in range(3):
quant_error = working_image[y, x, i] - best_colour_rgb[i]
if abs(quant_error) <= min_quant_error:
quant_error = 0
working_image[y, x, i] = best_colour_rgb[i]
if floyd_steinberg:
# Floyd-Steinberg dither
# 0 * 7
# 3 5 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
# 0 0 X 7 5
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
return total_image_error, working_image
@cython.boundscheck(False)
@cython.wraparound(False)
def dither_shr(
@ -357,6 +472,7 @@ def dither_shr(
cdef float decay = 0.5
cdef float min_quant_error = 0.0 # 0.02
cdef int floyd_steinberg = 1
best_palette = -1
total_image_error = 0.0
@ -393,70 +509,72 @@ def dither_shr(
if abs(quant_error) <= min_quant_error:
quant_error = 0
# Floyd-Steinberg dither
# 0 * 7
# 3 5 1
working_image[y, x, i] = best_colour_rgb[i]
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 32), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 32), 0, 1)
if floyd_steinberg:
# Floyd-Steinberg dither
# 0 * 7
# 3 5 1
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 32), 0, 1)
# # 0 0 X 7 5
# # 3 5 7 5 3
# # 1 3 5 3 1
#if x < 319:
# working_image[y, x + 1, i] = clip(
# working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
#if x < 318:
# working_image[y, x + 2, i] = clip(
# working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
#if y < 199:
# if x > 1:
# working_image[y + 1, x - 2, i] = clip(
# working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 1, x - 1, i] = clip(
# working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
# 1)
# working_image[y + 1, x, i] = clip(
# working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
# if x < 319:
# working_image[y + 1, x + 1, i] = clip(
# working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 1, x + 2, i] = clip(
# working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
# 0, 1)
#if y < 198:
# if x > 1:
# working_image[y + 2, x - 2, i] = clip(
# working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 2, x - 1, i] = clip(
# working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
# 1)
# working_image[y + 2, x, i] = clip(
# working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
# if x < 319:
# working_image[y + 2, x + 1, i] = clip(
# working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 2, x + 2, i] = clip(
# working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
# 0, 1)
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (3 / 16), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + decay * quant_error * (5 / 16), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (1 / 16), 0, 1)
else:
# Jarvis
# 0 0 X 7 5
# 3 5 7 5 3
# 1 3 5 3 1
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
if x < 318:
working_image[y, x + 2, i] = clip(
working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
if y < 199:
if x > 1:
working_image[y + 1, x - 2, i] = clip(
working_image[y + 1, x - 2, i] + decay * quant_error * (3 / 48), 0,
1)
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + decay * quant_error * (5 / 48), 0,
1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + decay * quant_error * (7 / 48), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + decay * quant_error * (5 / 48),
0, 1)
if x < 318:
working_image[y + 1, x + 2, i] = clip(
working_image[y + 1, x + 2, i] + decay * quant_error * (3 / 48),
0, 1)
if y < 198:
if x > 1:
working_image[y + 2, x - 2, i] = clip(
working_image[y + 2, x - 2, i] + decay * decay * quant_error * (1 / 48), 0,
1)
if x > 0:
working_image[y + 2, x - 1, i] = clip(
working_image[y + 2, x - 1, i] + decay * decay * quant_error * (3 / 48), 0,
1)
working_image[y + 2, x, i] = clip(
working_image[y + 2, x, i] + decay * decay * quant_error * (5 / 48), 0, 1)
if x < 319:
working_image[y + 2, x + 1, i] = clip(
working_image[y + 2, x + 1, i] + decay * decay * quant_error * (3 / 48),
0, 1)
if x < 318:
working_image[y + 2, x + 2, i] = clip(
working_image[y + 2, x + 2, i] + decay * decay * quant_error * (1 / 48),
0, 1)
return np.array(output_4bit, dtype=np.uint8), line_to_palette, total_image_error, np.array(palette_line_errors, dtype=np.float64)