View File

@ -1,6 +1,8 @@
# Millfork
A middle-level programming language targeting 6502-based microcomputers.
A middle-level programming language targeting 6502-based microcomputers.
Distributed under GPLv3 (see [LICENSE](LICENSE))

build.sbt Normal file
View File

@ -0,0 +1,35 @@
name := "millfork"
version := "0.0.1-SNAPSHOT"
scalaVersion := "2.12.3"
resolvers += Resolver.mavenLocal
libraryDependencies += "com.lihaoyi" %% "fastparse" % "1.0.0"
libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.2"
libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.4"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % "test"
// these two not in Maven Central or any other public repo
// get them from the following links or just build millfork without tests:
libraryDependencies += "com.loomcom.symon" % "symon" % "1.3.0-SNAPSHOT" % "test"
libraryDependencies += "com.grapeshot" % "halfnes" % "061" % "test"
mainClass in Compile := Some("millfork.Main")
assemblyJarName := "millfork.jar"
//lazy val root = (project in file(".")).
// enablePlugins(BuildInfoPlugin).
// settings(
// buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion),
// buildInfoPackage := "hello"
// )

doc/ Normal file
View File

@ -0,0 +1,7 @@
# Documentation
## Tutorial
* [Getting started](tutorial/
* [Basic functions and variables](tutorial/

doc/ Normal file
View File

@ -0,0 +1,93 @@
# Target platforms
Currently, Millfork supports creating disk- or tape-based programs for Commodore and Atari 8-bit computers,
but it may be expanded to support other 6502-based platforms in the future.
## Supported platforms
The following platforms are currently supported:
* `c64` Commodore 64
* `c16` Commodore 16
* `plus4` Commodore Plus/4
* `vic20` Commodore VIC-20 without memory expansion
* `vic20_3k` Commodore VIC-20 with 3K memory expansion
* `vic20_8k` Commodore VIC-20 with 8K or 16K memory expansion
* `c128` Commodore 128 in its native mode
* `pet` Commodore PET
* `a8` Atari 8-bit computers
The primary and most tested platform is Commodore 64.
Currently, all targets assume that the program will be loaded from disk or tape.
Cartridge targets are not yet available.
## Adding a custom platform
Every platform is defined in an `.ini` file with an appropriate name.
#### `[compilation]` section
* `arch` CPU architecture. It defines which instructions are available. Available values:
* `nmos`
* `strict` (= NMOS without illegal instructions)
* `ricoh` (= NMOS without decimal mode)
* `strictricoh`
* `cmos` (= 65C02)
* `modules` comma-separated list of modules that will be automatically imported
* other compilation options (they can be overridden using commandline options):
* `emit_illegals` whether the compiler should emit illegal instructions, default `false`
* `emit_cmos` whether the compiler should emit CMOS instructions, default is `true` on `cmos` and `false` elsewhere
* `decimal_mode` whether the compiler should emit decimal instructions, default is `false` on `ricoh` and `strictricoh` and `true` elsewhere
* `ro_arrays` whether the compiler should warn upon array writes, default is `false`
* `prevent_jmp_indirect_bug` whether the compiler should try to avoid the indirect JMP bug, default is `false` on `cmos` and `true` elsewhere
#### `[allocation]` section
* `main_org` the address for the `main` function; all the other functions will be placed after it
* `zp_pointers` either a list of comma separated zeropage addresses that can be used by the program as zeropage pointers, or `all` for all. Each value should be the address of the first of two free bytes in the zeropage.
* `himem_style` not yet supported
* `himem_start` the first address used for non-zeropage variables, or `after_code` if the variables should be allocated after the code
* `himem_end` the last address available for non-zeropage variables
#### `[output]` section
* `style` not yet supported
* `format` output file format; a comma-separated list of tokens:
* literal byte values
* `startaddr` little-endian 16-bit address of the first used byte of the compiled output
* `endaddr` little-endian 16-bit address of the last used byte of the compiled output
* `allocated` all used bytes
* `<addr>:<addr>` - inclusive range of bytes
* `extension` target file extension, with or without the dot

View File

@ -0,0 +1,54 @@
# Getting started
## Hello world example
Save the following as ``:
import stdio
array hello_world = "hello world" petscii
void main(){
putstr(hello_world, hello_world.length)
Compile is using the following commandline:
java millfork.jar -o hello_world -t c64 -I path_to_millfork\include
Run the output executable (here using the VICE emulator):
x64 hello_world.prg
## Basic commandline usage
The following options are crucial when compiling your sources:
* `-o FILENAME` specifies the base name for your output file, an appropriate file extension will be appended (`prg` for Commodore, `xex` for Atari, `asm` for assembly output, `lbl` for label file)
* `-I DIR;DIR;DIR;...` specifies the paths to directories with modules to include.
* `-t PLATFORM` specifies the target platform (`c64` is the default). Each platform is defined in an `.ini` file in the include directory. For the list of supported platforms, see [Supported platforms](../
You may be also interested in the following:
* `-O`, `-O2`, `-O3` enable optimization (various levels)
* `--detailed-flow` use more resource-consuming but more precise flow analysis engine for better optimization
* `-s` additionally generate assembly output
* `-g` additionally generate a label file, in format compatible with VICE emulator
* `-r PROGRAM` automatically launch given program after successful compilation
* `-Wall` enable all warnings
* `--help` list all commandline options

View File

@ -0,0 +1,16 @@
# Functions and variables
TODO: write all of this
## Basic types
## Defining variables
## Built-in operators
### Byte operators
| a | a | a |
| -- | -- | -- |
| a | a | a |

View File

@ -0,0 +1,11 @@
// compile with
// java -jar millfork.jar -I ${PATH}/include -t ${platform} ${PATH}/examples/hello_world/hello_world.mfk
import stdio
array hello_world = "hello world" petscii
void main(){
putstr(hello_world, hello_world.length)

include/a8.ini Normal file
View File

@ -0,0 +1,22 @@

include/a8_kernel.mfk Normal file
View File

@ -0,0 +1,9 @@
asm void putchar(byte a) {
lda $347
lda $346

include/c128.ini Normal file
View File

@ -0,0 +1,20 @@

View File

@ -0,0 +1,5 @@
import c64_vic
import c64_sid
import c64_cia
array c64_color_ram [1000] @$D800

include/c128_kernal.mfk Normal file
View File

@ -0,0 +1,5 @@
// Routines from Commodore 128 KERNAL ROM
// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
// Input: A = Byte to write.
asm void putchar(byte a) @$FFD2 extern

include/c1531.mfk Normal file
View File

@ -0,0 +1,60 @@
// mouse driver for Commodore 1531 mouse on Commodore 64
import mouse
import c64_hardware
sbyte _c1531_calculate_delta (byte old, byte new) {
byte mouse_delta
mouse_delta = (new - old)
mouse_delta &= $3f
if mouse_delta >= $20 {
mouse_delta |= $c0
return mouse_delta
byte _c1531_handle_x() {
static byte _c1531_old_pot_x
sbyte mouse_delta
byte new_pot_x
new_pot_x = sid_paddle_x >> 1
mouse_delta = _c1531_calculate_delta(_c1531_old_pot_x, new_pot_x)
_c1531_old_pot_x = new_pot_x
mouse_x += mouse_delta
mouse_x.hi &= 1
if mouse_x > 319 {
if mouse_delta > 0 {
mouse_x = 319
} else {
mouse_x = 0
byte _c1531_handle_y() {
static byte _c1531_old_pot_y
byte new_pot_y
sbyte mouse_delta
new_pot_y = sid_paddle_y >> 1
mouse_delta = _c1531_calculate_delta(_c1531_old_pot_y, new_pot_y)
_c1531_old_pot_y = new_pot_y
mouse_y -= mouse_delta
if mouse_y > 199 {
if mouse_delta > 0 {
mouse_y = 0
} else {
mouse_y = 199
void c1531_mouse () {
cia1_pra = ($3f & cia1_pra) | $40

include/c16.ini Normal file
View File

@ -0,0 +1,19 @@

View File

@ -0,0 +1 @@
import c16_ted

include/c264_kernal.mfk Normal file
View File

@ -0,0 +1,5 @@
// Routines from C16 and Plus/4 KERNAL ROM
// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
// Input: A = Byte to write.
asm void putchar(byte a) @$FFD2 extern

include/c264_ted.mfk Normal file
View File

@ -0,0 +1,20 @@
const byte black = 0
const byte white = $71
const byte red = $22
const byte cyan = $43
const byte purple = $24
const byte green = $35
const byte blue = $16
const byte yellow = $57
const byte orange = $28
const byte brown = $19
const byte light_red = $32
const byte dark_grey = $21
const byte dark_gray = $21
const byte medium_grey = $31
const byte medium gray = $31
const byte light_green = $55
const byte light_blue = $36
const byte light_grey = $41
const byte light_gray = $41

include/c64.ini Normal file
View File

@ -0,0 +1,37 @@
; Commodore 64
; assuming a program loaded from disk or tape
; CPU architecture: nmos, strictnmos, ricoh, strictricoh, cmos
; modules to load
; optionally: default flags
; where the main function should be allocated, also the start of bank 0
; list of free zp pointer locations (these assume that BASIC will keep working)
; where to allocate non-zp variables
; how the banks are laid out in the output files; so far, there is no bank support in the compiler yet
; output file format
; startaddr - little-endian address of the first used byte in the bank
; endaddr - little-endian address of the last used byte in the bank
; allocated - all used bytes in the bank
; <addr>:<addr> - bytes from the current bank
; <bank>:addr>:<addr> - bytes from arbitrary bank
; <byte> - single byte
; default output file extension

include/c64_basic.mfk Normal file
View File

@ -0,0 +1,6 @@
// Routines from C64 BASIC ROM
import c64_kernal
// print a 16-bit number on the standard output
asm void putword(word xa) @$BDCD extern

include/c64_cia.mfk Normal file
View File

@ -0,0 +1,40 @@
// Hardware addresses for C64
// CIA1
byte cia1_pra @$DC00
byte cia1_prb @$DC01
byte cia1_ddra @$DC02
byte cia1_ddrb @$DC03
byte cia2_pra @$DD00
byte cia2_prb @$DD01
byte cia2_ddra @$DD02
byte cia2_ddrb @$DD03
inline asm void cia_disable_irq() {
LDA #$7f
LDA $dc0d
LDA $dd0d
LDA $dc0d
LDA $dd0d
inline void vic_bank_0000() {
cia2_ddra = $C0
cia2_pra = $C0
inline void vic_bank_4000() {
cia2_ddra = $C0
cia2_pra = $80
inline void vic_bank_8000() {
cia2_ddra = $C0
cia2_pra = $40
inline void vic_bank_C000() {
cia2_ddra = $C0
cia2_pra = $00

include/c64_hardware.mfk Normal file
View File

@ -0,0 +1,41 @@
import c64_vic
import c64_sid
import c64_cia
import cpu6510
array c64_color_ram [1000] @$D800
inline void c64_ram_only() {
cpu6510_ddr = 7
cpu6510_port = 0
inline void c64_ram_io() {
cpu6510_ddr = 7
cpu6510_port = 5
inline void c64_ram_io_kernal() {
cpu6510_ddr = 7
cpu6510_port = 6
inline void c64_ram_io_basic() {
cpu6510_ddr = 7
cpu6510_port = 7
inline void c64_ram_charset() {
cpu6510_ddr = 7
cpu6510_port = 1
inline void c64_ram_charset_kernal() {
cpu6510_ddr = 7
cpu6510_port = 2
inline void c64_ram_charset_basic() {
cpu6510_ddr = 7
cpu6510_port = 3

include/c64_kernal.mfk Normal file
View File

@ -0,0 +1,32 @@
// Routines from C64 KERNAL ROM
// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
// Input: A = Byte to write.
asm void putchar(byte a) @$FFD2 extern
// OPEN. Open file. (Must call SETLFS and SETNAM beforehands.)
asm void open() @$FFC0 extern
// CLOSE. Close file.
// Input: A = Logical number.
asm void close(byte a) @$FFC0 extern
// SETLFS. Set file parameters.
// Input: A = Logical number; X = Device number; Y = Secondary address.
asm void setlfs(byte a, byte x, byte y) @$FFBA extern
// SETNAM. Set file name parameters.
// Input: A = File name length; X/Y = Pointer to file name.
asm void setnam(word yx, byte a) @$FFBA extern
// LOAD. Load or verify file. (Must call SETLFS and SETNAM beforehands.)
// Input: A: 0 = Load, 1-255 = Verify; X/Y = Load address (if secondary address = 0).
// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1); X/Y = Address of last byte loaded/verified (if Carry = 0).
asm clear_carry load(byte a, word yx) @$FFD5 extern
// SAVE. Save file. (Must call SETLFS and SETNAM beforehands.)
// Input: A = Address of zero page register holding start address of memory area to save; X/Y = End address of memory area plus 1.
// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1).
asm clear_carry save(byte a, word yx) @$FFD5 extern
word irq_pointer @$314

include/c64_sid.mfk Normal file
View File

@ -0,0 +1,24 @@
// Hardware addresses for C64
// SID
word sid_v1_freq @$D400
word sid_v1_pulse @$D402
byte sid_v1_cr @$D404
byte sid_v1_ad @$D405
byte sid_v1_sr @$D409
word sid_v2_freq @$D407
word sid_v2_pulse @$D409
byte sid_v2_cr @$D40B
byte sid_v2_ad @$D40C
byte sid_v2_sr @$D40D
word sid_v3_freq @$D40E
word sid_v3_pulse @$D410
byte sid_v3_cr @$D412
byte sid_v3_ad @$D413
byte sid_v3_sr @$D414
byte sid_paddle_x @$D419
byte sid_paddle_y @$D41A

include/c64_vic.mfk Normal file
View File

@ -0,0 +1,154 @@
// Hardware addresses for C64
byte vic_spr0_x @$D000
byte vic_spr0_y @$D001
byte vic_spr1_x @$D002
byte vic_spr1_y @$D003
byte vic_spr2_x @$D004
byte vic_spr2_y @$D005
byte vic_spr3_x @$D006
byte vic_spr3_y @$D007
byte vic_spr4_x @$D008
byte vic_spr4_y @$D009
byte vic_spr5_x @$D00A
byte vic_spr5_y @$D00B
byte vic_spr6_x @$D00C
byte vic_spr6_y @$D00D
byte vic_spr7_x @$D00E
byte vic_spr7_y @$D00F
byte vic_spr_hi_x @$D010
byte vic_cr1 @$D011
byte vic_raster @$D012
byte vic_lp_x @$D013
byte vic_lp_y @$D014
byte vic_spr_ena @$D015
byte vic_cr2 @$D016
byte vic_spr_exp_y @$D017
byte vic_mem @$D018
byte vic_irq @$D019
byte vic_irq_ena @$D01A
byte vic_spr_dp @$D01B
byte vic_spr_mcolor @$D01C
byte vic_spr_exp_x @$D01D
byte vic_spr_ss_col @$D01E
byte vic_spr_sd_col @$D01F
byte vic_border @$D020
byte vic_bg_color0 @$D021
byte vic_bg_color1 @$D022
byte vic_bg_color2 @$D023
byte vic_bg_color3 @$D024
byte vic_spr_color1 @$D025
byte vic_spr_color2 @$D026
byte vic_spr0_color @$D027
byte vic_spr1_color @$D028
byte vic_spr2_color @$D029
byte vic_spr3_color @$D02A
byte vic_spr4_color @$D02B
byte vic_spr5_color @$D02C
byte vic_spr6_color @$D02D
byte vic_spr7_color @$D02E
array vic_spr_coord [16] @$D000
array vic_spr_color [8] @$D027
inline void vic_enable_multicolor() {
vic_cr2 |= 0x10
inline void vic_disable_multicolor() {
vic_cr2 &= 0xEF
inline void vic_enable_bitmap() {
vic_cr1 |= 0x20
inline void vic_disable_bitmap() {
vic_cr1 &= 0xDF
inline void vic_24_rows() {
vic_cr1 &= 0xF7
inline void vic_25_rows() {
vic_cr1 |= 8
inline void vic_38_columns() {
vic_cr2 &= 0xF7
inline void vic_40_columns() {
vic_cr2 |= 8
inline void vic_disable_irq() {
vic_irq_ena = 0
vic_irq += 1
// base: divisible by $400, $0000-$3C00 allowed
//inline void vic_screen(word const base) {
// vic_mem = (vic_mem & $0F) | (base >> 6)
inline void vic_charset_0000() {
vic_mem = (vic_mem & $F1)
inline void vic_charset_0800() {
vic_mem = (vic_mem & $F1) | 2
inline void vic_charset_1000() {
vic_mem = (vic_mem & $F1) | 4
inline void vic_charset_1800() {
vic_mem = (vic_mem & $F1) | 6
inline void vic_charset_2000() {
vic_mem = (vic_mem & $F1) | 8
inline void vic_charset_2800() {
vic_mem = (vic_mem & $F1) | $A
inline void vic_charset_3000() {
vic_mem = (vic_mem & $F1) | $C
inline void vic_charset_3800() {
vic_mem = (vic_mem & $F1) | $E
inline void vic_bitmap_0000() {
vic_mem &= $F7
inline void vic_bitmap_2000() {
vic_mem |= 8
// x, y < 8
// default: x=0, y=3
void vic_set_scroll(byte x, byte y) {
vic_cr1 = (vic_cr1 & $F8) | y
vic_cr2 = (vic_cr2 & $F8) | x
const byte black = 0
const byte white = 1
const byte red = 2
const byte cyan = 3
const byte purple = 4
const byte green = 5
const byte blue = 6
const byte yellow = 7
const byte orange = 8
const byte brown = 9
const byte light_red = 10
const byte dark_grey = 11
const byte dark_gray = 11
const byte medium_grey = 12
const byte medium_gray = 12
const byte light_green = 13
const byte light_blue = 14
const byte light_grey = 15
const byte light_gray = 15

include/cpu6510.mfk Normal file
View File

@ -0,0 +1,3 @@
byte cpu6510_ddr @0
byte cpu6510_port @1

include/loader_0401.mfk Normal file
View File

@ -0,0 +1,15 @@
array _basic_loader @$401 = [

include/loader_0801.mfk Normal file
View File

@ -0,0 +1,15 @@
array _basic_loader @$801 = [

include/loader_1001.mfk Normal file
View File

@ -0,0 +1,15 @@
array _basic_loader @$1001 = [

include/loader_1201.mfk Normal file
View File

@ -0,0 +1,15 @@
array _basic_loader @$1201 = [

include/loader_1c01.mfk Normal file
View File

@ -0,0 +1,15 @@
array _basic_loader @$1C01 = [

include/mouse.mfk Normal file
View File

@ -0,0 +1,8 @@
// Generic module for mouse support
// Resolutions up to 512x256 are supported
// Mouse X coordinate
word mouse_x
// Mouse Y coordinate
byte mouse_y

include/pet.ini Normal file
View File

@ -0,0 +1,19 @@

include/pet_kernal.mfk Normal file
View File

@ -0,0 +1,5 @@
// Routines from Commodore PET KERNAL ROM
// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
// Input: A = Byte to write.
asm void putchar(byte a) @$FFD2 extern

include/plus4.ini Normal file
View File

@ -0,0 +1,19 @@

include/stdio.mfk Normal file
View File

@ -0,0 +1,10 @@
// target-independent standard I/O routines
void putstr(pointer str, byte len) {
byte index
index = 0
while (index != len) {
index += 1

include/stdlib.mfk Normal file
View File

@ -0,0 +1,23 @@
// target-independent things
word nmi_routine_addr @$FFFA
word reset_routine_addr @$FFFC
word irq_routine_addr @$FFFE
inline asm void poke(word const addr, byte const value) {
?LDA #value
STA addr
inline asm byte peek(word const addr) {
LDA addr
inline asm void disable_irq() {
inline asm void enable_irq() {

include/vic20.ini Normal file
View File

@ -0,0 +1,19 @@

include/vic20_3k.ini Normal file
View File

@ -0,0 +1,19 @@

include/vic20_8k.ini Normal file
View File

@ -0,0 +1,19 @@

include/vic20_kernal.mfk Normal file
View File

@ -0,0 +1,5 @@
// Routines from C16 and Plus/4 KERNAL ROM
// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
// Input: A = Byte to write.
asm void putchar(byte a) @$FFD2 extern

project/assembly.sbt Normal file
View File

@ -0,0 +1,2 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")

project/ Normal file
View File

@ -0,0 +1 @@
sbt.version = 0.13.16

project/buildinfo.sbt Normal file
View File

@ -0,0 +1 @@
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.7.0")

project/plugins.sbt Normal file
View File

View File

@ -0,0 +1,116 @@
package millfork
import millfork.error.ErrorReporting
* @author Karol Stasiak
//object CompilationOptions {
// private var instance = new CompilationOptions(Platform.C64, Map())
// // TODO: ugly!
// def change(o: CompilationOptions): Unit = {
// instance = o
// }
// def current: CompilationOptions= instance
// def platform: Platform = instance.platform
// def flag(flag: CompilationFlag.Value):Boolean = instance.flags(flag)
// def flags: Map[CompilationFlag.Value, Boolean] = instance.flags
class CompilationOptions(val platform: Platform, val commandLineFlags: Map[CompilationFlag.Value, Boolean]) {
import CompilationFlag._
import Cpu._
val flags: Map[CompilationFlag.Value, Boolean] = { f =>
f -> commandLineFlags.getOrElse(f, platform.flagOverrides.getOrElse(f, Cpu.defaultFlags(platform.cpu)(f)))
def flag(f: CompilationFlag.Value) = flags(f)
if (flags(DecimalMode)) {
if (platform.cpu == Ricoh || platform.cpu == StrictRicoh) {
ErrorReporting.warn("Decimal mode enabled for Ricoh architecture", this)
if (platform.cpu != Cmos) {
if (!flags(PreventJmpIndirectBug)) {
ErrorReporting.warn("JMP bug prevention should be enabled for non-CMOS architecture", this)
if (flags(EmitCmosOpcodes)) {
ErrorReporting.warn("CMOS opcodes enabled for non-CMOS architecture", this)
if (flags(EmitIllegals)) {
if (platform.cpu == Cmos) {
ErrorReporting.warn("Illegal opcodes enabled for CMOS architecture", this)
if (platform.cpu == StrictRicoh || platform.cpu == Ricoh) {
ErrorReporting.warn("Illegal opcodes enabled for strict architecture", this)
object Cpu extends Enumeration {
val Mos, StrictMos, Ricoh, StrictRicoh, Cmos = Value
import CompilationFlag._
def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match {
case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap)
case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap)
case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap)
case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap)
case Cmos => Set(EmitCmosOpcodes, VariableOverlap)
def fromString(name: String): Cpu.Value = name match {
case "nmos" => Mos
case "6502" => Mos
case "6510" => Mos
case "strict" => StrictMos
case "cmos" => Cmos
case "65c02" => Cmos
case "ricoh" => Ricoh
case "2a03" => Ricoh
case "2a07" => Ricoh
case "strictricoh" => StrictRicoh
case "strict2a03" => StrictRicoh
case "strict2a07" => StrictRicoh
case _ => ErrorReporting.fatal("Unknown CPU achitecture")
object CompilationFlag extends Enumeration {
// compilation options:
EmitIllegals, EmitCmosOpcodes, DecimalMode, ReadOnlyArrays, PreventJmpIndirectBug,
// optimization options:
DetailedFlowAnalysis, DangerousOptimizations,
// memory allocation options
// warning options
FatalWarnings = Value
val allWarnings: Set[CompilationFlag.Value] = Set(ExtraComparisonWarnings)
val fromString = Map(
"emit_illegals" -> EmitIllegals,
"emit_cmos" -> EmitCmosOpcodes,
"decimal_mode" -> DecimalMode,
"ro_arrays" -> ReadOnlyArrays,
"ror_warn" -> RorWarning,
"prevent_jmp_indirect_bug" -> PreventJmpIndirectBug,

View File

@ -0,0 +1,264 @@
package millfork
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import java.util.Locale
import millfork.assembly.opt.{CmosOptimizations, DangerousOptimizations, SuperOptimizer, UndocumentedOptimizations}
import millfork.cli.{CliParser, CliStatus}
import millfork.env.Environment
import millfork.error.ErrorReporting
import millfork.node.StandardCallGraph
import millfork.output.Assembler
import millfork.parser.SourceLoadingQueue
* @author Karol Stasiak
case class Context(inputFileNames: List[String],
outputFileName: Option[String] = None,
runFileName: Option[String] = None,
optimizationLevel: Option[Int] = None,
platform: Option[String] = None,
outputAssembly: Boolean = false,
outputLabels: Boolean = false,
includePath: List[String] = Nil,
flags: Map[CompilationFlag.Value, Boolean] = Map(),
verbosity: Option[Int] = None) {
def changeFlag(f: CompilationFlag.Value, b: Boolean): Context = {
if (flags.contains(f)) {
if (flags(f) != b) {
ErrorReporting.error("Conflicting flags")
} else {
copy(flags = this.flags + (f -> b))
object Main {
def main(args: Array[String]): Unit = {
if (args.isEmpty) {"For help, use --help")
val (status, c) = parser.parse(Context(Nil), args.toList)
status match {
case CliStatus.Quit => return
case CliStatus.Failed =>
ErrorReporting.fatalQuit("Invalid command line")
case CliStatus.Ok => ()
ErrorReporting.assertNoErrors("Invalid command line")
if (c.inputFileNames.isEmpty) {
ErrorReporting.fatalQuit("No input files")
ErrorReporting.verbosity = c.verbosity.getOrElse(0)
val optLevel = c.optimizationLevel.getOrElse(0)
val platform = Platform.lookupPlatformFile(c.includePath, c.platform.getOrElse {"No platform selected, defaulting to `c64`")
val options = new CompilationOptions(platform, c.flags)
ErrorReporting.debug("Effective flags: " + options.flags)
val output = c.outputFileName.getOrElse("a")
val assOutput = output + ".asm"
val labelOutput = output + ".lbl"
val prgOutput = if (!output.endsWith(platform.fileExtension)) output + platform.fileExtension else output
val unoptimized = new SourceLoadingQueue(
initialFilenames = c.inputFileNames,
includePath = c.includePath,
options = options).run()
val program = if (optLevel > 0) {
OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt))
} else {
val callGraph = new StandardCallGraph(program)
val env = new Environment(None, "")
env.collectDeclarations(program, options)
val extras = List(
if (options.flag(CompilationFlag.EmitIllegals)) UndocumentedOptimizations.All else Nil,
if (options.flag(CompilationFlag.EmitCmosOpcodes)) CmosOptimizations.All else Nil,
if (options.flag(CompilationFlag.DangerousOptimizations)) DangerousOptimizations.All else Nil,
val goodCycle = List.fill(optLevel - 1)(OptimizationPresets.Good ++ extras).flatten
val assemblyOptimizations = if (optLevel <= 0) Nil else if (optLevel >= 9) List(SuperOptimizer) else {
goodCycle ++ OptimizationPresets.AssOpt ++ extras ++ goodCycle
// compile
val assembler = new Assembler(env)
val result = assembler.assemble(callGraph, assemblyOptimizations, options)
ErrorReporting.assertNoErrors("Codegen failed")
ErrorReporting.debug(f"Unoptimized code size: ${assembler.unoptimizedCodeSize}%5d B")
ErrorReporting.debug(f"Optimized code size: ${assembler.optimizedCodeSize}%5d B")
ErrorReporting.debug(f"Gain: ${(100L * (assembler.unoptimizedCodeSize - assembler.optimizedCodeSize) / assembler.unoptimizedCodeSize.toDouble).round}%5d%%")
ErrorReporting.debug(f"Initialized arrays: ${assembler.initializedArraysSize}%5d B")
if (c.outputAssembly) {
val path = Paths.get(assOutput)
ErrorReporting.debug("Writing assembly to " + path.toAbsolutePath)
Files.write(path, result.asm.mkString("\n").getBytes(StandardCharsets.UTF_8))
if (c.outputLabels) {
val path = Paths.get(labelOutput)
ErrorReporting.debug("Writing labels to " + path.toAbsolutePath)
Files.write(path, result.labels.sortWith { (a, b) =>
val aLocal = a._1.head == '.'
val bLocal = b._1.head == '.'
if (aLocal == bLocal) a._1 < b._1
else b._1 < a._1
}.groupBy(_._2) { case (l, a) =>
val normalized = l.replace('$', '_').replace('.', '_')
s"al ${a.toHexString} .$normalized"
val path = Paths.get(prgOutput)
ErrorReporting.debug("Writing output to " + path.toAbsolutePath)
Files.write(path, result.code)
c.runFileName.foreach(program =>
new ProcessBuilder(program, path.toAbsolutePath.toString).start()
private def parser = new CliParser[Context] {
fluff("Main options:", "")
parameter("-o", "--out").required().placeholder("<file>").action { (p, c) =>
assertNone(c.outputFileName, "Output already defined")
c.copy(outputFileName = Some(p))
}.description("The output file name, without extension.").onWrongNumber(_ => ErrorReporting.fatalQuit("No output file specified"))
flag("-s").action { c =>
c.copy(outputAssembly = true)
}.description("Generate also the assembly output.")
flag("-g").action { c =>
c.copy(outputLabels = true)
}.description("Generate also the label file.")
parameter("-t", "--target").placeholder("<platform>").action { (p, c) =>
assertNone(c.platform, "Platform already defined")
c.copy(platform = Some(p))
}.description("Target platform, any of: c64, c16, plus4, vic20, vic20_3k, vic20_8k, pet, c128, a8.")
parameter("-I", "--include-dir").repeatable().placeholder("<dir>;<dir>;...").action { (paths, c) =>
val n = paths.split(";")
c.copy(includePath = c.includePath ++ n)
}.description("Include paths for modules.")
parameter("-r", "--run").placeholder("<program>").action { (p, c) =>
assertNone(c.runFileName, "Run program already defined")
c.copy(runFileName = Some(p))
}.description("Program to run after successful compilation.")
endOfFlags("--").description("Marks the end of options.")
fluff("", "Verbosity options:", "")
flag("-q", "--quiet").action { c =>
assertNone(c.verbosity, "Cannot use -v and -q together")
c.copy(verbosity = Some(-1))
}.description("Supress all messages except for errors.")
private val verbose = flag("-v", "--verbose").maxCount(3).action { c =>
if (c.verbosity.exists(_ < 0)) ErrorReporting.error("Cannot use -v and -q together", None)
c.copy(verbosity = Some(1 + c.verbosity.getOrElse(0)))
}.description("Increase verbosity.")
flag("-vv").repeatable().action(c => verbose.encounter(verbose.encounter(verbose.encounter(c)))).description("Increase verbosity even more.")
flag("-vvv").repeatable().action(c => verbose.encounter(verbose.encounter(c))).description("Increase verbosity even more.")
fluff("", "Code generation options:", "")
boolean("-fcmos-ops", "-fno-cmos-ops").action { (c, v) =>
c.changeFlag(CompilationFlag.EmitCmosOpcodes, v)
}.description("Whether should emit CMOS opcodes.")
boolean("-fillegals", "-fno-illegals").action { (c, v) =>
c.changeFlag(CompilationFlag.EmitIllegals, v)
}.description("Whether should emit illegal (undocumented) NMOS opcodes.")
boolean("-fjmp-fix", "-fno-jmp-fix").action { (c, v) =>
c.changeFlag(CompilationFlag.PreventJmpIndirectBug, v)
}.description("Whether should prevent indirect JMP bug on page boundary.")
boolean("-fdecimal-mode", "-fno-decimal-mode").action { (c, v) =>
c.changeFlag(CompilationFlag.DecimalMode, v)
}.description("Whether should decimal mode be available.")
boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) =>
c.changeFlag(CompilationFlag.VariableOverlap, v)
}.description("Whether should variables overlap if their scopes do not intersect.")
fluff("", "Optimization options:", "")
flag("-O0").action { c =>
assertNone(c.optimizationLevel, "Optimization level already defined")
c.copy(optimizationLevel = Some(0))
}.description("Disable all optimizations.")
flag("-O").action { c =>
assertNone(c.optimizationLevel, "Optimization level already defined")
c.copy(optimizationLevel = Some(1))
}.description("Optimize code.")
for (i <- 2 to 9) {
val f = flag("-O" + i).action { c =>
assertNone(c.optimizationLevel, "Optimization level already defined")
c.copy(optimizationLevel = Some(i))
}.description("Optimize code even more.")
if (i > 3) f.hidden()
flag("--detailed-flow").action { c =>
c.changeFlag(CompilationFlag.DetailedFlowAnalysis, true)
}.description("Use detailed flow analysis (experimental).")
flag("--dangerous-optimizations").action { c =>
c.changeFlag(CompilationFlag.DangerousOptimizations, true)
}.description("Use dangerous optimizations (experimental).")
fluff("", "Warning options:", "")
flag("-Wall", "--Wall").action { c =>
CompilationFlag.allWarnings.foldLeft(c) { (c, f) => c.changeFlag(f, true) }
}.description("Enable extra warnings.")
flag("-Wfatal", "--Wfatal").action { c =>
c.changeFlag(CompilationFlag.FatalWarnings, true)
}.description("Treat warnings as errors.")
fluff("", "Other options:", "")
flag("--help").action(c => {
}).description("Display this message.")
flag("--version").action(c => {
println("millfork version ")
}).description("Print the version and quit.")
default.action { (p, c) =>
if (p.startsWith("-")) {
ErrorReporting.error(s"Invalid option `$p`", None)
} else {
c.copy(inputFileNames = c.inputFileNames :+ p)
def assertNone[T](value: Option[T], msg: String): Unit = {
if (value.isDefined) {
ErrorReporting.error(msg, None)

View File

@ -0,0 +1,150 @@
package millfork
import millfork.assembly.opt._
import millfork.node.opt.{UnreachableCode, UnusedFunctions, UnusedGlobalVariables, UnusedLocalVariables}
* @author Karol Stasiak
object OptimizationPresets {
val NodeOpt = List(
val AssOpt: List[AssemblyOptimization] = List[AssemblyOptimization](
val Good: List[AssemblyOptimization] = List[AssemblyOptimization](

View File

@ -0,0 +1,115 @@
package millfork
import{File, StringReader}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import millfork.error.ErrorReporting
import millfork.output._
import org.apache.commons.configuration2.INIConfiguration
* @author Karol Stasiak
class Platform(
val cpu: Cpu.Value,
val flagOverrides: Map[CompilationFlag.Value, Boolean],
val startingModules: List[String],
val outputPackager: OutputPackager,
val allocator: VariableAllocator,
val org: Int,
val fileExtension: String,
object Platform {
val C64 = new Platform(
List("c64_hardware", "c64_loader"),
SequenceOutput(List(StartAddressOutput, AllocatedDataOutput)),
new VariableAllocator(
List(0xC1, 0xC3, 0xFB, 0xFD, 0x39, 0x3B, 0x3D, 0x43, 0x4B),
new AfterCodeByteAllocator(0xA000)
def lookupPlatformFile(includePath: List[String], platformName: String): Platform = {
includePath.foreach { dir =>
val file = Paths.get(dir, platformName + ".ini").toFile
ErrorReporting.debug("Checking " + file)
if (file.exists()) {
return load(file)
ErrorReporting.fatal(s"Platfom definition `$platformName` not found", None)
def load(file: File): Platform = {
val conf = new INIConfiguration()
val bytes = Files.readAllBytes(file.toPath) StringReader(new String(bytes, StandardCharsets.UTF_8)))
val cs = conf.getSection("compilation")
val cpu = Cpu.fromString(cs.get(classOf[String], "cpu", "strict"))
val flagOverrides = CompilationFlag.fromString.flatMap { case (k, f) =>
cs.get(classOf[String], k, "").toLowerCase match {
case "" => None
case "false" | "off" | "0" => Some(f -> false)
case "true" | "on" | "1" => Some(f -> true)
val startingModules = cs.get(classOf[String], "modules", "").split("[, ]+").filter(_.nonEmpty).toList
val as = conf.getSection("allocation")
val org = as.get(classOf[String], "main_org", "") match {
case "" => ErrorReporting.fatal(s"Undefined main_org")
case m => parseNumber(m)
val freePointers = as.get(classOf[String], "zp_pointers", "all") match {
case "all" => List.tabulate(128)(_ * 2)
case xs => xs.split("[, ]+").map(parseNumber).toList
val byteAllocator = (as.get(classOf[String], "himem_start", ""), as.get(classOf[String], "himem_end", "")) match {
case ("", _) => ErrorReporting.fatal(s"Undefined himem_start")
case (_, "") => ErrorReporting.fatal(s"Undefined himem_end")
case ("after_code", end) => new AfterCodeByteAllocator(parseNumber(end) + 1)
case (start, end) => new UpwardByteAllocator(parseNumber(start), parseNumber(end) + 1)
val os = conf.getSection("output")
val outputPackager = SequenceOutput(os.get(classOf[String], "format", "").split("[, ]+").filter(_.nonEmpty).map {
case "startaddr" => StartAddressOutput
case "endaddr" => EndAddressOutput
case "allocated" => AllocatedDataOutput
case n => n.split(":").filter(_.nonEmpty) match {
case Array(b, s, e) => BankFragmentOutput(parseNumber(b), parseNumber(s), parseNumber(e))
case Array(s, e) => CurrentBankFragmentOutput(parseNumber(s), parseNumber(e))
case Array(b) => ConstOutput(parseNumber(b).toByte)
case x => ErrorReporting.fatal(s"Invalid output format: `$x`")
var fileExtension = os.get(classOf[String], "extension", ".bin")
new Platform(cpu, flagOverrides, startingModules, outputPackager,
new VariableAllocator(freePointers, byteAllocator), org,
if (fileExtension.startsWith(".")) fileExtension else "." + fileExtension)
def parseNumber(s: String): Int = {
if (s.startsWith("$")) {
Integer.parseInt(s.substring(1), 16)
} else if (s.startsWith("0x")) {
Integer.parseInt(s.substring(2), 16)
} else if (s.startsWith("%")) {
Integer.parseInt(s.substring(1), 2)
} else if (s.startsWith("0b")) {
Integer.parseInt(s.substring(2), 2)
} else {

View File

@ -0,0 +1,50 @@
package millfork
* @author Karol Stasiak
case class SeparatedList[T, S](head: T, tail: List[(S, T)]) {
def toPairList(initialSeparator: S) = (initialSeparator -> head) :: tail
def size: Int = tail.size + 1
def items: List[T] = head ::
def separators: List[S] =
def drop(i: Int): SeparatedList[T, S] = if (i == 0) this else SeparatedList(tail(i - 1)._2, tail.drop(i))
def take(i: Int): SeparatedList[T, S] = if (i <= 0) ??? else SeparatedList(head, tail.take(i - 1))
def splitAt(i: Int): (SeparatedList[T, S], S, SeparatedList[T, S]) = {
val (a, b) = tail.splitAt(i - 1)
(SeparatedList(head, a), b.head._1, SeparatedList(b.head._2, b.tail))
def indexOfSeparator(p: S => Boolean): Int = 1 + tail.indexWhere(x => p(x._1))
def ::(pair: (T, S)) = SeparatedList(pair._1, (pair._2 -> head) :: tail)
def split(p: S => Boolean): SeparatedList[SeparatedList[T, S], S] = {
val i = indexOfSeparator(p)
if (i <= 0) SeparatedList(this, Nil)
else {
val (a, b, c) = splitAt(i)
(a, b) :: c.split(p)
object SeparatedList {
def of[T, S](t0: T): SeparatedList[T, S] = SeparatedList[T, S](t0, Nil)
def of[T, S](t0: T, s1: S, t1: T): SeparatedList[T, S] =
SeparatedList(t0, List(s1 -> t1))
def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T): SeparatedList[T, S] =
SeparatedList(t0, List(s1 -> t1, s2 -> t2))
def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T, s3: S, t3: T): SeparatedList[T, S] =
SeparatedList(t0, List(s1 -> t1, s2 -> t2, s3 -> t3))

View File

@ -0,0 +1,305 @@
package millfork.assembly
import millfork.assembly.Opcode._
import millfork.assembly.opt.ReadsA
import millfork.compiler.{CompilationContext, MlCompiler}
import millfork.env._
//noinspection TypeAnnotation
object OpcodeClasses {
val ReadsAAlways = Set(
val ReadsAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC)
val ReadsXAlways = Set(
val ReadsYAlways = Set(CPY, DEY, INY, STY, TYA, PLY, SHY)
val ReadsZ = Set(BNE, BEQ, PHP)
val ReadsN = Set(BMI, BPL, PHP)
val ReadsNOrZ = ReadsZ ++ ReadsN
val ReadsV = Set(BVS, BVC, PHP)
val ReadsD = Set(PHP, ADC, SBC, RRA, ARR, ISC, DCP) // TODO: ??
val ReadsC = Set(
val ChangesAAlways = Set(
val ChangesAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC)
val ChangesX = Set(
val ChangesY = Set(
val ChangesS = Set(
val ChangesMemoryAlways = Set(
val ChangesMemoryIfNotImplied = Set(
val ReadsMemoryIfNotImpliedOrImmediate = Set(
val OverwritesA = Set(
val OverwritesX = Set(
val OverwritesY = Set(
val OverwritesC = Set(CLC, SEC, PLP)
val OverwritesD = Set(CLD, SED, PLP)
val OverwritesI = Set(CLI, SEI, PLP)
val OverwritesV = Set(CLV, PLP)
val ConcernsAAlways = ReadsAAlways ++ ChangesAAlways
val ConcernsAIfImplied = ReadsAIfImplied ++ ChangesAIfImplied
val ConcernsXAlways = ReadsXAlways | ChangesX
val ConcernsYAlways = ReadsYAlways | ChangesY
val ConcernsStack = Set(
val ChangesNAndZ = Set(
TSB, TRB // These two do not change N, but lets pretend they do for simplicity
val ChangesC = Set(
val ChangesV = Set(
val SupportsAbsoluteX = Set(
val SupportsAbsoluteY = Set(
val SupportsAbsolute = Set(
val SupportsZeroPageIndirect = Set(ORA, AND, EOR, ADC, STA, LDA, CMP, SBC)
val ShortBranching = Set(BEQ, BNE, BMI, BPL, BVC, BVS, BCC, BCS, BRA)
val AllDirectJumps = ShortBranching + JMP
val AllLinear = Set(
val NoopDiscardsFlags = Set(DISCARD_AF, DISCARD_XF, DISCARD_YF)
val DiscardsV = NoopDiscardsFlags | OverwritesV
val DiscardsC = NoopDiscardsFlags | OverwritesC
val DiscardsD = OverwritesD
val DiscardsI = NoopDiscardsFlags | OverwritesI
object AssemblyLine {
def treatment(lines: List[AssemblyLine], state: State.Value): Treatment.Value = ~ _)
def label(label: String): AssemblyLine = AssemblyLine.label(Label(label))
def label(label: Label): AssemblyLine = AssemblyLine(LABEL, AddrMode.DoesNotExist, label.toAddress)
def discardAF() = AssemblyLine(DISCARD_AF, AddrMode.DoesNotExist, Constant.Zero)
def discardXF() = AssemblyLine(DISCARD_XF, AddrMode.DoesNotExist, Constant.Zero)
def discardYF() = AssemblyLine(DISCARD_YF, AddrMode.DoesNotExist, Constant.Zero)
def immediate(opcode: Opcode.Value, value: Int) = AssemblyLine(opcode, AddrMode.Immediate, NumericConstant(value, 1))
def immediate(opcode: Opcode.Value, value: Constant) = AssemblyLine(opcode, AddrMode.Immediate, value)
def implied(opcode: Opcode.Value) = AssemblyLine(opcode, AddrMode.Implied, Constant.Zero)
def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] =
variable match {
case v@MemoryVariable(_, _, VariableAllocationMethod.Zeropage) =>
List(AssemblyLine.zeropage(opcode, v.toAddress + offset))
case v@RelativeVariable(_, _, _, true) =>
List(AssemblyLine.zeropage(opcode, v.toAddress + offset))
case v:VariableInMemory => List(AssemblyLine.absolute(opcode, v.toAddress + offset))
case v:StackVariable=> List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(opcode, v.baseOffset + offset + ctx.extraStackOffset))
def zeropage(opcode: Opcode.Value, addr: Constant) =
AssemblyLine(opcode, AddrMode.ZeroPage, addr)
def zeropage(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.ZeroPage, thing.toAddress + offset)
def absolute(opcode: Opcode.Value, addr: Constant) =
AssemblyLine(opcode, AddrMode.Absolute, addr)
def absolute(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.Absolute, thing.toAddress + offset)
def relative(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.Relative, thing.toAddress + offset)
def relative(opcode: Opcode.Value, label: String) =
AssemblyLine(opcode, AddrMode.Relative, Label(label).toAddress)
def absoluteY(opcode: Opcode.Value, addr: Constant) =
AssemblyLine(opcode, AddrMode.AbsoluteY, addr)
def absoluteY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.AbsoluteY, thing.toAddress + offset)
def absoluteX(opcode: Opcode.Value, addr: Int) =
AssemblyLine(opcode, AddrMode.AbsoluteX, NumericConstant(addr, 2))
def absoluteX(opcode: Opcode.Value, addr: Constant) =
AssemblyLine(opcode, AddrMode.AbsoluteX, addr)
def absoluteX(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.AbsoluteX, thing.toAddress + offset)
def indexedY(opcode: Opcode.Value, addr: Constant) =
AssemblyLine(opcode, AddrMode.IndexedY, addr)
def indexedY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
AssemblyLine(opcode, AddrMode.IndexedY, thing.toAddress + offset)
case class AssemblyLine(opcode: Opcode.Value, addrMode: AddrMode.Value, var parameter: Constant, elidable: Boolean = true) {
import AddrMode._
import State._
import OpcodeClasses._
import Treatment._
def reads(state: State.Value): Boolean = state match {
case A => if (addrMode == Implied) ReadsAIfImplied(opcode) else ReadsAAlways(opcode)
case X => addrMode == AbsoluteX || addrMode == ZeroPageX || addrMode == IndexedX || ReadsXAlways(opcode)
case Y => addrMode == AbsoluteY || addrMode == ZeroPageY || addrMode == IndexedY || ReadsYAlways(opcode)
case C => ReadsC(opcode)
case D => ReadsD(opcode)
case N => ReadsN(opcode)
case V => ReadsV(opcode)
case Z => ReadsZ(opcode)
def treatment(state: State.Value): Treatment.Value = opcode match {
case LABEL => Unchanged // TODO: ???
case NOP => Unchanged
case JSR | JMP | BEQ | BNE | BMI | BPL | BRK | BCC | BVC | BCS | BVS => Changed
case CLC => if (state == C) Cleared else Unchanged
case SEC => if (state == C) Set else Unchanged
case CLV => if (state == V) Cleared else Unchanged
case CLD => if (state == D) Cleared else Unchanged
case SED => if (state == D) Set else Unchanged
case _ => state match { // TODO: smart detection of constants
case A =>
if (ChangesAAlways(opcode) || addrMode == Implied && ChangesAIfImplied(opcode))
case X => if (ChangesX(opcode)) Changed else Unchanged
case Y => if (ChangesY(opcode)) Changed else Unchanged
case C => if (ChangesC(opcode)) Changed else Unchanged
case V => if (ChangesV(opcode)) Changed else Unchanged
case N | Z => if (ChangesNAndZ(opcode)) Changed else Unchanged
case D => Unchanged
def sizeInBytes: Int = addrMode match {
case Implied => 1
case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2
case AbsoluteX | Absolute | AbsoluteY | Indirect => 3
case DoesNotExist => 0
def cost: Int = addrMode match {
case Implied => 1000
case Relative | Immediate => 2000
case ZeroPage => 2001
case ZeroPageX | ZeroPageY => 2002
case IndexedX | IndexedY => 2003
case Absolute => 3000
case AbsoluteX | AbsoluteY | Indirect => 3001
case DoesNotExist => 1
def isPrintable: Boolean = true //addrMode != AddrMode.DoesNotExist || opcode == LABEL
override def toString: String =
if (opcode == LABEL) {
} else if (addrMode == DoesNotExist) {
s" ; $opcode"
} else {
s" $opcode ${AddrMode.addrModeToString(addrMode, parameter.toString)}"

View File

@ -0,0 +1,33 @@
package millfork.assembly
import millfork.env.Label
sealed trait Chunk {
def linearize: List[AssemblyLine]
def sizeInBytes: Int
case object EmptyChunk extends Chunk {
override def linearize: Nil.type = Nil
override def sizeInBytes = 0
case class LabelledChunk(label: String, chunk: Chunk) extends Chunk {
override def linearize: List[AssemblyLine] = AssemblyLine.label(Label(label)) :: chunk.linearize
override def sizeInBytes: Int = chunk.sizeInBytes
case class SequenceChunk(chunks: List[Chunk]) extends Chunk {
override def linearize: List[AssemblyLine] = chunks.flatMap(_.linearize)
override def sizeInBytes: Int =
case class LinearChunk(lines: List[AssemblyLine]) extends Chunk {
def linearize: List[AssemblyLine] = lines
override def sizeInBytes: Int =

View File

@ -0,0 +1,190 @@
package millfork.assembly
import java.util.Locale
import millfork.error.ErrorReporting
import millfork.node.Position
object State extends Enumeration {
val A, X, Y, Z, D, C, N, V = Value
object Treatment extends Enumeration {
val Unchanged, Unsure, Changed, Cleared, Set = Value
implicit class OverriddenValue(val left: Value) extends AnyVal {
def ~(right: Treatment.Value): Treatment.Value = right match {
case Unchanged => left
case Cleared | Set => if (left == Unsure) Changed else right
case _ => right
object Opcode extends Enumeration {
val ADC, AND, ASL,
LABEL = Value
def lookup(opcode: String, position: Option[Position]): Opcode.Value = opcode.toUpperCase(Locale.ROOT) match {
case "ADC" => ADC
case "AHX" => AHX
case "ALR" => ALR
case "ANC" => ANC
case "AND" => AND
case "ANE" => XAA
case "ARR" => ARR
case "ASL" => ASL
case "ASO" => SLO
case "AXA" => AHX
case "AXS" => SBX // TODO: could mean SAX
case "BCC" => BCC
case "BCS" => BCS
case "BEQ" => BEQ
case "BIT" => BIT
case "BMI" => BMI
case "BNE" => BNE
case "BPL" => BPL
case "BRA" => BRA
case "BRK" => BRK
case "BVC" => BVC
case "BVS" => BVS
case "CLC" => CLC
case "CLD" => CLD
case "CLI" => CLI
case "CLV" => CLV
case "CMP" => CMP
case "CPX" => CPX
case "CPY" => CPY
case "DCM" => DCP
case "DCP" => DCP
case "DEC" => DEC
case "DEX" => DEX
case "DEY" => DEY
case "EOR" => EOR
case "INC" => INC
case "INS" => ISC
case "INX" => INX
case "INY" => INY
case "ISC" => ISC
case "JMP" => JMP
case "JSR" => JSR
case "LAS" => LAS
case "LAX" => LAX
case "LDA" => LDA
case "LDX" => LDX
case "LDY" => LDY
case "LSE" => SRE
case "LSR" => LSR
case "LXA" => LXA
case "NOP" => NOP
case "OAL" => LXA
case "ORA" => ORA
case "PHA" => PHA
case "PHP" => PHP
case "PHX" => PHX
case "PHY" => PHY
case "PLA" => PLA
case "PLP" => PLP
case "PLX" => PLX
case "PLY" => PLY
case "RLA" => RLA
case "ROL" => ROL
case "ROR" => ROR
case "RRA" => RRA
case "RTI" => RTI
case "RTS" => RTS
case "SAX" => SAX // TODO: could mean SBX
case "SAY" => SHY
case "SBC" => SBC
case "SBX" => SBX
case "SEC" => SEC
case "SED" => SED
case "SEI" => SEI
case "SHX" => SHX
case "SHY" => SHY
case "SLO" => SLO
case "SRE" => SRE
case "STA" => STA
case "STP" => STP
case "STX" => STX
case "STY" => STY
case "STZ" => STZ
case "TAS" => TAS
case "TAX" => TAX
case "TAY" => TAY
case "TRB" => TRB
case "TSB" => TSB
case "TSX" => TSX
case "TXA" => TXA
case "TXS" => TXS
case "TYA" => TYA
case "WAI" => WAI
case "XAA" => XAA
case "XAS" => SHX
case _ =>
ErrorReporting.error(s"Invalid opcode `$opcode`", position)
object AddrMode extends Enumeration {
val Implied,
DoesNotExist = Value
def argumentLength(a: AddrMode.Value): Int = a match {
case Absolute | AbsoluteX | AbsoluteY | Indirect =>
case _ =>
def addrModeToString(am: AddrMode.Value, argument: String): String = {
am match {
case Implied => ""
case Immediate => "#" + argument
case AbsoluteX | ZeroPageX => argument + ", X"
case AbsoluteY | ZeroPageY => argument + ", Y"
case IndexedX | AbsoluteIndexedX => "(" + argument + ", X)"
case IndexedY => "(" + argument + "), Y"
case Indirect | ZeroPageIndirect => "(" + argument + ")"
case _ => argument;

View File

@ -0,0 +1,848 @@
package millfork.assembly.opt
import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger
import millfork.assembly.{opt, _}
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.assembly.OpcodeClasses._
import millfork.env._
* These optimizations should not remove opportunities for more complex optimizations to trigger.
* @author Karol Stasiak
object AlwaysGoodOptimizations {
val counter = new AtomicInteger(30000)
def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d"
val PointlessMath = new RuleBasedAssemblyOptimization("Pointless math",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcode(CLC) & Elidable) ~
(HasOpcode(ADC) & Elidable & MatchParameter(0)) ~
(HasOpcode(SEC) & Elidable) ~
(HasOpcode(SBC) & Elidable & MatchParameter(0)) ~
(LinearOrLabel & Not(ReadsNOrZ) & Not(ReadsV) & Not(ReadsC) & Not(NoopDiscardsFlags) & Not(Set(ADC, SBC))).* ~
(NoopDiscardsFlags | Set(ADC, SBC)) ~~> (_.drop(4)),
(HasOpcode(LDA) & HasImmediate(0) & Elidable) ~
(HasOpcode(CLC) & Elidable) ~
(HasOpcode(ADC) & Elidable) ~
(LinearOrLabel & Not(ReadsV) & Not(NoopDiscardsFlags) & Not(ChangesNAndZ)).* ~
(NoopDiscardsFlags | ChangesNAndZ) ~~> (code => code(2).copy(opcode = LDA) :: code.drop(3))
val PointlessMathFromFlow = new RuleBasedAssemblyOptimization("Pointless math from flow analysis",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(Elidable & MatchA(0) &
HasOpcode(ASL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil
(Elidable & MatchA(0) &
HasOpcode(LSR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil
(Elidable & MatchA(0) &
HasClear(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil
(Elidable & MatchA(0) &
HasClear(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil
(Elidable & MatchA(0) &
HasSet(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Int](0) * 2 + 1) :: Nil
(Elidable & MatchA(0) &
HasSet(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, 0x80 + (ctx.get[Int](0) & 0xff) / 2) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(ADC) & HasAddrMode(Immediate) &
HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(ADC) & HasAddrMode(Immediate) &
HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.V)) ~
Where(ctx => (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify match {
case NumericConstant(x, _) => x == (x & 0xff)
case _ => false
}) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(ADC) & HasAddrMode(Immediate) &
HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ((ctx.get[Int](0) + 1) & 0xff)) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(SBC) & HasAddrMode(Immediate) &
HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Minus, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(EOR) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Exor, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(ORA) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Or, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(AND) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
(Elidable &
MatchA(0) & MatchParameter(1) &
HasOpcode(ANC) & HasAddrMode(Immediate) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
val MathOperationOnTwoIdenticalMemoryOperands = new RuleBasedAssemblyOptimization("Math operation on two identical memory operands",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
(HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
(HasOpcodeIn(Set(STA, LDA)) & HasAddrMode(AbsoluteX) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesX)).* ~
(HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteX) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
(HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrMode(AbsoluteY) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesY)).* ~
(HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteY) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
(HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
(DoesntMatterWhatItDoesWith(State.N, State.Z) & HasOpcodeIn(Set(ORA, AND)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init),
(HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
(DoesntMatterWhatItDoesWith(State.N, State.Z, State.C) & HasOpcode(ANC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init),
(HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
(Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
(HasOpcode(EOR) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.immediate(LDA, 0)),
val PointlessStoreAfterLoad = new RuleBasedAssemblyOptimization("Pointless store after load",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
(Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
(HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
(Elidable & HasOpcode(STX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
(HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
(Elidable & HasOpcode(STY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
val PoinlessStoreBeforeStore = new RuleBasedAssemblyOptimization("Pointless store before store",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasAddrModeIn(Set(Absolute, ZeroPage)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~
(LinearOrLabel & DoesNotConcernMemoryAt(2, 1)).* ~
(MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~~> (_.tail),
(Elidable & HasAddrModeIn(Set(AbsoluteX, ZeroPageX)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~
(LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesX)).* ~
(MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~~> (_.tail),
(Elidable & HasAddrModeIn(Set(AbsoluteY, ZeroPageY)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~
(LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesY)).* ~
(MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~~> (_.tail),
val PointlessLoadBeforeReturn = new RuleBasedAssemblyOptimization("Pointless load before return",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Set(LDA, TXA, TYA, EOR, AND, ORA, ANC) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_AF))).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail),
(Set(LDX, TAX, TSX, INX, DEX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_XF))).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail),
(Set(LDY, TAY, INY, DEY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_YF))).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail),
(HasOpcode(LDX) & Elidable & MatchAddrMode(3)) ~
(LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~
(HasOpcode(TXA) & Elidable) ~
((LinearOrLabel & Not(ConcernsX) & Not(HasOpcode(DISCARD_XF))).* ~
HasOpcode(DISCARD_XF)).capture(2) ~~> { (c, ctx) =>
ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2))
(HasOpcode(LDY) & Elidable & MatchAddrMode(3)) ~
(LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~
(HasOpcode(TYA) & Elidable) ~
((LinearOrLabel & Not(ConcernsY) & Not(HasOpcode(DISCARD_YF))).* ~
HasOpcode(DISCARD_YF)).capture(2) ~~> { (c, ctx) =>
ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2))
private def operationPairBuilder(op1: Opcode.Value, op2: Opcode.Value, middle: AssemblyLinePattern) = {
(HasOpcode(op1) & Elidable) ~
(Linear & middle).*.capture(1) ~
(HasOpcode(op2) & Elidable) ~
((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(2) ~~> { (_, ctx) =>
ctx.get[List[AssemblyLine]](1) ++ ctx.get[List[AssemblyLine]](2)
val PointlessOperationPairRemoval = new RuleBasedAssemblyOptimization("Pointless operation pair",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
operationPairBuilder(PHA, PLA, Not(ConcernsA) & Not(ConcernsStack)),
operationPairBuilder(PHX, PLX, Not(ConcernsX) & Not(ConcernsStack)),
operationPairBuilder(PHY, PLY, Not(ConcernsY) & Not(ConcernsStack)),
operationPairBuilder(INX, DEX, Not(ConcernsX) & Not(ReadsNOrZ)),
operationPairBuilder(DEX, INX, Not(ConcernsX) & Not(ReadsNOrZ)),
operationPairBuilder(INY, DEY, Not(ConcernsX) & Not(ReadsNOrZ)),
operationPairBuilder(DEY, INY, Not(ConcernsX) & Not(ReadsNOrZ)),
val BranchInPlaceRemoval = new RuleBasedAssemblyOptimization("Branch in place",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(AllDirectJumps & MatchParameter(0) & Elidable) ~
HasOpcodeIn(NoopDiscardsFlags).* ~
(HasOpcode(LABEL) & MatchParameter(0)) ~~> (c => c.last :: Nil)
val ImpossibleBranchRemoval = new RuleBasedAssemblyOptimization("Impossible branch",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(HasOpcode(BCC) & HasSet(State.C) & Elidable) ~~> (_ => Nil),
(HasOpcode(BCS) & HasClear(State.C) & Elidable) ~~> (_ => Nil),
(HasOpcode(BVC) & HasSet(State.V) & Elidable) ~~> (_ => Nil),
(HasOpcode(BVS) & HasClear(State.V) & Elidable) ~~> (_ => Nil),
(HasOpcode(BNE) & HasSet(State.Z) & Elidable) ~~> (_ => Nil),
(HasOpcode(BEQ) & HasClear(State.Z) & Elidable) ~~> (_ => Nil),
(HasOpcode(BPL) & HasSet(State.N) & Elidable) ~~> (_ => Nil),
(HasOpcode(BMI) & HasClear(State.N) & Elidable) ~~> (_ => Nil),
val UnconditionalJumpRemoval = new RuleBasedAssemblyOptimization("Unconditional jump removal",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~
(Elidable & LinearOrBranch).* ~
(HasOpcode(LABEL) & MatchParameter(0)) ~~> (_ => Nil),
(Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~
(Not(HasOpcode(LABEL)) & Not(MatchParameter(0))).* ~
(HasOpcode(LABEL) & MatchParameter(0)) ~
(HasOpcode(LABEL) | HasOpcodeIn(NoopDiscardsFlags)).* ~
HasOpcode(RTS) ~~> (code => AssemblyLine.implied(RTS) :: code.tail),
(Elidable & HasOpcodeIn(ShortBranching) & MatchParameter(0)) ~
(HasOpcodeIn(NoopDiscardsFlags).* ~
(Elidable & HasOpcode(RTS))).capture(1) ~
(HasOpcode(LABEL) & MatchParameter(0)) ~
HasOpcodeIn(NoopDiscardsFlags).* ~
(Elidable & HasOpcode(RTS)) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1)),
val TailCallOptimization = new RuleBasedAssemblyOptimization("Tail call optimization",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(JSR)) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (Elidable & HasOpcode(RTS)) ~~> (c => c.tail.init :+ c.head.copy(opcode = JMP)),
(Elidable & HasOpcode(JSR)) ~
HasOpcode(LABEL).* ~
HasOpcodeIn(NoopDiscardsFlags).*.capture(0) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](0) ++ (code.head.copy(opcode = JMP) :: code.tail)),
val UnusedCodeRemoval = new RuleBasedAssemblyOptimization("Unreachable code removal",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
HasOpcode(JMP) ~ (Not(HasOpcode(LABEL)) & Elidable).+ ~~> (c => c.head :: Nil)
val PoinlessFlagChange = new RuleBasedAssemblyOptimization("Pointless flag change",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcodeIn(Set(CMP, CPX, CPY)) & Elidable) ~ NoopDiscardsFlags ~~> (_.tail),
(OverwritesC & Elidable) ~ (LinearOrLabel & Not(ReadsC) & Not(DiscardsC)).* ~ DiscardsC ~~> (_.tail),
(OverwritesD & Elidable) ~ (LinearOrLabel & Not(ReadsD) & Not(DiscardsD)).* ~ DiscardsD ~~> (_.tail),
(OverwritesV & Elidable) ~ (LinearOrLabel & Not(ReadsV) & Not(DiscardsV)).* ~ DiscardsV ~~> (_.tail)
val FlagFlowAnalysis = new RuleBasedAssemblyOptimization("Flag flow analysis",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(HasSet(State.C) & HasOpcode(SEC) & Elidable) ~~> (_ => Nil),
(HasSet(State.D) & HasOpcode(SED) & Elidable) ~~> (_ => Nil),
(HasClear(State.C) & HasOpcode(CLC) & Elidable) ~~> (_ => Nil),
(HasClear(State.D) & HasOpcode(CLD) & Elidable) ~~> (_ => Nil),
(HasClear(State.V) & HasOpcode(CLV) & Elidable) ~~> (_ => Nil),
(HasSet(State.C) & HasOpcode(BCS) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasClear(State.C) & HasOpcode(BCC) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasSet(State.N) & HasOpcode(BMI) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasClear(State.N) & HasOpcode(BPL) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasClear(State.V) & HasOpcode(BVC) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasSet(State.V) & HasOpcode(BVS) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasSet(State.Z) & HasOpcode(BEQ) & Elidable) ~~> (c => = JMP, addrMode = Absolute))),
(HasClear(State.Z) & HasOpcode(BNE) & Elidable) ~~> (_ => Nil),
val ReverseFlowAnalysis = new RuleBasedAssemblyOptimization("Reverse flow analysis",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(Elidable & HasOpcodeIn(Set(TXA, TYA, LDA, EOR, ORA, AND)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcode(ANC) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(TAX, TSX, LDX, INX, DEX)) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(TAY, LDY, DEY, INY)) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(LAX)) & DoesntMatterWhatItDoesWith(State.A, State.X, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(SEC, CLC)) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(CLD, SED)) & DoesntMatterWhatItDoesWith(State.D)) ~~> (_ => Nil),
(Elidable & HasOpcode(CLV) & DoesntMatterWhatItDoesWith(State.V)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(CMP, CPX, CPY)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(BIT)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z, State.V)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(ASL, LSR, ROL, ROR)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil),
(Elidable & HasOpcodeIn(Set(ADC, SBC)) & DoesntMatterWhatItDoesWith(State.A, State.C, State.V, State.N, State.Z)) ~~> (_ => Nil),
private def modificationOfJustWrittenValue(store: Opcode.Value,
addrMode: AddrMode.Value,
initExtra: AssemblyLinePattern,
modify: Opcode.Value,
meantimeExtra: AssemblyLinePattern,
atLeastTwo: Boolean,
flagsToTrash: Seq[State.Value],
fix: ((AssemblyMatchingContext, Int) => List[AssemblyLine]),
alternateStore: Opcode.Value = LABEL) = {
val actualFlagsToTrash = List(State.N, State.Z) ++ flagsToTrash
val init = Elidable & HasOpcode(store) & HasAddrMode(addrMode) & MatchAddrMode(3) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*) & initExtra
val meantime = (Linear & Not(ConcernsMemory) & meantimeExtra).*
val oneModification = Elidable & HasOpcode(modify) & HasAddrMode(addrMode) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*)
val modifications = (if (atLeastTwo) oneModification ~ oneModification.+ else oneModification.+).captureLength(1)
if (alternateStore == LABEL) {
((init ~ meantime).capture(2) ~ modifications) ~~> ((code, ctx) => fix(ctx, ctx.get[Int](1)) ++ ctx.get[List[AssemblyLine]](2))
} else {
(init.capture(3) ~ meantime.capture(2) ~ modifications) ~~> { (code, ctx) =>
fix(ctx, ctx.get[Int](1)) ++
List(AssemblyLine(alternateStore, ctx.get[AddrMode.Value](3), ctx.get[Constant](0))) ++
val ModificationOfJustWrittenValue = new RuleBasedAssemblyOptimization("Modification of Just written value",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
modificationOfJustWrittenValue(STA, Absolute, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
modificationOfJustWrittenValue(STA, Absolute, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), INC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), DEC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List(
AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
modificationOfJustWrittenValue(STA, Absolute, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(ADC, i)
modificationOfJustWrittenValue(STA, Absolute, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(SBC, i)
modificationOfJustWrittenValue(STA, ZeroPage, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(ADC, i)
modificationOfJustWrittenValue(STA, ZeroPage, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(SBC, i)
modificationOfJustWrittenValue(STA, AbsoluteX, Anything, INC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(ADC, i)
modificationOfJustWrittenValue(STA, AbsoluteX, Anything, DEC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
AssemblyLine.immediate(SBC, i)
modificationOfJustWrittenValue(STA, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
modificationOfJustWrittenValue(STA, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
modificationOfJustWrittenValue(STA, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
modificationOfJustWrittenValue(STA, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
modificationOfJustWrittenValue(STA, AbsoluteX, Anything, ASL, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
modificationOfJustWrittenValue(STA, AbsoluteX, Anything, LSR, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
modificationOfJustWrittenValue(STX, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))),
modificationOfJustWrittenValue(STX, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))),
modificationOfJustWrittenValue(STY, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))),
modificationOfJustWrittenValue(STY, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))),
modificationOfJustWrittenValue(STZ, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
modificationOfJustWrittenValue(STZ, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
modificationOfJustWrittenValue(STZ, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA),
modificationOfJustWrittenValue(STZ, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA),
modificationOfJustWrittenValue(STX, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))),
modificationOfJustWrittenValue(STX, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))),
modificationOfJustWrittenValue(STY, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))),
modificationOfJustWrittenValue(STY, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))),
modificationOfJustWrittenValue(STZ, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
modificationOfJustWrittenValue(STZ, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
modificationOfJustWrittenValue(STZ, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA),
modificationOfJustWrittenValue(STZ, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA),
val ConstantFlowAnalysis = new RuleBasedAssemblyOptimization("Constant flow analysis",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(MatchX(0) & HasAddrMode(AbsoluteX) & SupportsAbsolute & Elidable) ~~> { (code, ctx) => => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0)))
(MatchY(0) & HasAddrMode(AbsoluteY) & SupportsAbsolute & Elidable) ~~> { (code, ctx) => => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0)))
(MatchX(0) & HasAddrMode(ZeroPageX) & Elidable) ~~> { (code, ctx) => => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0)))
(MatchY(0) & HasAddrMode(ZeroPageY) & Elidable) ~~> { (code, ctx) => => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0)))
val IdempotentDuplicateRemoval = new RuleBasedAssemblyOptimization("Idempotent duplicate operation",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
HasOpcode(RTS) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTS) ~ Elidable) ~~> (_.take(1)) ::
HasOpcode(RTI) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTI) ~ Elidable) ~~> (_.take(1)) ::
HasOpcode(DISCARD_XF) ~ (Not(HasOpcode(DISCARD_XF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail) ::
HasOpcode(DISCARD_AF) ~ (Not(HasOpcode(DISCARD_AF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail) ::
HasOpcode(DISCARD_YF) ~ (Not(HasOpcode(DISCARD_YF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail) ::
List(RTS, RTI, SEC, CLC, CLV, CLD, SED, SEI, CLI, TAX, TXA, TYA, TAY, TXS, TSX).flatMap { opcode =>
(HasOpcode(opcode) & Elidable) ~ (HasOpcodeIn(NoopDiscardsFlags) | HasOpcode(LABEL)).* ~ HasOpcode(opcode) ~~> (_.tail),
HasOpcode(opcode) ~ (HasOpcode(opcode) ~ Elidable) ~~> (_.init),
}: _*
val PointlessRegisterTransfers = new RuleBasedAssemblyOptimization("Pointless register transfers",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
HasOpcode(TYA) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init),
HasOpcode(TXA) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init),
HasOpcode(TAY) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init),
HasOpcode(TAX) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init),
HasOpcode(TSX) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
HasOpcode(TXS) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
HasOpcode(TSX) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
HasOpcode(TXS) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
val PointlessRegisterTransfersBeforeStore = new RuleBasedAssemblyOptimization("Pointless register transfers before store",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(Elidable & HasOpcode(TXA)) ~
(Linear & Not(ConcernsA) & Not(ConcernsX)).* ~
(Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageY, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STX)),
(Elidable & HasOpcode(TYA)) ~
(Linear & Not(ConcernsA) & Not(ConcernsY)).* ~
(Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STY)),
val PointlessRegisterTransfersBeforeReturn = new RuleBasedAssemblyOptimization("Pointless register transfers before return",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcode(TAX) & Elidable) ~
HasOpcode(LABEL).* ~
HasOpcode(TXA).? ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
(HasOpcode(TSX) & Elidable) ~
HasOpcode(LABEL).* ~
HasOpcode(TSX).? ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
(HasOpcode(TXA) & Elidable) ~
HasOpcode(LABEL).* ~
HasOpcode(TAX).? ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
(HasOpcode(TAY) & Elidable) ~
HasOpcode(LABEL).* ~
HasOpcode(TYA).? ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_YF)).capture(1) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
(HasOpcode(TYA) & Elidable) ~
HasOpcode(LABEL).* ~
HasOpcode(TAY).? ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~
HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
val PointlessRegisterTransfersBeforeCompare = new RuleBasedAssemblyOptimization("Pointless register transfers before compare",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
HasOpcodeIn(Set(DEX, INX, LDX, LAX)) ~
(HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init),
HasOpcodeIn(Set(DEY, INY, LDY)) ~
(HasOpcode(TYA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init),
private def stashing(tai: Opcode.Value, tia: Opcode.Value, readsI: AssemblyLinePattern, concernsI: AssemblyLinePattern, discardIF: Opcode.Value, withRts: Boolean, withBeq: Boolean) = {
val init: AssemblyPattern = if (withBeq) {
(Linear & ChangesNAndZ & ChangesA) ~
(HasOpcode(tai) & Elidable) ~
(Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
(ShortBranching & ReadsNOrZ & MatchParameter(0))
} else {
(HasOpcode(tai) & Elidable) ~
(Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
((ShortBranching -- ReadsNOrZ) & MatchParameter(0))
val inner: AssemblyPattern = if (withRts) {
(Linear & Not(readsI) & Not(ReadsNOrZ ++ NoopDiscardsFlags)).* ~
ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(discardIF)) ~
HasOpcodeIn(Set(RTS, RTI)) ~
} else {
(Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).*
val end: AssemblyPattern =
(HasOpcode(LABEL) & MatchParameter(0)) ~
(Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
(HasOpcode(tia) & Elidable)
val total = init ~ inner ~ end
if (withBeq) {
total ~~> (code => code.head :: (code.tail.tail.init :+ AssemblyLine.implied(tai)))
} else {
total ~~> (code => code.tail.init :+ AssemblyLine.implied(tai))
// Optimize the following patterns:
// TAX - B__ .a - don't change A - .a - TXA
// TAX - B__ .a - change A discard X RTS - .a - TXA
// by removing the first transfer and flipping the second one
val PointlessStashingToIndexOverShortSafeBranch = new RuleBasedAssemblyOptimization("Pointless stashing into index over short safe branch",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
// stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = false),
stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = false),
// stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = true),
// stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = true),
// stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = false),
// stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = false),
// stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = true),
// stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = true),
private def loadBeforeTransfer(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, ams: Set[AddrMode.Value]) =
(Elidable & HasOpcode(ld1) & MatchAddrMode(0) & MatchParameter(1) & HasAddrModeIn(ams)) ~
(Linear & Not(ReadsNOrZ) & Not(concerns1) & DoesntChangeMemoryAt(0, 1) & DoesntChangeIndexingInAddrMode(0) & Not(HasOpcode(t12))).*.capture(2) ~
(HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2)
val PointlessLoadBeforeTransfer = new RuleBasedAssemblyOptimization("Pointless load before transfer",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
loadBeforeTransfer(LDX, LDA, ConcernsX, State.X, TXA, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)),
loadBeforeTransfer(LDA, LDX, ConcernsA, State.A, TAX, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)),
loadBeforeTransfer(LDY, LDA, ConcernsY, State.Y, TYA, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)),
loadBeforeTransfer(LDA, LDY, ConcernsA, State.A, TAY, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)),
private def immediateLoadBeforeTwoTransfers(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, t21: Opcode.Value) =
(Elidable & HasOpcode(ld1) & HasAddrMode(Immediate)) ~
(Linear & Not(ReadsNOrZ) & Not(concerns1) & Not(HasOpcode(t12))).*.capture(2) ~
(HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2)
val YYY = new RuleBasedAssemblyOptimization("Pointless load before transfer",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
immediateLoadBeforeTwoTransfers(LDA, LDY, ConcernsA, State.A, TAY, TYA),
val ConstantIndexPropagation = new RuleBasedAssemblyOptimization("Constant index propagation",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcode(LDX) & HasAddrMode(Immediate) & MatchParameter(0)) ~
(Linear & Not(ChangesX) & Not(HasAddrMode(AbsoluteX))).* ~
(Elidable & SupportsAbsolute & HasAddrMode(AbsoluteX)) ~~> { (lines, ctx) =>
val last = lines.last
val offset = ctx.get[Constant](0)
lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset)
(HasOpcode(LDY) & HasAddrMode(Immediate) & MatchParameter(0)) ~
(Linear & Not(ChangesY) & Not(HasAddrMode(AbsoluteY))).* ~
(Elidable & SupportsAbsolute & HasAddrMode(AbsoluteY)) ~~> { (lines, ctx) =>
val last = lines.last
val offset = ctx.get[Constant](0)
lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset)
val PoinlessLoadBeforeAnotherLoad = new RuleBasedAssemblyOptimization("Pointless load before another load",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Set(LDA, TXA, TYA) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ)).* ~ OverwritesA ~~> (_.tail),
(Set(LDX, TAX, TSX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ)).* ~ OverwritesX ~~> (_.tail),
(Set(LDY, TAY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ)).* ~ OverwritesY ~~> (_.tail),
// TODO: better proofs that memory doesn't change
val PointlessLoadAfterLoadOrStore = new RuleBasedAssemblyOptimization("Pointless load after load or store",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcodeIn(Set(LDA, STA)) & HasAddrMode(Implied) & MatchParameter(1)) ~
(Linear & Not(ChangesA)).* ~
(Elidable & HasOpcode(LDA) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
(HasOpcodeIn(Set(LDX, STX)) & HasAddrMode(Implied) & MatchParameter(1)) ~
(Linear & Not(ChangesX)).* ~
(Elidable & HasOpcode(LDX) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
(HasOpcodeIn(Set(LDY, STY)) & HasAddrMode(Implied) & MatchParameter(1)) ~
(Linear & Not(ChangesY)).* ~
(Elidable & HasOpcode(LDY) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
(HasOpcodeIn(Set(LDA, STA)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
(HasOpcodeIn(Set(LDX, STX)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ChangesX) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
(Elidable & HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
(HasOpcodeIn(Set(LDY, STY)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ChangesY) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
(Elidable & HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
val PointlessOperationAfterLoad = new RuleBasedAssemblyOptimization("Pointless operation after load",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(EOR) & HasImmediate(0)) ~~> (_.init),
(ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(ORA) & HasImmediate(0)) ~~> (_.init),
(ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(AND) & HasImmediate(0xff)) ~~> (_.init)
val SimplifiableBitOpsSequence = new RuleBasedAssemblyOptimization("Simplifiable sequence of bit operations",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(EOR) & MatchImmediate(0)) ~
(Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
(Elidable & HasOpcode(EOR) & MatchImmediate(1)) ~~> { (lines, ctx) =>
lines.init.tail :+ AssemblyLine.immediate(EOR, CompoundConstant(MathOperator.Exor, ctx.get[Constant](0), ctx.get[Constant](1)))
(Elidable & HasOpcode(ORA) & MatchImmediate(0)) ~
(Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
(Elidable & HasOpcode(ORA) & MatchImmediate(1)) ~~> { (lines, ctx) =>
lines.init.tail :+ AssemblyLine.immediate(ORA, CompoundConstant(MathOperator.Or, ctx.get[Constant](0), ctx.get[Constant](1)))
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
(Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
(Elidable & HasOpcode(AND) & MatchImmediate(1)) ~~> { (lines, ctx) =>
lines.init.tail :+ AssemblyLine.immediate(AND, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1)))
(Elidable & HasOpcode(ANC) & MatchImmediate(0)) ~
(Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsC) & Not(ReadsA)).* ~
(Elidable & HasOpcode(ANC) & MatchImmediate(1)) ~~> { (lines, ctx) =>
lines.init.tail :+ AssemblyLine.immediate(ANC, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1)))
val RemoveNops = new RuleBasedAssemblyOptimization("Removing NOP instructions",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(NOP)) ~~> (_ => Nil)
val RearrangeMath = new RuleBasedAssemblyOptimization("Rearranging math",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~
(Elidable & HasOpcodeIn(Set(CLC, SEC))) ~
(Elidable & HasOpcode(ADC) & Not(HasAddrMode(Immediate))) ~~> { c =>
c.last.copy(opcode = LDA) :: c(1) :: c.head.copy(opcode = ADC) :: Nil
(Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~
(Elidable & HasOpcodeIn(Set(ADC, EOR, ORA, AND)) & Not(HasAddrMode(Immediate))) ~~> { c =>
c.last.copy(opcode = LDA) :: c.head.copy(opcode = c.last.opcode) :: Nil
private def wordShifting(i: Int, hiFirst: Boolean, hiFromX: Boolean) = {
val ldax = if (hiFromX) LDX else LDA
val stax = if (hiFromX) STX else STA
val restriction = if (hiFromX) Not(ReadsX) else Anything
val originalStart = if (hiFirst) {
(Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~
(Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction) ~
(Elidable & HasOpcode(ldax) & HasImmediate(0)) ~
(Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5))
} else {
(Elidable & HasOpcode(ldax) & HasImmediate(0)) ~
(Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5)) ~
(Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~
(Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction)
val middle = (Linear & Not(ConcernsMemory) & DoesntChangeIndexingInAddrMode(3) & DoesntChangeIndexingInAddrMode(5)).*
val singleOriginalShift =
(Elidable & HasOpcode(ASL) & MatchParameter(2) & MatchAddrMode(3)) ~
(Elidable & HasOpcode(ROL) & MatchParameter(4) & MatchAddrMode(5) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z))
val originalShifting = (1 to i).map(_ => singleOriginalShift).reduce(_ ~ _)
originalStart ~ middle.capture(6) ~ originalShifting ~~> { (code, ctx) =>
val newStart = List(
code(1).copy(addrMode = code(3).addrMode, parameter = code(3).parameter),
code(3).copy(addrMode = code(1).addrMode, parameter = code(1).parameter))
val middle = ctx.get[List[AssemblyLine]](6)
val singleNewShift = List(
AssemblyLine(LSR, ctx.get[AddrMode.Value](5), ctx.get[Constant](4)),
AssemblyLine(ROR, ctx.get[AddrMode.Value](3), ctx.get[Constant](2)))
newStart ++ middle ++ (i until 8).flatMap(_ => singleNewShift)
val SmarterShiftingWords = new RuleBasedAssemblyOptimization("Smarter shifting of words",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
wordShifting(8, hiFirst = false, hiFromX = true),
wordShifting(8, hiFirst = false, hiFromX = false),
wordShifting(8, hiFirst = true, hiFromX = true),
wordShifting(8, hiFirst = true, hiFromX = false),
wordShifting(7, hiFirst = false, hiFromX = true),
wordShifting(7, hiFirst = false, hiFromX = false),
wordShifting(7, hiFirst = true, hiFromX = true),
wordShifting(7, hiFirst = true, hiFromX = false),
wordShifting(6, hiFirst = false, hiFromX = true),
wordShifting(6, hiFirst = false, hiFromX = false),
wordShifting(6, hiFirst = true, hiFromX = true),
wordShifting(6, hiFirst = true, hiFromX = false),
wordShifting(5, hiFirst = false, hiFromX = true),
wordShifting(5, hiFirst = false, hiFromX = false),
wordShifting(5, hiFirst = true, hiFromX = true),
wordShifting(5, hiFirst = true, hiFromX = false),
private def carryFlagConversionCase(shift: Int, firstSet: Boolean, zeroIfSet: Boolean) = {
val nonZero = 1 << shift
val test = Elidable & HasOpcode(if (firstSet) BCC else BCS) & MatchParameter(0)
val ifSet = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) 0 else nonZero)
val ifClear = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) nonZero else 0)
val jump = Elidable & HasOpcodeIn(Set(JMP, if (firstSet) BCS else BCC, if (zeroIfSet) BEQ else BNE)) & MatchParameter(1)
val elseLabel = Elidable & HasOpcode(LABEL) & MatchParameter(0)
val afterLabel = Elidable & HasOpcode(LABEL) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z)
val store = Elidable & (Not(ReadsC) & Linear | HasOpcodeIn(Set(RTS, JSR, RTI)))
val secondReturn = (Elidable & HasOpcodeIn(Set(RTS, RTI) | NoopDiscardsFlags)).*.capture(6)
val where = Where { ctx =>
ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) ||
ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) ++ ctx.get[List[AssemblyLine]](6)
val pattern =
if (firstSet) test ~ ifSet ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifClear ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where
else test ~ ifClear ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifSet ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where
pattern ~~> { (_, ctx) =>
AssemblyLine.immediate(LDA, 0),
AssemblyLine.implied(if (shift >= 4) ROR else ROL)) ++
(if (shift >= 4) List.fill(7 - shift)(AssemblyLine.implied(LSR)) else List.fill(shift)(AssemblyLine.implied(ASL))) ++
(if (zeroIfSet) List(AssemblyLine.immediate(EOR, nonZero)) else Nil) ++
ctx.get[List[AssemblyLine]](5) ++
val CarryFlagConversion = new RuleBasedAssemblyOptimization("Carry flag conversion",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
// TODO: These yield 2 cycles more but 12 bytes less
// TODO: Add an "optimize for size" compilation option?
// carryFlagConversionCase(2, firstSet = false, zeroIfSet = false),
// carryFlagConversionCase(2, firstSet = true, zeroIfSet = false),
// carryFlagConversionCase(1, firstSet = true, zeroIfSet = true),
// carryFlagConversionCase(1, firstSet = false, zeroIfSet = true),
carryFlagConversionCase(1, firstSet = false, zeroIfSet = false),
carryFlagConversionCase(1, firstSet = true, zeroIfSet = false),
carryFlagConversionCase(0, firstSet = true, zeroIfSet = true),
carryFlagConversionCase(0, firstSet = false, zeroIfSet = true),
carryFlagConversionCase(0, firstSet = false, zeroIfSet = false),
carryFlagConversionCase(0, firstSet = true, zeroIfSet = false),
// carryFlagConversionCase(5, firstSet = false, zeroIfSet = false),
// carryFlagConversionCase(5, firstSet = true, zeroIfSet = false),
// carryFlagConversionCase(6, firstSet = true, zeroIfSet = true),
// carryFlagConversionCase(6, firstSet = false, zeroIfSet = true),
carryFlagConversionCase(6, firstSet = false, zeroIfSet = false),
carryFlagConversionCase(6, firstSet = true, zeroIfSet = false),
carryFlagConversionCase(7, firstSet = true, zeroIfSet = true),
carryFlagConversionCase(7, firstSet = false, zeroIfSet = true),
carryFlagConversionCase(7, firstSet = false, zeroIfSet = false),
carryFlagConversionCase(7, firstSet = true, zeroIfSet = false),
val Adc0Optimization = new RuleBasedAssemblyOptimization("ADC #0/#1 optimization",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~
(Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("ah")
AssemblyLine.relative(BCC, label),
code.last.copy(opcode = INC),
(Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
(Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("ah")
AssemblyLine.relative(BCC, label),
code.last.copy(opcode = INC),
(Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~
(Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = INC))
(Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.C) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
(Elidable & HasOpcode(ADC) & HasImmediate(1) & HasClear(State.D)) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = INC))
(Elidable & HasOpcode(TXA) & HasClear(State.D)) ~
(Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
(Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("ah")
AssemblyLine.relative(BCC, label),
(Elidable & HasOpcode(TYA) & HasClear(State.D)) ~
(Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
(Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("ah")
AssemblyLine.relative(BCC, label),
(Elidable & HasOpcode(TXA) & HasClear(State.D) & HasClear(State.C)) ~
(Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
(Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
(Elidable & HasOpcode(TYA) & HasClear(State.D) & HasClear(State.C)) ~
(Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
(Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val IndexSequenceOptimization = new RuleBasedAssemblyOptimization("Index sequence optimization",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil),
(Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INY))),
(Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEY))),
(Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil),
(Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INX))),
(Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEX))),

View File

@ -0,0 +1,14 @@
package millfork.assembly.opt
import millfork.CompilationOptions
import millfork.assembly.AssemblyLine
import millfork.env.NormalFunction
* @author Karol Stasiak
trait AssemblyOptimization {
def name: String
def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine]

View File

@ -0,0 +1,155 @@
package millfork.assembly.opt
import millfork.CompilationOptions
import millfork.assembly.{AssemblyLine, OpcodeClasses}
import millfork.env.NormalFunction
import millfork.error.ErrorReporting
* @author Karol Stasiak
object ChangeIndexRegisterOptimizationPreferringX2Y extends ChangeIndexRegisterOptimization(true)
object ChangeIndexRegisterOptimizationPreferringY2X extends ChangeIndexRegisterOptimization(false)
class ChangeIndexRegisterOptimization(preferX2Y: Boolean) extends AssemblyOptimization {
object IndexReg extends Enumeration {
val X, Y = Value
object IndexDirection extends Enumeration {
val X2Y, Y2X = Value
import IndexReg._
import IndexDirection._
import millfork.assembly.AddrMode._
import millfork.assembly.Opcode._
type IndexReg = IndexReg.Value
type IndexDirection = IndexDirection.Value
override def name = "Changing index registers"
override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
val usesIndex = code.exists(l =>
OpcodeClasses.ReadsXAlways(l.opcode) ||
OpcodeClasses.ReadsYAlways(l.opcode) ||
OpcodeClasses.ChangesX(l.opcode) ||
OpcodeClasses.ChangesY(l.opcode) ||
Set(AbsoluteX, AbsoluteY, ZeroPageY, ZeroPageX, IndexedX, IndexedY)(l.addrMode)
if (!usesIndex) {
return code
val canX2Y = f.returnType.size <= 1 && canOptimize(code, X2Y, None)
val canY2X = canOptimize(code, Y2X, None)
(canX2Y, canY2X) match {
case (false, false) => code
case (true, false) =>
ErrorReporting.debug("Changing index register from X to Y")
case (false, true) =>
ErrorReporting.debug("Changing index register from X to Y")
case (true, true) =>
if (preferX2Y) {
ErrorReporting.debug("Changing index register from X to Y (arbitrarily)")
} else {
ErrorReporting.debug("Changing index register from Y to X (arbitrarily)")
//noinspection OptionEqualsSome
private def canOptimize(code: List[AssemblyLine], dir: IndexDirection, loaded: Option[IndexReg]): Boolean = code match {
case AssemblyLine(_, AbsoluteY, _, _) :: xs if loaded != Some(Y) => false
case AssemblyLine(_, ZeroPageY, _, _) :: xs if loaded != Some(Y) => false
case AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y || loaded != Some(Y) => false
case AssemblyLine(_, AbsoluteX, _, _) :: xs if loaded != Some(X) => false
case AssemblyLine(_, ZeroPageX, _, _) :: xs if loaded != Some(X) => false
case AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X || loaded != Some(Y) => false
// using a wrong index register for one instruction is fine
case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X =>
canOptimize(xs, dir, None)
case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y =>
canOptimize(xs, dir, None)
case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y =>
canOptimize(xs, dir, None)
case AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y => false
case AssemblyLine(LAX, _, _, _) :: xs => false
case AssemblyLine(JSR, _, _, _) :: xs => false // TODO
case AssemblyLine(JMP, _, _, _) :: xs => canOptimize(xs, dir, None)
case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.ShortBranching(op) => canOptimize(xs, dir, None)
case AssemblyLine(RTS, _, _, _) :: xs => canOptimize(xs, dir, None)
case AssemblyLine(LABEL, _, _, _) :: xs => canOptimize(xs, dir, None)
case AssemblyLine(DISCARD_XF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != X))
case AssemblyLine(DISCARD_YF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != Y))
case AssemblyLine(_, DoesNotExist, _, _) :: xs => canOptimize(xs, dir, loaded)
case AssemblyLine(TAX | LDX | PLX, _, _, e) :: xs =>
(e || dir == Y2X) && canOptimize(xs, dir, Some(X))
case AssemblyLine(TAY | LDY | PLY, _, _, e) :: xs =>
(e || dir == X2Y) && canOptimize(xs, dir, Some(Y))
case AssemblyLine(TXA | STX | PHX | CPX | INX | DEX, _, _, e) :: xs =>
(e || dir == Y2X) && loaded == Some(X) && canOptimize(xs, dir, Some(X))
case AssemblyLine(TYA | STY | PHY | CPY | INY | DEY, _, _, e) :: xs =>
(e || dir == X2Y) && loaded == Some(Y) && canOptimize(xs, dir, Some(Y))
case AssemblyLine(SAX | TXS | SBX, _, _, _) :: xs => dir == Y2X && loaded == Some(X) && canOptimize(xs, dir, Some(X))
case AssemblyLine(TSX, _, _, _) :: xs => dir == Y2X && loaded != Some(Y) && canOptimize(xs, dir, Some(X))
case _ :: xs => canOptimize(xs, dir, loaded)
case Nil => true
private def switchX2Y(code: List[AssemblyLine]): List[AssemblyLine] = code match {
case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _)) :: xs => a :: b :: switchX2Y(xs)
case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(_, IndexedX, _, _)) :: xs => a :: b :: switchX2Y(xs)
case (x@AssemblyLine(TAX, _, _, _)) :: xs => x.copy(opcode = TAY) :: switchX2Y(xs)
case (x@AssemblyLine(TXA, _, _, _)) :: xs => x.copy(opcode = TYA) :: switchX2Y(xs)
case (x@AssemblyLine(STX, _, _, _)) :: xs => x.copy(opcode = STY) :: switchX2Y(xs)
case (x@AssemblyLine(LDX, _, _, _)) :: xs => x.copy(opcode = LDY) :: switchX2Y(xs)
case (x@AssemblyLine(INX, _, _, _)) :: xs => x.copy(opcode = INY) :: switchX2Y(xs)
case (x@AssemblyLine(DEX, _, _, _)) :: xs => x.copy(opcode = DEY) :: switchX2Y(xs)
case (x@AssemblyLine(CPX, _, _, _)) :: xs => x.copy(opcode = CPY) :: switchX2Y(xs)
case AssemblyLine(LAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected LAX")
case AssemblyLine(TXS, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TXS")
case AssemblyLine(TSX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TSX")
case AssemblyLine(SBX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SBX")
case AssemblyLine(SAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SAX")
case (x@AssemblyLine(_, AbsoluteX, _, _)) :: xs => x.copy(addrMode = AbsoluteY) :: switchX2Y(xs)
case (x@AssemblyLine(_, ZeroPageX, _, _)) :: xs => x.copy(addrMode = ZeroPageY) :: switchX2Y(xs)
case (x@AssemblyLine(_, IndexedX, _, _)) :: xs => ErrorReporting.fatal("Unexpected IndexedX")
case x::xs => x :: switchX2Y(xs)
case Nil => Nil
private def switchY2X(code: List[AssemblyLine]): List[AssemblyLine] = code match {
case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs => code.take(2) ++ switchY2X(xs)
case (x@AssemblyLine(TAY, _, _, _)) :: xs => x.copy(opcode = TAX) :: switchY2X(xs)
case (x@AssemblyLine(TYA, _, _, _)) :: xs => x.copy(opcode = TXA) :: switchY2X(xs)
case (x@AssemblyLine(STY, _, _, _)) :: xs => x.copy(opcode = STX) :: switchY2X(xs)
case (x@AssemblyLine(LDY, _, _, _)) :: xs => x.copy(opcode = LDX) :: switchY2X(xs)
case (x@AssemblyLine(INY, _, _, _)) :: xs => x.copy(opcode = INX) :: switchY2X(xs)
case (x@AssemblyLine(DEY, _, _, _)) :: xs => x.copy(opcode = DEX) :: switchY2X(xs)
case (x@AssemblyLine(CPY, _, _, _)) :: xs => x.copy(opcode = CPX) :: switchY2X(xs)
case (x@AssemblyLine(_, AbsoluteY, _, _)) :: xs => x.copy(addrMode = AbsoluteX) :: switchY2X(xs)
case (x@AssemblyLine(_, ZeroPageY, _, _)) :: xs => x.copy(addrMode = ZeroPageX) :: switchY2X(xs)
case AssemblyLine(_, IndexedY, _, _) :: xs => ErrorReporting.fatal("Unexpected IndexedY")
case x::xs => x :: switchY2X(xs)
case Nil => Nil

View File

@ -0,0 +1,36 @@
package millfork.assembly.opt
import millfork.assembly.{AssemblyLine, Opcode}
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.assembly.OpcodeClasses._
import millfork.env.{Constant, NormalFunction}
* @author Karol Stasiak
object CmosOptimizations {
val StzAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX)
val ZeroStoreAsStz = new RuleBasedAssemblyOptimization("Zero store",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(HasA(0) & HasOpcode(STA) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
code.head.copy(opcode = STZ) :: Nil
(HasX(0) & HasOpcode(STX) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
code.head.copy(opcode = STZ) :: Nil
(HasY(0) & HasOpcode(STY) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
code.head.copy(opcode = STZ) :: Nil
val OptimizeZeroIndex = new RuleBasedAssemblyOptimization("Optimizing zero index",
needsFlowInfo = FlowInfoRequirement.ForwardFlow,
(Elidable & HasY(0) & HasAddrMode(IndexedY) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => = ZeroPageIndirect))),
(Elidable & HasX(0) & HasAddrMode(IndexedX) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => = ZeroPageIndirect))),
val All: List[AssemblyOptimization] = List(ZeroStoreAsStz)

View File

@ -0,0 +1,259 @@
package millfork.assembly.opt
import millfork.assembly.{AssemblyLine, OpcodeClasses, State}
import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
import scala.collection.immutable
* @author Karol Stasiak
sealed trait Status[T] {
def contains(value: T): Boolean
def ~(that: Status[T]): Status[T] = {
(this, that) match {
case (AnyStatus(), _) => AnyStatus()
case (_, AnyStatus()) => AnyStatus()
case (SingleStatus(x), SingleStatus(y)) => if (x == y) SingleStatus(x) else AnyStatus()
case (SingleStatus(x), UnknownStatus()) => SingleStatus(x)
case (UnknownStatus(), SingleStatus(x)) => SingleStatus(x)
case (UnknownStatus(), UnknownStatus()) => UnknownStatus()
object Status {
implicit class IntStatusOps(val inner: Status[Int]) extends AnyVal {
def map[T](f: Int => T): Status[T] = inner match {
case SingleStatus(x) => SingleStatus(f(x))
case _ => AnyStatus()
def z(f: Int => Int = identity): Status[Boolean] = inner match {
case SingleStatus(x) =>
val y = f(x) & 0xff
SingleStatus(y == 0)
case _ => AnyStatus()
def n(f: Int => Int = identity): Status[Boolean] = inner match {
case SingleStatus(x) =>
val y = f(x) & 0xff
SingleStatus(y >= 0x80)
case _ => AnyStatus()
case class SingleStatus[T](t: T) extends Status[T] {
override def contains(value: T): Boolean = t == value
override def toString: String = t match {
case true => "1"
case false => "0"
case _ => t.toString
case class UnknownStatus[T]() extends Status[T] {
override def contains(value: T) = false
override def toString: String = "_"
case class AnyStatus[T]() extends Status[T] {
override def contains(value: T) = false
override def toString: String = "#"
//noinspection RedundantNewCaseClass
case class CpuStatus(a: Status[Int] = UnknownStatus(),
x: Status[Int] = UnknownStatus(),
y: Status[Int] = UnknownStatus(),
z: Status[Boolean] = UnknownStatus(),
n: Status[Boolean] = UnknownStatus(),
c: Status[Boolean] = UnknownStatus(),
v: Status[Boolean] = UnknownStatus(),
d: Status[Boolean] = UnknownStatus(),
) {
override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d"
def nz: CpuStatus =
this.copy(n = AnyStatus(), z = AnyStatus())
def nz(i: Long): CpuStatus =
this.copy(n = SingleStatus((i & 0x80) != 0), z = SingleStatus((i & 0xff) == 0))
def ~(that: CpuStatus) = new CpuStatus(
a = this.a ~ that.a,
x = this.x ~ that.x,
y = this.y ~ that.y,
z = this.z ~ that.z,
n = this.n ~ that.n,
c = this.c ~ that.c,
v = this.v ~ that.v,
d = this.d ~ that.d,
def hasClear(state: State.Value): Boolean = state match {
case State.A => a.contains(0)
case State.X => x.contains(0)
case State.Y => y.contains(0)
case State.Z => z.contains(false)
case State.N => n.contains(false)
case State.C => c.contains(false)
case State.V => v.contains(false)
case State.D => d.contains(false)
def hasSet(state: State.Value): Boolean = state match {
case State.A => false
case State.X => false
case State.Y => false
case State.Z => z.contains(true)
case State.N => n.contains(true)
case State.C => c.contains(true)
case State.V => v.contains(true)
case State.D => d.contains(true)
object CoarseFlowAnalyzer {
//noinspection RedundantNewCaseClass
def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuStatus] = {
val flagArray = Array.fill[CpuStatus](code.length)(CpuStatus())
val codeArray = code.toArray
val initialStatus = new CpuStatus(d = SingleStatus(false))
var changed = true
while (changed) {
changed = false
var currentStatus: CpuStatus = if (f.interrupt) CpuStatus() else initialStatus
for (i <- codeArray.indices) {
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
if (flagArray(i) != currentStatus) {
changed = true
flagArray(i) = currentStatus
codeArray(i) match {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) =>
val L = l
currentStatus = codeArray.indices.flatMap(j => codeArray(j) match {
case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j))
case _ => None
}).fold(CpuStatus())(_ ~ _)
case AssemblyLine(BCC, _, _, _) =>
currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(true))
case AssemblyLine(BCS, _, _, _) =>
currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(false))
case AssemblyLine(BVS, _, _, _) =>
currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(false))
case AssemblyLine(BVC, _, _, _) =>
currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(true))
case AssemblyLine(BMI, _, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(false))
case AssemblyLine(BPL, _, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(true))
case AssemblyLine(BEQ, _, _, _) =>
currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(false))
case AssemblyLine(BNE, _, _, _) =>
currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(true))
case AssemblyLine(SED, _, _, _) =>
currentStatus = currentStatus.copy(d = SingleStatus(true))
case AssemblyLine(SEC, _, _, _) =>
currentStatus = currentStatus.copy(c = SingleStatus(true))
case AssemblyLine(CLD, _, _, _) =>
currentStatus = currentStatus.copy(d = SingleStatus(false))
case AssemblyLine(CLC, _, _, _) =>
currentStatus = currentStatus.copy(c = SingleStatus(false))
case AssemblyLine(CLV, _, _, _) =>
currentStatus = currentStatus.copy(v = SingleStatus(false))
case AssemblyLine(JSR, _, _, _) =>
currentStatus = initialStatus
case AssemblyLine(LDX, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = = SingleStatus(n))
case AssemblyLine(LDY, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = = SingleStatus(n))
case AssemblyLine(LDA, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = = SingleStatus(n))
case AssemblyLine(LAX, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = = SingleStatus(n), x = SingleStatus(n))
case AssemblyLine(EOR, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = currentStatus.copy(n = currentStatus.a.n(_ ^ n), z = currentStatus.a.z(_ ^ n), a = ^ n))
case AssemblyLine(AND, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), z = currentStatus.a.z(_ & n), a = & n))
case AssemblyLine(ANC, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), c = currentStatus.a.n(_ & n), z = currentStatus.x.z(_ & n), a = & n))
case AssemblyLine(ORA, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = currentStatus.copy(n = currentStatus.a.n(_ | n), z = currentStatus.a.z(_ | n), a = | n))
case AssemblyLine(ALR, Immediate, NumericConstant(nn, _), _) =>
val n = nn.toInt
currentStatus = currentStatus.copy(
n = currentStatus.a.n(i => (i & n & 0xff) >> 1),
z = currentStatus.a.z(i => (i & n & 0xff) >> 1),
c = => (i & n & 1) == 0),
a = => (i & n & 0xff) >> 1))
case AssemblyLine(INX, Implied, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.x.n(_ + 1), z = currentStatus.x.z(_ + 1), x = + 1))
case AssemblyLine(DEX, Implied, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.x.n(_ - 1), z = currentStatus.x.z(_ - 1), x = - 1))
case AssemblyLine(INY, Implied, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.y.n(_ + 1), z = currentStatus.y.z(_ + 1), y = + 1))
case AssemblyLine(DEY, Implied, _, _) =>
currentStatus = currentStatus.copy(n = currentStatus.y.n(_ - 1), z = currentStatus.y.z(_ - 1), y = - 1))
case AssemblyLine(TAX, _, _, _) =>
currentStatus = currentStatus.copy(x = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z())
case AssemblyLine(TXA, _, _, _) =>
currentStatus = currentStatus.copy(a = currentStatus.x, n = currentStatus.x.n(), z = currentStatus.x.z())
case AssemblyLine(TAY, _, _, _) =>
currentStatus = currentStatus.copy(y = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z())
case AssemblyLine(TYA, _, _, _) =>
currentStatus = currentStatus.copy(a = currentStatus.y, n = currentStatus.y.n(), z = currentStatus.y.z())
case AssemblyLine(opcode, addrMode, parameter, _) =>
if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.copy(x = AnyStatus())
if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.copy(y = AnyStatus())
if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.copy(a = AnyStatus())
if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.copy(a = AnyStatus())
if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus =
if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.copy(c = AnyStatus())
if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.copy(v = AnyStatus())
if (opcode == CMP || opcode == CPX || opcode == CPY) {
if (addrMode == Immediate) parameter match {
case NumericConstant(0, _) => currentStatus = currentStatus.copy(c = SingleStatus(true))
case _ => ()
// case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s")
// }
// println("---------------------")

View File

@ -0,0 +1,59 @@
package millfork.assembly.opt
import millfork.assembly._
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.env._
* @author Karol Stasiak
object DangerousOptimizations {
val ConstantIndexOffsetPropagation = new RuleBasedAssemblyOptimization("Constant index offset propagation",
// TODO: try to guess when overflow can happen
needsFlowInfo = FlowInfoRequirement.BothFlows,
(Elidable & HasOpcode(CLC)).? ~
(Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~
(HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~
(Linear & Not(ConcernsY)).*
).capture(1) ~
(Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
val last = code.last
ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify)
(Elidable & HasOpcode(CLC)).? ~
(Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~
(HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~
(Linear & Not(ConcernsX)).*
).capture(1) ~
(Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
val last = code.last
ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify)
(Elidable & HasOpcode(INY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
(Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
val last = code.last
List(last.copy(parameter = last.parameter.+(1).quickSimplify))
(Elidable & HasOpcode(DEY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
(Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
val last = code.last
List(last.copy(parameter = last.parameter.+(-1).quickSimplify))
(Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
(Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
val last = code.last
List(last.copy(parameter = last.parameter.+(1).quickSimplify))
(Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
(Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
val last = code.last
List(last.copy(parameter = last.parameter.+(-1).quickSimplify))
val All: List[AssemblyOptimization] = List(ConstantIndexOffsetPropagation)

View File

@ -0,0 +1,34 @@
package millfork.assembly.opt
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly.{AssemblyLine, State}
import millfork.env.NormalFunction
* @author Karol Stasiak
case class FlowInfo(statusBefore: CpuStatus, importanceAfter: CpuImportance) {
def hasClear(state: State.Value): Boolean = statusBefore.hasClear(state)
def hasSet(state: State.Value): Boolean = statusBefore.hasSet(state)
def isUnimportant(state: State.Value): Boolean = importanceAfter.isUnimportant(state)
object FlowInfo {
val Default = FlowInfo(CpuStatus(), CpuImportance())
object FlowAnalyzer {
def analyze(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[(FlowInfo, AssemblyLine)] = {
val forwardFlow = if (options.flag(CompilationFlag.DetailedFlowAnalysis)) {
QuantumFlowAnalyzer.analyze(f, code).map(_.collapse)
} else {
CoarseFlowAnalyzer.analyze(f, code)
val reverseFlow = ReverseFlowAnalyzer.analyze(f, code){case (s,i) => FlowInfo(s,i)}.zip(code)

View File

@ -0,0 +1,242 @@
package millfork.assembly.opt
import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State}
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.assembly.OpcodeClasses._
import millfork.env.{Constant, NormalFunction, NumericConstant}
* These optimizations help on their own, but may prevent other optimizations from triggering.
* @author Karol Stasiak
object LaterOptimizations {
// This optimization tends to prevent later Variable To Register Optimization,
// so run this only after it's pretty sure V2RO won't happen any more
val DoubleLoadToDifferentRegisters = new RuleBasedAssemblyOptimization("Double load to different registers",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDX, TAX),
TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDY, TAY),
TwoDifferentLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX), LDA, TXA),
TwoDifferentLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY), LDA, TYA),
TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDX, TAX),
TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDY, TAY),
TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX), LDA, TXA),
TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY), LDA, TYA),
private def TwoDifferentLoadsWithNoFlagChangeInBetween(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = {
(HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).* ~
(HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c =>
c.init :+ AssemblyLine.implied(transferOpcode)
private def TwoDifferentLoadsWhoseFlagsWillNotBeChecked(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = {
((HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).*).capture(2) ~
(HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
ctx.get[List[AssemblyLine]](2) ++ (AssemblyLine.implied(transferOpcode) :: ctx.get[List[AssemblyLine]](3))
private def TwoIdenticalLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
(HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & Not(ChangesMemory) & middle & Not(ChangesNAndZ)).* ~
(HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c =>
private def TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
(HasOpcode(opcode) & HasAddrMode(Immediate) & MatchParameter(1)) ~
(LinearOrLabel & middle & Not(ChangesNAndZ)).* ~
(HasOpcode(opcode) & Elidable & HasAddrMode(Immediate) & MatchParameter(1)) ~~> { c =>
private def TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
((HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & Not(ChangesMemory) & middle).*).capture(2) ~
(HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3)
//noinspection ZeroIndexToHead
private def InterleavedImmediateLoads(load: Opcode.Value, store: Opcode.Value) = {
(Elidable & HasOpcode(load) & MatchImmediate(0)) ~
(Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8)) ~
(Elidable & HasOpcode(load) & MatchImmediate(1)) ~
(Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8)) ~
(Elidable & HasOpcode(load) & MatchImmediate(0)) ~~> { c =>
List(c(2), c(3), c(0), c(1))
//noinspection ZeroIndexToHead
private def InterleavedAbsoluteLoads(load: Opcode.Value, store: Opcode.Value) = {
(Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~
(Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8) & DontMatchParameter(0)) ~
(Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(1) & DontMatchParameter(8) & DontMatchParameter(0)) ~
(Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8) & DontMatchParameter(1) & DontMatchParameter(0)) ~
(Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~~> { c =>
List(c(2), c(3), c(0), c(1))
val DoubleLoadToTheSameRegister = new RuleBasedAssemblyOptimization("Double load to the same register",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
TwoIdenticalLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)),
TwoIdenticalLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)),
TwoIdenticalLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)),
TwoIdenticalLoadsWithNoFlagChangeInBetween(LAX, Not(ChangesA) & Not(ChangesX)),
TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)),
TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)),
TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)),
TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA)),
TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX)),
TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY)),
TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LAX, Not(ChangesA) & Not(ChangesX)),
InterleavedImmediateLoads(LDA, STA),
InterleavedImmediateLoads(LDX, STX),
InterleavedImmediateLoads(LDY, STY),
InterleavedAbsoluteLoads(LDA, STA),
InterleavedAbsoluteLoads(LDX, STX),
InterleavedAbsoluteLoads(LDY, STY),
private def pointlessLoadAfterStore(store: Opcode.Value, load: Opcode.Value, addrMode: AddrMode.Value, meantime: AssemblyLinePattern = Anything) = {
((HasOpcode(store) & HasAddrMode(addrMode) & MatchParameter(1)) ~
(LinearOrBranch & Not(ChangesA) & Not(ChangesMemory) & meantime).*).capture(2) ~
(HasOpcode(load) & Elidable & HasAddrMode(addrMode) & MatchParameter(1)) ~
((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3)
val PointlessLoadAfterStore = new RuleBasedAssemblyOptimization("Pointless load after store",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
pointlessLoadAfterStore(STA, LDA, Absolute),
pointlessLoadAfterStore(STA, LDA, AbsoluteX, Not(ChangesX)),
pointlessLoadAfterStore(STA, LDA, AbsoluteY, Not(ChangesY)),
pointlessLoadAfterStore(STX, LDX, Absolute),
pointlessLoadAfterStore(STY, LDY, Absolute),
private val ShiftAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX)
private val ShiftOpcodes = Set(ASL, ROL, ROR, LSR)
// LDA-SHIFT-STA is slower than just SHIFT
// LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator
val PointessLoadingForShifting = new RuleBasedAssemblyOptimization("Pointless loading for shifting",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
(Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
(Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) =>
AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) :: code.drop(3)
(Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
(Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
(Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
(Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) =>
val shift = AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1))
shift :: shift :: code.drop(4)
// SHIFT-LDA is equally fast as LDA-SHIFT-STA, but can enable further optimizations doesn't use the accumulator
// LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator
val LoadingAfterShifting = new RuleBasedAssemblyOptimization("Loading after shifting",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcodeIn(ShiftOpcodes) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
AssemblyLine(LDA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) ::
AssemblyLine.implied(code.head.opcode) ::
AssemblyLine(STA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) ::
val UseZeropageAddressingMode = new RuleBasedAssemblyOptimization("Using zeropage addressing mode",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasAddrMode(Absolute) & MatchParameter(0)) ~ Where(ctx => ctx.get[Constant](0).quickSimplify match {
case NumericConstant(x, _) => (x & 0xff00) == 0
case _ => false
}) ~~> (code => code.head.copy(addrMode = ZeroPage) :: Nil)
val UseXInsteadOfStack = new RuleBasedAssemblyOptimization("Using X instead of stack",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.X)) ~
(Not(ConcernsStack) & Not(ConcernsX)).capture(1) ~
Where(_.isExternallyLinearBlock(1)) ~
(Elidable & HasOpcode(PLA)) ~~> (c =>
AssemblyLine.implied(TAX) :: (c.tail.init :+ AssemblyLine.implied(TXA))
val UseYInsteadOfStack = new RuleBasedAssemblyOptimization("Using Y instead of stack",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.Y)) ~
(Not(ConcernsStack) & Not(ConcernsY)).capture(1) ~
Where(_.isExternallyLinearBlock(1)) ~
(Elidable & HasOpcode(PLA)) ~~> (c =>
AssemblyLine.implied(TAY) :: (c.tail.init :+ AssemblyLine.implied(TYA))
// TODO: make it more generic
val IndexSwitchingOptimization = new RuleBasedAssemblyOptimization("Index switching optimization",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~
(Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX)) ~
(HasOpcode(LDY) & Not(ConcernsX)) ~
(Linear & Not(ChangesY) & Not(ConcernsX) & HasAddrModeIn(Set(AbsoluteY, IndexedY, ZeroPageY))) ~
(Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~
(Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> { (code, ctx) =>
code(0).copy(opcode = LDX),
code(1).copy(addrMode = AbsoluteX),
code(5).copy(addrMode = AbsoluteX))
(Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~
(Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY)) ~
(HasOpcode(LDX) & Not(ConcernsY)) ~
(Linear & Not(ChangesX) & Not(ConcernsY) & HasAddrModeIn(Set(AbsoluteX, IndexedX, ZeroPageX, AbsoluteIndexedX))) ~
(Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~
(Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> { (code, ctx) =>
code(0).copy(opcode = LDY),
code(1).copy(addrMode = AbsoluteY),
code(5).copy(addrMode = AbsoluteY))
val All = List(

View File

@ -0,0 +1,425 @@
package millfork.assembly.opt
import millfork.assembly.{AssemblyLine, OpcodeClasses}
import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
import scala.collection.immutable.BitSet
* @author Karol Stasiak
object QCpuStatus {
val InitialStatus = QCpuStatus((for {
c <- Seq(true, false)
v <- Seq(true, false)
n <- Seq(true, false)
z <- Seq(true, false)
} yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.NoEquality)).toMap)
val UnknownStatus = QCpuStatus((for {
c <- Seq(true, false)
v <- Seq(true, false)
n <- Seq(true, false)
z <- Seq(true, false)
} yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.UnknownEquality)).toMap)
def gather(l: List[(QFlagStatus, QRegStatus)]) =
map { case (k, vs) => k -> ++ _) }.
case class QCpuStatus(data: Map[QFlagStatus, QRegStatus]) {
def collapse: CpuStatus = {
val registers = data.values.reduce(_ ++ _)
def bitset(b: BitSet): Status[Int] = if (b.size == 1) SingleStatus(b.head) else AnyStatus()
def flag(f: QFlagStatus => Boolean): Status[Boolean] =
if (data.keys.forall(k => f(k))) SingleStatus(true)
else if (data.keys.forall(k => !f(k))) SingleStatus(false)
else AnyStatus()
a = bitset(registers.a),
x = bitset(registers.x),
y = bitset(registers.y),
c = flag(_.c),
d = flag(_.d),
v = flag(_.v),
z = flag(_.z),
n = flag(_.n),
def changeFlagUnconditionally(f: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather( { case (k, v) => f(k) -> v })
def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v) })
def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus, h: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v, h(k) -> v) })
def mapRegisters(f: QRegStatus => QRegStatus): QCpuStatus = {
QCpuStatus( { case (k, v) => k -> f(v) })
def mapRegisters(f: (QFlagStatus, QRegStatus) => QRegStatus): QCpuStatus = {
QCpuStatus( { case (k, v) => k -> f(k, v) })
def flatMap(f: (QFlagStatus, QRegStatus) => List[(QFlagStatus, QRegStatus)]): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => f(k, v) })
def changeNZFromA: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
k.copy(n = false, z = false) -> v.whereA(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereA(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereA(i => i.toByte == 0))
def changeNZFromX: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
k.copy(n = false, z = false) -> v.whereX(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereX(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereX(i => i.toByte == 0))
def changeNZFromY: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
k.copy(n = false, z = false) -> v.whereY(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereY(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereY(i => i.toByte == 0))
def ~(that: QCpuStatus): QCpuStatus = QCpuStatus.gather( ++
object QRegStatus {
val NoValues: BitSet = BitSet.empty
val AllValues: BitSet = BitSet.fromBitMask(Array(-1L, -1L, -1L, -1L))
object RegEquality extends Enumeration {
val NoEquality, AX, AY, XY, AXY, UnknownEquality = Value
def or(a: Value, b: Value) = {
(a, b) match {
case (UnknownEquality, _) => b
case (_, UnknownEquality) => a
case (NoEquality, _) => NoEquality
case (_, NoEquality) => NoEquality
case (_, _) if a == b => a
case (AXY, _) => b
case (_, AXY) => a
case _ => NoEquality
def afterTransfer(a: Value, b: Value) = {
(a, b) match {
case (UnknownEquality, _) => b
case (_, UnknownEquality) => a
case (NoEquality, _) => b
case (_, NoEquality) => a
case (_, _) if a == b => a
case _ => AXY
case class QRegStatus(a: BitSet, x: BitSet, y: BitSet, equal: RegEquality.Value) {
def isEmpty: Boolean = a.isEmpty || x.isEmpty || y.isEmpty
def ++(that: QRegStatus) = QRegStatus(
a = a ++ that.a,
x = x ++ that.x,
y = y ++ that.y,
equal = RegEquality.or(equal, that.equal))
def afterTransfer(transfer: RegEquality.Value): QRegStatus =
copy(equal = RegEquality.afterTransfer(equal, transfer))
def changeA(f: Int => Long): QRegStatus = {
val newA = => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
QRegStatus(newA, x, y, newEqual)
def changeX(f: Int => Long): QRegStatus = {
val newA = => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
QRegStatus(newA, x, y, newEqual)
def changeY(f: Int => Long): QRegStatus = {
val newA = => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
QRegStatus(newA, x, y, newEqual)
def whereA(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.AY =>
copy(a = a.filter(f), y = y.filter(f))
case RegEquality.AX =>
copy(a = a.filter(f), x = x.filter(f))
case _ =>
copy(a = a.filter(f))
def whereX(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.XY =>
copy(x = x.filter(f), y = y.filter(f))
case RegEquality.AX =>
copy(a = a.filter(f), x = x.filter(f))
case _ =>
copy(x = x.filter(f))
def whereY(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.AY =>
copy(a = a.filter(f), y = y.filter(f))
case RegEquality.XY =>
copy(x = x.filter(f), y = y.filter(f))
case _ =>
copy(y = y.filter(f))
case class QFlagStatus(c: Boolean, d: Boolean, v: Boolean, z: Boolean, n: Boolean)
object QuantumFlowAnalyzer {
private def loBit(b: Boolean) = if (b) 1 else 0
private def hiBit(b: Boolean) = if (b) 0x80 else 0
//noinspection RedundantNewCaseClass
def analyze(f: NormalFunction, code: List[AssemblyLine]): List[QCpuStatus] = {
val flagArray = Array.fill[QCpuStatus](code.length)(QCpuStatus.UnknownStatus)
val codeArray = code.toArray
var changed = true
while (changed) {
changed = false
var currentStatus: QCpuStatus = if (f.interrupt) QCpuStatus.UnknownStatus else QCpuStatus.UnknownStatus
for (i <- codeArray.indices) {
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
if (flagArray(i) != currentStatus) {
changed = true
flagArray(i) = currentStatus
codeArray(i) match {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) =>
val L = l
currentStatus = codeArray.indices.flatMap(j => codeArray(j) match {
case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j))
case _ => None
}).fold(QCpuStatus.UnknownStatus)(_ ~ _)
case AssemblyLine(BCC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
case AssemblyLine(BCS, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
case AssemblyLine(BVS, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
case AssemblyLine(BVC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = true))
case AssemblyLine(BMI, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = false))
case AssemblyLine(BPL, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = true))
case AssemblyLine(BEQ, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = false))
case AssemblyLine(BNE, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = true))
case AssemblyLine(SED, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = true))
case AssemblyLine(SEC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
case AssemblyLine(CLD, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = false))
case AssemblyLine(CLC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
case AssemblyLine(CLV, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
case AssemblyLine(JSR, _, _, _) =>
currentStatus = QCpuStatus.InitialStatus
case AssemblyLine(LDX, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ => n)).changeNZFromX
case AssemblyLine(LDY, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ => n)).changeNZFromY
case AssemblyLine(LDA, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n)).changeNZFromA
case AssemblyLine(LAX, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n).changeX(_ => n).afterTransfer(RegEquality.AX)).changeNZFromA
case AssemblyLine(EOR, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ ^ n)).changeNZFromA
case AssemblyLine(AND, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA
case AssemblyLine(ANC, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA.changeFlagUnconditionally(f => f.copy(c = f.z))
case AssemblyLine(ORA, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ | n)).changeNZFromA
case AssemblyLine(INX, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ + 1)).changeNZFromX
case AssemblyLine(DEX, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ - 1)).changeNZFromX
case AssemblyLine(INY, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
case AssemblyLine(DEY, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
case AssemblyLine(TAX, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(x = r.a).afterTransfer(RegEquality.AX)).changeNZFromX
case AssemblyLine(TXA, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.x).afterTransfer(RegEquality.AX)).changeNZFromA
case AssemblyLine(TAY, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(y = r.a).afterTransfer(RegEquality.AY)).changeNZFromY
case AssemblyLine(TYA, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.y).afterTransfer(RegEquality.AY)).changeNZFromA
case AssemblyLine(ROL, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2 + loBit(f.c)),
f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2 + loBit(f.c)),
case AssemblyLine(ROR, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
case AssemblyLine(ASL, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2),
f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2),
case AssemblyLine(LSR, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f),
f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f),
case AssemblyLine(ALR, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & n & 1) != 0).changeA(a => ((a & n) >>> 2) & 0x7f),
f.copy(c = false) -> r.whereA(a => (a & n & 1) == 0).changeA(a => ((a & n) >>> 2) & 0x7f),
case AssemblyLine(ADC, Immediate, NumericConstant(nn, _), _) =>
val n = nn & 0xff
currentStatus = currentStatus.flatMap((f, r) =>
if (f.d) {
val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
} else {
if (f.c) {
val regs = r.changeA(_ + n + 1)
f.copy(c = false, v = false) -> regs.whereA(_ >= n),
f.copy(c = true, v = false) -> regs.whereA(_ < n),
f.copy(c = false, v = true) -> regs.whereA(_ >= n),
f.copy(c = true, v = true) -> regs.whereA(_ < n),
} else {
val regs = r.changeA(_ + n)
f.copy(c = false, v = false) -> regs.whereA(_ > n),
f.copy(c = true, v = false) -> regs.whereA(_ <= n),
f.copy(c = false, v = true) -> regs.whereA(_ > n),
f.copy(c = true, v = true) -> regs.whereA(_ <= n),
case AssemblyLine(SBC, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.flatMap((f, r) =>
if (f.d) {
val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
// TODO: guess the carry flag correctly
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
} else {
val regs = if (f.c) r.changeA(_ - n) else r.changeA(_ - n - 1)
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
case AssemblyLine(opcode, addrMode, parameter, _) =>
if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(x = QRegStatus.AllValues))
if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(y = QRegStatus.AllValues))
if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(
_.copy(n = false, z = false),
_.copy(n = true, z = false),
_.copy(n = false, z = true))
if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(c = false), _.copy(c = true))
if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(v = false), _.copy(v = true))
if (opcode == CMP || opcode == CPX || opcode == CPY) {
if (addrMode == Immediate) parameter match {
case NumericConstant(0, _) => currentStatus = currentStatus.changeFlagUnconditionally(_.copy(c = true))
case _ => ()
// case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s")
// }
// println("---------------------")

View File

@ -0,0 +1,149 @@
package millfork.assembly.opt
import millfork.assembly.{AssemblyLine, OpcodeClasses, State}
import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
import scala.collection.immutable
* @author Karol Stasiak
sealed trait Importance {
def ~(that: Importance) = (this, that) match {
case (_, Important) | (Important, _) => Important
case (_, Unimportant) | (Unimportant, _) => Unimportant
case (UnknownImportance, UnknownImportance) => UnknownImportance
case object Important extends Importance {
override def toString = "!"
case object Unimportant extends Importance {
override def toString = "*"
case object UnknownImportance extends Importance {
override def toString = "?"
//noinspection RedundantNewCaseClass
case class CpuImportance(a: Importance = UnknownImportance,
x: Importance = UnknownImportance,
y: Importance = UnknownImportance,
n: Importance = UnknownImportance,
z: Importance = UnknownImportance,
v: Importance = UnknownImportance,
c: Importance = UnknownImportance,
d: Importance = UnknownImportance,
) {
override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d"
def ~(that: CpuImportance) = new CpuImportance(
a = this.a ~ that.a,
x = this.x ~ that.x,
y = this.y ~ that.y,
z = this.z ~ that.z,
n = this.n ~ that.n,
c = this.c ~ that.c,
v = this.v ~ that.v,
d = this.d ~ that.d,
def isUnimportant(state: State.Value): Boolean = state match {
case State.A => a == Unimportant
case State.X => x == Unimportant
case State.Y => y == Unimportant
case State.Z => z == Unimportant
case State.N => n == Unimportant
case State.C => c == Unimportant
case State.V => v == Unimportant
case State.D => d == Unimportant
object ReverseFlowAnalyzer {
//noinspection RedundantNewCaseClass
def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuImportance] = {
val importanceArray = Array.fill[CpuImportance](code.length)(new CpuImportance())
val codeArray = code.toArray
val initialStatus = new CpuStatus(d = SingleStatus(false))
var changed = true
val finalImportance = new CpuImportance(a = Important, x = Important, y = Important, c = Important, v = Important, d = Important, z = Important, n = Important)
changed = true
while (changed) {
changed = false
var currentImportance: CpuImportance = finalImportance
for (i <- codeArray.indices.reverse) {
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
if (importanceArray(i) != currentImportance) {
changed = true
importanceArray(i) = currentImportance
codeArray(i) match {
case AssemblyLine(opcode, Relative, MemoryAddressConstant(Label(l)), _) if OpcodeClasses.ShortBranching(opcode) =>
val L = l
val labelIndex = codeArray.indexWhere {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true
case _ => false
currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance
case _ =>
codeArray(i) match {
case AssemblyLine(JMP, Absolute, MemoryAddressConstant(Label(l)), _) =>
val L = l
val labelIndex = codeArray.indexWhere {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true
case _ => false
currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex)
case AssemblyLine(JMP, Indirect, _, _) =>
currentImportance = finalImportance
case AssemblyLine(BNE | BEQ, _, _, _) =>
currentImportance = currentImportance.copy(z = Important)
case AssemblyLine(BMI | BPL, _, _, _) =>
currentImportance = currentImportance.copy(n = Important)
case AssemblyLine(SED | CLD, _, _, _) =>
currentImportance = currentImportance.copy(d = Unimportant)
case AssemblyLine(RTS, _, _, _) =>
currentImportance = finalImportance
case AssemblyLine(DISCARD_XF, _, _, _) =>
currentImportance = currentImportance.copy(x = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
case AssemblyLine(DISCARD_YF, _, _, _) =>
currentImportance = currentImportance.copy(y = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
case AssemblyLine(DISCARD_AF, _, _, _) =>
currentImportance = currentImportance.copy(a = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
case AssemblyLine(opcode, addrMode, _, _) =>
if (OpcodeClasses.ChangesC(opcode)) currentImportance = currentImportance.copy(c = Unimportant)
if (OpcodeClasses.ChangesV(opcode)) currentImportance = currentImportance.copy(v = Unimportant)
if (OpcodeClasses.ChangesNAndZ(opcode)) currentImportance = currentImportance.copy(n = Unimportant, z = Unimportant)
if (OpcodeClasses.OverwritesA(opcode)) currentImportance = currentImportance.copy(a = Unimportant)
if (OpcodeClasses.OverwritesX(opcode)) currentImportance = currentImportance.copy(x = Unimportant)
if (OpcodeClasses.OverwritesY(opcode)) currentImportance = currentImportance.copy(y = Unimportant)
if (OpcodeClasses.ReadsC(opcode)) currentImportance = currentImportance.copy(c = Important)
if (OpcodeClasses.ReadsD(opcode)) currentImportance = currentImportance.copy(d = Important)
if (OpcodeClasses.ReadsV(opcode)) currentImportance = currentImportance.copy(v = Important)
if (OpcodeClasses.ReadsXAlways(opcode)) currentImportance = currentImportance.copy(x = Important)
if (OpcodeClasses.ReadsYAlways(opcode)) currentImportance = currentImportance.copy(y = Important)
if (OpcodeClasses.ReadsAAlways(opcode)) currentImportance = currentImportance.copy(a = Important)
if (OpcodeClasses.ReadsAIfImplied(opcode) && addrMode == Implied) currentImportance = currentImportance.copy(a = Important)
if (addrMode == AbsoluteX || addrMode == IndexedX || addrMode == ZeroPageX) currentImportance = currentImportance.copy(x = Important)
if (addrMode == AbsoluteY || addrMode == IndexedY || addrMode == ZeroPageY) currentImportance = currentImportance.copy(y = Important)
// case (i, y) => if (y.isPrintable) println(f"$y%-32s $i%-32s")
// }
// println("---------------------")

View File

@ -0,0 +1,757 @@
package millfork.assembly.opt
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly._
import millfork.env._
import millfork.error.ErrorReporting
import scala.collection.mutable
* @author Karol Stasiak
object FlowInfoRequirement extends Enumeration {
val NoRequirement, BothFlows, ForwardFlow, BackwardFlow = Value
def assertForward(x: FlowInfoRequirement.Value): Unit = x match {
case BothFlows | ForwardFlow => ()
case NoRequirement | BackwardFlow => ErrorReporting.fatal("Forward flow info required")
def assertBackward(x: FlowInfoRequirement.Value): Unit = x match {
case BothFlows | BackwardFlow => ()
case NoRequirement | ForwardFlow => ErrorReporting.fatal("Backward flow info required")
class RuleBasedAssemblyOptimization(val name: String, val needsFlowInfo: FlowInfoRequirement.Value, val rules: AssemblyRule*) extends AssemblyOptimization {
override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
val effectiveCode = => a.copy(parameter = a.parameter.quickSimplify))
val taggedCode = needsFlowInfo match {
case FlowInfoRequirement.NoRequirement => -> _)
case FlowInfoRequirement.BothFlows => FlowAnalyzer.analyze(f, effectiveCode, options)
case FlowInfoRequirement.ForwardFlow =>
if (options.flag(CompilationFlag.DetailedFlowAnalysis)) {
QuantumFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s.collapse, CpuImportance())).zip(code)
} else {
CoarseFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s, CpuImportance())).zip(code)
case FlowInfoRequirement.BackwardFlow =>
ReverseFlowAnalyzer.analyze(f, code).map(i => FlowInfo(CpuStatus(), i)).zip(code)
optimizeImpl(f, taggedCode, options)
def optimizeImpl(f: NormalFunction, code: List[(FlowInfo, AssemblyLine)], options: CompilationOptions): List[AssemblyLine] = {
code match {
case Nil => Nil
case head :: tail =>
for ((rule, index) <- rules.zipWithIndex) {
val ctx = new AssemblyMatchingContext
rule.pattern.matchTo(ctx, code) match {
case Some(rest: List[(FlowInfo, AssemblyLine)]) =>
val matchedChunkToOptimize: List[AssemblyLine] = code.take(code.length - rest.length).map(_._2)
val optimizedChunk: List[AssemblyLine] = rule.result(matchedChunkToOptimize, ctx)
ErrorReporting.debug(s"Applied $name ($index)")
if (needsFlowInfo != FlowInfoRequirement.NoRequirement) {
val before = code.head._1.statusBefore
val after = code(matchedChunkToOptimize.length - 1)._1.importanceAfter
ErrorReporting.trace(s"Before: $before")
ErrorReporting.trace(s"After: $after")
matchedChunkToOptimize.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString))
ErrorReporting.trace(" ↓")
optimizedChunk.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString))
if (needsFlowInfo != FlowInfoRequirement.NoRequirement) {
return optimizedChunk ++ optimizeImpl(f, rest, options)
} else {
return optimize(f, optimizedChunk ++, options)
case None => ()
head._2 :: optimizeImpl(f, tail, options)
class AssemblyMatchingContext {
private val map = mutable.Map[Int, Any]()
def addObject(i: Int, o: Any): Boolean = {
if (map.contains(i)) {
map(i) == o
} else {
map(i) = o
def dontMatch(i: Int, o: Any): Boolean = {
if (map.contains(i)) {
map(i) != o
} else {
def get[T: Manifest](i: Int): T = {
val t = map(i)
val clazz = implicitly[Manifest[T]].runtimeClass match {
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
case x => x
if (clazz.isInstance(t)) {
} else {
if (i eq null) {
ErrorReporting.fatal(s"Value at index $i is null")
} else {
ErrorReporting.fatal(s"Value at index $i is a ${t.getClass.getSimpleName}, not a ${clazz.getSimpleName}")
def isExternallyLinearBlock(i: Int): Boolean = {
val labels = mutable.Set[String]()
val jumps = mutable.Set[String]()
get[List[AssemblyLine]](i).foreach {
case AssemblyLine(Opcode.RTS | Opcode.RTI | Opcode.BRK, _, _, _) =>
return false
case AssemblyLine(Opcode.JMP, AddrMode.Indirect, _, _) =>
return false
case AssemblyLine(Opcode.LABEL, _, MemoryAddressConstant(Label(l)), _) =>
labels += l
case AssemblyLine(Opcode.JMP, AddrMode.Absolute, MemoryAddressConstant(Label(l)), _) =>
jumps += l
case AssemblyLine(Opcode.JMP, AddrMode.Absolute, _, _) =>
return false
case AssemblyLine(_, AddrMode.Relative, MemoryAddressConstant(Label(l)), _) =>
jumps += l
case AssemblyLine(br, _, _, _) if OpcodeClasses.ShortBranching(br) =>
return false
case _ => ()
// if a jump leads inside the block, then it's internal
// if a jump leads outside the block, then it's external
jumps --= labels
def areMemoryReferencesProvablyNonOverlapping(param1: Int, addrMode1: Int, param2: Int, addrMode2: Int): Boolean = {
val p1 = get[Constant](param1).quickSimplify
val a1 = get[AddrMode.Value](addrMode1)
val p2 = get[Constant](param2).quickSimplify
val a2 = get[AddrMode.Value](addrMode2)
import AddrMode._
val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX)
if (badAddrModes(a1) || badAddrModes(a2)) return false
def handleKnownDistance(distance: Short): Boolean = {
val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY)
val a1Indexing = indexingAddrModes(a1)
val a2Indexing = indexingAddrModes(a2)
(a1Indexing, a2Indexing) match {
case (false, false) => distance != 0
case (true, false) => distance > 255 || distance < 0
case (false, true) => distance > 0 || distance < -255
case (true, true) => distance > 255 || distance < -255
(p1, p2) match {
case (NumericConstant(n1, _), NumericConstant(n2, _)) =>
handleKnownDistance((n2 - n1).toShort)
case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
case (MemoryAddressConstant(MemoryVariable(a, _, _)), MemoryAddressConstant(MemoryVariable(b, _, _))) =>
a.takeWhile(_ != '.') != a.takeWhile(_ != '.') // TODO: ???
case _ =>
case class AssemblyRule(pattern: AssemblyPattern, result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) {
trait AssemblyPattern {
def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = ()
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]]
def ~(x: AssemblyPattern) = Concatenation(this, x)
def ~(x: AssemblyLinePattern) = Concatenation(this, x)
def ~~>(result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) = AssemblyRule(this, result)
def ~~>(result: List[AssemblyLine] => List[AssemblyLine]) = AssemblyRule(this, (code, _) => result(code))
def capture(i: Int) = Capture(i, this)
def captureLength(i: Int) = CaptureLength(i, this)
case class Capture(i: Int, pattern: AssemblyPattern) extends AssemblyPattern {
override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] =
for {
rest <- pattern.matchTo(ctx, code)
} yield {
ctx.addObject(i, code.take(code.length - rest.length).map(_._2))
override def toString: String = s"(?<$i>$pattern)"
case class CaptureLength(i: Int, pattern: AssemblyPattern) extends AssemblyPattern {
override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] =
for {
rest <- pattern.matchTo(ctx, code)
} yield {
ctx.addObject(i, code.length - rest.length)
override def toString: String = s"(?<$i>$pattern)"
case class Where(predicate: (AssemblyMatchingContext => Boolean)) extends AssemblyPattern {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
if (predicate(ctx)) Some(code) else None
override def toString: String = "Where(...)"
case class Concatenation(l: AssemblyPattern, r: AssemblyPattern) extends AssemblyPattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
for {
middle <- l.matchTo(ctx, code)
end <- r.matchTo(ctx, middle)
} yield end
override def toString: String = (l, r) match {
case (_: Both, _: Both) => s"($l) · ($r)"
case (_, _: Both) => s"$l · ($r)"
case (_: Both, _) => s"($l) · $r"
case _ => s"$l · $r"
case class Many(rule: AssemblyLinePattern) extends AssemblyPattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
var c = code
while (true) {
c match {
case Nil =>
return Some(Nil)
case x :: xs =>
if (rule.matchLineTo(ctx, x._1, x._2)) {
c = xs
} else {
return Some(c)
override def toString: String = s"[$rule]*"
case class ManyWhereAtLeastOne(rule: AssemblyLinePattern, atLeastOneIsThis: AssemblyLinePattern) extends AssemblyPattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
var c = code
var oneFound = false
while (true) {
c match {
case Nil =>
return Some(Nil)
case x :: xs =>
if (atLeastOneIsThis.matchLineTo(ctx, x._1, x._2)) {
oneFound = true
if (rule.matchLineTo(ctx, x._1, x._2)) {
c = xs
} else {
if (oneFound) {
return Some(c)
} else {
return None
override def toString: String = s"[∃$atLeastOneIsThis:$rule]*"
case class Opt(rule: AssemblyLinePattern) extends AssemblyPattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
code match {
case Nil =>
case x :: xs =>
if (rule.matchLineTo(ctx, x._1, x._2)) {
} else {
override def toString: String = s"[$rule]?"
trait AssemblyLinePattern extends AssemblyPattern {
def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match {
case Nil => None
case x :: xs => if (matchLineTo(ctx, x._1, x._2)) Some(xs) else None
def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean
def unary_! : AssemblyLinePattern = Not(this)
def ? : AssemblyPattern = Opt(this)
def * : AssemblyPattern = Many(this)
def + : AssemblyPattern = this ~ Many(this)
def |(x: AssemblyLinePattern): AssemblyLinePattern = Either(this, x)
def &(x: AssemblyLinePattern): AssemblyLinePattern = Both(this, x)
protected def memoryAccessDoesntOverlap(a1: AddrMode.Value, p1: Constant, a2: AddrMode.Value, p2: Constant): Boolean = {
import AddrMode._
val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX)
if (badAddrModes(a1) || badAddrModes(a2)) return false
val goodAddrModes = Set(Implied, Immediate, Relative)
if (goodAddrModes(a1) || goodAddrModes(a2)) return true
def handleKnownDistance(distance: Short): Boolean = {
val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY)
val a1Indexing = indexingAddrModes(a1)
val a2Indexing = indexingAddrModes(a2)
(a1Indexing, a2Indexing) match {
case (false, false) => distance != 0
case (true, false) => distance > 255 || distance < 0
case (false, true) => distance > 0 || distance < -255
case (true, true) => distance > 255 || distance < -255
(p1.quickSimplify, p2.quickSimplify) match {
case (NumericConstant(n1, _), NumericConstant(n2, _)) =>
handleKnownDistance((n2 - n1).toShort)
case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
case (MemoryAddressConstant(a: ThingInMemory), MemoryAddressConstant(b:ThingInMemory)) => != '.') != != '.') // TODO: ???
case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)),
MemoryAddressConstant(b: ThingInMemory)) => != '.') != != '.') // TODO: ???
case (MemoryAddressConstant(a: ThingInMemory),
CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => != '.') != != '.') // TODO: ???
case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)),
CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => != '.') != != '.') // TODO: ???
case _ =>
//noinspection LanguageFeature
object AssemblyLinePattern {
implicit def __implicitOpcodeIn(ops: Set[Opcode.Value]): AssemblyLinePattern = HasOpcodeIn(ops)
case class MatchA(i: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
flowInfo.statusBefore.a match {
case SingleStatus(value) => ctx.addObject(i, value)
case _ => false
case class MatchX(i: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
flowInfo.statusBefore.x match {
case SingleStatus(value) => ctx.addObject(i, value)
case _ => false
case class MatchY(i: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
flowInfo.statusBefore.y match {
case SingleStatus(value) => ctx.addObject(i, value)
case _ => false
case class HasA(value: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case class HasX(value: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case class HasY(value: Int) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case class DoesntMatterWhatItDoesWith(states: State.Value*) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
states.forall(state => flowInfo.importanceAfter.isUnimportant(state))
override def toString: String = states.mkString("[¯\\_(ツ)_/¯:", ",", "]")
case class HasSet(state: State.Value) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case class HasClear(state: State.Value) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case object Anything extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case class Not(inner: AssemblyLinePattern) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = inner.validate(needsFlowInfo)
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
!inner.matchLineTo(ctx, flowInfo, line)
override def toString: String = "¬" + inner
case class Both(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
l.matchLineTo(ctx, flowInfo, line) && r.matchLineTo(ctx, flowInfo, line)
override def toString: String = l + " ∧ " + r
case class Either(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
l.matchLineTo(ctx, flowInfo, line) || r.matchLineTo(ctx, flowInfo, line)
override def toString: String = s"($l $r)"
case object Elidable extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case object Linear extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
case object LinearOrBranch extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.AllLinear(line.opcode) || OpcodeClasses.ShortBranching(line.opcode)
case object LinearOrLabel extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
line.opcode == Opcode.LABEL || OpcodeClasses.AllLinear(line.opcode)
case object ReadsA extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ReadsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ReadsAIfImplied(line.opcode)
case object ReadsMemory extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
line.addrMode match {
case AddrMode.Indirect => true
case AddrMode.Implied | AddrMode.Immediate => false
case _ =>
case object ReadsX extends AssemblyLinePattern {
val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX, AddrMode.AbsoluteIndexedX)
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ReadsXAlways(line.opcode) || XAddrModes(line.addrMode)
case object ReadsY extends AssemblyLinePattern {
val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY)
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ReadsYAlways(line.opcode) || YAddrModes(line.addrMode)
case object ConcernsC extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ReadsC(line.opcode) && OpcodeClasses.ChangesC(line.opcode)
case object ConcernsA extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ConcernsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ConcernsAIfImplied(line.opcode)
case object ConcernsX extends AssemblyLinePattern {
val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX)
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ConcernsXAlways(line.opcode) || XAddrModes(line.addrMode)
case object ConcernsY extends AssemblyLinePattern {
val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY)
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ConcernsYAlways(line.opcode) || YAddrModes(line.addrMode)
case object ChangesA extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ChangesAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ChangesAIfImplied(line.opcode)
case object ChangesMemory extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode)
case class DoesntChangeMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = {
val p1 = ctx.get[Constant](param1)
val p2 = line.parameter
val a1 = ctx.get[AddrMode.Value](addrMode1)
val a2 = line.addrMode
val changesSomeMemory = OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode)
!changesSomeMemory || memoryAccessDoesntOverlap(a1, p1, a2, p2)
case object ConcernsMemory extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ReadsMemory.matchLineTo(ctx, flowInfo, line) && ChangesMemory.matchLineTo(ctx, flowInfo, line)
case class DoesNotConcernMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = {
val p1 = ctx.get[Constant](param1)
val p2 = line.parameter
val a1 = ctx.get[AddrMode.Value](addrMode1)
val a2 = line.addrMode
memoryAccessDoesntOverlap(a1, p1, a2, p2)
case class HasOpcode(op: Opcode.Value) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
line.opcode == op
override def toString: String = op.toString
case class HasOpcodeIn(ops: Set[Opcode.Value]) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
override def toString: String = ops.mkString("{", ",", "}")
case class HasAddrMode(am: AddrMode.Value) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
line.addrMode == am
override def toString: String = am.toString
case class HasAddrModeIn(ams: Set[AddrMode.Value]) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
override def toString: String = ams.mkString("{", ",", "}")
case class HasImmediate(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
line.addrMode == AddrMode.Immediate && (line.parameter.quickSimplify match {
case NumericConstant(j, _) => (i & 0xff) == (j & 0xff)
case _ => false
override def toString: String = "#" + i
case class MatchObject(i: Int, f: Function[AssemblyLine, Any]) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.addObject(i, f(line))
override def toString: String = s"(?<$i>...)"
case class MatchParameter(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.addObject(i, line.parameter.quickSimplify)
override def toString: String = s"(?<$i>Param)"
case class DontMatchParameter(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.dontMatch(i, line.parameter.quickSimplify)
override def toString: String = s"¬(?<$i>Param)"
case class MatchAddrMode(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.addObject(i, line.addrMode)
override def toString: String = s"¬(?<$i>AddrMode)"
case class MatchOpcode(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.addObject(i, line.opcode)
override def toString: String = s"¬(?<$i>Op)"
case class MatchImmediate(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
if (line.addrMode == AddrMode.Immediate) {
ctx.addObject(i, line.parameter.quickSimplify)
} else false
override def toString: String = s"(?<$i>#)"
case class DoesntChangeIndexingInAddrMode(i: Int) extends AssemblyLinePattern {
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
ctx.get[AddrMode.Value](i) match {
case AddrMode.ZeroPageX | AddrMode.AbsoluteX | AddrMode.IndexedX | AddrMode.AbsoluteIndexedX => !OpcodeClasses.ChangesX.contains(line.opcode)
case AddrMode.ZeroPageY | AddrMode.AbsoluteY | AddrMode.IndexedY => !OpcodeClasses.ChangesY.contains(line.opcode)
case _ => true
override def toString: String = s"¬(?<$i>AddrMode)"
case class Before(pattern: AssemblyPattern) extends AssemblyLinePattern {
override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match {
case Nil => None
case x :: xs => pattern.matchTo(ctx, xs) match {
case Some(m) => Some(xs)
case None => None
override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = ???

View File

@ -0,0 +1,8 @@
package millfork.assembly.opt
* @author Karol Stasiak
object SizeOptimizations {

View File

@ -0,0 +1,75 @@
package millfork.assembly.opt
import millfork.{CompilationFlag, CompilationOptions, OptimizationPresets}
import millfork.assembly.{AddrMode, AssemblyLine, Opcode}
import millfork.env.NormalFunction
import millfork.error.ErrorReporting
import scala.collection.mutable
* @author Karol Stasiak
object SuperOptimizer extends AssemblyOptimization {
def optimize(m: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
val oldVerbosity = ErrorReporting.verbosity
ErrorReporting.verbosity = -1
var allOptimizers = OptimizationPresets.Good ++ LaterOptimizations.All
if (options.flag(CompilationFlag.EmitIllegals)) {
allOptimizers ++= UndocumentedOptimizations.All
if (options.flag(CompilationFlag.EmitCmosOpcodes)) {
allOptimizers ++= CmosOptimizations.All
allOptimizers ++= List(
val seenSoFar = mutable.Set[CodeView]()
val queue = mutable.Queue[(List[AssemblyOptimization], List[AssemblyLine])]()
val leaves = mutable.ListBuffer[(List[AssemblyOptimization], List[AssemblyLine])]()
seenSoFar += viewCode(code)
queue.enqueue(Nil -> code)
while(queue.nonEmpty) {
val (optsSoFar, codeSoFar) = queue.dequeue()
var isLeaf = true
allOptimizers.par.foreach { o =>
val optimized = o.optimize(m, codeSoFar, options)
val view = viewCode(optimized)
if (!seenSoFar(view)) {
isLeaf = false
seenSoFar += view
queue.enqueue((o :: optsSoFar) -> optimized)
if (isLeaf) {
// println( + " B: " +" -> "))
leaves += optsSoFar -> codeSoFar
val result = leaves.minBy(
ErrorReporting.verbosity = oldVerbosity
ErrorReporting.debug(s"Visited ${leaves.size} leaves")
ErrorReporting.debug(s"${} B -> ${} B: ${" -> ")}")
result._1.reverse.foldLeft(code){(c, opt) =>
val n = opt.optimize(m, c, options)
// println(c.mkString("","",""))
// println(n.mkString("","",""))
override val name = "Superoptimizer"
def viewCode(code: List[AssemblyLine]): CodeView = {
CodeView( => l.opcode -> l.addrMode))
case class CodeView(content: List[(Opcode.Value, AddrMode.Value)])

View File

@ -0,0 +1,340 @@
package millfork.assembly.opt
import java.util.concurrent.atomic.AtomicInteger
import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State}
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.assembly.OpcodeClasses._
import millfork.env.{Constant, NormalFunction, NumericConstant}
* @author Karol Stasiak
object UndocumentedOptimizations {
val counter = new AtomicInteger(30000)
def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d"
// TODO: test these
private val LaxAddrModeRestriction = Not(HasAddrModeIn(Set(AbsoluteX, ZeroPageX, IndexedX, Immediate)))
//noinspection ScalaUnnecessaryParentheses
val UseLax = new RuleBasedAssemblyOptimization("Using undocumented instruction LAX",
needsFlowInfo = FlowInfoRequirement.BackwardFlow,
(HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~
(HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
(HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~
(HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
(HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~
(HasOpcode(TAX) & Elidable) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
(HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~
(HasOpcode(TXA) & Elidable) ~~> { (code, ctx) =>
ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
(HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~
(HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
(HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~
(HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
(HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~
(HasOpcode(TAX) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
(HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~
(LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~
(HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
val SaxModes: Set[AddrMode.Value] = Set(ZeroPage, IndexedX, ZeroPageY, Absolute)
val UseSax = new RuleBasedAssemblyOptimization("Using undocumented instruction SAX",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(10) ~
(HasOpcode(AND) & Elidable & MatchAddrMode(2) & MatchParameter(3) & Not(ReadsX)) ~
(Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(11) ~
(HasOpcode(STA) & Elidable & MatchAddrMode(4) & MatchParameter(5) & HasAddrModeIn(SaxModes) & DontMatchParameter(0)) ~
(Linear & Not(ConcernsA) & Not(ConcernsX) & Not(ChangesMemory)).?.capture(12) ~
(HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
(LinearOrLabel & Not(ConcernsX)).*.capture(13) ~ OverwritesX ~~> { (code, ctx) =>
val lda = code.head
val ldx = AssemblyLine(LDX, ctx.get[AddrMode.Value](2), ctx.get[Constant](3))
val sax = AssemblyLine(SAX, ctx.get[AddrMode.Value](4), ctx.get[Constant](5))
val fragment0 = lda :: ctx.get[List[AssemblyLine]](10)
val fragment1 = ldx :: ctx.get[List[AssemblyLine]](11)
val fragment2 = sax :: ctx.get[List[AssemblyLine]](12)
val fragment3 = ctx.get[List[AssemblyLine]](13)
List(fragment0, fragment1, fragment2, fragment3).flatten
def andConstant(const: Constant, mask: Int): Option[Long] = const match {
case NumericConstant(n, _) => Some(n & mask)
case _ => None
val UseAnc = new RuleBasedAssemblyOptimization("Using undocumented instruction ANC",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(Elidable & HasOpcode(LDA) & HasImmediate(0)) ~
(Elidable & HasOpcode(CLC)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))),
(Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.C)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))),
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~
(Elidable & HasOpcode(CLC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
Where(c => andConstant(c.get[Constant](0), 0x80).contains(0x80)) ~
(Elidable & HasOpcode(SEC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
(Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
(Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0) & HasClear(State.C)) ~
Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0) & HasSet(State.C)) ~
Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
(Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
(Elidable & HasOpcodeIn(Set(ROL, ASL)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.Z, State.N, State.A)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
val UseSbx = new RuleBasedAssemblyOptimization("Using undocumented instruction SBX",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~
Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List(
AssemblyLine.immediate(SBX, ctx.get[Int](0)),
(Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~
Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List(
AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)),
HasOpcode(TXA) ~
(Elidable & HasOpcode(CLC)).? ~
(Elidable & HasClear(State.C) & HasClear(State.D) & HasOpcode(ADC) & MatchImmediate(0)) ~
(Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List(
AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)),
HasOpcode(TXA) ~
(Elidable & HasOpcode(SEC)).? ~
(Elidable & HasSet(State.C) & HasClear(State.D) & HasOpcode(SBC) & MatchImmediate(0)) ~
(Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List(
AssemblyLine.immediate(SBX, ctx.get[Int](0)),
val UseAlr = new RuleBasedAssemblyOptimization("Using undocumented instruction ALR",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
(Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~
(Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~~> { (code, ctx) =>
List(AssemblyLine.immediate(ALR, code.head.parameter))
(Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~
(Elidable & HasOpcode(CLC)) ~~> { (code, ctx) =>
List(AssemblyLine.immediate(ALR, 0xFE))
val UseArr = new RuleBasedAssemblyOptimization("Using undocumented instruction ARR",
needsFlowInfo = FlowInfoRequirement.BothFlows,
(HasClear(State.D) & Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~
(Elidable & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
List(AssemblyLine.immediate(ARR, code.head.parameter))
private def trivialSequence1(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) =
(Elidable & HasOpcode(o1) & HasAddrMode(Absolute) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~
(Elidable & HasOpcode(o2) & HasAddrMode(Absolute) & MatchParameter(1)) ~~> { (code, ctx) =>
code.tail.init :+ AssemblyLine(combined, Absolute, ctx.get[Constant](1))
private def trivialSequence2(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) =
(Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~
(Elidable & HasOpcode(o2) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
code.tail.init :+ AssemblyLine(combined, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))
// ROL c LDA c AND d => LDA d RLA c
private def trivialCommutativeSequence(o1: Opcode.Value, o2: Opcode.Value, combined: Opcode.Value) = {
(Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(o2) & MatchAddrMode(2) & MatchParameter(3)) ~~> { code =>
List(code(2).copy(opcode = LDA), code(1).copy(opcode = combined))
val UseSlo = new RuleBasedAssemblyOptimization("Using undocumented instruction SLO",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
trivialSequence1(ASL, ORA, Not(ConcernsC), SLO),
trivialSequence2(ASL, ORA, Not(ConcernsC), SLO),
trivialCommutativeSequence(ASL, ORA, SLO),
(Elidable & HasOpcode(ASL) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ConcernsMemory)).* ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SLO, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)))
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~
(Elidable & HasOpcode(ASL) & HasAddrMode(Implied)) ~
(Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++
ctx.get[List[AssemblyLine]](2) ++
val UseSre = new RuleBasedAssemblyOptimization("Using undocumented instruction SRE",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
trivialSequence1(LSR, EOR, Not(ConcernsC), SRE),
trivialSequence2(LSR, EOR, Not(ConcernsC), SRE),
trivialCommutativeSequence(LSR, EOR, SRE),
(Elidable & HasOpcode(LSR) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ConcernsMemory)).* ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)))
(Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
(Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~
(Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~
(Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++
ctx.get[List[AssemblyLine]](2) ++
val UseRla = new RuleBasedAssemblyOptimization("Using undocumented instruction RLA",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
trivialSequence1(ROL, AND, Not(ConcernsC), RLA),
trivialSequence2(ROL, AND, Not(ConcernsC), RLA),
trivialCommutativeSequence(ROL, AND, RLA),
val UseRra = new RuleBasedAssemblyOptimization("Using undocumented instruction RRA",
needsFlowInfo = FlowInfoRequirement.NoRequirement,
// TODO: is it ok? carry flag and stuff?
trivialSequence1(ROR, ADC, Not(ConcernsC), RRA),
trivialSequence2(ROR, ADC, Not(ConcernsC), RRA),
trivialCommutativeSequence(ROR, ADC, RRA),
val UseDcp = new RuleBasedAssemblyOptimization("Using undocumented instruction DCP",
needsFlowInfo = FlowInfoRequirement.BothFlows,
trivialSequence1(DEC, CMP, Not(ConcernsC), DCP),
trivialSequence2(DEC, CMP, Not(ConcernsC), DCP),
(Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~
(Elidable & HasOpcode(TAX)) ~
(Elidable & HasOpcode(DEC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.head.copy(opcode = LDY), code.last.copy(opcode = DCP, addrMode = AbsoluteY))
(Elidable & HasOpcode(DEC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(CMP) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code =>
List(code(2).copy(opcode = LDA), code(1).copy(opcode = DCP))
val UseIsc = new RuleBasedAssemblyOptimization("Using undocumented instruction ISC",
needsFlowInfo = FlowInfoRequirement.BothFlows,
trivialSequence1(INC, SBC, Not(ReadsC), ISC),
trivialSequence2(INC, SBC, Not(ReadsC), ISC),
(Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~
(Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("is")
AssemblyLine.relative(BCC, label),
code.last.copy(opcode = ISC),
(Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
val label = getNextLabel("is")
AssemblyLine.relative(BCC, label),
code.last.copy(opcode = ISC),
(Elidable & HasOpcode(CLC)).? ~
(Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~
(Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = ISC))
(Elidable & HasOpcode(CLC)).? ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasClear(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = ISC))
(Elidable & HasOpcode(SEC)).? ~
(Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D) & HasSet(State.C)) ~
(Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = ISC))
(Elidable & HasOpcode(SEC)).? ~
(Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasSet(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
(Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
(Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.last.copy(opcode = ISC))
(Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~
(Elidable & HasOpcode(TAX)) ~
(Elidable & HasOpcode(INC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code =>
List(code.head.copy(opcode = LDY), code.last.copy(opcode = ISC, addrMode = AbsoluteY))
(Elidable & HasOpcode(INC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
(Elidable & HasOpcode(CMP) & HasClear(State.D) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code =>
List(code(2).copy(opcode = LDA), AssemblyLine.implied(SEC), code(1).copy(opcode = ISC))
val All: List[AssemblyOptimization] = List(

View File

@ -0,0 +1,38 @@
package millfork.assembly.opt
import millfork.CompilationOptions
import millfork.assembly.AddrMode._
import millfork.assembly.Opcode._
import millfork.assembly.{AddrMode, AssemblyLine}
import millfork.env._
import millfork.error.ErrorReporting
* @author Karol Stasiak
object UnusedLabelRemoval extends AssemblyOptimization {
override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
val usedLabels = code.flatMap {
case AssemblyLine(LABEL, _, _, _) => None
case AssemblyLine(_, _, MemoryAddressConstant(Label(l)), _) => Some(l)
case _ => None
val definedLabels = code.flatMap {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => Some(l).filter(_.startsWith("."))
case _ => None
val toRemove = definedLabels -- usedLabels
if (toRemove.nonEmpty) {
ErrorReporting.debug("Removing labels: " + toRemove.mkString(", "))
code.filterNot {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => toRemove(l)
case _ => false
} else {
override def name = "Unused label removal"

View File

@ -0,0 +1,322 @@
package millfork.assembly.opt
import millfork.CompilationOptions
import millfork.assembly.{AddrMode, AssemblyLine}
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.env._
import millfork.error.ErrorReporting
import scala.annotation.tailrec
* @author Karol Stasiak
object VariableToRegisterOptimization extends AssemblyOptimization {
// If any of these opcodes is present within a method,
// then it's too hard to assign any variable to a register.
private val opcodesThatAlwaysPrecludeXAllocation = Set(JSR, STX, TXA, PHX, PLX, INX, DEX, CPX, SBX, SAX)
private val opcodesThatAlwaysPrecludeYAllocation = Set(JSR, STY, TYA, PHY, PLY, INY, DEY, CPY)
// If any of these opcodes is used on a variable
// then it's too hard to assign that variable to a register.
// Also, LDY prevents assigning a variable to X and LDX prevents assigning a variable to Y.
private val opcodesThatCannotBeUsedWithIndexRegistersAsParameters =
override def name = "Allocating variables to index registers"
override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
val paramVariables = f.params match {
case NormalParamSignature(ps) =>
case _ =>
// assembly functions do not get this optimization
return code
val stillUsedVariables = code.flatMap {
case AssemblyLine(_, _, MemoryAddressConstant(th), _) => Some(
case _ => None
val localVariables = f.environment.getAllLocalVariables.filter {
case MemoryVariable(name, typ, VariableAllocationMethod.Auto) =>
typ.size == 1 && !paramVariables(name) && stillUsedVariables(name)
case _ => false
val candidates = None :: => Option(
val variants = for {
vx <- candidates.par
vy <- candidates
if vx != vy
(score, prologueLength) <- canBeInlined(vx, vy, code.tail, Some(1))
if prologueLength >= 1
} yield (score, prologueLength, vx, vy)
if (variants.isEmpty) {
return code
val (_, bestPrologueLength, bestX, bestY) = variants.max
if ((bestX.isDefined || bestY.isDefined) && bestPrologueLength != 0xffff) {
(bestX, bestY) match {
case (Some(x), Some(y)) => ErrorReporting.debug(s"Inlining $x to X and $y to Y")
case (Some(x), None) => ErrorReporting.debug(s"Inlining $x to X")
case (None, Some(y)) => ErrorReporting.debug(s"Inlining $y to Y")
case _ =>
code.take(bestPrologueLength) ++ inlineVars(bestX, bestY, code.drop(bestPrologueLength))
} else {
private def add(i: Int) = (p: (Int, Int)) => (p._1 + i) -> p._2
private def mark(i: Option[Int]) = (p: (Int, Int)) => p._1 -> i.getOrElse(p._2)
def canBeInlined(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine], instrCounter: Option[Int]): Option[(Int, Int)] = {
val vx = xCandidate.getOrElse("-")
val vy = yCandidate.getOrElse("-")
val next = + 1)
val next2 = + 2)
lines match {
case AssemblyLine(_, Immediate, SubbyteConstant(MemoryAddressConstant(th), _), _) :: xs
if == vx || == vy =>
// if an address of a variable is used, then that variable cannot be assigned to a register
case AssemblyLine(_, Immediate, HalfWordConstant(MemoryAddressConstant(th), _), _) :: xs
if == vx || == vy =>
// if an address of a variable is used, then that variable cannot be assigned to a register
case AssemblyLine(_, AbsoluteX | AbsoluteY | ZeroPageX | ZeroPageY, MemoryAddressConstant(th), _) :: xs =>
// if a variable is used as an array, then it cannot be assigned to a register
if ( == vx || == vy) {
} else {
canBeInlined(xCandidate, yCandidate, xs, next)
case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx && (opcode == LDY || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) =>
// if a variable is used by some opcodes, then it cannot be assigned to a register
case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy && (opcode == LDX || opcode == LAX || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) =>
// if a variable is used by some opcodes, then it cannot be assigned to a register
case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), elidable) :: xs
if xCandidate.isDefined =>
// if a register is populated with a different variable, then this variable cannot be assigned to that register
// removing LDX saves 3 cycles
if (elidable && == vx) {
canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
} else {
case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), elidable) :: xs
if xCandidate.isDefined =>
// LAX = LDX-LDA, and since LDX simplifies to nothing and LDA simplifies to TXA,
// LAX simplifies to TXA, saving two bytes
if (elidable && == vx) {
canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter))
} else {
case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), elidable) :: xs if yCandidate.isDefined =>
// if a register is populated with a different variable, then this variable cannot be assigned to that register
// removing LDX saves 3 cycles
if (elidable && == vy) {
canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
} else {
case AssemblyLine(LDX, _, _, _) :: xs if xCandidate.isDefined =>
// if a register is populated with something else than a variable, then no variable cannot be assigned to that register
case AssemblyLine(LDY, _, _, _) :: xs if yCandidate.isDefined =>
// if a register is populated with something else than a variable, then no variable cannot be assigned to that register
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAX, _, _, elidable2) :: xs
if xCandidate.isDefined =>
// a variable cannot be inlined if there is TAX not after LDA of that variable
// but LDA-TAX can be simplified to TXA
if (elidable && elidable2 && == vx) {
canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
} else {
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAY, _, _, elidable2) :: xs
if yCandidate.isDefined =>
// a variable cannot be inlined if there is TAY not after LDA of that variable
// but LDA-TAY can be simplified to TYA
if (elidable && elidable2 && == vy) {
canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
} else {
case AssemblyLine(LDA | STA | INC | DEC, Absolute, MemoryAddressConstant(th), elidable) :: xs =>
// changing LDA->TXA, STA->TAX, INC->INX, DEC->DEX saves 2 cycles
if ( == vy || == vx) {
if (elidable) canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter))
else None
} else {
canBeInlined(xCandidate, yCandidate, xs, next)
case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined =>
// a variable cannot be inlined if there is TAX not after LDA of that variable
if (instrCounter.isDefined) {
canBeInlined(xCandidate, yCandidate, xs, next)
} else None
case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined =>
// a variable cannot be inlined if there is TAY not after LDA of that variable
if (instrCounter.isDefined) {
canBeInlined(xCandidate, yCandidate, xs, next)
} else None
case AssemblyLine(LABEL, _, _, _) :: xs =>
// labels always end the initial section
canBeInlined(xCandidate, yCandidate, xs, None).map(mark(instrCounter))
case x :: xs =>
if (instrCounter.isDefined) {
canBeInlined(xCandidate, yCandidate, xs, next)
} else {
if (xCandidate.isDefined && opcodesThatAlwaysPrecludeXAllocation(x.opcode)) {
} else if (yCandidate.isDefined && opcodesThatAlwaysPrecludeYAllocation(x.opcode)) {
} else {
canBeInlined(xCandidate, yCandidate, xs, next)
case Nil => Some(0 -> -1)
def inlineVars(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine]): List[AssemblyLine] = {
val vx = xCandidate.getOrElse("-")
val vy = yCandidate.getOrElse("-")
lines match {
case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
AssemblyLine.implied(INX) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy =>
AssemblyLine.implied(INY) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
AssemblyLine.implied(DEX) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy =>
AssemblyLine.implied(DEY) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy =>
inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAX, _, _, true) :: xs
if == vx =>
// these TXA's may get optimized away by a different optimization
AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAY, _, _, true) :: xs
if == vy =>
// these TYA's may get optimized away by a different optimization
AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs
if == vx && doesntUseX(am) =>
// these TXA's may get optimized away by a different optimization
AssemblyLine(LDX, am, param) :: AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs
if == vy && doesntUseY(am) =>
// these TYA's may get optimized away by a different optimization
AssemblyLine(LDY, am, param) :: AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs
if == vx && doesntUseXOrY(am) =>
// ditto
AssemblyLine.implied(TXA) :: AssemblyLine(CPX, am, param) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs
if == vy && doesntUseXOrY(am) =>
// ditto
AssemblyLine.implied(TYA) :: AssemblyLine(CPY, am, param) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy =>
AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs
if == vx =>
AssemblyLine.implied(TAX) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs
if == vy =>
AssemblyLine.implied(TAY) :: inlineVars(xCandidate, yCandidate, xs)
case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined =>
ErrorReporting.fatal("Unexpected TAX")
case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined =>
ErrorReporting.fatal("Unexpected TAY")
case x :: xs => x :: inlineVars(xCandidate, yCandidate, xs)
case Nil => Nil
def doesntUseY(am: AddrMode.Value): Boolean = am match {
case AbsoluteY | ZeroPageY | IndexedY => false
case _ => true
def doesntUseX(am: AddrMode.Value): Boolean = am match {
case AbsoluteX | ZeroPageX | IndexedX => false
case _ => true
def doesntUseXOrY(am: AddrMode.Value): Boolean = am match {
case Immediate | ZeroPage | Absolute | Relative | Indirect => true
case _ => false

View File

@ -0,0 +1,201 @@
package millfork.cli
* @author Karol Stasiak
trait CliOption[T, O <: CliOption[T, O]] {
this: O =>
def toStrings(firstTab: Int): List[String] = {
val fl = firstLine
if (_description == "") {
} else if (fl.length < firstTab) {
List(fl.padTo(firstTab, ' ') + _description)
} else {
List(fl, "".padTo(firstTab, ' ') + _description)
protected def firstLine: String = names.mkString(" | ")
def names: Seq[String]
private[cli] def length: Int
private[cli] val _shortName: String
private[cli] var _description: String = ""
private[cli] var _hidden = false
private[cli] var _maxEncounters = 1
private[cli] var _minEncounters = 0
private[cli] var _actualEncounters = 0
private[cli] var _onTooFew: Option[Int => Unit] = None
private[cli] var _onTooMany: Option[Int => Unit] = None
def validate(): Boolean = {
var ok = true
if (_actualEncounters < _minEncounters) {
_onTooFew.fold(throw new IllegalArgumentException(s"Too few ${_shortName} options: required ${_minEncounters}, given ${_actualEncounters}"))(_ (_actualEncounters))
ok = false
if (_actualEncounters > _maxEncounters) {
_onTooMany.fold()(_ (_actualEncounters))
ok = false
def onWrongNumber(action: Int => Unit): Unit = {
_onTooFew = Some(action)
_onTooMany = Some(action)
def onTooFew(action: Int => Unit): Unit = {
_onTooFew = Some(action)
def onTooMany(action: Int => Unit): Unit = {
_onTooMany = Some(action)
def encounter(): Unit = {
_actualEncounters += 1
def description(d: String): O = {
_description = d
def hidden(): O = {
_hidden = true
def minCount(count: Int): O = {
_minEncounters = count
def maxCount(count: Int): O = {
_maxEncounters = count
def required(): O = minCount(1)
def repeatable(): O = maxCount(Int.MaxValue)
class Fluff[T](val text: Seq[String]) extends CliOption[T, Fluff[T]] {
override def toStrings(firstTab: Int): List[String] = text.toList
override def length = 0
override val _shortName = ""
override def names = Nil
class NoMoreOptions[T](val names: Seq[String]) extends CliOption[T, NoMoreOptions[T]] {
override def length = 1
override val _shortName = names.head
class UnknownParamOption[T] extends CliOption[T, UnknownParamOption[T]] {
this._hidden = true
override def length = 0
val names: Seq[String] = Nil
private var _action: ((String, T) => T) = (_, x) => x
def action(a: ((String, T) => T)): UnknownParamOption[T] = {
_action = a
def encounter(value: String, t: T): T = {
_action(value, t)
override private[cli] val _shortName = ""
class FlagOption[T](val names: Seq[String]) extends CliOption[T, FlagOption[T]] {
override def length = 1
private var _action: (T => T) = x => x
def action(a: (T => T)): FlagOption[T] = {
_action = a
def encounter(t: T): T = {
override val _shortName = names.head
class BooleanOption[T](val trueName: String, val falseName: String) extends CliOption[T, BooleanOption[T]] {
override def length = 1
private var _action: ((T,Boolean) => T) = (x,_) => x
def action(a: ((T,Boolean) => T)): BooleanOption[T] = {
_action = a
def encounter(asName: String, t: T): T = {
if (asName == trueName) {
return _action(t, true)
if (asName == falseName) {
return _action(t, false)
override val _shortName = names.head
override protected def firstLine: String = trueName + " | " + falseName
override def names = Seq(trueName, falseName)
class ParamOption[T](val names: Seq[String]) extends CliOption[T, ParamOption[T]] {
override protected def firstLine: String = names.mkString(" | ") + " " + _paramPlaceholder
override def length = 2
private var _action: ((String, T) => T) = (_, x) => x
private var _paramPlaceholder: String = "<x>"
def placeholder(p: String): ParamOption[T] = {
_paramPlaceholder = p
def action(a: ((String, T) => T)): ParamOption[T] = {
_action = a
def encounter(value: String, t: T): T = {
_action(value, t)
override val _shortName = names.head

View File

@ -0,0 +1,81 @@
package millfork.cli
import fastparse.core.Parsed.Failure
import scala.collection.mutable
* @author Karol Stasiak
class CliParser[T] {
private val options = mutable.ArrayBuffer[CliOption[T, _]]()
private val mapFlags = mutable.Map[String, CliOption[T, _]]()
private val mapOptions = mutable.Map[String, CliOption[T, _]]()
private val _default = new UnknownParamOption[T]().action((p, _) => throw new IllegalArgumentException(s"Unknown option $p"))
private var _status: Option[CliStatus.Value] = None
options += _default
private def add[O <: CliOption[T, _]](o: O) = {
options += o
o.length match {
case 1 =>
o.names.foreach { n => mapFlags(n) = o }
case 2 =>
o.names.foreach { n => mapOptions(n) = o }
case _ => ()
def parse(context: T, args: List[String]): (CliStatus.Value, T) = {
val t = parseInner(context, args)
_status.getOrElse(if (options.forall(_.validate())) CliStatus.Ok else CliStatus.Failed) -> t
def assumeStatus(s: CliStatus.Value): Unit = {
_status = Some(s)
private def parseInner(context: T, args: List[String]): T = {
args match {
case k :: v :: xs if mapOptions.contains(k) =>
mapOptions(k) match {
case p: ParamOption[T] => parseInner(p.encounter(v, context), xs)
case _ => ???
case k :: xs if mapFlags.contains(k) =>
mapFlags(k) match {
case p: FlagOption[T] =>
parseInner(p.encounter(context), xs)
case p: BooleanOption[T] =>
parseInner(p.encounter(k, context), xs)
case p: NoMoreOptions[T] =>
xs.foldLeft(context)((t, x) => _default.encounter(x, t))
case _ => ???
case x :: xs =>
parseInner(_default.encounter(x, context), xs)
case Nil => context
def fluff(text: String*): Unit = add(new Fluff[T](text))
def flag(names: String*): FlagOption[T] = add(new FlagOption[T](names))
def boolean(trueName: String, falseName: String): BooleanOption[T] = add(new BooleanOption[T](trueName, falseName))
def endOfFlags(names: String*): NoMoreOptions[T] = add(new NoMoreOptions[T](names))
def default: UnknownParamOption[T] = _default
def printHelp(firstTab: Int): List[String] = {
def parameter(names: String*): ParamOption[T] = add(new ParamOption[T](names))

View File

@ -0,0 +1,8 @@
package millfork.cli
* @author Karol Stasiak
object CliStatus extends Enumeration {
val Ok, Failed, Quit = Value

View File

@ -0,0 +1,832 @@
package millfork.compiler
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly._
import millfork.env._
import millfork.node._
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.error.ErrorReporting
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox
object ComparisonType extends Enumeration {
val Equal, NotEqual,
LessUnsigned, LessSigned,
GreaterUnsigned, GreaterSigned,
LessOrEqualUnsigned, LessOrEqualSigned,
GreaterOrEqualUnsigned, GreaterOrEqualSigned = Value
def flip(x: ComparisonType.Value): ComparisonType.Value = x match {
case LessUnsigned => GreaterUnsigned
case GreaterUnsigned => LessUnsigned
case LessOrEqualUnsigned => GreaterOrEqualUnsigned
case GreaterOrEqualUnsigned => LessOrEqualUnsigned
case LessSigned => GreaterSigned
case GreaterSigned => LessSigned
case LessOrEqualSigned => GreaterOrEqualSigned
case GreaterOrEqualSigned => LessOrEqualSigned
case _ => x
def negate(x: ComparisonType.Value): ComparisonType.Value = x match {
case LessUnsigned => GreaterOrEqualUnsigned
case GreaterUnsigned => LessOrEqualUnsigned
case LessOrEqualUnsigned => GreaterUnsigned
case GreaterOrEqualUnsigned => LessUnsigned
case LessSigned => GreaterOrEqualSigned
case GreaterSigned => LessOrEqualSigned
case LessOrEqualSigned => GreaterSigned
case GreaterOrEqualSigned => LessSigned
case Equal => NotEqual
case NotEqual => Equal
* @author Karol Stasiak
object BuiltIns {
object IndexChoice extends Enumeration {
val RequireX, PreferX, PreferY = Value
def wrapInSedCldIfNeeded(decimal: Boolean, code: List[AssemblyLine]): List[AssemblyLine] = {
if (decimal) {
AssemblyLine.implied(SED) :: (code :+ AssemblyLine.implied(CLD))
} else {
def staTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = => if (x.opcode == STA) x.copy(opcode = op) else x)
def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x)
def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean): List[AssemblyLine] = {
val env = ctx.env
val parts: (List[AssemblyLine], List[AssemblyLine]) = env.eval(source).fold {
val b = env.get[Type]("byte")
source match {
case VariableExpression(name) =>
val v = env.get[Variable](name)
if (v.typ.size > 1) {
ErrorReporting.error(s"Variable `$name` is too big for a built-in operation", source.position)
return Nil
Nil -> AssemblyLine.variable(ctx, opcode, v)
case IndexedExpression(arrayName, index) =>
indexChoice match {
case IndexChoice.RequireX | IndexChoice.PreferX =>
val array = env.getArrayOrPointer(arrayName)
val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.X, b)), NoBranching)
val baseAddress = array match {
case c: ConstantThing => c.value
case a: MlArray => a.toAddress
calculateIndex -> List(AssemblyLine.absoluteX(opcode, baseAddress))
case IndexChoice.PreferY =>
val array = env.getArrayOrPointer(arrayName)
val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.Y, b)), NoBranching)
val baseAddress = array match {
case c: ConstantThing => c.value
case a: MlArray => a.toAddress
calculateIndex -> List(AssemblyLine.absoluteY(opcode, baseAddress))
case f: FunctionCallExpression if commutative =>
// TODO: is it ok?
return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), f, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ List(
AssemblyLine.absoluteX(opcode, 0x101),
case _ =>
ErrorReporting.error("Right-hand-side expression is too complex", source.position)
return Nil
} {
const =>
if (const.requiredSize > 1) {
ErrorReporting.error("Constant too big for a built-in operation", source.position)
Nil -> List(AssemblyLine.immediate(opcode, const))
val preparations = parts._1
val finalRead = parts._2
if (preserveA && AssemblyLine.treatment(preparations, State.A) != Treatment.Unchanged) {
AssemblyLine.implied(PHA) :: (preparations ++ (AssemblyLine.implied(PLA) :: finalRead))
} else {
preparations ++ finalRead
def insertBeforeLast(item: AssemblyLine, list: List[AssemblyLine]): List[AssemblyLine] = list match {
case Nil => Nil
case last :: dex :: txs :: Nil if dex.opcode == DEX && txs.opcode == TXS => item :: last :: dex :: txs :: Nil
case last :: inx :: txs :: Nil if inx.opcode == INX && txs.opcode == TXS => item :: last :: inx :: txs :: Nil
case last :: Nil => item :: last :: Nil
case first :: rest => first :: insertBeforeLast(item, rest)
def compileAddition(ctx: CompilationContext, params: List[(Boolean, Expression)], decimal: Boolean): List[AssemblyLine] = {
if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
ErrorReporting.warn("Unsupported decimal operation", ctx.options, params.head._2.position)
// if (params.isEmpty) {
// return Nil
// }
val env = ctx.env
val b = env.get[Type]("byte")
val sortedParams = params.sortBy { case (subtract, expr) =>
val constPart = env.eval(expr) match {
case Some(NumericConstant(_, _)) => "Z"
case Some(_) => "Y"
case None => expr match {
case VariableExpression(_) => "V"
case IndexedExpression(_, LiteralExpression(_, _)) => "K"
case IndexedExpression(_, VariableExpression(_)) => "J"
case IndexedExpression(_, _) => "I"
case _ => "A"
val subtractPart = if (subtract) "X" else "P"
constPart + subtractPart
// TODO: merge constants
val normalizedParams = sortedParams
val h = normalizedParams.head
val firstParamCompiled = MlCompiler.compile(ctx, h._2, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val firstParamSignCompiled = if (h._1) {
List(AssemblyLine.immediate(EOR, 0xff), AssemblyLine.implied(SEC), AssemblyLine.immediate(ADC, 0))
} else {
val remainingParamsCompiled = normalizedParams.tail.flatMap { p =>
if (p._1) {
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false))
} else {
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true))
wrapInSedCldIfNeeded(decimal, firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled)
def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
val sortedParams = params.sortBy { expr =>
ctx.env.eval(expr) match {
case Some(NumericConstant(_, _)) => "Z"
case Some(_) => "Y"
case None => expr match {
case VariableExpression(_) => "V"
case IndexedExpression(_, LiteralExpression(_, _)) => "K"
case IndexedExpression(_, VariableExpression(_)) => "J"
case IndexedExpression(_, _) => "I"
case _ => "A"
val h = sortedParams.head
val firstParamCompiled = MlCompiler.compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val remainingParamsCompiled = sortedParams.tail.flatMap { p =>
simpleOperation(opcode, ctx, p, IndexChoice.PreferY, preserveA = true, commutative = true)
firstParamCompiled ++ remainingParamsCompiled
def compileShiftOps(opcode: Opcode.Value, ctx: CompilationContext, l: Expression, r: Expression): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
val firstParamCompiled = MlCompiler.compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
case Some(NumericConstant(v, _)) if v > 0 =>
firstParamCompiled ++ List.fill(v.toInt)(AssemblyLine.implied(opcode))
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
def compileNonetOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
val (ldaHi, ldaLo) = lhs match {
case v: VariableExpression =>
val variable = env.get[Variable](
AssemblyLine.variable(ctx, LDA, variable, 1) -> AssemblyLine.variable(ctx, LDA, variable, 0)
case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
AssemblyLine.variable(ctx, LDA, env.get[Variable](, 0) -> AssemblyLine.variable(ctx, LDA, env.get[Variable](, 0)
case _ =>
env.eval(rhs) match {
case Some(NumericConstant(0, _)) =>
case Some(NumericConstant(shift, _)) if shift > 0 =>
if (ctx.options.flag(CompilationFlag.RorWarning))
ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position)
ldaHi ++ List(AssemblyLine.implied(ROR)) ++ ldaLo ++ List(AssemblyLine.implied(ROR)) ++ List.fill(shift.toInt - 1)(AssemblyLine.implied(LSR))
case _ =>
ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
def compileInPlaceByteShiftOps(opcode: Opcode.Value, ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
env.eval(rhs) match {
case Some(NumericConstant(0, _)) =>
case Some(NumericConstant(v, _)) if v > 0 =>
val result = simpleOperation(opcode, ctx, lhs, IndexChoice.RequireX, preserveA = true, commutative = false)
result ++ List.fill(v.toInt - 1)(result.last)
case _ =>
ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
def compileInPlaceWordOrLongShiftOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, aslRatherThanLsr: Boolean): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
val targetBytes = lhs match {
case v: VariableExpression =>
val variable = env.get[Variable](
List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
AssemblyLine.variable(ctx, STA, env.get[Variable](,
AssemblyLine.variable(ctx, STA, env.get[Variable](
val lo = targetBytes.head
val hi = targetBytes.last
env.eval(rhs) match {
case Some(NumericConstant(0, _)) =>
case Some(NumericConstant(shift, _)) if shift > 0 =>
List.fill(shift.toInt)(if (aslRatherThanLsr) {
staTo(ASL, lo) ++ targetBytes.tail.flatMap { b => staTo(ROL, b) }
} else {
if (ctx.options.flag(CompilationFlag.RorWarning))
ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position)
staTo(LSR, hi) ++ targetBytes.reverse.tail.flatMap { b => staTo(ROR, b) }
case _ =>
ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
def compileByteComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
env.eval(rhs) match {
case Some(NumericConstant(0, _)) =>
compType match {
case ComparisonType.LessUnsigned =>
ErrorReporting.warn("Unsigned < 0 is always false", ctx.options, lhs.position)
case ComparisonType.LessOrEqualUnsigned =>
if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings))
ErrorReporting.warn("Unsigned <= 0 means the same as unsigned == 0", ctx.options, lhs.position)
case ComparisonType.GreaterUnsigned =>
if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings))
ErrorReporting.warn("Unsigned > 0 means the same as unsigned != 0", ctx.options, lhs.position)
case ComparisonType.GreaterOrEqualUnsigned =>
ErrorReporting.warn("Unsigned >= 0 is always true", ctx.options, lhs.position)
case _ =>
case Some(NumericConstant(1, _)) =>
if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings)) {
compType match {
case ComparisonType.LessUnsigned =>
ErrorReporting.warn("Unsigned < 1 means the same as unsigned == 0", ctx.options, lhs.position)
case ComparisonType.GreaterOrEqualUnsigned =>
ErrorReporting.warn("Unsigned >= 1 means the same as unsigned != 0", ctx.options, lhs.position)
case _ =>
case _ =>
val secondParamCompiledUnoptimized = simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false)
val secondParamCompiled = compType match {
case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned =>
secondParamCompiledUnoptimized match {
case List(AssemblyLine(CMP, Immediate, NumericConstant(0, _), true)) =>
if (OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) {
} else {
case _ => secondParamCompiledUnoptimized
case _ => secondParamCompiledUnoptimized
val (effectiveComparisonType, label) = branches match {
case NoBranching => return Nil
case BranchIfTrue(l) => compType -> l
case BranchIfFalse(l) => ComparisonType.negate(compType) -> l
val branchingCompiled = effectiveComparisonType match {
case ComparisonType.Equal =>
List(AssemblyLine.relative(BEQ, Label(label)))
case ComparisonType.NotEqual =>
List(AssemblyLine.relative(BNE, Label(label)))
case ComparisonType.LessUnsigned =>
List(AssemblyLine.relative(BCC, Label(label)))
case ComparisonType.GreaterOrEqualUnsigned =>
List(AssemblyLine.relative(BCS, Label(label)))
case ComparisonType.LessOrEqualUnsigned =>
List(AssemblyLine.relative(BCC, Label(label)), AssemblyLine.relative(BEQ, Label(label)))
case ComparisonType.GreaterUnsigned =>
val x = MlCompiler.nextLabel("co")
AssemblyLine.relative(BEQ, x),
AssemblyLine.relative(BCS, Label(label)),
case ComparisonType.LessSigned =>
List(AssemblyLine.relative(BMI, Label(label)))
case ComparisonType.GreaterOrEqualSigned =>
List(AssemblyLine.relative(BPL, Label(label)))
case ComparisonType.LessOrEqualSigned =>
List(AssemblyLine.relative(BMI, Label(label)), AssemblyLine.relative(BEQ, Label(label)))
case ComparisonType.GreaterSigned =>
val x = MlCompiler.nextLabel("co")
AssemblyLine.relative(BEQ, x),
AssemblyLine.relative(BPL, Label(label)),
firstParamCompiled ++ secondParamCompiled ++ branchingCompiled
def compileWordComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = {
val env = ctx.env
// TODO: comparing stack variables
val b = env.get[Type]("byte")
val w = env.get[Type]("word")
val (effectiveComparisonType, x) = branches match {
case NoBranching => return Nil
case BranchIfTrue(label) => compType -> label
case BranchIfFalse(label) => ComparisonType.negate(compType) -> label
val (lh, ll, rh, rl, ram) = (lhs, env.eval(lhs), rhs, env.eval(rhs)) match {
case (_, Some(NumericConstant(lc, _)), _, Some(NumericConstant(rc, _))) =>
return if (effectiveComparisonType match {
// TODO: those masks are probably wrong
case ComparisonType.Equal =>
(lc & 0xffff) == (rc & 0xffff) // ??
case ComparisonType.NotEqual =>
(lc & 0xffff) != (rc & 0xffff) // ??
case ComparisonType.LessOrEqualUnsigned =>
(lc & 0xffff) <= (rc & 0xffff)
case ComparisonType.GreaterOrEqualUnsigned =>
(lc & 0xffff) >= (rc & 0xffff)
case ComparisonType.GreaterUnsigned =>
(lc & 0xffff) > (rc & 0xffff)
case ComparisonType.LessUnsigned =>
(lc & 0xffff) < (rc & 0xffff)
case ComparisonType.LessOrEqualSigned =>
lc.toShort <= rc.toShort
case ComparisonType.GreaterOrEqualSigned =>
lc.toShort >= rc.toShort
case ComparisonType.GreaterSigned =>
lc.toShort > rc.toShort
case ComparisonType.LessSigned =>
lc.toShort < rc.toShort
}) List(AssemblyLine.absolute(JMP, Label(x))) else Nil
case (_, Some(lc), _, Some(rc)) =>
// TODO: comparing late-bound constants
case (_, Some(lc), rv: VariableInMemory, None) =>
return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches)
case (v: VariableExpression, None, _, Some(rc)) =>
// TODO: stack variables
(env.get[VariableInMemory]( + ".hi").toAddress,
env.get[VariableInMemory]( + ".lo").toAddress,
case (lv: VariableExpression, None, rv: VariableExpression, None) =>
// TODO: stack variables
(env.get[VariableInMemory]( + ".hi").toAddress,
env.get[VariableInMemory]( + ".lo").toAddress,
env.get[VariableInMemory]( + ".hi").toAddress,
env.get[VariableInMemory]( + ".lo").toAddress, Absolute)
effectiveComparisonType match {
case ComparisonType.Equal =>
val innerLabel = MlCompiler.nextLabel("cp")
List(AssemblyLine.absolute(LDA, ll),
AssemblyLine(CMP, ram, rl),
AssemblyLine.relative(BNE, innerLabel),
AssemblyLine.absolute(LDA, lh),
AssemblyLine(CMP, ram, rh),
AssemblyLine.relative(BEQ, Label(x)),
case ComparisonType.NotEqual =>
List(AssemblyLine.absolute(LDA, ll),
AssemblyLine(CMP, ram, rl),
AssemblyLine.relative(BNE, Label(x)),
AssemblyLine.absolute(LDA, lh),
AssemblyLine(CMP, ram, rh),
AssemblyLine.relative(BNE, Label(x)))
case ComparisonType.LessUnsigned =>
val innerLabel = MlCompiler.nextLabel("cp")
List(AssemblyLine.absolute(LDA, lh),
AssemblyLine(CMP, ram, rh),
AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.relative(BNE, innerLabel),
AssemblyLine.absolute(LDA, ll),
AssemblyLine(CMP, ram, rl),
AssemblyLine.relative(BCC, Label(x)),
case ComparisonType.LessOrEqualUnsigned =>
val innerLabel = MlCompiler.nextLabel("cp")
List(AssemblyLine(LDA, ram, rh),
AssemblyLine.absolute(CMP, lh),
AssemblyLine.relative(BCC, innerLabel),
AssemblyLine.relative(BNE, x),
AssemblyLine(LDA, ram, rl),
AssemblyLine.absolute(CMP, ll),
AssemblyLine.relative(BCS, x),
case ComparisonType.GreaterUnsigned =>
val innerLabel = MlCompiler.nextLabel("cp")
List(AssemblyLine(LDA, ram, rh),
AssemblyLine.absolute(CMP, lh),
AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.relative(BNE, innerLabel),
AssemblyLine(LDA, ram, rl),
AssemblyLine.absolute(CMP, ll),
AssemblyLine.relative(BCC, Label(x)),
case ComparisonType.GreaterOrEqualUnsigned =>
val innerLabel = MlCompiler.nextLabel("cp")
List(AssemblyLine.absolute(LDA, lh),
AssemblyLine(CMP, ram, rh),
AssemblyLine.relative(BCC, innerLabel),
AssemblyLine.relative(BNE, x),
AssemblyLine.absolute(LDA, ll),
AssemblyLine(CMP, ram, rl),
AssemblyLine.relative(BCS, x),
case _ => ???
// TODO: signed word comparisons
def compileInPlaceByteMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
ctx.env.eval(addend) match {
case Some(NumericConstant(0, _)) =>
AssemblyLine.immediate(LDA, 0) :: MlCompiler.compileByteStorage(ctx, Register.A, v)
case Some(NumericConstant(1, _)) =>
case Some(NumericConstant(x, _)) =>
compileByteMultiplication(ctx, v, x.toInt) ++ MlCompiler.compileByteStorage(ctx, Register.A, v)
case _ =>
ErrorReporting.error("Multiplying by not a constant not supported", v.position)
def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = {
val result = ListBuffer[AssemblyLine]()
// TODO: optimise
val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false)
val adc = addingCode.last
val indexing = addingCode.init
result ++= indexing
result += AssemblyLine.immediate(LDA, 0)
val mult = c & 0xff
var mask = 128
var empty = true
while (mask > 0) {
if (!empty) {
result += AssemblyLine.implied(ASL)
if ((mult & mask) != 0) {
result ++= List(AssemblyLine.implied(CLC), adc)
empty = false
mask >>>= 1
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = {
val (constants, variables) = => p -> ctx.env.eval(p)).partition(_._2.exists(_.isInstanceOf[NumericConstant]))
val constant =[NumericConstant].value).foldLeft(1L)(_ * _).toInt
variables.length match {
case 0 => List(AssemblyLine.immediate(LDA, constant & 0xff))
case 1 =>compileByteMultiplication(ctx, variables.head._1, constant)
case 2 =>
ErrorReporting.error("Multiplying by not a constant not supported", params.head.position)
def compileInPlaceByteAddition(ctx: CompilationContext, v: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = {
if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
ErrorReporting.warn("Unsupported decimal operation", ctx.options, v.position)
val env = ctx.env
val b = env.get[Type]("byte")
env.eval(addend) match {
case Some(NumericConstant(0, _)) => Nil
case Some(NumericConstant(1, _)) if !decimal => if (subtract) {
simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
} else {
simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
// TODO: compile +=2 to two INCs
case Some(NumericConstant(-1, _)) if !decimal => if (subtract) {
simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
} else {
simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
case _ =>
val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val modifyLhs = if (subtract) {
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false))
} else {
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true))
val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v)
wrapInSedCldIfNeeded(decimal, loadLhs ++ modifyLhs ++ storeLhs)
def compileInPlaceWordOrLongAddition(ctx: CompilationContext, lhs: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = {
if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
ErrorReporting.warn("Unsupported decimal operation", ctx.options, lhs.position)
val env = ctx.env
val b = env.get[Type]("byte")
val w = env.get[Type]("word")
val targetBytes: List[List[AssemblyLine]] = lhs match {
case v: VariableExpression =>
val variable = env.get[Variable](
List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
val lv = env.get[Variable](
val hv = env.get[Variable](
AssemblyLine.variable(ctx, STA, lv),
AssemblyLine.variable(ctx, STA, hv))
val lhsIsStack = targetBytes.head.head.opcode == TSX
val targetSize = targetBytes.size
val addendType = MlCompiler.getExpressionType(ctx, addend)
var addendSize = addendType.size
def isRhsComplex(xs: List[AssemblyLine]): Boolean = xs match {
case AssemblyLine(LDA, _, _, _) :: Nil => false
case AssemblyLine(LDA, _, _, _) :: AssemblyLine(LDX, _, _, _) :: Nil => false
case _ => true
def isRhsStack(xs: List[AssemblyLine]): Boolean = xs.exists(_.opcode == TSX)
val (calculateRhs, addendByteRead0): (List[AssemblyLine], List[List[AssemblyLine]]) = env.eval(addend) match {
case Some(constant) =>
addendSize = targetSize
Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i))))
case None =>
addendSize match {
case 1 =>
val base = MlCompiler.compile(ctx, addend, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
if (subtract) {
if (isRhsComplex(base)) {
if (isRhsStack(base)) {
ErrorReporting.warn("Subtracting a stack-based value", ctx.options)
(base ++ List(AssemblyLine.implied(PHA))) -> List(List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101)))
} else {
Nil -> :: Nil)
} else {
base -> List(Nil)
case 2 =>
val base = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AX, w)), NoBranching)
if (isRhsStack(base)) {
val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching)
if (subtract) {
ErrorReporting.warn("Subtracting a stack-based value", ctx.options)
if (isRhsComplex(base)) {
} else {
Nil -> fixedBase
} else {
fixedBase -> List(Nil, List(AssemblyLine.implied(TYA)))
} else {
if (subtract) {
if (isRhsComplex(base)) {
(base ++ List(
) -> List(
List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x102)),
List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101)))
} else {
Nil -> :: Nil)
} else {
if (lhsIsStack) {
val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching)
fixedBase -> List(Nil, List(AssemblyLine.implied(TYA)))
} else {
base -> List(Nil, List(AssemblyLine.implied(TXA)))
case _ => Nil -> (addend match {
case vv: VariableExpression =>
val source = env.get[Variable](
List.tabulate(addendSize)(i => AssemblyLine.variable(ctx, LDA, source, i))
val addendByteRead = addendByteRead0 ++ List.fill((targetSize - addendByteRead0.size) max 0)(List(AssemblyLine.immediate(LDA, 0)))
val buffer = mutable.ListBuffer[AssemblyLine]()
buffer ++= calculateRhs
buffer += AssemblyLine.implied(if (subtract) SEC else CLC)
val extendMultipleBytes = targetSize > addendSize + 1
val extendAtLeastOneByte = targetSize > addendSize
for (i <- 0 until targetSize) {
if (subtract) {
if (addendSize < targetSize && addendType.isSigned) {
// TODO: sign extension
buffer ++= staTo(LDA, targetBytes(i))
buffer ++= ldTo(SBC, addendByteRead(i))
buffer ++= targetBytes(i)
} else {
if (i >= addendSize) {
if (addendType.isSigned) {
val label = MlCompiler.nextLabel("sx")
buffer += AssemblyLine.implied(TXA)
if (i == addendSize) {
buffer += AssemblyLine.immediate(ORA, 0x7f)
buffer += AssemblyLine.relative(BMI, label)
buffer += AssemblyLine.immediate(LDA, 0)
buffer += AssemblyLine.label(label)
if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX)
} else {
buffer += AssemblyLine.immediate(LDA, 0)
} else {
buffer ++= addendByteRead(i)
if (addendType.isSigned && i == addendSize - 1 && extendAtLeastOneByte) {
buffer += AssemblyLine.implied(TAX)
buffer ++= staTo(ADC, targetBytes(i))
buffer ++= targetBytes(i)
for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) {
buffer += AssemblyLine.implied(PLA)
wrapInSedCldIfNeeded(decimal, buffer.toList)
def compileInPlaceByteBitOp(ctx: CompilationContext, v: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
(operation, env.eval(param)) match {
case (EOR, Some(NumericConstant(0, _)))
| (ORA, Some(NumericConstant(0, _)))
| (AND, Some(NumericConstant(0xff, _)))
| (AND, Some(NumericConstant(-1, _))) =>
case _ =>
val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val modifyLhs = simpleOperation(operation, ctx, param, IndexChoice.PreferY, preserveA = true, commutative = true)
val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v)
loadLhs ++ modifyLhs ++ storeLhs
def compileInPlaceWordOrLongBitOp(ctx: CompilationContext, lhs: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = {
val env = ctx.env
val b = env.get[Type]("byte")
val w = env.get[Type]("word")
val targetBytes: List[List[AssemblyLine]] = lhs match {
case v: VariableExpression =>
val variable = env.get[Variable](
List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
val lv = env.get[Variable](
val hv = env.get[Variable](
AssemblyLine.variable(ctx, STA, lv),
AssemblyLine.variable(ctx, STA, hv))
case _ =>
val lo = targetBytes.head
val targetSize = targetBytes.size
val paramType = MlCompiler.getExpressionType(ctx, param)
var paramSize = paramType.size
val extendMultipleBytes = targetSize > paramSize + 1
val extendAtLeastOneByte = targetSize > paramSize
val (calculateRhs, addendByteRead) = env.eval(param) match {
case Some(constant) =>
paramSize = targetSize
Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i))))
case None =>
paramSize match {
case 1 =>
val base = MlCompiler.compile(ctx, param, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
base -> List(Nil)
case 2 =>
val base = MlCompiler.compile(ctx, param, Some(w -> RegisterVariable(Register.AX, w)), NoBranching)
base -> List(Nil, List(AssemblyLine.implied(TXA)))
case _ => Nil -> (param match {
case vv: VariableExpression =>
val source = env.get[Variable](
List.tabulate(paramSize)(i => AssemblyLine.variable(ctx, LDA, source, i))
val AllOnes = (1L << (8 * targetSize)) - 1
(operation, env.eval(param)) match {
case (EOR, Some(NumericConstant(0, _)))
| (ORA, Some(NumericConstant(0, _)))
| (AND, Some(NumericConstant(AllOnes, _))) =>
case _ =>
val buffer = mutable.ListBuffer[AssemblyLine]()
buffer ++= calculateRhs
for (i <- 0 until targetSize) {
if (i < paramSize) {
buffer ++= addendByteRead(i)
if (paramType.isSigned && i == paramSize - 1 && extendAtLeastOneByte) {
buffer += AssemblyLine.implied(TAX)
} else {
if (paramType.isSigned) {
val label = MlCompiler.nextLabel("sx")
buffer += AssemblyLine.implied(TXA)
if (i == paramSize) {
buffer += AssemblyLine.immediate(ORA, 0x7f)
buffer += AssemblyLine.relative(BMI, label)
buffer += AssemblyLine.immediate(LDA, 0)
buffer += AssemblyLine.label(label)
if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX)
} else {
buffer += AssemblyLine.immediate(LDA, 0)
buffer ++= staTo(operation, targetBytes(i))
buffer ++= targetBytes(i)
for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) {
buffer += AssemblyLine.implied(PLA)

package millfork.compiler
import millfork.{CompilationFlag, CompilationOptions}
import millfork.env.{Environment, MangledFunction, NormalFunction}
* @author Karol Stasiak
case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){
def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i)

package millfork.env
import millfork.error.ErrorReporting
import millfork.node.Position
object Constant {
val Zero: Constant = NumericConstant(0, 1)
val One: Constant = NumericConstant(1, 1)
def error(msg: String, position: Option[Position] = None): Constant = {
ErrorReporting.error(msg, position)
def minimumSize(value: Long): Int = if (value < -128 || value > 255) 2 else 1 // TODO !!!
import millfork.env.Constant.minimumSize
import millfork.error.ErrorReporting
import millfork.node.Position
sealed trait Constant {
def asl(i: Constant): Constant = i match {
case NumericConstant(sa, _) => asl(sa.toInt)
case _ => CompoundConstant(MathOperator.Shl, this, i)
def asl(i: Int): Constant = CompoundConstant(MathOperator.Shl, this, NumericConstant(i, 1))
def requiredSize: Int
def +(that: Constant): Constant = CompoundConstant(MathOperator.Plus, this, that)
def -(that: Constant): Constant = CompoundConstant(MathOperator.Minus, this, that)
def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that))
def -(that: Long): Constant = this + (-that)
def loByte: Constant = {
if (requiredSize == 1) return this
HalfWordConstant(this, hi = false)
def hiByte: Constant = {
if (requiredSize == 1) Constant.Zero
else HalfWordConstant(this, hi = true)
def subbyte(index: Int): Constant = {
if (requiredSize <= index) Constant.Zero
else index match {
case 0 => loByte
case 1 => hiByte
case _ => SubbyteConstant(this, index)
def isLowestByteAlwaysEqual(i: Int) : Boolean = false
def quickSimplify: Constant = this
case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant
case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
if (requiredSize == 1) {
if (value < -128 || value > 255) {
throw new IllegalArgumentException("The constant is too big")
override def isLowestByteAlwaysEqual(i: Int) : Boolean = (value & 0xff) == (i&0xff)
override def asl(i: Int) = NumericConstant(value << i, requiredSize + i / 8)
override def +(that: Constant): Constant = that + value
override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that))
override def toString: String = if (value > 9) value.formatted("$%X") else value.toString
case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant {
override def requiredSize = 2
override def toString: String =
case class HalfWordConstant(base: Constant, hi: Boolean) extends Constant {
override def quickSimplify: Constant = {
val simplified = base.quickSimplify
simplified match {
case NumericConstant(x, size) => if (hi) {
if (size == 1) Constant.Zero else NumericConstant((x >> 8) & 0xff, 1)
} else {
NumericConstant(x & 0xff, 1)
case _ => HalfWordConstant(simplified, hi)
override def requiredSize = 1
override def toString: String = base + (if (hi) ".hi" else ".lo")
case class SubbyteConstant(base: Constant, index: Int) extends Constant {
override def quickSimplify: Constant = {
val simplified = base.quickSimplify
simplified match {
case NumericConstant(x, size) => if (index >= size) {
} else {
NumericConstant((x >> (index * 8)) & 0xff, 1)
case _ => SubbyteConstant(simplified, index)
override def requiredSize = 1
override def toString: String = base + (index match {
case 0 => ".lo"
case 1 => ".hi"
case 2 => ".b2"
case 3 => ".b3"
object MathOperator extends Enumeration {
val Plus, Minus, Times, Shl, Shr,
DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShr,
And, Or, Exor = Value
case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Constant) extends Constant {
override def quickSimplify: Constant = {
val l = lhs.quickSimplify
val r = rhs.quickSimplify
(l, r) match {
case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
var size = ls max rs
val value = operator match {
case MathOperator.Plus => lv + rv
case MathOperator.Minus => lv - rv
case MathOperator.Times => lv * rv
case MathOperator.Shl => lv << rv
case MathOperator.Shr => lv >> rv
case MathOperator.Exor => lv ^ rv
case MathOperator.Or => lv | rv
case MathOperator.And => lv & rv
case _ => return this
operator match {
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)){
size = ls + rs
case _ =>
NumericConstant(value, size)
case _ => CompoundConstant(operator, l, r)
import MathOperator._
override def +(that: Constant): Constant = {
that match {
case NumericConstant(n, _) => this + n
case _ => super.+(that)
override def +(that: Long): Constant = {
if (that == 0) {
return this
val That = that
val MinusThat = -that
this match {
case CompoundConstant(Plus, NumericConstant(MinusThat, _), r) => r
case CompoundConstant(Plus, l, NumericConstant(MinusThat, _)) => l
case CompoundConstant(Plus, NumericConstant(x, _), r) => CompoundConstant(Plus, r, NumericConstant(x + that, minimumSize(x + that)))
case CompoundConstant(Plus, l, NumericConstant(x, _)) => CompoundConstant(Plus, l, NumericConstant(x + that, minimumSize(x + that)))
case CompoundConstant(Minus, l, NumericConstant(That, _)) => l
case _ => CompoundConstant(Plus, this, NumericConstant(that, minimumSize(that)))
private def plhs = lhs match {
case _: NumericConstant | _: MemoryAddressConstant => lhs
case _ => "(" + lhs + ')'
private def prhs = lhs match {
case _: NumericConstant | _: MemoryAddressConstant => rhs
case _ => "(" + rhs + ')'
override def toString: String = {
operator match {
case Plus => f"$plhs + $prhs"
case Minus => f"$plhs - $prhs"
case Times => f"$plhs * $prhs"
case Shl => f"$plhs << $prhs"
case Shr => f"$plhs >> $prhs"
case DecimalPlus => f"$plhs +' $prhs"
case DecimalMinus => f"$plhs -' $prhs"
case DecimalTimes => f"$plhs *' $prhs"
case DecimalShl => f"$plhs <<' $prhs"
case DecimalShr => f"$plhs >>' $prhs"
case And => f"$plhs & $prhs"
case Or => f"$plhs | $prhs"
case Exor => f"$plhs ^ $prhs"
override def requiredSize: Int = lhs.requiredSize max rhs.requiredSize

package millfork.env
import java.util.concurrent.atomic.AtomicLong
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly.Opcode
import millfork.compiler._
import millfork.error.ErrorReporting
import millfork.node._
import millfork.output.VariableAllocator
import scala.collection.mutable
* @author Karol Stasiak
//noinspection NotImplementedCode
class Environment(val parent: Option[Environment], val prefix: String) {
private var baseStackOffset = 0x101
private val relVarId = new AtomicLong
def genRelativeVariable(constant: Constant, typ: Type, zeropage: Boolean): RelativeVariable = {
val variable = RelativeVariable(".rv__" + relVarId.incrementAndGet().formatted("%06d"), constant, typ, zeropage = zeropage)
addThing(variable, None)
def allThings: Environment = {
val allThings: Map[String, Thing] = {
case m: FunctionInMemory =>
case m: InlinedFunction =>
case _ => Map[String, Thing]()
}.fold(things.toMap)(_ ++ _)
val e = new Environment(None, "")
e.things ++= allThings
private def getAllPrefixedThings = { { case (n, th) => (if (n.startsWith(".")) n else prefix + n, th) }
def getAllLocalVariables: List[Variable] = things.values.flatMap {
case v: Variable =>
case _ => None
def allPreallocatables: List[PrellocableThing] = things.values.flatMap {
case m: NormalFunction => Some(m)
case m: InitializedArray => Some(m)
case _ => None
def allConstants: List[ConstantThing] = things.values.flatMap {
case m: NormalFunction => m.environment.allConstants
case m: InlinedFunction => m.environment.allConstants
case m: ConstantThing => List(m)
case _ => Nil
def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = {
val b = get[Type]("byte")
val p = get[Type]("pointer")
var params = nf.fold(List[String]()) { f =>
f.params match {
case NormalParamSignature(ps) => =>
case _ =>
val toAdd = things.values.flatMap {
case m: UninitializedMemory =>
val vertex = if (options.flag(CompilationFlag.VariableOverlap)) {
nf.fold[VariableVertex](GlobalVertex) { f =>
if (m.alloc == VariableAllocationMethod.Static) {
} else if (params( {
} else {
} else GlobalVertex
m.alloc match {
case VariableAllocationMethod.None =>
case VariableAllocationMethod.Zeropage =>
m.sizeInBytes match {
case 2 =>
val addr =
allocator.allocatePointer(callGraph, vertex)
onEachVariable(, addr)
ConstantThing( + "`", NumericConstant(addr, 2), p)
case VariableAllocationMethod.Auto | VariableAllocationMethod.Static =>
m.sizeInBytes match {
case 0 => Nil
case 2 =>
val addr =
allocator.allocateBytes(callGraph, vertex, options, 2)
onEachVariable(, addr)
ConstantThing( + "`", NumericConstant(addr, 2), p)
case count =>
val addr = allocator.allocateBytes(callGraph, vertex, options, count)
onEachVariable(, addr)
ConstantThing( + "`", NumericConstant(addr, 2), p)
case f: NormalFunction =>
f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable)
case _ => Nil
val tagged: List[(String, Thing)] = => -> x)
things ++= tagged
val things: mutable.Map[String, Thing] = mutable.Map()
private def addThing(t: Thing, position: Option[Position]): Unit = {
assertNotDefined(, position)
things( = t
def removeVariable(str: String): Unit = {
things -= str
things -= str + ".addr"
def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = {
val clazz = implicitly[Manifest[T]].runtimeClass
if (things.contains(name)) {
val t: Thing = things(name)
if ((t ne null) && clazz.isInstance(t)) {
} else {
ErrorReporting.fatal(s"`$name` is not a ${clazz.getSimpleName}", position)
} else parent.fold {
ErrorReporting.fatal(s"${clazz.getSimpleName} `$name` is not defined", position)
} {
_.get[T](name, position)
def maybeGet[T <: Thing : Manifest](name: String): Option[T] = {
if (things.contains(name)) {
val t: Thing = things(name)
val clazz = implicitly[Manifest[T]].runtimeClass
if ((t ne null) && clazz.isInstance(t)) {
} else {
} else parent.flatMap {
def getArrayOrPointer(arrayName: String): Thing = {
orElse(maybeGet[ThingInMemory](arrayName + ".array")).
getOrElse(ErrorReporting.fatal(s"`$arrayName` is not an array or a pointer"))
if (parent.isEmpty) {
addThing(VoidType, None)
addThing(BuiltInBooleanType, None)
addThing(BasicPlainType("byte", 1), None)
addThing(BasicPlainType("word", 2), None)
addThing(BasicPlainType("long", 4), None)
addThing(DerivedPlainType("pointer", get[PlainType]("word"), isSigned = false), None)
addThing(DerivedPlainType("ubyte", get[PlainType]("byte"), isSigned = false), None)
addThing(DerivedPlainType("sbyte", get[PlainType]("byte"), isSigned = true), None)
addThing(DerivedPlainType("cent", get[PlainType]("byte"), isSigned = false), None)
val trueType = ConstantBooleanType("true$", value = true)
val falseType = ConstantBooleanType("false$", value = false)
addThing(trueType, None)
addThing(falseType, None)
addThing(ConstantThing("true", NumericConstant(0, 0), trueType), None)
addThing(ConstantThing("false", NumericConstant(0, 0), falseType), None)
addThing(FlagBooleanType("set_carry", Opcode.BCS, Opcode.BCC), None)
addThing(FlagBooleanType("clear_carry", Opcode.BCC, Opcode.BCS), None)
addThing(FlagBooleanType("set_overflow", Opcode.BVS, Opcode.BVC), None)
addThing(FlagBooleanType("clear_overflow", Opcode.BVC, Opcode.BVS), None)
addThing(FlagBooleanType("set_zero", Opcode.BEQ, Opcode.BNE), None)
addThing(FlagBooleanType("clear_zero", Opcode.BNE, Opcode.BEQ), None)
addThing(FlagBooleanType("set_negative", Opcode.BMI, Opcode.BPL), None)
addThing(FlagBooleanType("clear_negative", Opcode.BPL, Opcode.BMI), None)
def assertNotDefined(name: String, position: Option[Position]): Unit = {
if (things.contains(name) || parent.exists(_.things.contains(name)))
ErrorReporting.fatal(s"`$name` is already defined", position)
def registerType(stmt: TypeDefinitionStatement): Unit = {
// addThing(DerivedPlainType(, get(stmt.parent)))
def sequence[A](a: List[Option[A]]): Option[List[A]] = a match {
case Nil => Some(Nil)
case None :: _ => None
case Some(r) :: t => sequence(t) map (r :: _)
def evalVariableAndConstantSubParts(e: Expression): (Option[Expression], Constant) =
e match {
case SumExpression(params, false) =>
val (constants, variables) = { case (sign, expr) => (sign, expr, eval(expr)) }.partition(_._3.isDefined)
val constant = eval(SumExpression( => (x._1, x._2)), decimal = false)).get
val variable = variables match {
case Nil => None
case List((false, x, _)) => Some(x)
case _ => Some(SumExpression( => (x._1, x._2)), decimal = false))
variable -> constant
case _ => eval(e) match {
case Some(c) => None -> c
case None => Some(e) -> Constant.Zero
def eval(e: Expression): Option[Constant] = {
e match {
case LiteralExpression(value, size) => Some(NumericConstant(value, size))
case VariableExpression(name) =>
case IndexedExpression(_, _) => None
case HalfWordExpression(param, hi) => eval(e).map(c => if (hi) c.hiByte else c.loByte)
case SumExpression(params, decimal) => {
case (minus, param) => (minus, eval(param))
}.foldLeft(Some(Constant.Zero).asInstanceOf[Option[Constant]]) { (oc, pair) =>
oc.flatMap { c =>
pair match {
case (_, None) => None
case (minus, Some(addend)) =>
val op = if (decimal) {
if (minus) MathOperator.DecimalMinus else MathOperator.DecimalPlus
} else {
if (minus) MathOperator.Minus else MathOperator.Plus
Some(CompoundConstant(op, c, addend))
case SeparateBytesExpression(h, l) => for {
lc <- eval(l)
hc <- eval(h)
} yield hc.asl(8) + lc
case FunctionCallExpression(name, params) =>
name match {
case "*" =>
constantOperation(MathOperator.Times, params)
case "&&" | "&" =>
constantOperation(MathOperator.And, params)
case "^" =>
constantOperation(MathOperator.Exor, params)
case "||" | "|" =>
constantOperation(MathOperator.Or, params)
case _ =>
private def constantOperation(op: MathOperator.Value, params: List[Expression]) = {[Option[Constant]] { (oc, om) =>
for {
c <- oc
m <- om
} yield CompoundConstant(op, c, m)
def registerFunction(stmt: FunctionDeclarationStatement, options: CompilationOptions): Unit = {
val w = get[Type]("word")
val name =
val resultType = get[Type](stmt.resultType)
if (stmt.reentrant && stmt.interrupt) ErrorReporting.error(s"Reentrant function `$name` cannot be an interrupt handler", stmt.position)
if (stmt.reentrant && stmt.params.nonEmpty) ErrorReporting.error(s"Reentrant function `$name` cannot have parameters", stmt.position)
if (stmt.interrupt && stmt.params.nonEmpty) ErrorReporting.error(s"Interrupt function `$name` cannot have parameters", stmt.position)
if (stmt.inlined) {
if (!stmt.assembly) {
if (stmt.params.nonEmpty) ErrorReporting.error(s"Inline non-assembly function `$name` cannot have parameters", stmt.position) // TODO: ???
if (resultType != VoidType) ErrorReporting.error(s"Inline non-assembly function `$name` must return void", stmt.position)
if (stmt.params.exists(_.assemblyParamPassingConvention.inNonInlinedOnly))
ErrorReporting.error(s"Inline function `$name` cannot have by-variable parameters", stmt.position)
} else {
if (!stmt.assembly) {
if (stmt.params.exists(!_.assemblyParamPassingConvention.isInstanceOf[ByVariable]))
ErrorReporting.error(s"Non-assembly function `$name` cannot have non-variable parameters", stmt.position)
if (stmt.params.exists(_.assemblyParamPassingConvention.inInlinedOnly))
ErrorReporting.error(s"Non-inline function `$name` cannot have inlinable parameters", stmt.position)
val env = new Environment(Some(this), name + "$")
stmt.params.foreach(p => env.registerParameter(p))
val params = if (stmt.assembly) {
AssemblyParamSignature( {
pd =>
val typ = env.get[Type](pd.typ)
pd.assemblyParamPassingConvention match {
case ByVariable(vn) =>
AssemblyParam(typ, env.get[MemoryVariable](vn), AssemblyParameterPassingBehaviour.Copy)
case ByRegister(reg) =>
AssemblyParam(typ, RegisterVariable(reg, typ), AssemblyParameterPassingBehaviour.Copy)
case ByConstant(vn) =>
AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByConstant)
case ByReference(vn) =>
AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByReference)
} else {
NormalParamSignature( { pd =>
stmt.statements match {
case None =>
stmt.address match {
case None =>
ErrorReporting.error(s"Extern function `${}`needs an address", stmt.position)
case Some(a) =>
val addr = eval(a).getOrElse(Constant.error(s"Address of `${}` is not a constant", stmt.position))
val mangled = ExternFunction(
addThing(mangled, stmt.position)
registerAddressConstant(mangled, stmt.position)
addThing(ConstantThing(name + '`', addr, w), stmt.position)
case Some(statements) =>
statements.foreach {
case v: VariableDeclarationStatement => env.registerVariable(v, options)
case _ => ()
val executableStatements = statements.flatMap {
case e: ExecutableStatement => Some(e)
case _ => None
val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement])
if (stmt.inlined) {
val mangled = new InlinedFunction(
executableStatements ++ (if (needsExtraRTS) List(AssemblyStatement.implied(Opcode.RTS, elidable = true)) else Nil),
addThing(mangled, stmt.position)
} else {
var stackVariablesSize = {
case StackVariable(n, t, _) if !n.contains(".") => t.size
case _ => 0
val mangled = NormalFunction(
stackVariablesSize, => this.eval(a).getOrElse(Constant.error(s"Address of `${}` is not a constant"))),
executableStatements ++ (if (needsExtraRTS) List(ReturnStatement(None)) else Nil),
interrupt = stmt.interrupt,
reentrant = stmt.reentrant,
position = stmt.position
addThing(mangled, stmt.position)
registerAddressConstant(mangled, stmt.position)
private def registerAddressConstant(thing: ThingInMemory, position: Option[Position]): Unit = {
val addr = thing.toAddress
addThing(ConstantThing( + ".addr", addr, get[Type]("pointer")), position)
addThing(ConstantThing( + ".addr.hi", addr.hiByte, get[Type]("byte")), position)
addThing(ConstantThing( + ".addr.lo", addr.loByte, get[Type]("byte")), position)
def registerParameter(stmt: ParameterDeclaration): Unit = {
val typ = get[Type](stmt.typ)
val b = get[Type]("byte")
val p = get[Type]("pointer")
stmt.assemblyParamPassingConvention match {
case ByVariable(name) =>
val zp = == "pointer" // TODO
val v = MemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto)
addThing(v, stmt.position)
registerAddressConstant(v, stmt.position)
if (typ.size == 2) {
val addr = v.toAddress
addThing(RelativeVariable( + ".hi", addr + 1, b, zeropage = zp), stmt.position)
addThing(RelativeVariable( + ".lo", addr, b, zeropage = zp), stmt.position)
case ByRegister(_) => ()
case ByConstant(name) =>
val v = ConstantThing(prefix + name, UnexpandedConstant(prefix + name, typ.size), typ)
addThing(v, stmt.position)
case ByReference(name) =>
val addr = UnexpandedConstant(prefix + name, typ.size)
val v = RelativeVariable(prefix + name, addr, p, zeropage = false)
addThing(v, stmt.position)
addThing(RelativeVariable( + ".hi", addr + 1, b, zeropage = false), stmt.position)
addThing(RelativeVariable( + ".lo", addr, b, zeropage = false), stmt.position)
def registerArray(stmt: ArrayDeclarationStatement): Unit = {
val b = get[Type]("byte")
val p = get[Type]("pointer")
stmt.elements match {
case None =>
stmt.length match {
case None => ErrorReporting.error(s"Array `${}` without size nor contents", stmt.position)
case Some(l) =>
val address = => eval(a).getOrElse(ErrorReporting.fatal(s"Array `${}` has non-constant address", stmt.position)))
val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${}` has non-constant length", stmt.position))
lengthConst match {
case NumericConstant(length, _) =>
if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${}` has invalid length", stmt.position)
val array = address match {
case None => UninitializedArray( + ".array", length.toInt)
case Some(aa) => RelativeArray( + ".array", aa, length.toInt)
addThing(array, stmt.position)
registerAddressConstant(MemoryVariable(, p, VariableAllocationMethod.None), stmt.position)
val a = address match {
case None => array.toAddress
case Some(aa) => aa
addThing(RelativeVariable( + ".first", a, b, zeropage = false), stmt.position)
addThing(ConstantThing(, a, p), stmt.position)
addThing(ConstantThing( + ".hi", a.hiByte, b), stmt.position)
addThing(ConstantThing( + ".lo", a.loByte, b), stmt.position)
addThing(ConstantThing( + ".array.hi", a.hiByte, b), stmt.position)
addThing(ConstantThing( + ".array.lo", a.loByte, b), stmt.position)
if (length < 256) {
addThing(ConstantThing( + ".length", lengthConst, b), stmt.position)
case _ => ErrorReporting.error(s"Array `${}` has weird length", stmt.position)
case Some(contents) =>
stmt.length match {
case None =>
case Some(l) =>
val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${}` has non-constant length", stmt.position))
lengthConst match {
case NumericConstant(ll, _) =>
if (ll != contents.length) ErrorReporting.error(s"Array `${}` has different declared and actual length", stmt.position)
case _ => ErrorReporting.error(s"Array `${}` has weird length", stmt.position)
val length = contents.length
if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${}` has invalid length", stmt.position)
val address = => eval(a).getOrElse(Constant.error(s"Array `${}` has non-constant address", stmt.position)))
val data = => eval(x).getOrElse(Constant.error(s"Array `${}` has non-constant contents", stmt.position)))
val array = InitializedArray( + ".array", address, data)
addThing(array, stmt.position)
registerAddressConstant(MemoryVariable(, p, VariableAllocationMethod.None), stmt.position)
val a = address match {
case None => array.toAddress
case Some(aa) => aa
addThing(RelativeVariable( + ".first", a, b, zeropage = false), stmt.position)
addThing(ConstantThing(, a, p), stmt.position)
addThing(ConstantThing( + ".hi", a.hiByte, b), stmt.position)
addThing(ConstantThing( + ".lo", a.loByte, b), stmt.position)
addThing(ConstantThing( + ".array.hi", a.hiByte, b), stmt.position)
addThing(ConstantThing( + ".array.lo", a.loByte, b), stmt.position)
if (length < 256) {
addThing(ConstantThing( + ".length", NumericConstant(length, 1), b), stmt.position)
def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions): Unit = {
if (stmt.volatile) {
ErrorReporting.warn("`volatile` not yet supported", options)
val name =
val position = stmt.position
if (stmt.stack && parent.isEmpty) {
if (stmt.stack && ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position)
val b = get[Type]("byte")
val typ = get[PlainType](stmt.typ)
if (stmt.typ == "pointer") {
// if (stmt.constant) {
// ErrorReporting.error(s"Pointer `${}` cannot be constant")
// }
stmt.address.flatMap(eval) match {
case Some(NumericConstant(a, _)) =>
if ((a & 0xff00) != 0)
ErrorReporting.error(s"Pointer `${}` cannot be located outside the zero page")
case _ => ()
if (stmt.constant) {
if (stmt.stack) ErrorReporting.error(s"`$name` is a constant and cannot be on stack", position)
if (stmt.address.isDefined) ErrorReporting.error(s"`$name` is a constant and cannot have an address", position)
if (stmt.initialValue.isEmpty) ErrorReporting.error(s"`$name` is a constant and requires a value", position)
val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(Constant.error(s"`$name` has a non-constant value", position))
if (constantValue.requiredSize > typ.size) ErrorReporting.error(s"`$name` is has an invalid value: not in the range of `$typ`", position)
addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position)
if (typ.size == 2) {
addThing(ConstantThing(prefix + name + ".hi", constantValue + 1, b), stmt.position)
addThing(ConstantThing(prefix + name + ".lo", constantValue, b), stmt.position)
} else {
if (stmt.stack && ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position)
if (stmt.initialValue.isDefined) ErrorReporting.error(s"`$name` is not a constant and cannot have a value", position)
if (stmt.stack) {
val v = StackVariable(prefix + name, typ, this.baseStackOffset)
baseStackOffset += typ.size
addThing(v, stmt.position)
if (typ.size == 2) {
addThing(StackVariable(prefix + name + ".lo", b, baseStackOffset), stmt.position)
addThing(StackVariable(prefix + name + ".hi", b, baseStackOffset + 1), stmt.position)
} else {
val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({
val alloc = if ( == "pointer") VariableAllocationMethod.Zeropage else if ( VariableAllocationMethod.Static else VariableAllocationMethod.Auto
val v = MemoryVariable(prefix + name, typ, alloc)
registerAddressConstant(v, stmt.position)
(v, v.toAddress)
})(a => {
val addr = eval(a).getOrElse(Constant.error(s"Address of `$name` has a non-constant value", position))
val zp = addr match {
case NumericConstant(n, _) => n < 0x100
case _ => false
(RelativeVariable(prefix + name, addr, typ, zeropage = zp), addr)
addThing(v, stmt.position)
if (!v.isInstanceOf[MemoryVariable]) {
addThing(ConstantThing( + "`", addr, b), stmt.position)
if (typ.size == 2) {
addThing(RelativeVariable(prefix + name + ".hi", addr + 1, b, zeropage = v.zeropage), stmt.position)
addThing(RelativeVariable(prefix + name + ".lo", addr, b, zeropage = v.zeropage), stmt.position)
def lookup[T <: Thing : Manifest](name: String): Option[T] = {
if (things.contains(name)) {
} else {
def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = {
if (things.contains(name)) {
val function = get[MangledFunction](name)
if (function.params.length != actualParams.length) {
ErrorReporting.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position))
function.params match {
case NormalParamSignature(params) => { case ((required, (actual, expr)), m) =>
if (!actual.isAssignableTo(required)) {
ErrorReporting.error(s"Invalid value for parameter `${}` of function `$name`", expr.position)
case AssemblyParamSignature(params) => { case ((required, (actual, expr)), ix) =>
if (!actual.isAssignableTo(required)) {
ErrorReporting.error(s"Invalid value for parameter ${ix + 1} of function `$name`", expr.position)
} else {
parent.flatMap(_.lookupFunction(name, actualParams))
def collectDeclarations(program: Program, options: CompilationOptions): Unit = {
program.declarations.foreach {
case f: FunctionDeclarationStatement => registerFunction(f, options)
case v: VariableDeclarationStatement => registerVariable(v, options)
case a: ArrayDeclarationStatement => registerArray(a)
case i: ImportStatement => ()

package millfork.env
import millfork.assembly.Opcode
import millfork.error.ErrorReporting
import millfork.node._
sealed trait Thing {
def name: String
sealed trait Type extends Thing {
def size: Int
def isSigned: Boolean
def isSubtypeOf(other: Type): Boolean = this == other
def isCompatible(other: Type): Boolean = this == other
override def toString(): String = name
def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType)
case object VoidType extends Type {
def size = 0
def isSigned = false
override def name = "void"
sealed trait PlainType extends Type {
override def isCompatible(other: Type): Boolean = this == other || this.isSubtypeOf(other) || other.isSubtypeOf(this)
override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || (targetType match {
case BasicPlainType(_, size) => size > this.size // TODO
case _ => false
case class BasicPlainType(name: String, size: Int) extends PlainType {
def isSigned = false
override def isSubtypeOf(other: Type): Boolean = this == other
case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean) extends PlainType {
def size: Int = parent.size
override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other)
sealed trait BooleanType extends Type {
def size = 0
def isSigned = false
case class ConstantBooleanType(name: String, value: Boolean) extends BooleanType
case class FlagBooleanType(name: String, jumpIfTrue: Opcode.Value, jumpIfFalse: Opcode.Value) extends BooleanType
case object BuiltInBooleanType extends BooleanType {
override def name = "bool$"
sealed trait TypedThing extends Thing {
def typ: Type
sealed trait ThingInMemory extends Thing {
def toAddress: Constant
sealed trait PrellocableThing extends ThingInMemory {
def shouldGenerate: Boolean
def address: Option[Constant]
def toAddress: Constant = address.getOrElse(MemoryAddressConstant(this))
case class Label(name: String) extends ThingInMemory {
override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
sealed trait Variable extends TypedThing
case class BlackHole(typ: Type) extends Variable {
override def name = "<black hole>"
sealed trait VariableInMemory extends Variable with ThingInMemory {
def zeropage: Boolean
case class RegisterVariable(register: Register.Value, typ: Type) extends Variable {
def name: String = register.toString
case class Placeholder(name: String, typ: Type) extends Variable
sealed trait UninitializedMemory extends ThingInMemory {
def sizeInBytes: Int
def alloc: VariableAllocationMethod.Value
object VariableAllocationMethod extends Enumeration {
val Auto, Static, Zeropage, None = Value
case class StackVariable(name: String, typ: Type, baseOffset: Int) extends Variable {
def sizeInBytes: Int = typ.size
case class MemoryVariable(name: String, typ: Type, alloc: VariableAllocationMethod.Value) extends VariableInMemory with UninitializedMemory {
override def sizeInBytes: Int = typ.size
override def zeropage: Boolean = alloc == VariableAllocationMethod.Zeropage
override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
trait MlArray extends ThingInMemory
case class UninitializedArray(name: String, sizeInBytes: Int) extends MlArray with UninitializedMemory {
override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
override def alloc = VariableAllocationMethod.Static
case class RelativeArray(name: String, address: Constant, sizeInBytes: Int) extends MlArray {
override def toAddress: Constant = address
case class InitializedArray(name: String, address: Option[Constant], contents: List[Constant]) extends MlArray with PrellocableThing {
override def shouldGenerate = true
case class RelativeVariable(name: String, address: Constant, typ: Type, zeropage: Boolean) extends VariableInMemory {
override def toAddress: Constant = address
sealed trait MangledFunction extends Thing {
def name: String
def returnType: Type
def params: ParamSignature
def interrupt: Boolean
case class EmptyFunction(name: String,
returnType: Type,
paramType: Type) extends MangledFunction {
override def params = EmptyFunctionParamSignature(paramType)
override def interrupt = false
case class InlinedFunction(name: String,
returnType: Type,
params: ParamSignature,
environment: Environment,
code: List[ExecutableStatement]) extends MangledFunction {
override def interrupt = false
sealed trait FunctionInMemory extends MangledFunction with ThingInMemory {
def environment: Environment
case class ExternFunction(name: String,
returnType: Type,
params: ParamSignature,
address: Constant,
environment: Environment) extends FunctionInMemory {
override def toAddress: Constant = address
override def interrupt = false
case class NormalFunction(name: String,
returnType: Type,
params: ParamSignature,
environment: Environment,
stackVariablesSize: Int,
address: Option[Constant],
code: List[ExecutableStatement],
interrupt: Boolean,
reentrant: Boolean,
position: Option[Position]) extends FunctionInMemory with PrellocableThing {
override def shouldGenerate = true
case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing
trait ParamSignature {
def types: List[Type]
def length: Int
case class NormalParamSignature(params: List[MemoryVariable]) extends ParamSignature {
override def length: Int = params.length
override def types: List[Type] =
sealed trait ParamPassingConvention {
def inInlinedOnly: Boolean
def inNonInlinedOnly: Boolean
case class ByRegister(register: Register.Value) extends ParamPassingConvention {
override def inInlinedOnly = false
override def inNonInlinedOnly = false
case class ByVariable(name: String) extends ParamPassingConvention {
override def inInlinedOnly = false
override def inNonInlinedOnly = true
case class ByConstant(name: String) extends ParamPassingConvention {
override def inInlinedOnly = true
override def inNonInlinedOnly = false
case class ByReference(name: String) extends ParamPassingConvention {
override def inInlinedOnly = true
override def inNonInlinedOnly = false
object AssemblyParameterPassingBehaviour extends Enumeration {
val Copy, ByReference, ByConstant = Value
case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value)
case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature {
override def length: Int = params.length
override def types: List[Type] =
case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature {
override def length: Int = 1
override def types: List[Type] = List(paramType)

package millfork.error
import millfork.{CompilationFlag, CompilationOptions}
import millfork.node.Position
object ErrorReporting {
var verbosity = 0
var hasErrors = false
def f(position: Option[Position]): String = position.fold("")(p => s"(${p.line}:${p.column}) ")
def info(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 0) return
println("INFO: " + f(position) + msg)
def debug(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 1) return
println("DEBUG: " + f(position) + msg)
def trace(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 2) return
println("TRACE: " + f(position) + msg)
private def flushOutput(): Unit = {
def warn(msg: String, options: CompilationOptions, position: Option[Position] = None): Unit = {
if (verbosity < 0) return
println("WARN: " + f(position) + msg)
if (options.flag(CompilationFlag.FatalWarnings)) {
hasErrors = true
def error(msg: String, position: Option[Position] = None): Unit = {
hasErrors = true
println("ERROR: " + f(position) + msg)
def fatal(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
throw new RuntimeException(msg)
def fatalQuit(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
throw new RuntimeException(msg)
def assertNoErrors(msg: String): Unit = {
if (hasErrors) {
fatal("Build halted due to previous errors")

View File

@ -0,0 +1,151 @@
package millfork.node
import millfork.error.ErrorReporting
import scala.collection.mutable
* @author Karol Stasiak
sealed trait VariableVertex {
def function: String
case class ParamVertex(function: String) extends VariableVertex
case class LocalVertex(function: String) extends VariableVertex
case object GlobalVertex extends VariableVertex {
override def function = ""
trait CallGraph {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean
object RestrictiveCallGraph extends CallGraph {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
class StandardCallGraph(program: Program) extends CallGraph {
private val entryPoints = mutable.Set[String]()
// (F,G) means function F calls function G
private val callEdges = mutable.Set[(String, String)]()
// (F,G) means function G is called when building parameters for function F
private val paramEdges = mutable.Set[(String, String)]()
private val multiaccessibleFunctions = mutable.Set[String]()
private val everCalledFunctions = mutable.Set[String]()
private val allFunctions = mutable.Set[String]()
entryPoints += "main"
program.declarations.foreach(s => add(None, Nil, s))
def add(currentFunction: Option[String], callingFunctions: List[String], node: Node): Unit = {
node match {
case f: FunctionDeclarationStatement =>
allFunctions +=
if (f.address.isDefined || f.interrupt) entryPoints +=
f.statements.getOrElse(Nil).foreach(s => this.add(Some(, Nil, s))
case s: Statement =>
s.getAllExpressions.foreach(e => add(currentFunction, callingFunctions, e))
case g: FunctionCallExpression =>
everCalledFunctions += g.functionName
currentFunction.foreach(f => callEdges += f -> g.functionName)
callingFunctions.foreach(f => paramEdges += f -> g.functionName)
g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr))
case x: VariableExpression =>
val varName =".hi").stripSuffix(".lo").stripSuffix(".addr")
everCalledFunctions += varName
case _ => ()
def fillOut(): Unit = {
var changed = true
while (changed) {
changed = false
val toAdd = for {
(a, b) <- callEdges
(c, d) <- callEdges
if b == c
if !callEdges.contains(a -> d)
} yield (a, d)
if (toAdd.nonEmpty) {
callEdges ++= toAdd
changed = true
changed = true
while (changed) {
changed = false
val toAdd = for {
(a, b) <- paramEdges
(c, d) <- callEdges
if b == c
if !paramEdges.contains(a -> d)
} yield (a, d)
if (toAdd.nonEmpty) {
paramEdges ++= toAdd
changed = true
multiaccessibleFunctions ++= entryPoints
everCalledFunctions ++= entryPoints
callEdges.filter(e => entryPoints.contains(e._1)).foreach(e => everCalledFunctions += e._2)
multiaccessibleFunctions ++= callEdges.filter(e => entryPoints.contains(e._1)).map(_._2).groupBy(identity).filter(p => p._2.size > 1).keys
ErrorReporting.trace("Call edges:")
callEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString))
ErrorReporting.trace("Param edges:")
paramEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString))
ErrorReporting.trace("Entry points:")
ErrorReporting.trace("Multiaccessible functions:")
ErrorReporting.trace("Ever called functions:")
def isEverCalled(function: String): Boolean = {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = {
if (a.function == b.function) {
return false
if (a == GlobalVertex || b == GlobalVertex) {
return false
if (multiaccessibleFunctions(a.function) || multiaccessibleFunctions(b.function)) {
return false
if (callEdges(a.function -> b.function) || callEdges(b.function -> a.function)) {
return false
a match {
case ParamVertex(af) =>
if (paramEdges(af -> b.function)) return false
case _ =>
b match {
case ParamVertex(bf) =>
if (paramEdges(bf -> a.function)) return false
case _ =>
ErrorReporting.trace(s"$a and $b can overlap")

package millfork.node
import millfork.assembly.{AddrMode, Opcode}
import millfork.env.{Label, ParamPassingConvention}
case class Position(filename: String, line: Int, column: Int, cursor: Int)
sealed trait Node {
var position: Option[Position] = None
object Node {
implicit class NodeOps[N<:Node](val node: N) extends AnyVal {
def pos(position: Position): N = {
node.position = Some(position)
sealed trait Expression extends Node {
def replaceVariable(variable: String, actualParam: Expression): Expression
case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
case class BooleanLiteralExpression(value: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
sealed trait LhsExpression extends Expression
case object BlackHoleExpression extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression {
def replaceVariable(variable: String, actualParam: Expression): Expression =
hi.replaceVariable(variable, actualParam),
lo.replaceVariable(variable, actualParam))
case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
SumExpression( { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal)
case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
FunctionCallExpression(functionName, {
_.replaceVariable(variable, actualParam)
case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte)
object Register extends Enumeration {
val A, X, Y, AX, AY, YA, XA, XY, YX = Value
//case class Indexing(child: Expression, register: Register.Value) extends Expression
case class VariableExpression(name: String) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
if (name == variable) actualParam else this
case class IndexedExpression(name: String, index: Expression) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
if (name == variable) {
actualParam match {
case VariableExpression(actualVariable) => IndexedExpression(actualVariable, index.replaceVariable(variable, actualParam))
case _ => ??? // TODO
} else IndexedExpression(name, index.replaceVariable(variable, actualParam))
sealed trait Statement extends Node {
def getAllExpressions: List[Expression]
sealed trait DeclarationStatement extends Statement
case class TypeDefinitionStatement(name: String, parent: String) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = Nil
case class VariableDeclarationStatement(name: String,
typ: String,
global: Boolean,
stack: Boolean,
constant: Boolean,
volatile: Boolean,
initialValue: Option[Expression],
address: Option[Expression]) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = List(initialValue, address).flatten
case class ArrayDeclarationStatement(name: String,
length: Option[Expression],
address: Option[Expression],
elements: Option[List[Expression]]) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = List(length, address).flatten ++ elements.getOrElse(Nil)
case class ParameterDeclaration(typ: String,
assemblyParamPassingConvention: ParamPassingConvention) extends Node
case class ImportStatement(filename: String) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = Nil
case class FunctionDeclarationStatement(name: String,
resultType: String,
params: List[ParameterDeclaration],
address: Option[Expression],
statements: Option[List[Statement]],
inlined: Boolean,
assembly: Boolean,
interrupt: Boolean,
reentrant: Boolean) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions)
sealed trait ExecutableStatement extends Statement
case class ExpressionStatement(expression: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(expression)
case class ReturnStatement(value: Option[Expression]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = value.toList
case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(destination, source)
case class LabelStatement(label: Label) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = Nil
case class AssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, expression: Expression, elidable: Boolean) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(expression)
case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions)
case class WhileStatement(condition: Expression, body: List[ExecutableStatement]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
object ForDirection extends Enumeration {
val To, Until, DownTo, ParallelTo, ParallelUntil = Value
case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = start :: end :: body.flatMap(_.getAllExpressions)
case class DoWhileStatement(body: List[ExecutableStatement], condition: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
case class BlockStatement(body: List[ExecutableStatement]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = body.flatMap(_.getAllExpressions)
object AssemblyStatement {
def implied(opcode: Opcode.Value, elidable: Boolean) = AssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidable)
def nonexistent(opcode: Opcode.Value) = AssemblyStatement(opcode, AddrMode.DoesNotExist, LiteralExpression(0, 1), elidable = true)

package millfork.node
import millfork.node.opt.NodeOptimization
* @author Karol Stasiak
case class Program(declarations: List[DeclarationStatement]) {
def applyNodeOptimization(o: NodeOptimization) = Program(o.optimize(declarations).asInstanceOf[List[DeclarationStatement]])
def +(p:Program): Program = Program(this.declarations ++ p.declarations)

package millfork.node.opt
import millfork.node.{ExecutableStatement, Expression, Node, Statement}
* @author Karol Stasiak
trait NodeOptimization {
def optimize(nodes: List[Node]): List[Node]
def optimizeExecutableStatements(nodes: List[ExecutableStatement]): List[ExecutableStatement] =
def optimizeStatements(nodes: List[Statement]): List[Statement] =

package millfork.node.opt
import millfork.node._
* @author Karol Stasiak
object UnreachableCode extends NodeOptimization {
override def optimize(nodes: List[Node]): List[Node] = nodes match {
case (x:FunctionDeclarationStatement)::xs =>
x.copy(statements = :: optimize(xs)
case (x:IfStatement)::xs =>
thenBranch = optimizeExecutableStatements(x.thenBranch),
elseBranch = optimizeExecutableStatements(x.elseBranch)) :: optimize(xs)
case (x:WhileStatement)::xs =>
x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs)
case (x:DoWhileStatement)::xs =>
x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs)
case (x:ReturnStatement) :: xs =>
x :: Nil
case x :: xs =>
x :: optimize(xs)
case Nil =>

package millfork.node.opt
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node._
* @author Karol Stasiak
object UnusedFunctions extends NodeOptimization {
override def optimize(nodes: List[Node]): List[Node] = {
val allNormalFunctions = nodes.flatMap {
case v: FunctionDeclarationStatement => if (v.address.isDefined || v.interrupt || == "main") Nil else List(
case _ => Nil
val allCalledFunctions = getAllCalledFunctions(nodes).toSet
val unusedFunctions = allNormalFunctions -- allCalledFunctions
if (unusedFunctions.nonEmpty) {
ErrorReporting.debug("Removing unused functions: " + unusedFunctions.mkString(", "))
removeFunctionsFromProgram(nodes, unusedFunctions)
private def removeFunctionsFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = {
nodes match {
case (x: FunctionDeclarationStatement) :: xs if unusedVariables( =>
removeFunctionsFromProgram(xs, unusedVariables)
case x :: xs =>
x :: removeFunctionsFromProgram(xs, unusedVariables)
case Nil =>
def getAllCalledFunctions(c: Constant): List[String] = c match {
case HalfWordConstant(cc, _) => getAllCalledFunctions(cc)
case SubbyteConstant(cc, _) => getAllCalledFunctions(cc)
case CompoundConstant(_, l, r) => getAllCalledFunctions(l) ++ getAllCalledFunctions(r)
case MemoryAddressConstant(th) => List(,".addr"),".hi"),".lo"),".addr.lo"),".addr.hi"))
case _ => Nil
def getAllCalledFunctions(expressions: List[Node]): List[String] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil))
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => List(,".addr"),".hi"),".lo"),".addr.lo"),".addr.hi"))
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
case SumExpression(xs, _) => getAllCalledFunctions(
case FunctionCallExpression(name, xs) => name :: getAllCalledFunctions(xs)
case IndexedExpression(arr, index) => arr :: getAllCalledFunctions(List(index))
case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
case _ => Nil

package millfork.node.opt
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node._
* @author Karol Stasiak
object UnusedGlobalVariables extends NodeOptimization {
override def optimize(nodes: List[Node]): List[Node] = {
// TODO: volatile
val allNonvolatileGlobalVariables = nodes.flatMap {
case v: VariableDeclarationStatement => if (v.address.isDefined) Nil else List(
case v: ArrayDeclarationStatement => if (v.address.isDefined) Nil else List(
case _ => Nil
val allReadVariables = getAllReadVariables(nodes).toSet
val unusedVariables = allNonvolatileGlobalVariables -- allReadVariables
if (unusedVariables.nonEmpty) {
ErrorReporting.debug("Removing unused global variables: " + unusedVariables.mkString(", "))
removeVariablesFromProgram(nodes, unusedVariables.flatMap(v => Set(v, v + ".hi", v + ".lo")))
private def removeVariablesFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = {
nodes match {
case (x: ArrayDeclarationStatement) :: xs if unusedVariables( => removeVariablesFromProgram(xs, unusedVariables)
case (x: VariableDeclarationStatement) :: xs if unusedVariables( => removeVariablesFromProgram(xs, unusedVariables)
case (x: FunctionDeclarationStatement) :: xs =>
x.copy(statements = => removeVariablesFromStatement(s, unusedVariables))) :: removeVariablesFromProgram(xs, unusedVariables)
case x :: xs =>
x :: removeVariablesFromProgram(xs, unusedVariables)
case Nil =>
def getAllReadVariables(c: Constant): List[String] = c match {
case HalfWordConstant(cc, _) => getAllReadVariables(cc)
case SubbyteConstant(cc, _) => getAllReadVariables(cc)
case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r)
case MemoryAddressConstant(th) => List( != '.'))
case _ => Nil
def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.initialValue.toList)
case s: ArrayDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.elements.getOrElse(Nil))
case s: FunctionDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllReadVariables(expr :: Nil)
case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => getAllReadVariables(params)
case s: Statement => getAllReadVariables(s.getAllExpressions)
case s: VariableExpression => List( != '.'))
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil)
case SumExpression(xs, _) => getAllReadVariables(
case FunctionCallExpression(name, xs) => name :: getAllReadVariables(xs)
case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index))
case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l))
case _ => Nil
def removeVariablesFromStatement(statements: List[Statement], globalsToRemove: Set[String]): List[Statement] = statements.flatMap {
case s: VariableDeclarationStatement =>
if (globalsToRemove( None else Some(s)
case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") =>
if (globalsToRemove(n)) else Some(s)
case s@Assignment(VariableExpression(n), expr) =>
if (globalsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s)
case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) =>
if (globalsToRemove(h)) {
if (globalsToRemove(l))
Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr))
} else {
if (globalsToRemove(l))
Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr))
case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) =>
if (globalsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr))
else Some(s)
case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) =>
if (globalsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr))
else Some(s)
case s: IfStatement =>
thenBranch = removeVariablesFromStatement(s.thenBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]],
elseBranch = removeVariablesFromStatement(s.elseBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s: WhileStatement =>
body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s: DoWhileStatement =>
body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s => Some(s)

package millfork.node.opt
import millfork.assembly.AssemblyLine
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node._
* @author Karol Stasiak
object UnusedLocalVariables extends NodeOptimization {
override def optimize(nodes: List[Node]): List[Node] = nodes match {
case (x: FunctionDeclarationStatement) :: xs =>
x.copy(statements = :: optimize(xs)
case x :: xs =>
x :: optimize(xs)
case Nil =>
def getAllLocalVariables(statements: List[Statement]): List[String] = statements.flatMap {
case v: VariableDeclarationStatement => List(
case x: IfStatement => getAllLocalVariables(x.thenBranch) ++ getAllLocalVariables(x.elseBranch)
case x: WhileStatement => getAllLocalVariables(x.body)
case x: DoWhileStatement => getAllLocalVariables(x.body)
case _ => Nil
def getAllReadVariables(c: Constant): List[String] = c match {
case HalfWordConstant(cc, _) => getAllReadVariables(cc)
case SubbyteConstant(cc, _) => getAllReadVariables(cc)
case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r)
case MemoryAddressConstant(th) => List(,".addr"),".hi"),".lo"),".addr.lo"),".addr.hi"))
case _ => Nil
def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap {
case s: VariableExpression => List(,".addr"),".hi"),".lo"),".addr.lo"),".addr.hi"))
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil)
case SumExpression(xs, _) => getAllReadVariables(
case FunctionCallExpression(_, xs) => getAllReadVariables(xs)
case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index))
case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l))
case _ => Nil
def optimizeVariables(statements: List[Statement]): List[Statement] = {
val allLocals = getAllLocalVariables(statements)
val allRead = getAllReadVariables(statements.flatMap {
case Assignment(VariableExpression(_), expression) => List(expression)
case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => params
case x => x.getAllExpressions
val localsToRemove = allLocals.filterNot(allRead).toSet
if (localsToRemove.nonEmpty) {
ErrorReporting.debug("Removing unused local variables: " + localsToRemove.mkString(", "))
removeVariables(statements, localsToRemove)
def removeVariables(statements: List[Statement], localsToRemove: Set[String]): List[Statement] = statements.flatMap {
case s: VariableDeclarationStatement =>
if (localsToRemove( None else Some(s)
case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") =>
if (localsToRemove(n)) else Some(s)
case s@Assignment(VariableExpression(n), expr) =>
if (localsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s)
case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) =>
if (localsToRemove(h)) {
if (localsToRemove(l))
Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr))
} else {
if (localsToRemove(l))
Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr))
case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) =>
if (localsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr))
else Some(s)
case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) =>
if (localsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr))
else Some(s)
case s: IfStatement =>
thenBranch = removeVariables(s.thenBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]],
elseBranch = removeVariables(s.elseBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s: WhileStatement =>
body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s: DoWhileStatement =>
body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
case s => Some(s)

package millfork.output
import millfork.assembly.opt.AssemblyOptimization
import millfork.assembly.{AddrMode, AssemblyLine, Opcode}
import millfork.compiler.{CompilationContext, MlCompiler}
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node.CallGraph
import millfork.{CompilationFlag, CompilationOptions}
import scala.collection.mutable
* @author Karol Stasiak
case class AssemblerOutput(code: Array[Byte], asm: Array[String], labels: List[(String, Int)])
class Assembler(private val rootEnv: Environment) {
var env = rootEnv.allThings
var unoptimizedCodeSize = 0
var optimizedCodeSize = 0
var initializedArraysSize = 0
val mem = new CompiledMemory
val labelMap = mutable.Map[String, Int]()
val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]()
val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]()
def writeByte(bank: Int, addr: Int, value: Byte): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).readable(addr) = true
mem.banks(bank).output(addr) = value.toByte
def writeByte(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).readable(addr) = true
value match {
case NumericConstant(x, _) =>
if (x > 0xffff) ErrorReporting.error("Byte overflow")
mem.banks(0).output(addr) = x.toByte
case _ =>
bytesToWriteLater += addr -> value
def writeWord(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).occupied(addr + 1) = true
mem.banks(bank).readable(addr) = true
mem.banks(bank).readable(addr + 1) = true
value match {
case NumericConstant(x, _) =>
if (x > 0xffff) ErrorReporting.error("Word overflow")
mem.banks(bank).output(addr) = x.toByte
mem.banks(bank).output(addr + 1) = (x >> 8).toByte
case _ =>
wordsToWriteLater += addr -> value
def deepConstResolve(c: Constant): Long = {
c match {
case NumericConstant(v, _) => v
case MemoryAddressConstant(th) =>
if (labelMap.contains( return labelMap(
if (labelMap.contains( + "`")) return labelMap(
if (labelMap.contains( + ".addr")) return labelMap(
val x1 = env.maybeGet[ConstantThing](
val x2 = env.maybeGet[ConstantThing]( + "`").map(_.value)
val x3 = env.maybeGet[NormalFunction](
val x4 = env.maybeGet[ConstantThing]( + ".addr").map(_.value)
val x5 = env.maybeGet[RelativeVariable](
val x6 = env.maybeGet[ConstantThing](".array") + ".addr").map(_.value)
val x = x1.orElse(x2).orElse(x3).orElse(x4).orElse(x5).orElse(x6)
x match {
case Some(cc) =>
case None =>
case HalfWordConstant(cc, true) => deepConstResolve(cc).>>>(8).&(0xff)
case HalfWordConstant(cc, false) => deepConstResolve(cc).&(0xff)
case SubbyteConstant(cc, i) => deepConstResolve(cc).>>>(i * 8).&(0xff)
case CompoundConstant(operator, lc, rc) =>
val l = deepConstResolve(lc)
val r = deepConstResolve(rc)
operator match {
case MathOperator.Plus => l + r
case MathOperator.Minus => l - r
case MathOperator.Times => l * r
case MathOperator.Shl => l << r
case MathOperator.Shr => l >>> r
case MathOperator.DecimalPlus => asDecimal(l, r, _ + _)
case MathOperator.DecimalMinus => asDecimal(l, r, _ - _)
case MathOperator.DecimalTimes => asDecimal(l, r, _ * _)
case MathOperator.DecimalShl => asDecimal(l, 1 << r, _ * _)
case MathOperator.DecimalShr => asDecimal(l, 1 << r, _ / _)
case MathOperator.And => l & r
case MathOperator.Exor => l ^ r
case MathOperator.Or => l | r
private def parseNormalToDecimalValue(a: Long): Long = {
if (a < 0) -parseNormalToDecimalValue(-a)
var x = a
var result = 0L
var multiplier = 1L
while (x > 0) {
result += multiplier * (a % 16L)
x /= 16L
multiplier *= 10L
private def storeDecimalValueInNormalRespresentation(a: Long): Long = {
if (a < 0) -storeDecimalValueInNormalRespresentation(-a)
var x = a
var result = 0L
var multiplier = 1L
while (x > 0) {
result += multiplier * (a % 10L)
x /= 10L
multiplier *= 16L
private def asDecimal(a: Long, b: Long, f: (Long, Long) => Long): Long =
storeDecimalValueInNormalRespresentation(f(parseNormalToDecimalValue(a), parseNormalToDecimalValue(b)))
def assemble(callGraph: CallGraph, optimizations: Seq[AssemblyOptimization], options: CompilationOptions): AssemblerOutput = {
val platform = options.platform
val assembly = mutable.ArrayBuffer[String]()
env.allPreallocatables.foreach {
case InitializedArray(name, Some(NumericConstant(address, _)), items) =>
var index = address.toInt
assembly.append("* = $" + index.toHexString)
for (item <- items) {
writeByte(0, index, item)
assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true
index += 1
initializedArraysSize += items.length
case InitializedArray(name, Some(_), items) => ???
case f: NormalFunction if f.address.isDefined =>
var index = f.address.get.asInstanceOf[NumericConstant].value.toInt
labelMap( = index
compileFunction(f, index, optimizations, assembly, options)
case _ =>
var index =
env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && == "main" =>
labelMap( = index
index = compileFunction(f, index, optimizations, assembly, options)
case _ =>
env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && != "main" =>
labelMap( = index
index = compileFunction(f, index, optimizations, assembly, options)
case _ =>
env.allPreallocatables.foreach {
case InitializedArray(name, None, items) =>
labelMap(name) = index
assembly.append("* = $" + index.toHexString)
for (item <- items) {
writeByte(0, index, item)
assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true
index += 1
initializedArraysSize += items.length
case _ =>
val allocator = platform.allocator
allocator.onEachByte = { addr =>
mem.banks(0).readable(addr) = true
mem.banks(0).writeable(addr) = true
env.allocateVariables(None, callGraph, allocator, options, labelMap.put)
env = rootEnv.allThings
for ((addr, b) <- bytesToWriteLater) {
val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte
for ((addr, b) <- wordsToWriteLater) {
val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte
mem.banks(0).output(addr + 1) = value.>>>(8).toByte
val start = mem.banks(0).occupied.indexOf(true)
val end = mem.banks(0).occupied.lastIndexOf(true)
val length = end - start + 1
mem.banks(0).start = start
mem.banks(0).end = end
labelMap.toList.sorted.foreach {case (l, v) =>
assembly += f"$l%-30s = $$$v%04X"
labelMap.toList.sortBy{case (a,b) => b->a}.foreach {case (l, v) =>
assembly += f" ; $$$v%04X = $l%s"
AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList)
private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
ErrorReporting.debug("Compiling: " +, f.position)
var index = startFrom
assOut.append("* = $" + startFrom.toHexString)
val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize
unoptimizedCodeSize +=
val code = optimizations.foldLeft(unoptimized) { (c, opt) =>
opt.optimize(f, c, options)
optimizedCodeSize +=
import millfork.assembly.AddrMode._
import millfork.assembly.Opcode._
for (instr <- code) {
if (instr.isPrintable) {
instr match {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(labelName)), _) =>
labelMap(labelName) = index
case AssemblyLine(_, DoesNotExist, _, _) =>
case AssemblyLine(op, Implied, _, _) =>
writeByte(0, index, Assembler.opcodeFor(op, Implied, options))
index += 1
case AssemblyLine(op, Relative, param, _) =>
writeByte(0, index, Assembler.opcodeFor(op, Relative, options))
writeByte(0, index + 1, param - (index + 2))
index += 2
case AssemblyLine(op, am@(Immediate | ZeroPage | ZeroPageX | ZeroPageY | IndexedY | IndexedX | ZeroPageIndirect), param, _) =>
writeByte(0, index, Assembler.opcodeFor(op, am, options))
writeByte(0, index + 1, param)
index += 2
case AssemblyLine(op, am@(Absolute | AbsoluteY | AbsoluteX | Indirect | AbsoluteIndexedX), param, _) =>
writeByte(0, index, Assembler.opcodeFor(op, am, options))
writeWord(0, index + 1, param)
index += 3
object Assembler {
val opcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
val illegalOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
val cmosOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
def opcodeFor(opcode: Opcode.Value, addrMode: AddrMode.Value, options: CompilationOptions): Byte = {
val key = opcode -> addrMode
opcodes.get(key) match {
case Some(v) => v
case None =>
illegalOpcodes.get(key) match {
case Some(v) =>
if (options.flag(CompilationFlag.EmitIllegals)) v
else ErrorReporting.fatal("Cannot assemble an illegal opcode " + key)
case None =>
cmosOpcodes.get(key) match {
case Some(v) =>
if (options.flag(CompilationFlag.EmitCmosOpcodes)) v
else ErrorReporting.fatal("Cannot assemble a CMOS opcode " + key)
case None =>
ErrorReporting.fatal("Cannot assemble an unknown opcode " + key)
private def op(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
if (x < 0 || x > 0xff) ???
opcodes(op -> am) = x.toByte
if (am == AddrMode.Relative) opcodes(op -> AddrMode.Immediate) = x.toByte
private def cm(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
if (x < 0 || x > 0xff) ???
cmosOpcodes(op -> am) = x.toByte
private def il(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
if (x < 0 || x > 0xff) ???
illegalOpcodes(op -> am) = x.toByte
def getStandardLegalOpcodes: Set[Int] = & 0xff).toSet
import AddrMode._
import Opcode._
op(ADC, Immediate, 0x69)
op(ADC, ZeroPage, 0x65)
op(ADC, ZeroPageX, 0x75)
op(ADC, Absolute, 0x6D)
op(ADC, AbsoluteX, 0x7D)
op(ADC, AbsoluteY, 0x79)
op(ADC, IndexedX, 0x61)
op(ADC, IndexedY, 0x71)
op(AND, Immediate, 0x29)
op(AND, ZeroPage, 0x25)
op(AND, ZeroPageX, 0x35)
op(AND, Absolute, 0x2D)
op(AND, AbsoluteX, 0x3D)
op(AND, AbsoluteY, 0x39)
op(AND, IndexedX, 0x21)
op(AND, IndexedY, 0x31)
op(ASL, Implied, 0x0A)
op(ASL, ZeroPage, 0x06)
op(ASL, ZeroPageX, 0x16)
op(ASL, Absolute, 0x0E)
op(ASL, AbsoluteX, 0x1E)
op(BIT, ZeroPage, 0x24)
op(BIT, Absolute, 0x2C)
op(BPL, Relative, 0x10)
op(BMI, Relative, 0x30)
op(BVC, Relative, 0x50)
op(BVS, Relative, 0x70)
op(BCC, Relative, 0x90)
op(BCS, Relative, 0xB0)
op(BNE, Relative, 0xD0)
op(BEQ, Relative, 0xF0)
op(BRK, Implied, 0)
op(CMP, Immediate, 0xC9)
op(CMP, ZeroPage, 0xC5)
op(CMP, ZeroPageX, 0xD5)
op(CMP, Absolute, 0xCD)
op(CMP, AbsoluteX, 0xDD)
op(CMP, AbsoluteY, 0xD9)
op(CMP, IndexedX, 0xC1)
op(CMP, IndexedY, 0xD1)
op(CPX, Immediate, 0xE0)
op(CPX, ZeroPage, 0xE4)
op(CPX, Absolute, 0xEC)
op(CPY, Immediate, 0xC0)
op(CPY, ZeroPage, 0xC4)
op(CPY, Absolute, 0xCC)
op(DEC, ZeroPage, 0xC6)
op(DEC, ZeroPageX, 0xD6)
op(DEC, Absolute, 0xCE)
op(DEC, AbsoluteX, 0xDE)
op(EOR, Immediate, 0x49)
op(EOR, ZeroPage, 0x45)
op(EOR, ZeroPageX, 0x55)
op(EOR, Absolute, 0x4D)
op(EOR, AbsoluteX, 0x5D)
op(EOR, AbsoluteY, 0x59)
op(EOR, IndexedX, 0x41)
op(EOR, IndexedY, 0x51)
op(INC, ZeroPage, 0xE6)
op(INC, ZeroPageX, 0xF6)
op(INC, Absolute, 0xEE)
op(INC, AbsoluteX, 0xFE)
op(CLC, Implied, 0x18)
op(SEC, Implied, 0x38)
op(CLI, Implied, 0x58)
op(SEI, Implied, 0x78)
op(CLV, Implied, 0xB8)
op(CLD, Implied, 0xD8)
op(SED, Implied, 0xF8)
op(JMP, Absolute, 0x4C)
op(JMP, Indirect, 0x6C)
op(JSR, Absolute, 0x20)
op(LDA, Immediate, 0xA9)
op(LDA, ZeroPage, 0xA5)
op(LDA, ZeroPageX, 0xB5)
op(LDA, Absolute, 0xAD)
op(LDA, AbsoluteX, 0xBD)
op(LDA, AbsoluteY, 0xB9)
op(LDA, IndexedX, 0xA1)
op(LDA, IndexedY, 0xB1)
op(LDX, Immediate, 0xA2)
op(LDX, ZeroPage, 0xA6)
op(LDX, ZeroPageY, 0xB6)
op(LDX, Absolute, 0xAE)
op(LDX, AbsoluteY, 0xBE)
op(LDY, Immediate, 0xA0)
op(LDY, ZeroPage, 0xA4)
op(LDY, ZeroPageX, 0xB4)
op(LDY, Absolute, 0xAC)
op(LDY, AbsoluteX, 0xBC)
op(LSR, Implied, 0x4A)
op(LSR, ZeroPage, 0x46)
op(LSR, ZeroPageX, 0x56)
op(LSR, Absolute, 0x4E)
op(LSR, AbsoluteX, 0x5E)
op(NOP, Implied, 0xEA)
op(ORA, Immediate, 0x09)
op(ORA, ZeroPage, 0x05)
op(ORA, ZeroPageX, 0x15)
op(ORA, Absolute, 0x0D)
op(ORA, AbsoluteX, 0x1D)
op(ORA, AbsoluteY, 0x19)
op(ORA, IndexedX, 0x01)
op(ORA, IndexedY, 0x11)
op(TAX, Implied, 0xAA)
op(TXA, Implied, 0x8A)
op(DEX, Implied, 0xCA)
op(INX, Implied, 0xE8)
op(TAY, Implied, 0xA8)
op(TYA, Implied, 0x98)
op(DEY, Implied, 0x88)
op(INY, Implied, 0xC8)
op(ROL, Implied, 0x2A)
op(ROL, ZeroPage, 0x26)
op(ROL, ZeroPageX, 0x36)
op(ROL, Absolute, 0x2E)
op(ROL, AbsoluteX, 0x3E)
op(ROR, Implied, 0x6A)
op(ROR, ZeroPage, 0x66)
op(ROR, ZeroPageX, 0x76)
op(ROR, Absolute, 0x6E)
op(ROR, AbsoluteX, 0x7E)
op(RTI, Implied, 0x40)
op(RTS, Implied, 0x60)
op(SBC, Immediate, 0xE9)
op(SBC, ZeroPage, 0xE5)
op(SBC, ZeroPageX, 0xF5)
op(SBC, Absolute, 0xED)
op(SBC, AbsoluteX, 0xFD)
op(SBC, AbsoluteY, 0xF9)
op(SBC, IndexedX, 0xE1)
op(SBC, IndexedY, 0xF1)
op(STA, ZeroPage, 0x85)
op(STA, ZeroPageX, 0x95)
op(STA, Absolute, 0x8D)
op(STA, AbsoluteX, 0x9D)
op(STA, AbsoluteY, 0x99)
op(STA, IndexedX, 0x81)
op(STA, IndexedY, 0x91)
op(TXS, Implied, 0x9A)
op(TSX, Implied, 0xBA)
op(PHA, Implied, 0x48)
op(PLA, Implied, 0x68)
op(PHP, Implied, 0x08)
op(PLP, Implied, 0x28)
op(STX, ZeroPage, 0x86)
op(STX, ZeroPageY, 0x96)
op(STX, Absolute, 0x8E)
op(STY, ZeroPage, 0x84)
op(STY, ZeroPageX, 0x94)
op(STY, Absolute, 0x8C)
il(LAX, ZeroPage, 0xA7)
il(LAX, ZeroPageY, 0xB7)
il(LAX, Absolute, 0xAF)
il(LAX, AbsoluteY, 0xBF)
il(LAX, IndexedX, 0xA3)
il(LAX, IndexedY, 0xB3)
il(SAX, ZeroPage, 0x87)
il(SAX, ZeroPageY, 0x97)
il(SAX, Absolute, 0x8F)
il(TAS, AbsoluteY, 0x9B)
il(AHX, AbsoluteY, 0x9F)
il(SAX, IndexedX, 0x83)
il(AHX, IndexedY, 0x93)
il(ANC, Immediate, 0x0B)
il(ALR, Immediate, 0x4B)
il(ARR, Immediate, 0x6B)
il(XAA, Immediate, 0x8B)
il(LXA, Immediate, 0xAB)
il(SBX, Immediate, 0xCB)
il(SLO, ZeroPage, 0x07)
il(SLO, ZeroPageX, 0x17)
il(SLO, IndexedX, 0x03)
il(SLO, IndexedY, 0x13)
il(SLO, Absolute, 0x0F)
il(SLO, AbsoluteX, 0x1F)
il(SLO, AbsoluteY, 0x1B)
il(RLA, ZeroPage, 0x27)
il(RLA, ZeroPageX, 0x37)
il(RLA, IndexedX, 0x23)
il(RLA, IndexedY, 0x33)
il(RLA, Absolute, 0x2F)
il(RLA, AbsoluteX, 0x3F)
il(RLA, AbsoluteY, 0x3B)
il(SRE, ZeroPage, 0x47)
il(SRE, ZeroPageX, 0x57)
il(SRE, IndexedX, 0x43)
il(SRE, IndexedY, 0x53)
il(SRE, Absolute, 0x4F)
il(SRE, AbsoluteX, 0x5F)
il(SRE, AbsoluteY, 0x5B)
il(RRA, ZeroPage, 0x67)
il(RRA, ZeroPageX, 0x77)
il(RRA, IndexedX, 0x63)
il(RRA, IndexedY, 0x73)
il(RRA, Absolute, 0x6F)
il(RRA, AbsoluteX, 0x7F)
il(RRA, AbsoluteY, 0x7B)
il(DCP, ZeroPage, 0xC7)
il(DCP, ZeroPageX, 0xD7)
il(DCP, IndexedX, 0xC3)
il(DCP, IndexedY, 0xD3)
il(DCP, Absolute, 0xCF)
il(DCP, AbsoluteX, 0xDF)
il(DCP, AbsoluteY, 0xDB)
il(ISC, ZeroPage, 0xE7)
il(ISC, ZeroPageX, 0xF7)
il(ISC, IndexedX, 0xE3)
il(ISC, IndexedY, 0xF3)
il(ISC, Absolute, 0xEF)
il(ISC, AbsoluteX, 0xFF)
il(ISC, AbsoluteY, 0xFB)
il(NOP, Immediate, 0x80)
il(NOP, ZeroPage, 0x44)
il(NOP, ZeroPageX, 0x54)
il(NOP, Absolute, 0x5C)
il(NOP, AbsoluteX, 0x1C)
cm(NOP, Immediate, 0x02)
cm(NOP, ZeroPage, 0x44)
cm(NOP, ZeroPageX, 0x54)
cm(NOP, Absolute, 0x5C)
cm(STZ, ZeroPage, 0x64)
cm(STZ, ZeroPageX, 0x74)
cm(STZ, Absolute, 0x9C)
cm(STZ, AbsoluteX, 0x9E)
cm(PHX, Implied, 0xDA)
cm(PHY, Implied, 0x5A)
cm(PLX, Implied, 0xFA)
cm(PLY, Implied, 0x7A)
cm(ORA, ZeroPageIndirect, 0x12)
cm(AND, ZeroPageIndirect, 0x32)
cm(EOR, ZeroPageIndirect, 0x52)
cm(ADC, ZeroPageIndirect, 0x72)
cm(STA, ZeroPageIndirect, 0x92)
cm(LDA, ZeroPageIndirect, 0xB2)
cm(CMP, ZeroPageIndirect, 0xD2)
cm(SBC, ZeroPageIndirect, 0xF2)
cm(TSB, ZeroPage, 0x04)
cm(TSB, Absolute, 0x0C)
cm(TRB, ZeroPage, 0x14)
cm(TRB, Absolute, 0x1C)
cm(BIT, ZeroPageX, 0x34)
cm(BIT, AbsoluteX, 0x3C)
cm(INC, Implied, 0x1A)
cm(DEC, Implied, 0x3A)
cm(JMP, AbsoluteIndexedX, 0x7C)
cm(WAI, Implied, 0xCB)
cm(STP, Implied, 0xDB)

package millfork.output
import scala.collection.mutable
* @author Karol Stasiak
class CompiledMemory {
val banks = mutable.Map(0 -> new MemoryBank)
class MemoryBank {
def readByte(addr: Int) = output(addr) & 0xff
def readWord(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8)
def readMedium(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16)
def readLong(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16) + (readByte(addr + 3) << 24)
def readWord(addrHi: Int, addrLo: Int) = readByte(addrLo) + (readByte(addrHi) << 8)
val output = Array.fill[Byte](1 << 16)(0)
val occupied = Array.fill(1 << 16)(false)
val readable = Array.fill(1 << 16)(false)
val writeable = Array.fill(1 << 16)(false)
var start: Int = 0
var end: Int = 0

package millfork.output
* @author Karol Stasiak
trait OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte]
case class SequenceOutput(children: List[OutputPackager]) extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val baos = new ByteArrayOutputStream
children.foreach { c =>
val a = c.packageOutput(mem, bank)
baos.write(a, 0, a.length)
case class ConstOutput(byte: Byte) extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = Array(byte)
case class CurrentBankFragmentOutput(start: Int, end: Int) extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val b = mem.banks(bank)
b.output.slice(start, end + 1)
case class BankFragmentOutput(alwaysBank: Int, start: Int, end: Int) extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val b = mem.banks(alwaysBank)
b.output.slice(start, end + 1)
object StartAddressOutput extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val b = mem.banks(bank)
Array(b.start.toByte, b.start.>>(8).toByte)
object EndAddressOutput extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val b = mem.banks(bank)
Array(b.end.toByte, b.end.>>(8).toByte)
object AllocatedDataOutput extends OutputPackager {
def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
val b = mem.banks(bank)
b.output.slice(b.start, b.end + 1)

package millfork.output
import millfork.error.ErrorReporting
import millfork.node.{CallGraph, VariableVertex}
import millfork.{CompilationFlag, CompilationOptions}
import scala.collection.mutable
* @author Karol Stasiak
sealed trait ByteAllocator {
def notifyAboutEndOfCode(org: Int): Unit
def allocateBytes(count: Int, options: CompilationOptions): Int
class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator {
private var nextByte = startAt
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
def notifyAboutEndOfCode(org: Int): Unit = ()
class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator {
var nextByte = 0x200
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
def notifyAboutEndOfCode(org: Int): Unit = nextByte = org
class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) {
private var pointerMap = mutable.Map[Int, Set[VariableVertex]]()
private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]()
var onEachByte: (Int => Unit) = _
def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = {
pointerMap.foreach { case (addr, alreadyThere) =>
if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
pointerMap(addr) += p
return addr
pointers match {
case Nil =>
ErrorReporting.fatal("Out of zero-page memory")
case next :: rest =>
pointers = rest
onEachByte(next + 1)
pointerMap(next) = Set(p)
def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1)
def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = {
if (!variableMap.contains(count)) {
variableMap(count) = mutable.Map()
variableMap(count).foreach { case (a, alreadyThere) =>
if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
variableMap(count)(a) += p
return a
val addr = bytes.allocateBytes(count, options)
(addr to (addr + count)).foreach(onEachByte)
variableMap(count)(addr) = Set(p)
def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org)

package millfork.parser
import java.nio.file.{Files, Paths}
import fastparse.all._
import millfork.assembly.{AddrMode, Opcode}
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node._
import millfork.{CompilationOptions, SeparatedList}
* @author Karol Stasiak
case class MfParser(filename: String, input: String, currentDirectory: String, options: CompilationOptions) {
var lastPosition = Position(filename, 1, 1, 0)
var lastLabel = ""
def toAst: Parsed[Program] = program.parse(input + "\n\n\n")
private val lineStarts: Array[Int] = (0 +: input.zipWithIndex.filter(_._1 == '\n').map(_._2)).toArray
def position(label: String = ""): P[Position] = => indexToPosition(i, label))
def indexToPosition(i: Int, label: String): Position = {
val prefix = lineStarts.takeWhile(_ <= i)
val newPosition = Position(filename, prefix.length, i - prefix.last, i)
if (newPosition.cursor > lastPosition.cursor) {
lastPosition = newPosition
lastLabel = label
val comment: P[Unit] = P("//" ~/ CharsWhile(c => c != '\n' && c != '\r', min = 0) ~ ("\r\n" | "\r" | "\n"))
val SWS: P[Unit] = P(CharsWhileIn(" \t", min = 1)).opaque("<horizontal whitespace>")
val HWS: P[Unit] = P(CharsWhileIn(" \t", min = 0)).opaque("<horizontal whitespace>")
val AWS: P[Unit] = P((CharIn(" \t\n\r;") | NoCut(comment)).rep(min = 0)).opaque("<any whitespace>")
val EOL: P[Unit] = P(HWS ~ ("\r\n" | "\r" | "\n" | comment).opaque("<first line break>") ~ AWS).opaque("<line break>")
val letter: P[String] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_").!)
val letterOrDigit: P[Unit] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890"))
val lettersOrDigits: P[String] = P(CharsWhileIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890", min = 0).!)
val identifier: P[String] = P((letter ~ lettersOrDigits).map { case (a, b) => a + b }).opaque("<identifier>")
// def operator: P[String] = P(CharsWhileIn("!-+*/><=~|&^", min=1).!) // TODO: only valid operators
// TODO: 3-byte types
def size(value: Int, wordLiteral: Boolean, longLiteral: Boolean): Int =
if (value > 255 || value < -128 || wordLiteral)
if (value > 0xffff || longLiteral) 4 else 2
else 1
def sign(abs: Int, minus: Boolean): Int = if (minus) -abs else abs
val decimalAtom: P[LiteralExpression] =
for {
p <- position()
minus <- "-".!.?
s <- CharsWhileIn("1234567890", min = 1).!.opaque("<decimal digits>") ~ !("x" | "b")
} yield {
val abs = Integer.parseInt(s, 10)
val value = sign(abs, minus.isDefined)
LiteralExpression(value, size(value, s.length > 3, s.length > 5)).pos(p)
val binaryAtom: P[LiteralExpression] =
for {
p <- position()
minus <- "-".!.?
_ <- P("0b" | "%") ~/ Pass
s <- CharsWhileIn("01", min = 1).!.opaque("<binary digits>")
} yield {
val abs = Integer.parseInt(s, 2)
val value = sign(abs, minus.isDefined)
LiteralExpression(value, size(value, s.length > 8, s.length > 16)).pos(p)
val hexAtom: P[LiteralExpression] =
for {
p <- position()
minus <- "-".!.?
_ <- P("0x" | "$") ~/ Pass
s <- CharsWhileIn("1234567890abcdefABCDEF", min = 1).!.opaque("<hex digits>")
} yield {
val abs = Integer.parseInt(s, 16)
val value = sign(abs, minus.isDefined)
LiteralExpression(value, size(value, s.length > 2, s.length > 4)).pos(p)
val literalAtom: P[LiteralExpression] = binaryAtom | hexAtom | decimalAtom
val atom: P[Expression] = P(literalAtom | (position() ~ identifier).map { case (p, i) => VariableExpression(i).pos(p) })
val mlOperators = List(
List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'="),
List("||", "^^"),
List("==", "<=", ">=", "!=", "<", ">"),
List("+'", "-'", "<<'", ">>'", ">>>>", "+", "-", "&", "|", "^", "<<", ">>"),
List("*'", "*"))
val nonStatementLevel = 1 // everything but not `=`
val mathLevel = 4 // the `:` operator
def flags(allowed: String*): P[Set[String]] = StringIn(allowed: _*).!.rep(min = 0, sep = SWS).map(_.toSet).opaque("<flags>")
def variableDefinition(implicitlyGlobal: Boolean): P[DeclarationStatement] = for {
p <- position()
flags <- flags("const", "static", "volatile", "stack") ~ HWS
typ <- identifier ~ SWS
name <- identifier ~/ HWS ~/ Pass
addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("<address>") ~ HWS
initialValue <- ("=" ~/ HWS ~/ mlExpression(1)).? ~ HWS
_ <- &(EOL) ~/ ""
} yield {
VariableDeclarationStatement(name, typ,
global = implicitlyGlobal || flags("static"),
stack = flags("stack"),
constant = flags("const"),
volatile = flags("volatile"),
initialValue, addr).pos(p)
val externFunctionBody: P[Option[List[Statement]]] = P("extern" ~/ PassWith(None))
val paramDefinition: P[ParameterDeclaration] = for {
p <- position()
typ <- identifier ~/ SWS ~/ Pass
name <- identifier ~/ Pass
} yield {
ParameterDeclaration(typ, ByVariable(name)).pos(p)
val appcSimple: P[ParamPassingConvention] = P("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "stack" | "a" | "x" | "y").!.map {
case "xy" => ByRegister(Register.XY)
case "yx" => ByRegister(Register.YX)
case "ax" => ByRegister(Register.AX)
case "ay" => ByRegister(Register.AY)
case "xa" => ByRegister(Register.XA)
case "ya" => ByRegister(Register.YA)
case "a" => ByRegister(Register.A)
case "x" => ByRegister(Register.X)
case "y" => ByRegister(Register.Y)
case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`")
val appcComplex: P[ParamPassingConvention] = P((("const" | "ref").! ~/ AWS).? ~ AWS ~ identifier) map {
case (None, name) => ByVariable(name)
case (Some("const"), name) => ByConstant(name)
case (Some("ref"), name) => ByReference(name)
case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`")
val asmParamDefinition: P[ParameterDeclaration] = for {
p <- position()
typ <- identifier ~ SWS
appc <- appcSimple | appcComplex
} yield ParameterDeclaration(typ, appc).pos(p)
val arrayListContents: P[List[Expression]] = ("[" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ "]" ~/ Pass).map(_.toList)
val doubleQuotedString: P[List[Char]] = P("\"" ~/ CharsWhile(c => c != '\"' && c != '\n' && c != '\r').! ~ "\"").map(_.toList)
val codec: P[TextCodec] = P(position() ~ identifier).map {
case (_, "ascii") => TextCodec.Ascii
case (_, "petscii") => TextCodec.Petscii
case (_, "pet") => TextCodec.Petscii
case (p, x) =>
ErrorReporting.error(s"Unknown string encoding: `$x`", Some(p))
def arrayFileContents: P[List[Expression]] = for {
p <- "file" ~ HWS ~/ "(" ~/ HWS ~/ position()
filePath <- doubleQuotedString ~/ HWS
optSlice <- ("," ~/ HWS ~/ literalAtom ~/ HWS ~/ "," ~/ HWS ~/ literalAtom ~/ HWS ~/ Pass).?
_ <- ")" ~/ Pass
} yield {
val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString))
val slice = optSlice.fold(data) {
case (start, length) => data.drop(start.value.toInt).take(length.value.toInt)
} => LiteralExpression(c & 0xff, 1)).toList
def arrayStringContents: P[List[Expression]] = P(position() ~ doubleQuotedString ~/ HWS ~ codec).map {
case (p, s, co) => => LiteralExpression(co.decode(None, c), 1).pos(p))
def arrayContents: P[List[Expression]] = arrayListContents | arrayFileContents | arrayStringContents
def arrayDefinition: P[ArrayDeclarationStatement] = for {
p <- position()
name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS
length <- ("[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~ "]").? ~ HWS
addr <- ("@" ~/ HWS ~/ mlExpression(1)).? ~/ HWS
contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS
} yield ArrayDeclarationStatement(name, length, addr, contents).pos(p)
def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO
def mlExpression(level: Int): P[Expression] = {
val allowedOperators = mlOperators.drop(level).flatten
def inner: P[SeparatedList[Expression, String]] = {
for {
head <- tightMlExpression ~/ HWS
maybeOperator <- StringIn(allowedOperators: _*).!.?
maybeTail <- maybeOperator.fold[P[Option[List[(String, Expression)]]]]( => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail)))
} yield {
maybeTail.fold[SeparatedList[Expression, String]](SeparatedList.of(head))(t => SeparatedList(head, t))
def p(list: SeparatedList[Expression, String], level: Int): Expression =
if (level == mlOperators.length) list.head
else {
val xs = list.split(mlOperators(level).toSet(_))
xs.separators.distinct match {
case Nil =>
if (xs.tail.nonEmpty)
ErrorReporting.error("Too many different operators")
p(xs.head, level + 1)
case List("+") | List("-") | List("+", "-") | List("-", "+") =>
SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = false)
case List("+'") | List("-'") | List("+'", "-'") | List("-'", "+'") =>
SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = true)
case List(":") =>
if (xs.size != 2) {
ErrorReporting.error("The `:` operator can have only two arguments", xs.head.head.position)
LiteralExpression(0, 1)
} else {
SeparateBytesExpression(p(xs.head, level + 1), p(xs.tail.head._2, level + 1))
case List(op) =>
FunctionCallExpression(op, => p(value, level + 1)))
case _ =>
ErrorReporting.error("Too many different operators")
LiteralExpression(0, 1)
} => p(x, 0))
def mlLhsExpressionSimple: P[LhsExpression] = mlIndexedExpression | (position() ~ identifier).map { case (p, n) => VariableExpression(n).pos(p) }
def mlLhsExpression: P[LhsExpression] = {
val separated = position() ~ mlLhsExpressionSimple ~ HWS ~ ":" ~/ HWS ~ mlLhsExpressionSimple { case (p, h, l) => SeparateBytesExpression(h, l).pos(p) } | mlLhsExpressionSimple
def mlParenExpr: P[Expression] = P("(" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ ")")
def mlIndexedExpression: P[IndexedExpression] = for {
p <- position()
array <- identifier
index <- HWS ~ "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ "]"
} yield IndexedExpression(array, index).pos(p)
def functionCall: P[FunctionCallExpression] = for {
p <- position()
name <- identifier
params <- HWS ~ "(" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ ""
} yield FunctionCallExpression(name, params.toList).pos(p)
val expressionStatement: P[ExecutableStatement] = mlExpression(0).map(ExpressionStatement)
val assignmentStatement: P[ExecutableStatement] =
(position() ~ mlLhsExpression ~ HWS ~ "=" ~/ HWS ~ mlExpression(1)).map {
case (p, l, r) => Assignment(l, r).pos(p)
def keywordStatement: P[ExecutableStatement] = P(returnStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement)
def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) }
// TODO: label and instruction in one line
def asmLabel: P[ExecutableStatement] = (identifier ~ HWS ~ ":" ~/ HWS).map(l => AssemblyStatement(Opcode.LABEL, AddrMode.DoesNotExist, VariableExpression(l), elidable = true))
// def zeropageAddrModeHint: P[Option[Boolean]] = Pass
def asmOpcode: P[Opcode.Value] = (position() ~ letter.rep(exactly = 3).!).map { case (p, o) => Opcode.lookup(o, Some(p)) }
def asmExpression: P[Expression] = (position() ~ NoCut(
("<" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = false)) |
(">" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = true)) |
)).map { case (p, e) => e.pos(p) }
val commaX = HWS ~ "," ~ HWS ~ ("X" | "x") ~ HWS
val commaY = HWS ~ "," ~ HWS ~ ("Y" | "y") ~ HWS
def asmParameter: P[(AddrMode.Value, Expression)] = {
(SWS ~ (
("#" ~ asmExpression).map(AddrMode.Immediate -> _) |
("(" ~ HWS ~ asmExpression ~ HWS ~ ")" ~ commaY).map(AddrMode.IndexedY -> _) |
("(" ~ HWS ~ asmExpression ~ commaX ~ ")").map(AddrMode.IndexedX -> _) |
("(" ~ HWS ~ asmExpression ~ HWS ~ ")").map(AddrMode.Indirect -> _) |
(asmExpression ~ commaX).map(AddrMode.AbsoluteX -> _) |
(asmExpression ~ commaY).map(AddrMode.AbsoluteY -> _) | -> _)
)).?.map(_.getOrElse(AddrMode.Implied -> LiteralExpression(0, 1)))
def elidable: P[Boolean] = ("?".! ~/ HWS).?.map(_.isDefined)
def asmInstruction: P[ExecutableStatement] = {
val lineParser: P[(Boolean, Opcode.Value, (AddrMode.Value, Expression))] = !"}" ~ elidable ~/ asmOpcode ~/ asmParameter { case (elid, op, param) =>
AssemblyStatement(op, param._1, param._2, elid)
def asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros
def statement: P[Statement] = (position() ~ P(keywordStatement | variableDefinition(false) | expressionStatement)).map { case (p, s) => s.pos(p) }
def asmStatements: P[List[ExecutableStatement]] = ("{" ~/ AWS ~/ asmStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList)
def statements: P[List[Statement]] = ("{" ~/ AWS ~ statement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList)
def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}"
def returnStatement: P[ExecutableStatement] = ("return" ~ !letterOrDigit ~/ HWS ~ mlExpression(nonStatementLevel).?).map(ReturnStatement)
def ifStatement: P[ExecutableStatement] = for {
condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
thenBranch <- AWS ~/ executableStatements
elseBranch <- (AWS ~ "else" ~/ AWS ~/ executableStatements).?
} yield IfStatement(condition, thenBranch.toList, elseBranch.getOrElse(Nil).toList)
def whileStatement: P[ExecutableStatement] = for {
condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
body <- AWS ~ executableStatements
} yield WhileStatement(condition, body.toList)
def forDirection: P[ForDirection.Value] =
("parallel" ~ HWS ~ "to").!.map(_ => ForDirection.ParallelTo) |
("parallel" ~ HWS ~ "until").!.map(_ => ForDirection.ParallelUntil) |
"until".!.map(_ => ForDirection.Until) |
"to".!.map(_ => ForDirection.To) |
("down" ~/ HWS ~/ "to").!.map(_ => ForDirection.DownTo)
def forStatement: P[ExecutableStatement] = for {
identifier <- "for" ~ SWS ~/ identifier ~/ "," ~/ Pass
start <- mlExpression(nonStatementLevel) ~ HWS ~ "," ~/ HWS ~/ Pass
direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass
end <- mlExpression(nonStatementLevel)
body <- AWS ~ executableStatements
} yield ForStatement(identifier, start, end, direction, body.toList)
def inlineAssembly: P[ExecutableStatement] = for {
condition <- "asm" ~ !letterOrDigit ~/ Pass
body <- AWS ~ asmStatements
} yield BlockStatement(body)
//noinspection MutatorLikeMethodIsParameterless
def doWhileStatement: P[ExecutableStatement] = for {
body <- "do" ~ !letterOrDigit ~/ AWS ~ executableStatements ~/ AWS
condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
} yield DoWhileStatement(body.toList, condition)
def functionDefinition: P[DeclarationStatement] = for {
p <- position()
flags <- flags("asm", "inline", "interrupt", "reentrant") ~ HWS
returnType <- identifier ~ SWS
name <- identifier ~ HWS
params <- "(" ~/ AWS ~/ (if (flags("asm")) asmParamDefinition else paramDefinition).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ ")" ~/ AWS
addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("<address>") ~/ AWS
statements <- (externFunctionBody | (if (flags("asm")) asmStatements else statements).map(l => Some(l))) ~/ Pass
} yield {
if (flags("interrupt") && flags("inline")) ErrorReporting.error(s"Interrupt function `$name` cannot be inline", Some(p))
if (flags("interrupt") && flags("reentrant")) ErrorReporting.error("Interrupt function `$name` cannot be reentrant", Some(p))
if (flags("inline") && flags("reentrant")) ErrorReporting.error("Reentrant and inline exclude each other", Some(p))
if (flags("interrupt") && returnType != "void") ErrorReporting.error("Interrupt function `$name` has to return void", Some(p))
if (addr.isEmpty && statements.isEmpty) ErrorReporting.error("Extern function `$name` must have an address", Some(p))
if (statements.isEmpty && !flags("asm") && params.nonEmpty) ErrorReporting.error("Extern non-asm function `$name` cannot have parameters", Some(p))
if (flags("asm")) statements match {
case Some(Nil) => ErrorReporting.warn("Assembly function `$name` is empty, did you mean RTS or RTI", options, Some(p))
case Some(xs) =>
if (flags("interrupt")) {
if (xs.exists {
case AssemblyStatement(Opcode.RTS, _, _, _) => true
case _ => false
}) ErrorReporting.warn("Assembly interrupt function `$name` contains RTS, did you mean RTI?", options, Some(p))
} else {
if (xs.exists {
case AssemblyStatement(Opcode.RTI, _, _, _) => true
case _ => false
}) ErrorReporting.warn("Assembly non-interrupt function `$name` contains RTI, did you mean RTS?", options, Some(p))
if (!flags("inline")) {
xs.last match {
case AssemblyStatement(Opcode.RTS, _, _, _) => () // OK
case AssemblyStatement(Opcode.RTI, _, _, _) => () // OK
case AssemblyStatement(Opcode.JMP, _, _, _) => () // OK
case _ =>
val validReturn = if (flags("interrupt")) "RTI" else "RTS"
ErrorReporting.warn(s"Non-inline assembly function `$name` should end in " + validReturn, options, Some(p))
case None => ()
FunctionDeclarationStatement(name, returnType, params.toList,
def importStatement: Parser[ImportStatement] = ("import" ~ !letterOrDigit ~/ SWS ~/ identifier).map(ImportStatement)
def program: Parser[Program] = for {
_ <- Start ~/ AWS ~/ Pass
definitions <- (importStatement | arrayDefinition | functionDefinition | variableDefinition(true)).rep(sep = EOL)
_ <- AWS ~ End
} yield Program(definitions.toList)

package millfork.parser
import fastparse.all._
import fastparse.core
object MinimalTestCase {
def AWS: P[Unit] = "\n".rep(min = 0).opaque("<any whitespace>").log()
def EOL: P[Unit] = "\n".rep(min = 1).opaque("<line break>").log()
def identifier: P[String] = CharPred(_.isLetter).rep(min = 1).!.opaque("<identifier>").log()
def identifierWithSpace: P[String] = (identifier ~/ AWS ~/ Pass).opaque("<identifier with space>").log()
def separator: P[Unit] = ("," ~/ AWS ~/ Pass).opaque("<comma>").log()
def identifiers: P[Seq[String]] = identifierWithSpace.rep(min = 0, sep = separator)//.opaque("<separated identifiers>").log()
def array: P[Seq[String]] = ("[" ~/ AWS ~/ identifiers ~/ "]" ~/ Pass)//.opaque("<array>").log()
def arrays: Parser[Seq[Seq[String]]] = (array ~/ EOL).rep(min = 0, sep = !End ~/ Pass)//.opaque("<arrays>").log()
def program: Parser[Seq[Seq[String]]] = Start ~/ AWS ~/ arrays ~/ End

package millfork.parser
import millfork.node.Position
* @author Karol Stasiak
case class ParseException(msg: String, position: Option[Position]) extends Exception
class ParserBase(filename: String, input: String) {
def reset(): Unit = {
cursor = 0
line = 1
column = FirstColumn
private val FirstColumn = 0
private val length = input.length
private var cursor = 0
private var line = 1
private var column = FirstColumn
def position = Position(filename, line, column, cursor)
def restorePosition(p: Position): Unit = {
cursor = p.cursor
column = p.column
line = p.line
def error(msg: String, pos: Option[Position]): Nothing = throw ParseException(msg, pos)
def error(msg: String, pos: Position): Nothing = throw ParseException(msg, Some(pos))
def error(msg: String): Nothing = throw ParseException(msg, Some(position))
def error() = throw ParseException("Syntax error", Some(position))
def nextChar() = {
if (cursor >= length) error("Unexpected end of input")
val c = input(cursor)
cursor += 1
if (c == '\n') {
line += 1
column = FirstColumn
} else {
column += 1
def peekChar(): Char = {
if (cursor >= length) '\0' else input(cursor)
def require(char: Char): Char = {
val pos = position
val c = nextChar()
if (c != char) error(s"Expected `$char`", pos)
def require(p: Char=>Boolean, errorMsg: String = "Unexpected character"): Char = {
val pos = position
val c = nextChar()
if (!p(c)) error(errorMsg, pos)
def require(s: String): String = {
val c = peekChars(s.length)
if (c != s) error(s"Expected `$s`")
1 to s.length foreach (_=>nextChar())
def requireAny(s: String, errorMsg: String = "Unexpected character"): Char = {
val c = nextChar()
if (s.contains(c)) c
else error(errorMsg)
def peek2Chars(): String = {
def peekChars(n: Int): String = {
if (cursor > length - n) input.substring(cursor) else input.substring(cursor, cursor + n)
def charsWhile(pred: Char => Boolean, min: Int = 0, errorMsg: String = "Unexpected character"): String = {
val sb = new StringBuilder()
while (pred(peekChar())) {
sb += nextChar()
val s = sb.toString
if (s.length < min) error(errorMsg)
else s
def skipNextIfMatches(c: Char): Boolean = {
if (peekChar() == c) {
} else {
def either(c: Char, s: String): Unit = {
if (peekChar() == c) {
} else if (peekChars(s.length) == s) {
} else {
error(s"Expected either `$c` or `$s`")
def sepOrEnd(sep: Char, end: Char): Boolean = {
val p = position
val c = nextChar()
if (c == sep) true
else if (c==end) false
else error(s"Expected `$sep` or `$end`", p)
def anyOf[T](errorMsg: String, alternatives: (()=> T)*): T = {
alternatives.foreach { t =>
val p = position
try {
return t()
} catch {
case _: ParseException => restorePosition(p)
def surrounded[T](left: => Any, content: => T, right: => Any): T = {
val result = content
def followed[T](content: => T, right: => Any): T = {
val result = content
def attempt[T](content: => T): Option[T] = {
val p = position
try {
} catch {
case _: ParseException => None
def opaque[T](errorMsg: String)(block: =>T) :T={
try {
} catch{
case p:ParseException => error(errorMsg, p.position)

package millfork.parser
import java.nio.file.{Files, Paths}
import fastparse.core.Parsed.{Failure, Success}
import millfork.CompilationOptions
import millfork.error.ErrorReporting
import millfork.node.{ImportStatement, Position, Program}
import scala.collection.mutable
* @author Karol Stasiak
class SourceLoadingQueue(val initialFilenames: List[String], val includePath: List[String], val options: CompilationOptions) {
private val parsedModules = mutable.Map[String, Program]()
private val moduleQueue = mutable.Queue[() => Unit]()
val extension: String = ".ml"
def run(): Program = {
initialFilenames.foreach { i =>
parseModule(extractName(i), includePath, Right(i), options)
options.platform.startingModules.foreach {m =>
moduleQueue.enqueue(() => parseModule(m, includePath, Left(None), options))
while (moduleQueue.nonEmpty) {
moduleQueue.dequeueAll(_ => true).par.foreach(_())
ErrorReporting.assertNoErrors("Parse failed")
parsedModules.values.reduce(_ + _)
def lookupModuleFile(includePath: List[String], moduleName: String, position: Option[Position]): String = {
includePath.foreach { dir =>
val file = Paths.get(dir, moduleName + extension).toFile
ErrorReporting.debug("Checking " + file)
if (file.exists()) {
return file.getAbsolutePath
ErrorReporting.fatal(s"Module `$moduleName` not found", position)
def parseModule(moduleName: String, includePath: List[String], why: Either[Option[Position], String], options: CompilationOptions): Unit = {
val filename: String = why.fold(p => lookupModuleFile(includePath, moduleName, p), s => s)
ErrorReporting.debug(s"Parsing $filename")
val path = Paths.get(filename)
val parentDir = path.toFile.getAbsoluteFile.getParent
val src = new String(Files.readAllBytes(path))
val parser = MfParser(filename, src, parentDir, options)
parser.toAst match {
case Success(prog, _) =>
parsedModules.synchronized {
parsedModules.put(moduleName, prog)
prog.declarations.foreach {
case s@ImportStatement(m) =>
if (!parsedModules.contains(m)) {
moduleQueue.enqueue(() => parseModule(m, parentDir :: includePath, Left(s.position), options))
case _ => ()
case f@Failure(a, b, d) =>
ErrorReporting.error(s"Failed to parse the module `$moduleName` in $filename", Some(parser.indexToPosition(f.index, parser.lastLabel)))
// ErrorReporting.error(a.toString)
// ErrorReporting.error(b.toString)
// ErrorReporting.error(d.toString)
// ErrorReporting.error(d.traced.expected)
// ErrorReporting.error(d.traced.stack.toString)
// ErrorReporting.error(d.traced.traceParsers.toString)
// ErrorReporting.error(d.traced.fullStack.toString)
// ErrorReporting.error(f.toString)
if (parser.lastLabel != "") {
ErrorReporting.error(s"Syntax error: ${parser.lastLabel} expected", Some(parser.lastPosition))
} else {
ErrorReporting.error("Syntax error", Some(parser.lastPosition))
def extractName(i: String): String = {
val noExt = i.stripSuffix(extension)
val lastSlash = noExt.lastIndexOf('/') max noExt.lastIndexOf('\\')
if (lastSlash >= 0) i.substring(lastSlash + 1) else i

package millfork.parser
import millfork.error.ErrorReporting
import millfork.node.Position
* @author Karol Stasiak
class TextCodec(val name:String, private val map: String, private val extra: Map[Char,Int]) {
def decode(position: Option[Position], c: Char): Int = {
if (extra.contains(c)) extra(c) else {
val index = map.indexOf(c)
if (index >= 0) {
} else {
ErrorReporting.fatal("Invalid character in string in ")
object TextCodec {
val NotAChar = '\ufffd'
val Ascii = new TextCodec("ASCII", 0.until(127).map{i => if (i<32) NotAChar else i.toChar}.mkString, Map.empty)
val Petscii = new TextCodec("PETSCII",
"\ufffd" * 32 + + "@abcdefghijklmnopqrstuvwxyz[£]↑←ABCDEFGHIJKLMNOPQRSTUVWXYZ",
Map('^' -> 0x5E, 'π' -> 0x7E)

package com.grapeshot.halfnes;
import com.grapeshot.halfnes.mappers.Mapper;
import millfork.output.MemoryBank;
* Since the original CPURAM class was a convoluted mess of dependencies,
* I overrode it with mine that has only few pieces of junk glue to make it work
* @author Karol Stasiak
public class CPURAM {
private final MemoryBank mem;
// required by the CPU class for some reason
public Mapper mapper = new Mapper() {
public TVType getTVType() {
// the base class returns null, but this can't be null
return TVType.DENDY;
// required by the CPU class for some reason
public APU apu = new APU(null, null, this);
public CPURAM(MemoryBank mem) {
boolean[] readable = mem.readable();
boolean[] writeable = mem.writeable();
for (int i = 0xfffe; i >= 0; i--) {
if (readable[i]) {
// allow for dummy fetches by implied instructions
readable[i + 1] = true;
readable[0] = true;
readable[1] = true;
readable[2] = true;
for (int i = 0x100; i <= 0x1ff; i++) {
readable[i] = true;
writeable[i] = true;
for (int i = 0x4000; i <= 0x407f; i++) {
readable[i] = true;
writeable[i] = true;
for (int i = 0xc000; i <= 0xcfff; i++) {
readable[i] = true;
writeable[i] = true;
for (int i = 0xfffa; i <= 0xffff; i++) {
readable[i] = true;
writeable[i] = true;
this.mem = mem;
public final int read(int addr) {
addr &= 0xffff;
if (!mem.readable()[addr]) {
throw new RuntimeException("Can't read from $" + Integer.toHexString(addr));
return mem.output()[addr] & 0xff;
public final void write(int addr, int data) {
addr &= 0xffff;
if (!mem.writeable()[addr]) {
throw new RuntimeException("Can't write to $" + Integer.toHexString(addr));
mem.output()[addr] = (byte) data;

package millfork.test
import millfork.{Cpu, OptimizationPresets}
import millfork.assembly.opt.{AlwaysGoodOptimizations, DangerousOptimizations}
import millfork.test.emu._
import org.scalatest.{FunSuite, Matchers}
* @author Karol Stasiak
class ArraySuite extends FunSuite with Matchers {
test("Array assignment") {
val m = EmuSuperOptimizedRun(
| array output [3] @$c000
| array input = [5,6,7]
| void main () {
| copyEntry(0)
| copyEntry(1)
| copyEntry(2)
| }
| void copyEntry(byte index) {
| output[index] = input[index]
| }
m.readByte(0xc000) should equal(5)
m.readByte(0xc001) should equal(6)
m.readByte(0xc002) should equal(7)
test("Array assignment with offset") {
| array output [8] @$c000
| void main () {
| byte i
| i = 0
| while i != 6 {
| output[i + 2] = i + 1
| output[i] = output[i]
| i += 1
| }
| }
""".stripMargin) { m =>
m.readByte(0xc002) should equal(1)
m.readByte(0xc007) should equal(6)
test("Array assignment with offset 1") {
val m = new EmuRun(Cpu.StrictMos, Nil, DangerousOptimizations.All ++ OptimizationPresets.Good, true)(
| array output [8] @$c000
| void main () {
| byte i
| i = 0
| while i != 6 {
| output[i + 2] = i + 1
| output[i] = output[i]
| i += 1
| }
| }
m.readByte(0xc002) should equal(1)
m.readByte(0xc007) should equal(6)
test("Array assignment through a pointer") {
val m = EmuUnoptimizedRun(
| array output [3] @$c000
| pointer p
| void main () {
| p = output.addr
| byte i
| byte ignored
| i = 1
| word w
| w = $105
| p[i]:ignored = w
| }
m.readByte(0xc001) should equal(1)
test("Array in place math") {
| array output [4] @$c000
| void main () {
| byte i
| i = 3
| output[i] = 3
| output[i + 1 - 1] *= 4
| output[3] *= 5
| }
""".stripMargin)(_.readByte(0xc003) should equal(60))
test("Array simple read") {
| byte output @$c000
| array a[7]
| void main () {
| byte i
| i = 6
| a[i] = 6
| output = a[i]
| }
""".stripMargin)(_.readByte(0xc000) should equal(6))
test("Array simple read 2") {
| word output @$c000
| array a[7]
| void main () {
| output = 777
| byte i
| i = 6
| a[i] = 6
| output = a[i]
| }
""".stripMargin){m =>
m.readByte(0xc000) should equal(6)
m.readByte(0xc001) should equal(0)

package millfork.test
import millfork.{Cpu, OptimizationPresets}
import millfork.assembly.opt.{AlwaysGoodOptimizations, LaterOptimizations, VariableToRegisterOptimization}
import millfork.test.emu.{EmuBenchmarkRun, EmuUltraBenchmarkRun, EmuRun}
import org.scalatest.{FunSuite, Matchers}
* @author Karol Stasiak
class AssemblyOptimizationSuite extends FunSuite with Matchers {
test("Duplicate RTS") {
| void main () {
| if 1 == 1 {
| return
| }
| }
""".stripMargin) { _ => }
test("Inlining variable") {
| array output [5] @$C000
| void main () {
| byte i
| i = 0
| while (i<5) {
| output[i] = i
| i += 1
| }
| }
""".stripMargin)(_.readByte(0xc003) should equal(3))
test("Loading modified variables") {
| byte output @$C000
| void main () {
| byte x
| output = 5
| output += 1
| output += 1
| output += 1
| x = output
| output = x
| }
""".stripMargin)(_.readByte(0xc000) should equal(8))
test("Bit ops") {
| byte output @$C000
| void main () {
| output ^= output
| output |= 5 | 6
| output |= 5 | 6
| output &= 5 & 6
| output ^= 8 ^ 16
| }
""".stripMargin)(_.readByte(0xc000) should equal(28))
test("Inlining after a while") {
| array output [2]@$C000
| void main () {
| byte i
| output[0] = 6
| lol()
| i = 1
| if (i > 0) {
| output[i] = 4
| }
| }
| void lol() {}
""".stripMargin)(_.readWord(0xc000) should equal(0x406))
test("Tail call") {
| byte output @$C000
| void main () {
| if (output != 55) {
| output += 1
| main()
| }
| }
""".stripMargin)(_.readByte(0xc000) should equal(55))
test("LDA-TAY elimination") {
new EmuRun(Cpu.StrictMos, OptimizationPresets.NodeOpt, List(VariableToRegisterOptimization, AlwaysGoodOptimizations.YYY), false)(
| array mouse_pointer[64]
| array arrow[64]
| byte output @$C000
| void main () {
| byte i
| i = 0
| while i < 63 {
| mouse_pointer[i] = arrow[i]
| i += 1
| }
| }
test("Carry flag after AND-LSR") {
| byte output @$C000
| void main () {
| output = f(5)
| }
| byte f(byte x) {
| return ((x & $1E) >> 1) + 3
| }
""".stripMargin)(_.readByte(0xc000) should equal(5))
test("Index sequence") {
| array output[6] @$C000
| void main () {
| pointer o
| o = output.addr
| o[3] = 8
| o[4] = 8
| o[5] = 8
| }
""".stripMargin){m =>
m.readByte(0xc005) should equal(8)
test("Index switching") {
| array output1[6] @$C000
| array output2[6] @$C010
| array input[6] @$C010
| void main () {
| static byte a
| static byte b
| input[5] = 3
| a = five()
| b = five()
| output1[a] = input[b]
| output2[a] = input[b]
| }
| byte five() {
| return 5
| }
""".stripMargin){m =>
m.readByte(0xc005) should equal(3)
m.readByte(0xc015) should equal(3)
test("TAX-BCC-RTS-TXA optimization") {
new EmuRun(Cpu.StrictMos,
OptimizationPresets.NodeOpt, List(
AlwaysGoodOptimizations.IdempotentDuplicateRemoval), false)(
| byte output @$C000
| void main(){ delta() }
| byte delta () {
| output = 0
| byte mouse_delta
| mouse_delta = 6
| mouse_delta &= $3f
| if mouse_delta >= $20 {
| mouse_delta |= $c0
| return 3
| }
| return mouse_delta
| }
""".stripMargin).readByte(0xc000) should equal(0)
test("Memory access detection"){
| array h [4] @$C000
| array l [4] @$C404
| word output @$C00C
| word a @$C200
| void main () {
| byte i
| a = 0x102
| barrier()
| for i,0,until,4 {
| h[i]:l[i] = a
| }
| a.lo:a.hi=a
| output = a
| }
| void barrier (){}
""".stripMargin){m =>
m.readByte(0xc000) should equal(1)
m.readByte(0xc001) should equal(1)
m.readByte(0xc002) should equal(1)
m.readByte(0xc003) should equal(1)
m.readByte(0xc404) should equal(2)
m.readByte(0xc405) should equal(2)
m.readByte(0xc406) should equal(2)
m.readByte(0xc407) should equal(2)
m.readWord(0xc00c) should equal(0x201)
test("Memory access detection 2"){
| array h [4]
| array l [4]
| word output @$C00C
| word ptrh @$C000
| word ptrl @$C002
| void main () {
| ptrh = h.addr
| ptrl = l.addr
| byte i
| word a
| a = 0x102
| barrier()
| for i,0,until,4 {
| h[i]:l[i] = a
| }
| a.lo:a.hi=a
| output = a
| couput
| }
| void barrier (){}
""".stripMargin){m =>
val ptrh = 0xffff & m.readWord(0xC000)
val ptrl = 0xffff & m.readWord(0xC002)
m.readByte(ptrh + 0) should equal(1)
m.readByte(ptrh + 1) should equal(1)
m.readByte(ptrh + 2) should equal(1)
m.readByte(ptrh + 3) should equal(1)
m.readByte(ptrl + 0) should equal(2)
m.readByte(ptrl + 1) should equal(2)
m.readByte(ptrl + 2) should equal(2)
m.readByte(ptrl + 3) should equal(2)
m.readWord(0xc00c) should equal(0x201)

package millfork.test
import millfork.test.emu.EmuBenchmarkRun
import org.scalatest.{FunSuite, Matchers}
* @author Karol Stasiak
class AssemblySuite extends FunSuite with Matchers {
test("Inline assembly") {
| byte output @$c000
| void main () {
| output = 0
| asm {
| inc $c000
| }
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
test("Assembly functions") {
| byte output @$c000
| void main () {
| output = 0
| thing()
| }
| asm void thing() {
| inc $c000
| rts
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
test("Empty assembly") {
| byte output @$c000
| void main () {
| output = 1
| asm {}
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
test("Passing params to assembly") {
| byte output @$c000
| void main () {
| output = f(5)
| }
| asm byte f(byte a) {
| clc
| adc #5
| rts
| }
""".stripMargin)(_.readByte(0xc000) should equal(10))
test("Inline asm functions") {
| byte output @$c000
| void main () {
| output = 0
| f()
| f()
| }
| inline asm void f() {
| inc $c000
| rts
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
test("Inline asm functions 2") {
| byte output @$c000
| void main () {
| output = 0
| add(output, 5)
| add(output, 5)
| }
| inline asm void add(byte ref v, byte const c) {
| lda v
| clc
| adc #c
| sta v
| rts
| }
""".stripMargin)(_.readByte(0xc000) should equal(5))

Some files were not shown because too many files have changed in this diff Show More