From 48e26a05382f95cc9de96311c88a64d0260f33b3 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Thu, 7 Dec 2017 00:23:30 +0100 Subject: [PATCH] Initial code upload --- LICENSE | 674 +++++++ README.md | 4 +- build.sbt | 35 + doc/README.md | 7 + doc/target-platforms.md | 93 + doc/tutorial/01-getting-started.md | 54 + doc/tutorial/02-functions-variables.md | 16 + examples/hello_world/hello_world.mfk | 11 + include/a8.ini | 22 + include/a8_kernel.mfk | 9 + include/c128.ini | 20 + include/c128_hardware.mfk | 5 + include/c128_kernal.mfk | 5 + include/c1531.mfk | 60 + include/c16.ini | 19 + include/c264_hardware.mfk | 1 + include/c264_kernal.mfk | 5 + include/c264_ted.mfk | 20 + include/c64.ini | 37 + include/c64_basic.mfk | 6 + include/c64_cia.mfk | 40 + include/c64_hardware.mfk | 41 + include/c64_kernal.mfk | 32 + include/c64_sid.mfk | 24 + include/c64_vic.mfk | 154 ++ include/cpu6510.mfk | 3 + include/loader_0401.mfk | 15 + include/loader_0801.mfk | 15 + include/loader_1001.mfk | 15 + include/loader_1201.mfk | 15 + include/loader_1c01.mfk | 15 + include/mouse.mfk | 8 + include/pet.ini | 19 + include/pet_kernal.mfk | 5 + include/plus4.ini | 19 + include/stdio.mfk | 10 + include/stdlib.mfk | 23 + include/vic20.ini | 19 + include/vic20_3k.ini | 19 + include/vic20_8k.ini | 19 + include/vic20_kernal.mfk | 5 + project/assembly.sbt | 2 + project/build.properties | 1 + project/buildinfo.sbt | 1 + project/plugins.sbt | 0 .../scala/millfork/CompilationOptions.scala | 116 ++ src/main/scala/millfork/Main.scala | 264 +++ .../scala/millfork/OptimizationPresets.scala | 150 ++ src/main/scala/millfork/Platform.scala | 115 ++ src/main/scala/millfork/SeparatedList.scala | 50 + .../millfork/assembly/AssemblyLine.scala | 305 +++ src/main/scala/millfork/assembly/Chunk.scala | 33 + src/main/scala/millfork/assembly/Opcode.scala | 190 ++ .../opt/AlwaysGoodOptimizations.scala | 848 +++++++++ .../assembly/opt/AssemblyOptimization.scala | 14 + .../opt/ChangeIndexRegisterOptimization.scala | 155 ++ .../assembly/opt/CmosOptimizations.scala | 36 + .../assembly/opt/CoarseFlowAnalyzer.scala | 259 +++ .../assembly/opt/DangerousOptimizations.scala | 59 + .../millfork/assembly/opt/FlowAnalyzer.scala | 34 + .../assembly/opt/LaterOptimizations.scala | 242 +++ .../assembly/opt/QuantumFlowAnalyzer.scala | 425 +++++ .../assembly/opt/ReverseFlowAnalyzer.scala | 149 ++ .../opt/RuleBasedAssemblyOptimization.scala | 757 ++++++++ .../assembly/opt/SizeOptimizations.scala | 8 + .../assembly/opt/SuperOptimizer.scala | 75 + .../opt/UndocumentedOptimizations.scala | 340 ++++ .../assembly/opt/UnusedLabelRemoval.scala | 38 + .../opt/VariableToRegisterOptimization.scala | 322 ++++ src/main/scala/millfork/cli/CliOption.scala | 201 ++ src/main/scala/millfork/cli/CliParser.scala | 81 + src/main/scala/millfork/cli/CliStatus.scala | 8 + .../scala/millfork/compiler/BuiltIns.scala | 832 ++++++++ .../compiler/CompilationContext.scala | 12 + .../scala/millfork/compiler/MfCompiler.scala | 1675 +++++++++++++++++ src/main/scala/millfork/env/Constant.scala | 224 +++ src/main/scala/millfork/env/Environment.scala | 618 ++++++ src/main/scala/millfork/env/Thing.scala | 264 +++ .../scala/millfork/error/ErrorReporting.scala | 74 + src/main/scala/millfork/node/CallGraph.scala | 151 ++ src/main/scala/millfork/node/Node.scala | 181 ++ src/main/scala/millfork/node/Program.scala | 11 + .../millfork/node/opt/NodeOptimization.scala | 16 + .../millfork/node/opt/UnreachableCode.scala | 29 + .../millfork/node/opt/UnusedFunctions.scala | 72 + .../node/opt/UnusedGlobalVariables.scala | 104 + .../node/opt/UnusedLocalVariables.scala | 114 ++ .../scala/millfork/output/Assembler.scala | 612 ++++++ .../millfork/output/CompiledMemory.scala | 29 + .../millfork/output/OutputPackager.scala | 60 + .../millfork/output/VariableAllocator.scala | 96 + src/main/scala/millfork/parser/MfParser.scala | 435 +++++ .../millfork/parser/MinimalTestCase.scala | 24 + .../scala/millfork/parser/ParserBase.scala | 169 ++ .../millfork/parser/SourceLoadingQueue.scala | 89 + .../scala/millfork/parser/TextCodec.scala | 32 + .../java/com/grapeshot/halfnes/CPURAM.java | 75 + src/test/scala/millfork/test/ArraySuite.scala | 133 ++ .../test/AssemblyOptimizationSuite.scala | 281 +++ .../scala/millfork/test/AssemblySuite.scala | 99 + .../scala/millfork/test/BasicSymonTest.scala | 28 + src/test/scala/millfork/test/BitOpSuite.scala | 36 + .../scala/millfork/test/BooleanSuite.scala | 64 + .../millfork/test/ByteDecimalMathSuite.scala | 68 + .../scala/millfork/test/ByteMathSuite.scala | 158 ++ src/test/scala/millfork/test/CmosSuite.scala | 31 + .../scala/millfork/test/ComparisonSuite.scala | 264 +++ .../millfork/test/ErasthotenesSuite.scala | 37 + .../scala/millfork/test/ForLoopSuite.scala | 76 + .../scala/millfork/test/IllegalSuite.scala | 51 + .../test/InlineAssemblyFunctionsSuite.scala | 80 + src/test/scala/millfork/test/LongTest.scala | 160 ++ .../scala/millfork/test/MinimalTest.scala | 23 + .../millfork/test/NodeOptimizationSuite.scala | 32 + src/test/scala/millfork/test/NonetSuite.scala | 29 + .../millfork/test/SeparateBytesSuite.scala | 137 ++ src/test/scala/millfork/test/ShiftSuite.scala | 63 + .../millfork/test/SignExtensionSuite.scala | 44 + .../scala/millfork/test/StackVarSuite.scala | 142 ++ .../millfork/test/TypeWideningSuite.scala | 23 + .../scala/millfork/test/WordMathSuite.scala | 103 + .../millfork/test/emu/EmuBenchmarkRun.scala | 30 + .../test/emu/EmuCmosBenchmarkRun.scala | 26 + .../test/emu/EmuNodeOptimizedRun.scala | 15 + .../test/emu/EmuOptimizedCmosRun.scala | 19 + .../millfork/test/emu/EmuOptimizedRun.scala | 15 + .../scala/millfork/test/emu/EmuPlatform.scala | 19 + .../test/emu/EmuQuantumOptimizedRun.scala | 15 + src/test/scala/millfork/test/emu/EmuRun.scala | 238 +++ .../emu/EmuSuperQuantumOptimizedRun.scala | 17 + .../test/emu/EmuSuperoptimizedRun.scala | 17 + .../test/emu/EmuUltraBenchmarkRun.scala | 35 + .../test/emu/EmuUndocumentedRun.scala | 19 + .../millfork/test/emu/EmuUnoptimizedRun.scala | 9 + .../millfork/test/emu/SymonTestRam.scala | 39 + 135 files changed, 15568 insertions(+), 1 deletion(-) create mode 100644 LICENSE create mode 100644 build.sbt create mode 100644 doc/README.md create mode 100644 doc/target-platforms.md create mode 100644 doc/tutorial/01-getting-started.md create mode 100644 doc/tutorial/02-functions-variables.md create mode 100644 examples/hello_world/hello_world.mfk create mode 100644 include/a8.ini create mode 100644 include/a8_kernel.mfk create mode 100644 include/c128.ini create mode 100644 include/c128_hardware.mfk create mode 100644 include/c128_kernal.mfk create mode 100644 include/c1531.mfk create mode 100644 include/c16.ini create mode 100644 include/c264_hardware.mfk create mode 100644 include/c264_kernal.mfk create mode 100644 include/c264_ted.mfk create mode 100644 include/c64.ini create mode 100644 include/c64_basic.mfk create mode 100644 include/c64_cia.mfk create mode 100644 include/c64_hardware.mfk create mode 100644 include/c64_kernal.mfk create mode 100644 include/c64_sid.mfk create mode 100644 include/c64_vic.mfk create mode 100644 include/cpu6510.mfk create mode 100644 include/loader_0401.mfk create mode 100644 include/loader_0801.mfk create mode 100644 include/loader_1001.mfk create mode 100644 include/loader_1201.mfk create mode 100644 include/loader_1c01.mfk create mode 100644 include/mouse.mfk create mode 100644 include/pet.ini create mode 100644 include/pet_kernal.mfk create mode 100644 include/plus4.ini create mode 100644 include/stdio.mfk create mode 100644 include/stdlib.mfk create mode 100644 include/vic20.ini create mode 100644 include/vic20_3k.ini create mode 100644 include/vic20_8k.ini create mode 100644 include/vic20_kernal.mfk create mode 100644 project/assembly.sbt create mode 100644 project/build.properties create mode 100644 project/buildinfo.sbt create mode 100644 project/plugins.sbt create mode 100644 src/main/scala/millfork/CompilationOptions.scala create mode 100644 src/main/scala/millfork/Main.scala create mode 100644 src/main/scala/millfork/OptimizationPresets.scala create mode 100644 src/main/scala/millfork/Platform.scala create mode 100644 src/main/scala/millfork/SeparatedList.scala create mode 100644 src/main/scala/millfork/assembly/AssemblyLine.scala create mode 100644 src/main/scala/millfork/assembly/Chunk.scala create mode 100644 src/main/scala/millfork/assembly/Opcode.scala create mode 100644 src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala create mode 100644 src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala create mode 100644 src/main/scala/millfork/assembly/opt/CmosOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala create mode 100644 src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala create mode 100644 src/main/scala/millfork/assembly/opt/LaterOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala create mode 100644 src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala create mode 100644 src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala create mode 100644 src/main/scala/millfork/assembly/opt/SizeOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/SuperOptimizer.scala create mode 100644 src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala create mode 100644 src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala create mode 100644 src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala create mode 100644 src/main/scala/millfork/cli/CliOption.scala create mode 100644 src/main/scala/millfork/cli/CliParser.scala create mode 100644 src/main/scala/millfork/cli/CliStatus.scala create mode 100644 src/main/scala/millfork/compiler/BuiltIns.scala create mode 100644 src/main/scala/millfork/compiler/CompilationContext.scala create mode 100644 src/main/scala/millfork/compiler/MfCompiler.scala create mode 100644 src/main/scala/millfork/env/Constant.scala create mode 100644 src/main/scala/millfork/env/Environment.scala create mode 100644 src/main/scala/millfork/env/Thing.scala create mode 100644 src/main/scala/millfork/error/ErrorReporting.scala create mode 100644 src/main/scala/millfork/node/CallGraph.scala create mode 100644 src/main/scala/millfork/node/Node.scala create mode 100644 src/main/scala/millfork/node/Program.scala create mode 100644 src/main/scala/millfork/node/opt/NodeOptimization.scala create mode 100644 src/main/scala/millfork/node/opt/UnreachableCode.scala create mode 100644 src/main/scala/millfork/node/opt/UnusedFunctions.scala create mode 100644 src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala create mode 100644 src/main/scala/millfork/node/opt/UnusedLocalVariables.scala create mode 100644 src/main/scala/millfork/output/Assembler.scala create mode 100644 src/main/scala/millfork/output/CompiledMemory.scala create mode 100644 src/main/scala/millfork/output/OutputPackager.scala create mode 100644 src/main/scala/millfork/output/VariableAllocator.scala create mode 100644 src/main/scala/millfork/parser/MfParser.scala create mode 100644 src/main/scala/millfork/parser/MinimalTestCase.scala create mode 100644 src/main/scala/millfork/parser/ParserBase.scala create mode 100644 src/main/scala/millfork/parser/SourceLoadingQueue.scala create mode 100644 src/main/scala/millfork/parser/TextCodec.scala create mode 100644 src/test/java/com/grapeshot/halfnes/CPURAM.java create mode 100644 src/test/scala/millfork/test/ArraySuite.scala create mode 100644 src/test/scala/millfork/test/AssemblyOptimizationSuite.scala create mode 100644 src/test/scala/millfork/test/AssemblySuite.scala create mode 100644 src/test/scala/millfork/test/BasicSymonTest.scala create mode 100644 src/test/scala/millfork/test/BitOpSuite.scala create mode 100644 src/test/scala/millfork/test/BooleanSuite.scala create mode 100644 src/test/scala/millfork/test/ByteDecimalMathSuite.scala create mode 100644 src/test/scala/millfork/test/ByteMathSuite.scala create mode 100644 src/test/scala/millfork/test/CmosSuite.scala create mode 100644 src/test/scala/millfork/test/ComparisonSuite.scala create mode 100644 src/test/scala/millfork/test/ErasthotenesSuite.scala create mode 100644 src/test/scala/millfork/test/ForLoopSuite.scala create mode 100644 src/test/scala/millfork/test/IllegalSuite.scala create mode 100644 src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala create mode 100644 src/test/scala/millfork/test/LongTest.scala create mode 100644 src/test/scala/millfork/test/MinimalTest.scala create mode 100644 src/test/scala/millfork/test/NodeOptimizationSuite.scala create mode 100644 src/test/scala/millfork/test/NonetSuite.scala create mode 100644 src/test/scala/millfork/test/SeparateBytesSuite.scala create mode 100644 src/test/scala/millfork/test/ShiftSuite.scala create mode 100644 src/test/scala/millfork/test/SignExtensionSuite.scala create mode 100644 src/test/scala/millfork/test/StackVarSuite.scala create mode 100644 src/test/scala/millfork/test/TypeWideningSuite.scala create mode 100644 src/test/scala/millfork/test/WordMathSuite.scala create mode 100644 src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuOptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuPlatform.scala create mode 100644 src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala create mode 100644 src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala create mode 100644 src/test/scala/millfork/test/emu/SymonTestRam.scala diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..f288702d --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 42efeaff..2143467b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Millfork -A middle-level programming language targeting 6502-based microcomputers. +A middle-level programming language targeting 6502-based microcomputers. + +Distributed under GPLv3 (see [LICENSE](LICENSE)) **UNDER DEVELOPMENT, NOT FOR PRODUCTION USE** diff --git a/build.sbt b/build.sbt new file mode 100644 index 00000000..1cc38069 --- /dev/null +++ b/build.sbt @@ -0,0 +1,35 @@ +name := "millfork" + +version := "0.0.1-SNAPSHOT" + +scalaVersion := "2.12.3" + +resolvers += Resolver.mavenLocal + +libraryDependencies += "com.lihaoyi" %% "fastparse" % "1.0.0" + +libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.2" + +libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.4" + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % "test" + +// these two not in Maven Central or any other public repo +// get them from the following links or just build millfork without tests: +// https://github.com/sethm/symon +// https://github.com/andrew-hoffman/halfnes/tree/061 + +libraryDependencies += "com.loomcom.symon" % "symon" % "1.3.0-SNAPSHOT" % "test" + +libraryDependencies += "com.grapeshot" % "halfnes" % "061" % "test" + +mainClass in Compile := Some("millfork.Main") + +assemblyJarName := "millfork.jar" + +//lazy val root = (project in file(".")). +// enablePlugins(BuildInfoPlugin). +// settings( +// buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion), +// buildInfoPackage := "hello" +// ) \ No newline at end of file diff --git a/doc/README.md b/doc/README.md new file mode 100644 index 00000000..d6bdedb3 --- /dev/null +++ b/doc/README.md @@ -0,0 +1,7 @@ +# Documentation + +## Tutorial + +* [Getting started](tutorial/01-getting-started.md) + +* [Basic functions and variables](tutorial/02-functions-variables.md) \ No newline at end of file diff --git a/doc/target-platforms.md b/doc/target-platforms.md new file mode 100644 index 00000000..ec53a857 --- /dev/null +++ b/doc/target-platforms.md @@ -0,0 +1,93 @@ +# Target platforms + +Currently, Millfork supports creating disk- or tape-based programs for Commodore and Atari 8-bit computers, +but it may be expanded to support other 6502-based platforms in the future. + +## Supported platforms + +The following platforms are currently supported: + +* `c64` – Commodore 64 + +* `c16` – Commodore 16 + +* `plus4` – Commodore Plus/4 + +* `vic20` – Commodore VIC-20 without memory expansion + +* `vic20_3k` – Commodore VIC-20 with 3K memory expansion + +* `vic20_8k` – Commodore VIC-20 with 8K or 16K memory expansion + +* `c128` – Commodore 128 in its native mode + +* `pet` – Commodore PET + +* `a8` – Atari 8-bit computers + +The primary and most tested platform is Commodore 64. + +Currently, all targets assume that the program will be loaded from disk or tape. +Cartridge targets are not yet available. + +## Adding a custom platform + +Every platform is defined in an `.ini` file with an appropriate name. + +#### `[compilation]` section + +* `arch` – CPU architecture. It defines which instructions are available. Available values: + + * `nmos` + + * `strict` (= NMOS without illegal instructions) + + * `ricoh` (= NMOS without decimal mode) + + * `strictricoh` + + * `cmos` (= 65C02) + +* `modules` – comma-separated list of modules that will be automatically imported + +* other compilation options (they can be overridden using commandline options): + + * `emit_illegals` – whether the compiler should emit illegal instructions, default `false` + + * `emit_cmos` – whether the compiler should emit CMOS instructions, default is `true` on `cmos` and `false` elsewhere + + * `decimal_mode` – whether the compiler should emit decimal instructions, default is `false` on `ricoh` and `strictricoh` and `true` elsewhere + + * `ro_arrays` – whether the compiler should warn upon array writes, default is `false` + + * `prevent_jmp_indirect_bug` – whether the compiler should try to avoid the indirect JMP bug, default is `false` on `cmos` and `true` elsewhere + +#### `[allocation]` section + +* `main_org` – the address for the `main` function; all the other functions will be placed after it + +* `zp_pointers` – either a list of comma separated zeropage addresses that can be used by the program as zeropage pointers, or `all` for all. Each value should be the address of the first of two free bytes in the zeropage. + +* `himem_style` – not yet supported + +* `himem_start` – the first address used for non-zeropage variables, or `after_code` if the variables should be allocated after the code + +* `himem_end` – the last address available for non-zeropage variables + +#### `[output]` section + +* `style` – not yet supported + +* `format` – output file format; a comma-separated list of tokens: + + * literal byte values + + * `startaddr` – little-endian 16-bit address of the first used byte of the compiled output + + * `endaddr` – little-endian 16-bit address of the last used byte of the compiled output + + * `allocated` – all used bytes + + * `:` - inclusive range of bytes + +* `extension` – target file extension, with or without the dot \ No newline at end of file diff --git a/doc/tutorial/01-getting-started.md b/doc/tutorial/01-getting-started.md new file mode 100644 index 00000000..923f8fb5 --- /dev/null +++ b/doc/tutorial/01-getting-started.md @@ -0,0 +1,54 @@ +# Getting started + +## Hello world example + +Save the following as `hello_world.ml`: + +``` +import stdio + +array hello_world = "hello world" petscii + +void main(){ + putstr(hello_world, hello_world.length) + while(true){} +} +``` + +Compile is using the following commandline: + +``` +java millfork.jar hello_world.ml -o hello_world -t c64 -I path_to_millfork\include +``` + +Run the output executable (here using the VICE emulator): + +``` +x64 hello_world.prg +``` + +## Basic commandline usage + +The following options are crucial when compiling your sources: + +* `-o FILENAME` – specifies the base name for your output file, an appropriate file extension will be appended (`prg` for Commodore, `xex` for Atari, `asm` for assembly output, `lbl` for label file) + +* `-I DIR;DIR;DIR;...` – specifies the paths to directories with modules to include. + +* `-t PLATFORM` – specifies the target platform (`c64` is the default). Each platform is defined in an `.ini` file in the include directory. For the list of supported platforms, see [Supported platforms](../target-platforms.md) + +You may be also interested in the following: + +* `-O`, `-O2`, `-O3` – enable optimization (various levels) + +* `--detailed-flow` – use more resource-consuming but more precise flow analysis engine for better optimization + +* `-s` – additionally generate assembly output + +* `-g` – additionally generate a label file, in format compatible with VICE emulator + +* `-r PROGRAM` – automatically launch given program after successful compilation + +* `-Wall` – enable all warnings + +* `--help` – list all commandline options \ No newline at end of file diff --git a/doc/tutorial/02-functions-variables.md b/doc/tutorial/02-functions-variables.md new file mode 100644 index 00000000..e590f14e --- /dev/null +++ b/doc/tutorial/02-functions-variables.md @@ -0,0 +1,16 @@ +# Functions and variables + +TODO: write all of this + +## Basic types + +## Defining variables + +## Built-in operators + +### Byte operators + +| a | a | a | +| -- | -- | -- | +| a | a | a | + diff --git a/examples/hello_world/hello_world.mfk b/examples/hello_world/hello_world.mfk new file mode 100644 index 00000000..4bbbf4df --- /dev/null +++ b/examples/hello_world/hello_world.mfk @@ -0,0 +1,11 @@ +// compile with +// java -jar millfork.jar -I ${PATH}/include -t ${platform} ${PATH}/examples/hello_world/hello_world.mfk + +import stdio + +array hello_world = "hello world" petscii + +void main(){ + putstr(hello_world, hello_world.length) + while(true){} +} \ No newline at end of file diff --git a/include/a8.ini b/include/a8.ini new file mode 100644 index 00000000..0723c075 --- /dev/null +++ b/include/a8.ini @@ -0,0 +1,22 @@ +[compilation] +arch=strict +modules=a8_kernel + + +[allocation] +main_org=$2000 +; TODO +zp_pointers=$80,$82,$84,$86,$88,$8a,$8c,$8e,$90,$92,$94,$96,$98,$9a,$9c,$9e,$a0,$a2,$a4 +;TODO +himem_style=per_bank +himem_start=after_code +;TODO +himem_end=$3FFF + +[output] +;TODO +style=per_bank +format=$FF,$FF,$E0,$02,$E1,$02,startaddr,startaddr,endaddr,allocated +extension=xex + + diff --git a/include/a8_kernel.mfk b/include/a8_kernel.mfk new file mode 100644 index 00000000..76afe512 --- /dev/null +++ b/include/a8_kernel.mfk @@ -0,0 +1,9 @@ +asm void putchar(byte a) { + tax + lda $347 + pha + lda $346 + pha + txa + rts +} \ No newline at end of file diff --git a/include/c128.ini b/include/c128.ini new file mode 100644 index 00000000..6d40a6cd --- /dev/null +++ b/include/c128.ini @@ -0,0 +1,20 @@ +[compilation] +arch=nmos +modules=c128_hardware,loader_1c01,c128_kernal + + +[allocation] +main_org=$1C0D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +; TODO +himem_end=$FEFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/c128_hardware.mfk b/include/c128_hardware.mfk new file mode 100644 index 00000000..ba0b81dd --- /dev/null +++ b/include/c128_hardware.mfk @@ -0,0 +1,5 @@ +import c64_vic +import c64_sid +import c64_cia + +array c64_color_ram [1000] @$D800 diff --git a/include/c128_kernal.mfk b/include/c128_kernal.mfk new file mode 100644 index 00000000..ac73145d --- /dev/null +++ b/include/c128_kernal.mfk @@ -0,0 +1,5 @@ +// Routines from Commodore 128 KERNAL ROM + +// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.) +// Input: A = Byte to write. +asm void putchar(byte a) @$FFD2 extern \ No newline at end of file diff --git a/include/c1531.mfk b/include/c1531.mfk new file mode 100644 index 00000000..f42e8eba --- /dev/null +++ b/include/c1531.mfk @@ -0,0 +1,60 @@ +// mouse driver for Commodore 1531 mouse on Commodore 64 + +import mouse +import c64_hardware + +sbyte _c1531_calculate_delta (byte old, byte new) { + byte mouse_delta + mouse_delta = (new - old) + mouse_delta &= $3f + if mouse_delta >= $20 { + mouse_delta |= $c0 + } + return mouse_delta +} + +byte _c1531_handle_x() { + static byte _c1531_old_pot_x + sbyte mouse_delta + byte new_pot_x + + new_pot_x = sid_paddle_x >> 1 + mouse_delta = _c1531_calculate_delta(_c1531_old_pot_x, new_pot_x) + _c1531_old_pot_x = new_pot_x + + mouse_x += mouse_delta + mouse_x.hi &= 1 + + if mouse_x > 319 { + if mouse_delta > 0 { + mouse_x = 319 + } else { + mouse_x = 0 + } + } +} + +byte _c1531_handle_y() { + static byte _c1531_old_pot_y + byte new_pot_y + sbyte mouse_delta + + new_pot_y = sid_paddle_y >> 1 + mouse_delta = _c1531_calculate_delta(_c1531_old_pot_y, new_pot_y) + _c1531_old_pot_y = new_pot_y + mouse_y -= mouse_delta + if mouse_y > 199 { + if mouse_delta > 0 { + mouse_y = 0 + } else { + mouse_y = 199 + } + } +} + +void c1531_mouse () { + + cia1_pra = ($3f & cia1_pra) | $40 + _c1531_handle_x() + _c1531_handle_y() +} \ No newline at end of file diff --git a/include/c16.ini b/include/c16.ini new file mode 100644 index 00000000..bba71a26 --- /dev/null +++ b/include/c16.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=loader_1001,c264_kernal,c264_hardware + + +[allocation] +main_org=$100D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$3FFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/c264_hardware.mfk b/include/c264_hardware.mfk new file mode 100644 index 00000000..b83cf6e6 --- /dev/null +++ b/include/c264_hardware.mfk @@ -0,0 +1 @@ +import c16_ted \ No newline at end of file diff --git a/include/c264_kernal.mfk b/include/c264_kernal.mfk new file mode 100644 index 00000000..2f1ae990 --- /dev/null +++ b/include/c264_kernal.mfk @@ -0,0 +1,5 @@ +// Routines from C16 and Plus/4 KERNAL ROM + +// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.) +// Input: A = Byte to write. +asm void putchar(byte a) @$FFD2 extern \ No newline at end of file diff --git a/include/c264_ted.mfk b/include/c264_ted.mfk new file mode 100644 index 00000000..580b08f9 --- /dev/null +++ b/include/c264_ted.mfk @@ -0,0 +1,20 @@ + +const byte black = 0 +const byte white = $71 +const byte red = $22 +const byte cyan = $43 +const byte purple = $24 +const byte green = $35 +const byte blue = $16 +const byte yellow = $57 +const byte orange = $28 +const byte brown = $19 +const byte light_red = $32 +const byte dark_grey = $21 +const byte dark_gray = $21 +const byte medium_grey = $31 +const byte medium gray = $31 +const byte light_green = $55 +const byte light_blue = $36 +const byte light_grey = $41 +const byte light_gray = $41 \ No newline at end of file diff --git a/include/c64.ini b/include/c64.ini new file mode 100644 index 00000000..dcaf4d21 --- /dev/null +++ b/include/c64.ini @@ -0,0 +1,37 @@ +; Commodore 64 +; assuming a program loaded from disk or tape + +[compilation] +; CPU architecture: nmos, strictnmos, ricoh, strictricoh, cmos +arch=nmos +; modules to load +modules=c64_hardware,loader_0801,c64_kernal,stdlib +; optionally: default flags +emit_illegals=true + + +[allocation] +; where the main function should be allocated, also the start of bank 0 +main_org=$80D +; list of free zp pointer locations (these assume that BASIC will keep working) +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +; where to allocate non-zp variables +himem_style=per_bank +himem_start=after_code +himem_end=$9FFF + +[output] +; how the banks are laid out in the output files; so far, there is no bank support in the compiler yet +style=per_bank +; output file format +; startaddr - little-endian address of the first used byte in the bank +; endaddr - little-endian address of the last used byte in the bank +; allocated - all used bytes in the bank +; : - bytes from the current bank +; :addr>: - bytes from arbitrary bank +; - single byte +format=startaddr,allocated +; default output file extension +extension=prg + + diff --git a/include/c64_basic.mfk b/include/c64_basic.mfk new file mode 100644 index 00000000..2e455c03 --- /dev/null +++ b/include/c64_basic.mfk @@ -0,0 +1,6 @@ +// Routines from C64 BASIC ROM + +import c64_kernal + +// print a 16-bit number on the standard output +asm void putword(word xa) @$BDCD extern \ No newline at end of file diff --git a/include/c64_cia.mfk b/include/c64_cia.mfk new file mode 100644 index 00000000..f077b672 --- /dev/null +++ b/include/c64_cia.mfk @@ -0,0 +1,40 @@ +// Hardware addresses for C64 + +// CIA1 +byte cia1_pra @$DC00 +byte cia1_prb @$DC01 +byte cia1_ddra @$DC02 +byte cia1_ddrb @$DC03 +byte cia2_pra @$DD00 +byte cia2_prb @$DD01 +byte cia2_ddra @$DD02 +byte cia2_ddrb @$DD03 + +inline asm void cia_disable_irq() { + LDA #$7f + LDA $dc0d + LDA $dd0d + LDA $dc0d + LDA $dd0d +} + + +inline void vic_bank_0000() { + cia2_ddra = $C0 + cia2_pra = $C0 +} + +inline void vic_bank_4000() { + cia2_ddra = $C0 + cia2_pra = $80 +} + +inline void vic_bank_8000() { + cia2_ddra = $C0 + cia2_pra = $40 +} + +inline void vic_bank_C000() { + cia2_ddra = $C0 + cia2_pra = $00 +} \ No newline at end of file diff --git a/include/c64_hardware.mfk b/include/c64_hardware.mfk new file mode 100644 index 00000000..0347eb68 --- /dev/null +++ b/include/c64_hardware.mfk @@ -0,0 +1,41 @@ +import c64_vic +import c64_sid +import c64_cia +import cpu6510 + +array c64_color_ram [1000] @$D800 + +inline void c64_ram_only() { + cpu6510_ddr = 7 + cpu6510_port = 0 +} + +inline void c64_ram_io() { + cpu6510_ddr = 7 + cpu6510_port = 5 +} + +inline void c64_ram_io_kernal() { + cpu6510_ddr = 7 + cpu6510_port = 6 +} + +inline void c64_ram_io_basic() { + cpu6510_ddr = 7 + cpu6510_port = 7 +} + +inline void c64_ram_charset() { + cpu6510_ddr = 7 + cpu6510_port = 1 +} + +inline void c64_ram_charset_kernal() { + cpu6510_ddr = 7 + cpu6510_port = 2 +} + +inline void c64_ram_charset_basic() { + cpu6510_ddr = 7 + cpu6510_port = 3 +} \ No newline at end of file diff --git a/include/c64_kernal.mfk b/include/c64_kernal.mfk new file mode 100644 index 00000000..ed69eb0b --- /dev/null +++ b/include/c64_kernal.mfk @@ -0,0 +1,32 @@ +// Routines from C64 KERNAL ROM + +// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.) +// Input: A = Byte to write. +asm void putchar(byte a) @$FFD2 extern + +// OPEN. Open file. (Must call SETLFS and SETNAM beforehands.) +asm void open() @$FFC0 extern + +// CLOSE. Close file. +// Input: A = Logical number. +asm void close(byte a) @$FFC0 extern + +// SETLFS. Set file parameters. +// Input: A = Logical number; X = Device number; Y = Secondary address. +asm void setlfs(byte a, byte x, byte y) @$FFBA extern + +// SETNAM. Set file name parameters. +// Input: A = File name length; X/Y = Pointer to file name. +asm void setnam(word yx, byte a) @$FFBA extern + +// LOAD. Load or verify file. (Must call SETLFS and SETNAM beforehands.) +// Input: A: 0 = Load, 1-255 = Verify; X/Y = Load address (if secondary address = 0). +// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1); X/Y = Address of last byte loaded/verified (if Carry = 0). +asm clear_carry load(byte a, word yx) @$FFD5 extern + +// SAVE. Save file. (Must call SETLFS and SETNAM beforehands.) +// Input: A = Address of zero page register holding start address of memory area to save; X/Y = End address of memory area plus 1. +// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1). +asm clear_carry save(byte a, word yx) @$FFD5 extern + +word irq_pointer @$314 \ No newline at end of file diff --git a/include/c64_sid.mfk b/include/c64_sid.mfk new file mode 100644 index 00000000..79cd772b --- /dev/null +++ b/include/c64_sid.mfk @@ -0,0 +1,24 @@ +// Hardware addresses for C64 + +// SID + +word sid_v1_freq @$D400 +word sid_v1_pulse @$D402 +byte sid_v1_cr @$D404 +byte sid_v1_ad @$D405 +byte sid_v1_sr @$D409 + +word sid_v2_freq @$D407 +word sid_v2_pulse @$D409 +byte sid_v2_cr @$D40B +byte sid_v2_ad @$D40C +byte sid_v2_sr @$D40D + +word sid_v3_freq @$D40E +word sid_v3_pulse @$D410 +byte sid_v3_cr @$D412 +byte sid_v3_ad @$D413 +byte sid_v3_sr @$D414 + +byte sid_paddle_x @$D419 +byte sid_paddle_y @$D41A diff --git a/include/c64_vic.mfk b/include/c64_vic.mfk new file mode 100644 index 00000000..25f33259 --- /dev/null +++ b/include/c64_vic.mfk @@ -0,0 +1,154 @@ +// Hardware addresses for C64 + +// VIC-II +byte vic_spr0_x @$D000 +byte vic_spr0_y @$D001 +byte vic_spr1_x @$D002 +byte vic_spr1_y @$D003 +byte vic_spr2_x @$D004 +byte vic_spr2_y @$D005 +byte vic_spr3_x @$D006 +byte vic_spr3_y @$D007 +byte vic_spr4_x @$D008 +byte vic_spr4_y @$D009 +byte vic_spr5_x @$D00A +byte vic_spr5_y @$D00B +byte vic_spr6_x @$D00C +byte vic_spr6_y @$D00D +byte vic_spr7_x @$D00E +byte vic_spr7_y @$D00F +byte vic_spr_hi_x @$D010 +byte vic_cr1 @$D011 +byte vic_raster @$D012 +byte vic_lp_x @$D013 +byte vic_lp_y @$D014 +byte vic_spr_ena @$D015 +byte vic_cr2 @$D016 +byte vic_spr_exp_y @$D017 +byte vic_mem @$D018 +byte vic_irq @$D019 +byte vic_irq_ena @$D01A +byte vic_spr_dp @$D01B +byte vic_spr_mcolor @$D01C +byte vic_spr_exp_x @$D01D +byte vic_spr_ss_col @$D01E +byte vic_spr_sd_col @$D01F +byte vic_border @$D020 +byte vic_bg_color0 @$D021 +byte vic_bg_color1 @$D022 +byte vic_bg_color2 @$D023 +byte vic_bg_color3 @$D024 +byte vic_spr_color1 @$D025 +byte vic_spr_color2 @$D026 +byte vic_spr0_color @$D027 +byte vic_spr1_color @$D028 +byte vic_spr2_color @$D029 +byte vic_spr3_color @$D02A +byte vic_spr4_color @$D02B +byte vic_spr5_color @$D02C +byte vic_spr6_color @$D02D +byte vic_spr7_color @$D02E + +array vic_spr_coord [16] @$D000 +array vic_spr_color [8] @$D027 + +inline void vic_enable_multicolor() { + vic_cr2 |= 0x10 +} + +inline void vic_disable_multicolor() { + vic_cr2 &= 0xEF +} + +inline void vic_enable_bitmap() { + vic_cr1 |= 0x20 +} + +inline void vic_disable_bitmap() { + vic_cr1 &= 0xDF +} + +inline void vic_24_rows() { + vic_cr1 &= 0xF7 +} + +inline void vic_25_rows() { + vic_cr1 |= 8 +} + +inline void vic_38_columns() { + vic_cr2 &= 0xF7 +} + +inline void vic_40_columns() { + vic_cr2 |= 8 +} + +inline void vic_disable_irq() { + vic_irq_ena = 0 + vic_irq += 1 +} + +// base: divisible by $400, $0000-$3C00 allowed +//inline void vic_screen(word const base) { +// vic_mem = (vic_mem & $0F) | (base >> 6) +//} + +inline void vic_charset_0000() { + vic_mem = (vic_mem & $F1) +} +inline void vic_charset_0800() { + vic_mem = (vic_mem & $F1) | 2 +} +inline void vic_charset_1000() { + vic_mem = (vic_mem & $F1) | 4 +} +inline void vic_charset_1800() { + vic_mem = (vic_mem & $F1) | 6 +} +inline void vic_charset_2000() { + vic_mem = (vic_mem & $F1) | 8 +} +inline void vic_charset_2800() { + vic_mem = (vic_mem & $F1) | $A +} +inline void vic_charset_3000() { + vic_mem = (vic_mem & $F1) | $C +} +inline void vic_charset_3800() { + vic_mem = (vic_mem & $F1) | $E +} + +inline void vic_bitmap_0000() { + vic_mem &= $F7 +} +inline void vic_bitmap_2000() { + vic_mem |= 8 +} + +// x, y < 8 +// default: x=0, y=3 +void vic_set_scroll(byte x, byte y) { + vic_cr1 = (vic_cr1 & $F8) | y + vic_cr2 = (vic_cr2 & $F8) | x +} + +const byte black = 0 +const byte white = 1 +const byte red = 2 +const byte cyan = 3 +const byte purple = 4 +const byte green = 5 +const byte blue = 6 +const byte yellow = 7 +const byte orange = 8 +const byte brown = 9 +const byte light_red = 10 +const byte dark_grey = 11 +const byte dark_gray = 11 +const byte medium_grey = 12 +const byte medium_gray = 12 +const byte light_green = 13 +const byte light_blue = 14 +const byte light_grey = 15 +const byte light_gray = 15 \ No newline at end of file diff --git a/include/cpu6510.mfk b/include/cpu6510.mfk new file mode 100644 index 00000000..90a674ca --- /dev/null +++ b/include/cpu6510.mfk @@ -0,0 +1,3 @@ + +byte cpu6510_ddr @0 +byte cpu6510_port @1 diff --git a/include/loader_0401.mfk b/include/loader_0401.mfk new file mode 100644 index 00000000..0a748c51 --- /dev/null +++ b/include/loader_0401.mfk @@ -0,0 +1,15 @@ +array _basic_loader @$401 = [ + $0b, + 4, + 10, + 0, + $9e, + $31, + $30, + $33, + $37, + 0, + 0, + 0 + ] + diff --git a/include/loader_0801.mfk b/include/loader_0801.mfk new file mode 100644 index 00000000..796b8b2f --- /dev/null +++ b/include/loader_0801.mfk @@ -0,0 +1,15 @@ +array _basic_loader @$801 = [ + $0b, + $08, + 10, + 0, + $9e, + $32, + $30, + $36, + $31, + 0, + 0, + 0 + ] + diff --git a/include/loader_1001.mfk b/include/loader_1001.mfk new file mode 100644 index 00000000..b16d858c --- /dev/null +++ b/include/loader_1001.mfk @@ -0,0 +1,15 @@ +array _basic_loader @$1001 = [ + $0b, + $10, + 10, + 0, + $9e, + $34, + $31, + $30, + $39, + 0, + 0, + 0 + ] + diff --git a/include/loader_1201.mfk b/include/loader_1201.mfk new file mode 100644 index 00000000..144a3e90 --- /dev/null +++ b/include/loader_1201.mfk @@ -0,0 +1,15 @@ +array _basic_loader @$1201 = [ + $0b, + $12, + 10, + 0, + $9e, + $34, + $36, + $32, + $31, + 0, + 0, + 0 + ] + diff --git a/include/loader_1c01.mfk b/include/loader_1c01.mfk new file mode 100644 index 00000000..73962381 --- /dev/null +++ b/include/loader_1c01.mfk @@ -0,0 +1,15 @@ +array _basic_loader @$1C01 = [ + $0b, + $1C, + 10, + 0, + $9e, + $37, + $31, + $38, + $31, + 0, + 0, + 0 + ] + diff --git a/include/mouse.mfk b/include/mouse.mfk new file mode 100644 index 00000000..f4a65210 --- /dev/null +++ b/include/mouse.mfk @@ -0,0 +1,8 @@ +// Generic module for mouse support +// Resolutions up to 512x256 are supported + + +// Mouse X coordinate +word mouse_x +// Mouse Y coordinate +byte mouse_y diff --git a/include/pet.ini b/include/pet.ini new file mode 100644 index 00000000..ca975341 --- /dev/null +++ b/include/pet.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=loader_0401,pet_kernal + + +[allocation] +main_org=$40D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$FFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/pet_kernal.mfk b/include/pet_kernal.mfk new file mode 100644 index 00000000..31d7daf8 --- /dev/null +++ b/include/pet_kernal.mfk @@ -0,0 +1,5 @@ +// Routines from Commodore PET KERNAL ROM + +// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.) +// Input: A = Byte to write. +asm void putchar(byte a) @$FFD2 extern \ No newline at end of file diff --git a/include/plus4.ini b/include/plus4.ini new file mode 100644 index 00000000..5b12893d --- /dev/null +++ b/include/plus4.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=c264_loader,c264_kernal,c264_hardware + + +[allocation] +main_org=$100D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$7FFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/stdio.mfk b/include/stdio.mfk new file mode 100644 index 00000000..16161f2d --- /dev/null +++ b/include/stdio.mfk @@ -0,0 +1,10 @@ +// target-independent standard I/O routines + +void putstr(pointer str, byte len) { + byte index + index = 0 + while (index != len) { + putchar(str[index]) + index += 1 + } +} \ No newline at end of file diff --git a/include/stdlib.mfk b/include/stdlib.mfk new file mode 100644 index 00000000..0dda8f5c --- /dev/null +++ b/include/stdlib.mfk @@ -0,0 +1,23 @@ +// target-independent things + +word nmi_routine_addr @$FFFA +word reset_routine_addr @$FFFC +word irq_routine_addr @$FFFE + +inline asm void poke(word const addr, byte const value) { + ?LDA #value + STA addr +} + +inline asm byte peek(word const addr) { + LDA addr +} + +inline asm void disable_irq() { + SEI +} + +inline asm void enable_irq() { + CLI +} + diff --git a/include/vic20.ini b/include/vic20.ini new file mode 100644 index 00000000..6c33a727 --- /dev/null +++ b/include/vic20.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=loader_1001,vic20_kernal + + +[allocation] +main_org=$100D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$1CFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/vic20_3k.ini b/include/vic20_3k.ini new file mode 100644 index 00000000..10088a92 --- /dev/null +++ b/include/vic20_3k.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=loader_0401,vic20_kernal + + +[allocation] +main_org=$40D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$1CFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/vic20_8k.ini b/include/vic20_8k.ini new file mode 100644 index 00000000..d7acfa58 --- /dev/null +++ b/include/vic20_8k.ini @@ -0,0 +1,19 @@ +[compilation] +arch=nmos +modules=loader_1201,vic20_kernal + + +[allocation] +main_org=$120D +; TODO +zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B +himem_style=per_bank +himem_start=after_code +himem_end=$1FFF + +[output] +style=per_bank +format=startaddr,allocated +extension=prg + + diff --git a/include/vic20_kernal.mfk b/include/vic20_kernal.mfk new file mode 100644 index 00000000..2f1ae990 --- /dev/null +++ b/include/vic20_kernal.mfk @@ -0,0 +1,5 @@ +// Routines from C16 and Plus/4 KERNAL ROM + +// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.) +// Input: A = Byte to write. +asm void putchar(byte a) @$FFD2 extern \ No newline at end of file diff --git a/project/assembly.sbt b/project/assembly.sbt new file mode 100644 index 00000000..cdb3c0bb --- /dev/null +++ b/project/assembly.sbt @@ -0,0 +1,2 @@ + +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6") diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 00000000..826c0bd9 --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version = 0.13.16 \ No newline at end of file diff --git a/project/buildinfo.sbt b/project/buildinfo.sbt new file mode 100644 index 00000000..42c669b1 --- /dev/null +++ b/project/buildinfo.sbt @@ -0,0 +1 @@ +addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.7.0") \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 00000000..e69de29b diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala new file mode 100644 index 00000000..1e7db997 --- /dev/null +++ b/src/main/scala/millfork/CompilationOptions.scala @@ -0,0 +1,116 @@ +package millfork + +import millfork.error.ErrorReporting + +/** + * @author Karol Stasiak + */ +// +//object CompilationOptions { +// +// +// private var instance = new CompilationOptions(Platform.C64, Map()) +// +// // TODO: ugly! +// def change(o: CompilationOptions): Unit = { +// instance = o +// } +// +// def current: CompilationOptions= instance +// +// def platform: Platform = instance.platform +// +// def flag(flag: CompilationFlag.Value):Boolean = instance.flags(flag) +// +// def flags: Map[CompilationFlag.Value, Boolean] = instance.flags +//} +class CompilationOptions(val platform: Platform, val commandLineFlags: Map[CompilationFlag.Value, Boolean]) { + + import CompilationFlag._ + import Cpu._ + + val flags: Map[CompilationFlag.Value, Boolean] = CompilationFlag.values.map { f => + f -> commandLineFlags.getOrElse(f, platform.flagOverrides.getOrElse(f, Cpu.defaultFlags(platform.cpu)(f))) + }.toMap + + def flag(f: CompilationFlag.Value) = flags(f) + + if (flags(DecimalMode)) { + if (platform.cpu == Ricoh || platform.cpu == StrictRicoh) { + ErrorReporting.warn("Decimal mode enabled for Ricoh architecture", this) + } + } + if (platform.cpu != Cmos) { + if (!flags(PreventJmpIndirectBug)) { + ErrorReporting.warn("JMP bug prevention should be enabled for non-CMOS architecture", this) + } + if (flags(EmitCmosOpcodes)) { + ErrorReporting.warn("CMOS opcodes enabled for non-CMOS architecture", this) + } + } + if (flags(EmitIllegals)) { + if (platform.cpu == Cmos) { + ErrorReporting.warn("Illegal opcodes enabled for CMOS architecture", this) + } + if (platform.cpu == StrictRicoh || platform.cpu == Ricoh) { + ErrorReporting.warn("Illegal opcodes enabled for strict architecture", this) + } + } +} + +object Cpu extends Enumeration { + + val Mos, StrictMos, Ricoh, StrictRicoh, Cmos = Value + + import CompilationFlag._ + + def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match { + case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) + case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) + case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap) + case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap) + case Cmos => Set(EmitCmosOpcodes, VariableOverlap) + } + + def fromString(name: String): Cpu.Value = name match { + case "nmos" => Mos + case "6502" => Mos + case "6510" => Mos + case "strict" => StrictMos + case "cmos" => Cmos + case "65c02" => Cmos + case "ricoh" => Ricoh + case "2a03" => Ricoh + case "2a07" => Ricoh + case "strictricoh" => StrictRicoh + case "strict2a03" => StrictRicoh + case "strict2a07" => StrictRicoh + case _ => ErrorReporting.fatal("Unknown CPU achitecture") + } +} + +object CompilationFlag extends Enumeration { + val + // compilation options: + EmitIllegals, EmitCmosOpcodes, DecimalMode, ReadOnlyArrays, PreventJmpIndirectBug, + // optimization options: + DetailedFlowAnalysis, DangerousOptimizations, + // memory allocation options + VariableOverlap, + // warning options + ExtraComparisonWarnings, + RorWarning, + FatalWarnings = Value + + val allWarnings: Set[CompilationFlag.Value] = Set(ExtraComparisonWarnings) + + val fromString = Map( + "emit_illegals" -> EmitIllegals, + "emit_cmos" -> EmitCmosOpcodes, + "decimal_mode" -> DecimalMode, + "ro_arrays" -> ReadOnlyArrays, + "ror_warn" -> RorWarning, + "prevent_jmp_indirect_bug" -> PreventJmpIndirectBug, + ) + +} \ No newline at end of file diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala new file mode 100644 index 00000000..37a9cc99 --- /dev/null +++ b/src/main/scala/millfork/Main.scala @@ -0,0 +1,264 @@ +package millfork + +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} +import java.util.Locale + +import millfork.assembly.opt.{CmosOptimizations, DangerousOptimizations, SuperOptimizer, UndocumentedOptimizations} +import millfork.cli.{CliParser, CliStatus} +import millfork.env.Environment +import millfork.error.ErrorReporting +import millfork.node.StandardCallGraph +import millfork.output.Assembler +import millfork.parser.SourceLoadingQueue + +/** + * @author Karol Stasiak + */ + +case class Context(inputFileNames: List[String], + outputFileName: Option[String] = None, + runFileName: Option[String] = None, + optimizationLevel: Option[Int] = None, + platform: Option[String] = None, + outputAssembly: Boolean = false, + outputLabels: Boolean = false, + includePath: List[String] = Nil, + flags: Map[CompilationFlag.Value, Boolean] = Map(), + verbosity: Option[Int] = None) { + def changeFlag(f: CompilationFlag.Value, b: Boolean): Context = { + if (flags.contains(f)) { + if (flags(f) != b) { + ErrorReporting.error("Conflicting flags") + } + this + } else { + copy(flags = this.flags + (f -> b)) + } + } +} + +object Main { + + + def main(args: Array[String]): Unit = { + if (args.isEmpty) { + ErrorReporting.info("For help, use --help") + } + val (status, c) = parser.parse(Context(Nil), args.toList) + status match { + case CliStatus.Quit => return + case CliStatus.Failed => + ErrorReporting.fatalQuit("Invalid command line") + case CliStatus.Ok => () + } + ErrorReporting.assertNoErrors("Invalid command line") + if (c.inputFileNames.isEmpty) { + ErrorReporting.fatalQuit("No input files") + } + ErrorReporting.verbosity = c.verbosity.getOrElse(0) + val optLevel = c.optimizationLevel.getOrElse(0) + val platform = Platform.lookupPlatformFile(c.includePath, c.platform.getOrElse { + ErrorReporting.info("No platform selected, defaulting to `c64`") + "c64" + }) + val options = new CompilationOptions(platform, c.flags) + ErrorReporting.debug("Effective flags: " + options.flags) + + val output = c.outputFileName.getOrElse("a") + val assOutput = output + ".asm" + val labelOutput = output + ".lbl" + val prgOutput = if (!output.endsWith(platform.fileExtension)) output + platform.fileExtension else output + + val unoptimized = new SourceLoadingQueue( + initialFilenames = c.inputFileNames, + includePath = c.includePath, + options = options).run() + + val program = if (optLevel > 0) { + OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt)) + } else { + unoptimized + } + val callGraph = new StandardCallGraph(program) + + val env = new Environment(None, "") + env.collectDeclarations(program, options) + val extras = List( + if (options.flag(CompilationFlag.EmitIllegals)) UndocumentedOptimizations.All else Nil, + if (options.flag(CompilationFlag.EmitCmosOpcodes)) CmosOptimizations.All else Nil, + if (options.flag(CompilationFlag.DangerousOptimizations)) DangerousOptimizations.All else Nil, + ).flatten + val goodCycle = List.fill(optLevel - 1)(OptimizationPresets.Good ++ extras).flatten + val assemblyOptimizations = if (optLevel <= 0) Nil else if (optLevel >= 9) List(SuperOptimizer) else { + goodCycle ++ OptimizationPresets.AssOpt ++ extras ++ goodCycle + } + + // compile + val assembler = new Assembler(env) + val result = assembler.assemble(callGraph, assemblyOptimizations, options) + ErrorReporting.assertNoErrors("Codegen failed") + ErrorReporting.debug(f"Unoptimized code size: ${assembler.unoptimizedCodeSize}%5d B") + ErrorReporting.debug(f"Optimized code size: ${assembler.optimizedCodeSize}%5d B") + ErrorReporting.debug(f"Gain: ${(100L * (assembler.unoptimizedCodeSize - assembler.optimizedCodeSize) / assembler.unoptimizedCodeSize.toDouble).round}%5d%%") + ErrorReporting.debug(f"Initialized arrays: ${assembler.initializedArraysSize}%5d B") + + if (c.outputAssembly) { + val path = Paths.get(assOutput) + ErrorReporting.debug("Writing assembly to " + path.toAbsolutePath) + Files.write(path, result.asm.mkString("\n").getBytes(StandardCharsets.UTF_8)) + } + if (c.outputLabels) { + val path = Paths.get(labelOutput) + ErrorReporting.debug("Writing labels to " + path.toAbsolutePath) + Files.write(path, result.labels.sortWith { (a, b) => + val aLocal = a._1.head == '.' + val bLocal = b._1.head == '.' + if (aLocal == bLocal) a._1 < b._1 + else b._1 < a._1 + }.groupBy(_._2).values.map(_.head).toSeq.sortBy(_._2).map { case (l, a) => + val normalized = l.replace('$', '_').replace('.', '_') + s"al ${a.toHexString} .$normalized" + }.mkString("\n").getBytes(StandardCharsets.UTF_8)) + } + val path = Paths.get(prgOutput) + ErrorReporting.debug("Writing output to " + path.toAbsolutePath) + Files.write(path, result.code) + c.runFileName.foreach(program => + new ProcessBuilder(program, path.toAbsolutePath.toString).start() + ) + } + + private def parser = new CliParser[Context] { + + fluff("Main options:", "") + + parameter("-o", "--out").required().placeholder("").action { (p, c) => + assertNone(c.outputFileName, "Output already defined") + c.copy(outputFileName = Some(p)) + }.description("The output file name, without extension.").onWrongNumber(_ => ErrorReporting.fatalQuit("No output file specified")) + + flag("-s").action { c => + c.copy(outputAssembly = true) + }.description("Generate also the assembly output.") + + flag("-g").action { c => + c.copy(outputLabels = true) + }.description("Generate also the label file.") + + parameter("-t", "--target").placeholder("").action { (p, c) => + assertNone(c.platform, "Platform already defined") + c.copy(platform = Some(p)) + }.description("Target platform, any of: c64, c16, plus4, vic20, vic20_3k, vic20_8k, pet, c128, a8.") + + parameter("-I", "--include-dir").repeatable().placeholder(";;...").action { (paths, c) => + val n = paths.split(";") + c.copy(includePath = c.includePath ++ n) + }.description("Include paths for modules.") + + parameter("-r", "--run").placeholder("").action { (p, c) => + assertNone(c.runFileName, "Run program already defined") + c.copy(runFileName = Some(p)) + }.description("Program to run after successful compilation.") + + endOfFlags("--").description("Marks the end of options.") + + fluff("", "Verbosity options:", "") + + flag("-q", "--quiet").action { c => + assertNone(c.verbosity, "Cannot use -v and -q together") + c.copy(verbosity = Some(-1)) + }.description("Supress all messages except for errors.") + + private val verbose = flag("-v", "--verbose").maxCount(3).action { c => + if (c.verbosity.exists(_ < 0)) ErrorReporting.error("Cannot use -v and -q together", None) + c.copy(verbosity = Some(1 + c.verbosity.getOrElse(0))) + }.description("Increase verbosity.") + flag("-vv").repeatable().action(c => verbose.encounter(verbose.encounter(verbose.encounter(c)))).description("Increase verbosity even more.") + flag("-vvv").repeatable().action(c => verbose.encounter(verbose.encounter(c))).description("Increase verbosity even more.") + + fluff("", "Code generation options:", "") + + boolean("-fcmos-ops", "-fno-cmos-ops").action { (c, v) => + c.changeFlag(CompilationFlag.EmitCmosOpcodes, v) + }.description("Whether should emit CMOS opcodes.") + boolean("-fillegals", "-fno-illegals").action { (c, v) => + c.changeFlag(CompilationFlag.EmitIllegals, v) + }.description("Whether should emit illegal (undocumented) NMOS opcodes.") + boolean("-fjmp-fix", "-fno-jmp-fix").action { (c, v) => + c.changeFlag(CompilationFlag.PreventJmpIndirectBug, v) + }.description("Whether should prevent indirect JMP bug on page boundary.") + boolean("-fdecimal-mode", "-fno-decimal-mode").action { (c, v) => + c.changeFlag(CompilationFlag.DecimalMode, v) + }.description("Whether should decimal mode be available.") + boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) => + c.changeFlag(CompilationFlag.VariableOverlap, v) + }.description("Whether should variables overlap if their scopes do not intersect.") + + fluff("", "Optimization options:", "") + + + flag("-O0").action { c => + assertNone(c.optimizationLevel, "Optimization level already defined") + c.copy(optimizationLevel = Some(0)) + }.description("Disable all optimizations.") + flag("-O").action { c => + assertNone(c.optimizationLevel, "Optimization level already defined") + c.copy(optimizationLevel = Some(1)) + }.description("Optimize code.") + for (i <- 2 to 9) { + val f = flag("-O" + i).action { c => + assertNone(c.optimizationLevel, "Optimization level already defined") + c.copy(optimizationLevel = Some(i)) + }.description("Optimize code even more.") + if (i > 3) f.hidden() + } + flag("--detailed-flow").action { c => + c.changeFlag(CompilationFlag.DetailedFlowAnalysis, true) + }.description("Use detailed flow analysis (experimental).") + flag("--dangerous-optimizations").action { c => + c.changeFlag(CompilationFlag.DangerousOptimizations, true) + }.description("Use dangerous optimizations (experimental).") + + fluff("", "Warning options:", "") + + flag("-Wall", "--Wall").action { c => + CompilationFlag.allWarnings.foldLeft(c) { (c, f) => c.changeFlag(f, true) } + }.description("Enable extra warnings.") + + flag("-Wfatal", "--Wfatal").action { c => + c.changeFlag(CompilationFlag.FatalWarnings, true) + }.description("Treat warnings as errors.") + + fluff("", "Other options:", "") + + flag("--help").action(c => { + printHelp(20).foreach(println(_)) + assumeStatus(CliStatus.Quit) + c + }).description("Display this message.") + + flag("--version").action(c => { + println("millfork version ") + assumeStatus(CliStatus.Quit) + System.exit(0) + c + }).description("Print the version and quit.") + + + default.action { (p, c) => + if (p.startsWith("-")) { + ErrorReporting.error(s"Invalid option `$p`", None) + c + } else { + c.copy(inputFileNames = c.inputFileNames :+ p) + } + } + + def assertNone[T](value: Option[T], msg: String): Unit = { + if (value.isDefined) { + ErrorReporting.error(msg, None) + } + } + } +} diff --git a/src/main/scala/millfork/OptimizationPresets.scala b/src/main/scala/millfork/OptimizationPresets.scala new file mode 100644 index 00000000..7b160e24 --- /dev/null +++ b/src/main/scala/millfork/OptimizationPresets.scala @@ -0,0 +1,150 @@ +package millfork + +import millfork.assembly.opt._ +import millfork.node.opt.{UnreachableCode, UnusedFunctions, UnusedGlobalVariables, UnusedLocalVariables} + +/** + * @author Karol Stasiak + */ +object OptimizationPresets { + val NodeOpt = List( + UnreachableCode, + UnusedFunctions, + UnusedLocalVariables, + UnusedGlobalVariables, + ) + val AssOpt: List[AssemblyOptimization] = List[AssemblyOptimization]( + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessLoadAfterLoadOrStore, + LaterOptimizations.PointessLoadingForShifting, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.BranchInPlaceRemoval, + UnusedLabelRemoval, + AlwaysGoodOptimizations.UnconditionalJumpRemoval, + UnusedLabelRemoval, + AlwaysGoodOptimizations.RearrangeMath, + LaterOptimizations.PointlessLoadAfterStore, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessOperationAfterLoad, + AlwaysGoodOptimizations.PointlessLoadBeforeTransfer, + VariableToRegisterOptimization, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessOperationPairRemoval, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + LaterOptimizations.PointlessLoadAfterStore, + AlwaysGoodOptimizations.PointlessOperationAfterLoad, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.ConstantIndexPropagation, + AlwaysGoodOptimizations.PointlessLoadBeforeReturn, + AlwaysGoodOptimizations.PoinlessFlagChange, + AlwaysGoodOptimizations.FlagFlowAnalysis, + AlwaysGoodOptimizations.ConstantFlowAnalysis, + AlwaysGoodOptimizations.PointlessMath, + VariableToRegisterOptimization, + ChangeIndexRegisterOptimizationPreferringX2Y, + VariableToRegisterOptimization, + ChangeIndexRegisterOptimizationPreferringY2X, + VariableToRegisterOptimization, + AlwaysGoodOptimizations.ConstantFlowAnalysis, + LaterOptimizations.DoubleLoadToDifferentRegisters, + LaterOptimizations.DoubleLoadToTheSameRegister, + LaterOptimizations.DoubleLoadToDifferentRegisters, + LaterOptimizations.DoubleLoadToTheSameRegister, + LaterOptimizations.DoubleLoadToDifferentRegisters, + LaterOptimizations.DoubleLoadToTheSameRegister, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.ConstantIndexPropagation, + AlwaysGoodOptimizations.ConstantFlowAnalysis, + AlwaysGoodOptimizations.PointlessRegisterTransfers, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeCompare, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeStore, + AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch, + AlwaysGoodOptimizations.RearrangeMath, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PointlessLoadBeforeReturn, + LaterOptimizations.PointessLoadingForShifting, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + + LaterOptimizations.LoadingAfterShifting, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PoinlessStoreBeforeStore, + LaterOptimizations.PointlessLoadAfterStore, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + + LaterOptimizations.LoadingAfterShifting, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PoinlessStoreBeforeStore, + LaterOptimizations.PointlessLoadAfterStore, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + + LaterOptimizations.LoadingAfterShifting, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PoinlessStoreBeforeStore, + LaterOptimizations.PointlessLoadAfterStore, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.TailCallOptimization, + AlwaysGoodOptimizations.UnusedCodeRemoval, + AlwaysGoodOptimizations.ReverseFlowAnalysis, + AlwaysGoodOptimizations.ModificationOfJustWrittenValue, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.MathOperationOnTwoIdenticalMemoryOperands, + LaterOptimizations.UseZeropageAddressingMode, + + LaterOptimizations.UseXInsteadOfStack, + LaterOptimizations.UseYInsteadOfStack, + LaterOptimizations.IndexSwitchingOptimization, + ) + + val Good: List[AssemblyOptimization] = List[AssemblyOptimization]( + AlwaysGoodOptimizations.Adc0Optimization, + AlwaysGoodOptimizations.CarryFlagConversion, + DangerousOptimizations.ConstantIndexOffsetPropagation, + AlwaysGoodOptimizations.BranchInPlaceRemoval, + AlwaysGoodOptimizations.ConstantFlowAnalysis, + AlwaysGoodOptimizations.ConstantIndexPropagation, + AlwaysGoodOptimizations.FlagFlowAnalysis, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.ImpossibleBranchRemoval, + AlwaysGoodOptimizations.IndexSequenceOptimization, + AlwaysGoodOptimizations.MathOperationOnTwoIdenticalMemoryOperands, + AlwaysGoodOptimizations.ModificationOfJustWrittenValue, + AlwaysGoodOptimizations.PoinlessFlagChange, + AlwaysGoodOptimizations.PointlessLoadAfterLoadOrStore, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessLoadBeforeReturn, + AlwaysGoodOptimizations.PointlessLoadBeforeTransfer, + AlwaysGoodOptimizations.PointlessMath, + AlwaysGoodOptimizations.PointlessMathFromFlow, + AlwaysGoodOptimizations.PointlessOperationAfterLoad, + AlwaysGoodOptimizations.PointlessOperationPairRemoval, + AlwaysGoodOptimizations.PointlessRegisterTransfers, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeCompare, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch, + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + AlwaysGoodOptimizations.PoinlessStoreBeforeStore, + AlwaysGoodOptimizations.RearrangeMath, + AlwaysGoodOptimizations.RemoveNops, + AlwaysGoodOptimizations.ReverseFlowAnalysis, + AlwaysGoodOptimizations.SimplifiableBitOpsSequence, + AlwaysGoodOptimizations.SmarterShiftingWords, + AlwaysGoodOptimizations.UnconditionalJumpRemoval, + UnusedLabelRemoval, + AlwaysGoodOptimizations.TailCallOptimization, + AlwaysGoodOptimizations.UnusedCodeRemoval, + ) +} diff --git a/src/main/scala/millfork/Platform.scala b/src/main/scala/millfork/Platform.scala new file mode 100644 index 00000000..1d05edf9 --- /dev/null +++ b/src/main/scala/millfork/Platform.scala @@ -0,0 +1,115 @@ +package millfork + +import java.io.{File, StringReader} +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} + +import millfork.error.ErrorReporting +import millfork.output._ +import org.apache.commons.configuration2.INIConfiguration + +/** + * @author Karol Stasiak + */ + +class Platform( + val cpu: Cpu.Value, + val flagOverrides: Map[CompilationFlag.Value, Boolean], + val startingModules: List[String], + val outputPackager: OutputPackager, + val allocator: VariableAllocator, + val org: Int, + val fileExtension: String, + ) + +object Platform { + + val C64 = new Platform( + Cpu.Mos, + Map(), + List("c64_hardware", "c64_loader"), + SequenceOutput(List(StartAddressOutput, AllocatedDataOutput)), + new VariableAllocator( + List(0xC1, 0xC3, 0xFB, 0xFD, 0x39, 0x3B, 0x3D, 0x43, 0x4B), + new AfterCodeByteAllocator(0xA000) + ), + 0x80D, + ".prg" + ) + + def lookupPlatformFile(includePath: List[String], platformName: String): Platform = { + includePath.foreach { dir => + val file = Paths.get(dir, platformName + ".ini").toFile + ErrorReporting.debug("Checking " + file) + if (file.exists()) { + return load(file) + } + } + ErrorReporting.fatal(s"Platfom definition `$platformName` not found", None) + } + + def load(file: File): Platform = { + val conf = new INIConfiguration() + val bytes = Files.readAllBytes(file.toPath) + conf.read(new StringReader(new String(bytes, StandardCharsets.UTF_8))) + + val cs = conf.getSection("compilation") + val cpu = Cpu.fromString(cs.get(classOf[String], "cpu", "strict")) + val flagOverrides = CompilationFlag.fromString.flatMap { case (k, f) => + cs.get(classOf[String], k, "").toLowerCase match { + case "" => None + case "false" | "off" | "0" => Some(f -> false) + case "true" | "on" | "1" => Some(f -> true) + } + } + val startingModules = cs.get(classOf[String], "modules", "").split("[, ]+").filter(_.nonEmpty).toList + + val as = conf.getSection("allocation") + val org = as.get(classOf[String], "main_org", "") match { + case "" => ErrorReporting.fatal(s"Undefined main_org") + case m => parseNumber(m) + } + val freePointers = as.get(classOf[String], "zp_pointers", "all") match { + case "all" => List.tabulate(128)(_ * 2) + case xs => xs.split("[, ]+").map(parseNumber).toList + } + val byteAllocator = (as.get(classOf[String], "himem_start", ""), as.get(classOf[String], "himem_end", "")) match { + case ("", _) => ErrorReporting.fatal(s"Undefined himem_start") + case (_, "") => ErrorReporting.fatal(s"Undefined himem_end") + case ("after_code", end) => new AfterCodeByteAllocator(parseNumber(end) + 1) + case (start, end) => new UpwardByteAllocator(parseNumber(start), parseNumber(end) + 1) + } + + val os = conf.getSection("output") + val outputPackager = SequenceOutput(os.get(classOf[String], "format", "").split("[, ]+").filter(_.nonEmpty).map { + case "startaddr" => StartAddressOutput + case "endaddr" => EndAddressOutput + case "allocated" => AllocatedDataOutput + case n => n.split(":").filter(_.nonEmpty) match { + case Array(b, s, e) => BankFragmentOutput(parseNumber(b), parseNumber(s), parseNumber(e)) + case Array(s, e) => CurrentBankFragmentOutput(parseNumber(s), parseNumber(e)) + case Array(b) => ConstOutput(parseNumber(b).toByte) + case x => ErrorReporting.fatal(s"Invalid output format: `$x`") + } + }.toList) + var fileExtension = os.get(classOf[String], "extension", ".bin") + + new Platform(cpu, flagOverrides, startingModules, outputPackager, + new VariableAllocator(freePointers, byteAllocator), org, + if (fileExtension.startsWith(".")) fileExtension else "." + fileExtension) + } + + def parseNumber(s: String): Int = { + if (s.startsWith("$")) { + Integer.parseInt(s.substring(1), 16) + } else if (s.startsWith("0x")) { + Integer.parseInt(s.substring(2), 16) + } else if (s.startsWith("%")) { + Integer.parseInt(s.substring(1), 2) + } else if (s.startsWith("0b")) { + Integer.parseInt(s.substring(2), 2) + } else { + s.toInt + } + } +} \ No newline at end of file diff --git a/src/main/scala/millfork/SeparatedList.scala b/src/main/scala/millfork/SeparatedList.scala new file mode 100644 index 00000000..c58c848c --- /dev/null +++ b/src/main/scala/millfork/SeparatedList.scala @@ -0,0 +1,50 @@ +package millfork + +/** + * @author Karol Stasiak + */ +case class SeparatedList[T, S](head: T, tail: List[(S, T)]) { + + def toPairList(initialSeparator: S) = (initialSeparator -> head) :: tail + + def size: Int = tail.size + 1 + + def items: List[T] = head :: tail.map(_._2) + + def separators: List[S] = tail.map(_._1) + + def drop(i: Int): SeparatedList[T, S] = if (i == 0) this else SeparatedList(tail(i - 1)._2, tail.drop(i)) + + def take(i: Int): SeparatedList[T, S] = if (i <= 0) ??? else SeparatedList(head, tail.take(i - 1)) + + def splitAt(i: Int): (SeparatedList[T, S], S, SeparatedList[T, S]) = { + val (a, b) = tail.splitAt(i - 1) + (SeparatedList(head, a), b.head._1, SeparatedList(b.head._2, b.tail)) + } + + def indexOfSeparator(p: S => Boolean): Int = 1 + tail.indexWhere(x => p(x._1)) + + def ::(pair: (T, S)) = SeparatedList(pair._1, (pair._2 -> head) :: tail) + + def split(p: S => Boolean): SeparatedList[SeparatedList[T, S], S] = { + val i = indexOfSeparator(p) + if (i <= 0) SeparatedList(this, Nil) + else { + val (a, b, c) = splitAt(i) + (a, b) :: c.split(p) + } + } +} + +object SeparatedList { + def of[T, S](t0: T): SeparatedList[T, S] = SeparatedList[T, S](t0, Nil) + + def of[T, S](t0: T, s1: S, t1: T): SeparatedList[T, S] = + SeparatedList(t0, List(s1 -> t1)) + + def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T): SeparatedList[T, S] = + SeparatedList(t0, List(s1 -> t1, s2 -> t2)) + + def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T, s3: S, t3: T): SeparatedList[T, S] = + SeparatedList(t0, List(s1 -> t1, s2 -> t2, s3 -> t3)) +} \ No newline at end of file diff --git a/src/main/scala/millfork/assembly/AssemblyLine.scala b/src/main/scala/millfork/assembly/AssemblyLine.scala new file mode 100644 index 00000000..a9a2a868 --- /dev/null +++ b/src/main/scala/millfork/assembly/AssemblyLine.scala @@ -0,0 +1,305 @@ +package millfork.assembly + +import java.lang.management.MemoryType + +import millfork.assembly.Opcode._ +import millfork.assembly.opt.ReadsA +import millfork.compiler.{CompilationContext, MlCompiler} +import millfork.env._ + +//noinspection TypeAnnotation +object OpcodeClasses { + + val ReadsAAlways = Set( + ADC, AND, BIT, CMP, EOR, ORA, PHA, SBC, STA, TAX, TAY, + SAX, SBX, ANC, DCP + ) + val ReadsAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC) + val ReadsXAlways = Set( + CPX, DEX, INX, STX, TXA, TXS, SBX, + PLX, + XAA, SAX, AHX, SHX + ) + val ReadsYAlways = Set(CPY, DEY, INY, STY, TYA, PLY, SHY) + val ReadsZ = Set(BNE, BEQ, PHP) + val ReadsN = Set(BMI, BPL, PHP) + val ReadsNOrZ = ReadsZ ++ ReadsN + val ReadsV = Set(BVS, BVC, PHP) + val ReadsD = Set(PHP, ADC, SBC, RRA, ARR, ISC, DCP) // TODO: ?? + val ReadsC = Set( + PHP, ADC, SBC, BCC, BCS, ROL, ROR, + ALR, ARR, ISC, RLA, RRA, SLO, SRE // TODO: ?? + ) + val ChangesAAlways = Set( + TXA, TYA, PLA, + ORA, AND, EOR, ADC, LDA, SBC, + SLO, RLA, SRE, RRA, LAX, ISC, + XAA, ANC, ALR, ARR, LXA, LAS + ) + val ChangesAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC) + val ChangesX = Set( + DEX, INX, TAX, LDX, TSX, + SBX, LAX, LXA, LAS, + PLX, + ) + val ChangesY = Set( + DEY, INY, TAY, LDY + ) + val ChangesS = Set( + PHA, PLA, PHP, PLP, TXS, + PHX, PHY, PLX, PLY, TAS, LAS + ) + val ChangesMemoryAlways = Set( + STA, STY, STZ, + STX, DEC, INC, + SAX, DCP, ISC, + SLO, RLA, SRE, RRA, + AHX, SHY, SHX, TAS, LAS + ) + val ChangesMemoryIfNotImplied = Set( + ASL, ROL, LSR, ROR + ) + val ReadsMemoryIfNotImpliedOrImmediate = Set( + LDY, CPX, CPY, + ORA, AND, EOR, ADC, LDA, CMP, SBC, + ASL, ROL, LSR, ROR, LDX, DEC, INC, + SLO, RLA, SRE, RRA, LAX, DCP, ISC, + LAS, + TRB, TSB + ) + val OverwritesA = Set( + LDA, PLA, TXA, TYA, + LAX, LAS + ) + val OverwritesX = Set( + TAX, LDX, TSX, PLX, + LAX, LAS + ) + val OverwritesY = Set( + TAY, LDY, PLY + ) + val OverwritesC = Set(CLC, SEC, PLP) + val OverwritesD = Set(CLD, SED, PLP) + val OverwritesI = Set(CLI, SEI, PLP) + val OverwritesV = Set(CLV, PLP) + val ConcernsAAlways = ReadsAAlways ++ ChangesAAlways + val ConcernsAIfImplied = ReadsAIfImplied ++ ChangesAIfImplied + val ConcernsXAlways = ReadsXAlways | ChangesX + val ConcernsYAlways = ReadsYAlways | ChangesY + + val ConcernsStack = Set( + PHA, PLA, PHP, PLP, + PHX, PLX, PHY, PLY, + TXS, TSX, + JSR, RTS, RTI, + TAS, LAS, + ) + + val ChangesNAndZ = Set( + ADC, AND, ASL, BIT, CMP, CPX, CPY, DEC, DEX, DEY, EOR, INC, INX, INY, LDA, + LDX, LDY, LSR, ORA, PLP, ROL, ROR, SBC, TAX, TAY, TXA, TYA, + LAX, SBX, ANC, ALR, ARR, DCP, ISC, RLA, RRA, SLO, SRE, SAX, + TSB, TRB // These two do not change N, but lets pretend they do for simplicity + ) + val ChangesC = Set( + CLC, SEC, ADC, ASL, CMP, CPX, CPY, LSR, PLP, ROL, ROR, SBC, + SBX, ANC, ALR, ARR, DCP, ISC, RLA, RRA, SLO, SRE + ) + val ChangesV = Set( + ADC, BIT, PLP, SBC, + ARR, ISC, RRA, + ) + + val SupportsAbsoluteX = Set( + ORA, AND, EOR, ADC, CMP, SBC, + ASL, ROL, LSR, ROR, DEC, INC, + SLO, RLA, SRE, RRA, DCP, ISC, + STA, LDA, LDY, STZ, SHY, + ) + + val SupportsAbsoluteY = Set( + ORA, AND, EOR, ADC, CMP, SBC, + SLO, RLA, SRE, RRA, DCP, ISC, + STA, LDA, LDX, + LAX, AHX, SHX, TAS, LAS, + ) + + val SupportsAbsolute = Set( + ORA, AND, EOR, ADC, STA, LDA, CMP, SBC, + ASL, ROL, LSR, ROR, STX, LDX, DEC, INC, + SLO, RLA, SRE, RRA, SAX, LAX, DCP, ISC, + STY, LDY, + BIT, JMP, JSR, + STZ, TRB, TSB, + ) + + val SupportsZeroPageIndirect = Set(ORA, AND, EOR, ADC, STA, LDA, CMP, SBC) + + val ShortBranching = Set(BEQ, BNE, BMI, BPL, BVC, BVS, BCC, BCS, BRA) + val AllDirectJumps = ShortBranching + JMP + val AllLinear = Set( + ORA, AND, EOR, + ADC, SBC, CMP, CPX, CPY, + DEC, DEX, DEY, INC, INX, INY, + ASL, ROL, LSR, ROR, + LDA, STA, LDX, STX, LDY, STY, + TAX, TXA, TAY, TYA, TXS, TSX, + PLA, PLP, PHA, PHP, + BIT, NOP, + CLC, SEC, CLD, SED, CLI, SEI, CLV, + STZ, PHX, PHY, PLX, PLY, TSB, TRB, + SLO, RLA, SRE, RRA, SAX, LAX, DCP, ISC, + ANC, ALR, ARR, XAA, LXA, SBX, + DISCARD_AF, DISCARD_XF, DISCARD_YF) + + val NoopDiscardsFlags = Set(DISCARD_AF, DISCARD_XF, DISCARD_YF) + val DiscardsV = NoopDiscardsFlags | OverwritesV + val DiscardsC = NoopDiscardsFlags | OverwritesC + val DiscardsD = OverwritesD + val DiscardsI = NoopDiscardsFlags | OverwritesI + +} + +object AssemblyLine { + + def treatment(lines: List[AssemblyLine], state: State.Value): Treatment.Value = + lines.map(_.treatment(state)).foldLeft(Treatment.Unchanged)(_ ~ _) + + def label(label: String): AssemblyLine = AssemblyLine.label(Label(label)) + + def label(label: Label): AssemblyLine = AssemblyLine(LABEL, AddrMode.DoesNotExist, label.toAddress) + + def discardAF() = AssemblyLine(DISCARD_AF, AddrMode.DoesNotExist, Constant.Zero) + + def discardXF() = AssemblyLine(DISCARD_XF, AddrMode.DoesNotExist, Constant.Zero) + + def discardYF() = AssemblyLine(DISCARD_YF, AddrMode.DoesNotExist, Constant.Zero) + + def immediate(opcode: Opcode.Value, value: Int) = AssemblyLine(opcode, AddrMode.Immediate, NumericConstant(value, 1)) + + def immediate(opcode: Opcode.Value, value: Constant) = AssemblyLine(opcode, AddrMode.Immediate, value) + + def implied(opcode: Opcode.Value) = AssemblyLine(opcode, AddrMode.Implied, Constant.Zero) + + def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] = + variable match { + case v@MemoryVariable(_, _, VariableAllocationMethod.Zeropage) => + List(AssemblyLine.zeropage(opcode, v.toAddress + offset)) + case v@RelativeVariable(_, _, _, true) => + List(AssemblyLine.zeropage(opcode, v.toAddress + offset)) + case v:VariableInMemory => List(AssemblyLine.absolute(opcode, v.toAddress + offset)) + case v:StackVariable=> List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(opcode, v.baseOffset + offset + ctx.extraStackOffset)) + } + + def zeropage(opcode: Opcode.Value, addr: Constant) = + AssemblyLine(opcode, AddrMode.ZeroPage, addr) + + def zeropage(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.ZeroPage, thing.toAddress + offset) + + def absolute(opcode: Opcode.Value, addr: Constant) = + AssemblyLine(opcode, AddrMode.Absolute, addr) + + def absolute(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.Absolute, thing.toAddress + offset) + + def relative(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.Relative, thing.toAddress + offset) + + def relative(opcode: Opcode.Value, label: String) = + AssemblyLine(opcode, AddrMode.Relative, Label(label).toAddress) + + def absoluteY(opcode: Opcode.Value, addr: Constant) = + AssemblyLine(opcode, AddrMode.AbsoluteY, addr) + + def absoluteY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.AbsoluteY, thing.toAddress + offset) + + def absoluteX(opcode: Opcode.Value, addr: Int) = + AssemblyLine(opcode, AddrMode.AbsoluteX, NumericConstant(addr, 2)) + + def absoluteX(opcode: Opcode.Value, addr: Constant) = + AssemblyLine(opcode, AddrMode.AbsoluteX, addr) + + def absoluteX(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.AbsoluteX, thing.toAddress + offset) + + def indexedY(opcode: Opcode.Value, addr: Constant) = + AssemblyLine(opcode, AddrMode.IndexedY, addr) + + def indexedY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) = + AssemblyLine(opcode, AddrMode.IndexedY, thing.toAddress + offset) +} + +case class AssemblyLine(opcode: Opcode.Value, addrMode: AddrMode.Value, var parameter: Constant, elidable: Boolean = true) { + + + import AddrMode._ + import State._ + import OpcodeClasses._ + import Treatment._ + + def reads(state: State.Value): Boolean = state match { + case A => if (addrMode == Implied) ReadsAIfImplied(opcode) else ReadsAAlways(opcode) + case X => addrMode == AbsoluteX || addrMode == ZeroPageX || addrMode == IndexedX || ReadsXAlways(opcode) + case Y => addrMode == AbsoluteY || addrMode == ZeroPageY || addrMode == IndexedY || ReadsYAlways(opcode) + case C => ReadsC(opcode) + case D => ReadsD(opcode) + case N => ReadsN(opcode) + case V => ReadsV(opcode) + case Z => ReadsZ(opcode) + } + + def treatment(state: State.Value): Treatment.Value = opcode match { + case LABEL => Unchanged // TODO: ??? + case NOP => Unchanged + case JSR | JMP | BEQ | BNE | BMI | BPL | BRK | BCC | BVC | BCS | BVS => Changed + case CLC => if (state == C) Cleared else Unchanged + case SEC => if (state == C) Set else Unchanged + case CLV => if (state == V) Cleared else Unchanged + case CLD => if (state == D) Cleared else Unchanged + case SED => if (state == D) Set else Unchanged + case _ => state match { // TODO: smart detection of constants + case A => + if (ChangesAAlways(opcode) || addrMode == Implied && ChangesAIfImplied(opcode)) + Changed + else + Unchanged + case X => if (ChangesX(opcode)) Changed else Unchanged + case Y => if (ChangesY(opcode)) Changed else Unchanged + case C => if (ChangesC(opcode)) Changed else Unchanged + case V => if (ChangesV(opcode)) Changed else Unchanged + case N | Z => if (ChangesNAndZ(opcode)) Changed else Unchanged + case D => Unchanged + } + } + + def sizeInBytes: Int = addrMode match { + case Implied => 1 + case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2 + case AbsoluteX | Absolute | AbsoluteY | Indirect => 3 + case DoesNotExist => 0 + } + + def cost: Int = addrMode match { + case Implied => 1000 + case Relative | Immediate => 2000 + case ZeroPage => 2001 + case ZeroPageX | ZeroPageY => 2002 + case IndexedX | IndexedY => 2003 + case Absolute => 3000 + case AbsoluteX | AbsoluteY | Indirect => 3001 + case DoesNotExist => 1 + } + + def isPrintable: Boolean = true //addrMode != AddrMode.DoesNotExist || opcode == LABEL + + override def toString: String = + if (opcode == LABEL) { + parameter.toString + } else if (addrMode == DoesNotExist) { + s" ; $opcode" + } else { + s" $opcode ${AddrMode.addrModeToString(addrMode, parameter.toString)}" + } +} diff --git a/src/main/scala/millfork/assembly/Chunk.scala b/src/main/scala/millfork/assembly/Chunk.scala new file mode 100644 index 00000000..117587b1 --- /dev/null +++ b/src/main/scala/millfork/assembly/Chunk.scala @@ -0,0 +1,33 @@ +package millfork.assembly + +import millfork.env.Label + +sealed trait Chunk { + def linearize: List[AssemblyLine] + def sizeInBytes: Int +} + +case object EmptyChunk extends Chunk { + override def linearize: Nil.type = Nil + + override def sizeInBytes = 0 +} + +case class LabelledChunk(label: String, chunk: Chunk) extends Chunk { + override def linearize: List[AssemblyLine] = AssemblyLine.label(Label(label)) :: chunk.linearize + + override def sizeInBytes: Int = chunk.sizeInBytes +} + +case class SequenceChunk(chunks: List[Chunk]) extends Chunk { + override def linearize: List[AssemblyLine] = chunks.flatMap(_.linearize) + + override def sizeInBytes: Int = chunks.map(_.sizeInBytes).sum +} + +case class LinearChunk(lines: List[AssemblyLine]) extends Chunk { + def linearize: List[AssemblyLine] = lines + + override def sizeInBytes: Int = lines.map(_.sizeInBytes).sum + +} \ No newline at end of file diff --git a/src/main/scala/millfork/assembly/Opcode.scala b/src/main/scala/millfork/assembly/Opcode.scala new file mode 100644 index 00000000..55e5b099 --- /dev/null +++ b/src/main/scala/millfork/assembly/Opcode.scala @@ -0,0 +1,190 @@ +package millfork.assembly + +import java.util.Locale + +import millfork.error.ErrorReporting +import millfork.node.Position + +object State extends Enumeration { + val A, X, Y, Z, D, C, N, V = Value +} + +object Treatment extends Enumeration { + val Unchanged, Unsure, Changed, Cleared, Set = Value + + implicit class OverriddenValue(val left: Value) extends AnyVal { + def ~(right: Treatment.Value): Treatment.Value = right match { + case Unchanged => left + case Cleared | Set => if (left == Unsure) Changed else right + case _ => right + } + } + +} + +object Opcode extends Enumeration { + val ADC, AND, ASL, + BIT, BNE, BEQ, BPL, BMI, BVS, BVC, BCC, BCS, BRK, + CMP, CPX, CPY, CLV, CLC, CLI, CLD, + DEC, DEX, DEY, + EOR, + INC, INX, INY, + JMP, JSR, + LDA, LDX, LDY, LSR, + NOP, + ORA, + PHA, PHP, PLA, PLP, + ROL, ROR, RTS, RTI, + SBC, SEC, SED, SEI, STA, STX, STY, + TAX, TAY, TXA, TXS, TSX, TYA, + + LXA, XAA, ANC, ARR, ALR, SBX, + LAX, SAX, RLA, RRA, SLO, SRE, DCP, ISC, + TAS, LAS, SHX, SHY, AHX, + STZ, PHX, PHY, PLX, PLY, + BRA, TRB, TSB, STP, WAI, + DISCARD_AF, DISCARD_XF, DISCARD_YF, + LABEL = Value + + def lookup(opcode: String, position: Option[Position]): Opcode.Value = opcode.toUpperCase(Locale.ROOT) match { + case "ADC" => ADC + case "AHX" => AHX + case "ALR" => ALR + case "ANC" => ANC + case "AND" => AND + case "ANE" => XAA + case "ARR" => ARR + case "ASL" => ASL + case "ASO" => SLO + case "AXA" => AHX + case "AXS" => SBX // TODO: could mean SAX + case "BCC" => BCC + case "BCS" => BCS + case "BEQ" => BEQ + case "BIT" => BIT + case "BMI" => BMI + case "BNE" => BNE + case "BPL" => BPL + case "BRA" => BRA + case "BRK" => BRK + case "BVC" => BVC + case "BVS" => BVS + case "CLC" => CLC + case "CLD" => CLD + case "CLI" => CLI + case "CLV" => CLV + case "CMP" => CMP + case "CPX" => CPX + case "CPY" => CPY + case "DCM" => DCP + case "DCP" => DCP + case "DEC" => DEC + case "DEX" => DEX + case "DEY" => DEY + case "EOR" => EOR + case "INC" => INC + case "INS" => ISC + case "INX" => INX + case "INY" => INY + case "ISC" => ISC + case "JMP" => JMP + case "JSR" => JSR + case "LAS" => LAS + case "LAX" => LAX + case "LDA" => LDA + case "LDX" => LDX + case "LDY" => LDY + case "LSE" => SRE + case "LSR" => LSR + case "LXA" => LXA + case "NOP" => NOP + case "OAL" => LXA + case "ORA" => ORA + case "PHA" => PHA + case "PHP" => PHP + case "PHX" => PHX + case "PHY" => PHY + case "PLA" => PLA + case "PLP" => PLP + case "PLX" => PLX + case "PLY" => PLY + case "RLA" => RLA + case "ROL" => ROL + case "ROR" => ROR + case "RRA" => RRA + case "RTI" => RTI + case "RTS" => RTS + case "SAX" => SAX // TODO: could mean SBX + case "SAY" => SHY + case "SBC" => SBC + case "SBX" => SBX + case "SEC" => SEC + case "SED" => SED + case "SEI" => SEI + case "SHX" => SHX + case "SHY" => SHY + case "SLO" => SLO + case "SRE" => SRE + case "STA" => STA + case "STP" => STP + case "STX" => STX + case "STY" => STY + case "STZ" => STZ + case "TAS" => TAS + case "TAX" => TAX + case "TAY" => TAY + case "TRB" => TRB + case "TSB" => TSB + case "TSX" => TSX + case "TXA" => TXA + case "TXS" => TXS + case "TYA" => TYA + case "WAI" => WAI + case "XAA" => XAA + case "XAS" => SHX + case _ => + ErrorReporting.error(s"Invalid opcode `$opcode`", position) + LABEL + } + +} + +object AddrMode extends Enumeration { + val Implied, + Immediate, + Relative, + ZeroPage, + ZeroPageX, + ZeroPageY, + Absolute, + AbsoluteX, + AbsoluteY, + Indirect, + IndexedX, + IndexedY, + AbsoluteIndexedX, + ZeroPageIndirect, + Undecided, + DoesNotExist = Value + + + def argumentLength(a: AddrMode.Value): Int = a match { + case Absolute | AbsoluteX | AbsoluteY | Indirect => + 2 + case _ => + 1 + } + + def addrModeToString(am: AddrMode.Value, argument: String): String = { + am match { + case Implied => "" + case Immediate => "#" + argument + case AbsoluteX | ZeroPageX => argument + ", X" + case AbsoluteY | ZeroPageY => argument + ", Y" + case IndexedX | AbsoluteIndexedX => "(" + argument + ", X)" + case IndexedY => "(" + argument + "), Y" + case Indirect | ZeroPageIndirect => "(" + argument + ")" + case _ => argument; + } + } +} diff --git a/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala new file mode 100644 index 00000000..fbf57de3 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala @@ -0,0 +1,848 @@ +package millfork.assembly.opt + +import java.util.UUID +import java.util.concurrent.atomic.AtomicInteger + +import millfork.assembly.{opt, _} +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.assembly.OpcodeClasses._ +import millfork.env._ + +/** + * These optimizations should not remove opportunities for more complex optimizations to trigger. + * + * @author Karol Stasiak + */ +object AlwaysGoodOptimizations { + + val counter = new AtomicInteger(30000) + + def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d" + + val PointlessMath = new RuleBasedAssemblyOptimization("Pointless math", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcode(CLC) & Elidable) ~ + (HasOpcode(ADC) & Elidable & MatchParameter(0)) ~ + (HasOpcode(SEC) & Elidable) ~ + (HasOpcode(SBC) & Elidable & MatchParameter(0)) ~ + (LinearOrLabel & Not(ReadsNOrZ) & Not(ReadsV) & Not(ReadsC) & Not(NoopDiscardsFlags) & Not(Set(ADC, SBC))).* ~ + (NoopDiscardsFlags | Set(ADC, SBC)) ~~> (_.drop(4)), + (HasOpcode(LDA) & HasImmediate(0) & Elidable) ~ + (HasOpcode(CLC) & Elidable) ~ + (HasOpcode(ADC) & Elidable) ~ + (LinearOrLabel & Not(ReadsV) & Not(NoopDiscardsFlags) & Not(ChangesNAndZ)).* ~ + (NoopDiscardsFlags | ChangesNAndZ) ~~> (code => code(2).copy(opcode = LDA) :: code.drop(3)) + ) + + val PointlessMathFromFlow = new RuleBasedAssemblyOptimization("Pointless math from flow analysis", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & MatchA(0) & + HasOpcode(ASL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil + }, + (Elidable & MatchA(0) & + HasOpcode(LSR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil + }, + (Elidable & MatchA(0) & + HasClear(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil + }, + (Elidable & MatchA(0) & + HasClear(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil + }, + (Elidable & MatchA(0) & + HasSet(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Int](0) * 2 + 1) :: Nil + }, + (Elidable & MatchA(0) & + HasSet(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, 0x80 + (ctx.get[Int](0) & 0xff) / 2) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(ADC) & HasAddrMode(Immediate) & + HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(ADC) & HasAddrMode(Immediate) & + HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.V)) ~ + Where(ctx => (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify match { + case NumericConstant(x, _) => x == (x & 0xff) + case _ => false + }) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(ADC) & HasAddrMode(Immediate) & + HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ((ctx.get[Int](0) + 1) & 0xff)) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(SBC) & HasAddrMode(Immediate) & + HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Minus, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(EOR) & HasAddrMode(Immediate)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Exor, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(ORA) & HasAddrMode(Immediate)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Or, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(AND) & HasAddrMode(Immediate)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil + }, + (Elidable & + MatchA(0) & MatchParameter(1) & + HasOpcode(ANC) & HasAddrMode(Immediate) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) => + AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil + }, + ) + + val MathOperationOnTwoIdenticalMemoryOperands = new RuleBasedAssemblyOptimization("Math operation on two identical memory operands", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~ + (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)), + + (HasOpcodeIn(Set(STA, LDA)) & HasAddrMode(AbsoluteX) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesX)).* ~ + (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteX) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)), + + (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrMode(AbsoluteY) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesY)).* ~ + (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteY) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)), + + (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~ + (DoesntMatterWhatItDoesWith(State.N, State.Z) & HasOpcodeIn(Set(ORA, AND)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init), + + (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~ + (DoesntMatterWhatItDoesWith(State.N, State.Z, State.C) & HasOpcode(ANC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init), + + (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~ + (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~ + (HasOpcode(EOR) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.immediate(LDA, 0)), + ) + + val PointlessStoreAfterLoad = new RuleBasedAssemblyOptimization("Pointless store after load", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + (HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~ + (Elidable & HasOpcode(STX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + (HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~ + (Elidable & HasOpcode(STY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + ) + + val PoinlessStoreBeforeStore = new RuleBasedAssemblyOptimization("Pointless store before store", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasAddrModeIn(Set(Absolute, ZeroPage)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~ + (LinearOrLabel & DoesNotConcernMemoryAt(2, 1)).* ~ + (MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~~> (_.tail), + (Elidable & HasAddrModeIn(Set(AbsoluteX, ZeroPageX)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~ + (LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesX)).* ~ + (MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~~> (_.tail), + (Elidable & HasAddrModeIn(Set(AbsoluteY, ZeroPageY)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~ + (LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesY)).* ~ + (MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~~> (_.tail), + ) + + val PointlessLoadBeforeReturn = new RuleBasedAssemblyOptimization("Pointless load before return", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Set(LDA, TXA, TYA, EOR, AND, ORA, ANC) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_AF))).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail), + (Set(LDX, TAX, TSX, INX, DEX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_XF))).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail), + (Set(LDY, TAY, INY, DEY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_YF))).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail), + (HasOpcode(LDX) & Elidable & MatchAddrMode(3)) ~ + (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~ + (HasOpcode(TXA) & Elidable) ~ + ((LinearOrLabel & Not(ConcernsX) & Not(HasOpcode(DISCARD_XF))).* ~ + HasOpcode(DISCARD_XF)).capture(2) ~~> { (c, ctx) => + ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2)) + }, + (HasOpcode(LDY) & Elidable & MatchAddrMode(3)) ~ + (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~ + (HasOpcode(TYA) & Elidable) ~ + ((LinearOrLabel & Not(ConcernsY) & Not(HasOpcode(DISCARD_YF))).* ~ + HasOpcode(DISCARD_YF)).capture(2) ~~> { (c, ctx) => + ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2)) + }, + ) + + private def operationPairBuilder(op1: Opcode.Value, op2: Opcode.Value, middle: AssemblyLinePattern) = { + (HasOpcode(op1) & Elidable) ~ + (Linear & middle).*.capture(1) ~ + (HasOpcode(op2) & Elidable) ~ + ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(2) ~~> { (_, ctx) => + ctx.get[List[AssemblyLine]](1) ++ ctx.get[List[AssemblyLine]](2) + } + } + + val PointlessOperationPairRemoval = new RuleBasedAssemblyOptimization("Pointless operation pair", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + operationPairBuilder(PHA, PLA, Not(ConcernsA) & Not(ConcernsStack)), + operationPairBuilder(PHX, PLX, Not(ConcernsX) & Not(ConcernsStack)), + operationPairBuilder(PHY, PLY, Not(ConcernsY) & Not(ConcernsStack)), + operationPairBuilder(INX, DEX, Not(ConcernsX) & Not(ReadsNOrZ)), + operationPairBuilder(DEX, INX, Not(ConcernsX) & Not(ReadsNOrZ)), + operationPairBuilder(INY, DEY, Not(ConcernsX) & Not(ReadsNOrZ)), + operationPairBuilder(DEY, INY, Not(ConcernsX) & Not(ReadsNOrZ)), + ) + + + val BranchInPlaceRemoval = new RuleBasedAssemblyOptimization("Branch in place", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (AllDirectJumps & MatchParameter(0) & Elidable) ~ + HasOpcodeIn(NoopDiscardsFlags).* ~ + (HasOpcode(LABEL) & MatchParameter(0)) ~~> (c => c.last :: Nil) + ) + + val ImpossibleBranchRemoval = new RuleBasedAssemblyOptimization("Impossible branch", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (HasOpcode(BCC) & HasSet(State.C) & Elidable) ~~> (_ => Nil), + (HasOpcode(BCS) & HasClear(State.C) & Elidable) ~~> (_ => Nil), + (HasOpcode(BVC) & HasSet(State.V) & Elidable) ~~> (_ => Nil), + (HasOpcode(BVS) & HasClear(State.V) & Elidable) ~~> (_ => Nil), + (HasOpcode(BNE) & HasSet(State.Z) & Elidable) ~~> (_ => Nil), + (HasOpcode(BEQ) & HasClear(State.Z) & Elidable) ~~> (_ => Nil), + (HasOpcode(BPL) & HasSet(State.N) & Elidable) ~~> (_ => Nil), + (HasOpcode(BMI) & HasClear(State.N) & Elidable) ~~> (_ => Nil), + ) + + val UnconditionalJumpRemoval = new RuleBasedAssemblyOptimization("Unconditional jump removal", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~ + (Elidable & LinearOrBranch).* ~ + (HasOpcode(LABEL) & MatchParameter(0)) ~~> (_ => Nil), + (Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~ + (Not(HasOpcode(LABEL)) & Not(MatchParameter(0))).* ~ + (HasOpcode(LABEL) & MatchParameter(0)) ~ + (HasOpcode(LABEL) | HasOpcodeIn(NoopDiscardsFlags)).* ~ + HasOpcode(RTS) ~~> (code => AssemblyLine.implied(RTS) :: code.tail), + (Elidable & HasOpcodeIn(ShortBranching) & MatchParameter(0)) ~ + (HasOpcodeIn(NoopDiscardsFlags).* ~ + (Elidable & HasOpcode(RTS))).capture(1) ~ + (HasOpcode(LABEL) & MatchParameter(0)) ~ + HasOpcodeIn(NoopDiscardsFlags).* ~ + (Elidable & HasOpcode(RTS)) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1)), + ) + + val TailCallOptimization = new RuleBasedAssemblyOptimization("Tail call optimization", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(JSR)) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (Elidable & HasOpcode(RTS)) ~~> (c => c.tail.init :+ c.head.copy(opcode = JMP)), + (Elidable & HasOpcode(JSR)) ~ + HasOpcode(LABEL).* ~ + HasOpcodeIn(NoopDiscardsFlags).*.capture(0) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](0) ++ (code.head.copy(opcode = JMP) :: code.tail)), + ) + + val UnusedCodeRemoval = new RuleBasedAssemblyOptimization("Unreachable code removal", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + HasOpcode(JMP) ~ (Not(HasOpcode(LABEL)) & Elidable).+ ~~> (c => c.head :: Nil) + ) + + val PoinlessFlagChange = new RuleBasedAssemblyOptimization("Pointless flag change", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcodeIn(Set(CMP, CPX, CPY)) & Elidable) ~ NoopDiscardsFlags ~~> (_.tail), + (OverwritesC & Elidable) ~ (LinearOrLabel & Not(ReadsC) & Not(DiscardsC)).* ~ DiscardsC ~~> (_.tail), + (OverwritesD & Elidable) ~ (LinearOrLabel & Not(ReadsD) & Not(DiscardsD)).* ~ DiscardsD ~~> (_.tail), + (OverwritesV & Elidable) ~ (LinearOrLabel & Not(ReadsV) & Not(DiscardsV)).* ~ DiscardsV ~~> (_.tail) + ) + + val FlagFlowAnalysis = new RuleBasedAssemblyOptimization("Flag flow analysis", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (HasSet(State.C) & HasOpcode(SEC) & Elidable) ~~> (_ => Nil), + (HasSet(State.D) & HasOpcode(SED) & Elidable) ~~> (_ => Nil), + (HasClear(State.C) & HasOpcode(CLC) & Elidable) ~~> (_ => Nil), + (HasClear(State.D) & HasOpcode(CLD) & Elidable) ~~> (_ => Nil), + (HasClear(State.V) & HasOpcode(CLV) & Elidable) ~~> (_ => Nil), + (HasSet(State.C) & HasOpcode(BCS) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasClear(State.C) & HasOpcode(BCC) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasSet(State.N) & HasOpcode(BMI) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasClear(State.N) & HasOpcode(BPL) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasClear(State.V) & HasOpcode(BVC) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasSet(State.V) & HasOpcode(BVS) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasSet(State.Z) & HasOpcode(BEQ) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))), + (HasClear(State.Z) & HasOpcode(BNE) & Elidable) ~~> (_ => Nil), + ) + + val ReverseFlowAnalysis = new RuleBasedAssemblyOptimization("Reverse flow analysis", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (Elidable & HasOpcodeIn(Set(TXA, TYA, LDA, EOR, ORA, AND)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcode(ANC) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(TAX, TSX, LDX, INX, DEX)) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(TAY, LDY, DEY, INY)) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(LAX)) & DoesntMatterWhatItDoesWith(State.A, State.X, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(SEC, CLC)) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(CLD, SED)) & DoesntMatterWhatItDoesWith(State.D)) ~~> (_ => Nil), + (Elidable & HasOpcode(CLV) & DoesntMatterWhatItDoesWith(State.V)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(CMP, CPX, CPY)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(BIT)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z, State.V)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(ASL, LSR, ROL, ROR)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(ADC, SBC)) & DoesntMatterWhatItDoesWith(State.A, State.C, State.V, State.N, State.Z)) ~~> (_ => Nil), + ) + + private def modificationOfJustWrittenValue(store: Opcode.Value, + addrMode: AddrMode.Value, + initExtra: AssemblyLinePattern, + modify: Opcode.Value, + meantimeExtra: AssemblyLinePattern, + atLeastTwo: Boolean, + flagsToTrash: Seq[State.Value], + fix: ((AssemblyMatchingContext, Int) => List[AssemblyLine]), + alternateStore: Opcode.Value = LABEL) = { + val actualFlagsToTrash = List(State.N, State.Z) ++ flagsToTrash + val init = Elidable & HasOpcode(store) & HasAddrMode(addrMode) & MatchAddrMode(3) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*) & initExtra + val meantime = (Linear & Not(ConcernsMemory) & meantimeExtra).* + val oneModification = Elidable & HasOpcode(modify) & HasAddrMode(addrMode) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*) + val modifications = (if (atLeastTwo) oneModification ~ oneModification.+ else oneModification.+).captureLength(1) + if (alternateStore == LABEL) { + ((init ~ meantime).capture(2) ~ modifications) ~~> ((code, ctx) => fix(ctx, ctx.get[Int](1)) ++ ctx.get[List[AssemblyLine]](2)) + } else { + (init.capture(3) ~ meantime.capture(2) ~ modifications) ~~> { (code, ctx) => + fix(ctx, ctx.get[Int](1)) ++ + List(AssemblyLine(alternateStore, ctx.get[AddrMode.Value](3), ctx.get[Constant](0))) ++ + ctx.get[List[AssemblyLine]](2) + } + } + } + + val ModificationOfJustWrittenValue = new RuleBasedAssemblyOptimization("Modification of Just written value", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + modificationOfJustWrittenValue(STA, Absolute, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff) + )), + modificationOfJustWrittenValue(STA, Absolute, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff) + )), + modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff) + )), + modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff) + )), + modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), INC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff) + )), + modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), DEC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List( + AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff) + )), + modificationOfJustWrittenValue(STA, Absolute, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(CLC), + AssemblyLine.immediate(ADC, i) + )), + modificationOfJustWrittenValue(STA, Absolute, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(SEC), + AssemblyLine.immediate(SBC, i) + )), + modificationOfJustWrittenValue(STA, ZeroPage, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(CLC), + AssemblyLine.immediate(ADC, i) + )), + modificationOfJustWrittenValue(STA, ZeroPage, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(SEC), + AssemblyLine.immediate(SBC, i) + )), + modificationOfJustWrittenValue(STA, AbsoluteX, Anything, INC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(CLC), + AssemblyLine.immediate(ADC, i) + )), + modificationOfJustWrittenValue(STA, AbsoluteX, Anything, DEC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List( + AssemblyLine.implied(SEC), + AssemblyLine.immediate(SBC, i) + )), + modificationOfJustWrittenValue(STA, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))), + modificationOfJustWrittenValue(STA, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))), + modificationOfJustWrittenValue(STA, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))), + modificationOfJustWrittenValue(STA, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))), + modificationOfJustWrittenValue(STA, AbsoluteX, Anything, ASL, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))), + modificationOfJustWrittenValue(STA, AbsoluteX, Anything, LSR, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))), + modificationOfJustWrittenValue(STX, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))), + modificationOfJustWrittenValue(STX, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))), + modificationOfJustWrittenValue(STY, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))), + modificationOfJustWrittenValue(STY, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))), + modificationOfJustWrittenValue(STZ, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil), + modificationOfJustWrittenValue(STZ, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil), + modificationOfJustWrittenValue(STZ, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA), + modificationOfJustWrittenValue(STZ, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA), + modificationOfJustWrittenValue(STX, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))), + modificationOfJustWrittenValue(STX, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))), + modificationOfJustWrittenValue(STY, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))), + modificationOfJustWrittenValue(STY, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))), + modificationOfJustWrittenValue(STZ, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil), + modificationOfJustWrittenValue(STZ, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil), + modificationOfJustWrittenValue(STZ, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA), + modificationOfJustWrittenValue(STZ, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA), + ) + + val ConstantFlowAnalysis = new RuleBasedAssemblyOptimization("Constant flow analysis", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (MatchX(0) & HasAddrMode(AbsoluteX) & SupportsAbsolute & Elidable) ~~> { (code, ctx) => + code.map(l => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0))) + }, + (MatchY(0) & HasAddrMode(AbsoluteY) & SupportsAbsolute & Elidable) ~~> { (code, ctx) => + code.map(l => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0))) + }, + (MatchX(0) & HasAddrMode(ZeroPageX) & Elidable) ~~> { (code, ctx) => + code.map(l => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0))) + }, + (MatchY(0) & HasAddrMode(ZeroPageY) & Elidable) ~~> { (code, ctx) => + code.map(l => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0))) + }, + ) + + val IdempotentDuplicateRemoval = new RuleBasedAssemblyOptimization("Idempotent duplicate operation", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + HasOpcode(RTS) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTS) ~ Elidable) ~~> (_.take(1)) :: + HasOpcode(RTI) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTI) ~ Elidable) ~~> (_.take(1)) :: + HasOpcode(DISCARD_XF) ~ (Not(HasOpcode(DISCARD_XF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail) :: + HasOpcode(DISCARD_AF) ~ (Not(HasOpcode(DISCARD_AF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail) :: + HasOpcode(DISCARD_YF) ~ (Not(HasOpcode(DISCARD_YF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail) :: + List(RTS, RTI, SEC, CLC, CLV, CLD, SED, SEI, CLI, TAX, TXA, TYA, TAY, TXS, TSX).flatMap { opcode => + Seq( + (HasOpcode(opcode) & Elidable) ~ (HasOpcodeIn(NoopDiscardsFlags) | HasOpcode(LABEL)).* ~ HasOpcode(opcode) ~~> (_.tail), + HasOpcode(opcode) ~ (HasOpcode(opcode) ~ Elidable) ~~> (_.init), + ) + }: _* + ) + + val PointlessRegisterTransfers = new RuleBasedAssemblyOptimization("Pointless register transfers", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + HasOpcode(TYA) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init), + HasOpcode(TXA) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init), + HasOpcode(TAY) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init), + HasOpcode(TAX) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init), + HasOpcode(TSX) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init), + HasOpcode(TXS) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init), + HasOpcode(TSX) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init), + HasOpcode(TXS) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init), + ) + + val PointlessRegisterTransfersBeforeStore = new RuleBasedAssemblyOptimization("Pointless register transfers before store", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (Elidable & HasOpcode(TXA)) ~ + (Linear & Not(ConcernsA) & Not(ConcernsX)).* ~ + (Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageY, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STX)), + (Elidable & HasOpcode(TYA)) ~ + (Linear & Not(ConcernsA) & Not(ConcernsY)).* ~ + (Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STY)), + ) + + + val PointlessRegisterTransfersBeforeReturn = new RuleBasedAssemblyOptimization("Pointless register transfers before return", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcode(TAX) & Elidable) ~ + HasOpcode(LABEL).* ~ + HasOpcode(TXA).? ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)), + (HasOpcode(TSX) & Elidable) ~ + HasOpcode(LABEL).* ~ + HasOpcode(TSX).? ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)), + (HasOpcode(TXA) & Elidable) ~ + HasOpcode(LABEL).* ~ + HasOpcode(TAX).? ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)), + (HasOpcode(TAY) & Elidable) ~ + HasOpcode(LABEL).* ~ + HasOpcode(TYA).? ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_YF)).capture(1) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)), + (HasOpcode(TYA) & Elidable) ~ + HasOpcode(LABEL).* ~ + HasOpcode(TAY).? ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~ + HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)), + ) + + val PointlessRegisterTransfersBeforeCompare = new RuleBasedAssemblyOptimization("Pointless register transfers before compare", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + HasOpcodeIn(Set(DEX, INX, LDX, LAX)) ~ + (HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init), + HasOpcodeIn(Set(DEY, INY, LDY)) ~ + (HasOpcode(TYA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init), + ) + + private def stashing(tai: Opcode.Value, tia: Opcode.Value, readsI: AssemblyLinePattern, concernsI: AssemblyLinePattern, discardIF: Opcode.Value, withRts: Boolean, withBeq: Boolean) = { + val init: AssemblyPattern = if (withBeq) { + (Linear & ChangesNAndZ & ChangesA) ~ + (HasOpcode(tai) & Elidable) ~ + (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~ + (ShortBranching & ReadsNOrZ & MatchParameter(0)) + } else { + (HasOpcode(tai) & Elidable) ~ + (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~ + ((ShortBranching -- ReadsNOrZ) & MatchParameter(0)) + } + val inner: AssemblyPattern = if (withRts) { + (Linear & Not(readsI) & Not(ReadsNOrZ ++ NoopDiscardsFlags)).* ~ + ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(discardIF)) ~ + HasOpcodeIn(Set(RTS, RTI)) ~ + Not(HasOpcode(LABEL)).* + } else { + (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* + } + val end: AssemblyPattern = + (HasOpcode(LABEL) & MatchParameter(0)) ~ + (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~ + (HasOpcode(tia) & Elidable) + val total = init ~ inner ~ end + if (withBeq) { + total ~~> (code => code.head :: (code.tail.tail.init :+ AssemblyLine.implied(tai))) + } else { + total ~~> (code => code.tail.init :+ AssemblyLine.implied(tai)) + } + } + + // Optimize the following patterns: + // TAX - B__ .a - don't change A - .a - TXA + // TAX - B__ .a - change A – discard X – RTS - .a - TXA + // by removing the first transfer and flipping the second one + val PointlessStashingToIndexOverShortSafeBranch = new RuleBasedAssemblyOptimization("Pointless stashing into index over short safe branch", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = false), + stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = false), + // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = true), + // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = true), + // + // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = false), + // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = false), + // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = true), + // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = true), + ) + + private def loadBeforeTransfer(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, ams: Set[AddrMode.Value]) = + (Elidable & HasOpcode(ld1) & MatchAddrMode(0) & MatchParameter(1) & HasAddrModeIn(ams)) ~ + (Linear & Not(ReadsNOrZ) & Not(concerns1) & DoesntChangeMemoryAt(0, 1) & DoesntChangeIndexingInAddrMode(0) & Not(HasOpcode(t12))).*.capture(2) ~ + (HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2) + } + + val PointlessLoadBeforeTransfer = new RuleBasedAssemblyOptimization("Pointless load before transfer", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + loadBeforeTransfer(LDX, LDA, ConcernsX, State.X, TXA, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)), + loadBeforeTransfer(LDA, LDX, ConcernsA, State.A, TAX, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)), + loadBeforeTransfer(LDY, LDA, ConcernsY, State.Y, TYA, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)), + loadBeforeTransfer(LDA, LDY, ConcernsA, State.A, TAY, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)), + ) + + private def immediateLoadBeforeTwoTransfers(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, t21: Opcode.Value) = + (Elidable & HasOpcode(ld1) & HasAddrMode(Immediate)) ~ + (Linear & Not(ReadsNOrZ) & Not(concerns1) & Not(HasOpcode(t12))).*.capture(2) ~ + (HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2) + } + + val YYY = new RuleBasedAssemblyOptimization("Pointless load before transfer", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + immediateLoadBeforeTwoTransfers(LDA, LDY, ConcernsA, State.A, TAY, TYA), + ) + + val ConstantIndexPropagation = new RuleBasedAssemblyOptimization("Constant index propagation", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcode(LDX) & HasAddrMode(Immediate) & MatchParameter(0)) ~ + (Linear & Not(ChangesX) & Not(HasAddrMode(AbsoluteX))).* ~ + (Elidable & SupportsAbsolute & HasAddrMode(AbsoluteX)) ~~> { (lines, ctx) => + val last = lines.last + val offset = ctx.get[Constant](0) + lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset) + }, + (HasOpcode(LDY) & HasAddrMode(Immediate) & MatchParameter(0)) ~ + (Linear & Not(ChangesY) & Not(HasAddrMode(AbsoluteY))).* ~ + (Elidable & SupportsAbsolute & HasAddrMode(AbsoluteY)) ~~> { (lines, ctx) => + val last = lines.last + val offset = ctx.get[Constant](0) + lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset) + }, + ) + + val PoinlessLoadBeforeAnotherLoad = new RuleBasedAssemblyOptimization("Pointless load before another load", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Set(LDA, TXA, TYA) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ)).* ~ OverwritesA ~~> (_.tail), + (Set(LDX, TAX, TSX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ)).* ~ OverwritesX ~~> (_.tail), + (Set(LDY, TAY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ)).* ~ OverwritesY ~~> (_.tail), + ) + + // TODO: better proofs that memory doesn't change + val PointlessLoadAfterLoadOrStore = new RuleBasedAssemblyOptimization("Pointless load after load or store", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + + (HasOpcodeIn(Set(LDA, STA)) & HasAddrMode(Implied) & MatchParameter(1)) ~ + (Linear & Not(ChangesA)).* ~ + (Elidable & HasOpcode(LDA) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init), + + (HasOpcodeIn(Set(LDX, STX)) & HasAddrMode(Implied) & MatchParameter(1)) ~ + (Linear & Not(ChangesX)).* ~ + (Elidable & HasOpcode(LDX) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init), + + (HasOpcodeIn(Set(LDY, STY)) & HasAddrMode(Implied) & MatchParameter(1)) ~ + (Linear & Not(ChangesY)).* ~ + (Elidable & HasOpcode(LDY) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init), + + (HasOpcodeIn(Set(LDA, STA)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + + (HasOpcodeIn(Set(LDX, STX)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ChangesX) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + + (HasOpcodeIn(Set(LDY, STY)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ChangesY) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init), + ) + + val PointlessOperationAfterLoad = new RuleBasedAssemblyOptimization("Pointless operation after load", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(EOR) & HasImmediate(0)) ~~> (_.init), + (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(ORA) & HasImmediate(0)) ~~> (_.init), + (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(AND) & HasImmediate(0xff)) ~~> (_.init) + ) + + val SimplifiableBitOpsSequence = new RuleBasedAssemblyOptimization("Simplifiable sequence of bit operations", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(EOR) & MatchImmediate(0)) ~ + (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~ + (Elidable & HasOpcode(EOR) & MatchImmediate(1)) ~~> { (lines, ctx) => + lines.init.tail :+ AssemblyLine.immediate(EOR, CompoundConstant(MathOperator.Exor, ctx.get[Constant](0), ctx.get[Constant](1))) + }, + (Elidable & HasOpcode(ORA) & MatchImmediate(0)) ~ + (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~ + (Elidable & HasOpcode(ORA) & MatchImmediate(1)) ~~> { (lines, ctx) => + lines.init.tail :+ AssemblyLine.immediate(ORA, CompoundConstant(MathOperator.Or, ctx.get[Constant](0), ctx.get[Constant](1))) + }, + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~ + (Elidable & HasOpcode(AND) & MatchImmediate(1)) ~~> { (lines, ctx) => + lines.init.tail :+ AssemblyLine.immediate(AND, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1))) + }, + (Elidable & HasOpcode(ANC) & MatchImmediate(0)) ~ + (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsC) & Not(ReadsA)).* ~ + (Elidable & HasOpcode(ANC) & MatchImmediate(1)) ~~> { (lines, ctx) => + lines.init.tail :+ AssemblyLine.immediate(ANC, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1))) + }, + ) + + val RemoveNops = new RuleBasedAssemblyOptimization("Removing NOP instructions", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(NOP)) ~~> (_ => Nil) + ) + + val RearrangeMath = new RuleBasedAssemblyOptimization("Rearranging math", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~ + (Elidable & HasOpcodeIn(Set(CLC, SEC))) ~ + (Elidable & HasOpcode(ADC) & Not(HasAddrMode(Immediate))) ~~> { c => + c.last.copy(opcode = LDA) :: c(1) :: c.head.copy(opcode = ADC) :: Nil + }, + (Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~ + (Elidable & HasOpcodeIn(Set(ADC, EOR, ORA, AND)) & Not(HasAddrMode(Immediate))) ~~> { c => + c.last.copy(opcode = LDA) :: c.head.copy(opcode = c.last.opcode) :: Nil + }, + ) + + private def wordShifting(i: Int, hiFirst: Boolean, hiFromX: Boolean) = { + val ldax = if (hiFromX) LDX else LDA + val stax = if (hiFromX) STX else STA + val restriction = if (hiFromX) Not(ReadsX) else Anything + val originalStart = if (hiFirst) { + (Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~ + (Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction) ~ + (Elidable & HasOpcode(ldax) & HasImmediate(0)) ~ + (Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5)) + } else { + (Elidable & HasOpcode(ldax) & HasImmediate(0)) ~ + (Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5)) ~ + (Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~ + (Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction) + } + val middle = (Linear & Not(ConcernsMemory) & DoesntChangeIndexingInAddrMode(3) & DoesntChangeIndexingInAddrMode(5)).* + val singleOriginalShift = + (Elidable & HasOpcode(ASL) & MatchParameter(2) & MatchAddrMode(3)) ~ + (Elidable & HasOpcode(ROL) & MatchParameter(4) & MatchAddrMode(5) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z)) + val originalShifting = (1 to i).map(_ => singleOriginalShift).reduce(_ ~ _) + originalStart ~ middle.capture(6) ~ originalShifting ~~> { (code, ctx) => + val newStart = List( + code(0), + code(1).copy(addrMode = code(3).addrMode, parameter = code(3).parameter), + code(2), + code(3).copy(addrMode = code(1).addrMode, parameter = code(1).parameter)) + val middle = ctx.get[List[AssemblyLine]](6) + val singleNewShift = List( + AssemblyLine(LSR, ctx.get[AddrMode.Value](5), ctx.get[Constant](4)), + AssemblyLine(ROR, ctx.get[AddrMode.Value](3), ctx.get[Constant](2))) + newStart ++ middle ++ (i until 8).flatMap(_ => singleNewShift) + } + } + + val SmarterShiftingWords = new RuleBasedAssemblyOptimization("Smarter shifting of words", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + wordShifting(8, hiFirst = false, hiFromX = true), + wordShifting(8, hiFirst = false, hiFromX = false), + wordShifting(8, hiFirst = true, hiFromX = true), + wordShifting(8, hiFirst = true, hiFromX = false), + wordShifting(7, hiFirst = false, hiFromX = true), + wordShifting(7, hiFirst = false, hiFromX = false), + wordShifting(7, hiFirst = true, hiFromX = true), + wordShifting(7, hiFirst = true, hiFromX = false), + wordShifting(6, hiFirst = false, hiFromX = true), + wordShifting(6, hiFirst = false, hiFromX = false), + wordShifting(6, hiFirst = true, hiFromX = true), + wordShifting(6, hiFirst = true, hiFromX = false), + wordShifting(5, hiFirst = false, hiFromX = true), + wordShifting(5, hiFirst = false, hiFromX = false), + wordShifting(5, hiFirst = true, hiFromX = true), + wordShifting(5, hiFirst = true, hiFromX = false), + ) + + private def carryFlagConversionCase(shift: Int, firstSet: Boolean, zeroIfSet: Boolean) = { + val nonZero = 1 << shift + val test = Elidable & HasOpcode(if (firstSet) BCC else BCS) & MatchParameter(0) + val ifSet = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) 0 else nonZero) + val ifClear = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) nonZero else 0) + val jump = Elidable & HasOpcodeIn(Set(JMP, if (firstSet) BCS else BCC, if (zeroIfSet) BEQ else BNE)) & MatchParameter(1) + val elseLabel = Elidable & HasOpcode(LABEL) & MatchParameter(0) + val afterLabel = Elidable & HasOpcode(LABEL) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z) + val store = Elidable & (Not(ReadsC) & Linear | HasOpcodeIn(Set(RTS, JSR, RTI))) + val secondReturn = (Elidable & HasOpcodeIn(Set(RTS, RTI) | NoopDiscardsFlags)).*.capture(6) + val where = Where { ctx => + ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) || + ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) ++ ctx.get[List[AssemblyLine]](6) + } + val pattern = + if (firstSet) test ~ ifSet ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifClear ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where + else test ~ ifClear ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifSet ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where + pattern ~~> { (_, ctx) => + List( + AssemblyLine.immediate(LDA, 0), + AssemblyLine.implied(if (shift >= 4) ROR else ROL)) ++ + (if (shift >= 4) List.fill(7 - shift)(AssemblyLine.implied(LSR)) else List.fill(shift)(AssemblyLine.implied(ASL))) ++ + (if (zeroIfSet) List(AssemblyLine.immediate(EOR, nonZero)) else Nil) ++ + ctx.get[List[AssemblyLine]](5) ++ + ctx.get[List[AssemblyLine]](6) + } + } + + val CarryFlagConversion = new RuleBasedAssemblyOptimization("Carry flag conversion", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + // TODO: These yield 2 cycles more but 1–2 bytes less + // TODO: Add an "optimize for size" compilation option? + // carryFlagConversionCase(2, firstSet = false, zeroIfSet = false), + // carryFlagConversionCase(2, firstSet = true, zeroIfSet = false), + // carryFlagConversionCase(1, firstSet = true, zeroIfSet = true), + // carryFlagConversionCase(1, firstSet = false, zeroIfSet = true), + carryFlagConversionCase(1, firstSet = false, zeroIfSet = false), + carryFlagConversionCase(1, firstSet = true, zeroIfSet = false), + carryFlagConversionCase(0, firstSet = true, zeroIfSet = true), + carryFlagConversionCase(0, firstSet = false, zeroIfSet = true), + carryFlagConversionCase(0, firstSet = false, zeroIfSet = false), + carryFlagConversionCase(0, firstSet = true, zeroIfSet = false), + // carryFlagConversionCase(5, firstSet = false, zeroIfSet = false), + // carryFlagConversionCase(5, firstSet = true, zeroIfSet = false), + // carryFlagConversionCase(6, firstSet = true, zeroIfSet = true), + // carryFlagConversionCase(6, firstSet = false, zeroIfSet = true), + carryFlagConversionCase(6, firstSet = false, zeroIfSet = false), + carryFlagConversionCase(6, firstSet = true, zeroIfSet = false), + carryFlagConversionCase(7, firstSet = true, zeroIfSet = true), + carryFlagConversionCase(7, firstSet = false, zeroIfSet = true), + carryFlagConversionCase(7, firstSet = false, zeroIfSet = false), + carryFlagConversionCase(7, firstSet = true, zeroIfSet = false), + ) + + val Adc0Optimization = new RuleBasedAssemblyOptimization("ADC #0/#1 optimization", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~ + (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("ah") + List( + AssemblyLine.relative(BCC, label), + code.last.copy(opcode = INC), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("ah") + List( + AssemblyLine.relative(BCC, label), + code.last.copy(opcode = INC), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~ + (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = INC)) + }, + (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.C) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(1) & HasClear(State.D)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = INC)) + }, + (Elidable & HasOpcode(TXA) & HasClear(State.D)) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~ + (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("ah") + List( + AssemblyLine.relative(BCC, label), + AssemblyLine.implied(INX), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(TYA) & HasClear(State.D)) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~ + (Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("ah") + List( + AssemblyLine.relative(BCC, label), + AssemblyLine.implied(INY), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(TXA) & HasClear(State.D) & HasClear(State.C)) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~ + (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(AssemblyLine.implied(INX)) + }, + (Elidable & HasOpcode(TYA) & HasClear(State.D) & HasClear(State.C)) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~ + (Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(AssemblyLine.implied(INY)) + }, + ) + + val IndexSequenceOptimization = new RuleBasedAssemblyOptimization("Index sequence optimization", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil), + (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INY))), + (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEY))), + (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil), + (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INX))), + (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEX))), + ) + + +} diff --git a/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala b/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala new file mode 100644 index 00000000..5ab931fd --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala @@ -0,0 +1,14 @@ +package millfork.assembly.opt + +import millfork.CompilationOptions +import millfork.assembly.AssemblyLine +import millfork.env.NormalFunction + +/** + * @author Karol Stasiak + */ +trait AssemblyOptimization { + def name: String + + def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] +} diff --git a/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala b/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala new file mode 100644 index 00000000..d67f71c0 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala @@ -0,0 +1,155 @@ +package millfork.assembly.opt + +import millfork.CompilationOptions +import millfork.assembly.{AssemblyLine, OpcodeClasses} +import millfork.env.NormalFunction +import millfork.error.ErrorReporting + +/** + * @author Karol Stasiak + */ + +object ChangeIndexRegisterOptimizationPreferringX2Y extends ChangeIndexRegisterOptimization(true) +object ChangeIndexRegisterOptimizationPreferringY2X extends ChangeIndexRegisterOptimization(false) + +class ChangeIndexRegisterOptimization(preferX2Y: Boolean) extends AssemblyOptimization { + + object IndexReg extends Enumeration { + val X, Y = Value + } + + object IndexDirection extends Enumeration { + val X2Y, Y2X = Value + } + + import IndexReg._ + import IndexDirection._ + import millfork.assembly.AddrMode._ + import millfork.assembly.Opcode._ + + type IndexReg = IndexReg.Value + type IndexDirection = IndexDirection.Value + + override def name = "Changing index registers" + + override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = { + val usesIndex = code.exists(l => + OpcodeClasses.ReadsXAlways(l.opcode) || + OpcodeClasses.ReadsYAlways(l.opcode) || + OpcodeClasses.ChangesX(l.opcode) || + OpcodeClasses.ChangesY(l.opcode) || + Set(AbsoluteX, AbsoluteY, ZeroPageY, ZeroPageX, IndexedX, IndexedY)(l.addrMode) + ) + if (!usesIndex) { + return code + } + val canX2Y = f.returnType.size <= 1 && canOptimize(code, X2Y, None) + val canY2X = canOptimize(code, Y2X, None) + (canX2Y, canY2X) match { + case (false, false) => code + case (true, false) => + ErrorReporting.debug("Changing index register from X to Y") + switchX2Y(code) + case (false, true) => + ErrorReporting.debug("Changing index register from X to Y") + switchY2X(code) + case (true, true) => + if (preferX2Y) { + ErrorReporting.debug("Changing index register from X to Y (arbitrarily)") + switchX2Y(code) + } else { + ErrorReporting.debug("Changing index register from Y to X (arbitrarily)") + switchY2X(code) + } + } + } + + //noinspection OptionEqualsSome + private def canOptimize(code: List[AssemblyLine], dir: IndexDirection, loaded: Option[IndexReg]): Boolean = code match { + case AssemblyLine(_, AbsoluteY, _, _) :: xs if loaded != Some(Y) => false + case AssemblyLine(_, ZeroPageY, _, _) :: xs if loaded != Some(Y) => false + case AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y || loaded != Some(Y) => false + case AssemblyLine(_, AbsoluteX, _, _) :: xs if loaded != Some(X) => false + case AssemblyLine(_, ZeroPageX, _, _) :: xs if loaded != Some(X) => false + case AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X || loaded != Some(Y) => false + + // using a wrong index register for one instruction is fine + case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X => + canOptimize(xs, dir, None) + case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y => + canOptimize(xs, dir, None) + case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y => + canOptimize(xs, dir, None) + + case AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y => false + + case AssemblyLine(LAX, _, _, _) :: xs => false + case AssemblyLine(JSR, _, _, _) :: xs => false // TODO + case AssemblyLine(JMP, _, _, _) :: xs => canOptimize(xs, dir, None) + case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.ShortBranching(op) => canOptimize(xs, dir, None) + case AssemblyLine(RTS, _, _, _) :: xs => canOptimize(xs, dir, None) + case AssemblyLine(LABEL, _, _, _) :: xs => canOptimize(xs, dir, None) + case AssemblyLine(DISCARD_XF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != X)) + case AssemblyLine(DISCARD_YF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != Y)) + case AssemblyLine(_, DoesNotExist, _, _) :: xs => canOptimize(xs, dir, loaded) + + case AssemblyLine(TAX | LDX | PLX, _, _, e) :: xs => + (e || dir == Y2X) && canOptimize(xs, dir, Some(X)) + case AssemblyLine(TAY | LDY | PLY, _, _, e) :: xs => + (e || dir == X2Y) && canOptimize(xs, dir, Some(Y)) + case AssemblyLine(TXA | STX | PHX | CPX | INX | DEX, _, _, e) :: xs => + (e || dir == Y2X) && loaded == Some(X) && canOptimize(xs, dir, Some(X)) + case AssemblyLine(TYA | STY | PHY | CPY | INY | DEY, _, _, e) :: xs => + (e || dir == X2Y) && loaded == Some(Y) && canOptimize(xs, dir, Some(Y)) + + case AssemblyLine(SAX | TXS | SBX, _, _, _) :: xs => dir == Y2X && loaded == Some(X) && canOptimize(xs, dir, Some(X)) + case AssemblyLine(TSX, _, _, _) :: xs => dir == Y2X && loaded != Some(Y) && canOptimize(xs, dir, Some(X)) + + case _ :: xs => canOptimize(xs, dir, loaded) + + case Nil => true + } + + private def switchX2Y(code: List[AssemblyLine]): List[AssemblyLine] = code match { + case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _)) :: xs => a :: b :: switchX2Y(xs) + case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(_, IndexedX, _, _)) :: xs => a :: b :: switchX2Y(xs) + case (x@AssemblyLine(TAX, _, _, _)) :: xs => x.copy(opcode = TAY) :: switchX2Y(xs) + case (x@AssemblyLine(TXA, _, _, _)) :: xs => x.copy(opcode = TYA) :: switchX2Y(xs) + case (x@AssemblyLine(STX, _, _, _)) :: xs => x.copy(opcode = STY) :: switchX2Y(xs) + case (x@AssemblyLine(LDX, _, _, _)) :: xs => x.copy(opcode = LDY) :: switchX2Y(xs) + case (x@AssemblyLine(INX, _, _, _)) :: xs => x.copy(opcode = INY) :: switchX2Y(xs) + case (x@AssemblyLine(DEX, _, _, _)) :: xs => x.copy(opcode = DEY) :: switchX2Y(xs) + case (x@AssemblyLine(CPX, _, _, _)) :: xs => x.copy(opcode = CPY) :: switchX2Y(xs) + + case AssemblyLine(LAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected LAX") + case AssemblyLine(TXS, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TXS") + case AssemblyLine(TSX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TSX") + case AssemblyLine(SBX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SBX") + case AssemblyLine(SAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SAX") + + case (x@AssemblyLine(_, AbsoluteX, _, _)) :: xs => x.copy(addrMode = AbsoluteY) :: switchX2Y(xs) + case (x@AssemblyLine(_, ZeroPageX, _, _)) :: xs => x.copy(addrMode = ZeroPageY) :: switchX2Y(xs) + case (x@AssemblyLine(_, IndexedX, _, _)) :: xs => ErrorReporting.fatal("Unexpected IndexedX") + + case x::xs => x :: switchX2Y(xs) + case Nil => Nil + } + + private def switchY2X(code: List[AssemblyLine]): List[AssemblyLine] = code match { + case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs => code.take(2) ++ switchY2X(xs) + case (x@AssemblyLine(TAY, _, _, _)) :: xs => x.copy(opcode = TAX) :: switchY2X(xs) + case (x@AssemblyLine(TYA, _, _, _)) :: xs => x.copy(opcode = TXA) :: switchY2X(xs) + case (x@AssemblyLine(STY, _, _, _)) :: xs => x.copy(opcode = STX) :: switchY2X(xs) + case (x@AssemblyLine(LDY, _, _, _)) :: xs => x.copy(opcode = LDX) :: switchY2X(xs) + case (x@AssemblyLine(INY, _, _, _)) :: xs => x.copy(opcode = INX) :: switchY2X(xs) + case (x@AssemblyLine(DEY, _, _, _)) :: xs => x.copy(opcode = DEX) :: switchY2X(xs) + case (x@AssemblyLine(CPY, _, _, _)) :: xs => x.copy(opcode = CPX) :: switchY2X(xs) + + case (x@AssemblyLine(_, AbsoluteY, _, _)) :: xs => x.copy(addrMode = AbsoluteX) :: switchY2X(xs) + case (x@AssemblyLine(_, ZeroPageY, _, _)) :: xs => x.copy(addrMode = ZeroPageX) :: switchY2X(xs) + case AssemblyLine(_, IndexedY, _, _) :: xs => ErrorReporting.fatal("Unexpected IndexedY") + + case x::xs => x :: switchY2X(xs) + case Nil => Nil + } +} diff --git a/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala b/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala new file mode 100644 index 00000000..d1b4ca60 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala @@ -0,0 +1,36 @@ +package millfork.assembly.opt + +import millfork.assembly.{AssemblyLine, Opcode} +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.assembly.OpcodeClasses._ +import millfork.env.{Constant, NormalFunction} + +/** + * @author Karol Stasiak + */ +object CmosOptimizations { + + val StzAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX) + + val ZeroStoreAsStz = new RuleBasedAssemblyOptimization("Zero store", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (HasA(0) & HasOpcode(STA) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code => + code.head.copy(opcode = STZ) :: Nil + }, + (HasX(0) & HasOpcode(STX) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code => + code.head.copy(opcode = STZ) :: Nil + }, + (HasY(0) & HasOpcode(STY) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code => + code.head.copy(opcode = STZ) :: Nil + }, + ) + + val OptimizeZeroIndex = new RuleBasedAssemblyOptimization("Optimizing zero index", + needsFlowInfo = FlowInfoRequirement.ForwardFlow, + (Elidable & HasY(0) & HasAddrMode(IndexedY) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => code.map(_.copy(addrMode = ZeroPageIndirect))), + (Elidable & HasX(0) & HasAddrMode(IndexedX) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => code.map(_.copy(addrMode = ZeroPageIndirect))), + ) + + val All: List[AssemblyOptimization] = List(ZeroStoreAsStz) +} diff --git a/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala new file mode 100644 index 00000000..7e21009c --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala @@ -0,0 +1,259 @@ +package millfork.assembly.opt + +import millfork.assembly.{AssemblyLine, OpcodeClasses, State} +import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant} + +import scala.collection.immutable + +/** + * @author Karol Stasiak + */ + +sealed trait Status[T] { + def contains(value: T): Boolean + + def ~(that: Status[T]): Status[T] = { + (this, that) match { + case (AnyStatus(), _) => AnyStatus() + case (_, AnyStatus()) => AnyStatus() + case (SingleStatus(x), SingleStatus(y)) => if (x == y) SingleStatus(x) else AnyStatus() + case (SingleStatus(x), UnknownStatus()) => SingleStatus(x) + case (UnknownStatus(), SingleStatus(x)) => SingleStatus(x) + case (UnknownStatus(), UnknownStatus()) => UnknownStatus() + } + } + +} + +object Status { + + implicit class IntStatusOps(val inner: Status[Int]) extends AnyVal { + def map[T](f: Int => T): Status[T] = inner match { + case SingleStatus(x) => SingleStatus(f(x)) + case _ => AnyStatus() + } + + def z(f: Int => Int = identity): Status[Boolean] = inner match { + case SingleStatus(x) => + val y = f(x) & 0xff + SingleStatus(y == 0) + case _ => AnyStatus() + } + + def n(f: Int => Int = identity): Status[Boolean] = inner match { + case SingleStatus(x) => + val y = f(x) & 0xff + SingleStatus(y >= 0x80) + case _ => AnyStatus() + } + } + +} + + +case class SingleStatus[T](t: T) extends Status[T] { + override def contains(value: T): Boolean = t == value + + override def toString: String = t match { + case true => "1" + case false => "0" + case _ => t.toString + } +} + +case class UnknownStatus[T]() extends Status[T] { + override def contains(value: T) = false + + override def toString: String = "_" +} + +case class AnyStatus[T]() extends Status[T] { + override def contains(value: T) = false + + override def toString: String = "#" +} +//noinspection RedundantNewCaseClass +case class CpuStatus(a: Status[Int] = UnknownStatus(), + x: Status[Int] = UnknownStatus(), + y: Status[Int] = UnknownStatus(), + z: Status[Boolean] = UnknownStatus(), + n: Status[Boolean] = UnknownStatus(), + c: Status[Boolean] = UnknownStatus(), + v: Status[Boolean] = UnknownStatus(), + d: Status[Boolean] = UnknownStatus(), + ) { + + override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d" + + def nz: CpuStatus = + this.copy(n = AnyStatus(), z = AnyStatus()) + + def nz(i: Long): CpuStatus = + this.copy(n = SingleStatus((i & 0x80) != 0), z = SingleStatus((i & 0xff) == 0)) + + def ~(that: CpuStatus) = new CpuStatus( + a = this.a ~ that.a, + x = this.x ~ that.x, + y = this.y ~ that.y, + z = this.z ~ that.z, + n = this.n ~ that.n, + c = this.c ~ that.c, + v = this.v ~ that.v, + d = this.d ~ that.d, + ) + + def hasClear(state: State.Value): Boolean = state match { + case State.A => a.contains(0) + case State.X => x.contains(0) + case State.Y => y.contains(0) + case State.Z => z.contains(false) + case State.N => n.contains(false) + case State.C => c.contains(false) + case State.V => v.contains(false) + case State.D => d.contains(false) + } + + def hasSet(state: State.Value): Boolean = state match { + case State.A => false + case State.X => false + case State.Y => false + case State.Z => z.contains(true) + case State.N => n.contains(true) + case State.C => c.contains(true) + case State.V => v.contains(true) + case State.D => d.contains(true) + } +} + +object CoarseFlowAnalyzer { + //noinspection RedundantNewCaseClass + def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuStatus] = { + val flagArray = Array.fill[CpuStatus](code.length)(CpuStatus()) + val codeArray = code.toArray + val initialStatus = new CpuStatus(d = SingleStatus(false)) + + var changed = true + while (changed) { + changed = false + var currentStatus: CpuStatus = if (f.interrupt) CpuStatus() else initialStatus + for (i <- codeArray.indices) { + import millfork.assembly.Opcode._ + import millfork.assembly.AddrMode._ + if (flagArray(i) != currentStatus) { + changed = true + flagArray(i) = currentStatus + } + codeArray(i) match { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => + val L = l + currentStatus = codeArray.indices.flatMap(j => codeArray(j) match { + case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j)) + case _ => None + }).fold(CpuStatus())(_ ~ _) + + case AssemblyLine(BCC, _, _, _) => + currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(true)) + case AssemblyLine(BCS, _, _, _) => + currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(false)) + case AssemblyLine(BVS, _, _, _) => + currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(false)) + case AssemblyLine(BVC, _, _, _) => + currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(true)) + case AssemblyLine(BMI, _, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(false)) + case AssemblyLine(BPL, _, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(true)) + case AssemblyLine(BEQ, _, _, _) => + currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(false)) + case AssemblyLine(BNE, _, _, _) => + currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(true)) + + case AssemblyLine(SED, _, _, _) => + currentStatus = currentStatus.copy(d = SingleStatus(true)) + case AssemblyLine(SEC, _, _, _) => + currentStatus = currentStatus.copy(c = SingleStatus(true)) + case AssemblyLine(CLD, _, _, _) => + currentStatus = currentStatus.copy(d = SingleStatus(false)) + case AssemblyLine(CLC, _, _, _) => + currentStatus = currentStatus.copy(c = SingleStatus(false)) + case AssemblyLine(CLV, _, _, _) => + currentStatus = currentStatus.copy(v = SingleStatus(false)) + + case AssemblyLine(JSR, _, _, _) => + currentStatus = initialStatus + + case AssemblyLine(LDX, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.nz(n).copy(x = SingleStatus(n)) + case AssemblyLine(LDY, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.nz(n).copy(y = SingleStatus(n)) + case AssemblyLine(LDA, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.nz(n).copy(a = SingleStatus(n)) + case AssemblyLine(LAX, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.nz(n).copy(a = SingleStatus(n), x = SingleStatus(n)) + + case AssemblyLine(EOR, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.copy(n = currentStatus.a.n(_ ^ n), z = currentStatus.a.z(_ ^ n), a = currentStatus.a.map(_ ^ n)) + case AssemblyLine(AND, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), z = currentStatus.a.z(_ & n), a = currentStatus.a.map(_ & n)) + case AssemblyLine(ANC, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), c = currentStatus.a.n(_ & n), z = currentStatus.x.z(_ & n), a = currentStatus.a.map(_ & n)) + case AssemblyLine(ORA, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.copy(n = currentStatus.a.n(_ | n), z = currentStatus.a.z(_ | n), a = currentStatus.a.map(_ | n)) + case AssemblyLine(ALR, Immediate, NumericConstant(nn, _), _) => + val n = nn.toInt + currentStatus = currentStatus.copy( + n = currentStatus.a.n(i => (i & n & 0xff) >> 1), + z = currentStatus.a.z(i => (i & n & 0xff) >> 1), + c = currentStatus.a.map(i => (i & n & 1) == 0), + a = currentStatus.a.map(i => (i & n & 0xff) >> 1)) + + case AssemblyLine(INX, Implied, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.x.n(_ + 1), z = currentStatus.x.z(_ + 1), x = currentStatus.x.map(_ + 1)) + case AssemblyLine(DEX, Implied, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.x.n(_ - 1), z = currentStatus.x.z(_ - 1), x = currentStatus.x.map(_ - 1)) + case AssemblyLine(INY, Implied, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.y.n(_ + 1), z = currentStatus.y.z(_ + 1), y = currentStatus.y.map(_ + 1)) + case AssemblyLine(DEY, Implied, _, _) => + currentStatus = currentStatus.copy(n = currentStatus.y.n(_ - 1), z = currentStatus.y.z(_ - 1), y = currentStatus.y.map(_ - 1)) + case AssemblyLine(TAX, _, _, _) => + currentStatus = currentStatus.copy(x = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z()) + case AssemblyLine(TXA, _, _, _) => + currentStatus = currentStatus.copy(a = currentStatus.x, n = currentStatus.x.n(), z = currentStatus.x.z()) + case AssemblyLine(TAY, _, _, _) => + currentStatus = currentStatus.copy(y = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z()) + case AssemblyLine(TYA, _, _, _) => + currentStatus = currentStatus.copy(a = currentStatus.y, n = currentStatus.y.n(), z = currentStatus.y.z()) + + case AssemblyLine(opcode, addrMode, parameter, _) => + if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.copy(x = AnyStatus()) + if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.copy(y = AnyStatus()) + if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.copy(a = AnyStatus()) + if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.copy(a = AnyStatus()) + if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.nz + if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.copy(c = AnyStatus()) + if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.copy(v = AnyStatus()) + if (opcode == CMP || opcode == CPX || opcode == CPY) { + if (addrMode == Immediate) parameter match { + case NumericConstant(0, _) => currentStatus = currentStatus.copy(c = SingleStatus(true)) + case _ => () + } + } + } + } +// flagArray.zip(codeArray).foreach{ +// case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s") +// } +// println("---------------------") + } + + flagArray.toList + } +} diff --git a/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala b/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala new file mode 100644 index 00000000..649bd4e1 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala @@ -0,0 +1,59 @@ +package millfork.assembly.opt + +import millfork.assembly._ +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.env._ + +/** + * @author Karol Stasiak + */ +object DangerousOptimizations { + + val ConstantIndexOffsetPropagation = new RuleBasedAssemblyOptimization("Constant index offset propagation", + // TODO: try to guess when overflow can happen + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~ + ( + (HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~ + (Linear & Not(ConcernsY)).* + ).capture(1) ~ + (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) => + val last = code.last + ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify) + }, + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~ + ( + (HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~ + (Linear & Not(ConcernsX)).* + ).capture(1) ~ + (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) => + val last = code.last + ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify) + }, + (Elidable & HasOpcode(INY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~ + (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) => + val last = code.last + List(last.copy(parameter = last.parameter.+(1).quickSimplify)) + }, + (Elidable & HasOpcode(DEY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~ + (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) => + val last = code.last + List(last.copy(parameter = last.parameter.+(-1).quickSimplify)) + }, + (Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~ + (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) => + val last = code.last + List(last.copy(parameter = last.parameter.+(1).quickSimplify)) + }, + (Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~ + (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) => + val last = code.last + List(last.copy(parameter = last.parameter.+(-1).quickSimplify)) + }, + ) + + val All: List[AssemblyOptimization] = List(ConstantIndexOffsetPropagation) +} diff --git a/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala new file mode 100644 index 00000000..eaeb718f --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala @@ -0,0 +1,34 @@ +package millfork.assembly.opt + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly.{AssemblyLine, State} +import millfork.env.NormalFunction + +/** + * @author Karol Stasiak + */ + +case class FlowInfo(statusBefore: CpuStatus, importanceAfter: CpuImportance) { + + def hasClear(state: State.Value): Boolean = statusBefore.hasClear(state) + + def hasSet(state: State.Value): Boolean = statusBefore.hasSet(state) + + def isUnimportant(state: State.Value): Boolean = importanceAfter.isUnimportant(state) +} + +object FlowInfo { + val Default = FlowInfo(CpuStatus(), CpuImportance()) +} + +object FlowAnalyzer { + def analyze(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[(FlowInfo, AssemblyLine)] = { + val forwardFlow = if (options.flag(CompilationFlag.DetailedFlowAnalysis)) { + QuantumFlowAnalyzer.analyze(f, code).map(_.collapse) + } else { + CoarseFlowAnalyzer.analyze(f, code) + } + val reverseFlow = ReverseFlowAnalyzer.analyze(f, code) + forwardFlow.zip(reverseFlow).map{case (s,i) => FlowInfo(s,i)}.zip(code) + } +} diff --git a/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala b/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala new file mode 100644 index 00000000..3a1b2095 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala @@ -0,0 +1,242 @@ +package millfork.assembly.opt + +import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State} +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.assembly.OpcodeClasses._ +import millfork.env.{Constant, NormalFunction, NumericConstant} + +/** + * These optimizations help on their own, but may prevent other optimizations from triggering. + * + * @author Karol Stasiak + */ +object LaterOptimizations { + + + // This optimization tends to prevent later Variable To Register Optimization, + // so run this only after it's pretty sure V2RO won't happen any more + val DoubleLoadToDifferentRegisters = new RuleBasedAssemblyOptimization("Double load to different registers", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDX, TAX), + TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDY, TAY), + TwoDifferentLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX), LDA, TXA), + TwoDifferentLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY), LDA, TYA), + TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDX, TAX), + TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDY, TAY), + TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX), LDA, TXA), + TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY), LDA, TYA), + ) + + private def TwoDifferentLoadsWithNoFlagChangeInBetween(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = { + (HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).* ~ + (HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c => + c.init :+ AssemblyLine.implied(transferOpcode) + } + } + + private def TwoDifferentLoadsWhoseFlagsWillNotBeChecked(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = { + ((HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).*).capture(2) ~ + (HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~ + ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) => + ctx.get[List[AssemblyLine]](2) ++ (AssemblyLine.implied(transferOpcode) :: ctx.get[List[AssemblyLine]](3)) + } + } + + private def TwoIdenticalLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = { + (HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & Not(ChangesMemory) & middle & Not(ChangesNAndZ)).* ~ + (HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c => + c.init + } + } + + private def TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = { + (HasOpcode(opcode) & HasAddrMode(Immediate) & MatchParameter(1)) ~ + (LinearOrLabel & middle & Not(ChangesNAndZ)).* ~ + (HasOpcode(opcode) & Elidable & HasAddrMode(Immediate) & MatchParameter(1)) ~~> { c => + c.init + } + } + + private def TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(opcode: Opcode.Value, middle: AssemblyLinePattern) = { + ((HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & Not(ChangesMemory) & middle).*).capture(2) ~ + (HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~ + ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) => + ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3) + } + } + + //noinspection ZeroIndexToHead + private def InterleavedImmediateLoads(load: Opcode.Value, store: Opcode.Value) = { + (Elidable & HasOpcode(load) & MatchImmediate(0)) ~ + (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8)) ~ + (Elidable & HasOpcode(load) & MatchImmediate(1)) ~ + (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8)) ~ + (Elidable & HasOpcode(load) & MatchImmediate(0)) ~~> { c => + List(c(2), c(3), c(0), c(1)) + } + } + + //noinspection ZeroIndexToHead + private def InterleavedAbsoluteLoads(load: Opcode.Value, store: Opcode.Value) = { + (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~ + (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8) & DontMatchParameter(0)) ~ + (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(1) & DontMatchParameter(8) & DontMatchParameter(0)) ~ + (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8) & DontMatchParameter(1) & DontMatchParameter(0)) ~ + (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~~> { c => + List(c(2), c(3), c(0), c(1)) + } + } + + val DoubleLoadToTheSameRegister = new RuleBasedAssemblyOptimization("Double load to the same register", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + TwoIdenticalLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)), + TwoIdenticalLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)), + TwoIdenticalLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)), + TwoIdenticalLoadsWithNoFlagChangeInBetween(LAX, Not(ChangesA) & Not(ChangesX)), + TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)), + TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)), + TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)), + TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA)), + TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX)), + TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY)), + TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LAX, Not(ChangesA) & Not(ChangesX)), + InterleavedImmediateLoads(LDA, STA), + InterleavedImmediateLoads(LDX, STX), + InterleavedImmediateLoads(LDY, STY), + InterleavedAbsoluteLoads(LDA, STA), + InterleavedAbsoluteLoads(LDX, STX), + InterleavedAbsoluteLoads(LDY, STY), + ) + + private def pointlessLoadAfterStore(store: Opcode.Value, load: Opcode.Value, addrMode: AddrMode.Value, meantime: AssemblyLinePattern = Anything) = { + ((HasOpcode(store) & HasAddrMode(addrMode) & MatchParameter(1)) ~ + (LinearOrBranch & Not(ChangesA) & Not(ChangesMemory) & meantime).*).capture(2) ~ + (HasOpcode(load) & Elidable & HasAddrMode(addrMode) & MatchParameter(1)) ~ + ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) => + ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3) + } + } + + val PointlessLoadAfterStore = new RuleBasedAssemblyOptimization("Pointless load after store", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + pointlessLoadAfterStore(STA, LDA, Absolute), + pointlessLoadAfterStore(STA, LDA, AbsoluteX, Not(ChangesX)), + pointlessLoadAfterStore(STA, LDA, AbsoluteY, Not(ChangesY)), + pointlessLoadAfterStore(STX, LDX, Absolute), + pointlessLoadAfterStore(STY, LDY, Absolute), + ) + + + private val ShiftAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX) + private val ShiftOpcodes = Set(ASL, ROL, ROR, LSR) + + // LDA-SHIFT-STA is slower than just SHIFT + // LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator + val PointessLoadingForShifting = new RuleBasedAssemblyOptimization("Pointless loading for shifting", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~ + (Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) => + AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) :: code.drop(3) + }, + (Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~ + (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~ + (Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) => + val shift = AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) + shift :: shift :: code.drop(4) + } + ) + + // SHIFT-LDA is equally fast as LDA-SHIFT-STA, but can enable further optimizations doesn't use the accumulator + // LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator + val LoadingAfterShifting = new RuleBasedAssemblyOptimization("Loading after shifting", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcodeIn(ShiftOpcodes) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + AssemblyLine(LDA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) :: + AssemblyLine.implied(code.head.opcode) :: + AssemblyLine(STA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) :: + code.drop(2) + } + ) + + val UseZeropageAddressingMode = new RuleBasedAssemblyOptimization("Using zeropage addressing mode", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasAddrMode(Absolute) & MatchParameter(0)) ~ Where(ctx => ctx.get[Constant](0).quickSimplify match { + case NumericConstant(x, _) => (x & 0xff00) == 0 + case _ => false + }) ~~> (code => code.head.copy(addrMode = ZeroPage) :: Nil) + ) + + val UseXInsteadOfStack = new RuleBasedAssemblyOptimization("Using X instead of stack", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.X)) ~ + (Not(ConcernsStack) & Not(ConcernsX)).capture(1) ~ + Where(_.isExternallyLinearBlock(1)) ~ + (Elidable & HasOpcode(PLA)) ~~> (c => + AssemblyLine.implied(TAX) :: (c.tail.init :+ AssemblyLine.implied(TXA)) + ) + ) + + val UseYInsteadOfStack = new RuleBasedAssemblyOptimization("Using Y instead of stack", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.Y)) ~ + (Not(ConcernsStack) & Not(ConcernsY)).capture(1) ~ + Where(_.isExternallyLinearBlock(1)) ~ + (Elidable & HasOpcode(PLA)) ~~> (c => + AssemblyLine.implied(TAY) :: (c.tail.init :+ AssemblyLine.implied(TYA)) + ) + ) + + // TODO: make it more generic + val IndexSwitchingOptimization = new RuleBasedAssemblyOptimization("Index switching optimization", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~ + (Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX)) ~ + (HasOpcode(LDY) & Not(ConcernsX)) ~ + (Linear & Not(ChangesY) & Not(ConcernsX) & HasAddrModeIn(Set(AbsoluteY, IndexedY, ZeroPageY))) ~ + (Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~ + (Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> { (code, ctx) => + List( + code(0).copy(opcode = LDX), + code(1).copy(addrMode = AbsoluteX), + code(2), + code(3), + code(5).copy(addrMode = AbsoluteX)) + }, + (Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~ + (Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY)) ~ + (HasOpcode(LDX) & Not(ConcernsY)) ~ + (Linear & Not(ChangesX) & Not(ConcernsY) & HasAddrModeIn(Set(AbsoluteX, IndexedX, ZeroPageX, AbsoluteIndexedX))) ~ + (Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~ + (Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> { (code, ctx) => + List( + code(0).copy(opcode = LDY), + code(1).copy(addrMode = AbsoluteY), + code(2), + code(3), + code(5).copy(addrMode = AbsoluteY)) + }, + + ) + + val All = List( + DoubleLoadToDifferentRegisters, + DoubleLoadToTheSameRegister, + IndexSwitchingOptimization, + PointlessLoadAfterStore, + PointessLoadingForShifting, + LoadingAfterShifting, + UseXInsteadOfStack, + UseYInsteadOfStack, + UseZeropageAddressingMode) +} + diff --git a/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala new file mode 100644 index 00000000..2f0cabb2 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala @@ -0,0 +1,425 @@ +package millfork.assembly.opt + +import millfork.assembly.{AssemblyLine, OpcodeClasses} +import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant} + +import scala.collection.immutable.BitSet + +/** + * @author Karol Stasiak + */ + +object QCpuStatus { + val InitialStatus = QCpuStatus((for { + c <- Seq(true, false) + v <- Seq(true, false) + n <- Seq(true, false) + z <- Seq(true, false) + } yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.NoEquality)).toMap) + + val UnknownStatus = QCpuStatus((for { + c <- Seq(true, false) + v <- Seq(true, false) + n <- Seq(true, false) + z <- Seq(true, false) + } yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.UnknownEquality)).toMap) + + def gather(l: List[(QFlagStatus, QRegStatus)]) = + QCpuStatus(l.groupBy(_._1). + map { case (k, vs) => k -> vs.map(_._2).reduce(_ ++ _) }. + filterNot(_._2.isEmpty)) +} + +case class QCpuStatus(data: Map[QFlagStatus, QRegStatus]) { + def collapse: CpuStatus = { + val registers = data.values.reduce(_ ++ _) + + def bitset(b: BitSet): Status[Int] = if (b.size == 1) SingleStatus(b.head) else AnyStatus() + + def flag(f: QFlagStatus => Boolean): Status[Boolean] = + if (data.keys.forall(k => f(k))) SingleStatus(true) + else if (data.keys.forall(k => !f(k))) SingleStatus(false) + else AnyStatus() + + CpuStatus( + a = bitset(registers.a), + x = bitset(registers.x), + y = bitset(registers.y), + c = flag(_.c), + d = flag(_.d), + v = flag(_.v), + z = flag(_.z), + n = flag(_.n), + ) + } + + def changeFlagUnconditionally(f: QFlagStatus => QFlagStatus): QCpuStatus = { + QCpuStatus.gather(data.toList.map { case (k, v) => f(k) -> v }) + } + + def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus): QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v) }) + } + + def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus, h: QFlagStatus => QFlagStatus): QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v, h(k) -> v) }) + } + + def mapRegisters(f: QRegStatus => QRegStatus): QCpuStatus = { + QCpuStatus(data.map { case (k, v) => k -> f(v) }) + } + + def mapRegisters(f: (QFlagStatus, QRegStatus) => QRegStatus): QCpuStatus = { + QCpuStatus(data.map { case (k, v) => k -> f(k, v) }) + } + + def flatMap(f: (QFlagStatus, QRegStatus) => List[(QFlagStatus, QRegStatus)]): QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => f(k, v) }) + } + + def changeNZFromA: QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => + List( + k.copy(n = false, z = false) -> v.whereA(i => i.toByte > 0), + k.copy(n = true, z = false) -> v.whereA(i => i.toByte < 0), + k.copy(n = false, z = true) -> v.whereA(i => i.toByte == 0)) + }) + } + + def changeNZFromX: QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => + List( + k.copy(n = false, z = false) -> v.whereX(i => i.toByte > 0), + k.copy(n = true, z = false) -> v.whereX(i => i.toByte < 0), + k.copy(n = false, z = true) -> v.whereX(i => i.toByte == 0)) + }) + } + + def changeNZFromY: QCpuStatus = { + QCpuStatus.gather(data.toList.flatMap { case (k, v) => + List( + k.copy(n = false, z = false) -> v.whereY(i => i.toByte > 0), + k.copy(n = true, z = false) -> v.whereY(i => i.toByte < 0), + k.copy(n = false, z = true) -> v.whereY(i => i.toByte == 0)) + }) + } + + def ~(that: QCpuStatus): QCpuStatus = QCpuStatus.gather(this.data.toList ++ that.data.toList) +} + +object QRegStatus { + val NoValues: BitSet = BitSet.empty + val AllValues: BitSet = BitSet.fromBitMask(Array(-1L, -1L, -1L, -1L)) + +} + +object RegEquality extends Enumeration { + val NoEquality, AX, AY, XY, AXY, UnknownEquality = Value + + def or(a: Value, b: Value) = { + (a, b) match { + case (UnknownEquality, _) => b + case (_, UnknownEquality) => a + case (NoEquality, _) => NoEquality + case (_, NoEquality) => NoEquality + case (_, _) if a == b => a + case (AXY, _) => b + case (_, AXY) => a + case _ => NoEquality + } + } + + def afterTransfer(a: Value, b: Value) = { + (a, b) match { + case (UnknownEquality, _) => b + case (_, UnknownEquality) => a + case (NoEquality, _) => b + case (_, NoEquality) => a + case (_, _) if a == b => a + case _ => AXY + } + } +} + +case class QRegStatus(a: BitSet, x: BitSet, y: BitSet, equal: RegEquality.Value) { + def isEmpty: Boolean = a.isEmpty || x.isEmpty || y.isEmpty + + def ++(that: QRegStatus) = QRegStatus( + a = a ++ that.a, + x = x ++ that.x, + y = y ++ that.y, + equal = RegEquality.or(equal, that.equal)) + + def afterTransfer(transfer: RegEquality.Value): QRegStatus = + copy(equal = RegEquality.afterTransfer(equal, transfer)) + + def changeA(f: Int => Long): QRegStatus = { + val newA = a.map(i => f(i).toInt & 0xff) + val newEqual = equal match { + case RegEquality.XY => RegEquality.XY + case RegEquality.AXY => RegEquality.XY + case _ => RegEquality.NoEquality + } + QRegStatus(newA, x, y, newEqual) + } + + def changeX(f: Int => Long): QRegStatus = { + val newA = a.map(i => f(i).toInt & 0xff) + val newEqual = equal match { + case RegEquality.XY => RegEquality.XY + case RegEquality.AXY => RegEquality.XY + case _ => RegEquality.NoEquality + } + QRegStatus(newA, x, y, newEqual) + } + + def changeY(f: Int => Long): QRegStatus = { + val newA = a.map(i => f(i).toInt & 0xff) + val newEqual = equal match { + case RegEquality.XY => RegEquality.XY + case RegEquality.AXY => RegEquality.XY + case _ => RegEquality.NoEquality + } + QRegStatus(newA, x, y, newEqual) + } + + def whereA(f: Int => Boolean): QRegStatus = + equal match { + case RegEquality.AXY => + copy(a = a.filter(f), x = x.filter(f), y = y.filter(f)) + case RegEquality.AY => + copy(a = a.filter(f), y = y.filter(f)) + case RegEquality.AX => + copy(a = a.filter(f), x = x.filter(f)) + case _ => + copy(a = a.filter(f)) + } + + def whereX(f: Int => Boolean): QRegStatus = + equal match { + case RegEquality.AXY => + copy(a = a.filter(f), x = x.filter(f), y = y.filter(f)) + case RegEquality.XY => + copy(x = x.filter(f), y = y.filter(f)) + case RegEquality.AX => + copy(a = a.filter(f), x = x.filter(f)) + case _ => + copy(x = x.filter(f)) + } + + def whereY(f: Int => Boolean): QRegStatus = + equal match { + case RegEquality.AXY => + copy(a = a.filter(f), x = x.filter(f), y = y.filter(f)) + case RegEquality.AY => + copy(a = a.filter(f), y = y.filter(f)) + case RegEquality.XY => + copy(x = x.filter(f), y = y.filter(f)) + case _ => + copy(y = y.filter(f)) + } +} + +case class QFlagStatus(c: Boolean, d: Boolean, v: Boolean, z: Boolean, n: Boolean) + +object QuantumFlowAnalyzer { + private def loBit(b: Boolean) = if (b) 1 else 0 + + private def hiBit(b: Boolean) = if (b) 0x80 else 0 + + //noinspection RedundantNewCaseClass + def analyze(f: NormalFunction, code: List[AssemblyLine]): List[QCpuStatus] = { + val flagArray = Array.fill[QCpuStatus](code.length)(QCpuStatus.UnknownStatus) + val codeArray = code.toArray + + var changed = true + while (changed) { + changed = false + var currentStatus: QCpuStatus = if (f.interrupt) QCpuStatus.UnknownStatus else QCpuStatus.UnknownStatus + for (i <- codeArray.indices) { + import millfork.assembly.Opcode._ + import millfork.assembly.AddrMode._ + if (flagArray(i) != currentStatus) { + changed = true + flagArray(i) = currentStatus + } + codeArray(i) match { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => + val L = l + currentStatus = codeArray.indices.flatMap(j => codeArray(j) match { + case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j)) + case _ => None + }).fold(QCpuStatus.UnknownStatus)(_ ~ _) + + case AssemblyLine(BCC, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true)) + case AssemblyLine(BCS, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false)) + case AssemblyLine(BVS, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false)) + case AssemblyLine(BVC, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = true)) + case AssemblyLine(BMI, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = false)) + case AssemblyLine(BPL, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = true)) + case AssemblyLine(BEQ, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = false)) + case AssemblyLine(BNE, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = true)) + + case AssemblyLine(SED, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = true)) + case AssemblyLine(SEC, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true)) + case AssemblyLine(CLD, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = false)) + case AssemblyLine(CLC, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false)) + case AssemblyLine(CLV, _, _, _) => + currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false)) + + case AssemblyLine(JSR, _, _, _) => + currentStatus = QCpuStatus.InitialStatus + + case AssemblyLine(LDX, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeX(_ => n)).changeNZFromX + case AssemblyLine(LDY, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeY(_ => n)).changeNZFromY + case AssemblyLine(LDA, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n)).changeNZFromA + case AssemblyLine(LAX, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n).changeX(_ => n).afterTransfer(RegEquality.AX)).changeNZFromA + + case AssemblyLine(EOR, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ ^ n)).changeNZFromA + case AssemblyLine(AND, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA + case AssemblyLine(ANC, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA.changeFlagUnconditionally(f => f.copy(c = f.z)) + case AssemblyLine(ORA, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.mapRegisters(r => r.changeA(_ | n)).changeNZFromA + + case AssemblyLine(INX, Implied, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.changeX(_ + 1)).changeNZFromX + case AssemblyLine(DEX, Implied, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.changeX(_ - 1)).changeNZFromX + case AssemblyLine(INY, Implied, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY + case AssemblyLine(DEY, Implied, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY + case AssemblyLine(TAX, _, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.copy(x = r.a).afterTransfer(RegEquality.AX)).changeNZFromX + case AssemblyLine(TXA, _, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.x).afterTransfer(RegEquality.AX)).changeNZFromA + case AssemblyLine(TAY, _, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.copy(y = r.a).afterTransfer(RegEquality.AY)).changeNZFromY + case AssemblyLine(TYA, _, _, _) => + currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.y).afterTransfer(RegEquality.AY)).changeNZFromA + + case AssemblyLine(ROL, Implied, _, _) => + currentStatus = currentStatus.flatMap((f, r) => List( + f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2 + loBit(f.c)), + f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2 + loBit(f.c)), + )).changeNZFromA + case AssemblyLine(ROR, Implied, _, _) => + currentStatus = currentStatus.flatMap((f, r) => List( + f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)), + f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)), + )).changeNZFromA + case AssemblyLine(ASL, Implied, _, _) => + currentStatus = currentStatus.flatMap((f, r) => List( + f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2), + f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2), + )).changeNZFromA + case AssemblyLine(LSR, Implied, _, _) => + currentStatus = currentStatus.flatMap((f, r) => List( + f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f), + f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f), + )).changeNZFromA + case AssemblyLine(ALR, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.flatMap((f, r) => List( + f.copy(c = true) -> r.whereA(a => (a & n & 1) != 0).changeA(a => ((a & n) >>> 2) & 0x7f), + f.copy(c = false) -> r.whereA(a => (a & n & 1) == 0).changeA(a => ((a & n) >>> 2) & 0x7f), + )).changeNZFromA + case AssemblyLine(ADC, Immediate, NumericConstant(nn, _), _) => + val n = nn & 0xff + currentStatus = currentStatus.flatMap((f, r) => + if (f.d) { + val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong) + List( + f.copy(c = false, v = false) -> regs, + f.copy(c = true, v = false) -> regs, + f.copy(c = false, v = true) -> regs, + f.copy(c = true, v = true) -> regs, + ) + } else { + if (f.c) { + val regs = r.changeA(_ + n + 1) + List( + f.copy(c = false, v = false) -> regs.whereA(_ >= n), + f.copy(c = true, v = false) -> regs.whereA(_ < n), + f.copy(c = false, v = true) -> regs.whereA(_ >= n), + f.copy(c = true, v = true) -> regs.whereA(_ < n), + ) + } else { + val regs = r.changeA(_ + n) + List( + f.copy(c = false, v = false) -> regs.whereA(_ > n), + f.copy(c = true, v = false) -> regs.whereA(_ <= n), + f.copy(c = false, v = true) -> regs.whereA(_ > n), + f.copy(c = true, v = true) -> regs.whereA(_ <= n), + ) + } + } + ).changeNZFromA + case AssemblyLine(SBC, Immediate, NumericConstant(n, _), _) => + currentStatus = currentStatus.flatMap((f, r) => + if (f.d) { + val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong) + // TODO: guess the carry flag correctly + List( + f.copy(c = false, v = false) -> regs, + f.copy(c = true, v = false) -> regs, + f.copy(c = false, v = true) -> regs, + f.copy(c = true, v = true) -> regs, + ) + } else { + val regs = if (f.c) r.changeA(_ - n) else r.changeA(_ - n - 1) + List( + f.copy(c = false, v = false) -> regs, + f.copy(c = true, v = false) -> regs, + f.copy(c = false, v = true) -> regs, + f.copy(c = true, v = true) -> regs, + ) + } + ).changeNZFromA + + case AssemblyLine(opcode, addrMode, parameter, _) => + if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(x = QRegStatus.AllValues)) + if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(y = QRegStatus.AllValues)) + if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues)) + if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues)) + if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay( + _.copy(n = false, z = false), + _.copy(n = true, z = false), + _.copy(n = false, z = true)) + if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(c = false), _.copy(c = true)) + if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(v = false), _.copy(v = true)) + if (opcode == CMP || opcode == CPX || opcode == CPY) { + if (addrMode == Immediate) parameter match { + case NumericConstant(0, _) => currentStatus = currentStatus.changeFlagUnconditionally(_.copy(c = true)) + case _ => () + } + } + } + } + // flagArray.zip(codeArray).foreach{ + // case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s") + // } + // println("---------------------") + } + + flagArray.toList + } +} diff --git a/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala new file mode 100644 index 00000000..00d8ff91 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala @@ -0,0 +1,149 @@ +package millfork.assembly.opt + +import millfork.assembly.{AssemblyLine, OpcodeClasses, State} +import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant} + +import scala.collection.immutable + +/** + * @author Karol Stasiak + */ + +sealed trait Importance { + def ~(that: Importance) = (this, that) match { + case (_, Important) | (Important, _) => Important + case (_, Unimportant) | (Unimportant, _) => Unimportant + case (UnknownImportance, UnknownImportance) => UnknownImportance + } +} + +case object Important extends Importance { + override def toString = "!" +} + + +case object Unimportant extends Importance { + override def toString = "*" +} + +case object UnknownImportance extends Importance { + override def toString = "?" +} + +//noinspection RedundantNewCaseClass +case class CpuImportance(a: Importance = UnknownImportance, + x: Importance = UnknownImportance, + y: Importance = UnknownImportance, + n: Importance = UnknownImportance, + z: Importance = UnknownImportance, + v: Importance = UnknownImportance, + c: Importance = UnknownImportance, + d: Importance = UnknownImportance, + ) { + override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d" + + def ~(that: CpuImportance) = new CpuImportance( + a = this.a ~ that.a, + x = this.x ~ that.x, + y = this.y ~ that.y, + z = this.z ~ that.z, + n = this.n ~ that.n, + c = this.c ~ that.c, + v = this.v ~ that.v, + d = this.d ~ that.d, + ) + + def isUnimportant(state: State.Value): Boolean = state match { + case State.A => a == Unimportant + case State.X => x == Unimportant + case State.Y => y == Unimportant + case State.Z => z == Unimportant + case State.N => n == Unimportant + case State.C => c == Unimportant + case State.V => v == Unimportant + case State.D => d == Unimportant + } +} + +object ReverseFlowAnalyzer { + //noinspection RedundantNewCaseClass + def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuImportance] = { + val importanceArray = Array.fill[CpuImportance](code.length)(new CpuImportance()) + val codeArray = code.toArray + val initialStatus = new CpuStatus(d = SingleStatus(false)) + + var changed = true + val finalImportance = new CpuImportance(a = Important, x = Important, y = Important, c = Important, v = Important, d = Important, z = Important, n = Important) + changed = true + while (changed) { + changed = false + var currentImportance: CpuImportance = finalImportance + for (i <- codeArray.indices.reverse) { + import millfork.assembly.Opcode._ + import millfork.assembly.AddrMode._ + if (importanceArray(i) != currentImportance) { + changed = true + importanceArray(i) = currentImportance + } + codeArray(i) match { + case AssemblyLine(opcode, Relative, MemoryAddressConstant(Label(l)), _) if OpcodeClasses.ShortBranching(opcode) => + val L = l + val labelIndex = codeArray.indexWhere { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true + case _ => false + } + currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance + case _ => + } + codeArray(i) match { + // TODO: JSR? + case AssemblyLine(JMP, Absolute, MemoryAddressConstant(Label(l)), _) => + val L = l + val labelIndex = codeArray.indexWhere { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true + case _ => false + } + currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) + case AssemblyLine(JMP, Indirect, _, _) => + currentImportance = finalImportance + case AssemblyLine(BNE | BEQ, _, _, _) => + currentImportance = currentImportance.copy(z = Important) + case AssemblyLine(BMI | BPL, _, _, _) => + currentImportance = currentImportance.copy(n = Important) + case AssemblyLine(SED | CLD, _, _, _) => + currentImportance = currentImportance.copy(d = Unimportant) + case AssemblyLine(RTS, _, _, _) => + currentImportance = finalImportance + case AssemblyLine(DISCARD_XF, _, _, _) => + currentImportance = currentImportance.copy(x = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant) + case AssemblyLine(DISCARD_YF, _, _, _) => + currentImportance = currentImportance.copy(y = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant) + case AssemblyLine(DISCARD_AF, _, _, _) => + currentImportance = currentImportance.copy(a = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant) + case AssemblyLine(opcode, addrMode, _, _) => + if (OpcodeClasses.ChangesC(opcode)) currentImportance = currentImportance.copy(c = Unimportant) + if (OpcodeClasses.ChangesV(opcode)) currentImportance = currentImportance.copy(v = Unimportant) + if (OpcodeClasses.ChangesNAndZ(opcode)) currentImportance = currentImportance.copy(n = Unimportant, z = Unimportant) + if (OpcodeClasses.OverwritesA(opcode)) currentImportance = currentImportance.copy(a = Unimportant) + if (OpcodeClasses.OverwritesX(opcode)) currentImportance = currentImportance.copy(x = Unimportant) + if (OpcodeClasses.OverwritesY(opcode)) currentImportance = currentImportance.copy(y = Unimportant) + if (OpcodeClasses.ReadsC(opcode)) currentImportance = currentImportance.copy(c = Important) + if (OpcodeClasses.ReadsD(opcode)) currentImportance = currentImportance.copy(d = Important) + if (OpcodeClasses.ReadsV(opcode)) currentImportance = currentImportance.copy(v = Important) + if (OpcodeClasses.ReadsXAlways(opcode)) currentImportance = currentImportance.copy(x = Important) + if (OpcodeClasses.ReadsYAlways(opcode)) currentImportance = currentImportance.copy(y = Important) + if (OpcodeClasses.ReadsAAlways(opcode)) currentImportance = currentImportance.copy(a = Important) + if (OpcodeClasses.ReadsAIfImplied(opcode) && addrMode == Implied) currentImportance = currentImportance.copy(a = Important) + if (addrMode == AbsoluteX || addrMode == IndexedX || addrMode == ZeroPageX) currentImportance = currentImportance.copy(x = Important) + if (addrMode == AbsoluteY || addrMode == IndexedY || addrMode == ZeroPageY) currentImportance = currentImportance.copy(y = Important) + } + } + } +// importanceArray.zip(codeArray).foreach{ +// case (i, y) => if (y.isPrintable) println(f"$y%-32s $i%-32s") +// } +// println("---------------------") + + importanceArray.toList + } +} diff --git a/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala new file mode 100644 index 00000000..48cb84e0 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala @@ -0,0 +1,757 @@ +package millfork.assembly.opt + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly._ +import millfork.env._ +import millfork.error.ErrorReporting + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ + +object FlowInfoRequirement extends Enumeration { + + val NoRequirement, BothFlows, ForwardFlow, BackwardFlow = Value + + def assertForward(x: FlowInfoRequirement.Value): Unit = x match { + case BothFlows | ForwardFlow => () + case NoRequirement | BackwardFlow => ErrorReporting.fatal("Forward flow info required") + } + + def assertBackward(x: FlowInfoRequirement.Value): Unit = x match { + case BothFlows | BackwardFlow => () + case NoRequirement | ForwardFlow => ErrorReporting.fatal("Backward flow info required") + } +} + +class RuleBasedAssemblyOptimization(val name: String, val needsFlowInfo: FlowInfoRequirement.Value, val rules: AssemblyRule*) extends AssemblyOptimization { + + rules.foreach(_.pattern.validate(needsFlowInfo)) + + override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = { + val effectiveCode = code.map(a => a.copy(parameter = a.parameter.quickSimplify)) + val taggedCode = needsFlowInfo match { + case FlowInfoRequirement.NoRequirement => effectiveCode.map(FlowInfo.Default -> _) + case FlowInfoRequirement.BothFlows => FlowAnalyzer.analyze(f, effectiveCode, options) + case FlowInfoRequirement.ForwardFlow => + if (options.flag(CompilationFlag.DetailedFlowAnalysis)) { + QuantumFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s.collapse, CpuImportance())).zip(code) + } else { + CoarseFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s, CpuImportance())).zip(code) + } + case FlowInfoRequirement.BackwardFlow => + ReverseFlowAnalyzer.analyze(f, code).map(i => FlowInfo(CpuStatus(), i)).zip(code) + } + optimizeImpl(f, taggedCode, options) + } + + def optimizeImpl(f: NormalFunction, code: List[(FlowInfo, AssemblyLine)], options: CompilationOptions): List[AssemblyLine] = { + code match { + case Nil => Nil + case head :: tail => + for ((rule, index) <- rules.zipWithIndex) { + val ctx = new AssemblyMatchingContext + rule.pattern.matchTo(ctx, code) match { + case Some(rest: List[(FlowInfo, AssemblyLine)]) => + val matchedChunkToOptimize: List[AssemblyLine] = code.take(code.length - rest.length).map(_._2) + val optimizedChunk: List[AssemblyLine] = rule.result(matchedChunkToOptimize, ctx) + ErrorReporting.debug(s"Applied $name ($index)") + if (needsFlowInfo != FlowInfoRequirement.NoRequirement) { + val before = code.head._1.statusBefore + val after = code(matchedChunkToOptimize.length - 1)._1.importanceAfter + ErrorReporting.trace(s"Before: $before") + ErrorReporting.trace(s"After: $after") + } + matchedChunkToOptimize.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString)) + ErrorReporting.trace(" ↓") + optimizedChunk.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString)) + if (needsFlowInfo != FlowInfoRequirement.NoRequirement) { + return optimizedChunk ++ optimizeImpl(f, rest, options) + } else { + return optimize(f, optimizedChunk ++ rest.map(_._2), options) + } + case None => () + } + } + head._2 :: optimizeImpl(f, tail, options) + } + } +} + +class AssemblyMatchingContext { + private val map = mutable.Map[Int, Any]() + + def addObject(i: Int, o: Any): Boolean = { + if (map.contains(i)) { + map(i) == o + } else { + map(i) = o + true + } + } + + def dontMatch(i: Int, o: Any): Boolean = { + if (map.contains(i)) { + map(i) != o + } else { + false + } + } + + def get[T: Manifest](i: Int): T = { + val t = map(i) + val clazz = implicitly[Manifest[T]].runtimeClass match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + // TODO + case x => x + } + if (clazz.isInstance(t)) { + t.asInstanceOf[T] + } else { + if (i eq null) { + ErrorReporting.fatal(s"Value at index $i is null") + } else { + ErrorReporting.fatal(s"Value at index $i is a ${t.getClass.getSimpleName}, not a ${clazz.getSimpleName}") + } + } + } + + def isExternallyLinearBlock(i: Int): Boolean = { + val labels = mutable.Set[String]() + val jumps = mutable.Set[String]() + get[List[AssemblyLine]](i).foreach { + case AssemblyLine(Opcode.RTS | Opcode.RTI | Opcode.BRK, _, _, _) => + return false + case AssemblyLine(Opcode.JMP, AddrMode.Indirect, _, _) => + return false + case AssemblyLine(Opcode.LABEL, _, MemoryAddressConstant(Label(l)), _) => + labels += l + case AssemblyLine(Opcode.JMP, AddrMode.Absolute, MemoryAddressConstant(Label(l)), _) => + jumps += l + case AssemblyLine(Opcode.JMP, AddrMode.Absolute, _, _) => + return false + case AssemblyLine(_, AddrMode.Relative, MemoryAddressConstant(Label(l)), _) => + jumps += l + case AssemblyLine(br, _, _, _) if OpcodeClasses.ShortBranching(br) => + return false + case _ => () + } + // if a jump leads inside the block, then it's internal + // if a jump leads outside the block, then it's external + jumps --= labels + jumps.isEmpty + } + + def areMemoryReferencesProvablyNonOverlapping(param1: Int, addrMode1: Int, param2: Int, addrMode2: Int): Boolean = { + val p1 = get[Constant](param1).quickSimplify + val a1 = get[AddrMode.Value](addrMode1) + val p2 = get[Constant](param2).quickSimplify + val a2 = get[AddrMode.Value](addrMode2) + import AddrMode._ + val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX) + if (badAddrModes(a1) || badAddrModes(a2)) return false + + def handleKnownDistance(distance: Short): Boolean = { + val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY) + val a1Indexing = indexingAddrModes(a1) + val a2Indexing = indexingAddrModes(a2) + (a1Indexing, a2Indexing) match { + case (false, false) => distance != 0 + case (true, false) => distance > 255 || distance < 0 + case (false, true) => distance > 0 || distance < -255 + case (true, true) => distance > 255 || distance < -255 + } + } + + (p1, p2) match { + case (NumericConstant(n1, _), NumericConstant(n2, _)) => + handleKnownDistance((n2 - n1).toShort) + case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify => + handleKnownDistance(distance.toShort) + case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify => + handleKnownDistance((-distance).toShort) + case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify => + handleKnownDistance((-distance).toShort) + case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify => + handleKnownDistance(distance.toShort) + case (MemoryAddressConstant(MemoryVariable(a, _, _)), MemoryAddressConstant(MemoryVariable(b, _, _))) => + a.takeWhile(_ != '.') != a.takeWhile(_ != '.') // TODO: ??? + case _ => + false + } + } +} + +case class AssemblyRule(pattern: AssemblyPattern, result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) { + +} + +trait AssemblyPattern { + + def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = () + + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] + + def ~(x: AssemblyPattern) = Concatenation(this, x) + + def ~(x: AssemblyLinePattern) = Concatenation(this, x) + + def ~~>(result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) = AssemblyRule(this, result) + + def ~~>(result: List[AssemblyLine] => List[AssemblyLine]) = AssemblyRule(this, (code, _) => result(code)) + + def capture(i: Int) = Capture(i, this) + + def captureLength(i: Int) = CaptureLength(i, this) +} + +case class Capture(i: Int, pattern: AssemblyPattern) extends AssemblyPattern { + override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = + for { + rest <- pattern.matchTo(ctx, code) + } yield { + ctx.addObject(i, code.take(code.length - rest.length).map(_._2)) + rest + } + + override def toString: String = s"(?<$i>$pattern)" +} + +case class CaptureLength(i: Int, pattern: AssemblyPattern) extends AssemblyPattern { + override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = + for { + rest <- pattern.matchTo(ctx, code) + } yield { + ctx.addObject(i, code.length - rest.length) + rest + } + + override def toString: String = s"(?<$i>$pattern)" +} + + +case class Where(predicate: (AssemblyMatchingContext => Boolean)) extends AssemblyPattern { + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + if (predicate(ctx)) Some(code) else None + } + + override def toString: String = "Where(...)" +} + +case class Concatenation(l: AssemblyPattern, r: AssemblyPattern) extends AssemblyPattern { + + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + l.validate(needsFlowInfo) + r.validate(needsFlowInfo) + } + + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + for { + middle <- l.matchTo(ctx, code) + end <- r.matchTo(ctx, middle) + } yield end + } + + override def toString: String = (l, r) match { + case (_: Both, _: Both) => s"($l) · ($r)" + case (_, _: Both) => s"$l · ($r)" + case (_: Both, _) => s"($l) · $r" + case _ => s"$l · $r" + } +} + +case class Many(rule: AssemblyLinePattern) extends AssemblyPattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + rule.validate(needsFlowInfo) + } + + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + var c = code + while (true) { + c match { + case Nil => + return Some(Nil) + case x :: xs => + if (rule.matchLineTo(ctx, x._1, x._2)) { + c = xs + } else { + return Some(c) + } + } + } + None + } + + override def toString: String = s"[$rule]*" +} + +case class ManyWhereAtLeastOne(rule: AssemblyLinePattern, atLeastOneIsThis: AssemblyLinePattern) extends AssemblyPattern { + + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + rule.validate(needsFlowInfo) + } + + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + var c = code + var oneFound = false + while (true) { + c match { + case Nil => + return Some(Nil) + case x :: xs => + if (atLeastOneIsThis.matchLineTo(ctx, x._1, x._2)) { + oneFound = true + } + if (rule.matchLineTo(ctx, x._1, x._2)) { + c = xs + } else { + if (oneFound) { + return Some(c) + } else { + return None + } + } + } + } + None + } + + override def toString: String = s"[∃$atLeastOneIsThis:$rule]*" +} + +case class Opt(rule: AssemblyLinePattern) extends AssemblyPattern { + + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + rule.validate(needsFlowInfo) + } + + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + code match { + case Nil => + Some(Nil) + case x :: xs => + if (rule.matchLineTo(ctx, x._1, x._2)) { + Some(xs) + } else { + Some(code) + } + } + } + + override def toString: String = s"[$rule]?" +} + +trait AssemblyLinePattern extends AssemblyPattern { + def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match { + case Nil => None + case x :: xs => if (matchLineTo(ctx, x._1, x._2)) Some(xs) else None + } + + def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean + + def unary_! : AssemblyLinePattern = Not(this) + + def ? : AssemblyPattern = Opt(this) + + def * : AssemblyPattern = Many(this) + + def + : AssemblyPattern = this ~ Many(this) + + def |(x: AssemblyLinePattern): AssemblyLinePattern = Either(this, x) + + def &(x: AssemblyLinePattern): AssemblyLinePattern = Both(this, x) + + protected def memoryAccessDoesntOverlap(a1: AddrMode.Value, p1: Constant, a2: AddrMode.Value, p2: Constant): Boolean = { + import AddrMode._ + val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX) + if (badAddrModes(a1) || badAddrModes(a2)) return false + val goodAddrModes = Set(Implied, Immediate, Relative) + if (goodAddrModes(a1) || goodAddrModes(a2)) return true + + def handleKnownDistance(distance: Short): Boolean = { + val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY) + val a1Indexing = indexingAddrModes(a1) + val a2Indexing = indexingAddrModes(a2) + (a1Indexing, a2Indexing) match { + case (false, false) => distance != 0 + case (true, false) => distance > 255 || distance < 0 + case (false, true) => distance > 0 || distance < -255 + case (true, true) => distance > 255 || distance < -255 + } + } + + (p1.quickSimplify, p2.quickSimplify) match { + case (NumericConstant(n1, _), NumericConstant(n2, _)) => + handleKnownDistance((n2 - n1).toShort) + case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify => + handleKnownDistance(distance.toShort) + case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify => + handleKnownDistance((-distance).toShort) + case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify => + handleKnownDistance((-distance).toShort) + case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify => + handleKnownDistance(distance.toShort) + case (MemoryAddressConstant(a: ThingInMemory), MemoryAddressConstant(b:ThingInMemory)) => + a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? + case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)), + MemoryAddressConstant(b: ThingInMemory)) => + a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? + case (MemoryAddressConstant(a: ThingInMemory), + CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => + a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? + case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)), + CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => + a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? + case _ => + false + } + } +} + +//noinspection LanguageFeature +object AssemblyLinePattern { + implicit def __implicitOpcodeIn(ops: Set[Opcode.Value]): AssemblyLinePattern = HasOpcodeIn(ops) +} + +case class MatchA(i: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.a match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } +} + +case class MatchX(i: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.x match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } +} + +case class MatchY(i: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.y match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } +} + +case class HasA(value: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.a.contains(value) +} + +case class HasX(value: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.x.contains(value) +} + +case class HasY(value: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.statusBefore.y.contains(value) +} + +case class DoesntMatterWhatItDoesWith(states: State.Value*) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertBackward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + states.forall(state => flowInfo.importanceAfter.isUnimportant(state)) + + override def toString: String = states.mkString("[¯\\_(ツ)_/¯:", ",", "]") +} + +case class HasSet(state: State.Value) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.hasSet(state) +} + +case class HasClear(state: State.Value) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + flowInfo.hasClear(state) +} + +case object Anything extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + true +} + +case class Not(inner: AssemblyLinePattern) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = inner.validate(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + !inner.matchLineTo(ctx, flowInfo, line) + + override def toString: String = "¬" + inner +} + +case class Both(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + l.validate(needsFlowInfo) + r.validate(needsFlowInfo) + } + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + l.matchLineTo(ctx, flowInfo, line) && r.matchLineTo(ctx, flowInfo, line) + + override def toString: String = l + " ∧ " + r +} + +case class Either(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + l.validate(needsFlowInfo) + r.validate(needsFlowInfo) + } + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + l.matchLineTo(ctx, flowInfo, line) || r.matchLineTo(ctx, flowInfo, line) + + override def toString: String = s"($l ∨ $r)" +} + +case object Elidable extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.elidable +} + +case object Linear extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.AllLinear(line.opcode) +} + +case object LinearOrBranch extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.AllLinear(line.opcode) || OpcodeClasses.ShortBranching(line.opcode) +} + +case object LinearOrLabel extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.opcode == Opcode.LABEL || OpcodeClasses.AllLinear(line.opcode) +} + +case object ReadsA extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ReadsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ReadsAIfImplied(line.opcode) +} + +case object ReadsMemory extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.addrMode match { + case AddrMode.Indirect => true + case AddrMode.Implied | AddrMode.Immediate => false + case _ => + OpcodeClasses.ReadsMemoryIfNotImpliedOrImmediate(line.opcode) + } +} + +case object ReadsX extends AssemblyLinePattern { + val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX, AddrMode.AbsoluteIndexedX) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ReadsXAlways(line.opcode) || XAddrModes(line.addrMode) +} + +case object ReadsY extends AssemblyLinePattern { + val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ReadsYAlways(line.opcode) || YAddrModes(line.addrMode) +} + +case object ConcernsC extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ReadsC(line.opcode) && OpcodeClasses.ChangesC(line.opcode) +} + +case object ConcernsA extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ConcernsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ConcernsAIfImplied(line.opcode) +} + +case object ConcernsX extends AssemblyLinePattern { + val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ConcernsXAlways(line.opcode) || XAddrModes(line.addrMode) +} + +case object ConcernsY extends AssemblyLinePattern { + val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ConcernsYAlways(line.opcode) || YAddrModes(line.addrMode) +} + +case object ChangesA extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ChangesAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ChangesAIfImplied(line.opcode) +} + +case object ChangesMemory extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode) +} + +case class DoesntChangeMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = { + val p1 = ctx.get[Constant](param1) + val p2 = line.parameter + val a1 = ctx.get[AddrMode.Value](addrMode1) + val a2 = line.addrMode + val changesSomeMemory = OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode) + !changesSomeMemory || memoryAccessDoesntOverlap(a1, p1, a2, p2) + } +} + +case object ConcernsMemory extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ReadsMemory.matchLineTo(ctx, flowInfo, line) && ChangesMemory.matchLineTo(ctx, flowInfo, line) +} + +case class DoesNotConcernMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = { + val p1 = ctx.get[Constant](param1) + val p2 = line.parameter + val a1 = ctx.get[AddrMode.Value](addrMode1) + val a2 = line.addrMode + memoryAccessDoesntOverlap(a1, p1, a2, p2) + } +} + +case class HasOpcode(op: Opcode.Value) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.opcode == op + + override def toString: String = op.toString +} + +case class HasOpcodeIn(ops: Set[Opcode.Value]) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ops(line.opcode) + + override def toString: String = ops.mkString("{", ",", "}") +} + +case class HasAddrMode(am: AddrMode.Value) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.addrMode == am + + override def toString: String = am.toString +} + +case class HasAddrModeIn(ams: Set[AddrMode.Value]) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ams(line.addrMode) + + override def toString: String = ams.mkString("{", ",", "}") +} + +case class HasImmediate(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + line.addrMode == AddrMode.Immediate && (line.parameter.quickSimplify match { + case NumericConstant(j, _) => (i & 0xff) == (j & 0xff) + case _ => false + }) + + override def toString: String = "#" + i +} + +case class MatchObject(i: Int, f: Function[AssemblyLine, Any]) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.addObject(i, f(line)) + + override def toString: String = s"(?<$i>...)" +} + +case class MatchParameter(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.addObject(i, line.parameter.quickSimplify) + + override def toString: String = s"(?<$i>Param)" +} + +case class DontMatchParameter(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.dontMatch(i, line.parameter.quickSimplify) + + override def toString: String = s"¬(?<$i>Param)" +} + +case class MatchAddrMode(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.addObject(i, line.addrMode) + + override def toString: String = s"¬(?<$i>AddrMode)" +} + +case class MatchOpcode(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.addObject(i, line.opcode) + + override def toString: String = s"¬(?<$i>Op)" +} + +case class MatchImmediate(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + if (line.addrMode == AddrMode.Immediate) { + ctx.addObject(i, line.parameter.quickSimplify) + } else false + + override def toString: String = s"(?<$i>#)" +} + + +case class DoesntChangeIndexingInAddrMode(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + ctx.get[AddrMode.Value](i) match { + case AddrMode.ZeroPageX | AddrMode.AbsoluteX | AddrMode.IndexedX | AddrMode.AbsoluteIndexedX => !OpcodeClasses.ChangesX.contains(line.opcode) + case AddrMode.ZeroPageY | AddrMode.AbsoluteY | AddrMode.IndexedY => !OpcodeClasses.ChangesY.contains(line.opcode) + case _ => true + } + + override def toString: String = s"¬(?<$i>AddrMode)" +} + +case class Before(pattern: AssemblyPattern) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = { + pattern.validate(needsFlowInfo) + } + + override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match { + case Nil => None + case x :: xs => pattern.matchTo(ctx, xs) match { + case Some(m) => Some(xs) + case None => None + } + } + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = ??? +} \ No newline at end of file diff --git a/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala b/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala new file mode 100644 index 00000000..d76eefb5 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala @@ -0,0 +1,8 @@ +package millfork.assembly.opt + +/** + * @author Karol Stasiak + */ +object SizeOptimizations { + +} diff --git a/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala b/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala new file mode 100644 index 00000000..6213d13d --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala @@ -0,0 +1,75 @@ +package millfork.assembly.opt + +import millfork.{CompilationFlag, CompilationOptions, OptimizationPresets} +import millfork.assembly.{AddrMode, AssemblyLine, Opcode} +import millfork.env.NormalFunction +import millfork.error.ErrorReporting + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +object SuperOptimizer extends AssemblyOptimization { + + def optimize(m: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = { + val oldVerbosity = ErrorReporting.verbosity + ErrorReporting.verbosity = -1 + var allOptimizers = OptimizationPresets.Good ++ LaterOptimizations.All + if (options.flag(CompilationFlag.EmitIllegals)) { + allOptimizers ++= UndocumentedOptimizations.All + } + if (options.flag(CompilationFlag.EmitCmosOpcodes)) { + allOptimizers ++= CmosOptimizations.All + } + allOptimizers ++= List( + VariableToRegisterOptimization, + ChangeIndexRegisterOptimizationPreferringX2Y, + ChangeIndexRegisterOptimizationPreferringY2X, + UnusedLabelRemoval) + val seenSoFar = mutable.Set[CodeView]() + val queue = mutable.Queue[(List[AssemblyOptimization], List[AssemblyLine])]() + val leaves = mutable.ListBuffer[(List[AssemblyOptimization], List[AssemblyLine])]() + seenSoFar += viewCode(code) + queue.enqueue(Nil -> code) + while(queue.nonEmpty) { + val (optsSoFar, codeSoFar) = queue.dequeue() + var isLeaf = true + allOptimizers.par.foreach { o => + val optimized = o.optimize(m, codeSoFar, options) + val view = viewCode(optimized) + seenSoFar.synchronized{ + if (!seenSoFar(view)) { + isLeaf = false + seenSoFar += view + queue.enqueue((o :: optsSoFar) -> optimized) + } + } + } + if (isLeaf) { +// println(codeSoFar.map(_.sizeInBytes).sum + " B: " + optsSoFar.reverse.map(_.name).mkString(" -> ")) + leaves += optsSoFar -> codeSoFar + } + } + + val result = leaves.minBy(_._2.map(_.cost).sum) + ErrorReporting.verbosity = oldVerbosity + ErrorReporting.debug(s"Visited ${leaves.size} leaves") + ErrorReporting.debug(s"${code.map(_.sizeInBytes).sum} B -> ${result._2.map(_.sizeInBytes).sum} B: ${result._1.reverse.map(_.name).mkString(" -> ")}") + result._1.reverse.foldLeft(code){(c, opt) => + val n = opt.optimize(m, c, options) +// println(c.mkString("","","")) +// println(n.mkString("","","")) + n + } + result._2 + } + + override val name = "Superoptimizer" + + def viewCode(code: List[AssemblyLine]): CodeView = { + CodeView(code.map(l => l.opcode -> l.addrMode)) + } +} + +case class CodeView(content: List[(Opcode.Value, AddrMode.Value)]) diff --git a/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala b/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala new file mode 100644 index 00000000..456256c6 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala @@ -0,0 +1,340 @@ +package millfork.assembly.opt + +import java.util.concurrent.atomic.AtomicInteger + +import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State} +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.assembly.OpcodeClasses._ +import millfork.env.{Constant, NormalFunction, NumericConstant} + +/** + * @author Karol Stasiak + */ +object UndocumentedOptimizations { + + val counter = new AtomicInteger(30000) + + def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d" + + // TODO: test these + + private val LaxAddrModeRestriction = Not(HasAddrModeIn(Set(AbsoluteX, ZeroPageX, IndexedX, Immediate))) + + //noinspection ScalaUnnecessaryParentheses + val UseLax = new RuleBasedAssemblyOptimization("Using undocumented instruction LAX", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~ + (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX) + }, + (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~ + (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX) + }, + + (HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~ + (HasOpcode(TAX) & Elidable) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX) + }, + (HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~ + (HasOpcode(TXA) & Elidable) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX) + }, + + (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~ + (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) => + code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2) + }, + (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~ + (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) => + code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2) + }, + + (HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~ + (HasOpcode(TAX) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) => + code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2) + }, + (HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~ + (LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~ + (HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) => + code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2) + }, + ) + + val SaxModes: Set[AddrMode.Value] = Set(ZeroPage, IndexedX, ZeroPageY, Absolute) + + val UseSax = new RuleBasedAssemblyOptimization("Using undocumented instruction SAX", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(10) ~ + (HasOpcode(AND) & Elidable & MatchAddrMode(2) & MatchParameter(3) & Not(ReadsX)) ~ + (Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(11) ~ + (HasOpcode(STA) & Elidable & MatchAddrMode(4) & MatchParameter(5) & HasAddrModeIn(SaxModes) & DontMatchParameter(0)) ~ + (Linear & Not(ConcernsA) & Not(ConcernsX) & Not(ChangesMemory)).?.capture(12) ~ + (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~ + (LinearOrLabel & Not(ConcernsX)).*.capture(13) ~ OverwritesX ~~> { (code, ctx) => + val lda = code.head + val ldx = AssemblyLine(LDX, ctx.get[AddrMode.Value](2), ctx.get[Constant](3)) + val sax = AssemblyLine(SAX, ctx.get[AddrMode.Value](4), ctx.get[Constant](5)) + val fragment0 = lda :: ctx.get[List[AssemblyLine]](10) + val fragment1 = ldx :: ctx.get[List[AssemblyLine]](11) + val fragment2 = sax :: ctx.get[List[AssemblyLine]](12) + val fragment3 = ctx.get[List[AssemblyLine]](13) + List(fragment0, fragment1, fragment2, fragment3).flatten + }, + ) + + def andConstant(const: Constant, mask: Int): Option[Long] = const match { + case NumericConstant(n, _) => Some(n & mask) + case _ => None + } + + val UseAnc = new RuleBasedAssemblyOptimization("Using undocumented instruction ANC", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(LDA) & HasImmediate(0)) ~ + (Elidable & HasOpcode(CLC)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))), + (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.C)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))), + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~ + (Elidable & HasOpcode(CLC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + Where(c => andConstant(c.get[Constant](0), 0x80).contains(0x80)) ~ + (Elidable & HasOpcode(SEC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + (Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + (Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0) & HasClear(State.C)) ~ + Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0) & HasSet(State.C)) ~ + Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~ + (Elidable & HasOpcodeIn(Set(ROL, ASL)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.Z, State.N, State.A)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))), + ) + + val UseSbx = new RuleBasedAssemblyOptimization("Using undocumented instruction SBX", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~ + Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List( + AssemblyLine.implied(TXA), + AssemblyLine.immediate(SBX, ctx.get[Int](0)), + )), + (Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~ + Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List( + AssemblyLine.implied(TXA), + AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)), + )), + HasOpcode(TXA) ~ + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasClear(State.C) & HasClear(State.D) & HasOpcode(ADC) & MatchImmediate(0)) ~ + (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List( + code.head, + AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)), + )), + HasOpcode(TXA) ~ + (Elidable & HasOpcode(SEC)).? ~ + (Elidable & HasSet(State.C) & HasClear(State.D) & HasOpcode(SBC) & MatchImmediate(0)) ~ + (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List( + code.head, + AssemblyLine.immediate(SBX, ctx.get[Int](0)), + )), + ) + + + val UseAlr = new RuleBasedAssemblyOptimization("Using undocumented instruction ALR", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~ + (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~~> { (code, ctx) => + List(AssemblyLine.immediate(ALR, code.head.parameter)) + }, + (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~ + (Elidable & HasOpcode(CLC)) ~~> { (code, ctx) => + List(AssemblyLine.immediate(ALR, 0xFE)) + }, + ) + + val UseArr = new RuleBasedAssemblyOptimization("Using undocumented instruction ARR", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (HasClear(State.D) & Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~ + (Elidable & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) => + List(AssemblyLine.immediate(ARR, code.head.parameter)) + }, + ) + + private def trivialSequence1(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) = + (Elidable & HasOpcode(o1) & HasAddrMode(Absolute) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~ + (Elidable & HasOpcode(o2) & HasAddrMode(Absolute) & MatchParameter(1)) ~~> { (code, ctx) => + code.tail.init :+ AssemblyLine(combined, Absolute, ctx.get[Constant](1)) + } + + private def trivialSequence2(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) = + (Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~ + (Elidable & HasOpcode(o2) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + code.tail.init :+ AssemblyLine(combined, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) + } + + // ROL c LDA c AND d => LDA d RLA c + private def trivialCommutativeSequence(o1: Opcode.Value, o2: Opcode.Value, combined: Opcode.Value) = { + (Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(o2) & MatchAddrMode(2) & MatchParameter(3)) ~~> { code => + List(code(2).copy(opcode = LDA), code(1).copy(opcode = combined)) + } + } + + val UseSlo = new RuleBasedAssemblyOptimization("Using undocumented instruction SLO", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + trivialSequence1(ASL, ORA, Not(ConcernsC), SLO), + trivialSequence2(ASL, ORA, Not(ConcernsC), SLO), + trivialCommutativeSequence(ASL, ORA, SLO), + (Elidable & HasOpcode(ASL) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ConcernsMemory)).* ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SLO, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) + }, + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~ + (Elidable & HasOpcode(ASL) & HasAddrMode(Implied)) ~ + (Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++ + ctx.get[List[AssemblyLine]](2) ++ + ctx.get[List[AssemblyLine]](3) + }, + ) + + val UseSre = new RuleBasedAssemblyOptimization("Using undocumented instruction SRE", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + trivialSequence1(LSR, EOR, Not(ConcernsC), SRE), + trivialSequence2(LSR, EOR, Not(ConcernsC), SRE), + trivialCommutativeSequence(LSR, EOR, SRE), + (Elidable & HasOpcode(LSR) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ConcernsMemory)).* ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) + }, + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~ + (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~ + (Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) => + List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++ + ctx.get[List[AssemblyLine]](2) ++ + ctx.get[List[AssemblyLine]](3) + }, + ) + + val UseRla = new RuleBasedAssemblyOptimization("Using undocumented instruction RLA", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + trivialSequence1(ROL, AND, Not(ConcernsC), RLA), + trivialSequence2(ROL, AND, Not(ConcernsC), RLA), + trivialCommutativeSequence(ROL, AND, RLA), + ) + + val UseRra = new RuleBasedAssemblyOptimization("Using undocumented instruction RRA", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + // TODO: is it ok? carry flag and stuff? + trivialSequence1(ROR, ADC, Not(ConcernsC), RRA), + trivialSequence2(ROR, ADC, Not(ConcernsC), RRA), + trivialCommutativeSequence(ROR, ADC, RRA), + ) + + val UseDcp = new RuleBasedAssemblyOptimization("Using undocumented instruction DCP", + needsFlowInfo = FlowInfoRequirement.BothFlows, + trivialSequence1(DEC, CMP, Not(ConcernsC), DCP), + trivialSequence2(DEC, CMP, Not(ConcernsC), DCP), + (Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~ + (Elidable & HasOpcode(TAX)) ~ + (Elidable & HasOpcode(DEC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.head.copy(opcode = LDY), code.last.copy(opcode = DCP, addrMode = AbsoluteY)) + }, + (Elidable & HasOpcode(DEC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(CMP) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code => + List(code(2).copy(opcode = LDA), code(1).copy(opcode = DCP)) + } + ) + + val UseIsc = new RuleBasedAssemblyOptimization("Using undocumented instruction ISC", + needsFlowInfo = FlowInfoRequirement.BothFlows, + trivialSequence1(INC, SBC, Not(ReadsC), ISC), + trivialSequence2(INC, SBC, Not(ReadsC), ISC), + (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~ + (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("is") + List( + AssemblyLine.relative(BCC, label), + code.last.copy(opcode = ISC), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + val label = getNextLabel("is") + List( + AssemblyLine.relative(BCC, label), + code.last.copy(opcode = ISC), + AssemblyLine.label(label)) + }, + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~ + (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = ISC)) + }, + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasClear(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = ISC)) + }, + (Elidable & HasOpcode(SEC)).? ~ + (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D) & HasSet(State.C)) ~ + (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = ISC)) + }, + (Elidable & HasOpcode(SEC)).? ~ + (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasSet(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~ + (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.last.copy(opcode = ISC)) + }, + (Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~ + (Elidable & HasOpcode(TAX)) ~ + (Elidable & HasOpcode(INC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code => + List(code.head.copy(opcode = LDY), code.last.copy(opcode = ISC, addrMode = AbsoluteY)) + }, + (Elidable & HasOpcode(INC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(CMP) & HasClear(State.D) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code => + List(code(2).copy(opcode = LDA), AssemblyLine.implied(SEC), code(1).copy(opcode = ISC)) + } + ) + + val All: List[AssemblyOptimization] = List( + UseLax, + UseSax, + UseSbx, + UseAnc, + UseSlo, + UseSre, + UseAlr, + UseArr, + UseRla, + UseRra, + UseIsc, + UseDcp, + ) +} diff --git a/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala b/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala new file mode 100644 index 00000000..dd919b48 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala @@ -0,0 +1,38 @@ +package millfork.assembly.opt + +import millfork.CompilationOptions +import millfork.assembly.AddrMode._ +import millfork.assembly.Opcode._ +import millfork.assembly.{AddrMode, AssemblyLine} +import millfork.env._ +import millfork.error.ErrorReporting + +/** + * @author Karol Stasiak + */ +object UnusedLabelRemoval extends AssemblyOptimization { + + override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = { + val usedLabels = code.flatMap { + case AssemblyLine(LABEL, _, _, _) => None + case AssemblyLine(_, _, MemoryAddressConstant(Label(l)), _) => Some(l) + case _ => None + }.toSet + val definedLabels = code.flatMap { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => Some(l).filter(_.startsWith(".")) + case _ => None + }.toSet + val toRemove = definedLabels -- usedLabels + if (toRemove.nonEmpty) { + ErrorReporting.debug("Removing labels: " + toRemove.mkString(", ")) + code.filterNot { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => toRemove(l) + case _ => false + } + } else { + code + } + } + + override def name = "Unused label removal" +} diff --git a/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala b/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala new file mode 100644 index 00000000..e6bc2b12 --- /dev/null +++ b/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala @@ -0,0 +1,322 @@ +package millfork.assembly.opt + +import millfork.CompilationOptions +import millfork.assembly.{AddrMode, AssemblyLine} +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.env._ +import millfork.error.ErrorReporting + +import scala.annotation.tailrec + +/** + * @author Karol Stasiak + */ +object VariableToRegisterOptimization extends AssemblyOptimization { + + // If any of these opcodes is present within a method, + // then it's too hard to assign any variable to a register. + private val opcodesThatAlwaysPrecludeXAllocation = Set(JSR, STX, TXA, PHX, PLX, INX, DEX, CPX, SBX, SAX) + + private val opcodesThatAlwaysPrecludeYAllocation = Set(JSR, STY, TYA, PHY, PLY, INY, DEY, CPY) + + // If any of these opcodes is used on a variable + // then it's too hard to assign that variable to a register. + // Also, LDY prevents assigning a variable to X and LDX prevents assigning a variable to Y. + private val opcodesThatCannotBeUsedWithIndexRegistersAsParameters = + Set(EOR, ORA, AND, BIT, ADC, SBC, CMP, CPX, CPY, STY, STX) + + override def name = "Allocating variables to index registers" + + + override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = { + val paramVariables = f.params match { + case NormalParamSignature(ps) => + ps.map(_.name).toSet + case _ => + // assembly functions do not get this optimization + return code + } + val stillUsedVariables = code.flatMap { + case AssemblyLine(_, _, MemoryAddressConstant(th), _) => Some(th.name) + case _ => None + }.toSet + val localVariables = f.environment.getAllLocalVariables.filter { + case MemoryVariable(name, typ, VariableAllocationMethod.Auto) => + typ.size == 1 && !paramVariables(name) && stillUsedVariables(name) + case _ => false + } + + val candidates = None :: localVariables.map(v => Option(v.name)) + + val variants = for { + vx <- candidates.par + vy <- candidates + if vx != vy + (score, prologueLength) <- canBeInlined(vx, vy, code.tail, Some(1)) + if prologueLength >= 1 + } yield (score, prologueLength, vx, vy) + + if (variants.isEmpty) { + return code + } + + val (_, bestPrologueLength, bestX, bestY) = variants.max + + if ((bestX.isDefined || bestY.isDefined) && bestPrologueLength != 0xffff) { + (bestX, bestY) match { + case (Some(x), Some(y)) => ErrorReporting.debug(s"Inlining $x to X and $y to Y") + case (Some(x), None) => ErrorReporting.debug(s"Inlining $x to X") + case (None, Some(y)) => ErrorReporting.debug(s"Inlining $y to Y") + case _ => + } + bestX.foreach(f.environment.removeVariable) + bestY.foreach(f.environment.removeVariable) + code.take(bestPrologueLength) ++ inlineVars(bestX, bestY, code.drop(bestPrologueLength)) + } else { + code + } + } + + + private def add(i: Int) = (p: (Int, Int)) => (p._1 + i) -> p._2 + + private def mark(i: Option[Int]) = (p: (Int, Int)) => p._1 -> i.getOrElse(p._2) + + def canBeInlined(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine], instrCounter: Option[Int]): Option[(Int, Int)] = { + val vx = xCandidate.getOrElse("-") + val vy = yCandidate.getOrElse("-") + val next = instrCounter.map(_ + 1) + val next2 = instrCounter.map(_ + 2) + lines match { + case AssemblyLine(_, Immediate, SubbyteConstant(MemoryAddressConstant(th), _), _) :: xs + if th.name == vx || th.name == vy => + // if an address of a variable is used, then that variable cannot be assigned to a register + None + case AssemblyLine(_, Immediate, HalfWordConstant(MemoryAddressConstant(th), _), _) :: xs + if th.name == vx || th.name == vy => + // if an address of a variable is used, then that variable cannot be assigned to a register + None + + case AssemblyLine(_, AbsoluteX | AbsoluteY | ZeroPageX | ZeroPageY, MemoryAddressConstant(th), _) :: xs => + // if a variable is used as an array, then it cannot be assigned to a register + if (th.name == vx || th.name == vy) { + None + } else { + canBeInlined(xCandidate, yCandidate, xs, next) + } + + case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx && (opcode == LDY || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) => + // if a variable is used by some opcodes, then it cannot be assigned to a register + None + + case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy && (opcode == LDX || opcode == LAX || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) => + // if a variable is used by some opcodes, then it cannot be assigned to a register + None + + case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), elidable) :: xs + if xCandidate.isDefined => + // if a register is populated with a different variable, then this variable cannot be assigned to that register + // removing LDX saves 3 cycles + if (elidable && th.name == vx) { + canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter)) + } else { + None + } + + case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), elidable) :: xs + if xCandidate.isDefined => + // LAX = LDX-LDA, and since LDX simplifies to nothing and LDA simplifies to TXA, + // LAX simplifies to TXA, saving two bytes + if (elidable && th.name == vx) { + canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter)) + } else { + None + } + + case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), elidable) :: xs if yCandidate.isDefined => + // if a register is populated with a different variable, then this variable cannot be assigned to that register + // removing LDX saves 3 cycles + if (elidable && th.name == vy) { + canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter)) + } else { + None + } + + case AssemblyLine(LDX, _, _, _) :: xs if xCandidate.isDefined => + // if a register is populated with something else than a variable, then no variable cannot be assigned to that register + None + + case AssemblyLine(LDY, _, _, _) :: xs if yCandidate.isDefined => + // if a register is populated with something else than a variable, then no variable cannot be assigned to that register + None + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAX, _, _, elidable2) :: xs + if xCandidate.isDefined => + // a variable cannot be inlined if there is TAX not after LDA of that variable + // but LDA-TAX can be simplified to TXA + if (elidable && elidable2 && th.name == vx) { + canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter)) + } else { + None + } + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAY, _, _, elidable2) :: xs + if yCandidate.isDefined => + // a variable cannot be inlined if there is TAY not after LDA of that variable + // but LDA-TAY can be simplified to TYA + if (elidable && elidable2 && th.name == vy) { + canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter)) + } else { + None + } + + case AssemblyLine(LDA | STA | INC | DEC, Absolute, MemoryAddressConstant(th), elidable) :: xs => + // changing LDA->TXA, STA->TAX, INC->INX, DEC->DEX saves 2 cycles + if (th.name == vy || th.name == vx) { + if (elidable) canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter)) + else None + } else { + canBeInlined(xCandidate, yCandidate, xs, next) + } + + case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined => + // a variable cannot be inlined if there is TAX not after LDA of that variable + if (instrCounter.isDefined) { + canBeInlined(xCandidate, yCandidate, xs, next) + } else None + + case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined => + // a variable cannot be inlined if there is TAY not after LDA of that variable + if (instrCounter.isDefined) { + canBeInlined(xCandidate, yCandidate, xs, next) + } else None + + case AssemblyLine(LABEL, _, _, _) :: xs => + // labels always end the initial section + canBeInlined(xCandidate, yCandidate, xs, None).map(mark(instrCounter)) + + case x :: xs => + if (instrCounter.isDefined) { + canBeInlined(xCandidate, yCandidate, xs, next) + } else { + if (xCandidate.isDefined && opcodesThatAlwaysPrecludeXAllocation(x.opcode)) { + None + } else if (yCandidate.isDefined && opcodesThatAlwaysPrecludeYAllocation(x.opcode)) { + None + } else { + canBeInlined(xCandidate, yCandidate, xs, next) + } + } + + case Nil => Some(0 -> -1) + } + } + + def inlineVars(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine]): List[AssemblyLine] = { + val vx = xCandidate.getOrElse("-") + val vy = yCandidate.getOrElse("-") + lines match { + case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + AssemblyLine.implied(INX) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy => + AssemblyLine.implied(INY) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + AssemblyLine.implied(DEX) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy => + AssemblyLine.implied(DEY) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy => + inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAX, _, _, true) :: xs + if th.name == vx => + // these TXA's may get optimized away by a different optimization + AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAY, _, _, true) :: xs + if th.name == vy => + // these TYA's may get optimized away by a different optimization + AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs + if th.name == vx && doesntUseX(am) => + // these TXA's may get optimized away by a different optimization + AssemblyLine(LDX, am, param) :: AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs + if th.name == vy && doesntUseY(am) => + // these TYA's may get optimized away by a different optimization + AssemblyLine(LDY, am, param) :: AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs + if th.name == vx && doesntUseXOrY(am) => + // ditto + AssemblyLine.implied(TXA) :: AssemblyLine(CPX, am, param) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs + if th.name == vy && doesntUseXOrY(am) => + // ditto + AssemblyLine.implied(TYA) :: AssemblyLine(CPY, am, param) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy => + AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vx => + AssemblyLine.implied(TAX) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs + if th.name == vy => + AssemblyLine.implied(TAY) :: inlineVars(xCandidate, yCandidate, xs) + + case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined => + ErrorReporting.fatal("Unexpected TAX") + + case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined => + ErrorReporting.fatal("Unexpected TAY") + + case x :: xs => x :: inlineVars(xCandidate, yCandidate, xs) + + case Nil => Nil + } + } + + def doesntUseY(am: AddrMode.Value): Boolean = am match { + case AbsoluteY | ZeroPageY | IndexedY => false + case _ => true + } + + def doesntUseX(am: AddrMode.Value): Boolean = am match { + case AbsoluteX | ZeroPageX | IndexedX => false + case _ => true + } + + def doesntUseXOrY(am: AddrMode.Value): Boolean = am match { + case Immediate | ZeroPage | Absolute | Relative | Indirect => true + case _ => false + } +} diff --git a/src/main/scala/millfork/cli/CliOption.scala b/src/main/scala/millfork/cli/CliOption.scala new file mode 100644 index 00000000..91925f55 --- /dev/null +++ b/src/main/scala/millfork/cli/CliOption.scala @@ -0,0 +1,201 @@ +package millfork.cli + +/** + * @author Karol Stasiak + */ +trait CliOption[T, O <: CliOption[T, O]] { + this: O => + def toStrings(firstTab: Int): List[String] = { + val fl = firstLine + if (_description == "") { + List(fl) + } else if (fl.length < firstTab) { + List(fl.padTo(firstTab, ' ') + _description) + } else { + List(fl, "".padTo(firstTab, ' ') + _description) + } + } + + protected def firstLine: String = names.mkString(" | ") + + def names: Seq[String] + + private[cli] def length: Int + + private[cli] val _shortName: String + private[cli] var _description: String = "" + private[cli] var _hidden = false + private[cli] var _maxEncounters = 1 + private[cli] var _minEncounters = 0 + private[cli] var _actualEncounters = 0 + private[cli] var _onTooFew: Option[Int => Unit] = None + private[cli] var _onTooMany: Option[Int => Unit] = None + + def validate(): Boolean = { + var ok = true + if (_actualEncounters < _minEncounters) { + _onTooFew.fold(throw new IllegalArgumentException(s"Too few ${_shortName} options: required ${_minEncounters}, given ${_actualEncounters}"))(_ (_actualEncounters)) + ok = false + } + if (_actualEncounters > _maxEncounters) { + _onTooMany.fold()(_ (_actualEncounters)) + ok = false + } + ok + } + + def onWrongNumber(action: Int => Unit): Unit = { + _onTooFew = Some(action) + _onTooMany = Some(action) + } + + def onTooFew(action: Int => Unit): Unit = { + _onTooFew = Some(action) + } + + def onTooMany(action: Int => Unit): Unit = { + _onTooMany = Some(action) + } + + def encounter(): Unit = { + _actualEncounters += 1 + } + + def description(d: String): O = { + _description = d + this + } + + def hidden(): O = { + _hidden = true + this + } + + def minCount(count: Int): O = { + _minEncounters = count + this + } + + def maxCount(count: Int): O = { + _maxEncounters = count + this + } + + def required(): O = minCount(1) + + def repeatable(): O = maxCount(Int.MaxValue) +} + +class Fluff[T](val text: Seq[String]) extends CliOption[T, Fluff[T]] { + this.repeatable() + + override def toStrings(firstTab: Int): List[String] = text.toList + + override def length = 0 + + override val _shortName = "" + + override def names = Nil +} + +class NoMoreOptions[T](val names: Seq[String]) extends CliOption[T, NoMoreOptions[T]] { + this.repeatable() + + override def length = 1 + + override val _shortName = names.head +} + +class UnknownParamOption[T] extends CliOption[T, UnknownParamOption[T]] { + this._hidden = true + + override def length = 0 + + val names: Seq[String] = Nil + private var _action: ((String, T) => T) = (_, x) => x + + def action(a: ((String, T) => T)): UnknownParamOption[T] = { + _action = a + this + } + + def encounter(value: String, t: T): T = { + encounter() + _action(value, t) + } + + override private[cli] val _shortName = "" +} + +class FlagOption[T](val names: Seq[String]) extends CliOption[T, FlagOption[T]] { + override def length = 1 + + private var _action: (T => T) = x => x + + def action(a: (T => T)): FlagOption[T] = { + _action = a + this + } + + def encounter(t: T): T = { + encounter() + _action(t) + } + + override val _shortName = names.head +} + +class BooleanOption[T](val trueName: String, val falseName: String) extends CliOption[T, BooleanOption[T]] { + override def length = 1 + + private var _action: ((T,Boolean) => T) = (x,_) => x + + def action(a: ((T,Boolean) => T)): BooleanOption[T] = { + _action = a + this + } + + def encounter(asName: String, t: T): T = { + encounter() + if (asName == trueName) { + return _action(t, true) + } + if (asName == falseName) { + return _action(t, false) + } + t + } + + override val _shortName = names.head + + override protected def firstLine: String = trueName + " | " + falseName + + override def names = Seq(trueName, falseName) +} + +class ParamOption[T](val names: Seq[String]) extends CliOption[T, ParamOption[T]] { + + override protected def firstLine: String = names.mkString(" | ") + " " + _paramPlaceholder + + override def length = 2 + + private var _action: ((String, T) => T) = (_, x) => x + private var _paramPlaceholder: String = "" + + def placeholder(p: String): ParamOption[T] = { + _paramPlaceholder = p + this + } + + def action(a: ((String, T) => T)): ParamOption[T] = { + _action = a + this + } + + def encounter(value: String, t: T): T = { + encounter() + _action(value, t) + } + + override val _shortName = names.head +} \ No newline at end of file diff --git a/src/main/scala/millfork/cli/CliParser.scala b/src/main/scala/millfork/cli/CliParser.scala new file mode 100644 index 00000000..25c7563f --- /dev/null +++ b/src/main/scala/millfork/cli/CliParser.scala @@ -0,0 +1,81 @@ +package millfork.cli + +import fastparse.core.Parsed.Failure + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +class CliParser[T] { + + private val options = mutable.ArrayBuffer[CliOption[T, _]]() + private val mapFlags = mutable.Map[String, CliOption[T, _]]() + private val mapOptions = mutable.Map[String, CliOption[T, _]]() + private val _default = new UnknownParamOption[T]().action((p, _) => throw new IllegalArgumentException(s"Unknown option $p")) + private var _status: Option[CliStatus.Value] = None + options += _default + + private def add[O <: CliOption[T, _]](o: O) = { + options += o + o.length match { + case 1 => + o.names.foreach { n => mapFlags(n) = o } + case 2 => + o.names.foreach { n => mapOptions(n) = o } + case _ => () + } + o + } + + def parse(context: T, args: List[String]): (CliStatus.Value, T) = { + val t = parseInner(context, args) + _status.getOrElse(if (options.forall(_.validate())) CliStatus.Ok else CliStatus.Failed) -> t + } + + def assumeStatus(s: CliStatus.Value): Unit = { + _status = Some(s) + } + + private def parseInner(context: T, args: List[String]): T = { + args match { + case k :: v :: xs if mapOptions.contains(k) => + mapOptions(k) match { + case p: ParamOption[T] => parseInner(p.encounter(v, context), xs) + case _ => ??? + } + case k :: xs if mapFlags.contains(k) => + mapFlags(k) match { + case p: FlagOption[T] => + parseInner(p.encounter(context), xs) + case p: BooleanOption[T] => + parseInner(p.encounter(k, context), xs) + case p: NoMoreOptions[T] => + p.encounter() + xs.foldLeft(context)((t, x) => _default.encounter(x, t)) + case _ => ??? + } + case x :: xs => + parseInner(_default.encounter(x, context), xs) + case Nil => context + } + } + + + def fluff(text: String*): Unit = add(new Fluff[T](text)) + + def flag(names: String*): FlagOption[T] = add(new FlagOption[T](names)) + + def boolean(trueName: String, falseName: String): BooleanOption[T] = add(new BooleanOption[T](trueName, falseName)) + + def endOfFlags(names: String*): NoMoreOptions[T] = add(new NoMoreOptions[T](names)) + + def default: UnknownParamOption[T] = _default + + def printHelp(firstTab: Int): List[String] = { + options.filterNot(_._hidden).toList.flatMap(_.toStrings(firstTab)) + } + + def parameter(names: String*): ParamOption[T] = add(new ParamOption[T](names)) + +} diff --git a/src/main/scala/millfork/cli/CliStatus.scala b/src/main/scala/millfork/cli/CliStatus.scala new file mode 100644 index 00000000..800d5782 --- /dev/null +++ b/src/main/scala/millfork/cli/CliStatus.scala @@ -0,0 +1,8 @@ +package millfork.cli + +/** + * @author Karol Stasiak + */ +object CliStatus extends Enumeration { + val Ok, Failed, Quit = Value +} diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala new file mode 100644 index 00000000..452962d6 --- /dev/null +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -0,0 +1,832 @@ +package millfork.compiler + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly._ +import millfork.env._ +import millfork.node._ +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.error.ErrorReporting + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.reflect.macros.blackbox + + +object ComparisonType extends Enumeration { + val Equal, NotEqual, + LessUnsigned, LessSigned, + GreaterUnsigned, GreaterSigned, + LessOrEqualUnsigned, LessOrEqualSigned, + GreaterOrEqualUnsigned, GreaterOrEqualSigned = Value + + def flip(x: ComparisonType.Value): ComparisonType.Value = x match { + case LessUnsigned => GreaterUnsigned + case GreaterUnsigned => LessUnsigned + case LessOrEqualUnsigned => GreaterOrEqualUnsigned + case GreaterOrEqualUnsigned => LessOrEqualUnsigned + case LessSigned => GreaterSigned + case GreaterSigned => LessSigned + case LessOrEqualSigned => GreaterOrEqualSigned + case GreaterOrEqualSigned => LessOrEqualSigned + case _ => x + } + + def negate(x: ComparisonType.Value): ComparisonType.Value = x match { + case LessUnsigned => GreaterOrEqualUnsigned + case GreaterUnsigned => LessOrEqualUnsigned + case LessOrEqualUnsigned => GreaterUnsigned + case GreaterOrEqualUnsigned => LessUnsigned + case LessSigned => GreaterOrEqualSigned + case GreaterSigned => LessOrEqualSigned + case LessOrEqualSigned => GreaterSigned + case GreaterOrEqualSigned => LessSigned + case Equal => NotEqual + case NotEqual => Equal + } +} + +/** + * @author Karol Stasiak + */ +object BuiltIns { + + object IndexChoice extends Enumeration { + val RequireX, PreferX, PreferY = Value + } + + def wrapInSedCldIfNeeded(decimal: Boolean, code: List[AssemblyLine]): List[AssemblyLine] = { + if (decimal) { + AssemblyLine.implied(SED) :: (code :+ AssemblyLine.implied(CLD)) + } else { + code + } + } + + def staTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == STA) x.copy(opcode = op) else x) + + def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x) + + def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean): List[AssemblyLine] = { + val env = ctx.env + val parts: (List[AssemblyLine], List[AssemblyLine]) = env.eval(source).fold { + val b = env.get[Type]("byte") + source match { + case VariableExpression(name) => + val v = env.get[Variable](name) + if (v.typ.size > 1) { + ErrorReporting.error(s"Variable `$name` is too big for a built-in operation", source.position) + return Nil + } + Nil -> AssemblyLine.variable(ctx, opcode, v) + case IndexedExpression(arrayName, index) => + indexChoice match { + case IndexChoice.RequireX | IndexChoice.PreferX => + val array = env.getArrayOrPointer(arrayName) + val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.X, b)), NoBranching) + val baseAddress = array match { + case c: ConstantThing => c.value + case a: MlArray => a.toAddress + } + calculateIndex -> List(AssemblyLine.absoluteX(opcode, baseAddress)) + case IndexChoice.PreferY => + val array = env.getArrayOrPointer(arrayName) + val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.Y, b)), NoBranching) + val baseAddress = array match { + case c: ConstantThing => c.value + case a: MlArray => a.toAddress + } + calculateIndex -> List(AssemblyLine.absoluteY(opcode, baseAddress)) + } + case f: FunctionCallExpression if commutative => + // TODO: is it ok? + return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), f, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(opcode, 0x101), + AssemblyLine.implied(INX), + AssemblyLine.implied(TXS)) + case _ => + ErrorReporting.error("Right-hand-side expression is too complex", source.position) + return Nil + } + } { + const => + if (const.requiredSize > 1) { + ErrorReporting.error("Constant too big for a built-in operation", source.position) + } + Nil -> List(AssemblyLine.immediate(opcode, const)) + } + val preparations = parts._1 + val finalRead = parts._2 + if (preserveA && AssemblyLine.treatment(preparations, State.A) != Treatment.Unchanged) { + AssemblyLine.implied(PHA) :: (preparations ++ (AssemblyLine.implied(PLA) :: finalRead)) + } else { + preparations ++ finalRead + } + } + + def insertBeforeLast(item: AssemblyLine, list: List[AssemblyLine]): List[AssemblyLine] = list match { + case Nil => Nil + case last :: dex :: txs :: Nil if dex.opcode == DEX && txs.opcode == TXS => item :: last :: dex :: txs :: Nil + case last :: inx :: txs :: Nil if inx.opcode == INX && txs.opcode == TXS => item :: last :: inx :: txs :: Nil + case last :: Nil => item :: last :: Nil + case first :: rest => first :: insertBeforeLast(item, rest) + } + + def compileAddition(ctx: CompilationContext, params: List[(Boolean, Expression)], decimal: Boolean): List[AssemblyLine] = { + if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) { + ErrorReporting.warn("Unsupported decimal operation", ctx.options, params.head._2.position) + } + // if (params.isEmpty) { + // return Nil + // } + val env = ctx.env + val b = env.get[Type]("byte") + val sortedParams = params.sortBy { case (subtract, expr) => + val constPart = env.eval(expr) match { + case Some(NumericConstant(_, _)) => "Z" + case Some(_) => "Y" + case None => expr match { + case VariableExpression(_) => "V" + case IndexedExpression(_, LiteralExpression(_, _)) => "K" + case IndexedExpression(_, VariableExpression(_)) => "J" + case IndexedExpression(_, _) => "I" + case _ => "A" + } + } + val subtractPart = if (subtract) "X" else "P" + constPart + subtractPart + } + // TODO: merge constants + val normalizedParams = sortedParams + + val h = normalizedParams.head + val firstParamCompiled = MlCompiler.compile(ctx, h._2, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val firstParamSignCompiled = if (h._1) { + List(AssemblyLine.immediate(EOR, 0xff), AssemblyLine.implied(SEC), AssemblyLine.immediate(ADC, 0)) + } else { + Nil + } + + val remainingParamsCompiled = normalizedParams.tail.flatMap { p => + if (p._1) { + insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false)) + } else { + insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true)) + } + } + + wrapInSedCldIfNeeded(decimal, firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled) + } + + def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + + val sortedParams = params.sortBy { expr => + ctx.env.eval(expr) match { + case Some(NumericConstant(_, _)) => "Z" + case Some(_) => "Y" + case None => expr match { + case VariableExpression(_) => "V" + case IndexedExpression(_, LiteralExpression(_, _)) => "K" + case IndexedExpression(_, VariableExpression(_)) => "J" + case IndexedExpression(_, _) => "I" + case _ => "A" + } + } + } + + val h = sortedParams.head + val firstParamCompiled = MlCompiler.compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + + val remainingParamsCompiled = sortedParams.tail.flatMap { p => + simpleOperation(opcode, ctx, p, IndexChoice.PreferY, preserveA = true, commutative = true) + } + + firstParamCompiled ++ remainingParamsCompiled + } + + def compileShiftOps(opcode: Opcode.Value, ctx: CompilationContext, l: Expression, r: Expression): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + val firstParamCompiled = MlCompiler.compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + ctx.env.eval(r) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(v, _)) if v > 0 => + firstParamCompiled ++ List.fill(v.toInt)(AssemblyLine.implied(opcode)) + case _ => + ErrorReporting.error("Cannot shift by a non-constant amount") + Nil + } + } + + def compileNonetOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val (ldaHi, ldaLo) = lhs match { + case v: VariableExpression => + val variable = env.get[Variable](v.name) + AssemblyLine.variable(ctx, LDA, variable, 1) -> AssemblyLine.variable(ctx, LDA, variable, 0) + case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => + AssemblyLine.variable(ctx, LDA, env.get[Variable](h.name), 0) -> AssemblyLine.variable(ctx, LDA, env.get[Variable](l.name), 0) + case _ => + ??? + } + env.eval(rhs) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(shift, _)) if shift > 0 => + if (ctx.options.flag(CompilationFlag.RorWarning)) + ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position) + ldaHi ++ List(AssemblyLine.implied(ROR)) ++ ldaLo ++ List(AssemblyLine.implied(ROR)) ++ List.fill(shift.toInt - 1)(AssemblyLine.implied(LSR)) + case _ => + ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO + Nil + } + } + + def compileInPlaceByteShiftOps(opcode: Opcode.Value, ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + env.eval(rhs) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(v, _)) if v > 0 => + val result = simpleOperation(opcode, ctx, lhs, IndexChoice.RequireX, preserveA = true, commutative = false) + result ++ List.fill(v.toInt - 1)(result.last) + case _ => + ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO + Nil + } + } + + def compileInPlaceWordOrLongShiftOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, aslRatherThanLsr: Boolean): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val targetBytes = lhs match { + case v: VariableExpression => + val variable = env.get[Variable](v.name) + List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } + case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => + List( + AssemblyLine.variable(ctx, STA, env.get[Variable](l.name)), + AssemblyLine.variable(ctx, STA, env.get[Variable](h.name))) + } + val lo = targetBytes.head + val hi = targetBytes.last + env.eval(rhs) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(shift, _)) if shift > 0 => + List.fill(shift.toInt)(if (aslRatherThanLsr) { + staTo(ASL, lo) ++ targetBytes.tail.flatMap { b => staTo(ROL, b) } + } else { + if (ctx.options.flag(CompilationFlag.RorWarning)) + ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position) + staTo(LSR, hi) ++ targetBytes.reverse.tail.flatMap { b => staTo(ROR, b) } + }).flatten + case _ => + ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO + Nil + } + } + + def compileByteComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + env.eval(rhs) match { + case Some(NumericConstant(0, _)) => + compType match { + case ComparisonType.LessUnsigned => + ErrorReporting.warn("Unsigned < 0 is always false", ctx.options, lhs.position) + case ComparisonType.LessOrEqualUnsigned => + if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings)) + ErrorReporting.warn("Unsigned <= 0 means the same as unsigned == 0", ctx.options, lhs.position) + case ComparisonType.GreaterUnsigned => + if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings)) + ErrorReporting.warn("Unsigned > 0 means the same as unsigned != 0", ctx.options, lhs.position) + case ComparisonType.GreaterOrEqualUnsigned => + ErrorReporting.warn("Unsigned >= 0 is always true", ctx.options, lhs.position) + case _ => + } + case Some(NumericConstant(1, _)) => + if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings)) { + compType match { + case ComparisonType.LessUnsigned => + ErrorReporting.warn("Unsigned < 1 means the same as unsigned == 0", ctx.options, lhs.position) + case ComparisonType.GreaterOrEqualUnsigned => + ErrorReporting.warn("Unsigned >= 1 means the same as unsigned != 0", ctx.options, lhs.position) + case _ => + } + } + case _ => + } + val secondParamCompiledUnoptimized = simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false) + val secondParamCompiled = compType match { + case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned => + secondParamCompiledUnoptimized match { + case List(AssemblyLine(CMP, Immediate, NumericConstant(0, _), true)) => + if (OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) { + Nil + } else { + secondParamCompiledUnoptimized + } + case _ => secondParamCompiledUnoptimized + } + case _ => secondParamCompiledUnoptimized + } + val (effectiveComparisonType, label) = branches match { + case NoBranching => return Nil + case BranchIfTrue(l) => compType -> l + case BranchIfFalse(l) => ComparisonType.negate(compType) -> l + } + val branchingCompiled = effectiveComparisonType match { + case ComparisonType.Equal => + List(AssemblyLine.relative(BEQ, Label(label))) + case ComparisonType.NotEqual => + List(AssemblyLine.relative(BNE, Label(label))) + + case ComparisonType.LessUnsigned => + List(AssemblyLine.relative(BCC, Label(label))) + case ComparisonType.GreaterOrEqualUnsigned => + List(AssemblyLine.relative(BCS, Label(label))) + case ComparisonType.LessOrEqualUnsigned => + List(AssemblyLine.relative(BCC, Label(label)), AssemblyLine.relative(BEQ, Label(label))) + case ComparisonType.GreaterUnsigned => + val x = MlCompiler.nextLabel("co") + List( + AssemblyLine.relative(BEQ, x), + AssemblyLine.relative(BCS, Label(label)), + AssemblyLine.label(x)) + + case ComparisonType.LessSigned => + List(AssemblyLine.relative(BMI, Label(label))) + case ComparisonType.GreaterOrEqualSigned => + List(AssemblyLine.relative(BPL, Label(label))) + case ComparisonType.LessOrEqualSigned => + List(AssemblyLine.relative(BMI, Label(label)), AssemblyLine.relative(BEQ, Label(label))) + case ComparisonType.GreaterSigned => + val x = MlCompiler.nextLabel("co") + List( + AssemblyLine.relative(BEQ, x), + AssemblyLine.relative(BPL, Label(label)), + AssemblyLine.label(x)) + } + firstParamCompiled ++ secondParamCompiled ++ branchingCompiled + + } + + def compileWordComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = { + val env = ctx.env + // TODO: comparing stack variables + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + + val (effectiveComparisonType, x) = branches match { + case NoBranching => return Nil + case BranchIfTrue(label) => compType -> label + case BranchIfFalse(label) => ComparisonType.negate(compType) -> label + } + val (lh, ll, rh, rl, ram) = (lhs, env.eval(lhs), rhs, env.eval(rhs)) match { + case (_, Some(NumericConstant(lc, _)), _, Some(NumericConstant(rc, _))) => + return if (effectiveComparisonType match { + // TODO: those masks are probably wrong + case ComparisonType.Equal => + (lc & 0xffff) == (rc & 0xffff) // ?? + case ComparisonType.NotEqual => + (lc & 0xffff) != (rc & 0xffff) // ?? + + case ComparisonType.LessOrEqualUnsigned => + (lc & 0xffff) <= (rc & 0xffff) + case ComparisonType.GreaterOrEqualUnsigned => + (lc & 0xffff) >= (rc & 0xffff) + case ComparisonType.GreaterUnsigned => + (lc & 0xffff) > (rc & 0xffff) + case ComparisonType.LessUnsigned => + (lc & 0xffff) < (rc & 0xffff) + + case ComparisonType.LessOrEqualSigned => + lc.toShort <= rc.toShort + case ComparisonType.GreaterOrEqualSigned => + lc.toShort >= rc.toShort + case ComparisonType.GreaterSigned => + lc.toShort > rc.toShort + case ComparisonType.LessSigned => + lc.toShort < rc.toShort + }) List(AssemblyLine.absolute(JMP, Label(x))) else Nil + case (_, Some(lc), _, Some(rc)) => + // TODO: comparing late-bound constants + ??? + case (_, Some(lc), rv: VariableInMemory, None) => + return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches) + case (v: VariableExpression, None, _, Some(rc)) => + // TODO: stack variables + (env.get[VariableInMemory](v.name + ".hi").toAddress, + env.get[VariableInMemory](v.name + ".lo").toAddress, + rc.hiByte, + rc.loByte, + Immediate) + case (lv: VariableExpression, None, rv: VariableExpression, None) => + // TODO: stack variables + (env.get[VariableInMemory](lv.name + ".hi").toAddress, + env.get[VariableInMemory](lv.name + ".lo").toAddress, + env.get[VariableInMemory](rv.name + ".hi").toAddress, + env.get[VariableInMemory](rv.name + ".lo").toAddress, Absolute) + } + effectiveComparisonType match { + case ComparisonType.Equal => + val innerLabel = MlCompiler.nextLabel("cp") + List(AssemblyLine.absolute(LDA, ll), + AssemblyLine(CMP, ram, rl), + AssemblyLine.relative(BNE, innerLabel), + AssemblyLine.absolute(LDA, lh), + AssemblyLine(CMP, ram, rh), + AssemblyLine.relative(BEQ, Label(x)), + AssemblyLine.label(innerLabel)) + + case ComparisonType.NotEqual => + List(AssemblyLine.absolute(LDA, ll), + AssemblyLine(CMP, ram, rl), + AssemblyLine.relative(BNE, Label(x)), + AssemblyLine.absolute(LDA, lh), + AssemblyLine(CMP, ram, rh), + AssemblyLine.relative(BNE, Label(x))) + + case ComparisonType.LessUnsigned => + val innerLabel = MlCompiler.nextLabel("cp") + List(AssemblyLine.absolute(LDA, lh), + AssemblyLine(CMP, ram, rh), + AssemblyLine.relative(BCC, Label(x)), + AssemblyLine.relative(BNE, innerLabel), + AssemblyLine.absolute(LDA, ll), + AssemblyLine(CMP, ram, rl), + AssemblyLine.relative(BCC, Label(x)), + AssemblyLine.label(innerLabel)) + + case ComparisonType.LessOrEqualUnsigned => + val innerLabel = MlCompiler.nextLabel("cp") + List(AssemblyLine(LDA, ram, rh), + AssemblyLine.absolute(CMP, lh), + AssemblyLine.relative(BCC, innerLabel), + AssemblyLine.relative(BNE, x), + AssemblyLine(LDA, ram, rl), + AssemblyLine.absolute(CMP, ll), + AssemblyLine.relative(BCS, x), + AssemblyLine.label(innerLabel)) + + case ComparisonType.GreaterUnsigned => + val innerLabel = MlCompiler.nextLabel("cp") + List(AssemblyLine(LDA, ram, rh), + AssemblyLine.absolute(CMP, lh), + AssemblyLine.relative(BCC, Label(x)), + AssemblyLine.relative(BNE, innerLabel), + AssemblyLine(LDA, ram, rl), + AssemblyLine.absolute(CMP, ll), + AssemblyLine.relative(BCC, Label(x)), + AssemblyLine.label(innerLabel)) + + case ComparisonType.GreaterOrEqualUnsigned => + val innerLabel = MlCompiler.nextLabel("cp") + List(AssemblyLine.absolute(LDA, lh), + AssemblyLine(CMP, ram, rh), + AssemblyLine.relative(BCC, innerLabel), + AssemblyLine.relative(BNE, x), + AssemblyLine.absolute(LDA, ll), + AssemblyLine(CMP, ram, rl), + AssemblyLine.relative(BCS, x), + AssemblyLine.label(innerLabel)) + + case _ => ??? + // TODO: signed word comparisons + } + } + + def compileInPlaceByteMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + ctx.env.eval(addend) match { + case Some(NumericConstant(0, _)) => + AssemblyLine.immediate(LDA, 0) :: MlCompiler.compileByteStorage(ctx, Register.A, v) + case Some(NumericConstant(1, _)) => + Nil + case Some(NumericConstant(x, _)) => + compileByteMultiplication(ctx, v, x.toInt) ++ MlCompiler.compileByteStorage(ctx, Register.A, v) + case _ => + ErrorReporting.error("Multiplying by not a constant not supported", v.position) + Nil + } + } + + def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = { + val result = ListBuffer[AssemblyLine]() + // TODO: optimise + val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false) + val adc = addingCode.last + val indexing = addingCode.init + result ++= indexing + result += AssemblyLine.immediate(LDA, 0) + val mult = c & 0xff + var mask = 128 + var empty = true + while (mask > 0) { + if (!empty) { + result += AssemblyLine.implied(ASL) + } + if ((mult & mask) != 0) { + result ++= List(AssemblyLine.implied(CLC), adc) + empty = false + } + + mask >>>= 1 + } + result.toList + } + + def compileByteMultiplication(ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = { + val (constants, variables) = params.map(p => p -> ctx.env.eval(p)).partition(_._2.exists(_.isInstanceOf[NumericConstant])) + val constant = constants.map(_._2.get.asInstanceOf[NumericConstant].value).foldLeft(1L)(_ * _).toInt + variables.length match { + case 0 => List(AssemblyLine.immediate(LDA, constant & 0xff)) + case 1 =>compileByteMultiplication(ctx, variables.head._1, constant) + case 2 => + ErrorReporting.error("Multiplying by not a constant not supported", params.head.position) + Nil + } + } + + def compileInPlaceByteAddition(ctx: CompilationContext, v: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = { + if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) { + ErrorReporting.warn("Unsupported decimal operation", ctx.options, v.position) + } + val env = ctx.env + val b = env.get[Type]("byte") + env.eval(addend) match { + case Some(NumericConstant(0, _)) => Nil + case Some(NumericConstant(1, _)) if !decimal => if (subtract) { + simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) + } else { + simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) + } + // TODO: compile +=2 to two INCs + case Some(NumericConstant(-1, _)) if !decimal => if (subtract) { + simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) + } else { + simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) + } + case _ => + val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyLhs = if (subtract) { + insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false)) + } else { + insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true)) + } + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + wrapInSedCldIfNeeded(decimal, loadLhs ++ modifyLhs ++ storeLhs) + } + } + + def compileInPlaceWordOrLongAddition(ctx: CompilationContext, lhs: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = { + if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) { + ErrorReporting.warn("Unsupported decimal operation", ctx.options, lhs.position) + } + val env = ctx.env + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + val targetBytes: List[List[AssemblyLine]] = lhs match { + case v: VariableExpression => + val variable = env.get[Variable](v.name) + List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } + case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => + val lv = env.get[Variable](l.name) + val hv = env.get[Variable](h.name) + List( + AssemblyLine.variable(ctx, STA, lv), + AssemblyLine.variable(ctx, STA, hv)) + } + val lhsIsStack = targetBytes.head.head.opcode == TSX + val targetSize = targetBytes.size + val addendType = MlCompiler.getExpressionType(ctx, addend) + var addendSize = addendType.size + + def isRhsComplex(xs: List[AssemblyLine]): Boolean = xs match { + case AssemblyLine(LDA, _, _, _) :: Nil => false + case AssemblyLine(LDA, _, _, _) :: AssemblyLine(LDX, _, _, _) :: Nil => false + case _ => true + } + + def isRhsStack(xs: List[AssemblyLine]): Boolean = xs.exists(_.opcode == TSX) + + val (calculateRhs, addendByteRead0): (List[AssemblyLine], List[List[AssemblyLine]]) = env.eval(addend) match { + case Some(constant) => + addendSize = targetSize + Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i)))) + case None => + addendSize match { + case 1 => + val base = MlCompiler.compile(ctx, addend, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + if (subtract) { + if (isRhsComplex(base)) { + if (isRhsStack(base)) { + ErrorReporting.warn("Subtracting a stack-based value", ctx.options) + ??? + } + (base ++ List(AssemblyLine.implied(PHA))) -> List(List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101))) + } else { + Nil -> base.map(_ :: Nil) + } + } else { + base -> List(Nil) + } + case 2 => + val base = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AX, w)), NoBranching) + if (isRhsStack(base)) { + val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching) + if (subtract) { + ErrorReporting.warn("Subtracting a stack-based value", ctx.options) + if (isRhsComplex(base)) { + ??? + } else { + Nil -> fixedBase + ??? + } + } else { + fixedBase -> List(Nil, List(AssemblyLine.implied(TYA))) + } + } else { + if (subtract) { + if (isRhsComplex(base)) { + (base ++ List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(TXA), + AssemblyLine.implied(PHA)) + ) -> List( + List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x102)), + List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101))) + } else { + Nil -> base.map(_ :: Nil) + } + } else { + if (lhsIsStack) { + val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching) + fixedBase -> List(Nil, List(AssemblyLine.implied(TYA))) + } else { + base -> List(Nil, List(AssemblyLine.implied(TXA))) + } + } + } + case _ => Nil -> (addend match { + case vv: VariableExpression => + val source = env.get[Variable](vv.name) + List.tabulate(addendSize)(i => AssemblyLine.variable(ctx, LDA, source, i)) + }) + } + } + val addendByteRead = addendByteRead0 ++ List.fill((targetSize - addendByteRead0.size) max 0)(List(AssemblyLine.immediate(LDA, 0))) + val buffer = mutable.ListBuffer[AssemblyLine]() + buffer ++= calculateRhs + buffer += AssemblyLine.implied(if (subtract) SEC else CLC) + val extendMultipleBytes = targetSize > addendSize + 1 + val extendAtLeastOneByte = targetSize > addendSize + for (i <- 0 until targetSize) { + if (subtract) { + if (addendSize < targetSize && addendType.isSigned) { + // TODO: sign extension + ??? + } + buffer ++= staTo(LDA, targetBytes(i)) + buffer ++= ldTo(SBC, addendByteRead(i)) + buffer ++= targetBytes(i) + } else { + if (i >= addendSize) { + if (addendType.isSigned) { + val label = MlCompiler.nextLabel("sx") + buffer += AssemblyLine.implied(TXA) + if (i == addendSize) { + buffer += AssemblyLine.immediate(ORA, 0x7f) + buffer += AssemblyLine.relative(BMI, label) + buffer += AssemblyLine.immediate(LDA, 0) + buffer += AssemblyLine.label(label) + if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX) + } + } else { + buffer += AssemblyLine.immediate(LDA, 0) + } + } else { + buffer ++= addendByteRead(i) + if (addendType.isSigned && i == addendSize - 1 && extendAtLeastOneByte) { + buffer += AssemblyLine.implied(TAX) + } + } + buffer ++= staTo(ADC, targetBytes(i)) + buffer ++= targetBytes(i) + } + } + for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) { + buffer += AssemblyLine.implied(PLA) + } + wrapInSedCldIfNeeded(decimal, buffer.toList) + } + + def compileInPlaceByteBitOp(ctx: CompilationContext, v: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + (operation, env.eval(param)) match { + case (EOR, Some(NumericConstant(0, _))) + | (ORA, Some(NumericConstant(0, _))) + | (AND, Some(NumericConstant(0xff, _))) + | (AND, Some(NumericConstant(-1, _))) => + Nil + case _ => + val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyLhs = simpleOperation(operation, ctx, param, IndexChoice.PreferY, preserveA = true, commutative = true) + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + loadLhs ++ modifyLhs ++ storeLhs + } + } + + + def compileInPlaceWordOrLongBitOp(ctx: CompilationContext, lhs: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + val targetBytes: List[List[AssemblyLine]] = lhs match { + case v: VariableExpression => + val variable = env.get[Variable](v.name) + List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } + case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => + val lv = env.get[Variable](l.name) + val hv = env.get[Variable](h.name) + List( + AssemblyLine.variable(ctx, STA, lv), + AssemblyLine.variable(ctx, STA, hv)) + case _ => + ??? + } + val lo = targetBytes.head + val targetSize = targetBytes.size + val paramType = MlCompiler.getExpressionType(ctx, param) + var paramSize = paramType.size + val extendMultipleBytes = targetSize > paramSize + 1 + val extendAtLeastOneByte = targetSize > paramSize + val (calculateRhs, addendByteRead) = env.eval(param) match { + case Some(constant) => + paramSize = targetSize + Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i)))) + case None => + paramSize match { + case 1 => + val base = MlCompiler.compile(ctx, param, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + base -> List(Nil) + case 2 => + val base = MlCompiler.compile(ctx, param, Some(w -> RegisterVariable(Register.AX, w)), NoBranching) + base -> List(Nil, List(AssemblyLine.implied(TXA))) + case _ => Nil -> (param match { + case vv: VariableExpression => + val source = env.get[Variable](vv.name) + List.tabulate(paramSize)(i => AssemblyLine.variable(ctx, LDA, source, i)) + }) + } + } + val AllOnes = (1L << (8 * targetSize)) - 1 + (operation, env.eval(param)) match { + case (EOR, Some(NumericConstant(0, _))) + | (ORA, Some(NumericConstant(0, _))) + | (AND, Some(NumericConstant(AllOnes, _))) => + Nil + case _ => + val buffer = mutable.ListBuffer[AssemblyLine]() + buffer ++= calculateRhs + for (i <- 0 until targetSize) { + if (i < paramSize) { + buffer ++= addendByteRead(i) + if (paramType.isSigned && i == paramSize - 1 && extendAtLeastOneByte) { + buffer += AssemblyLine.implied(TAX) + } + } else { + if (paramType.isSigned) { + val label = MlCompiler.nextLabel("sx") + buffer += AssemblyLine.implied(TXA) + if (i == paramSize) { + buffer += AssemblyLine.immediate(ORA, 0x7f) + buffer += AssemblyLine.relative(BMI, label) + buffer += AssemblyLine.immediate(LDA, 0) + buffer += AssemblyLine.label(label) + if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX) + } + } else { + buffer += AssemblyLine.immediate(LDA, 0) + } + } + buffer ++= staTo(operation, targetBytes(i)) + buffer ++= targetBytes(i) + } + for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) { + buffer += AssemblyLine.implied(PLA) + } + buffer.toList + } + } + + +} diff --git a/src/main/scala/millfork/compiler/CompilationContext.scala b/src/main/scala/millfork/compiler/CompilationContext.scala new file mode 100644 index 00000000..a321eb54 --- /dev/null +++ b/src/main/scala/millfork/compiler/CompilationContext.scala @@ -0,0 +1,12 @@ +package millfork.compiler + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.env.{Environment, MangledFunction, NormalFunction} + +/** + * @author Karol Stasiak + */ +case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){ + + def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i) +} diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala new file mode 100644 index 00000000..a977a2ec --- /dev/null +++ b/src/main/scala/millfork/compiler/MfCompiler.scala @@ -0,0 +1,1675 @@ +package millfork.compiler + +import java.util.concurrent.atomic.AtomicLong + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly._ +import millfork.env._ +import millfork.node.{Register, _} +import millfork.assembly.AddrMode._ +import millfork.assembly.Opcode._ +import millfork.error.ErrorReporting + +import scala.collection.JavaConverters._ + +/** + * @author Karol Stasiak + */ + +sealed trait BranchSpec { + def flip: BranchSpec +} + +case object NoBranching extends BranchSpec { + override def flip = this +} + +case class BranchIfTrue(label: String) extends BranchSpec { + override def flip = BranchIfFalse(label) +} + +case class BranchIfFalse(label: String) extends BranchSpec { + override def flip = BranchIfTrue(label) +} + +object BranchSpec { + val None = NoBranching +} + +//noinspection NotImplementedCode,ScalaUnusedSymbol +object MlCompiler { + + + private var labelCounter = new AtomicLong + + def nextLabel(prefix: String): String = "." + prefix + "__" + labelCounter.incrementAndGet().formatted("%05d") + + def compile(ctx: CompilationContext): Chunk = { + val chunk = compile(ctx, ctx.function.code) + val prefix = (if (ctx.function.interrupt) { + if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) { + List( + AssemblyLine.implied(SEI), + AssemblyLine.implied(PHA), + AssemblyLine.implied(PHX), + AssemblyLine.implied(PHY), + AssemblyLine.implied(CLD)) + } else { + List( + AssemblyLine.implied(SEI), + AssemblyLine.implied(PHA), + AssemblyLine.implied(TXA), + AssemblyLine.implied(PHA), + AssemblyLine.implied(TYA), + AssemblyLine.implied(PHA), + AssemblyLine.implied(CLD)) + } + } else Nil) ++ stackPointerFixAtBeginning(ctx) + if (prefix.nonEmpty) { + LabelledChunk(ctx.function.name, SequenceChunk(List(LinearChunk(prefix), chunk))) + } else { + LabelledChunk(ctx.function.name, chunk) + } + } + + def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = { + val paramsWithTypes = f.expressions.map(x => getExpressionType(ctx, x) -> x) + ctx.env.lookupFunction(f.functionName, paramsWithTypes).getOrElse( + ErrorReporting.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1)}`", f.position)) + } + + def getExpressionType(ctx: CompilationContext, expr: Expression): Type = { + val env = ctx.env + val b = env.get[Type]("byte") + val bool = env.get[Type]("bool$") + val v = env.get[Type]("void") + val w = env.get[Type]("word") + val l = env.get[Type]("long") + expr match { + case LiteralExpression(value, size) => + size match { + case 1 => b + case 2 => w + case 3 | 4 => l + } + case VariableExpression(name) => + env.get[TypedThing](name, expr.position).typ + case HalfWordExpression(param, _) => + getExpressionType(ctx, param) + b + case IndexedExpression(_, _) => b + case SeparateBytesExpression(h, l) => + if (getExpressionType(ctx, h).size > 1) ErrorReporting.error("Hi byte too large", h.position) + if (getExpressionType(ctx, l).size > 1) ErrorReporting.error("Lo byte too large", l.position) + w + case SumExpression(params, _) => b + case FunctionCallExpression("not", params) => bool + case FunctionCallExpression("*", params) => b + case FunctionCallExpression("|", params) => b + case FunctionCallExpression("&", params) => b + case FunctionCallExpression("^", params) => b + case FunctionCallExpression("<<", params) => b + case FunctionCallExpression(">>", params) => b + case FunctionCallExpression("<<'", params) => b + case FunctionCallExpression(">>'", params) => b + case FunctionCallExpression(">>>>", params) => b + case FunctionCallExpression("&&", params) => bool + case FunctionCallExpression("||", params) => bool + case FunctionCallExpression("^^", params) => bool + case FunctionCallExpression("==", params) => bool + case FunctionCallExpression("!=", params) => bool + case FunctionCallExpression("<", params) => bool + case FunctionCallExpression(">", params) => bool + case FunctionCallExpression("<=", params) => bool + case FunctionCallExpression(">=", params) => bool + case FunctionCallExpression("+=", params) => v + case FunctionCallExpression("-=", params) => v + case FunctionCallExpression("*=", params) => v + case FunctionCallExpression("+'=", params) => v + case FunctionCallExpression("-'=", params) => v + case FunctionCallExpression("*'=", params) => v + case FunctionCallExpression("|=", params) => v + case FunctionCallExpression("&=", params) => v + case FunctionCallExpression("^=", params) => v + case FunctionCallExpression("<<=", params) => v + case FunctionCallExpression(">>=", params) => v + case FunctionCallExpression("<<'=", params) => v + case FunctionCallExpression(">>'=", params) => v + case f@FunctionCallExpression(name, params) => + lookupFunction(ctx, f).returnType + } + } + + def compileConstant(ctx: CompilationContext, expr: Constant, target: Variable): List[AssemblyLine] = { + target match { + case RegisterVariable(Register.A, _) => List(AssemblyLine(LDA, Immediate, expr)) + case RegisterVariable(Register.X, _) => List(AssemblyLine(LDX, Immediate, expr)) + case RegisterVariable(Register.Y, _) => List(AssemblyLine(LDY, Immediate, expr)) + case RegisterVariable(Register.AX, _) => List( + AssemblyLine(LDA, Immediate, expr.loByte), + AssemblyLine(LDX, Immediate, expr.hiByte)) + case RegisterVariable(Register.AY, _) => List( + AssemblyLine(LDA, Immediate, expr.loByte), + AssemblyLine(LDY, Immediate, expr.hiByte)) + case RegisterVariable(Register.XA, _) => List( + AssemblyLine(LDA, Immediate, expr.hiByte), + AssemblyLine(LDX, Immediate, expr.loByte)) + case RegisterVariable(Register.YA, _) => List( + AssemblyLine(LDA, Immediate, expr.hiByte), + AssemblyLine(LDY, Immediate, expr.loByte)) + case m: VariableInMemory => + val addr = m.toAddress + m.typ.size match { + case 0 => Nil + case 1 => List( + AssemblyLine(LDA, Immediate, expr.loByte), + AssemblyLine(STA, Absolute, addr)) + case 2 => List( + AssemblyLine(LDA, Immediate, expr.loByte), + AssemblyLine(STA, Absolute, addr), + AssemblyLine(LDA, Immediate, expr.hiByte), + AssemblyLine(STA, Absolute, addr + 1)) + case s => List.tabulate(s)(i => List( + AssemblyLine(LDA, Immediate, expr.subbyte(i)), + AssemblyLine(STA, Absolute, addr + i))).flatten + } + case StackVariable(_, t, offset) => + t.size match { + case 0 => Nil + case 1 => List( + AssemblyLine.implied(TSX), + AssemblyLine.immediate(LDA, expr.loByte), + AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset)) + case 2 => List( + AssemblyLine.implied(TSX), + AssemblyLine.immediate(LDA, expr.loByte), + AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset), + AssemblyLine.immediate(LDA, expr.hiByte), + AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + 1)) + case s => AssemblyLine.implied(TSX) :: List.tabulate(s)(i => List( + AssemblyLine.immediate(LDA, expr.subbyte(i)), + AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + i))).flatten + } + } + } + + def fixTsx(code: List[AssemblyLine]): List[AssemblyLine] = code match { + case (tsx@AssemblyLine(TSX, _, _, _)) :: xs => tsx :: AssemblyLine.implied(INX) :: fixTsx(xs) + case (txs@AssemblyLine(TXS, _, _, _)) :: xs => ??? + case x :: xs => x :: fixTsx(xs) + case Nil => Nil + } + + def preserveRegisterIfNeeded(ctx: CompilationContext, register: Register.Value, code: List[AssemblyLine]): List[AssemblyLine] = { + val state = register match { + case Register.A => State.A + case Register.X => State.X + case Register.Y => State.Y + } + + val cmos = ctx.options.flag(CompilationFlag.EmitCmosOpcodes) + if (AssemblyLine.treatment(code, state) != Treatment.Unchanged) { + register match { + case Register.A => AssemblyLine.implied(PHA) +: fixTsx(code) :+ AssemblyLine.implied(PLA) + case Register.X => if (cmos) { + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(PHX), + ) ++ fixTsx(fixTsx(code)) ++ List( + AssemblyLine.implied(PLX), + AssemblyLine.implied(PLA), + ) + } else { + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(TXA), + AssemblyLine.implied(PHA), + ) ++ fixTsx(fixTsx(code)) ++ List( + AssemblyLine.implied(PLA), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA), + ) + } + case Register.Y => if (cmos) { + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(PHY), + ) ++ fixTsx(fixTsx(code)) ++ List( + AssemblyLine.implied(PLY), + AssemblyLine.implied(PLA), + ) + } else { + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(TYA), + AssemblyLine.implied(PHA), + ) ++ fixTsx(fixTsx(code)) ++ List( + AssemblyLine.implied(PLA), + AssemblyLine.implied(TAY), + AssemblyLine.implied(PLA), + ) + } + } + } else { + code + } + } + + def compileByteStorage(ctx: CompilationContext, register: Register.Value, target: LhsExpression): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val store = register match { + case Register.A => STA + case Register.X => STX + case Register.Y => STY + } + val transferToA = register match { + case Register.A => NOP + case Register.X => TXA + case Register.Y => TYA + } + target match { + case VariableExpression(name) => + val v = env.get[Variable](name) + v.typ.size match { + case 0 => ??? + case 1 => + v match { + case mv: VariableInMemory => AssemblyLine.absolute(store, mv) :: Nil + case sv@StackVariable(_, _, offset) => AssemblyLine.implied(transferToA) :: AssemblyLine.implied(TSX) :: AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset) :: Nil + } + case s if s > 1 => + v match { + case mv: VariableInMemory => + AssemblyLine.absolute(store, mv) :: + AssemblyLine.immediate(LDA, 0) :: + List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, mv.toAddress + (i + 1))) + case sv@StackVariable(_, _, offset) => + AssemblyLine.implied(transferToA) :: + AssemblyLine.implied(TSX) :: + AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset) :: + List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + i + 1)) + } + } + case IndexedExpression(arrayName, indexExpr) => + val array = env.getArrayOrPointer(arrayName) + val (variableIndex, constIndex) = env.evalVariableAndConstantSubParts(indexExpr) + + def storeToArrayAtUnknownIndex(variableIndex: Expression, arrayAddr: Constant) = { + // TODO check typ + val indexRegister = if (register == Register.Y) Register.X else Register.Y + val calculatingIndex = preserveRegisterIfNeeded(ctx, register, compile(ctx, variableIndex, Some(b, RegisterVariable(indexRegister, b)), NoBranching)) + if (register == Register.A) { + indexRegister match { + case Register.Y => + calculatingIndex ++ List(AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) + case Register.X => + calculatingIndex ++ List(AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) + } + } else { + indexRegister match { + case Register.Y => + calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) + case Register.X => + calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) + } + } + } + + (array, variableIndex) match { + + case (p: ConstantThing, None) => + List(AssemblyLine.absolute(store, env.genRelativeVariable(p.value + constIndex, b, zeropage = false))) + case (p: ConstantThing, Some(v)) => + storeToArrayAtUnknownIndex(v, p.value) + + case (a@InitializedArray(_, _, _), None) => + List(AssemblyLine.absolute(store, env.genRelativeVariable(a.toAddress + constIndex, b, zeropage = false))) + case (a@InitializedArray(_, _, _), Some(v)) => + storeToArrayAtUnknownIndex(v, a.toAddress) + + case (a@UninitializedArray(_, _), None) => + List(AssemblyLine.absolute(store, env.genRelativeVariable(a.toAddress + constIndex, b, zeropage = false))) + case (a@UninitializedArray(_, _), Some(v)) => + storeToArrayAtUnknownIndex(v, a.toAddress) + + case (RelativeArray(_, arrayAddr, _), None) => + List(AssemblyLine.absolute(store, env.genRelativeVariable(arrayAddr + constIndex, b, zeropage = false))) + case (RelativeArray(_, arrayAddr, _), Some(v)) => + storeToArrayAtUnknownIndex(v, arrayAddr) + + // TODO: are those two below okay? + case (RelativeVariable(_, arrayAddr, typ, _), None) => + List(AssemblyLine.absolute(store, env.genRelativeVariable(arrayAddr + constIndex, b, zeropage = false))) + case (RelativeVariable(_, arrayAddr, typ, _), Some(v)) => + storeToArrayAtUnknownIndex(v, arrayAddr) + + //TODO: should there be a type check or a zeropage check? + case (pointerVariable@MemoryVariable(_, typ, _), None) => + register match { + case Register.A => + List(AssemblyLine.immediate(LDY, constIndex), AssemblyLine.indexedY(STA, pointerVariable)) + case Register.Y => + List(AssemblyLine.implied(TYA), AssemblyLine.immediate(LDY, constIndex), AssemblyLine.indexedY(STA, pointerVariable), AssemblyLine.implied(TAY)) + case Register.X => + List(AssemblyLine.immediate(LDY, constIndex), AssemblyLine.implied(TXA), AssemblyLine.indexedY(STA, pointerVariable)) + case _ => + ErrorReporting.error("Cannot store a word in an array", target.position) + Nil + } + case (pointerVariable@MemoryVariable(_, typ, _), Some(_)) => + val calculatingIndex = compile(ctx, indexExpr, Some(b, RegisterVariable(Register.Y, b)), NoBranching) + register match { + case Register.A => + preserveRegisterIfNeeded(ctx, Register.A, calculatingIndex) :+ AssemblyLine.indexedY(STA, pointerVariable) + case Register.X => + preserveRegisterIfNeeded(ctx, Register.X, calculatingIndex) ++ List(AssemblyLine.implied(TXA), AssemblyLine.indexedY(STA, pointerVariable)) + case Register.Y => + AssemblyLine.implied(TYA) :: preserveRegisterIfNeeded(ctx, Register.A, calculatingIndex) ++ List( + AssemblyLine.indexedY(STA, pointerVariable), AssemblyLine.implied(TAY) + ) + case _ => + ErrorReporting.error("Cannot store a word in an array", target.position) + Nil + } + } + + } + } + + def assertCompatible(exprType: Type, variableType: Type): Unit = { + // TODO + } + + val noop: List[AssemblyLine] = Nil + + def callingContext(ctx: CompilationContext, v: MemoryVariable): CompilationContext = { + val result = new Environment(Some(ctx.env), "") + result.registerVariable(VariableDeclarationStatement(v.name, v.typ.name, stack = false, global = false, constant = false, volatile = false, initialValue = None, address = None), ctx.options) + ctx.copy(env = result) + } + + def assertBinary(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int) = { + if (params.length != 2) { + ErrorReporting.fatal("sfgdgfsd", None) + } + (params.head, params(1)) match { + case (l: Expression, r: Expression) => (l, r, getExpressionType(ctx, l).size max getExpressionType(ctx, r).size) + } + } + + def assertComparison(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int, Boolean) = { + if (params.length != 2) { + ErrorReporting.fatal("sfgdgfsd", None) + } + (params.head, params(1)) match { + case (l: Expression, r: Expression) => + val lt = getExpressionType(ctx, l) + val rt = getExpressionType(ctx, r) + (l, r, lt.size max rt.size, lt.isSigned || rt.isSigned) + } + } + + def assertBool(ctx: CompilationContext, params: List[Expression], expectedParamCount: Int): Unit = { + if (params.length != expectedParamCount) { + ErrorReporting.error("Invalid number of parameters", params.headOption.flatMap(_.position)) + return + } + params.foreach { param => + if (!getExpressionType(ctx, param).isInstanceOf[BooleanType]) + ErrorReporting.fatal("Parameter should be boolean", param.position) + } + } + + def assertAssignmentLike(ctx: CompilationContext, params: List[Expression]): (LhsExpression, Expression, Int) = { + if (params.length != 2) { + ErrorReporting.fatal("sfgdgfsd", None) + } + (params.head, params(1)) match { + case (l: LhsExpression, r: Expression) => + val lsize = getExpressionType(ctx, l).size + val rsize = getExpressionType(ctx, r).size + if (lsize < rsize) { + ErrorReporting.error("Left-hand-side expression is of smaller type than the right-hand-side expression", l.position) + } + (l, r, lsize) + case (err: Expression, _) => ErrorReporting.fatal("Invalid left-hand-side expression", err.position) + } + } + + def compile(ctx: CompilationContext, expr: Expression, exprTypeAndVariable: Option[(Type, Variable)], branches: BranchSpec): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + expr match { + case HalfWordExpression(expression, _) => ??? // TODO + case LiteralExpression(value, size) => + exprTypeAndVariable.fold(noop) { case (exprType, target) => + assertCompatible(exprType, target.typ) + compileConstant(ctx, NumericConstant(value, size), target) + } + case VariableExpression(name) => + exprTypeAndVariable.fold(noop) { case (exprType, target) => + assertCompatible(exprType, target.typ) + env.eval(expr).map(c => compileConstant(ctx, c, target)).getOrElse { + env.get[TypedThing](name) match { + case source: VariableInMemory => + target match { + case RegisterVariable(Register.A, _) => List(AssemblyLine.absolute(LDA, source)) + case RegisterVariable(Register.X, _) => List(AssemblyLine.absolute(LDX, source)) + case RegisterVariable(Register.Y, _) => List(AssemblyLine.absolute(LDY, source)) + case RegisterVariable(Register.AX, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.implied(PHA), + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA)) + } else List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.immediate(LDX, 0)) + case 2 => List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.absolute(LDX, source, 1)) + } + case RegisterVariable(Register.AY, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.implied(PHA), + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label), + AssemblyLine.implied(TAY), + AssemblyLine.implied(PLA)) + } else { + List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.immediate(LDY, 0)) + } + case 2 => List( + AssemblyLine.absolute(LDA, source), + AssemblyLine.absolute(LDY, source, 1)) + } + case RegisterVariable(Register.XA, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(LDX, source), + AssemblyLine.implied(TXA), + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) + } else List( + AssemblyLine.absolute(LDX, source), + AssemblyLine.immediate(LDA, 0)) + case 2 => List( + AssemblyLine.absolute(LDX, source), + AssemblyLine.absolute(LDA, source, 1)) + } + case RegisterVariable(Register.YA, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(LDY, source), + AssemblyLine.implied(TYA), + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) + } else List( + AssemblyLine.absolute(LDY, source), + AssemblyLine.immediate(LDA, 0)) + case 2 => List( + AssemblyLine.absolute(LDY, source), + AssemblyLine.absolute(LDA, source, 1)) + } + case target: VariableInMemory => + if (exprType.size > target.typ.size) { + ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) + Nil + } else { + val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absolute(LDA, source, i), AssemblyLine.absolute(STA, target, i))) + val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size)) + } else { + AssemblyLine.immediate(LDA, 0) :: + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size)) + } + copy.flatten ++ extend + } + case target: StackVariable => + if (exprType.size > target.typ.size) { + ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) + Nil + } else { + val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absolute(LDA, source, i), AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i))) + val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size)) + } else { + AssemblyLine.immediate(LDA, 0) :: + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size)) + } + AssemblyLine.implied(TSX) :: (copy.flatten ++ extend) + } + } + case source@StackVariable(_, sourceType, offset) => + target match { + case RegisterVariable(Register.A, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset)) + case RegisterVariable(Register.X, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), AssemblyLine.implied(TAX)) + case RegisterVariable(Register.Y, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset)) + case RegisterVariable(Register.AX, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), + AssemblyLine.implied(PHA), + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA)) + } else List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), + AssemblyLine.immediate(LDX, 0)) + case 2 => List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), + AssemblyLine.implied(PHA), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + 1), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA)) + } + case RegisterVariable(Register.AY, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + ??? // TODO + } else { + List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), + AssemblyLine.immediate(LDY, 0)) + } + case 2 => List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), + AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset + 1)) + } + case RegisterVariable(Register.XA, _) => + ??? // TODO + case RegisterVariable(Register.YA, _) => + exprType.size match { + case 1 => if (exprType.isSigned) { + val label = nextLabel("sx") + ??? // TODO + } else { + List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset), + AssemblyLine.immediate(LDA, 0)) + } + case 2 => List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset), + AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + 1)) + } + case target: VariableInMemory => + if (exprType.size > target.typ.size) { + ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) + Nil + } else { + val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absolute(STA, target, i))) + val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size)) + } else { + AssemblyLine.immediate(LDA, 0) :: + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size)) + } + AssemblyLine.implied(TSX) :: (copy.flatten ++ extend) + } + case target: StackVariable => + if (exprType.size > target.typ.size) { + ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) + Nil + } else { + val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absoluteX(STA, target.baseOffset + i))) + val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.immediate(ORA, 0x7F), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size)) + } else { + AssemblyLine.immediate(LDA, 0) :: + List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size)) + } + AssemblyLine.implied(TSX) :: (copy.flatten ++ extend) + } + } + case source@ConstantThing(_, value, _) => + compileConstant(ctx, value, target) + } + } + } + case IndexedExpression(arrayName, indexExpr) => + val array = env.getArrayOrPointer(arrayName) + // TODO: check + val (variableIndex, constantIndex) = env.evalVariableAndConstantSubParts(indexExpr) + exprTypeAndVariable.fold(noop) { case (exprType, target) => + + val register = target match { + case RegisterVariable(r, _) => r + case _ => Register.A + } + val suffix = target match { + case RegisterVariable(_, _) => Nil + case target: VariableInMemory => + if (target.typ.size == 1) { + AssemblyLine.variable(ctx, STA, target) + } + else if (target.typ.isSigned) { + val label = nextLabel("sx") + AssemblyLine.variable(ctx, STA, target) ++ + List( + AssemblyLine.immediate(ORA, 0x7f), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ + List.tabulate(target.typ.size - 1)(i => AssemblyLine.variable(ctx, STA, target, i + 1)).flatten + } else { + AssemblyLine.variable(ctx, STA, target) ++ + List(AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(target.typ.size - 1)(i => AssemblyLine.variable(ctx, STA, target, i + 1)).flatten + } + } + val load = register match { + case Register.A | Register.AX | Register.AY => LDA + case Register.X => LDX + case Register.Y => LDY + } + + def loadFromArrayAtUnknownIndex(variableIndex: Expression, arrayAddr: Constant) = { + // TODO check typ + val indexRegister = if (register == Register.Y) Register.X else Register.Y + val calculatingIndex = compile(ctx, variableIndex, Some(b, RegisterVariable(indexRegister, b)), NoBranching) + indexRegister match { + case Register.Y => + calculatingIndex ++ List(AssemblyLine.absoluteY(load, arrayAddr + constantIndex)) + case Register.X => + calculatingIndex ++ List(AssemblyLine.absoluteX(load, arrayAddr + constantIndex)) + } + } + + val result = (array, variableIndex) match { + case (a: ConstantThing, None) => + List(AssemblyLine.absolute(load, env.genRelativeVariable(a.value + constantIndex, b, zeropage = false))) + case (a: ConstantThing, Some(v)) => + loadFromArrayAtUnknownIndex(v, a.value) + + case (a: MlArray, None) => + List(AssemblyLine.absolute(load, env.genRelativeVariable(a.toAddress + constantIndex, b, zeropage = false))) + case (a: MlArray, Some(v)) => + loadFromArrayAtUnknownIndex(v, a.toAddress) + + // TODO: see above + case (RelativeVariable(_, arrayAddr, typ, _), None) => + List(AssemblyLine.absolute(load, env.genRelativeVariable(arrayAddr + constantIndex, b, zeropage = false))) + case (RelativeVariable(_, arrayAddr, typ, _), Some(v)) => + loadFromArrayAtUnknownIndex(v, arrayAddr) + + // TODO: see above + case (pointerVariable@MemoryVariable(_, typ, _), None) => + register match { + case Register.A => + List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDA, pointerVariable)) + case Register.Y => + List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDA, pointerVariable), AssemblyLine.implied(TAY)) + case Register.X => + List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDX, pointerVariable)) + } + case (pointerVariable@MemoryVariable(_, typ, _), Some(_)) => + val calculatingIndex = compile(ctx, indexExpr, Some(b, RegisterVariable(Register.Y, b)), NoBranching) + register match { + case Register.A => + calculatingIndex :+ AssemblyLine.indexedY(LDA, pointerVariable) + case Register.X => + calculatingIndex :+ AssemblyLine.indexedY(LDX, pointerVariable) + case Register.Y => + calculatingIndex ++ List(AssemblyLine.indexedY(LDA, pointerVariable), AssemblyLine.implied(TAY)) + } + } + register match { + case Register.A | Register.X | Register.Y => result ++ suffix + case Register.AX => result :+ AssemblyLine.immediate(LDX, 0) + case Register.AY => result :+ AssemblyLine.immediate(LDY, 0) + } + } + case SumExpression(params, decimal) => + assertAllBytesForSum("Long addition not supported", ctx, params) + val calculate = BuiltIns.compileAddition(ctx, params, decimal = decimal) + val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + calculate ++ store + case SeparateBytesExpression(h, l) => + exprTypeAndVariable.fold { + // TODO: order? + compile(ctx, l, None, branches) ++ compile(ctx, h, None, branches) + } { case (exprType, target) => + assertCompatible(exprType, target.typ) + target match { + case RegisterVariable(Register.A | Register.X | Register.Y, _) => compile(ctx, l, exprTypeAndVariable, branches) + case RegisterVariable(Register.AX, _) => + compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++ + compile(ctx, h, Some(b -> RegisterVariable(Register.X, b)), branches) + case RegisterVariable(Register.AY, _) => + compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++ + compile(ctx, h, Some(b -> RegisterVariable(Register.Y, b)), branches) + case RegisterVariable(Register.XA, _) => + compile(ctx, l, Some(b -> RegisterVariable(Register.X, b)), branches) ++ + compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), branches) + case RegisterVariable(Register.YA, _) => + compile(ctx, l, Some(b -> RegisterVariable(Register.Y, b)), branches) ++ + compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), branches) + case target: VariableInMemory => + target.typ.size match { + case 1 => + ErrorReporting.error(s"Variable `$target.name` cannot hold a word", expr.position) + Nil + case 2 => + compile(ctx, l, Some(b -> env.genRelativeVariable(target.toAddress, b, zeropage = target.zeropage)), branches) ++ + compile(ctx, h, Some(b -> env.genRelativeVariable(target.toAddress + 1, b, zeropage = target.zeropage)), branches) + } + case target: StackVariable => + target.typ.size match { + case 1 => + ErrorReporting.error(s"Variable `$target.name` cannot hold a word", expr.position) + Nil + case 2 => + compile(ctx, l, Some(b -> StackVariable("", b, target.baseOffset + ctx.extraStackOffset)), branches) ++ + compile(ctx, h, Some(b -> StackVariable("", b, target.baseOffset + ctx.extraStackOffset + 1)), branches) + } + } + } + + case f@FunctionCallExpression(name, params) => + val calculate = name match { + case "not" => + assertBool(ctx, params, 1) + compile(ctx, params.head, exprTypeAndVariable, branches.flip) + case "&&" => + assertBool(ctx, params, 2) + val a = params.head + val b = params(1) + branches match { + case BranchIfFalse(_) => + compile(ctx, a, exprTypeAndVariable, branches) ++ compile(ctx, b, exprTypeAndVariable, branches) + case _ => + val skip = nextLabel("an") + compile(ctx, a, exprTypeAndVariable, BranchIfFalse(skip)) ++ + compile(ctx, b, exprTypeAndVariable, branches) ++ + List(AssemblyLine.label(skip)) + } + case "||" => + assertBool(ctx, params, 2) + val a = params.head + val b = params(1) + branches match { + case BranchIfTrue(_) => + compile(ctx, a, exprTypeAndVariable, branches) ++ compile(ctx, b, exprTypeAndVariable, branches) + case _ => + val skip = nextLabel("or") + compile(ctx, a, exprTypeAndVariable, BranchIfTrue(skip)) ++ + compile(ctx, b, exprTypeAndVariable, branches) ++ + List(AssemblyLine.label(skip)) + } + case "^^" => ??? + case "&" => + assertAllBytes("Long bit ops not supported", ctx, params) + BuiltIns.compileBitOps(AND, ctx, params) + case "*" => + assertAllBytes("Long multiplication not supported", ctx, params) + BuiltIns.compileByteMultiplication(ctx, params) + case "|" => + assertAllBytes("Long bit ops not supported", ctx, params) + BuiltIns.compileBitOps(ORA, ctx, params) + case "^" => + assertAllBytes("Long bit ops not supported", ctx, params) + BuiltIns.compileBitOps(EOR, ctx, params) + case ">>>>" => + val (l, r, 2) = assertBinary(ctx, params) + l match { + case v: LhsExpression => + BuiltIns.compileNonetOps(ctx, v, r) + } + case "<<" => + assertAllBytes("Long shift ops not supported", ctx, params) + val (l, r, 1) = assertBinary(ctx, params) + BuiltIns.compileShiftOps(ASL, ctx, l, r) + case ">>" => + assertAllBytes("Long shift ops not supported", ctx, params) + val (l, r, 1) = assertBinary(ctx, params) + BuiltIns.compileShiftOps(LSR, ctx, l, r) + case "<" => + // TODO: signed + val (l, r, size, signed) = assertComparison(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + } + case ">=" => + // TODO: signed + val (l, r, size, signed) = assertComparison(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + } + case ">" => + // TODO: signed + val (l, r, size, signed) = assertComparison(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + } + case "<=" => + // TODO: signed + val (l, r, size, signed) = assertComparison(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + } + case "==" => + val (l, r, size) = assertBinary(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches) + } + case "!=" => + val (l, r, size) = assertBinary(ctx, params) + size match { + case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.NotEqual, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.NotEqual, l, r, branches) + } + case "+=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = false) + case 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = false) + } + case i if i > 2 => + l match { + case v: VariableExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = false) + } + } + case "-=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = false) + case 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = false) + } + case i if i > 2 => + l match { + case v: VariableExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = false) + } + } + case "+'=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = true) + case 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = true) + } + case i if i > 2 => + l match { + case v: VariableExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = true) + } + } + case "-'=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = true) + case 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = true) + } + case i if i > 2 => + l match { + case v: VariableExpression => + BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = true) + } + } + case "<<=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteShiftOps(ASL, ctx, l, r) + case i if i >= 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongShiftOps(ctx, v, r, aslRatherThanLsr = true) + } + } + case ">>=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteShiftOps(LSR, ctx, l, r) + case i if i >= 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongShiftOps(ctx, v, r, aslRatherThanLsr = false) + } + } + case "*=" => + assertAllBytes("Long multiplication not supported", ctx, params) + val (l, r, 1) = assertAssignmentLike(ctx, params) + BuiltIns.compileInPlaceByteMultiplication(ctx, l, r) + case "&=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteBitOp(ctx, l, r, AND) + case i if i >= 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, AND) + } + } + case "^=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteBitOp(ctx, l, r, EOR) + case i if i >= 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, EOR) + } + } + case "|=" => + val (l, r, size) = assertAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteBitOp(ctx, l, r, ORA) + case i if i >= 2 => + l match { + case v: LhsExpression => + BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, ORA) + } + } + case _ => + lookupFunction(ctx, f) match { + case function: InlinedFunction => + inlineFunction(function, params, Some(ctx)).map { + case AssemblyStatement(opcode, addrMode, expression, elidable) => + val param = env.eval(expression).getOrElse { + expression match { + case VariableExpression(name) => env.get[ThingInMemory](name).toAddress + case _ => + ErrorReporting.error("Inlining failed due to non-constant things", expression.position) + Constant.Zero + } + } + AssemblyLine(opcode, addrMode, param, elidable) + + } + case function: EmptyFunction => + ??? // TODO: type conversion? + case function: FunctionInMemory => + function match { + case nf: NormalFunction => + if (nf.interrupt) { + ErrorReporting.error(s"Calling an interrupt function `${f.functionName}`", expr.position) + } + case _ => () + } + val result = function.params match { + case AssemblyParamSignature(paramConvs) => + val pairs = params.zip(paramConvs) + val secondViaMemory = pairs.flatMap { + case (paramExpr, AssemblyParam(typ, paramVar: VariableInMemory, AssemblyParameterPassingBehaviour.Copy)) => + compile(ctx, paramExpr, Some(typ -> paramVar), NoBranching) + case _ => Nil + } + val thirdViaRegisters = pairs.flatMap { + case (paramExpr, AssemblyParam(typ, paramVar@RegisterVariable(register, _), AssemblyParameterPassingBehaviour.Copy)) => + compile(ctx, paramExpr, Some(typ -> paramVar), NoBranching) + + // TODO: fix + case _ => Nil + } + secondViaMemory ++ thirdViaRegisters :+ AssemblyLine.absolute(JSR, function) + case NormalParamSignature(paramVars) => + params.zip(paramVars).flatMap { + case (paramExpr, paramVar) => + val callCtx = callingContext(ctx, paramVar) + compileAssignment(callCtx, paramExpr, VariableExpression(paramVar.name)) + } ++ List(AssemblyLine.absolute(JSR, function)) + } + result + } + } + val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + calculate ++ store + } + } + + def expressionStorageFromAX(ctx: CompilationContext, exprTypeAndVariable: Option[(Type, Variable)], position: Option[Position]): List[AssemblyLine] = { + exprTypeAndVariable.fold(noop) { + case (VoidType, _) => ??? + case (_, RegisterVariable(Register.A, _)) => noop + case (_, RegisterVariable(Register.X, _)) => List(AssemblyLine.implied(TAX)) + case (_, RegisterVariable(Register.Y, _)) => List(AssemblyLine.implied(TAY)) + case (_, RegisterVariable(Register.AX, _)) => + // TODO: sign extension + noop + case (_, RegisterVariable(Register.XA, _)) => + // TODO: sign extension + if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) { + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(PHX), + AssemblyLine.implied(PLA), + AssemblyLine.implied(PLX)) + } else { + List( + AssemblyLine.implied(TAY), + AssemblyLine.implied(TXA), + AssemblyLine.implied(PHA), + AssemblyLine.implied(TYA), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA)) // fuck this shit + } + case (_, RegisterVariable(Register.YA, _)) => { + // TODO: sign extension + List( + AssemblyLine.implied(TAY), + AssemblyLine.implied(TXA)) + } + case (_, RegisterVariable(Register.AY, _)) => + // TODO: sign extension + List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(TXA), + AssemblyLine.implied(TAY), + AssemblyLine.implied(PLA)) + case (t, v: VariableInMemory) => t.size match { + case 1 => v.typ.size match { + case 1 => + List(AssemblyLine.absolute(STA, v)) + case s if s > 1 => + if (t.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(STA, v), + AssemblyLine.immediate(ORA, 0x7f), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, v, i + 1)) + } else { + List( + AssemblyLine.absolute(STA, v), + AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, v, i + 1)) + } + } + case 2 => v.typ.size match { + case 1 => + ErrorReporting.error(s"Variable `${v.name}` cannot hold a word", position) + Nil + case 2 => + List(AssemblyLine.absolute(STA, v), AssemblyLine.absolute(STX, v, 1)) + case s if s > 2 => + if (t.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.absolute(STA, v), + AssemblyLine.absolute(STX, v, 1), + AssemblyLine.implied(TXA), + AssemblyLine.immediate(ORA, 0x7f), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ List.tabulate(s - 2)(i => AssemblyLine.absolute(STA, v, i + 2)) + } else { + List( + AssemblyLine.absolute(STA, v), + AssemblyLine.absolute(STX, v, 1), + AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(s - 2)(i => AssemblyLine.absolute(STA, v, i + 2)) + } + } + } + case (t, v: StackVariable) => t.size match { + case 1 => v.typ.size match { + case 1 => + List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset)) + case s if s > 1 => + if (t.isSigned) { + val label = nextLabel("sx") + List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset), + AssemblyLine.immediate(ORA, 0x7f), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + i + 1)) + } else { + List( + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset), + AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + i + 1)) + } + } + case 2 => v.typ.size match { + case 1 => + ErrorReporting.error(s"Variable `${v.name}` cannot hold a word", position) + Nil + case 2 => + List( + AssemblyLine.implied(TAY), + AssemblyLine.implied(TXA), + AssemblyLine.implied(TSX), + AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + 1), + AssemblyLine.implied(TYA), + AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset)) + case s if s > 2 => ??? + } + } + } + } + + private def assertAllBytesForSum(msg: String, ctx: CompilationContext, params: List[(Boolean, Expression)]): Unit = { + if (params.exists { case (_, expr) => getExpressionType(ctx, expr).size != 1 }) { + ErrorReporting.fatal(msg, params.head._2.position) + } + } + + private def assertAllBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = { + if (params.exists { expr => getExpressionType(ctx, expr).size != 1 }) { + ErrorReporting.fatal(msg, params.head.position) + } + } + + def compileAssignment(ctx: CompilationContext, source: Expression, target: LhsExpression): List[AssemblyLine] = { + val env = ctx.env + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + target match { + case VariableExpression(name) => + val v = env.get[Variable](name, target.position) + // TODO check v.typ + compile(ctx, source, Some((getExpressionType(ctx, source), v)), NoBranching) + case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) => + compile(ctx, source, Some(w, RegisterVariable(Register.AX, w)), NoBranching) ++ + compileByteStorage(ctx, Register.A, l) ++ compileByteStorage(ctx, Register.X, h) + case SeparateBytesExpression(_, _) => + ErrorReporting.error("Invalid left-hand-side use of `:`") + Nil + case _ => + compile(ctx, source, Some(b, RegisterVariable(Register.A, b)), NoBranching) ++ compileByteStorage(ctx, Register.A, target) + } + } + + + def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): Chunk = { + SequenceChunk(statements.map(s => compile(ctx, s))) + } + + def inlineFunction(i: InlinedFunction, params: List[Expression], cc: Option[CompilationContext]): List[ExecutableStatement] = { + var actualCode = i.code + i.params match { + case AssemblyParamSignature(assParams) => + assParams.zip(params).foreach { + case (AssemblyParam(typ, Placeholder(ph, phType), AssemblyParameterPassingBehaviour.ByReference), actualParam) => + actualParam match { + case VariableExpression(vname) => + cc.foreach(_.env.get[ThingInMemory](vname)) + case l: LhsExpression => + // TODO: ?? + cc.foreach(c => compileByteStorage(c, Register.A, l)) + case _ => + ErrorReporting.error("A non-assignable expression was passed to an inlineable function as a `ref` parameter", actualParam.position) + } + actualCode = actualCode.map { + case a@AssemblyStatement(_, _, expr, _) => + a.copy(expression = expr.replaceVariable(ph, actualParam)) + case x => x + } + case (AssemblyParam(typ, Placeholder(ph, phType), AssemblyParameterPassingBehaviour.ByConstant), actualParam) => + cc.foreach(_.env.eval(actualParam).getOrElse(Constant.error("Non-constant expression was passed to an inlineable function as a `const` parameter", actualParam.position))) + actualCode = actualCode.map { + case a@AssemblyStatement(_, _, expr, _) => + a.copy(expression = expr.replaceVariable(ph, actualParam)) + case x => x + } + case (AssemblyParam(_, _, AssemblyParameterPassingBehaviour.Copy), actualParam) => + ??? + case (_, actualParam) => + } + case NormalParamSignature(Nil) => i.code + case NormalParamSignature(normalParams) => ??? + } + actualCode + } + + def stackPointerFixAtBeginning(ctx: CompilationContext): List[AssemblyLine] = { + val m = ctx.function + if (m.stackVariablesSize == 0) return Nil + if (ctx.options.flag(CompilationFlag.EmitIllegals)) { + if (m.stackVariablesSize > 4) + return List( + AssemblyLine.implied(TSX), + AssemblyLine.immediate(LDA, 0xff), + AssemblyLine.immediate(SBX, m.stackVariablesSize), + AssemblyLine.implied(TXS)) + } + List.fill(m.stackVariablesSize)(AssemblyLine.implied(PHA)) + } + + def stackPointerFixBeforeReturn(ctx: CompilationContext): List[AssemblyLine] = { + val m = ctx.function + if (m.stackVariablesSize == 0) return Nil + + if (m.returnType.size == 0 && m.stackVariablesSize <= 2) + return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLA)) + + if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) { + if (m.returnType.size == 1 && m.stackVariablesSize <= 2) { + return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLX)) + } + if (m.returnType.size == 2 && m.stackVariablesSize <= 2) { + return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLY)) + } + } + + if (ctx.options.flag(CompilationFlag.EmitIllegals)) { + if (m.returnType.size == 0 && m.stackVariablesSize > 4) + return List( + AssemblyLine.implied(TSX), + AssemblyLine.immediate(LDA, 0xff), + AssemblyLine.immediate(SBX, 256 - m.stackVariablesSize), + AssemblyLine.implied(TXS)) + if (m.returnType.size == 1 && m.stackVariablesSize > 6) + return List( + AssemblyLine.implied(TAY), + AssemblyLine.implied(TSX), + AssemblyLine.immediate(LDA, 0xff), + AssemblyLine.immediate(SBX, 256 - m.stackVariablesSize), + AssemblyLine.implied(TXS), + AssemblyLine.implied(TYA)) + } + + AssemblyLine.implied(TSX) :: (List.fill(m.stackVariablesSize)(AssemblyLine.implied(INX)) :+ AssemblyLine.implied(TXS)) + } + + def compile(ctx: CompilationContext, statement: ExecutableStatement): Chunk = { + val env = ctx.env + val m = ctx.function + val b = env.get[Type]("byte") + val w = env.get[Type]("word") + val someRegisterA = Some(b, RegisterVariable(Register.A, b)) + val someRegisterAX = Some(w, RegisterVariable(Register.AX, w)) + val someRegisterYA = Some(w, RegisterVariable(Register.YA, w)) + val returnInstructions = if (m.interrupt) { + if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) { + List( + AssemblyLine.implied(PLY), + AssemblyLine.implied(PLX), + AssemblyLine.implied(PLA), + AssemblyLine.implied(CLI), + AssemblyLine.implied(RTI)) + } else { + List( + AssemblyLine.implied(PLA), + AssemblyLine.implied(TAY), + AssemblyLine.implied(PLA), + AssemblyLine.implied(TAX), + AssemblyLine.implied(PLA), + AssemblyLine.implied(CLI), + AssemblyLine.implied(RTI)) + } + } else { + List(AssemblyLine.implied(RTS)) + } + statement match { + case AssemblyStatement(o, a, x, e) => + val c: Constant = x match { + // TODO: hmmm + case VariableExpression(name) => + if (OpcodeClasses.ShortBranching(o) || o == JMP || o == LABEL) { + MemoryAddressConstant(Label(name)) + } else{ + env.eval(x).getOrElse(env.get[ThingInMemory](name, x.position).toAddress) + } + case _ => + env.eval(x).getOrElse(Constant.error(s"`$x` is not a constant", x.position)) + } + val actualAddrMode = if (OpcodeClasses.ShortBranching(o) && a == Absolute) Relative else a + LinearChunk(List(AssemblyLine(o, actualAddrMode, c, e))) + case Assignment(dest, source) => + LinearChunk(compileAssignment(ctx, source, dest)) + case ExpressionStatement(e@FunctionCallExpression(name, params)) => + env.lookupFunction(name, params.map(p => getExpressionType(ctx, p) -> p)) match { + case Some(i: InlinedFunction) => + compile(ctx, inlineFunction(i, params, Some(ctx))) + case _ => + LinearChunk(compile(ctx, e, None, NoBranching)) + } + case ExpressionStatement(e) => + LinearChunk(compile(ctx, e, None, NoBranching)) + case BlockStatement(s) => + SequenceChunk(s.map(compile(ctx, _))) + case ReturnStatement(None) => + // TODO: return type check + // TODO: better stackpointer fix + ctx.function.returnType match { + case _: BooleanType => + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ returnInstructions) + case t => t.size match { + case 0 => + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ + List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions) + case 1 => + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ + List(AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions) + case 2 => + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ + List(AssemblyLine.discardYF()) ++ returnInstructions) + } + } + case ReturnStatement(Some(e)) => + m.returnType match { + case _: BooleanType => + m.returnType.size match { + case 0 => + ErrorReporting.error("Cannot return anything from a void function", statement.position) + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ returnInstructions) + case 1 => + LinearChunk(compile(ctx, e, someRegisterA, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions) + case 2 => + LinearChunk(compile(ctx, e, someRegisterAX, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions) + } + case _ => + m.returnType.size match { + case 0 => + ErrorReporting.error("Cannot return anything from a void function", statement.position) + LinearChunk(stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions) + case 1 => + LinearChunk(compile(ctx, e, someRegisterA, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions) + case 2 => + // TODO: ??? + val stackPointerFix = stackPointerFixBeforeReturn(ctx) + if (stackPointerFix.isEmpty) { + LinearChunk(compile(ctx, e, someRegisterAX, NoBranching) ++ List(AssemblyLine.discardYF()) ++ returnInstructions) + } else { + LinearChunk(compile(ctx, e, someRegisterYA, NoBranching) ++ + stackPointerFix ++ + List(AssemblyLine.implied(TAX), AssemblyLine.implied(TYA), AssemblyLine.discardYF()) ++ + returnInstructions) + } + } + } + case IfStatement(condition, thenPart, elsePart) => + val condType = getExpressionType(ctx, condition) + val thenBlock = compile(ctx, thenPart) + val elseBlock = compile(ctx, elsePart) + val largeThenBlock = thenBlock.sizeInBytes > 100 + val largeElseBlock = elseBlock.sizeInBytes > 100 + condType match { + case ConstantBooleanType(_, true) => thenBlock + case ConstantBooleanType(_, false) => elseBlock + case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + (thenPart, elsePart) match { + case (Nil, Nil) => EmptyChunk + case (Nil, _) => + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + if (largeElseBlock) { + val middle = nextLabel("el") + val end = nextLabel("fi") + SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, middle), jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end))) + } else { + val end = nextLabel("fi") + SequenceChunk(List(conditionBlock, branchChunk(jumpIfTrue, end), elseBlock, labelChunk(end))) + } + case (_, Nil) => + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + if (largeThenBlock) { + val middle = nextLabel("th") + val end = nextLabel("fi") + SequenceChunk(List(conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end))) + } else { + val end = nextLabel("fi") + SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, end), thenBlock, labelChunk(end))) + } + case _ => + // TODO: large blocks + if (largeElseBlock || largeThenBlock) ErrorReporting.error("Large blocks in if statement", statement.position) + val middle = nextLabel("el") + val end = nextLabel("fi") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, middle), thenBlock, jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end))) + } + case BuiltInBooleanType => + (thenPart, elsePart) match { + case (Nil, Nil) => EmptyChunk + case (Nil, _) => + val end = nextLabel("fi") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(end))) + SequenceChunk(List(conditionBlock, elseBlock, labelChunk(end))) + case (_, Nil) => + val end = nextLabel("fi") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end))) + SequenceChunk(List(conditionBlock, thenBlock, labelChunk(end))) + case _ => + val middle = nextLabel("el") + val end = nextLabel("fi") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(middle))) + SequenceChunk(List(conditionBlock, thenBlock, jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end))) + } + case _ => + ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position) + EmptyChunk + } + case WhileStatement(condition, bodyPart) => + val condType = getExpressionType(ctx, condition) + val bodyBlock = compile(ctx, bodyPart) + val largeBodyBlock = bodyBlock.sizeInBytes > 100 + condType match { + case ConstantBooleanType(_, true) => + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + val start = nextLabel("wh") + SequenceChunk(List(labelChunk(start), bodyBlock, jmpChunk(start))) + case ConstantBooleanType(_, false) => EmptyChunk + case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + if (largeBodyBlock) { + val start = nextLabel("wh") + val middle = nextLabel("he") + val end = nextLabel("ew") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + SequenceChunk(List(labelChunk(start), conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), bodyBlock, jmpChunk(start), labelChunk(end))) + } else { + val start = nextLabel("wh") + val end = nextLabel("ew") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + SequenceChunk(List(labelChunk(start), conditionBlock, branchChunk(jumpIfFalse, end), bodyBlock, jmpChunk(start), labelChunk(end))) + } + case BuiltInBooleanType => + if (largeBodyBlock) { + val start = nextLabel("wh") + val middle = nextLabel("he") + val end = nextLabel("ew") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(middle))) + SequenceChunk(List(labelChunk(start), conditionBlock, jmpChunk(end), labelChunk(middle), bodyBlock, jmpChunk(start), labelChunk(end))) + } else { + val start = nextLabel("wh") + val end = nextLabel("ew") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end))) + SequenceChunk(List(labelChunk(start), conditionBlock, bodyBlock, jmpChunk(start), labelChunk(end))) + } + case _ => + ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position) + EmptyChunk + } + case DoWhileStatement(bodyPart, condition) => + val condType = getExpressionType(ctx, condition) + val bodyBlock = compile(ctx, bodyPart) + val largeBodyBlock = bodyBlock.sizeInBytes > 100 + condType match { + case ConstantBooleanType(_, true) => + val start = nextLabel("do") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + SequenceChunk(List(labelChunk(start), bodyBlock, jmpChunk(start))) + case ConstantBooleanType(_, false) => bodyBlock + case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + val start = nextLabel("do") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching)) + if (largeBodyBlock) { + val end = nextLabel("od") + SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfFalse, end), jmpChunk(start), labelChunk(end))) + } else { + SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfTrue, start))) + } + case BuiltInBooleanType => + val start = nextLabel("do") + if (largeBodyBlock) { + val end = nextLabel("od") + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end))) + SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, jmpChunk(start), labelChunk(end))) + } else { + val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(start))) + SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock)) + } + case _ => + ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position) + EmptyChunk + } + case f@ForStatement(variable, start, end, direction, body) => + // TODO: check sizes + // TODO: special faster cases + val vex = VariableExpression(f.variable) + val one = LiteralExpression(1, 1) + val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one))) + val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one))) + (direction, env.eval(start), env.eval(end)) match { + + case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 => + compile(ctx, Assignment(vex, f.start) :: f.body) + case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s >= e => + EmptyChunk + + case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e => + compile(ctx, Assignment(vex, f.start) :: f.body) + case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s > e => + EmptyChunk + + case (ForDirection.ParallelUntil, Some(NumericConstant(0, ssize)), Some(NumericConstant(e, _))) if e > 0 => + compile(ctx, List( + Assignment(vex, f.end), + DoWhileStatement(decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start))) + )) + + case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e => + compile(ctx, Assignment(vex, LiteralExpression(s, ssize)) :: f.body) + case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s < e => + EmptyChunk + case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(0, esize))) if s > 0 => + compile(ctx, List( + Assignment(vex, f.start), + DoWhileStatement(f.body :+ decrement, FunctionCallExpression("!=", List(vex, f.end))) + )) + + + case (ForDirection.Until | ForDirection.ParallelUntil, _, _) => + compile(ctx, List( + Assignment(vex, f.start), + WhileStatement( + FunctionCallExpression("<", List(vex, f.end)), + f.body :+ increment), + )) + case (ForDirection.To | ForDirection.ParallelTo,_,_) => + compile(ctx, List( + Assignment(vex, f.start), + WhileStatement( + FunctionCallExpression("<=", List(vex, f.end)), + f.body :+ increment), + )) + case (ForDirection.DownTo,_,_) => + compile(ctx, List( + Assignment(vex, f.start), + IfStatement( + FunctionCallExpression(">=", List(vex, f.end)), + List(DoWhileStatement( + f.body :+ decrement, + FunctionCallExpression("!=", List(vex, f.end)) + )), + Nil) + )) + } + // TODO + } + } + + private def labelChunk(labelName: String) = { + LinearChunk(List(AssemblyLine.label(Label(labelName)))) + } + + private def jmpChunk(labelName: String) = { + LinearChunk(List(AssemblyLine.absolute(JMP, Label(labelName)))) + } + + private def branchChunk(opcode: Opcode.Value, labelName: String) = { + LinearChunk(List(AssemblyLine.relative(opcode, Label(labelName)))) + } +} diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala new file mode 100644 index 00000000..abc95f07 --- /dev/null +++ b/src/main/scala/millfork/env/Constant.scala @@ -0,0 +1,224 @@ +package millfork.env + +import millfork.error.ErrorReporting +import millfork.node.Position + +object Constant { + val Zero: Constant = NumericConstant(0, 1) + val One: Constant = NumericConstant(1, 1) + + def error(msg: String, position: Option[Position] = None): Constant = { + ErrorReporting.error(msg, position) + Zero + } + + def minimumSize(value: Long): Int = if (value < -128 || value > 255) 2 else 1 // TODO !!! +} + +import millfork.env.Constant.minimumSize +import millfork.error.ErrorReporting +import millfork.node.Position + +sealed trait Constant { + + def asl(i: Constant): Constant = i match { + case NumericConstant(sa, _) => asl(sa.toInt) + case _ => CompoundConstant(MathOperator.Shl, this, i) + } + + def asl(i: Int): Constant = CompoundConstant(MathOperator.Shl, this, NumericConstant(i, 1)) + + def requiredSize: Int + + def +(that: Constant): Constant = CompoundConstant(MathOperator.Plus, this, that) + + def -(that: Constant): Constant = CompoundConstant(MathOperator.Minus, this, that) + + def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that)) + + def -(that: Long): Constant = this + (-that) + + def loByte: Constant = { + if (requiredSize == 1) return this + HalfWordConstant(this, hi = false) + } + + def hiByte: Constant = { + if (requiredSize == 1) Constant.Zero + else HalfWordConstant(this, hi = true) + } + + def subbyte(index: Int): Constant = { + if (requiredSize <= index) Constant.Zero + else index match { + case 0 => loByte + case 1 => hiByte + case _ => SubbyteConstant(this, index) + } + } + + def isLowestByteAlwaysEqual(i: Int) : Boolean = false + + def quickSimplify: Constant = this +} + +case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant + +case class NumericConstant(value: Long, requiredSize: Int) extends Constant { + if (requiredSize == 1) { + if (value < -128 || value > 255) { + throw new IllegalArgumentException("The constant is too big") + } + } + + override def isLowestByteAlwaysEqual(i: Int) : Boolean = (value & 0xff) == (i&0xff) + + override def asl(i: Int) = NumericConstant(value << i, requiredSize + i / 8) + + override def +(that: Constant): Constant = that + value + + override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that)) + + override def toString: String = if (value > 9) value.formatted("$%X") else value.toString +} + +case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant { + override def requiredSize = 2 + + override def toString: String = thing.name +} + +case class HalfWordConstant(base: Constant, hi: Boolean) extends Constant { + override def quickSimplify: Constant = { + val simplified = base.quickSimplify + simplified match { + case NumericConstant(x, size) => if (hi) { + if (size == 1) Constant.Zero else NumericConstant((x >> 8) & 0xff, 1) + } else { + NumericConstant(x & 0xff, 1) + } + case _ => HalfWordConstant(simplified, hi) + } + } + + override def requiredSize = 1 + + override def toString: String = base + (if (hi) ".hi" else ".lo") +} + +case class SubbyteConstant(base: Constant, index: Int) extends Constant { + override def quickSimplify: Constant = { + val simplified = base.quickSimplify + simplified match { + case NumericConstant(x, size) => if (index >= size) { + Constant.Zero + } else { + NumericConstant((x >> (index * 8)) & 0xff, 1) + } + case _ => SubbyteConstant(simplified, index) + } + } + + override def requiredSize = 1 + + override def toString: String = base + (index match { + case 0 => ".lo" + case 1 => ".hi" + case 2 => ".b2" + case 3 => ".b3" + }) +} + +object MathOperator extends Enumeration { + val Plus, Minus, Times, Shl, Shr, + DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShr, + And, Or, Exor = Value +} + +case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Constant) extends Constant { + override def quickSimplify: Constant = { + val l = lhs.quickSimplify + val r = rhs.quickSimplify + (l, r) match { + case (NumericConstant(lv, ls), NumericConstant(rv, rs)) => + var size = ls max rs + val value = operator match { + case MathOperator.Plus => lv + rv + case MathOperator.Minus => lv - rv + case MathOperator.Times => lv * rv + case MathOperator.Shl => lv << rv + case MathOperator.Shr => lv >> rv + case MathOperator.Exor => lv ^ rv + case MathOperator.Or => lv | rv + case MathOperator.And => lv & rv + case _ => return this + } + operator match { + case MathOperator.Times | MathOperator.Shl => + val mask = (1 << (size * 8)) - 1 + if (value != (value & mask)){ + size = ls + rs + } + case _ => + } + NumericConstant(value, size) + case _ => CompoundConstant(operator, l, r) + } + } + + + import MathOperator._ + + override def +(that: Constant): Constant = { + that match { + case NumericConstant(n, _) => this + n + case _ => super.+(that) + } + } + + override def +(that: Long): Constant = { + if (that == 0) { + return this + } + val That = that + val MinusThat = -that + this match { + case CompoundConstant(Plus, NumericConstant(MinusThat, _), r) => r + case CompoundConstant(Plus, l, NumericConstant(MinusThat, _)) => l + case CompoundConstant(Plus, NumericConstant(x, _), r) => CompoundConstant(Plus, r, NumericConstant(x + that, minimumSize(x + that))) + case CompoundConstant(Plus, l, NumericConstant(x, _)) => CompoundConstant(Plus, l, NumericConstant(x + that, minimumSize(x + that))) + case CompoundConstant(Minus, l, NumericConstant(That, _)) => l + case _ => CompoundConstant(Plus, this, NumericConstant(that, minimumSize(that))) + } + } + + private def plhs = lhs match { + case _: NumericConstant | _: MemoryAddressConstant => lhs + case _ => "(" + lhs + ')' + } + + private def prhs = lhs match { + case _: NumericConstant | _: MemoryAddressConstant => rhs + case _ => "(" + rhs + ')' + } + + override def toString: String = { + operator match { + case Plus => f"$plhs + $prhs" + case Minus => f"$plhs - $prhs" + case Times => f"$plhs * $prhs" + case Shl => f"$plhs << $prhs" + case Shr => f"$plhs >> $prhs" + case DecimalPlus => f"$plhs +' $prhs" + case DecimalMinus => f"$plhs -' $prhs" + case DecimalTimes => f"$plhs *' $prhs" + case DecimalShl => f"$plhs <<' $prhs" + case DecimalShr => f"$plhs >>' $prhs" + case And => f"$plhs & $prhs" + case Or => f"$plhs | $prhs" + case Exor => f"$plhs ^ $prhs" + } + } + + override def requiredSize: Int = lhs.requiredSize max rhs.requiredSize +} \ No newline at end of file diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala new file mode 100644 index 00000000..c5d8cb2b --- /dev/null +++ b/src/main/scala/millfork/env/Environment.scala @@ -0,0 +1,618 @@ +package millfork.env + +import java.util.concurrent.atomic.AtomicLong + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly.Opcode +import millfork.compiler._ +import millfork.error.ErrorReporting +import millfork.node._ +import millfork.output.VariableAllocator + +import scala.collection.mutable + + +/** + * @author Karol Stasiak + */ +//noinspection NotImplementedCode +class Environment(val parent: Option[Environment], val prefix: String) { + + + private var baseStackOffset = 0x101 + private val relVarId = new AtomicLong + + def genRelativeVariable(constant: Constant, typ: Type, zeropage: Boolean): RelativeVariable = { + val variable = RelativeVariable(".rv__" + relVarId.incrementAndGet().formatted("%06d"), constant, typ, zeropage = zeropage) + addThing(variable, None) + variable + } + + + def allThings: Environment = { + val allThings: Map[String, Thing] = things.values.map { + case m: FunctionInMemory => + m.environment.getAllPrefixedThings + case m: InlinedFunction => + m.environment.getAllPrefixedThings + case _ => Map[String, Thing]() + }.fold(things.toMap)(_ ++ _) + val e = new Environment(None, "") + e.things.clear() + e.things ++= allThings + e + } + + + private def getAllPrefixedThings = { + things.toMap.map { case (n, th) => (if (n.startsWith(".")) n else prefix + n, th) } + } + + def getAllLocalVariables: List[Variable] = things.values.flatMap { + case v: Variable => + Some(v) + case _ => None + }.toList + + def allPreallocatables: List[PrellocableThing] = things.values.flatMap { + case m: NormalFunction => Some(m) + case m: InitializedArray => Some(m) + case _ => None + }.toList + + def allConstants: List[ConstantThing] = things.values.flatMap { + case m: NormalFunction => m.environment.allConstants + case m: InlinedFunction => m.environment.allConstants + case m: ConstantThing => List(m) + case _ => Nil + }.toList + + def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = { + val b = get[Type]("byte") + val p = get[Type]("pointer") + var params = nf.fold(List[String]()) { f => + f.params match { + case NormalParamSignature(ps) => + ps.map(p => p.name) + case _ => + Nil + } + }.toSet + val toAdd = things.values.flatMap { + case m: UninitializedMemory => + val vertex = if (options.flag(CompilationFlag.VariableOverlap)) { + nf.fold[VariableVertex](GlobalVertex) { f => + if (m.alloc == VariableAllocationMethod.Static) { + GlobalVertex + } else if (params(m.name)) { + ParamVertex(f.name) + } else { + LocalVertex(f.name) + } + } + } else GlobalVertex + m.alloc match { + case VariableAllocationMethod.None => + Nil + case VariableAllocationMethod.Zeropage => + m.sizeInBytes match { + case 2 => + val addr = + allocator.allocatePointer(callGraph, vertex) + onEachVariable(m.name, addr) + List( + ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) + ) + } + case VariableAllocationMethod.Auto | VariableAllocationMethod.Static => + m.sizeInBytes match { + case 0 => Nil + case 2 => + val addr = + allocator.allocateBytes(callGraph, vertex, options, 2) + onEachVariable(m.name, addr) + List( + ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) + ) + case count => + val addr = allocator.allocateBytes(callGraph, vertex, options, count) + onEachVariable(m.name, addr) + List( + ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) + ) + } + } + case f: NormalFunction => + f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable) + Nil + case _ => Nil + }.toList + val tagged: List[(String, Thing)] = toAdd.map(x => x.name -> x) + things ++= tagged + } + + val things: mutable.Map[String, Thing] = mutable.Map() + + private def addThing(t: Thing, position: Option[Position]): Unit = { + assertNotDefined(t.name, position) + things(t.name.stripPrefix(prefix)) = t + } + + def removeVariable(str: String): Unit = { + things -= str + things -= str + ".addr" + } + + def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = { + val clazz = implicitly[Manifest[T]].runtimeClass + if (things.contains(name)) { + val t: Thing = things(name) + if ((t ne null) && clazz.isInstance(t)) { + t.asInstanceOf[T] + } else { + ErrorReporting.fatal(s"`$name` is not a ${clazz.getSimpleName}", position) + } + } else parent.fold { + ErrorReporting.fatal(s"${clazz.getSimpleName} `$name` is not defined", position) + } { + _.get[T](name, position) + } + } + + def maybeGet[T <: Thing : Manifest](name: String): Option[T] = { + if (things.contains(name)) { + val t: Thing = things(name) + val clazz = implicitly[Manifest[T]].runtimeClass + if ((t ne null) && clazz.isInstance(t)) { + Some(t.asInstanceOf[T]) + } else { + None + } + } else parent.flatMap { + _.maybeGet[T](name) + } + } + + def getArrayOrPointer(arrayName: String): Thing = { + maybeGet[ThingInMemory](arrayName). + orElse(maybeGet[ThingInMemory](arrayName + ".array")). + orElse(maybeGet[ConstantThing](arrayName)). + getOrElse(ErrorReporting.fatal(s"`$arrayName` is not an array or a pointer")) + } + + if (parent.isEmpty) { + addThing(VoidType, None) + addThing(BuiltInBooleanType, None) + addThing(BasicPlainType("byte", 1), None) + addThing(BasicPlainType("word", 2), None) + addThing(BasicPlainType("long", 4), None) + addThing(DerivedPlainType("pointer", get[PlainType]("word"), isSigned = false), None) + addThing(DerivedPlainType("ubyte", get[PlainType]("byte"), isSigned = false), None) + addThing(DerivedPlainType("sbyte", get[PlainType]("byte"), isSigned = true), None) + addThing(DerivedPlainType("cent", get[PlainType]("byte"), isSigned = false), None) + val trueType = ConstantBooleanType("true$", value = true) + val falseType = ConstantBooleanType("false$", value = false) + addThing(trueType, None) + addThing(falseType, None) + addThing(ConstantThing("true", NumericConstant(0, 0), trueType), None) + addThing(ConstantThing("false", NumericConstant(0, 0), falseType), None) + addThing(FlagBooleanType("set_carry", Opcode.BCS, Opcode.BCC), None) + addThing(FlagBooleanType("clear_carry", Opcode.BCC, Opcode.BCS), None) + addThing(FlagBooleanType("set_overflow", Opcode.BVS, Opcode.BVC), None) + addThing(FlagBooleanType("clear_overflow", Opcode.BVC, Opcode.BVS), None) + addThing(FlagBooleanType("set_zero", Opcode.BEQ, Opcode.BNE), None) + addThing(FlagBooleanType("clear_zero", Opcode.BNE, Opcode.BEQ), None) + addThing(FlagBooleanType("set_negative", Opcode.BMI, Opcode.BPL), None) + addThing(FlagBooleanType("clear_negative", Opcode.BPL, Opcode.BMI), None) + } + + def assertNotDefined(name: String, position: Option[Position]): Unit = { + if (things.contains(name) || parent.exists(_.things.contains(name))) + ErrorReporting.fatal(s"`$name` is already defined", position) + } + + def registerType(stmt: TypeDefinitionStatement): Unit = { + // addThing(DerivedPlainType(stmt.name, get(stmt.parent))) + ??? + } + + def sequence[A](a: List[Option[A]]): Option[List[A]] = a match { + case Nil => Some(Nil) + case None :: _ => None + case Some(r) :: t => sequence(t) map (r :: _) + } + + def evalVariableAndConstantSubParts(e: Expression): (Option[Expression], Constant) = + e match { + case SumExpression(params, false) => + val (constants, variables) = params.map { case (sign, expr) => (sign, expr, eval(expr)) }.partition(_._3.isDefined) + val constant = eval(SumExpression(constants.map(x => (x._1, x._2)), decimal = false)).get + val variable = variables match { + case Nil => None + case List((false, x, _)) => Some(x) + case _ => Some(SumExpression(variables.map(x => (x._1, x._2)), decimal = false)) + } + variable -> constant + case _ => eval(e) match { + case Some(c) => None -> c + case None => Some(e) -> Constant.Zero + } + } + + def eval(e: Expression): Option[Constant] = { + e match { + case LiteralExpression(value, size) => Some(NumericConstant(value, size)) + case VariableExpression(name) => + maybeGet[ConstantThing](name).map(_.value) + case IndexedExpression(_, _) => None + case HalfWordExpression(param, hi) => eval(e).map(c => if (hi) c.hiByte else c.loByte) + case SumExpression(params, decimal) => + params.map { + case (minus, param) => (minus, eval(param)) + }.foldLeft(Some(Constant.Zero).asInstanceOf[Option[Constant]]) { (oc, pair) => + oc.flatMap { c => + pair match { + case (_, None) => None + case (minus, Some(addend)) => + val op = if (decimal) { + if (minus) MathOperator.DecimalMinus else MathOperator.DecimalPlus + } else { + if (minus) MathOperator.Minus else MathOperator.Plus + } + Some(CompoundConstant(op, c, addend)) + } + } + } + case SeparateBytesExpression(h, l) => for { + lc <- eval(l) + hc <- eval(h) + } yield hc.asl(8) + lc + case FunctionCallExpression(name, params) => + name match { + case "*" => + constantOperation(MathOperator.Times, params) + case "&&" | "&" => + constantOperation(MathOperator.And, params) + case "^" => + constantOperation(MathOperator.Exor, params) + case "||" | "|" => + constantOperation(MathOperator.Or, params) + case _ => + None + } + } + } + + private def constantOperation(op: MathOperator.Value, params: List[Expression]) = { + params.map(eval(_)).reduceLeft[Option[Constant]] { (oc, om) => + for { + c <- oc + m <- om + } yield CompoundConstant(op, c, m) + } + } + + def registerFunction(stmt: FunctionDeclarationStatement, options: CompilationOptions): Unit = { + val w = get[Type]("word") + val name = stmt.name + val resultType = get[Type](stmt.resultType) + + if (stmt.reentrant && stmt.interrupt) ErrorReporting.error(s"Reentrant function `$name` cannot be an interrupt handler", stmt.position) + if (stmt.reentrant && stmt.params.nonEmpty) ErrorReporting.error(s"Reentrant function `$name` cannot have parameters", stmt.position) + if (stmt.interrupt && stmt.params.nonEmpty) ErrorReporting.error(s"Interrupt function `$name` cannot have parameters", stmt.position) + if (stmt.inlined) { + if (!stmt.assembly) { + if (stmt.params.nonEmpty) ErrorReporting.error(s"Inline non-assembly function `$name` cannot have parameters", stmt.position) // TODO: ??? + if (resultType != VoidType) ErrorReporting.error(s"Inline non-assembly function `$name` must return void", stmt.position) + } + if (stmt.params.exists(_.assemblyParamPassingConvention.inNonInlinedOnly)) + ErrorReporting.error(s"Inline function `$name` cannot have by-variable parameters", stmt.position) + } else { + if (!stmt.assembly) { + if (stmt.params.exists(!_.assemblyParamPassingConvention.isInstanceOf[ByVariable])) + ErrorReporting.error(s"Non-assembly function `$name` cannot have non-variable parameters", stmt.position) + } + if (stmt.params.exists(_.assemblyParamPassingConvention.inInlinedOnly)) + ErrorReporting.error(s"Non-inline function `$name` cannot have inlinable parameters", stmt.position) + } + + val env = new Environment(Some(this), name + "$") + stmt.params.foreach(p => env.registerParameter(p)) + val params = if (stmt.assembly) { + AssemblyParamSignature(stmt.params.map { + pd => + val typ = env.get[Type](pd.typ) + pd.assemblyParamPassingConvention match { + case ByVariable(vn) => + AssemblyParam(typ, env.get[MemoryVariable](vn), AssemblyParameterPassingBehaviour.Copy) + case ByRegister(reg) => + AssemblyParam(typ, RegisterVariable(reg, typ), AssemblyParameterPassingBehaviour.Copy) + case ByConstant(vn) => + AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByConstant) + case ByReference(vn) => + AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByReference) + } + }) + } else { + NormalParamSignature(stmt.params.map { pd => + env.get[MemoryVariable](pd.assemblyParamPassingConvention.asInstanceOf[ByVariable].name) + }) + } + stmt.statements match { + case None => + stmt.address match { + case None => + ErrorReporting.error(s"Extern function `${stmt.name}`needs an address", stmt.position) + case Some(a) => + val addr = eval(a).getOrElse(Constant.error(s"Address of `${stmt.name}` is not a constant", stmt.position)) + val mangled = ExternFunction( + name, + resultType, + params, + addr, + env + ) + addThing(mangled, stmt.position) + registerAddressConstant(mangled, stmt.position) + addThing(ConstantThing(name + '`', addr, w), stmt.position) + } + + case Some(statements) => + statements.foreach { + case v: VariableDeclarationStatement => env.registerVariable(v, options) + case _ => () + } + val executableStatements = statements.flatMap { + case e: ExecutableStatement => Some(e) + case _ => None + } + val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement]) + if (stmt.inlined) { + val mangled = new InlinedFunction( + name, + resultType, + params, + env, + executableStatements ++ (if (needsExtraRTS) List(AssemblyStatement.implied(Opcode.RTS, elidable = true)) else Nil), + ) + addThing(mangled, stmt.position) + } else { + var stackVariablesSize = env.things.values.map { + case StackVariable(n, t, _) if !n.contains(".") => t.size + case _ => 0 + }.sum + val mangled = NormalFunction( + name, + resultType, + params, + env, + stackVariablesSize, + stmt.address.map(a => this.eval(a).getOrElse(Constant.error(s"Address of `${stmt.name}` is not a constant"))), + executableStatements ++ (if (needsExtraRTS) List(ReturnStatement(None)) else Nil), + interrupt = stmt.interrupt, + reentrant = stmt.reentrant, + position = stmt.position + ) + addThing(mangled, stmt.position) + registerAddressConstant(mangled, stmt.position) + } + } + } + + private def registerAddressConstant(thing: ThingInMemory, position: Option[Position]): Unit = { + val addr = thing.toAddress + addThing(ConstantThing(thing.name + ".addr", addr, get[Type]("pointer")), position) + addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, get[Type]("byte")), position) + addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, get[Type]("byte")), position) + } + + def registerParameter(stmt: ParameterDeclaration): Unit = { + val typ = get[Type](stmt.typ) + val b = get[Type]("byte") + val p = get[Type]("pointer") + stmt.assemblyParamPassingConvention match { + case ByVariable(name) => + val zp = typ.name == "pointer" // TODO + val v = MemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto) + addThing(v, stmt.position) + registerAddressConstant(v, stmt.position) + if (typ.size == 2) { + val addr = v.toAddress + addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = zp), stmt.position) + addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = zp), stmt.position) + } + case ByRegister(_) => () + case ByConstant(name) => + val v = ConstantThing(prefix + name, UnexpandedConstant(prefix + name, typ.size), typ) + addThing(v, stmt.position) + case ByReference(name) => + val addr = UnexpandedConstant(prefix + name, typ.size) + val v = RelativeVariable(prefix + name, addr, p, zeropage = false) + addThing(v, stmt.position) + addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = false), stmt.position) + addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = false), stmt.position) + } + } + + def registerArray(stmt: ArrayDeclarationStatement): Unit = { + val b = get[Type]("byte") + val p = get[Type]("pointer") + stmt.elements match { + case None => + stmt.length match { + case None => ErrorReporting.error(s"Array `${stmt.name}` without size nor contents", stmt.position) + case Some(l) => + val address = stmt.address.map(a => eval(a).getOrElse(ErrorReporting.fatal(s"Array `${stmt.name}` has non-constant address", stmt.position))) + val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position)) + lengthConst match { + case NumericConstant(length, _) => + if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position) + val array = address match { + case None => UninitializedArray(stmt.name + ".array", length.toInt) + case Some(aa) => RelativeArray(stmt.name + ".array", aa, length.toInt) + } + addThing(array, stmt.position) + registerAddressConstant(MemoryVariable(stmt.name, p, VariableAllocationMethod.None), stmt.position) + val a = address match { + case None => array.toAddress + case Some(aa) => aa + } + addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position) + addThing(ConstantThing(stmt.name, a, p), stmt.position) + addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position) + if (length < 256) { + addThing(ConstantThing(stmt.name + ".length", lengthConst, b), stmt.position) + } + case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position) + } + } + case Some(contents) => + stmt.length match { + case None => + case Some(l) => + val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position)) + lengthConst match { + case NumericConstant(ll, _) => + if (ll != contents.length) ErrorReporting.error(s"Array `${stmt.name}` has different declared and actual length", stmt.position) + case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position) + } + } + val length = contents.length + if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position) + val address = stmt.address.map(a => eval(a).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant address", stmt.position))) + val data = contents.map(x => eval(x).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant contents", stmt.position))) + val array = InitializedArray(stmt.name + ".array", address, data) + addThing(array, stmt.position) + registerAddressConstant(MemoryVariable(stmt.name, p, VariableAllocationMethod.None), stmt.position) + val a = address match { + case None => array.toAddress + case Some(aa) => aa + } + addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position) + addThing(ConstantThing(stmt.name, a, p), stmt.position) + addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position) + if (length < 256) { + addThing(ConstantThing(stmt.name + ".length", NumericConstant(length, 1), b), stmt.position) + } + } + } + + def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions): Unit = { + if (stmt.volatile) { + ErrorReporting.warn("`volatile` not yet supported", options) + } + val name = stmt.name + val position = stmt.position + if (stmt.stack && parent.isEmpty) { + if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position) + } + val b = get[Type]("byte") + val typ = get[PlainType](stmt.typ) + if (stmt.typ == "pointer") { + // if (stmt.constant) { + // ErrorReporting.error(s"Pointer `${stmt.name}` cannot be constant") + // } + stmt.address.flatMap(eval) match { + case Some(NumericConstant(a, _)) => + if ((a & 0xff00) != 0) + ErrorReporting.error(s"Pointer `${stmt.name}` cannot be located outside the zero page") + case _ => () + } + } + if (stmt.constant) { + if (stmt.stack) ErrorReporting.error(s"`$name` is a constant and cannot be on stack", position) + if (stmt.address.isDefined) ErrorReporting.error(s"`$name` is a constant and cannot have an address", position) + if (stmt.initialValue.isEmpty) ErrorReporting.error(s"`$name` is a constant and requires a value", position) + val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(Constant.error(s"`$name` has a non-constant value", position)) + if (constantValue.requiredSize > typ.size) ErrorReporting.error(s"`$name` is has an invalid value: not in the range of `$typ`", position) + addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position) + if (typ.size == 2) { + addThing(ConstantThing(prefix + name + ".hi", constantValue + 1, b), stmt.position) + addThing(ConstantThing(prefix + name + ".lo", constantValue, b), stmt.position) + } + } else { + if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position) + if (stmt.initialValue.isDefined) ErrorReporting.error(s"`$name` is not a constant and cannot have a value", position) + if (stmt.stack) { + val v = StackVariable(prefix + name, typ, this.baseStackOffset) + baseStackOffset += typ.size + addThing(v, stmt.position) + if (typ.size == 2) { + addThing(StackVariable(prefix + name + ".lo", b, baseStackOffset), stmt.position) + addThing(StackVariable(prefix + name + ".hi", b, baseStackOffset + 1), stmt.position) + } + } else { + val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({ + val alloc = if (typ.name == "pointer") VariableAllocationMethod.Zeropage else if (stmt.global) VariableAllocationMethod.Static else VariableAllocationMethod.Auto + val v = MemoryVariable(prefix + name, typ, alloc) + registerAddressConstant(v, stmt.position) + (v, v.toAddress) + })(a => { + val addr = eval(a).getOrElse(Constant.error(s"Address of `$name` has a non-constant value", position)) + val zp = addr match { + case NumericConstant(n, _) => n < 0x100 + case _ => false + } + (RelativeVariable(prefix + name, addr, typ, zeropage = zp), addr) + }) + addThing(v, stmt.position) + if (!v.isInstanceOf[MemoryVariable]) { + addThing(ConstantThing(v.name + "`", addr, b), stmt.position) + } + if (typ.size == 2) { + addThing(RelativeVariable(prefix + name + ".hi", addr + 1, b, zeropage = v.zeropage), stmt.position) + addThing(RelativeVariable(prefix + name + ".lo", addr, b, zeropage = v.zeropage), stmt.position) + } + } + } + } + + def lookup[T <: Thing : Manifest](name: String): Option[T] = { + if (things.contains(name)) { + maybeGet(name) + } else { + parent.flatMap(_.lookup[T](name)) + } + } + + def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = { + if (things.contains(name)) { + val function = get[MangledFunction](name) + if (function.params.length != actualParams.length) { + ErrorReporting.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position)) + } + function.params match { + case NormalParamSignature(params) => + function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) => + if (!actual.isAssignableTo(required)) { + ErrorReporting.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position) + } + } + case AssemblyParamSignature(params) => + function.params.types.zip(actualParams).zipWithIndex.foreach { case ((required, (actual, expr)), ix) => + if (!actual.isAssignableTo(required)) { + ErrorReporting.error(s"Invalid value for parameter ${ix + 1} of function `$name`", expr.position) + } + } + } + Some(function) + } else { + parent.flatMap(_.lookupFunction(name, actualParams)) + } + } + + def collectDeclarations(program: Program, options: CompilationOptions): Unit = { + program.declarations.foreach { + case f: FunctionDeclarationStatement => registerFunction(f, options) + case v: VariableDeclarationStatement => registerVariable(v, options) + case a: ArrayDeclarationStatement => registerArray(a) + case i: ImportStatement => () + } + } +} diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala new file mode 100644 index 00000000..38d13035 --- /dev/null +++ b/src/main/scala/millfork/env/Thing.scala @@ -0,0 +1,264 @@ +package millfork.env + +import millfork.assembly.Opcode +import millfork.error.ErrorReporting +import millfork.node._ + +sealed trait Thing { + def name: String +} + +sealed trait Type extends Thing { + + def size: Int + + def isSigned: Boolean + + def isSubtypeOf(other: Type): Boolean = this == other + + def isCompatible(other: Type): Boolean = this == other + + override def toString(): String = name + + def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) +} + +case object VoidType extends Type { + def size = 0 + + def isSigned = false + + override def name = "void" +} + +sealed trait PlainType extends Type { + override def isCompatible(other: Type): Boolean = this == other || this.isSubtypeOf(other) || other.isSubtypeOf(this) + + override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || (targetType match { + case BasicPlainType(_, size) => size > this.size // TODO + case _ => false + }) +} + +case class BasicPlainType(name: String, size: Int) extends PlainType { + def isSigned = false + + override def isSubtypeOf(other: Type): Boolean = this == other +} + +case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean) extends PlainType { + def size: Int = parent.size + + override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other) +} + +sealed trait BooleanType extends Type { + def size = 0 + + def isSigned = false +} + +case class ConstantBooleanType(name: String, value: Boolean) extends BooleanType + +case class FlagBooleanType(name: String, jumpIfTrue: Opcode.Value, jumpIfFalse: Opcode.Value) extends BooleanType + +case object BuiltInBooleanType extends BooleanType { + override def name = "bool$" +} + +sealed trait TypedThing extends Thing { + def typ: Type +} + + +sealed trait ThingInMemory extends Thing { + def toAddress: Constant +} + +sealed trait PrellocableThing extends ThingInMemory { + def shouldGenerate: Boolean + + def address: Option[Constant] + + def toAddress: Constant = address.getOrElse(MemoryAddressConstant(this)) +} + +case class Label(name: String) extends ThingInMemory { + override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) +} + +sealed trait Variable extends TypedThing + +case class BlackHole(typ: Type) extends Variable { + override def name = "" +} + +sealed trait VariableInMemory extends Variable with ThingInMemory { + + def zeropage: Boolean +} + +case class RegisterVariable(register: Register.Value, typ: Type) extends Variable { + def name: String = register.toString +} + +case class Placeholder(name: String, typ: Type) extends Variable + +sealed trait UninitializedMemory extends ThingInMemory { + def sizeInBytes: Int + + def alloc: VariableAllocationMethod.Value +} + +object VariableAllocationMethod extends Enumeration { + val Auto, Static, Zeropage, None = Value +} + +case class StackVariable(name: String, typ: Type, baseOffset: Int) extends Variable { + def sizeInBytes: Int = typ.size +} + +case class MemoryVariable(name: String, typ: Type, alloc: VariableAllocationMethod.Value) extends VariableInMemory with UninitializedMemory { + override def sizeInBytes: Int = typ.size + + override def zeropage: Boolean = alloc == VariableAllocationMethod.Zeropage + + override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) +} + +trait MlArray extends ThingInMemory + +case class UninitializedArray(name: String, sizeInBytes: Int) extends MlArray with UninitializedMemory { + override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) + + override def alloc = VariableAllocationMethod.Static +} + +case class RelativeArray(name: String, address: Constant, sizeInBytes: Int) extends MlArray { + override def toAddress: Constant = address +} + +case class InitializedArray(name: String, address: Option[Constant], contents: List[Constant]) extends MlArray with PrellocableThing { + override def shouldGenerate = true +} + +case class RelativeVariable(name: String, address: Constant, typ: Type, zeropage: Boolean) extends VariableInMemory { + override def toAddress: Constant = address +} + + +sealed trait MangledFunction extends Thing { + def name: String + + def returnType: Type + + def params: ParamSignature + + def interrupt: Boolean +} + +case class EmptyFunction(name: String, + returnType: Type, + paramType: Type) extends MangledFunction { + override def params = EmptyFunctionParamSignature(paramType) + + override def interrupt = false +} + +case class InlinedFunction(name: String, + returnType: Type, + params: ParamSignature, + environment: Environment, + code: List[ExecutableStatement]) extends MangledFunction { + override def interrupt = false +} + +sealed trait FunctionInMemory extends MangledFunction with ThingInMemory { + def environment: Environment +} + +case class ExternFunction(name: String, + returnType: Type, + params: ParamSignature, + address: Constant, + environment: Environment) extends FunctionInMemory { + override def toAddress: Constant = address + + override def interrupt = false +} + +case class NormalFunction(name: String, + returnType: Type, + params: ParamSignature, + environment: Environment, + stackVariablesSize: Int, + address: Option[Constant], + code: List[ExecutableStatement], + interrupt: Boolean, + reentrant: Boolean, + position: Option[Position]) extends FunctionInMemory with PrellocableThing { + override def shouldGenerate = true +} + +case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing + +trait ParamSignature { + def types: List[Type] + + def length: Int +} + +case class NormalParamSignature(params: List[MemoryVariable]) extends ParamSignature { + override def length: Int = params.length + + override def types: List[Type] = params.map(_.typ) +} + +sealed trait ParamPassingConvention { + def inInlinedOnly: Boolean + + def inNonInlinedOnly: Boolean +} + +case class ByRegister(register: Register.Value) extends ParamPassingConvention { + override def inInlinedOnly = false + + override def inNonInlinedOnly = false +} + +case class ByVariable(name: String) extends ParamPassingConvention { + override def inInlinedOnly = false + + override def inNonInlinedOnly = true +} + +case class ByConstant(name: String) extends ParamPassingConvention { + override def inInlinedOnly = true + + override def inNonInlinedOnly = false +} + +case class ByReference(name: String) extends ParamPassingConvention { + override def inInlinedOnly = true + + override def inNonInlinedOnly = false +} + +object AssemblyParameterPassingBehaviour extends Enumeration { + val Copy, ByReference, ByConstant = Value +} + +case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value) + + +case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature { + override def length: Int = params.length + + override def types: List[Type] = params.map(_.typ) +} + +case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature { + override def length: Int = 1 + + override def types: List[Type] = List(paramType) +} \ No newline at end of file diff --git a/src/main/scala/millfork/error/ErrorReporting.scala b/src/main/scala/millfork/error/ErrorReporting.scala new file mode 100644 index 00000000..a3cdbbb9 --- /dev/null +++ b/src/main/scala/millfork/error/ErrorReporting.scala @@ -0,0 +1,74 @@ +package millfork.error + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.node.Position + +object ErrorReporting { + + var verbosity = 0 + + var hasErrors = false + + def f(position: Option[Position]): String = position.fold("")(p => s"(${p.line}:${p.column}) ") + + def info(msg: String, position: Option[Position] = None): Unit = { + if (verbosity < 0) return + println("INFO: " + f(position) + msg) + flushOutput() + } + + def debug(msg: String, position: Option[Position] = None): Unit = { + if (verbosity < 1) return + println("DEBUG: " + f(position) + msg) + flushOutput() + } + + def trace(msg: String, position: Option[Position] = None): Unit = { + if (verbosity < 2) return + println("TRACE: " + f(position) + msg) + flushOutput() + } + + private def flushOutput(): Unit = { + System.out.flush() + System.err.flush() + } + + def warn(msg: String, options: CompilationOptions, position: Option[Position] = None): Unit = { + if (verbosity < 0) return + println("WARN: " + f(position) + msg) + flushOutput() + if (options.flag(CompilationFlag.FatalWarnings)) { + hasErrors = true + } + } + + def error(msg: String, position: Option[Position] = None): Unit = { + hasErrors = true + println("ERROR: " + f(position) + msg) + flushOutput() + } + + def fatal(msg: String, position: Option[Position] = None): Nothing = { + hasErrors = true + println("FATAL: " + f(position) + msg) + flushOutput() + throw new RuntimeException(msg) + } + + def fatalQuit(msg: String, position: Option[Position] = None): Nothing = { + hasErrors = true + println("FATAL: " + f(position) + msg) + flushOutput() + System.exit(1) + throw new RuntimeException(msg) + } + + def assertNoErrors(msg: String): Unit = { + if (hasErrors) { + error(msg) + fatal("Build halted due to previous errors") + } + } + +} \ No newline at end of file diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala new file mode 100644 index 00000000..9fc994bd --- /dev/null +++ b/src/main/scala/millfork/node/CallGraph.scala @@ -0,0 +1,151 @@ +package millfork.node + +import millfork.error.ErrorReporting + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ + +sealed trait VariableVertex { + def function: String +} + +case class ParamVertex(function: String) extends VariableVertex + +case class LocalVertex(function: String) extends VariableVertex + +case object GlobalVertex extends VariableVertex { + override def function = "" +} + +trait CallGraph { + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean +} + +object RestrictiveCallGraph extends CallGraph { + + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false +} + +class StandardCallGraph(program: Program) extends CallGraph { + + private val entryPoints = mutable.Set[String]() + // (F,G) means function F calls function G + private val callEdges = mutable.Set[(String, String)]() + // (F,G) means function G is called when building parameters for function F + private val paramEdges = mutable.Set[(String, String)]() + private val multiaccessibleFunctions = mutable.Set[String]() + private val everCalledFunctions = mutable.Set[String]() + private val allFunctions = mutable.Set[String]() + + entryPoints += "main" + program.declarations.foreach(s => add(None, Nil, s)) + everCalledFunctions.retain(allFunctions) + + def add(currentFunction: Option[String], callingFunctions: List[String], node: Node): Unit = { + node match { + case f: FunctionDeclarationStatement => + allFunctions += f.name + if (f.address.isDefined || f.interrupt) entryPoints += f.name + f.statements.getOrElse(Nil).foreach(s => this.add(Some(f.name), Nil, s)) + case s: Statement => + s.getAllExpressions.foreach(e => add(currentFunction, callingFunctions, e)) + case g: FunctionCallExpression => + everCalledFunctions += g.functionName + currentFunction.foreach(f => callEdges += f -> g.functionName) + callingFunctions.foreach(f => paramEdges += f -> g.functionName) + g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr)) + case x: VariableExpression => + val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") + everCalledFunctions += varName + case _ => () + } + } + + + def fillOut(): Unit = { + var changed = true + while (changed) { + changed = false + val toAdd = for { + (a, b) <- callEdges + (c, d) <- callEdges + if b == c + if !callEdges.contains(a -> d) + } yield (a, d) + if (toAdd.nonEmpty) { + callEdges ++= toAdd + changed = true + } + } + + changed = true + while (changed) { + changed = false + val toAdd = for { + (a, b) <- paramEdges + (c, d) <- callEdges + if b == c + if !paramEdges.contains(a -> d) + } yield (a, d) + if (toAdd.nonEmpty) { + paramEdges ++= toAdd + changed = true + } + } + multiaccessibleFunctions ++= entryPoints + everCalledFunctions ++= entryPoints + callEdges.filter(e => entryPoints.contains(e._1)).foreach(e => everCalledFunctions += e._2) + multiaccessibleFunctions ++= callEdges.filter(e => entryPoints.contains(e._1)).map(_._2).groupBy(identity).filter(p => p._2.size > 1).keys + + ErrorReporting.trace("Call edges:") + callEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString)) + + ErrorReporting.trace("Param edges:") + paramEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString)) + + ErrorReporting.trace("Entry points:") + entryPoints.toList.sorted.foreach(ErrorReporting.trace(_)) + + ErrorReporting.trace("Multiaccessible functions:") + multiaccessibleFunctions.toList.sorted.foreach(ErrorReporting.trace(_)) + + ErrorReporting.trace("Ever called functions:") + everCalledFunctions.toList.sorted.foreach(ErrorReporting.trace(_)) + } + + def isEverCalled(function: String): Boolean = { + everCalledFunctions(function) + } + + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = { + if (a.function == b.function) { + return false + } + if (a == GlobalVertex || b == GlobalVertex) { + return false + } + if (multiaccessibleFunctions(a.function) || multiaccessibleFunctions(b.function)) { + return false + } + if (callEdges(a.function -> b.function) || callEdges(b.function -> a.function)) { + return false + } + a match { + case ParamVertex(af) => + if (paramEdges(af -> b.function)) return false + case _ => + } + b match { + case ParamVertex(bf) => + if (paramEdges(bf -> a.function)) return false + case _ => + } + ErrorReporting.trace(s"$a and $b can overlap") + true + } + + +} diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala new file mode 100644 index 00000000..51d1367f --- /dev/null +++ b/src/main/scala/millfork/node/Node.scala @@ -0,0 +1,181 @@ +package millfork.node + +import millfork.assembly.{AddrMode, Opcode} +import millfork.env.{Label, ParamPassingConvention} + +case class Position(filename: String, line: Int, column: Int, cursor: Int) + +sealed trait Node { + var position: Option[Position] = None +} + +object Node { + implicit class NodeOps[N<:Node](val node: N) extends AnyVal { + def pos(position: Position): N = { + node.position = Some(position) + node + } + } +} + +sealed trait Expression extends Node { + def replaceVariable(variable: String, actualParam: Expression): Expression +} + +case class LiteralExpression(value: Long, requiredSize: Int) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = this +} + +case class BooleanLiteralExpression(value: Boolean) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = this +} + +sealed trait LhsExpression extends Expression + +case object BlackHoleExpression extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this +} + +case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression { + def replaceVariable(variable: String, actualParam: Expression): Expression = + SeparateBytesExpression( + hi.replaceVariable(variable, actualParam), + lo.replaceVariable(variable, actualParam)) +} + +case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal) +} + +case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + FunctionCallExpression(functionName, expressions.map { + _.replaceVariable(variable, actualParam) + }) +} + +case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte) +} + +object Register extends Enumeration { + val A, X, Y, AX, AY, YA, XA, XY, YX = Value +} + +//case class Indexing(child: Expression, register: Register.Value) extends Expression + +case class VariableExpression(name: String) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + if (name == variable) actualParam else this +} + +case class IndexedExpression(name: String, index: Expression) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + if (name == variable) { + actualParam match { + case VariableExpression(actualVariable) => IndexedExpression(actualVariable, index.replaceVariable(variable, actualParam)) + case _ => ??? // TODO + } + } else IndexedExpression(name, index.replaceVariable(variable, actualParam)) +} + +sealed trait Statement extends Node { + def getAllExpressions: List[Expression] +} + +sealed trait DeclarationStatement extends Statement + +case class TypeDefinitionStatement(name: String, parent: String) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = Nil +} + +case class VariableDeclarationStatement(name: String, + typ: String, + global: Boolean, + stack: Boolean, + constant: Boolean, + volatile: Boolean, + initialValue: Option[Expression], + address: Option[Expression]) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = List(initialValue, address).flatten +} + +case class ArrayDeclarationStatement(name: String, + length: Option[Expression], + address: Option[Expression], + elements: Option[List[Expression]]) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = List(length, address).flatten ++ elements.getOrElse(Nil) +} + +case class ParameterDeclaration(typ: String, + assemblyParamPassingConvention: ParamPassingConvention) extends Node + +case class ImportStatement(filename: String) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = Nil +} + +case class FunctionDeclarationStatement(name: String, + resultType: String, + params: List[ParameterDeclaration], + address: Option[Expression], + statements: Option[List[Statement]], + inlined: Boolean, + assembly: Boolean, + interrupt: Boolean, + reentrant: Boolean) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions) +} + +sealed trait ExecutableStatement extends Statement + +case class ExpressionStatement(expression: Expression) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = List(expression) +} + +case class ReturnStatement(value: Option[Expression]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = value.toList +} + +case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = List(destination, source) +} + +case class LabelStatement(label: Label) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = Nil +} + +case class AssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, expression: Expression, elidable: Boolean) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = List(expression) +} + +case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions) +} + +case class WhileStatement(condition: Expression, body: List[ExecutableStatement]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) +} + +object ForDirection extends Enumeration { + val To, Until, DownTo, ParallelTo, ParallelUntil = Value +} + +case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = start :: end :: body.flatMap(_.getAllExpressions) +} + +case class DoWhileStatement(body: List[ExecutableStatement], condition: Expression) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) +} + +case class BlockStatement(body: List[ExecutableStatement]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = body.flatMap(_.getAllExpressions) +} + +object AssemblyStatement { + def implied(opcode: Opcode.Value, elidable: Boolean) = AssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidable) + + def nonexistent(opcode: Opcode.Value) = AssemblyStatement(opcode, AddrMode.DoesNotExist, LiteralExpression(0, 1), elidable = true) +} \ No newline at end of file diff --git a/src/main/scala/millfork/node/Program.scala b/src/main/scala/millfork/node/Program.scala new file mode 100644 index 00000000..89f1e12a --- /dev/null +++ b/src/main/scala/millfork/node/Program.scala @@ -0,0 +1,11 @@ +package millfork.node + +import millfork.node.opt.NodeOptimization + +/** + * @author Karol Stasiak + */ +case class Program(declarations: List[DeclarationStatement]) { + def applyNodeOptimization(o: NodeOptimization) = Program(o.optimize(declarations).asInstanceOf[List[DeclarationStatement]]) + def +(p:Program): Program = Program(this.declarations ++ p.declarations) +} diff --git a/src/main/scala/millfork/node/opt/NodeOptimization.scala b/src/main/scala/millfork/node/opt/NodeOptimization.scala new file mode 100644 index 00000000..94421e8e --- /dev/null +++ b/src/main/scala/millfork/node/opt/NodeOptimization.scala @@ -0,0 +1,16 @@ +package millfork.node.opt + +import millfork.node.{ExecutableStatement, Expression, Node, Statement} + +/** + * @author Karol Stasiak + */ +trait NodeOptimization { + def optimize(nodes: List[Node]): List[Node] + + def optimizeExecutableStatements(nodes: List[ExecutableStatement]): List[ExecutableStatement] = + optimize(nodes).asInstanceOf[List[ExecutableStatement]] + + def optimizeStatements(nodes: List[Statement]): List[Statement] = + optimize(nodes).asInstanceOf[List[Statement]] +} diff --git a/src/main/scala/millfork/node/opt/UnreachableCode.scala b/src/main/scala/millfork/node/opt/UnreachableCode.scala new file mode 100644 index 00000000..fbff2a7b --- /dev/null +++ b/src/main/scala/millfork/node/opt/UnreachableCode.scala @@ -0,0 +1,29 @@ +package millfork.node.opt + +import millfork.node._ + +/** + * @author Karol Stasiak + */ +object UnreachableCode extends NodeOptimization { + + override def optimize(nodes: List[Node]): List[Node] = nodes match { + case (x:FunctionDeclarationStatement)::xs => + x.copy(statements = x.statements.map(optimizeStatements)) :: optimize(xs) + case (x:IfStatement)::xs => + x.copy( + thenBranch = optimizeExecutableStatements(x.thenBranch), + elseBranch = optimizeExecutableStatements(x.elseBranch)) :: optimize(xs) + case (x:WhileStatement)::xs => + x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs) + case (x:DoWhileStatement)::xs => + x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs) + case (x:ReturnStatement) :: xs => + x :: Nil + case x :: xs => + x :: optimize(xs) + case Nil => + Nil + } + +} diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala new file mode 100644 index 00000000..bbb158f8 --- /dev/null +++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala @@ -0,0 +1,72 @@ +package millfork.node.opt + +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ + +/** + * @author Karol Stasiak + */ +object UnusedFunctions extends NodeOptimization { + + override def optimize(nodes: List[Node]): List[Node] = { + val allNormalFunctions = nodes.flatMap { + case v: FunctionDeclarationStatement => if (v.address.isDefined || v.interrupt || v.name == "main") Nil else List(v.name) + case _ => Nil + }.toSet + val allCalledFunctions = getAllCalledFunctions(nodes).toSet + val unusedFunctions = allNormalFunctions -- allCalledFunctions + if (unusedFunctions.nonEmpty) { + ErrorReporting.debug("Removing unused functions: " + unusedFunctions.mkString(", ")) + } + removeFunctionsFromProgram(nodes, unusedFunctions) + } + + private def removeFunctionsFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = { + nodes match { + case (x: FunctionDeclarationStatement) :: xs if unusedVariables(x.name) => + removeFunctionsFromProgram(xs, unusedVariables) + case x :: xs => + x :: removeFunctionsFromProgram(xs, unusedVariables) + case Nil => + Nil + } + } + + def getAllCalledFunctions(c: Constant): List[String] = c match { + case HalfWordConstant(cc, _) => getAllCalledFunctions(cc) + case SubbyteConstant(cc, _) => getAllCalledFunctions(cc) + case CompoundConstant(_, l, r) => getAllCalledFunctions(l) ++ getAllCalledFunctions(r) + case MemoryAddressConstant(th) => List( + th.name, + th.name.stripSuffix(".addr"), + th.name.stripSuffix(".hi"), + th.name.stripSuffix(".lo"), + th.name.stripSuffix(".addr.lo"), + th.name.stripSuffix(".addr.hi")) + case _ => Nil + } + + def getAllCalledFunctions(expressions: List[Node]): List[String] = expressions.flatMap { + case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) + case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) + case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) + case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) + case s: Statement => getAllCalledFunctions(s.getAllExpressions) + case s: VariableExpression => List( + s.name, + s.name.stripSuffix(".addr"), + s.name.stripSuffix(".hi"), + s.name.stripSuffix(".lo"), + s.name.stripSuffix(".addr.lo"), + s.name.stripSuffix(".addr.hi")) + case s: LiteralExpression => Nil + case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil) + case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2)) + case FunctionCallExpression(name, xs) => name :: getAllCalledFunctions(xs) + case IndexedExpression(arr, index) => arr :: getAllCalledFunctions(List(index)) + case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l)) + case _ => Nil + } + +} diff --git a/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala new file mode 100644 index 00000000..83025f09 --- /dev/null +++ b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala @@ -0,0 +1,104 @@ +package millfork.node.opt + +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ + +/** + * @author Karol Stasiak + */ +object UnusedGlobalVariables extends NodeOptimization { + + override def optimize(nodes: List[Node]): List[Node] = { + + // TODO: volatile + val allNonvolatileGlobalVariables = nodes.flatMap { + case v: VariableDeclarationStatement => if (v.address.isDefined) Nil else List(v.name) + case v: ArrayDeclarationStatement => if (v.address.isDefined) Nil else List(v.name) + case _ => Nil + }.toSet + val allReadVariables = getAllReadVariables(nodes).toSet + val unusedVariables = allNonvolatileGlobalVariables -- allReadVariables + if (unusedVariables.nonEmpty) { + ErrorReporting.debug("Removing unused global variables: " + unusedVariables.mkString(", ")) + } + removeVariablesFromProgram(nodes, unusedVariables.flatMap(v => Set(v, v + ".hi", v + ".lo"))) + } + + private def removeVariablesFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = { + nodes match { + case (x: ArrayDeclarationStatement) :: xs if unusedVariables(x.name) => removeVariablesFromProgram(xs, unusedVariables) + case (x: VariableDeclarationStatement) :: xs if unusedVariables(x.name) => removeVariablesFromProgram(xs, unusedVariables) + case (x: FunctionDeclarationStatement) :: xs => + x.copy(statements = x.statements.map(s => removeVariablesFromStatement(s, unusedVariables))) :: removeVariablesFromProgram(xs, unusedVariables) + case x :: xs => + x :: removeVariablesFromProgram(xs, unusedVariables) + case Nil => + Nil + } + } + + def getAllReadVariables(c: Constant): List[String] = c match { + case HalfWordConstant(cc, _) => getAllReadVariables(cc) + case SubbyteConstant(cc, _) => getAllReadVariables(cc) + case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r) + case MemoryAddressConstant(th) => List(th.name.takeWhile(_ != '.')) + case _ => Nil + } + + def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap { + case s: VariableDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.initialValue.toList) + case s: ArrayDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.elements.getOrElse(Nil)) + case s: FunctionDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.statements.getOrElse(Nil)) + case Assignment(VariableExpression(_), expr) => getAllReadVariables(expr :: Nil) + case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => getAllReadVariables(params) + case s: Statement => getAllReadVariables(s.getAllExpressions) + case s: VariableExpression => List(s.name.takeWhile(_ != '.')) + case s: LiteralExpression => Nil + case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil) + case SumExpression(xs, _) => getAllReadVariables(xs.map(_._2)) + case FunctionCallExpression(name, xs) => name :: getAllReadVariables(xs) + case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index)) + case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l)) + case _ => Nil + } + + def removeVariablesFromStatement(statements: List[Statement], globalsToRemove: Set[String]): List[Statement] = statements.flatMap { + case s: VariableDeclarationStatement => + if (globalsToRemove(s.name)) None else Some(s) + case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") => + if (globalsToRemove(n)) params.map(ExpressionStatement) else Some(s) + case s@Assignment(VariableExpression(n), expr) => + if (globalsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s) + case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) => + if (globalsToRemove(h)) { + if (globalsToRemove(l)) + Some(ExpressionStatement(expr)) + else + Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr)) + } else { + if (globalsToRemove(l)) + Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr)) + else + Some(s) + } + case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) => + if (globalsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr)) + else Some(s) + case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) => + if (globalsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr)) + else Some(s) + case s: IfStatement => + Some(s.copy( + thenBranch = removeVariablesFromStatement(s.thenBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]], + elseBranch = removeVariablesFromStatement(s.elseBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s: WhileStatement => + Some(s.copy( + body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s: DoWhileStatement => + Some(s.copy( + body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s => Some(s) + } + +} diff --git a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala new file mode 100644 index 00000000..5b66d389 --- /dev/null +++ b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala @@ -0,0 +1,114 @@ +package millfork.node.opt + +import millfork.assembly.AssemblyLine +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ + +/** + * @author Karol Stasiak + */ +object UnusedLocalVariables extends NodeOptimization { + + override def optimize(nodes: List[Node]): List[Node] = nodes match { + case (x: FunctionDeclarationStatement) :: xs => + x.copy(statements = x.statements.map(optimizeVariables)) :: optimize(xs) + case x :: xs => + x :: optimize(xs) + case Nil => + Nil + } + + def getAllLocalVariables(statements: List[Statement]): List[String] = statements.flatMap { + case v: VariableDeclarationStatement => List(v.name) + case x: IfStatement => getAllLocalVariables(x.thenBranch) ++ getAllLocalVariables(x.elseBranch) + case x: WhileStatement => getAllLocalVariables(x.body) + case x: DoWhileStatement => getAllLocalVariables(x.body) + case _ => Nil + } + + def getAllReadVariables(c: Constant): List[String] = c match { + case HalfWordConstant(cc, _) => getAllReadVariables(cc) + case SubbyteConstant(cc, _) => getAllReadVariables(cc) + case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r) + case MemoryAddressConstant(th) => List( + th.name, + th.name.stripSuffix(".addr"), + th.name.stripSuffix(".hi"), + th.name.stripSuffix(".lo"), + th.name.stripSuffix(".addr.lo"), + th.name.stripSuffix(".addr.hi")) + case _ => Nil + } + + def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap { + case s: VariableExpression => List( + s.name, + s.name.stripSuffix(".addr"), + s.name.stripSuffix(".hi"), + s.name.stripSuffix(".lo"), + s.name.stripSuffix(".addr.lo"), + s.name.stripSuffix(".addr.hi")) + case s: LiteralExpression => Nil + case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil) + case SumExpression(xs, _) => getAllReadVariables(xs.map(_._2)) + case FunctionCallExpression(_, xs) => getAllReadVariables(xs) + case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index)) + case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l)) + case _ => Nil + } + + + def optimizeVariables(statements: List[Statement]): List[Statement] = { + val allLocals = getAllLocalVariables(statements) + val allRead = getAllReadVariables(statements.flatMap { + case Assignment(VariableExpression(_), expression) => List(expression) + case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => params + case x => x.getAllExpressions + }).toSet + val localsToRemove = allLocals.filterNot(allRead).toSet + if (localsToRemove.nonEmpty) { + ErrorReporting.debug("Removing unused local variables: " + localsToRemove.mkString(", ")) + } + removeVariables(statements, localsToRemove) + } + + def removeVariables(statements: List[Statement], localsToRemove: Set[String]): List[Statement] = statements.flatMap { + case s: VariableDeclarationStatement => + if (localsToRemove(s.name)) None else Some(s) + case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") => + if (localsToRemove(n)) params.map(ExpressionStatement) else Some(s) + case s@Assignment(VariableExpression(n), expr) => + if (localsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s) + case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) => + if (localsToRemove(h)) { + if (localsToRemove(l)) + Some(ExpressionStatement(expr)) + else + Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr)) + } else { + if (localsToRemove(l)) + Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr)) + else + Some(s) + } + case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) => + if (localsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr)) + else Some(s) + case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) => + if (localsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr)) + else Some(s) + case s: IfStatement => + Some(s.copy( + thenBranch = removeVariables(s.thenBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]], + elseBranch = removeVariables(s.elseBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s: WhileStatement => + Some(s.copy( + body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s: DoWhileStatement => + Some(s.copy( + body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]])) + case s => Some(s) + } + +} diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala new file mode 100644 index 00000000..d4f06231 --- /dev/null +++ b/src/main/scala/millfork/output/Assembler.scala @@ -0,0 +1,612 @@ +package millfork.output + +import millfork.assembly.opt.AssemblyOptimization +import millfork.assembly.{AddrMode, AssemblyLine, Opcode} +import millfork.compiler.{CompilationContext, MlCompiler} +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node.CallGraph +import millfork.{CompilationFlag, CompilationOptions} + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ + +case class AssemblerOutput(code: Array[Byte], asm: Array[String], labels: List[(String, Int)]) + +class Assembler(private val rootEnv: Environment) { + + var env = rootEnv.allThings + var unoptimizedCodeSize = 0 + var optimizedCodeSize = 0 + var initializedArraysSize = 0 + + val mem = new CompiledMemory + val labelMap = mutable.Map[String, Int]() + val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]() + val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]() + + def writeByte(bank: Int, addr: Int, value: Byte): Unit = { + if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") + mem.banks(bank).occupied(addr) = true + mem.banks(bank).readable(addr) = true + mem.banks(bank).output(addr) = value.toByte + } + + def writeByte(bank: Int, addr: Int, value: Constant): Unit = { + if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") + mem.banks(bank).occupied(addr) = true + mem.banks(bank).readable(addr) = true + value match { + case NumericConstant(x, _) => + if (x > 0xffff) ErrorReporting.error("Byte overflow") + mem.banks(0).output(addr) = x.toByte + case _ => + bytesToWriteLater += addr -> value + } + } + + def writeWord(bank: Int, addr: Int, value: Constant): Unit = { + if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") + mem.banks(bank).occupied(addr) = true + mem.banks(bank).occupied(addr + 1) = true + mem.banks(bank).readable(addr) = true + mem.banks(bank).readable(addr + 1) = true + value match { + case NumericConstant(x, _) => + if (x > 0xffff) ErrorReporting.error("Word overflow") + mem.banks(bank).output(addr) = x.toByte + mem.banks(bank).output(addr + 1) = (x >> 8).toByte + case _ => + wordsToWriteLater += addr -> value + } + } + + def deepConstResolve(c: Constant): Long = { + c match { + case NumericConstant(v, _) => v + case MemoryAddressConstant(th) => + if (labelMap.contains(th.name)) return labelMap(th.name) + if (labelMap.contains(th.name + "`")) return labelMap(th.name) + if (labelMap.contains(th.name + ".addr")) return labelMap(th.name) + val x1 = env.maybeGet[ConstantThing](th.name).map(_.value) + val x2 = env.maybeGet[ConstantThing](th.name + "`").map(_.value) + val x3 = env.maybeGet[NormalFunction](th.name).flatMap(_.address) + val x4 = env.maybeGet[ConstantThing](th.name + ".addr").map(_.value) + val x5 = env.maybeGet[RelativeVariable](th.name).map(_.address) + val x6 = env.maybeGet[ConstantThing](th.name.stripSuffix(".array") + ".addr").map(_.value) + val x = x1.orElse(x2).orElse(x3).orElse(x4).orElse(x5).orElse(x6) + x match { + case Some(cc) => + deepConstResolve(cc) + case None => + println(th) + ??? + } + case HalfWordConstant(cc, true) => deepConstResolve(cc).>>>(8).&(0xff) + case HalfWordConstant(cc, false) => deepConstResolve(cc).&(0xff) + case SubbyteConstant(cc, i) => deepConstResolve(cc).>>>(i * 8).&(0xff) + case CompoundConstant(operator, lc, rc) => + val l = deepConstResolve(lc) + val r = deepConstResolve(rc) + operator match { + case MathOperator.Plus => l + r + case MathOperator.Minus => l - r + case MathOperator.Times => l * r + case MathOperator.Shl => l << r + case MathOperator.Shr => l >>> r + case MathOperator.DecimalPlus => asDecimal(l, r, _ + _) + case MathOperator.DecimalMinus => asDecimal(l, r, _ - _) + case MathOperator.DecimalTimes => asDecimal(l, r, _ * _) + case MathOperator.DecimalShl => asDecimal(l, 1 << r, _ * _) + case MathOperator.DecimalShr => asDecimal(l, 1 << r, _ / _) + case MathOperator.And => l & r + case MathOperator.Exor => l ^ r + case MathOperator.Or => l | r + } + } + } + + private def parseNormalToDecimalValue(a: Long): Long = { + if (a < 0) -parseNormalToDecimalValue(-a) + var x = a + var result = 0L + var multiplier = 1L + while (x > 0) { + result += multiplier * (a % 16L) + x /= 16L + multiplier *= 10L + } + result + } + + private def storeDecimalValueInNormalRespresentation(a: Long): Long = { + if (a < 0) -storeDecimalValueInNormalRespresentation(-a) + var x = a + var result = 0L + var multiplier = 1L + while (x > 0) { + result += multiplier * (a % 10L) + x /= 10L + multiplier *= 16L + } + result + } + + private def asDecimal(a: Long, b: Long, f: (Long, Long) => Long): Long = + storeDecimalValueInNormalRespresentation(f(parseNormalToDecimalValue(a), parseNormalToDecimalValue(b))) + + def assemble(callGraph: CallGraph, optimizations: Seq[AssemblyOptimization], options: CompilationOptions): AssemblerOutput = { + val platform = options.platform + + val assembly = mutable.ArrayBuffer[String]() + + env.allPreallocatables.foreach { + case InitializedArray(name, Some(NumericConstant(address, _)), items) => + var index = address.toInt + assembly.append("* = $" + index.toHexString) + assembly.append(name) + for (item <- items) { + writeByte(0, index, item) + assembly.append(" !byte " + item) + mem.banks(0).writeable(index) = true + index += 1 + } + initializedArraysSize += items.length + case InitializedArray(name, Some(_), items) => ??? + case f: NormalFunction if f.address.isDefined => + var index = f.address.get.asInstanceOf[NumericConstant].value.toInt + labelMap(f.name) = index + compileFunction(f, index, optimizations, assembly, options) + case _ => + } + + var index = platform.org + env.allPreallocatables.foreach { + case f: NormalFunction if f.address.isEmpty && f.name == "main" => + labelMap(f.name) = index + index = compileFunction(f, index, optimizations, assembly, options) + case _ => + } + env.allPreallocatables.foreach { + case f: NormalFunction if f.address.isEmpty && f.name != "main" => + labelMap(f.name) = index + index = compileFunction(f, index, optimizations, assembly, options) + case _ => + } + env.allPreallocatables.foreach { + case InitializedArray(name, None, items) => + labelMap(name) = index + assembly.append("* = $" + index.toHexString) + assembly.append(name) + for (item <- items) { + writeByte(0, index, item) + assembly.append(" !byte " + item) + mem.banks(0).writeable(index) = true + index += 1 + } + initializedArraysSize += items.length + case _ => + } + val allocator = platform.allocator + allocator.notifyAboutEndOfCode(index) + allocator.onEachByte = { addr => + mem.banks(0).readable(addr) = true + mem.banks(0).writeable(addr) = true + } + env.allocateVariables(None, callGraph, allocator, options, labelMap.put) + + env = rootEnv.allThings + + for ((addr, b) <- bytesToWriteLater) { + val value = deepConstResolve(b) + mem.banks(0).output(addr) = value.toByte + } + for ((addr, b) <- wordsToWriteLater) { + val value = deepConstResolve(b) + mem.banks(0).output(addr) = value.toByte + mem.banks(0).output(addr + 1) = value.>>>(8).toByte + } + + val start = mem.banks(0).occupied.indexOf(true) + val end = mem.banks(0).occupied.lastIndexOf(true) + val length = end - start + 1 + mem.banks(0).start = start + mem.banks(0).end = end + + labelMap.toList.sorted.foreach {case (l, v) => + assembly += f"$l%-30s = $$$v%04X" + } + labelMap.toList.sortBy{case (a,b) => b->a}.foreach {case (l, v) => + assembly += f" ; $$$v%04X = $l%s" + } + + AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList) + } + + private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = { + ErrorReporting.debug("Compiling: " + f.name, f.position) + var index = startFrom + assOut.append("* = $" + startFrom.toHexString) + val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize + unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum + val code = optimizations.foldLeft(unoptimized) { (c, opt) => + opt.optimize(f, c, options) + } + optimizedCodeSize += code.map(_.sizeInBytes).sum + import millfork.assembly.AddrMode._ + import millfork.assembly.Opcode._ + for (instr <- code) { + if (instr.isPrintable) { + assOut.append(instr.toString) + } + instr match { + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(labelName)), _) => + labelMap(labelName) = index + case AssemblyLine(_, DoesNotExist, _, _) => + () + case AssemblyLine(op, Implied, _, _) => + writeByte(0, index, Assembler.opcodeFor(op, Implied, options)) + index += 1 + case AssemblyLine(op, Relative, param, _) => + writeByte(0, index, Assembler.opcodeFor(op, Relative, options)) + writeByte(0, index + 1, param - (index + 2)) + index += 2 + case AssemblyLine(op, am@(Immediate | ZeroPage | ZeroPageX | ZeroPageY | IndexedY | IndexedX | ZeroPageIndirect), param, _) => + writeByte(0, index, Assembler.opcodeFor(op, am, options)) + writeByte(0, index + 1, param) + index += 2 + case AssemblyLine(op, am@(Absolute | AbsoluteY | AbsoluteX | Indirect | AbsoluteIndexedX), param, _) => + writeByte(0, index, Assembler.opcodeFor(op, am, options)) + writeWord(0, index + 1, param) + index += 3 + } + } + index + } +} + +object Assembler { + val opcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]() + val illegalOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]() + val cmosOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]() + + def opcodeFor(opcode: Opcode.Value, addrMode: AddrMode.Value, options: CompilationOptions): Byte = { + val key = opcode -> addrMode + opcodes.get(key) match { + case Some(v) => v + case None => + illegalOpcodes.get(key) match { + case Some(v) => + if (options.flag(CompilationFlag.EmitIllegals)) v + else ErrorReporting.fatal("Cannot assemble an illegal opcode " + key) + case None => + cmosOpcodes.get(key) match { + case Some(v) => + if (options.flag(CompilationFlag.EmitCmosOpcodes)) v + else ErrorReporting.fatal("Cannot assemble a CMOS opcode " + key) + case None => + ErrorReporting.fatal("Cannot assemble an unknown opcode " + key) + } + } + } + } + + private def op(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = { + if (x < 0 || x > 0xff) ??? + opcodes(op -> am) = x.toByte + if (am == AddrMode.Relative) opcodes(op -> AddrMode.Immediate) = x.toByte + } + + private def cm(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = { + if (x < 0 || x > 0xff) ??? + cmosOpcodes(op -> am) = x.toByte + } + + private def il(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = { + if (x < 0 || x > 0xff) ??? + illegalOpcodes(op -> am) = x.toByte + } + + def getStandardLegalOpcodes: Set[Int] = opcodes.values.map(_ & 0xff).toSet + + import AddrMode._ + import Opcode._ + + op(ADC, Immediate, 0x69) + op(ADC, ZeroPage, 0x65) + op(ADC, ZeroPageX, 0x75) + op(ADC, Absolute, 0x6D) + op(ADC, AbsoluteX, 0x7D) + op(ADC, AbsoluteY, 0x79) + op(ADC, IndexedX, 0x61) + op(ADC, IndexedY, 0x71) + + op(AND, Immediate, 0x29) + op(AND, ZeroPage, 0x25) + op(AND, ZeroPageX, 0x35) + op(AND, Absolute, 0x2D) + op(AND, AbsoluteX, 0x3D) + op(AND, AbsoluteY, 0x39) + op(AND, IndexedX, 0x21) + op(AND, IndexedY, 0x31) + + op(ASL, Implied, 0x0A) + op(ASL, ZeroPage, 0x06) + op(ASL, ZeroPageX, 0x16) + op(ASL, Absolute, 0x0E) + op(ASL, AbsoluteX, 0x1E) + + op(BIT, ZeroPage, 0x24) + op(BIT, Absolute, 0x2C) + + op(BPL, Relative, 0x10) + op(BMI, Relative, 0x30) + op(BVC, Relative, 0x50) + op(BVS, Relative, 0x70) + op(BCC, Relative, 0x90) + op(BCS, Relative, 0xB0) + op(BNE, Relative, 0xD0) + op(BEQ, Relative, 0xF0) + + op(BRK, Implied, 0) + + op(CMP, Immediate, 0xC9) + op(CMP, ZeroPage, 0xC5) + op(CMP, ZeroPageX, 0xD5) + op(CMP, Absolute, 0xCD) + op(CMP, AbsoluteX, 0xDD) + op(CMP, AbsoluteY, 0xD9) + op(CMP, IndexedX, 0xC1) + op(CMP, IndexedY, 0xD1) + + op(CPX, Immediate, 0xE0) + op(CPX, ZeroPage, 0xE4) + op(CPX, Absolute, 0xEC) + + op(CPY, Immediate, 0xC0) + op(CPY, ZeroPage, 0xC4) + op(CPY, Absolute, 0xCC) + + op(DEC, ZeroPage, 0xC6) + op(DEC, ZeroPageX, 0xD6) + op(DEC, Absolute, 0xCE) + op(DEC, AbsoluteX, 0xDE) + + op(EOR, Immediate, 0x49) + op(EOR, ZeroPage, 0x45) + op(EOR, ZeroPageX, 0x55) + op(EOR, Absolute, 0x4D) + op(EOR, AbsoluteX, 0x5D) + op(EOR, AbsoluteY, 0x59) + op(EOR, IndexedX, 0x41) + op(EOR, IndexedY, 0x51) + + op(INC, ZeroPage, 0xE6) + op(INC, ZeroPageX, 0xF6) + op(INC, Absolute, 0xEE) + op(INC, AbsoluteX, 0xFE) + + op(CLC, Implied, 0x18) + op(SEC, Implied, 0x38) + op(CLI, Implied, 0x58) + op(SEI, Implied, 0x78) + op(CLV, Implied, 0xB8) + op(CLD, Implied, 0xD8) + op(SED, Implied, 0xF8) + + op(JMP, Absolute, 0x4C) + op(JMP, Indirect, 0x6C) + + op(JSR, Absolute, 0x20) + + op(LDA, Immediate, 0xA9) + op(LDA, ZeroPage, 0xA5) + op(LDA, ZeroPageX, 0xB5) + op(LDA, Absolute, 0xAD) + op(LDA, AbsoluteX, 0xBD) + op(LDA, AbsoluteY, 0xB9) + op(LDA, IndexedX, 0xA1) + op(LDA, IndexedY, 0xB1) + + op(LDX, Immediate, 0xA2) + op(LDX, ZeroPage, 0xA6) + op(LDX, ZeroPageY, 0xB6) + op(LDX, Absolute, 0xAE) + op(LDX, AbsoluteY, 0xBE) + + op(LDY, Immediate, 0xA0) + op(LDY, ZeroPage, 0xA4) + op(LDY, ZeroPageX, 0xB4) + op(LDY, Absolute, 0xAC) + op(LDY, AbsoluteX, 0xBC) + + op(LSR, Implied, 0x4A) + op(LSR, ZeroPage, 0x46) + op(LSR, ZeroPageX, 0x56) + op(LSR, Absolute, 0x4E) + op(LSR, AbsoluteX, 0x5E) + + op(NOP, Implied, 0xEA) + + op(ORA, Immediate, 0x09) + op(ORA, ZeroPage, 0x05) + op(ORA, ZeroPageX, 0x15) + op(ORA, Absolute, 0x0D) + op(ORA, AbsoluteX, 0x1D) + op(ORA, AbsoluteY, 0x19) + op(ORA, IndexedX, 0x01) + op(ORA, IndexedY, 0x11) + + op(TAX, Implied, 0xAA) + op(TXA, Implied, 0x8A) + op(DEX, Implied, 0xCA) + op(INX, Implied, 0xE8) + op(TAY, Implied, 0xA8) + op(TYA, Implied, 0x98) + op(DEY, Implied, 0x88) + op(INY, Implied, 0xC8) + + op(ROL, Implied, 0x2A) + op(ROL, ZeroPage, 0x26) + op(ROL, ZeroPageX, 0x36) + op(ROL, Absolute, 0x2E) + op(ROL, AbsoluteX, 0x3E) + + op(ROR, Implied, 0x6A) + op(ROR, ZeroPage, 0x66) + op(ROR, ZeroPageX, 0x76) + op(ROR, Absolute, 0x6E) + op(ROR, AbsoluteX, 0x7E) + + op(RTI, Implied, 0x40) + op(RTS, Implied, 0x60) + + op(SBC, Immediate, 0xE9) + op(SBC, ZeroPage, 0xE5) + op(SBC, ZeroPageX, 0xF5) + op(SBC, Absolute, 0xED) + op(SBC, AbsoluteX, 0xFD) + op(SBC, AbsoluteY, 0xF9) + op(SBC, IndexedX, 0xE1) + op(SBC, IndexedY, 0xF1) + + op(STA, ZeroPage, 0x85) + op(STA, ZeroPageX, 0x95) + op(STA, Absolute, 0x8D) + op(STA, AbsoluteX, 0x9D) + op(STA, AbsoluteY, 0x99) + op(STA, IndexedX, 0x81) + op(STA, IndexedY, 0x91) + + op(TXS, Implied, 0x9A) + op(TSX, Implied, 0xBA) + op(PHA, Implied, 0x48) + op(PLA, Implied, 0x68) + op(PHP, Implied, 0x08) + op(PLP, Implied, 0x28) + + op(STX, ZeroPage, 0x86) + op(STX, ZeroPageY, 0x96) + op(STX, Absolute, 0x8E) + + op(STY, ZeroPage, 0x84) + op(STY, ZeroPageX, 0x94) + op(STY, Absolute, 0x8C) + + il(LAX, ZeroPage, 0xA7) + il(LAX, ZeroPageY, 0xB7) + il(LAX, Absolute, 0xAF) + il(LAX, AbsoluteY, 0xBF) + il(LAX, IndexedX, 0xA3) + il(LAX, IndexedY, 0xB3) + + il(SAX, ZeroPage, 0x87) + il(SAX, ZeroPageY, 0x97) + il(SAX, Absolute, 0x8F) + il(TAS, AbsoluteY, 0x9B) + il(AHX, AbsoluteY, 0x9F) + il(SAX, IndexedX, 0x83) + il(AHX, IndexedY, 0x93) + + il(ANC, Immediate, 0x0B) + il(ALR, Immediate, 0x4B) + il(ARR, Immediate, 0x6B) + il(XAA, Immediate, 0x8B) + il(LXA, Immediate, 0xAB) + il(SBX, Immediate, 0xCB) + + il(SLO, ZeroPage, 0x07) + il(SLO, ZeroPageX, 0x17) + il(SLO, IndexedX, 0x03) + il(SLO, IndexedY, 0x13) + il(SLO, Absolute, 0x0F) + il(SLO, AbsoluteX, 0x1F) + il(SLO, AbsoluteY, 0x1B) + + il(RLA, ZeroPage, 0x27) + il(RLA, ZeroPageX, 0x37) + il(RLA, IndexedX, 0x23) + il(RLA, IndexedY, 0x33) + il(RLA, Absolute, 0x2F) + il(RLA, AbsoluteX, 0x3F) + il(RLA, AbsoluteY, 0x3B) + + il(SRE, ZeroPage, 0x47) + il(SRE, ZeroPageX, 0x57) + il(SRE, IndexedX, 0x43) + il(SRE, IndexedY, 0x53) + il(SRE, Absolute, 0x4F) + il(SRE, AbsoluteX, 0x5F) + il(SRE, AbsoluteY, 0x5B) + + il(RRA, ZeroPage, 0x67) + il(RRA, ZeroPageX, 0x77) + il(RRA, IndexedX, 0x63) + il(RRA, IndexedY, 0x73) + il(RRA, Absolute, 0x6F) + il(RRA, AbsoluteX, 0x7F) + il(RRA, AbsoluteY, 0x7B) + + il(DCP, ZeroPage, 0xC7) + il(DCP, ZeroPageX, 0xD7) + il(DCP, IndexedX, 0xC3) + il(DCP, IndexedY, 0xD3) + il(DCP, Absolute, 0xCF) + il(DCP, AbsoluteX, 0xDF) + il(DCP, AbsoluteY, 0xDB) + + il(ISC, ZeroPage, 0xE7) + il(ISC, ZeroPageX, 0xF7) + il(ISC, IndexedX, 0xE3) + il(ISC, IndexedY, 0xF3) + il(ISC, Absolute, 0xEF) + il(ISC, AbsoluteX, 0xFF) + il(ISC, AbsoluteY, 0xFB) + + il(NOP, Immediate, 0x80) + il(NOP, ZeroPage, 0x44) + il(NOP, ZeroPageX, 0x54) + il(NOP, Absolute, 0x5C) + il(NOP, AbsoluteX, 0x1C) + + cm(NOP, Immediate, 0x02) + cm(NOP, ZeroPage, 0x44) + cm(NOP, ZeroPageX, 0x54) + cm(NOP, Absolute, 0x5C) + + cm(STZ, ZeroPage, 0x64) + cm(STZ, ZeroPageX, 0x74) + cm(STZ, Absolute, 0x9C) + cm(STZ, AbsoluteX, 0x9E) + + cm(PHX, Implied, 0xDA) + cm(PHY, Implied, 0x5A) + cm(PLX, Implied, 0xFA) + cm(PLY, Implied, 0x7A) + + cm(ORA, ZeroPageIndirect, 0x12) + cm(AND, ZeroPageIndirect, 0x32) + cm(EOR, ZeroPageIndirect, 0x52) + cm(ADC, ZeroPageIndirect, 0x72) + cm(STA, ZeroPageIndirect, 0x92) + cm(LDA, ZeroPageIndirect, 0xB2) + cm(CMP, ZeroPageIndirect, 0xD2) + cm(SBC, ZeroPageIndirect, 0xF2) + + cm(TSB, ZeroPage, 0x04) + cm(TSB, Absolute, 0x0C) + cm(TRB, ZeroPage, 0x14) + cm(TRB, Absolute, 0x1C) + + cm(BIT, ZeroPageX, 0x34) + cm(BIT, AbsoluteX, 0x3C) + cm(INC, Implied, 0x1A) + cm(DEC, Implied, 0x3A) + cm(JMP, AbsoluteIndexedX, 0x7C) + cm(WAI, Implied, 0xCB) + cm(STP, Implied, 0xDB) + +} diff --git a/src/main/scala/millfork/output/CompiledMemory.scala b/src/main/scala/millfork/output/CompiledMemory.scala new file mode 100644 index 00000000..82f91768 --- /dev/null +++ b/src/main/scala/millfork/output/CompiledMemory.scala @@ -0,0 +1,29 @@ +package millfork.output + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +class CompiledMemory { + val banks = mutable.Map(0 -> new MemoryBank) +} + +class MemoryBank { + def readByte(addr: Int) = output(addr) & 0xff + + def readWord(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + + def readMedium(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16) + + def readLong(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16) + (readByte(addr + 3) << 24) + + def readWord(addrHi: Int, addrLo: Int) = readByte(addrLo) + (readByte(addrHi) << 8) + + val output = Array.fill[Byte](1 << 16)(0) + val occupied = Array.fill(1 << 16)(false) + val readable = Array.fill(1 << 16)(false) + val writeable = Array.fill(1 << 16)(false) + var start: Int = 0 + var end: Int = 0 +} diff --git a/src/main/scala/millfork/output/OutputPackager.scala b/src/main/scala/millfork/output/OutputPackager.scala new file mode 100644 index 00000000..997d7881 --- /dev/null +++ b/src/main/scala/millfork/output/OutputPackager.scala @@ -0,0 +1,60 @@ +package millfork.output + +import java.io.ByteArrayOutputStream + +/** + * @author Karol Stasiak + */ +trait OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] +} + +case class SequenceOutput(children: List[OutputPackager]) extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val baos = new ByteArrayOutputStream + children.foreach { c => + val a = c.packageOutput(mem, bank) + baos.write(a, 0, a.length) + } + baos.toByteArray + } +} + +case class ConstOutput(byte: Byte) extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = Array(byte) +} + +case class CurrentBankFragmentOutput(start: Int, end: Int) extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val b = mem.banks(bank) + b.output.slice(start, end + 1) + } +} + +case class BankFragmentOutput(alwaysBank: Int, start: Int, end: Int) extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val b = mem.banks(alwaysBank) + b.output.slice(start, end + 1) + } +} + +object StartAddressOutput extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val b = mem.banks(bank) + Array(b.start.toByte, b.start.>>(8).toByte) + } +} + +object EndAddressOutput extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val b = mem.banks(bank) + Array(b.end.toByte, b.end.>>(8).toByte) + } +} + +object AllocatedDataOutput extends OutputPackager { + def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = { + val b = mem.banks(bank) + b.output.slice(b.start, b.end + 1) + } +} \ No newline at end of file diff --git a/src/main/scala/millfork/output/VariableAllocator.scala b/src/main/scala/millfork/output/VariableAllocator.scala new file mode 100644 index 00000000..7263518c --- /dev/null +++ b/src/main/scala/millfork/output/VariableAllocator.scala @@ -0,0 +1,96 @@ +package millfork.output + +import millfork.error.ErrorReporting +import millfork.node.{CallGraph, VariableVertex} +import millfork.{CompilationFlag, CompilationOptions} + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ + +sealed trait ByteAllocator { + def notifyAboutEndOfCode(org: Int): Unit + + def allocateBytes(count: Int, options: CompilationOptions): Int +} + +class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator { + private var nextByte = startAt + + def allocateBytes(count: Int, options: CompilationOptions): Int = { + if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1 + val t = nextByte + nextByte += count + if (nextByte > endBefore) { + ErrorReporting.fatal("Out of high memory") + } + t + } + + def notifyAboutEndOfCode(org: Int): Unit = () +} + +class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator { + var nextByte = 0x200 + + def allocateBytes(count: Int, options: CompilationOptions): Int = { + if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1 + val t = nextByte + nextByte += count + if (nextByte > endBefore) { + ErrorReporting.fatal("Out of high memory") + } + t + } + + def notifyAboutEndOfCode(org: Int): Unit = nextByte = org +} + +class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) { + + private var pointerMap = mutable.Map[Int, Set[VariableVertex]]() + private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]() + + var onEachByte: (Int => Unit) = _ + + def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = { + pointerMap.foreach { case (addr, alreadyThere) => + if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) { + pointerMap(addr) += p + return addr + } + } + pointers match { + case Nil => + ErrorReporting.fatal("Out of zero-page memory") + case next :: rest => + pointers = rest + onEachByte(next) + onEachByte(next + 1) + pointerMap(next) = Set(p) + next + } + } + + def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1) + + def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = { + if (!variableMap.contains(count)) { + variableMap(count) = mutable.Map() + } + variableMap(count).foreach { case (a, alreadyThere) => + if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) { + variableMap(count)(a) += p + return a + } + } + val addr = bytes.allocateBytes(count, options) + (addr to (addr + count)).foreach(onEachByte) + variableMap(count)(addr) = Set(p) + addr + } + + def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org) +} diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala new file mode 100644 index 00000000..667a5360 --- /dev/null +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -0,0 +1,435 @@ +package millfork.parser + +import java.nio.file.{Files, Paths} + +import fastparse.all._ +import millfork.assembly.{AddrMode, Opcode} +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ +import millfork.{CompilationOptions, SeparatedList} + +/** + * @author Karol Stasiak + */ +case class MfParser(filename: String, input: String, currentDirectory: String, options: CompilationOptions) { + + var lastPosition = Position(filename, 1, 1, 0) + var lastLabel = "" + + def toAst: Parsed[Program] = program.parse(input + "\n\n\n") + + private val lineStarts: Array[Int] = (0 +: input.zipWithIndex.filter(_._1 == '\n').map(_._2)).toArray + + def position(label: String = ""): P[Position] = Index.map(i => indexToPosition(i, label)) + + def indexToPosition(i: Int, label: String): Position = { + val prefix = lineStarts.takeWhile(_ <= i) + val newPosition = Position(filename, prefix.length, i - prefix.last, i) + if (newPosition.cursor > lastPosition.cursor) { + lastPosition = newPosition + lastLabel = label + } + newPosition + } + + val comment: P[Unit] = P("//" ~/ CharsWhile(c => c != '\n' && c != '\r', min = 0) ~ ("\r\n" | "\r" | "\n")) + + val SWS: P[Unit] = P(CharsWhileIn(" \t", min = 1)).opaque("") + + val HWS: P[Unit] = P(CharsWhileIn(" \t", min = 0)).opaque("") + + val AWS: P[Unit] = P((CharIn(" \t\n\r;") | NoCut(comment)).rep(min = 0)).opaque("") + + val EOL: P[Unit] = P(HWS ~ ("\r\n" | "\r" | "\n" | comment).opaque("") ~ AWS).opaque("") + + val letter: P[String] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_").!) + + val letterOrDigit: P[Unit] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890")) + + val lettersOrDigits: P[String] = P(CharsWhileIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890", min = 0).!) + + val identifier: P[String] = P((letter ~ lettersOrDigits).map { case (a, b) => a + b }).opaque("") + + // def operator: P[String] = P(CharsWhileIn("!-+*/><=~|&^", min=1).!) // TODO: only valid operators + + // TODO: 3-byte types + def size(value: Int, wordLiteral: Boolean, longLiteral: Boolean): Int = + if (value > 255 || value < -128 || wordLiteral) + if (value > 0xffff || longLiteral) 4 else 2 + else 1 + + def sign(abs: Int, minus: Boolean): Int = if (minus) -abs else abs + + val decimalAtom: P[LiteralExpression] = + for { + p <- position() + minus <- "-".!.? + s <- CharsWhileIn("1234567890", min = 1).!.opaque("") ~ !("x" | "b") + } yield { + val abs = Integer.parseInt(s, 10) + val value = sign(abs, minus.isDefined) + LiteralExpression(value, size(value, s.length > 3, s.length > 5)).pos(p) + } + + val binaryAtom: P[LiteralExpression] = + for { + p <- position() + minus <- "-".!.? + _ <- P("0b" | "%") ~/ Pass + s <- CharsWhileIn("01", min = 1).!.opaque("") + } yield { + val abs = Integer.parseInt(s, 2) + val value = sign(abs, minus.isDefined) + LiteralExpression(value, size(value, s.length > 8, s.length > 16)).pos(p) + } + + val hexAtom: P[LiteralExpression] = + for { + p <- position() + minus <- "-".!.? + _ <- P("0x" | "$") ~/ Pass + s <- CharsWhileIn("1234567890abcdefABCDEF", min = 1).!.opaque("") + } yield { + val abs = Integer.parseInt(s, 16) + val value = sign(abs, minus.isDefined) + LiteralExpression(value, size(value, s.length > 2, s.length > 4)).pos(p) + } + + val literalAtom: P[LiteralExpression] = binaryAtom | hexAtom | decimalAtom + + val atom: P[Expression] = P(literalAtom | (position() ~ identifier).map { case (p, i) => VariableExpression(i).pos(p) }) + + val mlOperators = List( + List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'="), + List("||", "^^"), + List("&&"), + List("==", "<=", ">=", "!=", "<", ">"), + List(":"), + List("+'", "-'", "<<'", ">>'", ">>>>", "+", "-", "&", "|", "^", "<<", ">>"), + List("*'", "*")) + + val nonStatementLevel = 1 // everything but not `=` + val mathLevel = 4 // the `:` operator + + def flags(allowed: String*): P[Set[String]] = StringIn(allowed: _*).!.rep(min = 0, sep = SWS).map(_.toSet).opaque("") + + def variableDefinition(implicitlyGlobal: Boolean): P[DeclarationStatement] = for { + p <- position() + flags <- flags("const", "static", "volatile", "stack") ~ HWS + typ <- identifier ~ SWS + name <- identifier ~/ HWS ~/ Pass + addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("
") ~ HWS + initialValue <- ("=" ~/ HWS ~/ mlExpression(1)).? ~ HWS + _ <- &(EOL) ~/ "" + } yield { + VariableDeclarationStatement(name, typ, + global = implicitlyGlobal || flags("static"), + stack = flags("stack"), + constant = flags("const"), + volatile = flags("volatile"), + initialValue, addr).pos(p) + } + + val externFunctionBody: P[Option[List[Statement]]] = P("extern" ~/ PassWith(None)) + + val paramDefinition: P[ParameterDeclaration] = for { + p <- position() + typ <- identifier ~/ SWS ~/ Pass + name <- identifier ~/ Pass + } yield { + ParameterDeclaration(typ, ByVariable(name)).pos(p) + } + + val appcSimple: P[ParamPassingConvention] = P("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "stack" | "a" | "x" | "y").!.map { + case "xy" => ByRegister(Register.XY) + case "yx" => ByRegister(Register.YX) + case "ax" => ByRegister(Register.AX) + case "ay" => ByRegister(Register.AY) + case "xa" => ByRegister(Register.XA) + case "ya" => ByRegister(Register.YA) + case "a" => ByRegister(Register.A) + case "x" => ByRegister(Register.X) + case "y" => ByRegister(Register.Y) + case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`") + } + + val appcComplex: P[ParamPassingConvention] = P((("const" | "ref").! ~/ AWS).? ~ AWS ~ identifier) map { + case (None, name) => ByVariable(name) + case (Some("const"), name) => ByConstant(name) + case (Some("ref"), name) => ByReference(name) + case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`") + } + + val asmParamDefinition: P[ParameterDeclaration] = for { + p <- position() + typ <- identifier ~ SWS + appc <- appcSimple | appcComplex + } yield ParameterDeclaration(typ, appc).pos(p) + + + val arrayListContents: P[List[Expression]] = ("[" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ "]" ~/ Pass).map(_.toList) + + val doubleQuotedString: P[List[Char]] = P("\"" ~/ CharsWhile(c => c != '\"' && c != '\n' && c != '\r').! ~ "\"").map(_.toList) + + val codec: P[TextCodec] = P(position() ~ identifier).map { + case (_, "ascii") => TextCodec.Ascii + case (_, "petscii") => TextCodec.Petscii + case (_, "pet") => TextCodec.Petscii + case (p, x) => + ErrorReporting.error(s"Unknown string encoding: `$x`", Some(p)) + TextCodec.Ascii + } + + def arrayFileContents: P[List[Expression]] = for { + p <- "file" ~ HWS ~/ "(" ~/ HWS ~/ position() + filePath <- doubleQuotedString ~/ HWS + optSlice <- ("," ~/ HWS ~/ literalAtom ~/ HWS ~/ "," ~/ HWS ~/ literalAtom ~/ HWS ~/ Pass).? + _ <- ")" ~/ Pass + } yield { + val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString)) + val slice = optSlice.fold(data) { + case (start, length) => data.drop(start.value.toInt).take(length.value.toInt) + } + slice.map(c => LiteralExpression(c & 0xff, 1)).toList + } + + def arrayStringContents: P[List[Expression]] = P(position() ~ doubleQuotedString ~/ HWS ~ codec).map { + case (p, s, co) => s.map(c => LiteralExpression(co.decode(None, c), 1).pos(p)) + } + + def arrayContents: P[List[Expression]] = arrayListContents | arrayFileContents | arrayStringContents + + def arrayDefinition: P[ArrayDeclarationStatement] = for { + p <- position() + name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS + length <- ("[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~ "]").? ~ HWS + addr <- ("@" ~/ HWS ~/ mlExpression(1)).? ~/ HWS + contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS + } yield ArrayDeclarationStatement(name, length, addr, contents).pos(p) + + def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO + + def mlExpression(level: Int): P[Expression] = { + val allowedOperators = mlOperators.drop(level).flatten + + def inner: P[SeparatedList[Expression, String]] = { + for { + head <- tightMlExpression ~/ HWS + maybeOperator <- StringIn(allowedOperators: _*).!.? + maybeTail <- maybeOperator.fold[P[Option[List[(String, Expression)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail))) + } yield { + maybeTail.fold[SeparatedList[Expression, String]](SeparatedList.of(head))(t => SeparatedList(head, t)) + } + } + + def p(list: SeparatedList[Expression, String], level: Int): Expression = + if (level == mlOperators.length) list.head + else { + val xs = list.split(mlOperators(level).toSet(_)) + xs.separators.distinct match { + case Nil => + if (xs.tail.nonEmpty) + ErrorReporting.error("Too many different operators") + p(xs.head, level + 1) + case List("+") | List("-") | List("+", "-") | List("-", "+") => + SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = false) + case List("+'") | List("-'") | List("+'", "-'") | List("-'", "+'") => + SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = true) + case List(":") => + if (xs.size != 2) { + ErrorReporting.error("The `:` operator can have only two arguments", xs.head.head.position) + LiteralExpression(0, 1) + } else { + SeparateBytesExpression(p(xs.head, level + 1), p(xs.tail.head._2, level + 1)) + } + case List(op) => + FunctionCallExpression(op, xs.items.map(value => p(value, level + 1))) + case _ => + ErrorReporting.error("Too many different operators") + LiteralExpression(0, 1) + } + } + + inner.map(x => p(x, 0)) + } + + def mlLhsExpressionSimple: P[LhsExpression] = mlIndexedExpression | (position() ~ identifier).map { case (p, n) => VariableExpression(n).pos(p) } + + def mlLhsExpression: P[LhsExpression] = { + val separated = position() ~ mlLhsExpressionSimple ~ HWS ~ ":" ~/ HWS ~ mlLhsExpressionSimple + separated.map { case (p, h, l) => SeparateBytesExpression(h, l).pos(p) } | mlLhsExpressionSimple + } + + + def mlParenExpr: P[Expression] = P("(" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ ")") + + def mlIndexedExpression: P[IndexedExpression] = for { + p <- position() + array <- identifier + index <- HWS ~ "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ "]" + } yield IndexedExpression(array, index).pos(p) + + def functionCall: P[FunctionCallExpression] = for { + p <- position() + name <- identifier + params <- HWS ~ "(" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "" + } yield FunctionCallExpression(name, params.toList).pos(p) + + val expressionStatement: P[ExecutableStatement] = mlExpression(0).map(ExpressionStatement) + + val assignmentStatement: P[ExecutableStatement] = + (position() ~ mlLhsExpression ~ HWS ~ "=" ~/ HWS ~ mlExpression(1)).map { + case (p, l, r) => Assignment(l, r).pos(p) + } + + def keywordStatement: P[ExecutableStatement] = P(returnStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement) + + def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) } + + // TODO: label and instruction in one line + def asmLabel: P[ExecutableStatement] = (identifier ~ HWS ~ ":" ~/ HWS).map(l => AssemblyStatement(Opcode.LABEL, AddrMode.DoesNotExist, VariableExpression(l), elidable = true)) + + // def zeropageAddrModeHint: P[Option[Boolean]] = Pass + + def asmOpcode: P[Opcode.Value] = (position() ~ letter.rep(exactly = 3).!).map { case (p, o) => Opcode.lookup(o, Some(p)) } + + def asmExpression: P[Expression] = (position() ~ NoCut( + ("<" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = false)) | + (">" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = true)) | + mlExpression(mathLevel) + )).map { case (p, e) => e.pos(p) } + + val commaX = HWS ~ "," ~ HWS ~ ("X" | "x") ~ HWS + val commaY = HWS ~ "," ~ HWS ~ ("Y" | "y") ~ HWS + + def asmParameter: P[(AddrMode.Value, Expression)] = { + (SWS ~ ( + ("#" ~ asmExpression).map(AddrMode.Immediate -> _) | + ("(" ~ HWS ~ asmExpression ~ HWS ~ ")" ~ commaY).map(AddrMode.IndexedY -> _) | + ("(" ~ HWS ~ asmExpression ~ commaX ~ ")").map(AddrMode.IndexedX -> _) | + ("(" ~ HWS ~ asmExpression ~ HWS ~ ")").map(AddrMode.Indirect -> _) | + (asmExpression ~ commaX).map(AddrMode.AbsoluteX -> _) | + (asmExpression ~ commaY).map(AddrMode.AbsoluteY -> _) | + asmExpression.map(AddrMode.Absolute -> _) + )).?.map(_.getOrElse(AddrMode.Implied -> LiteralExpression(0, 1))) + } + + def elidable: P[Boolean] = ("?".! ~/ HWS).?.map(_.isDefined) + + def asmInstruction: P[ExecutableStatement] = { + val lineParser: P[(Boolean, Opcode.Value, (AddrMode.Value, Expression))] = !"}" ~ elidable ~/ asmOpcode ~/ asmParameter + lineParser.map { case (elid, op, param) => + AssemblyStatement(op, param._1, param._2, elid) + } + } + + def asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros + + def statement: P[Statement] = (position() ~ P(keywordStatement | variableDefinition(false) | expressionStatement)).map { case (p, s) => s.pos(p) } + + def asmStatements: P[List[ExecutableStatement]] = ("{" ~/ AWS ~/ asmStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) + + def statements: P[List[Statement]] = ("{" ~/ AWS ~ statement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) + + def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}" + + def returnStatement: P[ExecutableStatement] = ("return" ~ !letterOrDigit ~/ HWS ~ mlExpression(nonStatementLevel).?).map(ReturnStatement) + + def ifStatement: P[ExecutableStatement] = for { + condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) + thenBranch <- AWS ~/ executableStatements + elseBranch <- (AWS ~ "else" ~/ AWS ~/ executableStatements).? + } yield IfStatement(condition, thenBranch.toList, elseBranch.getOrElse(Nil).toList) + + def whileStatement: P[ExecutableStatement] = for { + condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) + body <- AWS ~ executableStatements + } yield WhileStatement(condition, body.toList) + + def forDirection: P[ForDirection.Value] = + ("parallel" ~ HWS ~ "to").!.map(_ => ForDirection.ParallelTo) | + ("parallel" ~ HWS ~ "until").!.map(_ => ForDirection.ParallelUntil) | + "until".!.map(_ => ForDirection.Until) | + "to".!.map(_ => ForDirection.To) | + ("down" ~/ HWS ~/ "to").!.map(_ => ForDirection.DownTo) + + def forStatement: P[ExecutableStatement] = for { + identifier <- "for" ~ SWS ~/ identifier ~/ "," ~/ Pass + start <- mlExpression(nonStatementLevel) ~ HWS ~ "," ~/ HWS ~/ Pass + direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass + end <- mlExpression(nonStatementLevel) + body <- AWS ~ executableStatements + } yield ForStatement(identifier, start, end, direction, body.toList) + + def inlineAssembly: P[ExecutableStatement] = for { + condition <- "asm" ~ !letterOrDigit ~/ Pass + body <- AWS ~ asmStatements + } yield BlockStatement(body) + + //noinspection MutatorLikeMethodIsParameterless + def doWhileStatement: P[ExecutableStatement] = for { + body <- "do" ~ !letterOrDigit ~/ AWS ~ executableStatements ~/ AWS + condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) + } yield DoWhileStatement(body.toList, condition) + + def functionDefinition: P[DeclarationStatement] = for { + p <- position() + flags <- flags("asm", "inline", "interrupt", "reentrant") ~ HWS + returnType <- identifier ~ SWS + name <- identifier ~ HWS + params <- "(" ~/ AWS ~/ (if (flags("asm")) asmParamDefinition else paramDefinition).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ ")" ~/ AWS + addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("
") ~/ AWS + statements <- (externFunctionBody | (if (flags("asm")) asmStatements else statements).map(l => Some(l))) ~/ Pass + } yield { + if (flags("interrupt") && flags("inline")) ErrorReporting.error(s"Interrupt function `$name` cannot be inline", Some(p)) + if (flags("interrupt") && flags("reentrant")) ErrorReporting.error("Interrupt function `$name` cannot be reentrant", Some(p)) + if (flags("inline") && flags("reentrant")) ErrorReporting.error("Reentrant and inline exclude each other", Some(p)) + if (flags("interrupt") && returnType != "void") ErrorReporting.error("Interrupt function `$name` has to return void", Some(p)) + if (addr.isEmpty && statements.isEmpty) ErrorReporting.error("Extern function `$name` must have an address", Some(p)) + if (statements.isEmpty && !flags("asm") && params.nonEmpty) ErrorReporting.error("Extern non-asm function `$name` cannot have parameters", Some(p)) + if (flags("asm")) statements match { + case Some(Nil) => ErrorReporting.warn("Assembly function `$name` is empty, did you mean RTS or RTI", options, Some(p)) + case Some(xs) => + if (flags("interrupt")) { + if (xs.exists { + case AssemblyStatement(Opcode.RTS, _, _, _) => true + case _ => false + }) ErrorReporting.warn("Assembly interrupt function `$name` contains RTS, did you mean RTI?", options, Some(p)) + } else { + if (xs.exists { + case AssemblyStatement(Opcode.RTI, _, _, _) => true + case _ => false + }) ErrorReporting.warn("Assembly non-interrupt function `$name` contains RTI, did you mean RTS?", options, Some(p)) + } + if (!flags("inline")) { + xs.last match { + case AssemblyStatement(Opcode.RTS, _, _, _) => () // OK + case AssemblyStatement(Opcode.RTI, _, _, _) => () // OK + case AssemblyStatement(Opcode.JMP, _, _, _) => () // OK + case _ => + val validReturn = if (flags("interrupt")) "RTI" else "RTS" + ErrorReporting.warn(s"Non-inline assembly function `$name` should end in " + validReturn, options, Some(p)) + } + } + case None => () + } + FunctionDeclarationStatement(name, returnType, params.toList, + addr, + statements, + flags("inline"), + flags("asm"), + flags("interrupt"), + flags("reentrant")).pos(p) + } + + def importStatement: Parser[ImportStatement] = ("import" ~ !letterOrDigit ~/ SWS ~/ identifier).map(ImportStatement) + + def program: Parser[Program] = for { + _ <- Start ~/ AWS ~/ Pass + definitions <- (importStatement | arrayDefinition | functionDefinition | variableDefinition(true)).rep(sep = EOL) + _ <- AWS ~ End + } yield Program(definitions.toList) + + +} diff --git a/src/main/scala/millfork/parser/MinimalTestCase.scala b/src/main/scala/millfork/parser/MinimalTestCase.scala new file mode 100644 index 00000000..2f5852d6 --- /dev/null +++ b/src/main/scala/millfork/parser/MinimalTestCase.scala @@ -0,0 +1,24 @@ +package millfork.parser + +import fastparse.all._ +import fastparse.core + +object MinimalTestCase { + def AWS: P[Unit] = "\n".rep(min = 0).opaque("").log() + + def EOL: P[Unit] = "\n".rep(min = 1).opaque("").log() + + def identifier: P[String] = CharPred(_.isLetter).rep(min = 1).!.opaque("").log() + + def identifierWithSpace: P[String] = (identifier ~/ AWS ~/ Pass).opaque("").log() + + def separator: P[Unit] = ("," ~/ AWS ~/ Pass).opaque("").log() + + def identifiers: P[Seq[String]] = identifierWithSpace.rep(min = 0, sep = separator)//.opaque("").log() + + def array: P[Seq[String]] = ("[" ~/ AWS ~/ identifiers ~/ "]" ~/ Pass)//.opaque("").log() + + def arrays: Parser[Seq[Seq[String]]] = (array ~/ EOL).rep(min = 0, sep = !End ~/ Pass)//.opaque("").log() + + def program: Parser[Seq[Seq[String]]] = Start ~/ AWS ~/ arrays ~/ End +} diff --git a/src/main/scala/millfork/parser/ParserBase.scala b/src/main/scala/millfork/parser/ParserBase.scala new file mode 100644 index 00000000..4836d700 --- /dev/null +++ b/src/main/scala/millfork/parser/ParserBase.scala @@ -0,0 +1,169 @@ +package millfork.parser + +import millfork.node.Position + +/** + * @author Karol Stasiak + */ +case class ParseException(msg: String, position: Option[Position]) extends Exception + +class ParserBase(filename: String, input: String) { + + def reset(): Unit = { + cursor = 0 + line = 1 + column = FirstColumn + } + + private val FirstColumn = 0 + private val length = input.length + private var cursor = 0 + private var line = 1 + private var column = FirstColumn + + def position = Position(filename, line, column, cursor) + + def restorePosition(p: Position): Unit = { + cursor = p.cursor + column = p.column + line = p.line + } + + def error(msg: String, pos: Option[Position]): Nothing = throw ParseException(msg, pos) + + def error(msg: String, pos: Position): Nothing = throw ParseException(msg, Some(pos)) + + def error(msg: String): Nothing = throw ParseException(msg, Some(position)) + + def error() = throw ParseException("Syntax error", Some(position)) + + def nextChar() = { + if (cursor >= length) error("Unexpected end of input") + val c = input(cursor) + cursor += 1 + if (c == '\n') { + line += 1 + column = FirstColumn + } else { + column += 1 + } + c + } + + def peekChar(): Char = { + if (cursor >= length) '\0' else input(cursor) + } + + def require(char: Char): Char = { + val pos = position + val c = nextChar() + if (c != char) error(s"Expected `$char`", pos) + c + } + def require(p: Char=>Boolean, errorMsg: String = "Unexpected character"): Char = { + val pos = position + val c = nextChar() + if (!p(c)) error(errorMsg, pos) + c + } + + def require(s: String): String = { + val c = peekChars(s.length) + if (c != s) error(s"Expected `$s`") + 1 to s.length foreach (_=>nextChar()) + s + } + + def requireAny(s: String, errorMsg: String = "Unexpected character"): Char = { + val c = nextChar() + if (s.contains(c)) c + else error(errorMsg) + } + + def peek2Chars(): String = { + peekChars(2) + } + + def peekChars(n: Int): String = { + if (cursor > length - n) input.substring(cursor) else input.substring(cursor, cursor + n) + } + + def charsWhile(pred: Char => Boolean, min: Int = 0, errorMsg: String = "Unexpected character"): String = { + val sb = new StringBuilder() + while (pred(peekChar())) { + sb += nextChar() + } + val s = sb.toString + if (s.length < min) error(errorMsg) + else s + } + + def skipNextIfMatches(c: Char): Boolean = { + if (peekChar() == c) { + nextChar() + true + } else { + false + } + } + + def either(c: Char, s: String): Unit = { + if (peekChar() == c) { + nextChar() + } else if (peekChars(s.length) == s) { + require(s) + } else { + error(s"Expected either `$c` or `$s`") + } + } + + def sepOrEnd(sep: Char, end: Char): Boolean = { + val p = position + val c = nextChar() + if (c == sep) true + else if (c==end) false + else error(s"Expected `$sep` or `$end`", p) + } + + def anyOf[T](errorMsg: String, alternatives: (()=> T)*): T = { + alternatives.foreach { t => + val p = position + try { + return t() + } catch { + case _: ParseException => restorePosition(p) + } + } + error(errorMsg) + } + + def surrounded[T](left: => Any, content: => T, right: => Any): T = { + left + val result = content + right + content + } + + def followed[T](content: => T, right: => Any): T = { + val result = content + right + content + } + + def attempt[T](content: => T): Option[T] = { + val p = position + try { + Some(content) + } catch { + case _: ParseException => None + } + } + + def opaque[T](errorMsg: String)(block: =>T) :T={ + try { + block + } catch{ + case p:ParseException => error(errorMsg, p.position) + } + } +} diff --git a/src/main/scala/millfork/parser/SourceLoadingQueue.scala b/src/main/scala/millfork/parser/SourceLoadingQueue.scala new file mode 100644 index 00000000..cba2df00 --- /dev/null +++ b/src/main/scala/millfork/parser/SourceLoadingQueue.scala @@ -0,0 +1,89 @@ +package millfork.parser + +import java.nio.file.{Files, Paths} + +import fastparse.core.Parsed.{Failure, Success} +import millfork.CompilationOptions +import millfork.error.ErrorReporting +import millfork.node.{ImportStatement, Position, Program} + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +class SourceLoadingQueue(val initialFilenames: List[String], val includePath: List[String], val options: CompilationOptions) { + + private val parsedModules = mutable.Map[String, Program]() + private val moduleQueue = mutable.Queue[() => Unit]() + val extension: String = ".ml" + + + def run(): Program = { + initialFilenames.foreach { i => + parseModule(extractName(i), includePath, Right(i), options) + } + options.platform.startingModules.foreach {m => + moduleQueue.enqueue(() => parseModule(m, includePath, Left(None), options)) + } + while (moduleQueue.nonEmpty) { + moduleQueue.dequeueAll(_ => true).par.foreach(_()) + } + ErrorReporting.assertNoErrors("Parse failed") + parsedModules.values.reduce(_ + _) + } + + def lookupModuleFile(includePath: List[String], moduleName: String, position: Option[Position]): String = { + includePath.foreach { dir => + val file = Paths.get(dir, moduleName + extension).toFile + ErrorReporting.debug("Checking " + file) + if (file.exists()) { + return file.getAbsolutePath + } + } + ErrorReporting.fatal(s"Module `$moduleName` not found", position) + } + + def parseModule(moduleName: String, includePath: List[String], why: Either[Option[Position], String], options: CompilationOptions): Unit = { + val filename: String = why.fold(p => lookupModuleFile(includePath, moduleName, p), s => s) + ErrorReporting.debug(s"Parsing $filename") + val path = Paths.get(filename) + val parentDir = path.toFile.getAbsoluteFile.getParent + val src = new String(Files.readAllBytes(path)) + val parser = MfParser(filename, src, parentDir, options) + parser.toAst match { + case Success(prog, _) => + parsedModules.synchronized { + parsedModules.put(moduleName, prog) + prog.declarations.foreach { + case s@ImportStatement(m) => + if (!parsedModules.contains(m)) { + moduleQueue.enqueue(() => parseModule(m, parentDir :: includePath, Left(s.position), options)) + } + case _ => () + } + } + case f@Failure(a, b, d) => + ErrorReporting.error(s"Failed to parse the module `$moduleName` in $filename", Some(parser.indexToPosition(f.index, parser.lastLabel))) +// ErrorReporting.error(a.toString) +// ErrorReporting.error(b.toString) +// ErrorReporting.error(d.toString) +// ErrorReporting.error(d.traced.expected) +// ErrorReporting.error(d.traced.stack.toString) +// ErrorReporting.error(d.traced.traceParsers.toString) +// ErrorReporting.error(d.traced.fullStack.toString) +// ErrorReporting.error(f.toString) + if (parser.lastLabel != "") { + ErrorReporting.error(s"Syntax error: ${parser.lastLabel} expected", Some(parser.lastPosition)) + } else { + ErrorReporting.error("Syntax error", Some(parser.lastPosition)) + } + } + } + + def extractName(i: String): String = { + val noExt = i.stripSuffix(extension) + val lastSlash = noExt.lastIndexOf('/') max noExt.lastIndexOf('\\') + if (lastSlash >= 0) i.substring(lastSlash + 1) else i + } +} diff --git a/src/main/scala/millfork/parser/TextCodec.scala b/src/main/scala/millfork/parser/TextCodec.scala new file mode 100644 index 00000000..cad9580b --- /dev/null +++ b/src/main/scala/millfork/parser/TextCodec.scala @@ -0,0 +1,32 @@ +package millfork.parser +import millfork.error.ErrorReporting +import millfork.node.Position + +/** + * @author Karol Stasiak + */ +class TextCodec(val name:String, private val map: String, private val extra: Map[Char,Int]) { + def decode(position: Option[Position], c: Char): Int = { + if (extra.contains(c)) extra(c) else { + val index = map.indexOf(c) + if (index >= 0) { + index + } else { + ErrorReporting.fatal("Invalid character in string in ") + } + } + } +} + +object TextCodec { + val NotAChar = '\ufffd' + + val Ascii = new TextCodec("ASCII", 0.until(127).map{i => if (i<32) NotAChar else i.toChar}.mkString, Map.empty) + + val Petscii = new TextCodec("PETSCII", + "\ufffd" * 32 + 0x20.to(0x3f).map(_.toChar).mkString + "@abcdefghijklmnopqrstuvwxyz[£]↑←–ABCDEFGHIJKLMNOPQRSTUVWXYZ", + Map('^' -> 0x5E, 'π' -> 0x7E) + ) + + +} diff --git a/src/test/java/com/grapeshot/halfnes/CPURAM.java b/src/test/java/com/grapeshot/halfnes/CPURAM.java new file mode 100644 index 00000000..1227b5a4 --- /dev/null +++ b/src/test/java/com/grapeshot/halfnes/CPURAM.java @@ -0,0 +1,75 @@ +package com.grapeshot.halfnes; + +import com.grapeshot.halfnes.mappers.Mapper; +import millfork.output.MemoryBank; + +/** + * Since the original CPURAM class was a convoluted mess of dependencies, + * I overrode it with mine that has only few pieces of junk glue to make it work + * @author Karol Stasiak + */ +@SuppressWarnings("unused") +public class CPURAM { + + private final MemoryBank mem; + + // required by the CPU class for some reason + public Mapper mapper = new Mapper() { + @Override + public TVType getTVType() { + // the base class returns null, but this can't be null + return TVType.DENDY; + } + }; + // required by the CPU class for some reason + public APU apu = new APU(null, null, this); + + public CPURAM(MemoryBank mem) { + boolean[] readable = mem.readable(); + boolean[] writeable = mem.writeable(); + for (int i = 0xfffe; i >= 0; i--) { + if (readable[i]) { + // allow for dummy fetches by implied instructions + readable[i + 1] = true; + } + } + readable[0] = true; + readable[1] = true; + readable[2] = true; + for (int i = 0x100; i <= 0x1ff; i++) { + readable[i] = true; + writeable[i] = true; + } + for (int i = 0x4000; i <= 0x407f; i++) { + readable[i] = true; + writeable[i] = true; + } + for (int i = 0xc000; i <= 0xcfff; i++) { + readable[i] = true; + writeable[i] = true; + } + for (int i = 0xfffa; i <= 0xffff; i++) { + readable[i] = true; + writeable[i] = true; + } + this.mem = mem; + } + + + public final int read(int addr) { + addr &= 0xffff; + if (!mem.readable()[addr]) { + throw new RuntimeException("Can't read from $" + Integer.toHexString(addr)); + } + return mem.output()[addr] & 0xff; + } + + public final void write(int addr, int data) { + addr &= 0xffff; + if (!mem.writeable()[addr]) { + throw new RuntimeException("Can't write to $" + Integer.toHexString(addr)); + } + mem.output()[addr] = (byte) data; + } + +} diff --git a/src/test/scala/millfork/test/ArraySuite.scala b/src/test/scala/millfork/test/ArraySuite.scala new file mode 100644 index 00000000..68bd25f4 --- /dev/null +++ b/src/test/scala/millfork/test/ArraySuite.scala @@ -0,0 +1,133 @@ +package millfork.test + +import millfork.{Cpu, OptimizationPresets} +import millfork.assembly.opt.{AlwaysGoodOptimizations, DangerousOptimizations} +import millfork.test.emu._ +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ArraySuite extends FunSuite with Matchers { + + test("Array assignment") { + val m = EmuSuperOptimizedRun( + """ + | array output [3] @$c000 + | array input = [5,6,7] + | void main () { + | copyEntry(0) + | copyEntry(1) + | copyEntry(2) + | } + | void copyEntry(byte index) { + | output[index] = input[index] + | } + """.stripMargin) + m.readByte(0xc000) should equal(5) + m.readByte(0xc001) should equal(6) + m.readByte(0xc002) should equal(7) + + } + test("Array assignment with offset") { + EmuUltraBenchmarkRun( + """ + | array output [8] @$c000 + | void main () { + | byte i + | i = 0 + | while i != 6 { + | output[i + 2] = i + 1 + | output[i] = output[i] + | i += 1 + | } + | } + """.stripMargin) { m => + m.readByte(0xc002) should equal(1) + m.readByte(0xc007) should equal(6) + } + } + + test("Array assignment with offset 1") { + val m = new EmuRun(Cpu.StrictMos, Nil, DangerousOptimizations.All ++ OptimizationPresets.Good, true)( + """ + | array output [8] @$c000 + | void main () { + | byte i + | i = 0 + | while i != 6 { + | output[i + 2] = i + 1 + | output[i] = output[i] + | i += 1 + | } + | } + """.stripMargin) + m.readByte(0xc002) should equal(1) + m.readByte(0xc007) should equal(6) + } + + test("Array assignment through a pointer") { + val m = EmuUnoptimizedRun( + """ + | array output [3] @$c000 + | pointer p + | void main () { + | p = output.addr + | byte i + | byte ignored + | i = 1 + | word w + | w = $105 + | p[i]:ignored = w + | } + """.stripMargin) + m.readByte(0xc001) should equal(1) + + } + + test("Array in place math") { + EmuBenchmarkRun( + """ + | array output [4] @$c000 + | void main () { + | byte i + | i = 3 + | output[i] = 3 + | output[i + 1 - 1] *= 4 + | output[3] *= 5 + | } + """.stripMargin)(_.readByte(0xc003) should equal(60)) + } + + test("Array simple read") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | array a[7] + | void main () { + | byte i + | i = 6 + | a[i] = 6 + | output = a[i] + | } + """.stripMargin)(_.readByte(0xc000) should equal(6)) + } + + test("Array simple read 2") { + EmuBenchmarkRun( + """ + | word output @$c000 + | array a[7] + | void main () { + | output = 777 + | byte i + | i = 6 + | a[i] = 6 + | output = a[i] + | } + """.stripMargin){m => + m.readByte(0xc000) should equal(6) + m.readByte(0xc001) should equal(0) + } + } +} diff --git a/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala b/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala new file mode 100644 index 00000000..783fb8e0 --- /dev/null +++ b/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala @@ -0,0 +1,281 @@ +package millfork.test + +import millfork.{Cpu, OptimizationPresets} +import millfork.assembly.opt.{AlwaysGoodOptimizations, LaterOptimizations, VariableToRegisterOptimization} +import millfork.test.emu.{EmuBenchmarkRun, EmuUltraBenchmarkRun, EmuRun} +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class AssemblyOptimizationSuite extends FunSuite with Matchers { + + test("Duplicate RTS") { + EmuBenchmarkRun( + """ + | void main () { + | if 1 == 1 { + | return + | } + | } + """.stripMargin) { _ => } + } + + test("Inlining variable") { + EmuBenchmarkRun( + """ + | array output [5] @$C000 + | void main () { + | byte i + | i = 0 + | while (i<5) { + | output[i] = i + | i += 1 + | } + | } + """.stripMargin)(_.readByte(0xc003) should equal(3)) + } + + test("Loading modified variables") { + EmuBenchmarkRun( + """ + | byte output @$C000 + | void main () { + | byte x + | output = 5 + | output += 1 + | output += 1 + | output += 1 + | x = output + | output = x + | } + """.stripMargin)(_.readByte(0xc000) should equal(8)) + } + + test("Bit ops") { + EmuBenchmarkRun( + """ + | byte output @$C000 + | void main () { + | output ^= output + | output |= 5 | 6 + | output |= 5 | 6 + | output &= 5 & 6 + | output ^= 8 ^ 16 + | } + """.stripMargin)(_.readByte(0xc000) should equal(28)) + } + + test("Inlining after a while") { + EmuBenchmarkRun( + """ + | array output [2]@$C000 + | void main () { + | byte i + | output[0] = 6 + | lol() + | i = 1 + | if (i > 0) { + | output[i] = 4 + | } + | } + | void lol() {} + """.stripMargin)(_.readWord(0xc000) should equal(0x406)) + } + + test("Tail call") { + EmuBenchmarkRun( + """ + | byte output @$C000 + | void main () { + | if (output != 55) { + | output += 1 + | main() + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(55)) + } + + test("LDA-TAY elimination") { + new EmuRun(Cpu.StrictMos, OptimizationPresets.NodeOpt, List(VariableToRegisterOptimization, AlwaysGoodOptimizations.YYY), false)( + """ + | array mouse_pointer[64] + | array arrow[64] + | byte output @$C000 + | void main () { + | byte i + | i = 0 + | while i < 63 { + | mouse_pointer[i] = arrow[i] + | i += 1 + | } + | } + """.stripMargin) + } + + test("Carry flag after AND-LSR") { + EmuUltraBenchmarkRun( + """ + | byte output @$C000 + | void main () { + | output = f(5) + | } + | byte f(byte x) { + | return ((x & $1E) >> 1) + 3 + | } + | + """.stripMargin)(_.readByte(0xc000) should equal(5)) + } + + test("Index sequence") { + EmuUltraBenchmarkRun( + """ + | array output[6] @$C000 + | void main () { + | pointer o + | o = output.addr + | o[3] = 8 + | o[4] = 8 + | o[5] = 8 + | } + | + """.stripMargin){m => + m.readByte(0xc005) should equal(8) + } + } + + test("Index switching") { + EmuUltraBenchmarkRun( + """ + | array output1[6] @$C000 + | array output2[6] @$C010 + | array input[6] @$C010 + | void main () { + | static byte a + | static byte b + | input[5] = 3 + | a = five() + | b = five() + | output1[a] = input[b] + | output2[a] = input[b] + | } + | byte five() { + | return 5 + | } + | + """.stripMargin){m => + m.readByte(0xc005) should equal(3) + m.readByte(0xc015) should equal(3) + } + } + + test("TAX-BCC-RTS-TXA optimization") { + new EmuRun(Cpu.StrictMos, + OptimizationPresets.NodeOpt, List( + AlwaysGoodOptimizations.PointlessStoreAfterLoad, + LaterOptimizations.PointlessLoadAfterStore, + VariableToRegisterOptimization, + LaterOptimizations.DoubleLoadToDifferentRegisters, + LaterOptimizations.DoubleLoadToTheSameRegister, + AlwaysGoodOptimizations.PointlessRegisterTransfers, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.PointlessLoadBeforeReturn, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessRegisterTransfers, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.PointlessLoadBeforeReturn, + AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad, + AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch, + AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval, + AlwaysGoodOptimizations.IdempotentDuplicateRemoval), false)( + """ + | byte output @$C000 + | void main(){ delta() } + | byte delta () { + | output = 0 + | byte mouse_delta + | mouse_delta = 6 + | mouse_delta &= $3f + | if mouse_delta >= $20 { + | mouse_delta |= $c0 + | return 3 + | } + | return mouse_delta + | } + """.stripMargin).readByte(0xc000) should equal(0) + } + + test("Memory access detection"){ + EmuUltraBenchmarkRun( + """ + | array h [4] @$C000 + | array l [4] @$C404 + | word output @$C00C + | word a @$C200 + | void main () { + | byte i + | a = 0x102 + | barrier() + | for i,0,until,4 { + | h[i]:l[i] = a + | } + | a.lo:a.hi=a + | output = a + | } + | void barrier (){} + | + """.stripMargin){m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(1) + m.readByte(0xc002) should equal(1) + m.readByte(0xc003) should equal(1) + m.readByte(0xc404) should equal(2) + m.readByte(0xc405) should equal(2) + m.readByte(0xc406) should equal(2) + m.readByte(0xc407) should equal(2) + m.readWord(0xc00c) should equal(0x201) + } + } + + test("Memory access detection 2"){ + EmuUltraBenchmarkRun( + """ + | array h [4] + | array l [4] + | word output @$C00C + | word ptrh @$C000 + | word ptrl @$C002 + | void main () { + | ptrh = h.addr + | ptrl = l.addr + | byte i + | word a + | a = 0x102 + | barrier() + | for i,0,until,4 { + | h[i]:l[i] = a + | } + | a.lo:a.hi=a + | output = a + | couput + | } + | void barrier (){} + | + """.stripMargin){m => + val ptrh = 0xffff & m.readWord(0xC000) + val ptrl = 0xffff & m.readWord(0xC002) + m.readByte(ptrh + 0) should equal(1) + m.readByte(ptrh + 1) should equal(1) + m.readByte(ptrh + 2) should equal(1) + m.readByte(ptrh + 3) should equal(1) + m.readByte(ptrl + 0) should equal(2) + m.readByte(ptrl + 1) should equal(2) + m.readByte(ptrl + 2) should equal(2) + m.readByte(ptrl + 3) should equal(2) + m.readWord(0xc00c) should equal(0x201) + } + } +} diff --git a/src/test/scala/millfork/test/AssemblySuite.scala b/src/test/scala/millfork/test/AssemblySuite.scala new file mode 100644 index 00000000..1c7acd5c --- /dev/null +++ b/src/test/scala/millfork/test/AssemblySuite.scala @@ -0,0 +1,99 @@ +package millfork.test +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class AssemblySuite extends FunSuite with Matchers { + + test("Inline assembly") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 0 + | asm { + | inc $c000 + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(1)) + } + + test("Assembly functions") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 0 + | thing() + | } + | asm void thing() { + | inc $c000 + | rts + | } + """.stripMargin)(_.readByte(0xc000) should equal(1)) + } + + test("Empty assembly") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 1 + | asm {} + | } + """.stripMargin)(_.readByte(0xc000) should equal(1)) + } + + test("Passing params to assembly") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = f(5) + | } + | asm byte f(byte a) { + | clc + | adc #5 + | rts + | } + """.stripMargin)(_.readByte(0xc000) should equal(10)) + } + + test("Inline asm functions") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 0 + | f() + | f() + | } + | inline asm void f() { + | inc $c000 + | rts + | } + """.stripMargin)(_.readByte(0xc000) should equal(1)) + } + + test("Inline asm functions 2") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 0 + | add(output, 5) + | add(output, 5) + | } + | inline asm void add(byte ref v, byte const c) { + | lda v + | clc + | adc #c + | sta v + | rts + | } + """.stripMargin)(_.readByte(0xc000) should equal(5)) + } + +} diff --git a/src/test/scala/millfork/test/BasicSymonTest.scala b/src/test/scala/millfork/test/BasicSymonTest.scala new file mode 100644 index 00000000..ddf0f6a3 --- /dev/null +++ b/src/test/scala/millfork/test/BasicSymonTest.scala @@ -0,0 +1,28 @@ +package millfork.test + +import millfork.test.emu.EmuUnoptimizedRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class BasicSymonTest extends FunSuite with Matchers { + test("Empty test") { + EmuUnoptimizedRun( + """ + | void main () { + | + | } + """.stripMargin) + } + + test("Byte assignment") { + EmuUnoptimizedRun( + """ + | byte output @$c000 + | void main () { + | output = (1) + | } + """.stripMargin).readByte(0xc000) should equal(1) + } +} diff --git a/src/test/scala/millfork/test/BitOpSuite.scala b/src/test/scala/millfork/test/BitOpSuite.scala new file mode 100644 index 00000000..6156483d --- /dev/null +++ b/src/test/scala/millfork/test/BitOpSuite.scala @@ -0,0 +1,36 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class BitOpSuite extends FunSuite with Matchers { + + test("Word AND") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | byte b + | output = $5555 + | b = $4E + | output &= b + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x44)) + } + test("Long AND and EOR") { + EmuBenchmarkRun(""" + | long output @$c000 + | void main () { + | byte b + | word w + | output = $55555555 + | w = $505 + | output ^= w + | b = $4E + | output &= b + | } + """.stripMargin)(_.readLong(0xc000) should equal(0x40)) + } +} diff --git a/src/test/scala/millfork/test/BooleanSuite.scala b/src/test/scala/millfork/test/BooleanSuite.scala new file mode 100644 index 00000000..f9bf8b6d --- /dev/null +++ b/src/test/scala/millfork/test/BooleanSuite.scala @@ -0,0 +1,64 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class BooleanSuite extends FunSuite with Matchers { + + test("Not") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | array input = [5,6,7] + | void main () { + | byte a + | a = 5 + | if not(a < 3) {output = 4} + | if not(a > 3) {output = 3} + | } + """.stripMargin)(_.readByte(0xc000) should equal(4)) + + } + + + test("And") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | array input = [5,6,7] + | void main () { + | byte a + | byte b + | a = 5 + | b = 5 + | if a > 3 && b > 3 {output = 4} + | if a < 3 && b > 3 {output = 5} + | if a > 3 && b < 3 {output = 2} + | if a < 3 && b < 3 {output = 3} + | } + """.stripMargin)(_.readByte(0xc000) should equal(4)) + } + + + test("Or") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | array input = [5,6,7] + | void main () { + | byte a + | byte b + | a = 5 + | b = 5 + | output = 0 + | if a > 3 || b > 3 {output += 4} + | if a < 3 || b > 3 {output += 5} + | if a > 3 || b < 3 {output += 2} + | if a < 3 || b < 3 {output = 30} + | } + """.stripMargin)(_.readByte(0xc000) should equal(11)) + } +} diff --git a/src/test/scala/millfork/test/ByteDecimalMathSuite.scala b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala new file mode 100644 index 00000000..d0d8afae --- /dev/null +++ b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala @@ -0,0 +1,68 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ByteDecimalMathSuite extends FunSuite with Matchers { + + test("Decimal byte addition") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | byte a + | void main () { + | a = $36 + | output = a +' a + | } + """.stripMargin)(_.readByte(0xc000) should equal(0x72)) + } + + test("Decimal byte addition 2") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | byte a + | void main () { + | a = 1 + | output = a +' $69 + | } + """.stripMargin)(_.readByte(0xc000) should equal(0x70)) + } + + test("In-place decimal byte addition") { + EmuBenchmarkRun( + """ + | array output[3] @$c000 + | byte a + | void main () { + | a = 1 + | output[1] = 5 + | output[a] +'= 1 + | output[a] +'= $36 + | } + """.stripMargin)(_.readByte(0xc001) should equal(0x42)) + } + + test("In-place decimal byte addition 2") { + EmuBenchmarkRun( + """ + | array output[3] @$c000 + | void main () { + | byte x + | byte y + | byte tmpx + | byte tmpy + | tmpx = one() + | tmpy = one() + | x = tmpx + | y = tmpy + | output[y] = $39 + | output[x] +'= 1 + | } + | byte one() { return 1 } + """.stripMargin)(_.readByte(0xc001) should equal(0x40)) + } +} diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala new file mode 100644 index 00000000..067743e2 --- /dev/null +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -0,0 +1,158 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ByteMathSuite extends FunSuite with Matchers { + + test("Complex expression") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = (one() + one()) | (((one()<<2)-1) ^ one()) + | } + | byte one() { + | return 1 + | } + """.stripMargin)(_.readByte(0xc000) should equal(2)) + } + + test("Byte addition") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | byte a + | void main () { + | a = 1 + | output = a + a + | } + """.stripMargin)(_.readByte(0xc000) should equal(2)) + } + + test("Byte addition 2") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | byte a + | void main () { + | a = 1 + | output = a + 65 + | } + """.stripMargin)(_.readByte(0xc000) should equal(66)) + } + + test("In-place byte addition") { + EmuBenchmarkRun( + """ + | array output[3] @$c000 + | byte a + | void main () { + | a = 1 + | output[1] = 5 + | output[a] += 1 + | output[a] += 36 + | } + """.stripMargin)(_.readByte(0xc001) should equal(42)) + } + + test("Parameter order") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | array arr[6] + | void main () { + | output = 42 + | } + | byte test1(byte a) @$6000 { + | return 5 + a + | } + | byte test2(byte a) @$6100 { + | return 5 | a + | } + | byte test3(byte a) @$6200 { + | return a + arr[a] + | } + """.stripMargin)(_.readByte(0xc000) should equal(42)) + } + + test("In-place byte addition 2") { + EmuBenchmarkRun( + """ + | array output[3] @$c000 + | void main () { + | byte x + | byte y + | byte tmpx + | byte tmpy + | tmpx = one() + | tmpy = one() + | x = tmpx + | y = tmpy + | output[y] = 36 + | output[x] += 1 + | } + | byte one() { return 1 } + """.stripMargin)(_.readByte(0xc001) should equal(37)) + } + + test("In-place byte multiplication") { + multiplyCase1(0, 0) + multiplyCase1(0, 1) + multiplyCase1(0, 2) + multiplyCase1(0, 5) + multiplyCase1(1, 0) + multiplyCase1(5, 0) + multiplyCase1(7, 0) + multiplyCase1(2, 5) + multiplyCase1(7, 2) + multiplyCase1(100, 2) + multiplyCase1(54, 4) + multiplyCase1(2, 100) + multiplyCase1(4, 54) + } + + private def multiplyCase1(x: Int, y: Int): Unit = { + EmuBenchmarkRun( + s""" + | byte output @$$c000 + | void main () { + | output = $x + | output *= $y + | } + """. + stripMargin)(_.readByte(0xc000) should equal(x * y)) + } + + test("Byte multiplication") { + multiplyCase2(0, 0) + multiplyCase2(0, 1) + multiplyCase2(0, 2) + multiplyCase2(0, 5) + multiplyCase2(1, 0) + multiplyCase2(5, 0) + multiplyCase2(7, 0) + multiplyCase2(2, 5) + multiplyCase2(7, 2) + multiplyCase2(100, 2) + multiplyCase2(54, 4) + multiplyCase2(2, 100) + multiplyCase2(4, 54) + } + + private def multiplyCase2(x: Int, y: Int): Unit = { + EmuBenchmarkRun( + s""" + | byte output @$$c000 + | void main () { + | byte a + | a = $x + | output = a * $y + | } + """. + stripMargin)(_.readByte(0xc000) should equal(x * y)) + } +} diff --git a/src/test/scala/millfork/test/CmosSuite.scala b/src/test/scala/millfork/test/CmosSuite.scala new file mode 100644 index 00000000..23be132c --- /dev/null +++ b/src/test/scala/millfork/test/CmosSuite.scala @@ -0,0 +1,31 @@ +package millfork.test + +import millfork.test.emu.EmuCmosBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class CmosSuite extends FunSuite with Matchers { + + test("Zero store 1") { + EmuCmosBenchmarkRun(""" + | word output @$c000 + | void main () { + | output = 1 + | output = 0 + | } + """.stripMargin)(_.readWord(0xc000) should equal(0)) + } + test("Zero store 2") { + EmuCmosBenchmarkRun(""" + | byte output @$c000 + | void main () { + | output = 1 + | output = 0 + | output += 1 + | output <<= 1 + | } + """.stripMargin)(_.readWord(0xc000) should equal(2)) + } +} diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala new file mode 100644 index 00000000..5b849d3b --- /dev/null +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -0,0 +1,264 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ComparisonSuite extends FunSuite with Matchers { + + test("Equality and inequality") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 5 + | if (output == 5) { + | output += 1 + | } else { + | output +=2 + | } + | if (output != 6) { + | output += 78 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(6)) + } + + test("Less") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 5 + | while output < 150 { + | output += 1 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(150)) + } + + test("Compare to zero") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | byte a + | a = 150 + | while a != 0 { + | a -= 1 + | output += 1 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(150)) + } + + test("Carry flag optimization test") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | byte a + | a = 150 + | if (a >= 50) { + | output = 4 + | } else { + | output = 0 + | } + | output += get(55) + | } + | byte get(byte x) { + | if x >= 6 {return 0} else {return 128} + | } + """.stripMargin)(_.readWord(0xc000) should equal(4)) + } + + test("Does it even work") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte a + | a = 150 + | if a != 0 { + | output = 345 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(345)) + } + + test("Word comparison constant") { + val src = + """ + | byte output @$c000 + | void main () { + | output = 0 + | if 2222 == 2222 { output += 1 } + | if 2222 == 3333 { output -= 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + test("Word comparison == and !=") { + val src = + """ + | byte output @$c000 + | void main () { + | word a + | word b + | word c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if a == 4 { output += 1 } + | if a == b { output += 1 } + | if a != c { output += 1 } + | if a != 5 { output += 1 } + | if a != 260 { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + test("Word comparison <=") { + val src = + """ + | byte output @$c000 + | void main () { + | word a + | word b + | word c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if a <= 4 { output += 1 } + | if a <= 257 { output += 1 } + | if a <= b { output += 1 } + | if a <= c { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + test("Word comparison <") { + val src = + """ + | byte output @$c000 + | void main () { + | word a + | word b + | word c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if a < 5 { output += 1 } + | if a < c { output += 1 } + | if a < 257 { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + + test("Word comparison >") { + val src = + """ + | byte output @$c000 + | void main () { + | word a + | word b + | word c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if c > a { output += 1 } + | if c > 1 { output += 1 } + | if c > 0 { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + test("Word comparison >=") { + val src = + """ + | byte output @$c000 + | void main () { + | word a + | word b + | word c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if c >= 1 { output += 1 } + | if c >= a { output += 1 } + | if a >= a { output += 1 } + | if a >= 4 { output += 1 } + | if a >= 4 { output += 1 } + | if a >= 0 { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + test("Signed comparison >=") { + val src = + """ + | byte output @$c000 + | void main () { + | sbyte a + | sbyte b + | sbyte c + | output = 0 + | a = 4 + | b = 4 + | c = 5 + | if c >= 1 { output += 1 } + | if c >= a { output += 1 } + | if a >= a { output += 1 } + | if a >= 4 { output += 1 } + | if a >= 4 { output += 1 } + | if a >= 0 { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } + + test("Signed comparison < and <=") { + val src = + """ + | byte output @$c000 + | void main () { + | sbyte a + | sbyte b + | sbyte c + | output = 0 + | a = -1 + | b = 0 + | c = 1 + | if a < 0 { output += 1 } + | if b < 0 { output -= 7 } + | if c < 0 { output -= 7 } + | if a < 1 { output += 1 } + | if b < 1 { output += 1 } + | if c < 1 { output -= 7 } + | if a <= 0 { output += 1 } + | if b <= 0 { output += 1 } + | if c <= 0 { output -= 7 } + | if a <= 1 { output += 1 } + | if b <= 1 { output += 1 } + | if c <= 1 { output += 1 } + | if a <= -1 { output += 1 } + | if b <= -1 { output -= 7 } + | if c <= -1 { output -= 7 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) + } +} diff --git a/src/test/scala/millfork/test/ErasthotenesSuite.scala b/src/test/scala/millfork/test/ErasthotenesSuite.scala new file mode 100644 index 00000000..34228b25 --- /dev/null +++ b/src/test/scala/millfork/test/ErasthotenesSuite.scala @@ -0,0 +1,37 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ErasthotenesSuite extends FunSuite with Matchers { + + test("Erasthotenes") { + EmuBenchmarkRun( + """ + | const pointer sieve = $C000 + | const byte sqrt = 128 + | void main () { + | byte i + | word j + | pointer s + | i = 2 + | while i < sqrt { + | if sieve[i] == 0 { + | j = i << 1 + | s = sieve + | s += j + | while j.hi == 0 { + | s[0] = 1 + | s += i + | j += i + | } + | } + | i += 1 + | } + | } + """.stripMargin){_=>} + } +} diff --git a/src/test/scala/millfork/test/ForLoopSuite.scala b/src/test/scala/millfork/test/ForLoopSuite.scala new file mode 100644 index 00000000..bb9d9929 --- /dev/null +++ b/src/test/scala/millfork/test/ForLoopSuite.scala @@ -0,0 +1,76 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ForLoopSuite extends FunSuite with Matchers { + + test("For-to") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i,0,to,5 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } + test("For-downto") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i,5,downto,0 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } + test("For-until") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i,0,until,6 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } + test("For-parallelto") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i,0,parallelto,5 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } + test("For-paralleluntil") { + EmuBenchmarkRun( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i,0,paralleluntil,6 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } +} diff --git a/src/test/scala/millfork/test/IllegalSuite.scala b/src/test/scala/millfork/test/IllegalSuite.scala new file mode 100644 index 00000000..36e8d312 --- /dev/null +++ b/src/test/scala/millfork/test/IllegalSuite.scala @@ -0,0 +1,51 @@ +package millfork.test + +import millfork.test.emu.EmuUndocumentedRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class IllegalSuite extends FunSuite with Matchers { + + test("ALR test 1") { + EmuUndocumentedRun(""" + | byte output @$c000 + | byte input @$cfff + | void main () { + | output = (input & 0x45) >> 1 + | } + """.stripMargin) + } + + test("ALR test 2") { + EmuUndocumentedRun(""" + | byte output @$c000 + | byte input @$cfff + | void main () { + | output = (input & 0x45) >> 1 + | } + """.stripMargin) + } + + test("ISC/DCP test") { + val m = EmuUndocumentedRun(""" + | array output [10] @$c000 + | void main () { + | stack byte a + | output[5] = 36 + | output[7] = 52 + | five() + | a = 5 + | five() + | output[a] += 1 + | output[a + 2] -= 1 + | } + | byte five () { + | return 5 + | } + """.stripMargin) + m.readByte(0xc005) should equal(37) + m.readByte(0xc007) should equal(51) + } +} diff --git a/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala b/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala new file mode 100644 index 00000000..2d62fe33 --- /dev/null +++ b/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala @@ -0,0 +1,80 @@ +package millfork.test + +import millfork.assembly.opt.DangerousOptimizations +import millfork.test.emu.EmuBenchmarkRun +import millfork.{Cpu, OptimizationPresets} +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class InlineAssemblyFunctionsSuite extends FunSuite with Matchers { + + test("Poke test 1") { + EmuBenchmarkRun( + """ + | inline asm void poke(word ref addr, byte const value) { + | ?LDA #value + | STA addr + | } + | + | byte output @$c000 + | void main () { + | poke(output, 5) + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } + test("Peek test 1") { + EmuBenchmarkRun( + """ + | inline asm byte peek(word ref addr) { + | ?LDA addr + | } + | + | byte output @$c000 + | void main () { + | byte a + | a = 5 + | output = peek(a) + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } + test("Poke test 2") { + EmuBenchmarkRun( + """ + | inline asm void poke(word const addr, byte const value) { + | ?LDA #value + | STA addr + | } + | + | byte output @$c000 + | void main () { + | poke($c000, 5) + | poke($c001, 5) + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } + test("Peek test 2") { + EmuBenchmarkRun( + """ + | inline asm byte peek(word const addr) { + | ?LDA addr + | } + | + | byte output @$c000 + | void main () { + | byte a + | a = 5 + | output = peek(a.addr) + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } +} diff --git a/src/test/scala/millfork/test/LongTest.scala b/src/test/scala/millfork/test/LongTest.scala new file mode 100644 index 00000000..129bb24f --- /dev/null +++ b/src/test/scala/millfork/test/LongTest.scala @@ -0,0 +1,160 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class LongTest extends FunSuite with Matchers { + + test("Long assignment") { + EmuBenchmarkRun( + """ + | long output4 @$c000 + | long output2 @$c004 + | long output1 @$c008 + | void main () { + | output4 = $11223344 + | output2 = $11223344 + | output1 = $11223344 + | output2 = $7788 + | output1 = $55 + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x11223344) + m.readLong(0xc004) should equal(0x7788) + m.readLong(0xc008) should equal(0x55) + } + } + test("Long assignment 2") { + EmuBenchmarkRun( + """ + | long output4 @$c000 + | long output2 @$c004 + | word output1 @$c008 + | void main () { + | word w + | byte b + | w = $7788 + | b = $55 + | output4 = $11223344 + | output2 = $11223344 + | output1 = $11223344 + | output2 = w + | output1 = b + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x11223344) + m.readLong(0xc004) should equal(0x7788) + m.readLong(0xc008) should equal(0x55) + } + } + test("Long addition") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | word w + | long l + | byte b + | w = $8000 + | b = $8 + | l = $50000 + | output = 0 + | output += l + | output += w + | output += b + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x58008) + } + } + test("Long addition 2") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | output = 0 + | output += $50000 + | output += $8000 + | output += $8 + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x58008) + } + } + test("Long subtraction") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | word w + | long l + | byte b + | w = $8000 + | b = $8 + | l = $50000 + | output = $58008 + | output -= l + | output -= w + | output -= b + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0) + } + } + test("Long subtraction 2") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | output = $58008 + | output -= $50000 + | output -= $8000 + | output -= $8 + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0) + } + } + test("Long subtraction 3") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | output = $58008 + | output -= w() + | output -= b() + | } + | byte b() { + | return $8 + | } + | word w() { + | return $8000 + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x50000) + } + } + + test("Long AND") { + EmuBenchmarkRun( + """ + | long output @$c000 + | void main () { + | output = $FFFFFF + | output &= w() + | output &= b() + | } + | byte b() { + | return $77 + | } + | word w() { + | return $CCCC + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(0x44) + } + } +} diff --git a/src/test/scala/millfork/test/MinimalTest.scala b/src/test/scala/millfork/test/MinimalTest.scala new file mode 100644 index 00000000..e82c1346 --- /dev/null +++ b/src/test/scala/millfork/test/MinimalTest.scala @@ -0,0 +1,23 @@ +package millfork.test + +import fastparse.core.Mutable.Failure +import fastparse.core.Parsed.Success +import millfork.parser.MinimalTestCase +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class MinimalTest extends FunSuite with Matchers { +// ignore("a") { +// MinimalTestCase.program.parse("[]\n[a,g,%h]\n") match { +// case Success(unoptimized, _) => +// println(unoptimized) +// case f:Failure[_,_] => +// val g =f +// println(f) +// println(f.originalParser) +// fail() +// } +// } +} diff --git a/src/test/scala/millfork/test/NodeOptimizationSuite.scala b/src/test/scala/millfork/test/NodeOptimizationSuite.scala new file mode 100644 index 00000000..13f196a7 --- /dev/null +++ b/src/test/scala/millfork/test/NodeOptimizationSuite.scala @@ -0,0 +1,32 @@ +package millfork.test + +import millfork.test.emu.EmuNodeOptimizedRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class NodeOptimizationSuite extends FunSuite with Matchers { + + test("Unreachable after return") { + EmuNodeOptimizedRun( + """ + | byte crash @$ffff + | void main () { + | return + | crash = 2 + | } + """.stripMargin) + } + + test("Unused local variable") { + EmuNodeOptimizedRun( + """ + | byte crash @$ffff + | void main () { + | byte a + | a = crash + | } + """.stripMargin) + } +} diff --git a/src/test/scala/millfork/test/NonetSuite.scala b/src/test/scala/millfork/test/NonetSuite.scala new file mode 100644 index 00000000..f78f055e --- /dev/null +++ b/src/test/scala/millfork/test/NonetSuite.scala @@ -0,0 +1,29 @@ +package millfork.test + +import millfork.test.emu.EmuUnoptimizedRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class NonetSuite extends FunSuite with Matchers { + + test("Nonet operations") { + val m = EmuUnoptimizedRun( + """ + | array output [3] @$c000 + | array input = [5,6,7] + | void main () { + | word a + | a = $110 + | output[1] = a >>>> 1 + | output[2] = a >>>> 2 + | } + | void copyEntry(byte index) { + | output[index] = input[index] + | } + """.stripMargin) + m.readByte(0xc001) should equal(0x88) + m.readByte(0xc002) should equal(0x44) + } +} diff --git a/src/test/scala/millfork/test/SeparateBytesSuite.scala b/src/test/scala/millfork/test/SeparateBytesSuite.scala new file mode 100644 index 00000000..b6e2dd1a --- /dev/null +++ b/src/test/scala/millfork/test/SeparateBytesSuite.scala @@ -0,0 +1,137 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class SeparateBytesSuite extends FunSuite with Matchers { + + test("Separate assignment 1") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | output = 2:3 + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x203)) + } + + test("Separate assignment 2") { + EmuBenchmarkRun(""" + | byte output @$c000 + | byte ignore @$c001 + | void main () { + | word w + | w = $355 + | output:ignore = w + | } + """.stripMargin)(_.readByte(0xc000) should equal(3)) + } + + test("Separate assignment 3") { + EmuBenchmarkRun(""" + | byte output @$c000 + | byte ignore @$c001 + | void main () { + | output:ignore = lol() + | } + | word lol() { + | return $567 + | } + """.stripMargin)(_.readByte(0xc000) should equal(5)) + } + + test("Separate assignment 4") { + EmuBenchmarkRun(""" + | array output [5] @$c000 + | byte ignore @$c001 + | void main () { + | byte index + | index = 3 + | output[index]:ignore = lol() + | } + | word lol() { + | return $567 + | } + """.stripMargin)(_.readByte(0xc003) should equal(5)) + } + + test("Separate assignment 5") { + EmuBenchmarkRun(""" + | array output [5] @$c000 + | byte ignore @$c001 + | void main () { + | byte index + | index = 3 + | ignore:output[index] = lol() + | } + | word lol() { + | return $567 + | } + """.stripMargin)(_.readByte(0xc003) should equal(0x67)) + } + + test("Magic split array") { + EmuBenchmarkRun(""" + | array hi [16] @$c000 + | array lo [16] @$c010 + | void main () { + | word a + | word b + | word tmp + | a = 1 + | b = 1 + | byte i + | i = 0 + | while i < 16 { + | hi[i]:lo[i] = a + | tmp = a + | tmp += b + | a = b + | b = tmp + | i += 1 + | } + | } + """.stripMargin) { m=> + m.readWord(0xc000, 0xc010) should equal(1) + m.readWord(0xc001, 0xc011) should equal(1) + m.readWord(0xc002, 0xc012) should equal(2) + m.readWord(0xc003, 0xc013) should equal(3) + m.readWord(0xc004, 0xc014) should equal(5) + m.readWord(0xc005, 0xc015) should equal(8) + m.readWord(0xc006, 0xc016) should equal(13) + m.readWord(0xc007, 0xc017) should equal(21) + } + } + + test("Separate addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | byte h + | byte l + | h = 6 + | l = 5 + | output = $101 + | output += h:l + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x706)) + } + + test("Separate increase") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | byte h + | byte l + | h = 6 + | l = 5 + | (h:l) += $101 + | output = h:l + | (h:l) += 1 + | output = h:l + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x707)) + } +} diff --git a/src/test/scala/millfork/test/ShiftSuite.scala b/src/test/scala/millfork/test/ShiftSuite.scala new file mode 100644 index 00000000..68ab499b --- /dev/null +++ b/src/test/scala/millfork/test/ShiftSuite.scala @@ -0,0 +1,63 @@ +package millfork.test +import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedRun} +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ShiftSuite extends FunSuite with Matchers { + + test("In-place shifting") { + EmuUnoptimizedRun(""" + | array output [3] @$c000 + | void main () { + | output[0] = 1 + | output[1] = 3 + | output[output[0]] <<= 2 + | } + """.stripMargin).readByte(0xc001) should equal(12) + } + + test("Byte shifting") { + EmuBenchmarkRun(""" + | byte output @$c000 + | void main () { + | byte a + | a = 3 + | output = a << 2 + | } + """.stripMargin)(_.readByte(0xc000) should equal(12)) + } + + test("Word shifting") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | byte a + | a = 3 + | output = a + | output <<= 7 + | } + """.stripMargin)(_.readWord(0xc000) should equal(128 * 3)) + } + + test("Long shifting left") { + EmuBenchmarkRun(""" + | long output @$c000 + | void main () { + | output = $1010301 + | output <<= 2 + | } + """.stripMargin)(_.readLong(0xc000) should equal(0x4040C04)) + } + + test("Long shifting right") { + EmuBenchmarkRun(""" + | long output @$c000 + | void main () { + | output = $4040C04 + | output >>= 2 + | } + """.stripMargin)(_.readLong(0xc000) should equal(0x1010301)) + } +} diff --git a/src/test/scala/millfork/test/SignExtensionSuite.scala b/src/test/scala/millfork/test/SignExtensionSuite.scala new file mode 100644 index 00000000..993e8b7c --- /dev/null +++ b/src/test/scala/millfork/test/SignExtensionSuite.scala @@ -0,0 +1,44 @@ +package millfork.test + +import millfork.test.emu.EmuUnoptimizedRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class SignExtensionSuite extends FunSuite with Matchers { + + test("Sbyte to Word") { + EmuUnoptimizedRun(""" + | word output @$c000 + | void main () { + | sbyte b + | b = -1 + | output = b + | } + """.stripMargin).readWord(0xc000) should equal(0xffff) + } + test("Sbyte to Word 2") { + EmuUnoptimizedRun(""" + | word output @$c000 + | void main () { + | output = b() + | } + | sbyte b() { + | return -1 + | } + """.stripMargin).readWord(0xc000) should equal(0xffff) + } + test("Sbyte to Long") { + EmuUnoptimizedRun(""" + | long output @$c000 + | void main () { + | output = 421 + | output += b() + | } + | sbyte b() { + | return -1 + | } + """.stripMargin).readLong(0xc000) should equal(420) + } +} diff --git a/src/test/scala/millfork/test/StackVarSuite.scala b/src/test/scala/millfork/test/StackVarSuite.scala new file mode 100644 index 00000000..8e5051f0 --- /dev/null +++ b/src/test/scala/millfork/test/StackVarSuite.scala @@ -0,0 +1,142 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class StackVarSuite extends FunSuite with Matchers { + + test("Basic stack assignment") { + EmuBenchmarkRun(""" + | byte output @$c000 + | void main () { + | stack byte a + | stack byte b + | b = 4 + | a = b + | output = a + | a = output + | } + """.stripMargin)(_.readByte(0xc000) should equal(4)) + } + + test("Stack byte addition") { + EmuBenchmarkRun(""" + | byte output @$c000 + | void main () { + | stack byte a + | stack byte b + | a = $11 + | b = $44 + | b += zzz() + | b += a + | output = b + | } + | byte zzz() { + | return $22 + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x77)) + } + + test("Complex expressions involving stack variables") { + EmuBenchmarkRun(""" + | byte output @$c000 + | void main () { + | stack byte a + | a = 7 + | output = f(a) + f(a) + f(a) + | } + | asm byte f(byte a) { + | rts + | } + """.stripMargin)(_.readWord(0xc000) should equal(21)) + } + +// test("Stack byte subtraction") { +// SymonUnoptimizedRun(""" +// | byte output @$c000 +// | void main () { +// | stack byte a +// | stack byte b +// | b = $77 +// | a = $11 +// | b -= zzz() +// | b -= a +// | output = b +// | } +// | byte zzz() { +// | return $22 +// | } +// """.stripMargin).readByte(0xc000) should equal(0x44) +// } + + test("Stack word addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | stack word a + | stack word b + | a = $111 + | b = $444 + | b += zzz() + | b += a + | output = b + | } + | word zzz() { + | return $222 + | } + """.stripMargin)(_.readWord(0xc000) should equal(0x777)) + } + + test("Recursion") { + EmuBenchmarkRun(""" + | array output [6] @$c000 + | byte fails @$c010 + | void main () { + | word w + | byte i + | for i,0,until,output.length { + | w = fib(i) + | if w.hi != 0 { fails += 1 } + | output[i] = w.lo + | } + | } + | word fib(byte i) { + | stack byte j + | j = i + | if j < 2 { + | return 1 + | } + | stack word sum + | sum = fib(j-1) + | sum += fib(j-2) + | sum &= $0F3F + | return sum + | } + """.stripMargin){ m => + m.readByte(0xc010) should equal(0) + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(1) + m.readByte(0xc002) should equal(2) + m.readByte(0xc003) should equal(3) + m.readByte(0xc004) should equal(5) + m.readByte(0xc005) should equal(8) + } + } + + + test("Indexing") { + EmuBenchmarkRun(""" + | array output [200] @$c000 + | void main () { + | stack byte a + | stack byte b + | a = $11 + | b = $44 + | output[a + b] = $66 + | } + """.stripMargin){m => m.readWord(0xc055) should equal(0x66) } + } +} diff --git a/src/test/scala/millfork/test/TypeWideningSuite.scala b/src/test/scala/millfork/test/TypeWideningSuite.scala new file mode 100644 index 00000000..7160f1b0 --- /dev/null +++ b/src/test/scala/millfork/test/TypeWideningSuite.scala @@ -0,0 +1,23 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class TypeWideningSuite extends FunSuite with Matchers { + + test("Word after simple ops") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | output = random() + | output = output.hi << 1 + | } + | word random() { + | return $777 + | } + """.stripMargin)(_.readWord(0xc000) should equal(14)) + } +} diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala new file mode 100644 index 00000000..83db90d8 --- /dev/null +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -0,0 +1,103 @@ +package millfork.test +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class WordMathSuite extends FunSuite with Matchers { + + test("Word addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | word a + | void main () { + | a = 640 + | output = a + | output += a + | } + """.stripMargin)(_.readWord(0xc000) should equal(1280)) + } + + test("Word subtraction") { + EmuBenchmarkRun(""" + | word output @$c000 + | word a + | void main () { + | a = 640 + | output = 740 + | output -= a + | } + """.stripMargin)(_.readWord(0xc000) should equal(100)) + } + + test("Word subtraction 2") { + EmuBenchmarkRun(""" + | word output @$c000 + | word a + | void main () { + | a = 640 + | output = a + | output -= 400 + | } + """.stripMargin)(_.readWord(0xc000) should equal(240)) + } + + test("Byte-to-word addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | word pair + | void main () { + | pair = $A5A5 + | pair.lo = 1 + | output = 640 + | output += pair.lo + | } + """.stripMargin)(_.readWord(0xc000) should equal(641)) + } + + test("Literal addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | output = 640 + | output += -0050 + | } + """.stripMargin)(_.readWord(0xc000) should equal(590)) + } + + test("Array element addition") { + EmuBenchmarkRun(""" + | word output @$c000 + | word pair + | array b[2] + | void main () { + | byte i + | i = 1 + | b[1] = 5 + | pair = $A5A5 + | pair.lo = 1 + | output = 640 + | output += b[i] + | } + """.stripMargin)(_.readWord(0xc000) should equal(645)) + } + + test("nesdev.com example") { + EmuBenchmarkRun(""" + | byte output @$c000 + | array map [256] @$c300 + | array b[2] + | void main () { + | output = get(5, 6) + | } + | byte get(byte mx, byte my) { + | pointer p + | p = mx + | p <<= 5 + | p += map + | return p[my] + | } + """.stripMargin)(m => ()) + } +} diff --git a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala new file mode 100644 index 00000000..9188abaf --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala @@ -0,0 +1,30 @@ +package millfork.test.emu + +import millfork.output.MemoryBank + +/** + * @author Karol Stasiak + */ +object EmuBenchmarkRun { + def apply(source:String)(verifier: MemoryBank=>Unit) = { + val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source) + val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source) +// val (Timings(t2, _), m2) = SymonSuperOptimizedRun.apply2(source) +//val (Timings(t3, _), m3) = SymonQuantumOptimizedRun.apply2(source) + println(f"Before optimization: $t0%7d") + println(f"After optimization: $t1%7d") +// println(f"After quantum: $t3%7d") +// println(f"After superopt.: $t2%7d") + println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%") +// println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%") +// println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%") + println(f"Running unoptimized") + verifier(m0) + println(f"Running optimized") + verifier(m1) +// println(f"Running quantum optimized") +// verifier(m3) +// println(f"Running superoptimized") +// verifier(m2) + } +} diff --git a/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala new file mode 100644 index 00000000..77f14579 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala @@ -0,0 +1,26 @@ +package millfork.test.emu + +import millfork.output.MemoryBank + +/** + * @author Karol Stasiak + */ +object EmuCmosBenchmarkRun { + def apply(source:String)(verifier: MemoryBank=>Unit) = { + val (Timings(_, t0), m0) = EmuUnoptimizedRun.apply2(source) + val (Timings(_, t1), m1) = EmuOptimizedRun.apply2(source) + val (Timings(_, t2), m2) = EmuOptimizedCmosRun.apply2(source) + println(f"Before optimization: $t0%7d") + println(f"After NMOS optimization: $t1%7d") + println(f"After CMOS optimization: $t2%7d") + println(f"Gain unopt->NMOS: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%") + println(f"Gain unopt->CMOS: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%") + println(f"Gain NMOS->CMOS: ${(100L*(t1-t2)/t1.toDouble).round}%7d%%") + println(f"Running unoptimized") + verifier(m0) + println(f"Running NMOS-optimized") + verifier(m1) + println(f"Running CMOS-optimized") + verifier(m2) + } +} diff --git a/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala new file mode 100644 index 00000000..f3f65144 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala @@ -0,0 +1,15 @@ +package millfork.test.emu + +import millfork.Cpu +import millfork.node.opt.{UnreachableCode, UnusedLocalVariables} + +/** + * @author Karol Stasiak + */ +object EmuNodeOptimizedRun extends EmuRun( + Cpu.StrictMos, + List( + UnreachableCode, + UnusedLocalVariables), + Nil, + false) \ No newline at end of file diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala new file mode 100644 index 00000000..91aed3e5 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala @@ -0,0 +1,19 @@ +package millfork.test.emu + +import millfork.assembly.opt.CmosOptimizations +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuOptimizedCmosRun extends EmuRun( + Cpu.Cmos, + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ + CmosOptimizations.All ++ OptimizationPresets.Good ++ + CmosOptimizations.All ++ OptimizationPresets.Good ++ + CmosOptimizations.All ++ OptimizationPresets.Good, + false) + + + diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala new file mode 100644 index 00000000..13d252b0 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala @@ -0,0 +1,15 @@ +package millfork.test.emu + +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuOptimizedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good, + false) + + + diff --git a/src/test/scala/millfork/test/emu/EmuPlatform.scala b/src/test/scala/millfork/test/emu/EmuPlatform.scala new file mode 100644 index 00000000..4caca5f3 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuPlatform.scala @@ -0,0 +1,19 @@ +package millfork.test.emu + +import millfork.output.{AfterCodeByteAllocator, CurrentBankFragmentOutput, VariableAllocator} +import millfork.{Cpu, Platform} + +/** + * @author Karol Stasiak + */ +object EmuPlatform { + def get(cpu: Cpu.Value) = new Platform( + cpu, + Map(), + Nil, + CurrentBankFragmentOutput(0, 0xffff), + new VariableAllocator((0 until 256 by 2).toList, new AfterCodeByteAllocator(0xff00)), + 0x200, + ".bin" + ) +} diff --git a/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala new file mode 100644 index 00000000..9a85ac7c --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala @@ -0,0 +1,15 @@ +package millfork.test.emu + +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuQuantumOptimizedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good, + true) + + + diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala new file mode 100644 index 00000000..59b0c853 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -0,0 +1,238 @@ +package millfork.test.emu + +import com.grapeshot.halfnes.{CPU, CPURAM} +import com.loomcom.symon.InstructionTable.CpuBehavior +import com.loomcom.symon.{Bus, Cpu, CpuState} +import fastparse.core.Parsed.{Failure, Success} +import millfork.assembly.opt.AssemblyOptimization +import millfork.compiler.{CompilationContext, MlCompiler} +import millfork.env.{Environment, InitializedArray, NormalFunction} +import millfork.error.ErrorReporting +import millfork.node.StandardCallGraph +import millfork.node.opt.NodeOptimization +import millfork.output.{Assembler, MemoryBank} +import millfork.parser.MfParser +import millfork.{CompilationFlag, CompilationOptions} +import org.scalatest.Matchers + +/** + * @author Karol Stasiak + */ +case class Timings(nmos: Long, cmos: Long) + +class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], assemblyOptimizations: List[AssemblyOptimization], quantum: Boolean) extends Matchers { + + def apply(source: String): MemoryBank = { + apply2(source)._2 + } + + def emitIllegals = false + + private val timingNmos = Array[Int]( + 7, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 4, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + 6, 6, 0, 8, 3, 3, 5, 5, 4, 2, 2, 2, 4, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + + 6, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 3, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + 6, 6, 0, 8, 3, 3, 5, 5, 4, 2, 2, 2, 5, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + + 2, 6, 2, 6, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, + 2, 6, 0, 6, 4, 4, 4, 4, 2, 5, 2, 5, 5, 5, 5, 5, + 2, 6, 2, 6, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, + 2, 5, 0, 5, 4, 4, 4, 4, 2, 4, 2, 4, 4, 4, 4, 4, + + 2, 6, 2, 8, 3, 3, 5, 5, 2, 2, 2, 2, 4, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + 2, 6, 2, 8, 3, 3, 5, 5, 2, 2, 2, 2, 4, 4, 6, 6, + 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, + ) + + private val timingCmos = Array[Int]( + 7, 6, 2, 1, 5, 3, 5, 5, 3, 2, 2, 1, 6, 4, 6, 5, + 2, 5, 5, 1, 5, 4, 6, 5, 2, 4, 2, 1, 6, 4, 6, 5, + 6, 6, 2, 1, 3, 3, 5, 5, 4, 2, 2, 1, 4, 4, 6, 5, + 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 2, 1, 4, 4, 6, 5, + + 6, 6, 2, 1, 3, 3, 5, 5, 3, 2, 2, 1, 3, 4, 6, 5, + 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 3, 1, 8, 4, 6, 5, + 6, 6, 2, 1, 3, 3, 5, 5, 4, 2, 2, 1, 6, 4, 6, 5, + 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 4, 1, 6, 4, 6, 5, + + 3, 6, 2, 1, 3, 3, 3, 5, 2, 2, 2, 1, 4, 4, 4, 5, + 2, 6, 5, 1, 4, 4, 4, 5, 2, 5, 2, 1, 4, 5, 5, 5, + 2, 6, 2, 1, 3, 3, 3, 5, 2, 2, 2, 1, 4, 4, 4, 5, + 2, 5, 5, 1, 4, 4, 4, 5, 2, 4, 2, 1, 4, 4, 4, 5, + + 2, 6, 2, 1, 3, 3, 5, 5, 2, 2, 2, 3, 4, 4, 6, 5, + 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 3, 3, 4, 4, 7, 5, + 2, 6, 2, 1, 3, 3, 5, 5, 2, 2, 2, 1, 4, 4, 6, 5, + 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 4, 1, 4, 4, 7, 5, + ) + + private val variableLength = Set(0x10, 0x30, 0x50, 0x70, 0x90, 0xb0, 0xd0, 0xf0) + + private val TooManyCycles: Long = 1000000 + + private def formatBool(b: Boolean, c: Char) = if (b) c else '-' + + private def formatState(q: CpuState): String = + f"A=${q.a}%02X X=${q.x}%02X Y=${q.y}%02X S=${q.sp}%02X PC=${q.pc}%04X " + + formatBool(q.negativeFlag, 'N') + formatBool(q.overflowFlag, 'V') + formatBool(q.breakFlag, 'B') + + formatBool(q.decimalModeFlag, 'D') + formatBool(q.irqDisableFlag, 'I') + formatBool(q.zeroFlag, 'Z') + formatBool(q.carryFlag, 'C') + + def apply2(source: String): (Timings, MemoryBank) = { + Console.out.flush() + Console.err.flush() + println(source) + val platform = EmuPlatform.get(cpu) + val options = new CompilationOptions(platform, Map( + CompilationFlag.EmitIllegals -> this.emitIllegals, + CompilationFlag.DetailedFlowAnalysis -> quantum, + )) + ErrorReporting.hasErrors = false + ErrorReporting.verbosity = 999 + val parserF = MfParser("", source, "", options) + parserF.toAst match { + case Success(unoptimized, _) => + ErrorReporting.assertNoErrors("Parse failed") + + + // prepare + val program = nodeOptimizations.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt)) + val callGraph = new StandardCallGraph(program) + val env = new Environment(None, "") + env.collectDeclarations(program, options) + + val hasOptimizations = assemblyOptimizations.nonEmpty + var optimizedSize = 0L + var unoptimizedSize = 0L + // print asm + env.allPreallocatables.foreach { + case f: NormalFunction => + val result = MlCompiler.compile(CompilationContext(f.environment, f, 0, options)) + val unoptimized = result.linearize + if (hasOptimizations) { + val optimized = assemblyOptimizations.foldLeft(unoptimized) { (c, opt) => + opt.optimize(f, c, options) + } + println("Unoptimized:") + unoptimized.filter(_.isPrintable).foreach(println(_)) + println("Optimized:") + optimized.filter(_.isPrintable).foreach(println(_)) + unoptimizedSize += unoptimized.map(_.sizeInBytes).sum + optimizedSize += optimized.map(_.sizeInBytes).sum + } else { + unoptimized.filter(_.isPrintable).foreach(println(_)) + unoptimizedSize += unoptimized.map(_.sizeInBytes).sum + optimizedSize += unoptimized.map(_.sizeInBytes).sum + } + case d: InitializedArray => + println(d.name) + d.contents.foreach(c => println(" !byte " + c)) + unoptimizedSize += d.contents.length + optimizedSize += d.contents.length + } + + ErrorReporting.assertNoErrors("Compile failed") + + if (unoptimizedSize == optimizedSize) { + println(f"Size: $unoptimizedSize%5d B") + } else { + println(f"Unoptimized size: $unoptimizedSize%5d B") + println(f"Optimized size: $optimizedSize%5d B") + println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%") + } + + // compile + val assembler = new Assembler(env) + assembler.assemble(callGraph, assemblyOptimizations, options) + assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") } + + ErrorReporting.assertNoErrors("Code generation failed") + + val memoryBank = assembler.mem.banks(0) + platform.cpu match { + case millfork.Cpu.Cmos => + runViaSymon(memoryBank, platform.org, CpuBehavior.CMOS_6502) + case millfork.Cpu.Ricoh => + runViaHalfnes(memoryBank, platform.org) + case millfork.Cpu.Mos => + ErrorReporting.fatal("There's no NMOS emulator with decimal mode support") + Timings(-1, -1) -> memoryBank + case _ => + runViaSymon(memoryBank, platform.org, CpuBehavior.NMOS_6502) + } + case f: Failure[_, _] => + println(f) + println(f.extra.toString) + println(f.lastParser.toString) + ErrorReporting.error("Syntax error", Some(parserF.lastPosition)) + ??? + } + } + + def runViaHalfnes(memoryBank: MemoryBank, org: Int): (Timings, MemoryBank) = { + val cpu = new CPU(new CPURAM(memoryBank)) + cpu.reset() + cpu.PC = org + // stack underflow cannot be easily detected directly, + // but since the stack is full of zeroes, an underflowing RTS jumps to $0001 + while (cpu.PC.&(0xffff) > 1 && cpu.clocks < TooManyCycles) { + // println(cpu.status()) + cpu.runcycle(0, 0) + } + println("clocks: " + cpu.clocks) + System.out.flush() + cpu.clocks.toLong should be < TooManyCycles + println(cpu.clocks + " NMOS cycles") + cpu.flagstobyte().&(8).==(0) should be(true) + Timings(cpu.clocks, 0) -> memoryBank + } + + def runViaSymon(memoryBank: MemoryBank, org: Int, behavior: CpuBehavior): (Timings, MemoryBank) = { + val cpu = new Cpu + cpu.setBehavior(behavior) + val ram = new SymonTestRam(memoryBank) + val bus = new Bus(1 << 16) + bus.addCpu(cpu) + bus.addDevice(ram) + cpu.setBus(bus) + cpu.setProgramCounter(org) + cpu.setStackPointer(0xff) + val legal = Assembler.getStandardLegalOpcodes + + var countNmos = 0L + var countCmos = 0L + while (cpu.getStackPointer > 1 && countCmos < TooManyCycles) { + // println(cpu.disassembleNextOp()) + val pcBefore = cpu.getProgramCounter + cpu.step() + val pcAfter = cpu.getProgramCounter + // println(formatState(cpu.getCpuState)) + val instruction = cpu.getInstruction + if (behavior == CpuBehavior.NMOS_6502 || behavior == CpuBehavior.NMOS_WITH_ROR_BUG) { + if (!legal(instruction)) { + throw new RuntimeException("unexpected illegal: " + instruction.toHexString) + } + } + countNmos += timingNmos(instruction) + countCmos += timingCmos(instruction) + if (variableLength(instruction)) { + val jump = pcAfter - pcBefore + if (jump <= 0 || jump > 3) { + countNmos += 1 + countCmos += 1 + } + } + } + countCmos should be < TooManyCycles + println(countNmos + " NMOS cycles") + println(countCmos + " CMOS cycles") + cpu.getDecimalModeFlag should be(false) + Timings(countNmos, countCmos) -> memoryBank + } + +} diff --git a/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala new file mode 100644 index 00000000..fdf03635 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala @@ -0,0 +1,17 @@ +package millfork.test.emu + +import millfork.assembly.opt.SuperOptimizer +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +// TODO : it doesn't work +object EmuSuperQuantumOptimizedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + List(SuperOptimizer), + true) + + + diff --git a/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala new file mode 100644 index 00000000..30be3c0d --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala @@ -0,0 +1,17 @@ +package millfork.test.emu + +import millfork.assembly.opt.SuperOptimizer +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +// TODO : it doesn't work +object EmuSuperOptimizedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + List(SuperOptimizer), + false) + + + diff --git a/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala new file mode 100644 index 00000000..008b7dcc --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala @@ -0,0 +1,35 @@ +package millfork.test.emu + +import millfork.output.MemoryBank + +/** + * @author Karol Stasiak + */ +object EmuUltraBenchmarkRun { + def apply(source:String)(verifier: MemoryBank=>Unit) = { + val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source) + val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source) + val (Timings(t2, _), m2) = EmuSuperOptimizedRun.apply2(source) + val (Timings(t3, _), m3) = EmuQuantumOptimizedRun.apply2(source) + val (Timings(t4, _), m4) = EmuSuperQuantumOptimizedRun.apply2(source) + println(f"Before optimization: $t0%7d") + println(f"After optimization: $t1%7d") + println(f"After superopt.: $t2%7d") + println(f"After quantum: $t3%7d") + println(f"After superquantum: $t4%7d") + println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%") + println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%") + println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%") + println(f"Super quantum gain: ${(100L*(t0-t4)/t0.toDouble).round}%7d%%") + println(f"Running unoptimized") + verifier(m0) + println(f"Running optimized") + verifier(m1) + println(f"Running superoptimized") + verifier(m2) + println(f"Running quantum optimized") + verifier(m3) + println(f"Running superquantum optimized") + verifier(m4) + } +} diff --git a/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala b/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala new file mode 100644 index 00000000..16f9d1ae --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala @@ -0,0 +1,19 @@ +package millfork.test.emu + +import millfork.assembly.opt.UndocumentedOptimizations +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuUndocumentedRun extends EmuRun( + Cpu.Ricoh, // not Cpu.Mos, because I haven't found an emulator that supports both illegals and decimal mode yet + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ UndocumentedOptimizations.All ++ OptimizationPresets.Good ++ UndocumentedOptimizations.All ++ OptimizationPresets.Good, + false) { + + override def emitIllegals = true +} + + + diff --git a/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala new file mode 100644 index 00000000..71af4c70 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala @@ -0,0 +1,9 @@ +package millfork.test.emu + +import millfork.Cpu + + +/** + * @author Karol Stasiak + */ +object EmuUnoptimizedRun extends EmuRun(Cpu.StrictMos, Nil, Nil, false) \ No newline at end of file diff --git a/src/test/scala/millfork/test/emu/SymonTestRam.scala b/src/test/scala/millfork/test/emu/SymonTestRam.scala new file mode 100644 index 00000000..0ae21e28 --- /dev/null +++ b/src/test/scala/millfork/test/emu/SymonTestRam.scala @@ -0,0 +1,39 @@ +package millfork.test.emu + +import com.loomcom.symon.devices.Device +import millfork.output.MemoryBank + +/** + * @author Karol Stasiak + */ +class SymonTestRam(mem: MemoryBank) extends Device(0x0000, 0xffff, "RAM") { + + mem.readable(1) = true + mem.readable(2) = true + + (0x100 to 0x1ff).foreach { stack => + mem.writeable(stack) = true + mem.readable(stack) = true + } + + (0xc000 to 0xcfff).foreach { himem => + mem.writeable(himem) = true + mem.readable(himem) = true + } + + override def write(i: Int, i1: Int): Unit = { + if (!mem.writeable(i)) { + throw new RuntimeException(s"Can't write to $$${i.toHexString}") + } + mem.output(i) = i1.toByte + } + + override def read(i: Int, b: Boolean): Int = { + if (!mem.readable(i)) { + throw new RuntimeException(s"Can't read from $$${i.toHexString}") + } + mem.output(i) + } + + override def toString: String = "TestRam" +}