mirror of
https://github.com/KrisKennaway/ii-pix.git
synced 2025-01-30 15:36:53 +00:00
Refactor
This commit is contained in:
parent
6988b19b43
commit
0323b80e68
158
convert.py
158
convert.py
@ -28,7 +28,9 @@ import screen as screen_py
|
||||
|
||||
class ClusterPalette:
|
||||
def __init__(
|
||||
self, image: Image, rgb12_iigs_to_cam16ucs, reserved_colours=0):
|
||||
self, image: Image, rgb12_iigs_to_cam16ucs, rgb24_to_cam16ucs,
|
||||
reserved_colours=0):
|
||||
self._image_rgb = image
|
||||
self._colours_cam = self._image_colours_cam(image)
|
||||
|
||||
self._errors = [1e9] * 16
|
||||
@ -55,6 +57,8 @@ class ClusterPalette:
|
||||
# colour space
|
||||
self._rgb12_iigs_to_cam16ucs = rgb12_iigs_to_cam16ucs
|
||||
|
||||
self._rgb24_to_cam16ucs = rgb24_to_cam16ucs
|
||||
|
||||
# List of line ranges used to train the 16 SHR palettes
|
||||
# [(lower_0, upper_0), ...]
|
||||
self._palette_splits = self._palette_splits()
|
||||
@ -68,6 +72,75 @@ class ClusterPalette:
|
||||
# Delta applied to palette split in previous iteration
|
||||
self._palette_mutate_delta = (0, 0)
|
||||
|
||||
def iterate(self, penalty: float, max_iterations: int):
|
||||
iterations_since_improvement = 0
|
||||
total_image_error = 1e9
|
||||
|
||||
last_good_splits = self._palette_splits
|
||||
|
||||
while iterations_since_improvement < max_iterations:
|
||||
# print("Iterations %d" % iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
self._propose_palettes())
|
||||
|
||||
# Suppress divide by zero warning,
|
||||
# https://github.com/colour-science/colour/issues/900
|
||||
with colour.utilities.suppress_warnings(python_warnings=True):
|
||||
new_palettes_linear_rgb = colour.convert(
|
||||
new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
|
||||
|
||||
# Recompute image with proposed palettes and check whether it has
|
||||
# lower total image error than our previous best.
|
||||
new_output_4bit, new_line_to_palette, new_total_image_error = \
|
||||
dither_pyx.dither_shr(
|
||||
self._image_rgb, new_palettes_cam, new_palettes_linear_rgb,
|
||||
self._rgb24_to_cam16ucs, float(penalty))
|
||||
|
||||
# print(total_image_error, new_total_image_error,
|
||||
# self._palette_splits)
|
||||
|
||||
# TODO: extract this into a function
|
||||
palettes_used = [False] * 16
|
||||
for palette in new_line_to_palette:
|
||||
palettes_used[palette] = True
|
||||
for palette_idx, palette_used in enumerate(palettes_used):
|
||||
if palette_used:
|
||||
continue
|
||||
print("Reassigning palette %d" % palette_idx)
|
||||
max_width = 0
|
||||
split_palette_idx = -1
|
||||
idx = 0
|
||||
for lower, upper in last_good_splits:
|
||||
width = upper - lower
|
||||
if width > max_width:
|
||||
split_palette_idx = idx
|
||||
idx += 1
|
||||
|
||||
lower, upper = last_good_splits[split_palette_idx]
|
||||
if upper - lower > 20:
|
||||
mid = (lower + upper) // 2
|
||||
self._palette_splits[split_palette_idx] = (
|
||||
lower, mid - 1)
|
||||
self._palette_splits[palette_idx] = (mid, upper)
|
||||
else:
|
||||
lower = np.random.randint(0, 199)
|
||||
upper = np.random.randint(lower, 200)
|
||||
self._palette_splits[palette_idx] = (lower, upper)
|
||||
|
||||
if new_total_image_error >= total_image_error:
|
||||
iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
iterations_since_improvement = 0
|
||||
last_good_splits = self._palette_splits
|
||||
total_image_error = new_total_image_error
|
||||
self._accept_palettes(
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
|
||||
|
||||
yield (new_total_image_error, new_output_4bit, new_line_to_palette,
|
||||
new_palettes_rgb12_iigs, new_palettes_linear_rgb)
|
||||
|
||||
def _image_colours_cam(self, image: Image):
|
||||
colours_rgb = np.asarray(image).reshape((-1, 3))
|
||||
with colour.utilities.suppress_warnings(colour_usage_warnings=True):
|
||||
@ -163,7 +236,7 @@ class ClusterPalette:
|
||||
self._apply_palette_delta(palette_to_mutate, palette_lower_delta,
|
||||
palette_upper_delta)
|
||||
|
||||
def propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
|
||||
def _propose_palettes(self) -> Tuple[np.ndarray, np.ndarray, List[float]]:
|
||||
"""Attempt to find new palettes that locally improve image quality.
|
||||
|
||||
Re-fit a set of 16 palettes from (overlapping) line ranges of the
|
||||
@ -226,7 +299,7 @@ class ClusterPalette:
|
||||
self._palettes_accepted = False
|
||||
return new_palettes_cam, new_palettes_rgb12_iigs, new_errors
|
||||
|
||||
def accept_palettes(
|
||||
def _accept_palettes(
|
||||
self, new_palettes_cam: np.ndarray,
|
||||
new_palettes_rgb: np.ndarray, new_errors: List[float]):
|
||||
self._palettes_cam = np.copy(new_palettes_cam)
|
||||
@ -297,7 +370,7 @@ def main():
|
||||
|
||||
# TODO: flags
|
||||
penalty = 1 # 1e18 # TODO: is this needed any more?
|
||||
iterations = 200
|
||||
iterations = 20# 0
|
||||
|
||||
pygame.init()
|
||||
# TODO: for some reason I need to execute this twice - the first time
|
||||
@ -307,83 +380,22 @@ def main():
|
||||
canvas.fill((0, 0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
total_image_error = 1e9
|
||||
iterations_since_improvement = 0
|
||||
|
||||
total_image_error = None
|
||||
# TODO: reserved_colours should be a flag
|
||||
cluster_palette = ClusterPalette(
|
||||
rgb, reserved_colours=1, rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs)
|
||||
last_good_splits = cluster_palette._palette_splits
|
||||
rgb, reserved_colours=1,
|
||||
rgb12_iigs_to_cam16ucs=rgb12_iigs_to_cam16ucs,
|
||||
rgb24_to_cam16ucs=rgb24_to_cam16ucs)
|
||||
|
||||
while iterations_since_improvement < iterations:
|
||||
# print("Iterations %d" % iterations_since_improvement)
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors = (
|
||||
cluster_palette.propose_palettes())
|
||||
for (new_total_image_error, output_4bit, line_to_palette,
|
||||
palettes_rgb12_iigs, palettes_linear_rgb) in cluster_palette.iterate(
|
||||
penalty, iterations):
|
||||
|
||||
# Suppress divide by zero warning,
|
||||
# https://github.com/colour-science/colour/issues/900
|
||||
with colour.utilities.suppress_warnings(python_warnings=True):
|
||||
new_palettes_linear_rgb = colour.convert(
|
||||
new_palettes_cam, "CAM16UCS", "RGB").astype(np.float32)
|
||||
|
||||
# Recompute image with proposed palettes and check whether it has
|
||||
# lower total image error than our previous best.
|
||||
new_output_4bit, new_line_to_palette, new_total_image_error = \
|
||||
dither_pyx.dither_shr(
|
||||
rgb, new_palettes_cam, new_palettes_linear_rgb,
|
||||
rgb24_to_cam16ucs, float(penalty))
|
||||
|
||||
# print(total_image_error, new_total_image_error,
|
||||
# cluster_palette._palette_splits)
|
||||
|
||||
# TODO: move this into ClusterPalettes
|
||||
palettes_used = [False] * 16
|
||||
for palette in new_line_to_palette:
|
||||
palettes_used[palette] = True
|
||||
for palette_idx, palette_used in enumerate(palettes_used):
|
||||
if palette_used:
|
||||
continue
|
||||
print("Reassigning palette %d" % palette_idx)
|
||||
max_width = 0
|
||||
split_palette_idx = -1
|
||||
idx = 0
|
||||
for lower, upper in last_good_splits:
|
||||
width = upper - lower
|
||||
if width > max_width:
|
||||
split_palette_idx = idx
|
||||
idx += 1
|
||||
|
||||
lower, upper = last_good_splits[split_palette_idx]
|
||||
if upper - lower > 20:
|
||||
mid = (lower + upper) // 2
|
||||
cluster_palette._palette_splits[split_palette_idx] = (
|
||||
lower, mid - 1)
|
||||
cluster_palette._palette_splits[palette_idx] = (mid, upper)
|
||||
else:
|
||||
lower = np.random.randint(0, 199)
|
||||
upper = np.random.randint(lower, 200)
|
||||
cluster_palette._palette_splits[palette_idx] = (lower, upper)
|
||||
|
||||
if new_total_image_error >= total_image_error:
|
||||
iterations_since_improvement += 1
|
||||
continue
|
||||
|
||||
# We found a globally better set of palettes
|
||||
iterations_since_improvement = 0
|
||||
cluster_palette.accept_palettes(
|
||||
new_palettes_cam, new_palettes_rgb12_iigs, new_palette_errors)
|
||||
last_good_splits = cluster_palette._palette_splits
|
||||
|
||||
if total_image_error < 1e9:
|
||||
if total_image_error is not None:
|
||||
print("Improved quality +%f%% (%f)" % (
|
||||
(1 - new_total_image_error / total_image_error) * 100,
|
||||
new_total_image_error))
|
||||
# print(cluster_palette._palette_splits)
|
||||
output_4bit = new_output_4bit
|
||||
line_to_palette = new_line_to_palette
|
||||
total_image_error = new_total_image_error
|
||||
palettes_rgb12_iigs = new_palettes_rgb12_iigs
|
||||
palettes_linear_rgb = new_palettes_linear_rgb
|
||||
for i in range(16):
|
||||
screen.set_palette(i, palettes_rgb12_iigs[i, :, :])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user