WIP - interleave 3 successive palettes for each contiguous row range.

Avoids the banding but not clear if it's overall better

Also implement my own k-means clustering which is able to keep some
centroids fixed, e.g. to be able to retain some fixed palette entries
while swapping out others.  I was hoping this would improve colour
blending across neighbouring palettes but it's also not clear if it
does.
This commit is contained in:
kris 2021-11-10 18:30:39 +00:00
parent 322123522c
commit 8c34d87216
2 changed files with 94 additions and 14 deletions

View File

@ -23,13 +23,29 @@ import screen as screen_py
# - support LR/DLR
# - support HGR
def cluster_palette(image: Image):
shuffle_lines = list(range(200))
random.shuffle(shuffle_lines)
line_to_palette = {}
for idx, line in enumerate(shuffle_lines):
line_to_palette[line] = idx % 16
#shuffle_lines = liprint(st(range(200))
#random.shuffle(shuffle_lines)
#for idx, line in enumerate(shuffle_lines):
# line_to_palette[line] = idx % 16
# for line in range(200):
# if line % 3 == 0:
# line_to_palette[line] = int(line / (200 / 16))
# elif line % 3 == 1:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
# else:
# line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
for line in range(200):
if line % 3 == 0:
line_to_palette[line] = int(line / (200 / 16))
elif line % 3 == 1:
line_to_palette[line] = np.clip(int(line / (200 / 16)) + 1, 0, 15)
else:
line_to_palette[line] = np.clip(int(line / (200 / 16)) + 2, 0, 15)
colours_rgb = np.asarray(image).reshape((-1, 3))
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
@ -43,16 +59,25 @@ def cluster_palette(image: Image):
palette_colours[palette].extend(
colours_cam[line * 320:(line + 1) * 320])
for palette in range(16):
kmeans = KMeans(n_clusters=16, max_iter=10000)
kmeans.fit_predict(palette_colours[palette])
palette_cam = kmeans.cluster_centers_
# For each line grouping, find big palette entries with minimal total
# distance
palette_cam = None
for palette_idx in range(16):
line_colours = palette_colours[palette_idx]
# if palette_idx > 0:
# fixed_centroids = palette_cam[:8, :]
# else:
fixed_centroids = None
# print(np.array(line_colours), fixed_centroids)
palette_cam = dither_pyx.k_means_with_fixed_centroids(16, np.array(
line_colours), fixed_centroids=fixed_centroids, tolerance=1e-6)
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
palette_rgb = colour.convert(palette_cam, "CAM16UCS", "RGB")
# SHR colour palette only uses 4-bit values
palette_rgb = np.round(palette_rgb * 15) / 15
# palette_rgb = palette_rgb.astype(np.float32) / 255
palettes_rgb[palette] = palette_rgb.astype(np.float32)
palettes_rgb[palette_idx] = palette_rgb.astype(np.float32)
# print(palettes_rgb)
return palettes_rgb, line_to_palette
@ -130,8 +155,8 @@ def main():
for i in range(200):
screen.line_palette[i] = line_to_palette[i]
output_rgb[i, :, :] = (
palettes_rgb[line_to_palette[i]][
output_4bit[i, :]] * 255).astype(np.uint8)
palettes_rgb[line_to_palette[i]][
output_4bit[i, :]] * 255).astype(np.uint8)
output_srgb = image_py.linear_to_srgb(output_rgb).astype(np.uint8)
# dither = dither_pattern.PATTERNS[args.dither]()

View File

@ -425,4 +425,59 @@ def dither_shr(float[:, :, ::1] working_image, object palettes_rgb, float[:,::1]
# working_image[y + 2, x + 2, i] + quant_error[i] * (1 / 48),
# 0, 1)
return np.array(output_4bit, dtype=np.uint8)
return np.array(output_4bit, dtype=np.uint8)
import collections
import random
def k_means_with_fixed_centroids(
int n_clusters, float[:, ::1] data, float[:, ::1] fixed_centroids = None,
int iterations = 10000, float tolerance = 1e-3):
cdef int i, iteration, centroid_idx, num_fixed_centroids, num_random_centroids, best_centroid_idx
cdef float[::1] point, centroid, new_centroid, old_centroid
cdef float[:, ::1] centroids
cdef float best_dist, centroid_movement
centroids = np.zeros((n_clusters, 3), dtype=np.float32)
if fixed_centroids is not None:
centroids[:fixed_centroids.shape[0], :] = fixed_centroids
num_fixed_centroids = fixed_centroids.shape[0] if fixed_centroids is not None else 0
num_random_centroids = n_clusters - num_fixed_centroids
# TODO: kmeans++ initialization
for i in range(num_random_centroids):
centroids[num_fixed_centroids + i, :] = data[
random.randint(0, data.shape[0]), :]
cdef int[::1] centroid_weights = np.zeros(n_clusters, dtype=np.int32)
for iteration in range(iterations):
# print("centroids ", centroids)
closest_points = collections.defaultdict(list)
for point in data:
best_dist = 1e9
best_centroid_idx = 0
for centroid_idx, centroid in enumerate(centroids):
dist = colour_distance_squared(centroid, point)
if dist < best_dist:
best_dist = dist
best_centroid_idx = centroid_idx
closest_points[best_centroid_idx].append(point)
centroid_movement = 0
for centroid_idx, points in closest_points.items():
centroid_weights[centroid_idx] = len(points)
if centroid_idx < num_fixed_centroids:
continue
new_centroid = np.mean(np.array(points), axis=0)
old_centroid = centroids[centroid_idx]
centroid_movement += colour_distance_squared(old_centroid, new_centroid)
centroids[centroid_idx, :] = new_centroid
# print("iteration %d: movement %f" % (iteration, centroid_movement))
if centroid_movement < tolerance:
break
weighted_centroids = list(zip(centroid_weights, [tuple(c) for c in centroids]))
print(weighted_centroids)
return np.array([c for w, c in sorted(weighted_centroids, reverse=True)], dtype=np.float32)