* Modify Floyd-Steinberg dithering to diffuse less error in the y

direction.  Otherwise, errors can accumulate in an RGB channel if
  there are no palette colours with an extremal value, and then when
  we introduce a new palette the error all suddenly discharges in a
  spurious horizontal line.  This now gives quite good results!

* Switch to using L1-norm for k-means, per suggestion of Lucas
  Scharenbroich: "A k-medians effectively uses an L1 distance metric
  instead of L2 for k-means.  Using a squared distance metric causes
  the fit to "fall off" too quickly and allows too many of the k
  centroids to cluster around areas of high density, which results in
  many similar colors being selected.  A linear cost function forces
  the centroids to spread out since the error influence has a broader
  range."
This commit is contained in:
kris 2021-11-11 11:10:22 +00:00
parent 8c34d87216
commit ee2229d0ea
3 changed files with 80 additions and 60 deletions

View File

@ -65,6 +65,10 @@ def cluster_palette(image: Image):
palette_cam = None
for palette_idx in range(16):
line_colours = palette_colours[palette_idx]
#if palette_idx < 15:
# line_colours += palette_colours[palette_idx + 1]
# if palette_idx < 14:
# line_colours += palette_colours[palette_idx + 2]
# if palette_idx > 0:
# fixed_centroids = palette_cam[:8, :]
# else:
@ -72,6 +76,11 @@ def cluster_palette(image: Image):
# print(np.array(line_colours), fixed_centroids)
palette_cam = dither_pyx.k_means_with_fixed_centroids(16, np.array(
line_colours), fixed_centroids=fixed_centroids, tolerance=1e-6)
#kmeans = KMeans(n_clusters=16, max_iter=10000)
#kmeans.fit_predict(line_colours)
#palette_cam = kmeans.cluster_centers_
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
palette_rgb = colour.convert(palette_cam, "CAM16UCS", "RGB")
# SHR colour palette only uses 4-bit values

View File

