From 938fbae4885233631213dd157927e32e8addcf99 Mon Sep 17 00:00:00 2001 From: kris Date: Mon, 24 Aug 2020 22:28:28 +0100 Subject: [PATCH] Optimize evolve() by expanding recurrence relation --- encode_audio.py | 66 +++++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/encode_audio.py b/encode_audio.py index d2e75c1..8f60113 100755 --- a/encode_audio.py +++ b/encode_audio.py @@ -24,13 +24,16 @@ # compensate for this "dead" period by pre-positioning. import collections +import functools import sys import librosa import numpy from eta import ETA +from typing import Tuple import opcodes + # TODO: add flags to parametrize options @@ -63,29 +66,35 @@ def lookahead(step_size: int, initial_position: float, data: numpy.ndarray, return best -# TODO: share implementation with lookahead -def evolve(opcode: opcodes.Opcode, starting_position, starting_voltage, - step_size, data, starting_idx): - """Apply the effects of playing a single opcode to completion. +@functools.lru_cache(None) +def _delta_powers(length: int, step_size: int) -> Tuple[float, numpy.ndarray]: + delta = (1 - 1 / step_size) + return delta, numpy.cumprod(numpy.full(length, delta)) - Returns new state. - """ + +def evolve(opcode: opcodes.Opcode, starting_position, starting_voltage, + step_size, data, starting_idx): + # The speaker position p_i evolves according to + # p_{i+1} = p_i + (v_i - p_i) / s + # where v_i is the i'th applied voltage, s is the speaker step size + # + # Rearranging, we get p_{i+1} = v_i / s + (1-1/s) p_i + # and if we expand the recurrence relation + # p_{i+1} = Sum_{j=0}^i (1-1/s)^(i-j) v_j / s + (1-1/s)^(i+1) p_0 + # = (1-1/s)^(i+1)(1/s * Sum_{j=0}^i v_j / (1-1/s)^(j+1) + p0) opcode_length = opcodes.cycle_length(opcode) voltages = starting_voltage * opcodes.VOLTAGE_SCHEDULE[opcode] - position = starting_position - total_err = 0.0 - v = starting_voltage - last_v = v - num_flips = 0 - for i, v in enumerate(voltages): - if v != last_v: - num_flips += 1 - last_v = v - position += (v - position) / step_size - err = position - data[starting_idx + i] - total_err += err ** 2 - return position, v, total_err, starting_idx + opcode_length, num_flips + delta, delta_powers = _delta_powers(opcode_length, step_size) + + positions = delta_powers * ( + numpy.cumsum(voltages / delta_powers) / step_size + + starting_position) + + # TODO: compute error once at the end? + total_err = numpy.sum(numpy.power( + positions - data[starting_idx:starting_idx + opcode_length], 2)) + return positions[-1], voltages[-1], total_err, starting_idx + opcode_length def audio_bytestream(data: numpy.ndarray, step: int, lookahead_steps: int): @@ -104,8 +113,7 @@ def audio_bytestream(data: numpy.ndarray, step: int, lookahead_steps: int): i = 0 last_updated = 0 opcode_counts = collections.defaultdict(int) - num_flips = 0 - while i < int(dlen/10): + while i < int(dlen / 10): if (i - last_updated) > int((dlen / 1000)): eta.print_status() last_updated = i @@ -117,12 +125,23 @@ def audio_bytestream(data: numpy.ndarray, step: int, lookahead_steps: int): opcode = candidate_opcodes[opcode_idx].opcodes[0] opcode_counts[opcode] += 1 yield opcode + # print(opcode, position, voltage) - position, voltage, new_error, i, new_flips = evolve( + position, voltage, new_error, i = evolve2( opcode, position, voltage, step, data, i) + # position2, voltage2, new_error2, i2 = evolve2( + # opcode, position, voltage, step, data, i) + # print(opcode, position, voltage) + # assert i1 == i2, (i1, i2) + # assert voltage1 == voltage2, (voltage, voltage2) + # assert abs(position1 - position2) < 1e-7, (position1, position2) + # print(position1 - position2) + # i = i1 + # voltage = voltage1 + # position = position1 + total_err += new_error - num_flips += new_flips frame_offset = (frame_offset + 1) % 2048 for _ in range(frame_offset % 2048, 2047): @@ -130,7 +149,6 @@ def audio_bytestream(data: numpy.ndarray, step: int, lookahead_steps: int): yield opcodes.Opcode.EXIT eta.done() print("Total error %f" % total_err) - print("%d speaker actuations" % num_flips) print("Opcodes used:") for v, k in sorted(list(opcode_counts.items()), key=lambda kv: kv[1],