@ -163,6 +163,11 @@ cdef inline float[::1] convert_rgb_to_cam16ucs(float[:, ::1] rgb_to_cam16ucs, fl
cdef int rgb_24bit = (<int>(r*255) << 16) + (<int>(g*255) << 8) + <int>(b*255)
return rgb_to_cam16ucs[rgb_24bit]
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline float fabs(float value) nogil:
return -value if value < 0 else value
@cython.boundscheck(False)
@cython.wraparound(False)
@ -170,6 +175,12 @@ cdef inline float colour_distance_squared(float[::1] colour1, float[::1] colour2
return (colour1[0] - colour2[0]) ** 2 + (colour1[1] - colour2[1]) ** 2 + (colour1[2] - colour2[2]) ** 2
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline float colour_distance(float[::1] colour1, float[::1] colour2) nogil:
return fabs(colour1[0] - colour2[0]) + fabs(colour1[1] - colour2[1]) + fabs(colour1[2] - colour2[2])
# Perform error diffusion to a single image row.
#
# Args:
@ -326,16 +337,17 @@ def dither_image(
import colour
@cython.boundscheck(False)
@cython.boundscheck(True)
@cython.wraparound(False)
def dither_shr(float[:, :, ::1] working_image, object palettes_rgb, float[:,::1] rgb_to_cam16ucs, object line_to_palette):
cdef int y, x, idx, best_colour_idx
cdef float best_distance, distance
cdef float[::1] best_colour_rgb, pixel_cam, colour_rgb, colour_cam
cdef float[3] quant_error
cdef float quant_error
cdef float[:, ::1] palette_rgb
cdef (unsigned char)[:, ::1] output_4bit = np.zeros((200, 320), dtype=np.uint8)
# cdef (unsigned char)[:, :, ::1] output_rgb = np.zeros((200, 320, 3), dtype=np.uint8)
for y in range(200):
print(y)
@ -345,18 +357,19 @@ def dither_shr(float[:, :, ::1] working_image, object palettes_rgb, float[:,::1]
rgb_to_cam16ucs, working_image[y, x, 0], working_image[y, x, 1], working_image[y, x, 2])
best_distance = 1e9
best_colour_idx = 0
best_colour_idx = -1
for idx, colour_rgb in enumerate(palette_rgb):
colour_cam = convert_rgb_to_cam16ucs(rgb_to_cam16ucs, colour_rgb[0], colour_rgb[1], colour_rgb[2])
distance = colour_distance_squared(pixel_cam, colour_cam)
if distance < best_distance:
best_distance = distance
best_colour_rgb = colour_rgb
best_colour_idx = idx
best_colour_rgb = palette_rgb[best_colour_idx]
output_4bit[y, x] = best_colour_idx
for i in range(3):
quant_error[i] = working_image[y, x, i] - best_colour_rgb[i]
# output_rgb[y,x,i] = <int>(best_colour_rgb[i] * 255)
quant_error = working_image[y, x, i] - best_colour_rgb[i]
# Floyd-Steinberg dither
# 0 * 7
@ -364,68 +377,66 @@ def dither_shr(float[:, :, ::1] working_image, object palettes_rgb, float[:,::1]
working_image[y, x, i] = best_colour_rgb[i]
if x < 319:
working_image[y, x + 1, i] = clip(
working_image[y, x + 1, i] + quant_error[i] * (7 / 16), 0, 1)
working_image[y, x + 1, i] + quant_error * (7 / 16), 0, 1)
if y < 199:
if x > 0:
working_image[y + 1, x - 1, i] = clip(
working_image[y + 1, x - 1, i] + quant_error[i] * (3 / 16), 0,
1)
working_image[y + 1, x - 1, i] + quant_error * (3 / 32), 0, 1)
working_image[y + 1, x, i] = clip(
working_image[y + 1, x, i] + quant_error[i] * (5 / 16), 0, 1)
working_image[y + 1, x, i] + quant_error * (5 / 32), 0, 1)
if x < 319:
working_image[y + 1, x + 1, i] = clip(
working_image[y + 1, x + 1, i] + quant_error[i] * (1 / 16),
0, 1)
working_image[y + 1, x + 1, i] + quant_error * (1 / 32), 0, 1)
# # 0 0 X 7 5
# # 3 5 7 5 3
# # 1 3 5 3 1
# if x < 319:
# working_image[y, x + 1, i] = clip(
# working_image[y, x + 1, i] + quant_error[i] * (7 / 48), 0, 1)
# if x < 318:
# working_image[y, x + 2, i] = clip(
# working_image[y, x + 2, i] + quant_error[i] * (5 / 48), 0, 1)
# if y < 199:
# if x > 1:
# working_image[y + 1, x - 2, i] = clip(
# working_image[y + 1, x - 2, i] + quant_error[i] * (3 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 1, x - 1, i] = clip(
# working_image[y + 1, x - 1, i] + quant_error[i] * (5 / 48), 0,
# 1)
# working_image[y + 1, x, i] = clip(
# working_image[y + 1, x, i] + quant_error[i] * (7 / 48), 0, 1)
# if x < 319:
# working_image[y + 1, x + 1, i] = clip(
# working_image[y + 1, x + 1, i] + quant_error[i] * (5 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 1, x + 2, i] = clip(
# working_image[y + 1, x + 2, i] + quant_error[i] * (3 / 48),
# 0, 1)
# if y < 198:
# if x > 1:
# working_image[y + 2, x - 2, i] = clip(
# working_image[y + 2, x - 2, i] + quant_error[i] * (1 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 2, x - 1, i] = clip(
# working_image[y + 2, x - 1, i] + quant_error[i] * (3 / 48), 0,
# 1)
# working_image[y + 2, x, i] = clip(
# working_image[y + 2, x, i] + quant_error[i] * (5 / 48), 0, 1)
# if x < 319:
# working_image[y + 2, x + 1, i] = clip(
# working_image[y + 2, x + 1, i] + quant_error[i] * (3 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 2, x + 2, i] = clip(
# working_image[y + 2, x + 2, i] + quant_error[i] * (1 / 48),
# 0, 1)
#if x < 319:
# working_image[y, x + 1, i] = clip(
# working_image[y, x + 1, i] + quant_error * (7 / 48), 0, 1)
#if x < 318:
# working_image[y, x + 2, i] = clip(
# working_image[y, x + 2, i] + quant_error * (5 / 48), 0, 1)
#if y < 199:
# if x > 1:
# working_image[y + 1, x - 2, i] = clip(
# working_image[y + 1, x - 2, i] + quant_error * (3 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 1, x - 1, i] = clip(
# working_image[y + 1, x - 1, i] + quant_error * (5 / 48), 0,
# 1)
# working_image[y + 1, x, i] = clip(
# working_image[y + 1, x, i] + quant_error * (7 / 48), 0, 1)
# if x < 319:
# working_image[y + 1, x + 1, i] = clip(
# working_image[y + 1, x + 1, i] + quant_error * (5 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 1, x + 2, i] = clip(
# working_image[y + 1, x + 2, i] + quant_error * (3 / 48),
# 0, 1)
#if y < 198:
# if x > 1:
# working_image[y + 2, x - 2, i] = clip(
# working_image[y + 2, x - 2, i] + quant_error * (1 / 48), 0,
# 1)
# if x > 0:
# working_image[y + 2, x - 1, i] = clip(
# working_image[y + 2, x - 1, i] + quant_error * (3 / 48), 0,
# 1)
# working_image[y + 2, x, i] = clip(
# working_image[y + 2, x, i] + quant_error * (5 / 48), 0, 1)
# if x < 319:
# working_image[y + 2, x + 1, i] = clip(
# working_image[y + 2, x + 1, i] + quant_error * (3 / 48),
# 0, 1)
# if x < 318:
# working_image[y + 2, x + 2, i] = clip(
# working_image[y + 2, x + 2, i] + quant_error * (1 / 48),
# 0, 1)
return np.array(output_4bit, dtype=np.uint8)
return np.array(output_4bit, dtype=np.uint8) #, np.array(output_rgb, dtype=np.uint8)
import collections
import random
@ -457,7 +468,7 @@ def k_means_with_fixed_centroids(
best_dist = 1e9
best_centroid_idx = 0
for centroid_idx, centroid in enumerate(centroids):
dist = colour_distance_squared(centroid, point)
dist = colour_distance(centroid, point)
if dist < best_dist:
best_dist = dist
best_centroid_idx = centroid_idx
@ -471,7 +482,7 @@ def k_means_with_fixed_centroids(
continue
new_centroid = np.mean(np.array(points), axis=0)
old_centroid = centroids[centroid_idx]
centroid_movement += colour_distance_squared(old_centroid, new_centroid)
centroid_movement += colour_distance(old_centroid, new_centroid)
centroids[centroid_idx, :] = new_centroid
# print("iteration %d: movement %f" % (iteration, centroid_movement))
if centroid_movement < tolerance:

View File

@ -28,7 +28,7 @@ class SHR320Screen:
# XXX check element range
if palette.dtype != np.uint8:
raise ValueError("Palette must be of type np.uint8")
print(palette)
# print(palette)
self.palettes[idx] = np.array(palette)
def set_pixels(self, pixels):
@ -58,7 +58,7 @@ class SHR320Screen:
# print(r, g, b)
rgb_low = (g << 4) | b
rgb_hi = r
print(hex(rgb_hi), hex(rgb_low))
# print(hex(rgb_hi), hex(rgb_low))
palette_idx_offset = palette_offset + (32 * palette_idx)
dump[palette_idx_offset + (2 * rgb_idx)] = rgb_low
dump[palette_idx_offset + (2 * rgb_idx + 1)] = rgb_hi