diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..f288702d
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/README.md b/README.md
index 42efeaff..2143467b 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,8 @@
# Millfork
-A middle-level programming language targeting 6502-based microcomputers.
+A middle-level programming language targeting 6502-based microcomputers.
+
+Distributed under GPLv3 (see [LICENSE](LICENSE))
**UNDER DEVELOPMENT, NOT FOR PRODUCTION USE**
diff --git a/build.sbt b/build.sbt
new file mode 100644
index 00000000..1cc38069
--- /dev/null
+++ b/build.sbt
@@ -0,0 +1,35 @@
+name := "millfork"
+
+version := "0.0.1-SNAPSHOT"
+
+scalaVersion := "2.12.3"
+
+resolvers += Resolver.mavenLocal
+
+libraryDependencies += "com.lihaoyi" %% "fastparse" % "1.0.0"
+
+libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.2"
+
+libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.4"
+
+libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % "test"
+
+// these two not in Maven Central or any other public repo
+// get them from the following links or just build millfork without tests:
+// https://github.com/sethm/symon
+// https://github.com/andrew-hoffman/halfnes/tree/061
+
+libraryDependencies += "com.loomcom.symon" % "symon" % "1.3.0-SNAPSHOT" % "test"
+
+libraryDependencies += "com.grapeshot" % "halfnes" % "061" % "test"
+
+mainClass in Compile := Some("millfork.Main")
+
+assemblyJarName := "millfork.jar"
+
+//lazy val root = (project in file(".")).
+// enablePlugins(BuildInfoPlugin).
+// settings(
+// buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion),
+// buildInfoPackage := "hello"
+// )
\ No newline at end of file
diff --git a/doc/README.md b/doc/README.md
new file mode 100644
index 00000000..d6bdedb3
--- /dev/null
+++ b/doc/README.md
@@ -0,0 +1,7 @@
+# Documentation
+
+## Tutorial
+
+* [Getting started](tutorial/01-getting-started.md)
+
+* [Basic functions and variables](tutorial/02-functions-variables.md)
\ No newline at end of file
diff --git a/doc/target-platforms.md b/doc/target-platforms.md
new file mode 100644
index 00000000..ec53a857
--- /dev/null
+++ b/doc/target-platforms.md
@@ -0,0 +1,93 @@
+# Target platforms
+
+Currently, Millfork supports creating disk- or tape-based programs for Commodore and Atari 8-bit computers,
+but it may be expanded to support other 6502-based platforms in the future.
+
+## Supported platforms
+
+The following platforms are currently supported:
+
+* `c64` – Commodore 64
+
+* `c16` – Commodore 16
+
+* `plus4` – Commodore Plus/4
+
+* `vic20` – Commodore VIC-20 without memory expansion
+
+* `vic20_3k` – Commodore VIC-20 with 3K memory expansion
+
+* `vic20_8k` – Commodore VIC-20 with 8K or 16K memory expansion
+
+* `c128` – Commodore 128 in its native mode
+
+* `pet` – Commodore PET
+
+* `a8` – Atari 8-bit computers
+
+The primary and most tested platform is Commodore 64.
+
+Currently, all targets assume that the program will be loaded from disk or tape.
+Cartridge targets are not yet available.
+
+## Adding a custom platform
+
+Every platform is defined in an `.ini` file with an appropriate name.
+
+#### `[compilation]` section
+
+* `arch` – CPU architecture. It defines which instructions are available. Available values:
+
+ * `nmos`
+
+ * `strict` (= NMOS without illegal instructions)
+
+ * `ricoh` (= NMOS without decimal mode)
+
+ * `strictricoh`
+
+ * `cmos` (= 65C02)
+
+* `modules` – comma-separated list of modules that will be automatically imported
+
+* other compilation options (they can be overridden using commandline options):
+
+ * `emit_illegals` – whether the compiler should emit illegal instructions, default `false`
+
+ * `emit_cmos` – whether the compiler should emit CMOS instructions, default is `true` on `cmos` and `false` elsewhere
+
+ * `decimal_mode` – whether the compiler should emit decimal instructions, default is `false` on `ricoh` and `strictricoh` and `true` elsewhere
+
+ * `ro_arrays` – whether the compiler should warn upon array writes, default is `false`
+
+ * `prevent_jmp_indirect_bug` – whether the compiler should try to avoid the indirect JMP bug, default is `false` on `cmos` and `true` elsewhere
+
+#### `[allocation]` section
+
+* `main_org` – the address for the `main` function; all the other functions will be placed after it
+
+* `zp_pointers` – either a list of comma separated zeropage addresses that can be used by the program as zeropage pointers, or `all` for all. Each value should be the address of the first of two free bytes in the zeropage.
+
+* `himem_style` – not yet supported
+
+* `himem_start` – the first address used for non-zeropage variables, or `after_code` if the variables should be allocated after the code
+
+* `himem_end` – the last address available for non-zeropage variables
+
+#### `[output]` section
+
+* `style` – not yet supported
+
+* `format` – output file format; a comma-separated list of tokens:
+
+ * literal byte values
+
+ * `startaddr` – little-endian 16-bit address of the first used byte of the compiled output
+
+ * `endaddr` – little-endian 16-bit address of the last used byte of the compiled output
+
+ * `allocated` – all used bytes
+
+ * `:` - inclusive range of bytes
+
+* `extension` – target file extension, with or without the dot
\ No newline at end of file
diff --git a/doc/tutorial/01-getting-started.md b/doc/tutorial/01-getting-started.md
new file mode 100644
index 00000000..923f8fb5
--- /dev/null
+++ b/doc/tutorial/01-getting-started.md
@@ -0,0 +1,54 @@
+# Getting started
+
+## Hello world example
+
+Save the following as `hello_world.ml`:
+
+```
+import stdio
+
+array hello_world = "hello world" petscii
+
+void main(){
+ putstr(hello_world, hello_world.length)
+ while(true){}
+}
+```
+
+Compile is using the following commandline:
+
+```
+java millfork.jar hello_world.ml -o hello_world -t c64 -I path_to_millfork\include
+```
+
+Run the output executable (here using the VICE emulator):
+
+```
+x64 hello_world.prg
+```
+
+## Basic commandline usage
+
+The following options are crucial when compiling your sources:
+
+* `-o FILENAME` – specifies the base name for your output file, an appropriate file extension will be appended (`prg` for Commodore, `xex` for Atari, `asm` for assembly output, `lbl` for label file)
+
+* `-I DIR;DIR;DIR;...` – specifies the paths to directories with modules to include.
+
+* `-t PLATFORM` – specifies the target platform (`c64` is the default). Each platform is defined in an `.ini` file in the include directory. For the list of supported platforms, see [Supported platforms](../target-platforms.md)
+
+You may be also interested in the following:
+
+* `-O`, `-O2`, `-O3` – enable optimization (various levels)
+
+* `--detailed-flow` – use more resource-consuming but more precise flow analysis engine for better optimization
+
+* `-s` – additionally generate assembly output
+
+* `-g` – additionally generate a label file, in format compatible with VICE emulator
+
+* `-r PROGRAM` – automatically launch given program after successful compilation
+
+* `-Wall` – enable all warnings
+
+* `--help` – list all commandline options
\ No newline at end of file
diff --git a/doc/tutorial/02-functions-variables.md b/doc/tutorial/02-functions-variables.md
new file mode 100644
index 00000000..e590f14e
--- /dev/null
+++ b/doc/tutorial/02-functions-variables.md
@@ -0,0 +1,16 @@
+# Functions and variables
+
+TODO: write all of this
+
+## Basic types
+
+## Defining variables
+
+## Built-in operators
+
+### Byte operators
+
+| a | a | a |
+| -- | -- | -- |
+| a | a | a |
+
diff --git a/examples/hello_world/hello_world.mfk b/examples/hello_world/hello_world.mfk
new file mode 100644
index 00000000..4bbbf4df
--- /dev/null
+++ b/examples/hello_world/hello_world.mfk
@@ -0,0 +1,11 @@
+// compile with
+// java -jar millfork.jar -I ${PATH}/include -t ${platform} ${PATH}/examples/hello_world/hello_world.mfk
+
+import stdio
+
+array hello_world = "hello world" petscii
+
+void main(){
+ putstr(hello_world, hello_world.length)
+ while(true){}
+}
\ No newline at end of file
diff --git a/include/a8.ini b/include/a8.ini
new file mode 100644
index 00000000..0723c075
--- /dev/null
+++ b/include/a8.ini
@@ -0,0 +1,22 @@
+[compilation]
+arch=strict
+modules=a8_kernel
+
+
+[allocation]
+main_org=$2000
+; TODO
+zp_pointers=$80,$82,$84,$86,$88,$8a,$8c,$8e,$90,$92,$94,$96,$98,$9a,$9c,$9e,$a0,$a2,$a4
+;TODO
+himem_style=per_bank
+himem_start=after_code
+;TODO
+himem_end=$3FFF
+
+[output]
+;TODO
+style=per_bank
+format=$FF,$FF,$E0,$02,$E1,$02,startaddr,startaddr,endaddr,allocated
+extension=xex
+
+
diff --git a/include/a8_kernel.mfk b/include/a8_kernel.mfk
new file mode 100644
index 00000000..76afe512
--- /dev/null
+++ b/include/a8_kernel.mfk
@@ -0,0 +1,9 @@
+asm void putchar(byte a) {
+ tax
+ lda $347
+ pha
+ lda $346
+ pha
+ txa
+ rts
+}
\ No newline at end of file
diff --git a/include/c128.ini b/include/c128.ini
new file mode 100644
index 00000000..6d40a6cd
--- /dev/null
+++ b/include/c128.ini
@@ -0,0 +1,20 @@
+[compilation]
+arch=nmos
+modules=c128_hardware,loader_1c01,c128_kernal
+
+
+[allocation]
+main_org=$1C0D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+; TODO
+himem_end=$FEFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/c128_hardware.mfk b/include/c128_hardware.mfk
new file mode 100644
index 00000000..ba0b81dd
--- /dev/null
+++ b/include/c128_hardware.mfk
@@ -0,0 +1,5 @@
+import c64_vic
+import c64_sid
+import c64_cia
+
+array c64_color_ram [1000] @$D800
diff --git a/include/c128_kernal.mfk b/include/c128_kernal.mfk
new file mode 100644
index 00000000..ac73145d
--- /dev/null
+++ b/include/c128_kernal.mfk
@@ -0,0 +1,5 @@
+// Routines from Commodore 128 KERNAL ROM
+
+// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
+// Input: A = Byte to write.
+asm void putchar(byte a) @$FFD2 extern
\ No newline at end of file
diff --git a/include/c1531.mfk b/include/c1531.mfk
new file mode 100644
index 00000000..f42e8eba
--- /dev/null
+++ b/include/c1531.mfk
@@ -0,0 +1,60 @@
+// mouse driver for Commodore 1531 mouse on Commodore 64
+
+import mouse
+import c64_hardware
+
+sbyte _c1531_calculate_delta (byte old, byte new) {
+ byte mouse_delta
+ mouse_delta = (new - old)
+ mouse_delta &= $3f
+ if mouse_delta >= $20 {
+ mouse_delta |= $c0
+ }
+ return mouse_delta
+}
+
+byte _c1531_handle_x() {
+ static byte _c1531_old_pot_x
+ sbyte mouse_delta
+ byte new_pot_x
+
+ new_pot_x = sid_paddle_x >> 1
+ mouse_delta = _c1531_calculate_delta(_c1531_old_pot_x, new_pot_x)
+ _c1531_old_pot_x = new_pot_x
+
+ mouse_x += mouse_delta
+ mouse_x.hi &= 1
+
+ if mouse_x > 319 {
+ if mouse_delta > 0 {
+ mouse_x = 319
+ } else {
+ mouse_x = 0
+ }
+ }
+}
+
+byte _c1531_handle_y() {
+ static byte _c1531_old_pot_y
+ byte new_pot_y
+ sbyte mouse_delta
+
+ new_pot_y = sid_paddle_y >> 1
+ mouse_delta = _c1531_calculate_delta(_c1531_old_pot_y, new_pot_y)
+ _c1531_old_pot_y = new_pot_y
+ mouse_y -= mouse_delta
+ if mouse_y > 199 {
+ if mouse_delta > 0 {
+ mouse_y = 0
+ } else {
+ mouse_y = 199
+ }
+ }
+}
+
+void c1531_mouse () {
+
+ cia1_pra = ($3f & cia1_pra) | $40
+ _c1531_handle_x()
+ _c1531_handle_y()
+}
\ No newline at end of file
diff --git a/include/c16.ini b/include/c16.ini
new file mode 100644
index 00000000..bba71a26
--- /dev/null
+++ b/include/c16.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=loader_1001,c264_kernal,c264_hardware
+
+
+[allocation]
+main_org=$100D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$3FFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/c264_hardware.mfk b/include/c264_hardware.mfk
new file mode 100644
index 00000000..b83cf6e6
--- /dev/null
+++ b/include/c264_hardware.mfk
@@ -0,0 +1 @@
+import c16_ted
\ No newline at end of file
diff --git a/include/c264_kernal.mfk b/include/c264_kernal.mfk
new file mode 100644
index 00000000..2f1ae990
--- /dev/null
+++ b/include/c264_kernal.mfk
@@ -0,0 +1,5 @@
+// Routines from C16 and Plus/4 KERNAL ROM
+
+// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
+// Input: A = Byte to write.
+asm void putchar(byte a) @$FFD2 extern
\ No newline at end of file
diff --git a/include/c264_ted.mfk b/include/c264_ted.mfk
new file mode 100644
index 00000000..580b08f9
--- /dev/null
+++ b/include/c264_ted.mfk
@@ -0,0 +1,20 @@
+
+const byte black = 0
+const byte white = $71
+const byte red = $22
+const byte cyan = $43
+const byte purple = $24
+const byte green = $35
+const byte blue = $16
+const byte yellow = $57
+const byte orange = $28
+const byte brown = $19
+const byte light_red = $32
+const byte dark_grey = $21
+const byte dark_gray = $21
+const byte medium_grey = $31
+const byte medium gray = $31
+const byte light_green = $55
+const byte light_blue = $36
+const byte light_grey = $41
+const byte light_gray = $41
\ No newline at end of file
diff --git a/include/c64.ini b/include/c64.ini
new file mode 100644
index 00000000..dcaf4d21
--- /dev/null
+++ b/include/c64.ini
@@ -0,0 +1,37 @@
+; Commodore 64
+; assuming a program loaded from disk or tape
+
+[compilation]
+; CPU architecture: nmos, strictnmos, ricoh, strictricoh, cmos
+arch=nmos
+; modules to load
+modules=c64_hardware,loader_0801,c64_kernal,stdlib
+; optionally: default flags
+emit_illegals=true
+
+
+[allocation]
+; where the main function should be allocated, also the start of bank 0
+main_org=$80D
+; list of free zp pointer locations (these assume that BASIC will keep working)
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+; where to allocate non-zp variables
+himem_style=per_bank
+himem_start=after_code
+himem_end=$9FFF
+
+[output]
+; how the banks are laid out in the output files; so far, there is no bank support in the compiler yet
+style=per_bank
+; output file format
+; startaddr - little-endian address of the first used byte in the bank
+; endaddr - little-endian address of the last used byte in the bank
+; allocated - all used bytes in the bank
+; : - bytes from the current bank
+; :addr>: - bytes from arbitrary bank
+; - single byte
+format=startaddr,allocated
+; default output file extension
+extension=prg
+
+
diff --git a/include/c64_basic.mfk b/include/c64_basic.mfk
new file mode 100644
index 00000000..2e455c03
--- /dev/null
+++ b/include/c64_basic.mfk
@@ -0,0 +1,6 @@
+// Routines from C64 BASIC ROM
+
+import c64_kernal
+
+// print a 16-bit number on the standard output
+asm void putword(word xa) @$BDCD extern
\ No newline at end of file
diff --git a/include/c64_cia.mfk b/include/c64_cia.mfk
new file mode 100644
index 00000000..f077b672
--- /dev/null
+++ b/include/c64_cia.mfk
@@ -0,0 +1,40 @@
+// Hardware addresses for C64
+
+// CIA1
+byte cia1_pra @$DC00
+byte cia1_prb @$DC01
+byte cia1_ddra @$DC02
+byte cia1_ddrb @$DC03
+byte cia2_pra @$DD00
+byte cia2_prb @$DD01
+byte cia2_ddra @$DD02
+byte cia2_ddrb @$DD03
+
+inline asm void cia_disable_irq() {
+ LDA #$7f
+ LDA $dc0d
+ LDA $dd0d
+ LDA $dc0d
+ LDA $dd0d
+}
+
+
+inline void vic_bank_0000() {
+ cia2_ddra = $C0
+ cia2_pra = $C0
+}
+
+inline void vic_bank_4000() {
+ cia2_ddra = $C0
+ cia2_pra = $80
+}
+
+inline void vic_bank_8000() {
+ cia2_ddra = $C0
+ cia2_pra = $40
+}
+
+inline void vic_bank_C000() {
+ cia2_ddra = $C0
+ cia2_pra = $00
+}
\ No newline at end of file
diff --git a/include/c64_hardware.mfk b/include/c64_hardware.mfk
new file mode 100644
index 00000000..0347eb68
--- /dev/null
+++ b/include/c64_hardware.mfk
@@ -0,0 +1,41 @@
+import c64_vic
+import c64_sid
+import c64_cia
+import cpu6510
+
+array c64_color_ram [1000] @$D800
+
+inline void c64_ram_only() {
+ cpu6510_ddr = 7
+ cpu6510_port = 0
+}
+
+inline void c64_ram_io() {
+ cpu6510_ddr = 7
+ cpu6510_port = 5
+}
+
+inline void c64_ram_io_kernal() {
+ cpu6510_ddr = 7
+ cpu6510_port = 6
+}
+
+inline void c64_ram_io_basic() {
+ cpu6510_ddr = 7
+ cpu6510_port = 7
+}
+
+inline void c64_ram_charset() {
+ cpu6510_ddr = 7
+ cpu6510_port = 1
+}
+
+inline void c64_ram_charset_kernal() {
+ cpu6510_ddr = 7
+ cpu6510_port = 2
+}
+
+inline void c64_ram_charset_basic() {
+ cpu6510_ddr = 7
+ cpu6510_port = 3
+}
\ No newline at end of file
diff --git a/include/c64_kernal.mfk b/include/c64_kernal.mfk
new file mode 100644
index 00000000..ed69eb0b
--- /dev/null
+++ b/include/c64_kernal.mfk
@@ -0,0 +1,32 @@
+// Routines from C64 KERNAL ROM
+
+// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
+// Input: A = Byte to write.
+asm void putchar(byte a) @$FFD2 extern
+
+// OPEN. Open file. (Must call SETLFS and SETNAM beforehands.)
+asm void open() @$FFC0 extern
+
+// CLOSE. Close file.
+// Input: A = Logical number.
+asm void close(byte a) @$FFC0 extern
+
+// SETLFS. Set file parameters.
+// Input: A = Logical number; X = Device number; Y = Secondary address.
+asm void setlfs(byte a, byte x, byte y) @$FFBA extern
+
+// SETNAM. Set file name parameters.
+// Input: A = File name length; X/Y = Pointer to file name.
+asm void setnam(word yx, byte a) @$FFBA extern
+
+// LOAD. Load or verify file. (Must call SETLFS and SETNAM beforehands.)
+// Input: A: 0 = Load, 1-255 = Verify; X/Y = Load address (if secondary address = 0).
+// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1); X/Y = Address of last byte loaded/verified (if Carry = 0).
+asm clear_carry load(byte a, word yx) @$FFD5 extern
+
+// SAVE. Save file. (Must call SETLFS and SETNAM beforehands.)
+// Input: A = Address of zero page register holding start address of memory area to save; X/Y = End address of memory area plus 1.
+// Output: Carry: 0 = No errors, 1 = Error; A = KERNAL error code (if Carry = 1).
+asm clear_carry save(byte a, word yx) @$FFD5 extern
+
+word irq_pointer @$314
\ No newline at end of file
diff --git a/include/c64_sid.mfk b/include/c64_sid.mfk
new file mode 100644
index 00000000..79cd772b
--- /dev/null
+++ b/include/c64_sid.mfk
@@ -0,0 +1,24 @@
+// Hardware addresses for C64
+
+// SID
+
+word sid_v1_freq @$D400
+word sid_v1_pulse @$D402
+byte sid_v1_cr @$D404
+byte sid_v1_ad @$D405
+byte sid_v1_sr @$D409
+
+word sid_v2_freq @$D407
+word sid_v2_pulse @$D409
+byte sid_v2_cr @$D40B
+byte sid_v2_ad @$D40C
+byte sid_v2_sr @$D40D
+
+word sid_v3_freq @$D40E
+word sid_v3_pulse @$D410
+byte sid_v3_cr @$D412
+byte sid_v3_ad @$D413
+byte sid_v3_sr @$D414
+
+byte sid_paddle_x @$D419
+byte sid_paddle_y @$D41A
diff --git a/include/c64_vic.mfk b/include/c64_vic.mfk
new file mode 100644
index 00000000..25f33259
--- /dev/null
+++ b/include/c64_vic.mfk
@@ -0,0 +1,154 @@
+// Hardware addresses for C64
+
+// VIC-II
+byte vic_spr0_x @$D000
+byte vic_spr0_y @$D001
+byte vic_spr1_x @$D002
+byte vic_spr1_y @$D003
+byte vic_spr2_x @$D004
+byte vic_spr2_y @$D005
+byte vic_spr3_x @$D006
+byte vic_spr3_y @$D007
+byte vic_spr4_x @$D008
+byte vic_spr4_y @$D009
+byte vic_spr5_x @$D00A
+byte vic_spr5_y @$D00B
+byte vic_spr6_x @$D00C
+byte vic_spr6_y @$D00D
+byte vic_spr7_x @$D00E
+byte vic_spr7_y @$D00F
+byte vic_spr_hi_x @$D010
+byte vic_cr1 @$D011
+byte vic_raster @$D012
+byte vic_lp_x @$D013
+byte vic_lp_y @$D014
+byte vic_spr_ena @$D015
+byte vic_cr2 @$D016
+byte vic_spr_exp_y @$D017
+byte vic_mem @$D018
+byte vic_irq @$D019
+byte vic_irq_ena @$D01A
+byte vic_spr_dp @$D01B
+byte vic_spr_mcolor @$D01C
+byte vic_spr_exp_x @$D01D
+byte vic_spr_ss_col @$D01E
+byte vic_spr_sd_col @$D01F
+byte vic_border @$D020
+byte vic_bg_color0 @$D021
+byte vic_bg_color1 @$D022
+byte vic_bg_color2 @$D023
+byte vic_bg_color3 @$D024
+byte vic_spr_color1 @$D025
+byte vic_spr_color2 @$D026
+byte vic_spr0_color @$D027
+byte vic_spr1_color @$D028
+byte vic_spr2_color @$D029
+byte vic_spr3_color @$D02A
+byte vic_spr4_color @$D02B
+byte vic_spr5_color @$D02C
+byte vic_spr6_color @$D02D
+byte vic_spr7_color @$D02E
+
+array vic_spr_coord [16] @$D000
+array vic_spr_color [8] @$D027
+
+inline void vic_enable_multicolor() {
+ vic_cr2 |= 0x10
+}
+
+inline void vic_disable_multicolor() {
+ vic_cr2 &= 0xEF
+}
+
+inline void vic_enable_bitmap() {
+ vic_cr1 |= 0x20
+}
+
+inline void vic_disable_bitmap() {
+ vic_cr1 &= 0xDF
+}
+
+inline void vic_24_rows() {
+ vic_cr1 &= 0xF7
+}
+
+inline void vic_25_rows() {
+ vic_cr1 |= 8
+}
+
+inline void vic_38_columns() {
+ vic_cr2 &= 0xF7
+}
+
+inline void vic_40_columns() {
+ vic_cr2 |= 8
+}
+
+inline void vic_disable_irq() {
+ vic_irq_ena = 0
+ vic_irq += 1
+}
+
+// base: divisible by $400, $0000-$3C00 allowed
+//inline void vic_screen(word const base) {
+// vic_mem = (vic_mem & $0F) | (base >> 6)
+//}
+
+inline void vic_charset_0000() {
+ vic_mem = (vic_mem & $F1)
+}
+inline void vic_charset_0800() {
+ vic_mem = (vic_mem & $F1) | 2
+}
+inline void vic_charset_1000() {
+ vic_mem = (vic_mem & $F1) | 4
+}
+inline void vic_charset_1800() {
+ vic_mem = (vic_mem & $F1) | 6
+}
+inline void vic_charset_2000() {
+ vic_mem = (vic_mem & $F1) | 8
+}
+inline void vic_charset_2800() {
+ vic_mem = (vic_mem & $F1) | $A
+}
+inline void vic_charset_3000() {
+ vic_mem = (vic_mem & $F1) | $C
+}
+inline void vic_charset_3800() {
+ vic_mem = (vic_mem & $F1) | $E
+}
+
+inline void vic_bitmap_0000() {
+ vic_mem &= $F7
+}
+inline void vic_bitmap_2000() {
+ vic_mem |= 8
+}
+
+// x, y < 8
+// default: x=0, y=3
+void vic_set_scroll(byte x, byte y) {
+ vic_cr1 = (vic_cr1 & $F8) | y
+ vic_cr2 = (vic_cr2 & $F8) | x
+}
+
+const byte black = 0
+const byte white = 1
+const byte red = 2
+const byte cyan = 3
+const byte purple = 4
+const byte green = 5
+const byte blue = 6
+const byte yellow = 7
+const byte orange = 8
+const byte brown = 9
+const byte light_red = 10
+const byte dark_grey = 11
+const byte dark_gray = 11
+const byte medium_grey = 12
+const byte medium_gray = 12
+const byte light_green = 13
+const byte light_blue = 14
+const byte light_grey = 15
+const byte light_gray = 15
\ No newline at end of file
diff --git a/include/cpu6510.mfk b/include/cpu6510.mfk
new file mode 100644
index 00000000..90a674ca
--- /dev/null
+++ b/include/cpu6510.mfk
@@ -0,0 +1,3 @@
+
+byte cpu6510_ddr @0
+byte cpu6510_port @1
diff --git a/include/loader_0401.mfk b/include/loader_0401.mfk
new file mode 100644
index 00000000..0a748c51
--- /dev/null
+++ b/include/loader_0401.mfk
@@ -0,0 +1,15 @@
+array _basic_loader @$401 = [
+ $0b,
+ 4,
+ 10,
+ 0,
+ $9e,
+ $31,
+ $30,
+ $33,
+ $37,
+ 0,
+ 0,
+ 0
+ ]
+
diff --git a/include/loader_0801.mfk b/include/loader_0801.mfk
new file mode 100644
index 00000000..796b8b2f
--- /dev/null
+++ b/include/loader_0801.mfk
@@ -0,0 +1,15 @@
+array _basic_loader @$801 = [
+ $0b,
+ $08,
+ 10,
+ 0,
+ $9e,
+ $32,
+ $30,
+ $36,
+ $31,
+ 0,
+ 0,
+ 0
+ ]
+
diff --git a/include/loader_1001.mfk b/include/loader_1001.mfk
new file mode 100644
index 00000000..b16d858c
--- /dev/null
+++ b/include/loader_1001.mfk
@@ -0,0 +1,15 @@
+array _basic_loader @$1001 = [
+ $0b,
+ $10,
+ 10,
+ 0,
+ $9e,
+ $34,
+ $31,
+ $30,
+ $39,
+ 0,
+ 0,
+ 0
+ ]
+
diff --git a/include/loader_1201.mfk b/include/loader_1201.mfk
new file mode 100644
index 00000000..144a3e90
--- /dev/null
+++ b/include/loader_1201.mfk
@@ -0,0 +1,15 @@
+array _basic_loader @$1201 = [
+ $0b,
+ $12,
+ 10,
+ 0,
+ $9e,
+ $34,
+ $36,
+ $32,
+ $31,
+ 0,
+ 0,
+ 0
+ ]
+
diff --git a/include/loader_1c01.mfk b/include/loader_1c01.mfk
new file mode 100644
index 00000000..73962381
--- /dev/null
+++ b/include/loader_1c01.mfk
@@ -0,0 +1,15 @@
+array _basic_loader @$1C01 = [
+ $0b,
+ $1C,
+ 10,
+ 0,
+ $9e,
+ $37,
+ $31,
+ $38,
+ $31,
+ 0,
+ 0,
+ 0
+ ]
+
diff --git a/include/mouse.mfk b/include/mouse.mfk
new file mode 100644
index 00000000..f4a65210
--- /dev/null
+++ b/include/mouse.mfk
@@ -0,0 +1,8 @@
+// Generic module for mouse support
+// Resolutions up to 512x256 are supported
+
+
+// Mouse X coordinate
+word mouse_x
+// Mouse Y coordinate
+byte mouse_y
diff --git a/include/pet.ini b/include/pet.ini
new file mode 100644
index 00000000..ca975341
--- /dev/null
+++ b/include/pet.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=loader_0401,pet_kernal
+
+
+[allocation]
+main_org=$40D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$FFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/pet_kernal.mfk b/include/pet_kernal.mfk
new file mode 100644
index 00000000..31d7daf8
--- /dev/null
+++ b/include/pet_kernal.mfk
@@ -0,0 +1,5 @@
+// Routines from Commodore PET KERNAL ROM
+
+// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
+// Input: A = Byte to write.
+asm void putchar(byte a) @$FFD2 extern
\ No newline at end of file
diff --git a/include/plus4.ini b/include/plus4.ini
new file mode 100644
index 00000000..5b12893d
--- /dev/null
+++ b/include/plus4.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=c264_loader,c264_kernal,c264_hardware
+
+
+[allocation]
+main_org=$100D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$7FFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/stdio.mfk b/include/stdio.mfk
new file mode 100644
index 00000000..16161f2d
--- /dev/null
+++ b/include/stdio.mfk
@@ -0,0 +1,10 @@
+// target-independent standard I/O routines
+
+void putstr(pointer str, byte len) {
+ byte index
+ index = 0
+ while (index != len) {
+ putchar(str[index])
+ index += 1
+ }
+}
\ No newline at end of file
diff --git a/include/stdlib.mfk b/include/stdlib.mfk
new file mode 100644
index 00000000..0dda8f5c
--- /dev/null
+++ b/include/stdlib.mfk
@@ -0,0 +1,23 @@
+// target-independent things
+
+word nmi_routine_addr @$FFFA
+word reset_routine_addr @$FFFC
+word irq_routine_addr @$FFFE
+
+inline asm void poke(word const addr, byte const value) {
+ ?LDA #value
+ STA addr
+}
+
+inline asm byte peek(word const addr) {
+ LDA addr
+}
+
+inline asm void disable_irq() {
+ SEI
+}
+
+inline asm void enable_irq() {
+ CLI
+}
+
diff --git a/include/vic20.ini b/include/vic20.ini
new file mode 100644
index 00000000..6c33a727
--- /dev/null
+++ b/include/vic20.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=loader_1001,vic20_kernal
+
+
+[allocation]
+main_org=$100D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$1CFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/vic20_3k.ini b/include/vic20_3k.ini
new file mode 100644
index 00000000..10088a92
--- /dev/null
+++ b/include/vic20_3k.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=loader_0401,vic20_kernal
+
+
+[allocation]
+main_org=$40D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$1CFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/vic20_8k.ini b/include/vic20_8k.ini
new file mode 100644
index 00000000..d7acfa58
--- /dev/null
+++ b/include/vic20_8k.ini
@@ -0,0 +1,19 @@
+[compilation]
+arch=nmos
+modules=loader_1201,vic20_kernal
+
+
+[allocation]
+main_org=$120D
+; TODO
+zp_pointers=$C1,$C3,$FB,$FD,$39,$3B,$3D,$43,$4B
+himem_style=per_bank
+himem_start=after_code
+himem_end=$1FFF
+
+[output]
+style=per_bank
+format=startaddr,allocated
+extension=prg
+
+
diff --git a/include/vic20_kernal.mfk b/include/vic20_kernal.mfk
new file mode 100644
index 00000000..2f1ae990
--- /dev/null
+++ b/include/vic20_kernal.mfk
@@ -0,0 +1,5 @@
+// Routines from C16 and Plus/4 KERNAL ROM
+
+// CHROUT. Write byte to default output. (If not screen, must call OPEN and CHKOUT beforehands.)
+// Input: A = Byte to write.
+asm void putchar(byte a) @$FFD2 extern
\ No newline at end of file
diff --git a/project/assembly.sbt b/project/assembly.sbt
new file mode 100644
index 00000000..cdb3c0bb
--- /dev/null
+++ b/project/assembly.sbt
@@ -0,0 +1,2 @@
+
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
diff --git a/project/build.properties b/project/build.properties
new file mode 100644
index 00000000..826c0bd9
--- /dev/null
+++ b/project/build.properties
@@ -0,0 +1 @@
+sbt.version = 0.13.16
\ No newline at end of file
diff --git a/project/buildinfo.sbt b/project/buildinfo.sbt
new file mode 100644
index 00000000..42c669b1
--- /dev/null
+++ b/project/buildinfo.sbt
@@ -0,0 +1 @@
+addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.7.0")
\ No newline at end of file
diff --git a/project/plugins.sbt b/project/plugins.sbt
new file mode 100644
index 00000000..e69de29b
diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala
new file mode 100644
index 00000000..1e7db997
--- /dev/null
+++ b/src/main/scala/millfork/CompilationOptions.scala
@@ -0,0 +1,116 @@
+package millfork
+
+import millfork.error.ErrorReporting
+
+/**
+ * @author Karol Stasiak
+ */
+//
+//object CompilationOptions {
+//
+//
+// private var instance = new CompilationOptions(Platform.C64, Map())
+//
+// // TODO: ugly!
+// def change(o: CompilationOptions): Unit = {
+// instance = o
+// }
+//
+// def current: CompilationOptions= instance
+//
+// def platform: Platform = instance.platform
+//
+// def flag(flag: CompilationFlag.Value):Boolean = instance.flags(flag)
+//
+// def flags: Map[CompilationFlag.Value, Boolean] = instance.flags
+//}
+class CompilationOptions(val platform: Platform, val commandLineFlags: Map[CompilationFlag.Value, Boolean]) {
+
+ import CompilationFlag._
+ import Cpu._
+
+ val flags: Map[CompilationFlag.Value, Boolean] = CompilationFlag.values.map { f =>
+ f -> commandLineFlags.getOrElse(f, platform.flagOverrides.getOrElse(f, Cpu.defaultFlags(platform.cpu)(f)))
+ }.toMap
+
+ def flag(f: CompilationFlag.Value) = flags(f)
+
+ if (flags(DecimalMode)) {
+ if (platform.cpu == Ricoh || platform.cpu == StrictRicoh) {
+ ErrorReporting.warn("Decimal mode enabled for Ricoh architecture", this)
+ }
+ }
+ if (platform.cpu != Cmos) {
+ if (!flags(PreventJmpIndirectBug)) {
+ ErrorReporting.warn("JMP bug prevention should be enabled for non-CMOS architecture", this)
+ }
+ if (flags(EmitCmosOpcodes)) {
+ ErrorReporting.warn("CMOS opcodes enabled for non-CMOS architecture", this)
+ }
+ }
+ if (flags(EmitIllegals)) {
+ if (platform.cpu == Cmos) {
+ ErrorReporting.warn("Illegal opcodes enabled for CMOS architecture", this)
+ }
+ if (platform.cpu == StrictRicoh || platform.cpu == Ricoh) {
+ ErrorReporting.warn("Illegal opcodes enabled for strict architecture", this)
+ }
+ }
+}
+
+object Cpu extends Enumeration {
+
+ val Mos, StrictMos, Ricoh, StrictRicoh, Cmos = Value
+
+ import CompilationFlag._
+
+ def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match {
+ case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap)
+ case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap)
+ case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap)
+ case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap)
+ case Cmos => Set(EmitCmosOpcodes, VariableOverlap)
+ }
+
+ def fromString(name: String): Cpu.Value = name match {
+ case "nmos" => Mos
+ case "6502" => Mos
+ case "6510" => Mos
+ case "strict" => StrictMos
+ case "cmos" => Cmos
+ case "65c02" => Cmos
+ case "ricoh" => Ricoh
+ case "2a03" => Ricoh
+ case "2a07" => Ricoh
+ case "strictricoh" => StrictRicoh
+ case "strict2a03" => StrictRicoh
+ case "strict2a07" => StrictRicoh
+ case _ => ErrorReporting.fatal("Unknown CPU achitecture")
+ }
+}
+
+object CompilationFlag extends Enumeration {
+ val
+ // compilation options:
+ EmitIllegals, EmitCmosOpcodes, DecimalMode, ReadOnlyArrays, PreventJmpIndirectBug,
+ // optimization options:
+ DetailedFlowAnalysis, DangerousOptimizations,
+ // memory allocation options
+ VariableOverlap,
+ // warning options
+ ExtraComparisonWarnings,
+ RorWarning,
+ FatalWarnings = Value
+
+ val allWarnings: Set[CompilationFlag.Value] = Set(ExtraComparisonWarnings)
+
+ val fromString = Map(
+ "emit_illegals" -> EmitIllegals,
+ "emit_cmos" -> EmitCmosOpcodes,
+ "decimal_mode" -> DecimalMode,
+ "ro_arrays" -> ReadOnlyArrays,
+ "ror_warn" -> RorWarning,
+ "prevent_jmp_indirect_bug" -> PreventJmpIndirectBug,
+ )
+
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala
new file mode 100644
index 00000000..37a9cc99
--- /dev/null
+++ b/src/main/scala/millfork/Main.scala
@@ -0,0 +1,264 @@
+package millfork
+
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Paths}
+import java.util.Locale
+
+import millfork.assembly.opt.{CmosOptimizations, DangerousOptimizations, SuperOptimizer, UndocumentedOptimizations}
+import millfork.cli.{CliParser, CliStatus}
+import millfork.env.Environment
+import millfork.error.ErrorReporting
+import millfork.node.StandardCallGraph
+import millfork.output.Assembler
+import millfork.parser.SourceLoadingQueue
+
+/**
+ * @author Karol Stasiak
+ */
+
+case class Context(inputFileNames: List[String],
+ outputFileName: Option[String] = None,
+ runFileName: Option[String] = None,
+ optimizationLevel: Option[Int] = None,
+ platform: Option[String] = None,
+ outputAssembly: Boolean = false,
+ outputLabels: Boolean = false,
+ includePath: List[String] = Nil,
+ flags: Map[CompilationFlag.Value, Boolean] = Map(),
+ verbosity: Option[Int] = None) {
+ def changeFlag(f: CompilationFlag.Value, b: Boolean): Context = {
+ if (flags.contains(f)) {
+ if (flags(f) != b) {
+ ErrorReporting.error("Conflicting flags")
+ }
+ this
+ } else {
+ copy(flags = this.flags + (f -> b))
+ }
+ }
+}
+
+object Main {
+
+
+ def main(args: Array[String]): Unit = {
+ if (args.isEmpty) {
+ ErrorReporting.info("For help, use --help")
+ }
+ val (status, c) = parser.parse(Context(Nil), args.toList)
+ status match {
+ case CliStatus.Quit => return
+ case CliStatus.Failed =>
+ ErrorReporting.fatalQuit("Invalid command line")
+ case CliStatus.Ok => ()
+ }
+ ErrorReporting.assertNoErrors("Invalid command line")
+ if (c.inputFileNames.isEmpty) {
+ ErrorReporting.fatalQuit("No input files")
+ }
+ ErrorReporting.verbosity = c.verbosity.getOrElse(0)
+ val optLevel = c.optimizationLevel.getOrElse(0)
+ val platform = Platform.lookupPlatformFile(c.includePath, c.platform.getOrElse {
+ ErrorReporting.info("No platform selected, defaulting to `c64`")
+ "c64"
+ })
+ val options = new CompilationOptions(platform, c.flags)
+ ErrorReporting.debug("Effective flags: " + options.flags)
+
+ val output = c.outputFileName.getOrElse("a")
+ val assOutput = output + ".asm"
+ val labelOutput = output + ".lbl"
+ val prgOutput = if (!output.endsWith(platform.fileExtension)) output + platform.fileExtension else output
+
+ val unoptimized = new SourceLoadingQueue(
+ initialFilenames = c.inputFileNames,
+ includePath = c.includePath,
+ options = options).run()
+
+ val program = if (optLevel > 0) {
+ OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt))
+ } else {
+ unoptimized
+ }
+ val callGraph = new StandardCallGraph(program)
+
+ val env = new Environment(None, "")
+ env.collectDeclarations(program, options)
+ val extras = List(
+ if (options.flag(CompilationFlag.EmitIllegals)) UndocumentedOptimizations.All else Nil,
+ if (options.flag(CompilationFlag.EmitCmosOpcodes)) CmosOptimizations.All else Nil,
+ if (options.flag(CompilationFlag.DangerousOptimizations)) DangerousOptimizations.All else Nil,
+ ).flatten
+ val goodCycle = List.fill(optLevel - 1)(OptimizationPresets.Good ++ extras).flatten
+ val assemblyOptimizations = if (optLevel <= 0) Nil else if (optLevel >= 9) List(SuperOptimizer) else {
+ goodCycle ++ OptimizationPresets.AssOpt ++ extras ++ goodCycle
+ }
+
+ // compile
+ val assembler = new Assembler(env)
+ val result = assembler.assemble(callGraph, assemblyOptimizations, options)
+ ErrorReporting.assertNoErrors("Codegen failed")
+ ErrorReporting.debug(f"Unoptimized code size: ${assembler.unoptimizedCodeSize}%5d B")
+ ErrorReporting.debug(f"Optimized code size: ${assembler.optimizedCodeSize}%5d B")
+ ErrorReporting.debug(f"Gain: ${(100L * (assembler.unoptimizedCodeSize - assembler.optimizedCodeSize) / assembler.unoptimizedCodeSize.toDouble).round}%5d%%")
+ ErrorReporting.debug(f"Initialized arrays: ${assembler.initializedArraysSize}%5d B")
+
+ if (c.outputAssembly) {
+ val path = Paths.get(assOutput)
+ ErrorReporting.debug("Writing assembly to " + path.toAbsolutePath)
+ Files.write(path, result.asm.mkString("\n").getBytes(StandardCharsets.UTF_8))
+ }
+ if (c.outputLabels) {
+ val path = Paths.get(labelOutput)
+ ErrorReporting.debug("Writing labels to " + path.toAbsolutePath)
+ Files.write(path, result.labels.sortWith { (a, b) =>
+ val aLocal = a._1.head == '.'
+ val bLocal = b._1.head == '.'
+ if (aLocal == bLocal) a._1 < b._1
+ else b._1 < a._1
+ }.groupBy(_._2).values.map(_.head).toSeq.sortBy(_._2).map { case (l, a) =>
+ val normalized = l.replace('$', '_').replace('.', '_')
+ s"al ${a.toHexString} .$normalized"
+ }.mkString("\n").getBytes(StandardCharsets.UTF_8))
+ }
+ val path = Paths.get(prgOutput)
+ ErrorReporting.debug("Writing output to " + path.toAbsolutePath)
+ Files.write(path, result.code)
+ c.runFileName.foreach(program =>
+ new ProcessBuilder(program, path.toAbsolutePath.toString).start()
+ )
+ }
+
+ private def parser = new CliParser[Context] {
+
+ fluff("Main options:", "")
+
+ parameter("-o", "--out").required().placeholder("").action { (p, c) =>
+ assertNone(c.outputFileName, "Output already defined")
+ c.copy(outputFileName = Some(p))
+ }.description("The output file name, without extension.").onWrongNumber(_ => ErrorReporting.fatalQuit("No output file specified"))
+
+ flag("-s").action { c =>
+ c.copy(outputAssembly = true)
+ }.description("Generate also the assembly output.")
+
+ flag("-g").action { c =>
+ c.copy(outputLabels = true)
+ }.description("Generate also the label file.")
+
+ parameter("-t", "--target").placeholder("").action { (p, c) =>
+ assertNone(c.platform, "Platform already defined")
+ c.copy(platform = Some(p))
+ }.description("Target platform, any of: c64, c16, plus4, vic20, vic20_3k, vic20_8k, pet, c128, a8.")
+
+ parameter("-I", "--include-dir").repeatable().placeholder(";;...").action { (paths, c) =>
+ val n = paths.split(";")
+ c.copy(includePath = c.includePath ++ n)
+ }.description("Include paths for modules.")
+
+ parameter("-r", "--run").placeholder("").action { (p, c) =>
+ assertNone(c.runFileName, "Run program already defined")
+ c.copy(runFileName = Some(p))
+ }.description("Program to run after successful compilation.")
+
+ endOfFlags("--").description("Marks the end of options.")
+
+ fluff("", "Verbosity options:", "")
+
+ flag("-q", "--quiet").action { c =>
+ assertNone(c.verbosity, "Cannot use -v and -q together")
+ c.copy(verbosity = Some(-1))
+ }.description("Supress all messages except for errors.")
+
+ private val verbose = flag("-v", "--verbose").maxCount(3).action { c =>
+ if (c.verbosity.exists(_ < 0)) ErrorReporting.error("Cannot use -v and -q together", None)
+ c.copy(verbosity = Some(1 + c.verbosity.getOrElse(0)))
+ }.description("Increase verbosity.")
+ flag("-vv").repeatable().action(c => verbose.encounter(verbose.encounter(verbose.encounter(c)))).description("Increase verbosity even more.")
+ flag("-vvv").repeatable().action(c => verbose.encounter(verbose.encounter(c))).description("Increase verbosity even more.")
+
+ fluff("", "Code generation options:", "")
+
+ boolean("-fcmos-ops", "-fno-cmos-ops").action { (c, v) =>
+ c.changeFlag(CompilationFlag.EmitCmosOpcodes, v)
+ }.description("Whether should emit CMOS opcodes.")
+ boolean("-fillegals", "-fno-illegals").action { (c, v) =>
+ c.changeFlag(CompilationFlag.EmitIllegals, v)
+ }.description("Whether should emit illegal (undocumented) NMOS opcodes.")
+ boolean("-fjmp-fix", "-fno-jmp-fix").action { (c, v) =>
+ c.changeFlag(CompilationFlag.PreventJmpIndirectBug, v)
+ }.description("Whether should prevent indirect JMP bug on page boundary.")
+ boolean("-fdecimal-mode", "-fno-decimal-mode").action { (c, v) =>
+ c.changeFlag(CompilationFlag.DecimalMode, v)
+ }.description("Whether should decimal mode be available.")
+ boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) =>
+ c.changeFlag(CompilationFlag.VariableOverlap, v)
+ }.description("Whether should variables overlap if their scopes do not intersect.")
+
+ fluff("", "Optimization options:", "")
+
+
+ flag("-O0").action { c =>
+ assertNone(c.optimizationLevel, "Optimization level already defined")
+ c.copy(optimizationLevel = Some(0))
+ }.description("Disable all optimizations.")
+ flag("-O").action { c =>
+ assertNone(c.optimizationLevel, "Optimization level already defined")
+ c.copy(optimizationLevel = Some(1))
+ }.description("Optimize code.")
+ for (i <- 2 to 9) {
+ val f = flag("-O" + i).action { c =>
+ assertNone(c.optimizationLevel, "Optimization level already defined")
+ c.copy(optimizationLevel = Some(i))
+ }.description("Optimize code even more.")
+ if (i > 3) f.hidden()
+ }
+ flag("--detailed-flow").action { c =>
+ c.changeFlag(CompilationFlag.DetailedFlowAnalysis, true)
+ }.description("Use detailed flow analysis (experimental).")
+ flag("--dangerous-optimizations").action { c =>
+ c.changeFlag(CompilationFlag.DangerousOptimizations, true)
+ }.description("Use dangerous optimizations (experimental).")
+
+ fluff("", "Warning options:", "")
+
+ flag("-Wall", "--Wall").action { c =>
+ CompilationFlag.allWarnings.foldLeft(c) { (c, f) => c.changeFlag(f, true) }
+ }.description("Enable extra warnings.")
+
+ flag("-Wfatal", "--Wfatal").action { c =>
+ c.changeFlag(CompilationFlag.FatalWarnings, true)
+ }.description("Treat warnings as errors.")
+
+ fluff("", "Other options:", "")
+
+ flag("--help").action(c => {
+ printHelp(20).foreach(println(_))
+ assumeStatus(CliStatus.Quit)
+ c
+ }).description("Display this message.")
+
+ flag("--version").action(c => {
+ println("millfork version ")
+ assumeStatus(CliStatus.Quit)
+ System.exit(0)
+ c
+ }).description("Print the version and quit.")
+
+
+ default.action { (p, c) =>
+ if (p.startsWith("-")) {
+ ErrorReporting.error(s"Invalid option `$p`", None)
+ c
+ } else {
+ c.copy(inputFileNames = c.inputFileNames :+ p)
+ }
+ }
+
+ def assertNone[T](value: Option[T], msg: String): Unit = {
+ if (value.isDefined) {
+ ErrorReporting.error(msg, None)
+ }
+ }
+ }
+}
diff --git a/src/main/scala/millfork/OptimizationPresets.scala b/src/main/scala/millfork/OptimizationPresets.scala
new file mode 100644
index 00000000..7b160e24
--- /dev/null
+++ b/src/main/scala/millfork/OptimizationPresets.scala
@@ -0,0 +1,150 @@
+package millfork
+
+import millfork.assembly.opt._
+import millfork.node.opt.{UnreachableCode, UnusedFunctions, UnusedGlobalVariables, UnusedLocalVariables}
+
+/**
+ * @author Karol Stasiak
+ */
+object OptimizationPresets {
+ val NodeOpt = List(
+ UnreachableCode,
+ UnusedFunctions,
+ UnusedLocalVariables,
+ UnusedGlobalVariables,
+ )
+ val AssOpt: List[AssemblyOptimization] = List[AssemblyOptimization](
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessLoadAfterLoadOrStore,
+ LaterOptimizations.PointessLoadingForShifting,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.BranchInPlaceRemoval,
+ UnusedLabelRemoval,
+ AlwaysGoodOptimizations.UnconditionalJumpRemoval,
+ UnusedLabelRemoval,
+ AlwaysGoodOptimizations.RearrangeMath,
+ LaterOptimizations.PointlessLoadAfterStore,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessOperationAfterLoad,
+ AlwaysGoodOptimizations.PointlessLoadBeforeTransfer,
+ VariableToRegisterOptimization,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessOperationPairRemoval,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ LaterOptimizations.PointlessLoadAfterStore,
+ AlwaysGoodOptimizations.PointlessOperationAfterLoad,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.ConstantIndexPropagation,
+ AlwaysGoodOptimizations.PointlessLoadBeforeReturn,
+ AlwaysGoodOptimizations.PoinlessFlagChange,
+ AlwaysGoodOptimizations.FlagFlowAnalysis,
+ AlwaysGoodOptimizations.ConstantFlowAnalysis,
+ AlwaysGoodOptimizations.PointlessMath,
+ VariableToRegisterOptimization,
+ ChangeIndexRegisterOptimizationPreferringX2Y,
+ VariableToRegisterOptimization,
+ ChangeIndexRegisterOptimizationPreferringY2X,
+ VariableToRegisterOptimization,
+ AlwaysGoodOptimizations.ConstantFlowAnalysis,
+ LaterOptimizations.DoubleLoadToDifferentRegisters,
+ LaterOptimizations.DoubleLoadToTheSameRegister,
+ LaterOptimizations.DoubleLoadToDifferentRegisters,
+ LaterOptimizations.DoubleLoadToTheSameRegister,
+ LaterOptimizations.DoubleLoadToDifferentRegisters,
+ LaterOptimizations.DoubleLoadToTheSameRegister,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.ConstantIndexPropagation,
+ AlwaysGoodOptimizations.ConstantFlowAnalysis,
+ AlwaysGoodOptimizations.PointlessRegisterTransfers,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeCompare,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeStore,
+ AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch,
+ AlwaysGoodOptimizations.RearrangeMath,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PointlessLoadBeforeReturn,
+ LaterOptimizations.PointessLoadingForShifting,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+
+ LaterOptimizations.LoadingAfterShifting,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PoinlessStoreBeforeStore,
+ LaterOptimizations.PointlessLoadAfterStore,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+
+ LaterOptimizations.LoadingAfterShifting,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PoinlessStoreBeforeStore,
+ LaterOptimizations.PointlessLoadAfterStore,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+
+ LaterOptimizations.LoadingAfterShifting,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PoinlessStoreBeforeStore,
+ LaterOptimizations.PointlessLoadAfterStore,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.TailCallOptimization,
+ AlwaysGoodOptimizations.UnusedCodeRemoval,
+ AlwaysGoodOptimizations.ReverseFlowAnalysis,
+ AlwaysGoodOptimizations.ModificationOfJustWrittenValue,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.MathOperationOnTwoIdenticalMemoryOperands,
+ LaterOptimizations.UseZeropageAddressingMode,
+
+ LaterOptimizations.UseXInsteadOfStack,
+ LaterOptimizations.UseYInsteadOfStack,
+ LaterOptimizations.IndexSwitchingOptimization,
+ )
+
+ val Good: List[AssemblyOptimization] = List[AssemblyOptimization](
+ AlwaysGoodOptimizations.Adc0Optimization,
+ AlwaysGoodOptimizations.CarryFlagConversion,
+ DangerousOptimizations.ConstantIndexOffsetPropagation,
+ AlwaysGoodOptimizations.BranchInPlaceRemoval,
+ AlwaysGoodOptimizations.ConstantFlowAnalysis,
+ AlwaysGoodOptimizations.ConstantIndexPropagation,
+ AlwaysGoodOptimizations.FlagFlowAnalysis,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.ImpossibleBranchRemoval,
+ AlwaysGoodOptimizations.IndexSequenceOptimization,
+ AlwaysGoodOptimizations.MathOperationOnTwoIdenticalMemoryOperands,
+ AlwaysGoodOptimizations.ModificationOfJustWrittenValue,
+ AlwaysGoodOptimizations.PoinlessFlagChange,
+ AlwaysGoodOptimizations.PointlessLoadAfterLoadOrStore,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessLoadBeforeReturn,
+ AlwaysGoodOptimizations.PointlessLoadBeforeTransfer,
+ AlwaysGoodOptimizations.PointlessMath,
+ AlwaysGoodOptimizations.PointlessMathFromFlow,
+ AlwaysGoodOptimizations.PointlessOperationAfterLoad,
+ AlwaysGoodOptimizations.PointlessOperationPairRemoval,
+ AlwaysGoodOptimizations.PointlessRegisterTransfers,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeCompare,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch,
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ AlwaysGoodOptimizations.PoinlessStoreBeforeStore,
+ AlwaysGoodOptimizations.RearrangeMath,
+ AlwaysGoodOptimizations.RemoveNops,
+ AlwaysGoodOptimizations.ReverseFlowAnalysis,
+ AlwaysGoodOptimizations.SimplifiableBitOpsSequence,
+ AlwaysGoodOptimizations.SmarterShiftingWords,
+ AlwaysGoodOptimizations.UnconditionalJumpRemoval,
+ UnusedLabelRemoval,
+ AlwaysGoodOptimizations.TailCallOptimization,
+ AlwaysGoodOptimizations.UnusedCodeRemoval,
+ )
+}
diff --git a/src/main/scala/millfork/Platform.scala b/src/main/scala/millfork/Platform.scala
new file mode 100644
index 00000000..1d05edf9
--- /dev/null
+++ b/src/main/scala/millfork/Platform.scala
@@ -0,0 +1,115 @@
+package millfork
+
+import java.io.{File, StringReader}
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Paths}
+
+import millfork.error.ErrorReporting
+import millfork.output._
+import org.apache.commons.configuration2.INIConfiguration
+
+/**
+ * @author Karol Stasiak
+ */
+
+class Platform(
+ val cpu: Cpu.Value,
+ val flagOverrides: Map[CompilationFlag.Value, Boolean],
+ val startingModules: List[String],
+ val outputPackager: OutputPackager,
+ val allocator: VariableAllocator,
+ val org: Int,
+ val fileExtension: String,
+ )
+
+object Platform {
+
+ val C64 = new Platform(
+ Cpu.Mos,
+ Map(),
+ List("c64_hardware", "c64_loader"),
+ SequenceOutput(List(StartAddressOutput, AllocatedDataOutput)),
+ new VariableAllocator(
+ List(0xC1, 0xC3, 0xFB, 0xFD, 0x39, 0x3B, 0x3D, 0x43, 0x4B),
+ new AfterCodeByteAllocator(0xA000)
+ ),
+ 0x80D,
+ ".prg"
+ )
+
+ def lookupPlatformFile(includePath: List[String], platformName: String): Platform = {
+ includePath.foreach { dir =>
+ val file = Paths.get(dir, platformName + ".ini").toFile
+ ErrorReporting.debug("Checking " + file)
+ if (file.exists()) {
+ return load(file)
+ }
+ }
+ ErrorReporting.fatal(s"Platfom definition `$platformName` not found", None)
+ }
+
+ def load(file: File): Platform = {
+ val conf = new INIConfiguration()
+ val bytes = Files.readAllBytes(file.toPath)
+ conf.read(new StringReader(new String(bytes, StandardCharsets.UTF_8)))
+
+ val cs = conf.getSection("compilation")
+ val cpu = Cpu.fromString(cs.get(classOf[String], "cpu", "strict"))
+ val flagOverrides = CompilationFlag.fromString.flatMap { case (k, f) =>
+ cs.get(classOf[String], k, "").toLowerCase match {
+ case "" => None
+ case "false" | "off" | "0" => Some(f -> false)
+ case "true" | "on" | "1" => Some(f -> true)
+ }
+ }
+ val startingModules = cs.get(classOf[String], "modules", "").split("[, ]+").filter(_.nonEmpty).toList
+
+ val as = conf.getSection("allocation")
+ val org = as.get(classOf[String], "main_org", "") match {
+ case "" => ErrorReporting.fatal(s"Undefined main_org")
+ case m => parseNumber(m)
+ }
+ val freePointers = as.get(classOf[String], "zp_pointers", "all") match {
+ case "all" => List.tabulate(128)(_ * 2)
+ case xs => xs.split("[, ]+").map(parseNumber).toList
+ }
+ val byteAllocator = (as.get(classOf[String], "himem_start", ""), as.get(classOf[String], "himem_end", "")) match {
+ case ("", _) => ErrorReporting.fatal(s"Undefined himem_start")
+ case (_, "") => ErrorReporting.fatal(s"Undefined himem_end")
+ case ("after_code", end) => new AfterCodeByteAllocator(parseNumber(end) + 1)
+ case (start, end) => new UpwardByteAllocator(parseNumber(start), parseNumber(end) + 1)
+ }
+
+ val os = conf.getSection("output")
+ val outputPackager = SequenceOutput(os.get(classOf[String], "format", "").split("[, ]+").filter(_.nonEmpty).map {
+ case "startaddr" => StartAddressOutput
+ case "endaddr" => EndAddressOutput
+ case "allocated" => AllocatedDataOutput
+ case n => n.split(":").filter(_.nonEmpty) match {
+ case Array(b, s, e) => BankFragmentOutput(parseNumber(b), parseNumber(s), parseNumber(e))
+ case Array(s, e) => CurrentBankFragmentOutput(parseNumber(s), parseNumber(e))
+ case Array(b) => ConstOutput(parseNumber(b).toByte)
+ case x => ErrorReporting.fatal(s"Invalid output format: `$x`")
+ }
+ }.toList)
+ var fileExtension = os.get(classOf[String], "extension", ".bin")
+
+ new Platform(cpu, flagOverrides, startingModules, outputPackager,
+ new VariableAllocator(freePointers, byteAllocator), org,
+ if (fileExtension.startsWith(".")) fileExtension else "." + fileExtension)
+ }
+
+ def parseNumber(s: String): Int = {
+ if (s.startsWith("$")) {
+ Integer.parseInt(s.substring(1), 16)
+ } else if (s.startsWith("0x")) {
+ Integer.parseInt(s.substring(2), 16)
+ } else if (s.startsWith("%")) {
+ Integer.parseInt(s.substring(1), 2)
+ } else if (s.startsWith("0b")) {
+ Integer.parseInt(s.substring(2), 2)
+ } else {
+ s.toInt
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/SeparatedList.scala b/src/main/scala/millfork/SeparatedList.scala
new file mode 100644
index 00000000..c58c848c
--- /dev/null
+++ b/src/main/scala/millfork/SeparatedList.scala
@@ -0,0 +1,50 @@
+package millfork
+
+/**
+ * @author Karol Stasiak
+ */
+case class SeparatedList[T, S](head: T, tail: List[(S, T)]) {
+
+ def toPairList(initialSeparator: S) = (initialSeparator -> head) :: tail
+
+ def size: Int = tail.size + 1
+
+ def items: List[T] = head :: tail.map(_._2)
+
+ def separators: List[S] = tail.map(_._1)
+
+ def drop(i: Int): SeparatedList[T, S] = if (i == 0) this else SeparatedList(tail(i - 1)._2, tail.drop(i))
+
+ def take(i: Int): SeparatedList[T, S] = if (i <= 0) ??? else SeparatedList(head, tail.take(i - 1))
+
+ def splitAt(i: Int): (SeparatedList[T, S], S, SeparatedList[T, S]) = {
+ val (a, b) = tail.splitAt(i - 1)
+ (SeparatedList(head, a), b.head._1, SeparatedList(b.head._2, b.tail))
+ }
+
+ def indexOfSeparator(p: S => Boolean): Int = 1 + tail.indexWhere(x => p(x._1))
+
+ def ::(pair: (T, S)) = SeparatedList(pair._1, (pair._2 -> head) :: tail)
+
+ def split(p: S => Boolean): SeparatedList[SeparatedList[T, S], S] = {
+ val i = indexOfSeparator(p)
+ if (i <= 0) SeparatedList(this, Nil)
+ else {
+ val (a, b, c) = splitAt(i)
+ (a, b) :: c.split(p)
+ }
+ }
+}
+
+object SeparatedList {
+ def of[T, S](t0: T): SeparatedList[T, S] = SeparatedList[T, S](t0, Nil)
+
+ def of[T, S](t0: T, s1: S, t1: T): SeparatedList[T, S] =
+ SeparatedList(t0, List(s1 -> t1))
+
+ def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T): SeparatedList[T, S] =
+ SeparatedList(t0, List(s1 -> t1, s2 -> t2))
+
+ def of[T, S](t0: T, s1: S, t1: T, s2: S, t2: T, s3: S, t3: T): SeparatedList[T, S] =
+ SeparatedList(t0, List(s1 -> t1, s2 -> t2, s3 -> t3))
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/assembly/AssemblyLine.scala b/src/main/scala/millfork/assembly/AssemblyLine.scala
new file mode 100644
index 00000000..a9a2a868
--- /dev/null
+++ b/src/main/scala/millfork/assembly/AssemblyLine.scala
@@ -0,0 +1,305 @@
+package millfork.assembly
+
+import java.lang.management.MemoryType
+
+import millfork.assembly.Opcode._
+import millfork.assembly.opt.ReadsA
+import millfork.compiler.{CompilationContext, MlCompiler}
+import millfork.env._
+
+//noinspection TypeAnnotation
+object OpcodeClasses {
+
+ val ReadsAAlways = Set(
+ ADC, AND, BIT, CMP, EOR, ORA, PHA, SBC, STA, TAX, TAY,
+ SAX, SBX, ANC, DCP
+ )
+ val ReadsAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC)
+ val ReadsXAlways = Set(
+ CPX, DEX, INX, STX, TXA, TXS, SBX,
+ PLX,
+ XAA, SAX, AHX, SHX
+ )
+ val ReadsYAlways = Set(CPY, DEY, INY, STY, TYA, PLY, SHY)
+ val ReadsZ = Set(BNE, BEQ, PHP)
+ val ReadsN = Set(BMI, BPL, PHP)
+ val ReadsNOrZ = ReadsZ ++ ReadsN
+ val ReadsV = Set(BVS, BVC, PHP)
+ val ReadsD = Set(PHP, ADC, SBC, RRA, ARR, ISC, DCP) // TODO: ??
+ val ReadsC = Set(
+ PHP, ADC, SBC, BCC, BCS, ROL, ROR,
+ ALR, ARR, ISC, RLA, RRA, SLO, SRE // TODO: ??
+ )
+ val ChangesAAlways = Set(
+ TXA, TYA, PLA,
+ ORA, AND, EOR, ADC, LDA, SBC,
+ SLO, RLA, SRE, RRA, LAX, ISC,
+ XAA, ANC, ALR, ARR, LXA, LAS
+ )
+ val ChangesAIfImplied = Set(ASL, LSR, ROL, ROR, INC, DEC)
+ val ChangesX = Set(
+ DEX, INX, TAX, LDX, TSX,
+ SBX, LAX, LXA, LAS,
+ PLX,
+ )
+ val ChangesY = Set(
+ DEY, INY, TAY, LDY
+ )
+ val ChangesS = Set(
+ PHA, PLA, PHP, PLP, TXS,
+ PHX, PHY, PLX, PLY, TAS, LAS
+ )
+ val ChangesMemoryAlways = Set(
+ STA, STY, STZ,
+ STX, DEC, INC,
+ SAX, DCP, ISC,
+ SLO, RLA, SRE, RRA,
+ AHX, SHY, SHX, TAS, LAS
+ )
+ val ChangesMemoryIfNotImplied = Set(
+ ASL, ROL, LSR, ROR
+ )
+ val ReadsMemoryIfNotImpliedOrImmediate = Set(
+ LDY, CPX, CPY,
+ ORA, AND, EOR, ADC, LDA, CMP, SBC,
+ ASL, ROL, LSR, ROR, LDX, DEC, INC,
+ SLO, RLA, SRE, RRA, LAX, DCP, ISC,
+ LAS,
+ TRB, TSB
+ )
+ val OverwritesA = Set(
+ LDA, PLA, TXA, TYA,
+ LAX, LAS
+ )
+ val OverwritesX = Set(
+ TAX, LDX, TSX, PLX,
+ LAX, LAS
+ )
+ val OverwritesY = Set(
+ TAY, LDY, PLY
+ )
+ val OverwritesC = Set(CLC, SEC, PLP)
+ val OverwritesD = Set(CLD, SED, PLP)
+ val OverwritesI = Set(CLI, SEI, PLP)
+ val OverwritesV = Set(CLV, PLP)
+ val ConcernsAAlways = ReadsAAlways ++ ChangesAAlways
+ val ConcernsAIfImplied = ReadsAIfImplied ++ ChangesAIfImplied
+ val ConcernsXAlways = ReadsXAlways | ChangesX
+ val ConcernsYAlways = ReadsYAlways | ChangesY
+
+ val ConcernsStack = Set(
+ PHA, PLA, PHP, PLP,
+ PHX, PLX, PHY, PLY,
+ TXS, TSX,
+ JSR, RTS, RTI,
+ TAS, LAS,
+ )
+
+ val ChangesNAndZ = Set(
+ ADC, AND, ASL, BIT, CMP, CPX, CPY, DEC, DEX, DEY, EOR, INC, INX, INY, LDA,
+ LDX, LDY, LSR, ORA, PLP, ROL, ROR, SBC, TAX, TAY, TXA, TYA,
+ LAX, SBX, ANC, ALR, ARR, DCP, ISC, RLA, RRA, SLO, SRE, SAX,
+ TSB, TRB // These two do not change N, but lets pretend they do for simplicity
+ )
+ val ChangesC = Set(
+ CLC, SEC, ADC, ASL, CMP, CPX, CPY, LSR, PLP, ROL, ROR, SBC,
+ SBX, ANC, ALR, ARR, DCP, ISC, RLA, RRA, SLO, SRE
+ )
+ val ChangesV = Set(
+ ADC, BIT, PLP, SBC,
+ ARR, ISC, RRA,
+ )
+
+ val SupportsAbsoluteX = Set(
+ ORA, AND, EOR, ADC, CMP, SBC,
+ ASL, ROL, LSR, ROR, DEC, INC,
+ SLO, RLA, SRE, RRA, DCP, ISC,
+ STA, LDA, LDY, STZ, SHY,
+ )
+
+ val SupportsAbsoluteY = Set(
+ ORA, AND, EOR, ADC, CMP, SBC,
+ SLO, RLA, SRE, RRA, DCP, ISC,
+ STA, LDA, LDX,
+ LAX, AHX, SHX, TAS, LAS,
+ )
+
+ val SupportsAbsolute = Set(
+ ORA, AND, EOR, ADC, STA, LDA, CMP, SBC,
+ ASL, ROL, LSR, ROR, STX, LDX, DEC, INC,
+ SLO, RLA, SRE, RRA, SAX, LAX, DCP, ISC,
+ STY, LDY,
+ BIT, JMP, JSR,
+ STZ, TRB, TSB,
+ )
+
+ val SupportsZeroPageIndirect = Set(ORA, AND, EOR, ADC, STA, LDA, CMP, SBC)
+
+ val ShortBranching = Set(BEQ, BNE, BMI, BPL, BVC, BVS, BCC, BCS, BRA)
+ val AllDirectJumps = ShortBranching + JMP
+ val AllLinear = Set(
+ ORA, AND, EOR,
+ ADC, SBC, CMP, CPX, CPY,
+ DEC, DEX, DEY, INC, INX, INY,
+ ASL, ROL, LSR, ROR,
+ LDA, STA, LDX, STX, LDY, STY,
+ TAX, TXA, TAY, TYA, TXS, TSX,
+ PLA, PLP, PHA, PHP,
+ BIT, NOP,
+ CLC, SEC, CLD, SED, CLI, SEI, CLV,
+ STZ, PHX, PHY, PLX, PLY, TSB, TRB,
+ SLO, RLA, SRE, RRA, SAX, LAX, DCP, ISC,
+ ANC, ALR, ARR, XAA, LXA, SBX,
+ DISCARD_AF, DISCARD_XF, DISCARD_YF)
+
+ val NoopDiscardsFlags = Set(DISCARD_AF, DISCARD_XF, DISCARD_YF)
+ val DiscardsV = NoopDiscardsFlags | OverwritesV
+ val DiscardsC = NoopDiscardsFlags | OverwritesC
+ val DiscardsD = OverwritesD
+ val DiscardsI = NoopDiscardsFlags | OverwritesI
+
+}
+
+object AssemblyLine {
+
+ def treatment(lines: List[AssemblyLine], state: State.Value): Treatment.Value =
+ lines.map(_.treatment(state)).foldLeft(Treatment.Unchanged)(_ ~ _)
+
+ def label(label: String): AssemblyLine = AssemblyLine.label(Label(label))
+
+ def label(label: Label): AssemblyLine = AssemblyLine(LABEL, AddrMode.DoesNotExist, label.toAddress)
+
+ def discardAF() = AssemblyLine(DISCARD_AF, AddrMode.DoesNotExist, Constant.Zero)
+
+ def discardXF() = AssemblyLine(DISCARD_XF, AddrMode.DoesNotExist, Constant.Zero)
+
+ def discardYF() = AssemblyLine(DISCARD_YF, AddrMode.DoesNotExist, Constant.Zero)
+
+ def immediate(opcode: Opcode.Value, value: Int) = AssemblyLine(opcode, AddrMode.Immediate, NumericConstant(value, 1))
+
+ def immediate(opcode: Opcode.Value, value: Constant) = AssemblyLine(opcode, AddrMode.Immediate, value)
+
+ def implied(opcode: Opcode.Value) = AssemblyLine(opcode, AddrMode.Implied, Constant.Zero)
+
+ def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] =
+ variable match {
+ case v@MemoryVariable(_, _, VariableAllocationMethod.Zeropage) =>
+ List(AssemblyLine.zeropage(opcode, v.toAddress + offset))
+ case v@RelativeVariable(_, _, _, true) =>
+ List(AssemblyLine.zeropage(opcode, v.toAddress + offset))
+ case v:VariableInMemory => List(AssemblyLine.absolute(opcode, v.toAddress + offset))
+ case v:StackVariable=> List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(opcode, v.baseOffset + offset + ctx.extraStackOffset))
+ }
+
+ def zeropage(opcode: Opcode.Value, addr: Constant) =
+ AssemblyLine(opcode, AddrMode.ZeroPage, addr)
+
+ def zeropage(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.ZeroPage, thing.toAddress + offset)
+
+ def absolute(opcode: Opcode.Value, addr: Constant) =
+ AssemblyLine(opcode, AddrMode.Absolute, addr)
+
+ def absolute(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.Absolute, thing.toAddress + offset)
+
+ def relative(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.Relative, thing.toAddress + offset)
+
+ def relative(opcode: Opcode.Value, label: String) =
+ AssemblyLine(opcode, AddrMode.Relative, Label(label).toAddress)
+
+ def absoluteY(opcode: Opcode.Value, addr: Constant) =
+ AssemblyLine(opcode, AddrMode.AbsoluteY, addr)
+
+ def absoluteY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.AbsoluteY, thing.toAddress + offset)
+
+ def absoluteX(opcode: Opcode.Value, addr: Int) =
+ AssemblyLine(opcode, AddrMode.AbsoluteX, NumericConstant(addr, 2))
+
+ def absoluteX(opcode: Opcode.Value, addr: Constant) =
+ AssemblyLine(opcode, AddrMode.AbsoluteX, addr)
+
+ def absoluteX(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.AbsoluteX, thing.toAddress + offset)
+
+ def indexedY(opcode: Opcode.Value, addr: Constant) =
+ AssemblyLine(opcode, AddrMode.IndexedY, addr)
+
+ def indexedY(opcode: Opcode.Value, thing: ThingInMemory, offset: Int = 0) =
+ AssemblyLine(opcode, AddrMode.IndexedY, thing.toAddress + offset)
+}
+
+case class AssemblyLine(opcode: Opcode.Value, addrMode: AddrMode.Value, var parameter: Constant, elidable: Boolean = true) {
+
+
+ import AddrMode._
+ import State._
+ import OpcodeClasses._
+ import Treatment._
+
+ def reads(state: State.Value): Boolean = state match {
+ case A => if (addrMode == Implied) ReadsAIfImplied(opcode) else ReadsAAlways(opcode)
+ case X => addrMode == AbsoluteX || addrMode == ZeroPageX || addrMode == IndexedX || ReadsXAlways(opcode)
+ case Y => addrMode == AbsoluteY || addrMode == ZeroPageY || addrMode == IndexedY || ReadsYAlways(opcode)
+ case C => ReadsC(opcode)
+ case D => ReadsD(opcode)
+ case N => ReadsN(opcode)
+ case V => ReadsV(opcode)
+ case Z => ReadsZ(opcode)
+ }
+
+ def treatment(state: State.Value): Treatment.Value = opcode match {
+ case LABEL => Unchanged // TODO: ???
+ case NOP => Unchanged
+ case JSR | JMP | BEQ | BNE | BMI | BPL | BRK | BCC | BVC | BCS | BVS => Changed
+ case CLC => if (state == C) Cleared else Unchanged
+ case SEC => if (state == C) Set else Unchanged
+ case CLV => if (state == V) Cleared else Unchanged
+ case CLD => if (state == D) Cleared else Unchanged
+ case SED => if (state == D) Set else Unchanged
+ case _ => state match { // TODO: smart detection of constants
+ case A =>
+ if (ChangesAAlways(opcode) || addrMode == Implied && ChangesAIfImplied(opcode))
+ Changed
+ else
+ Unchanged
+ case X => if (ChangesX(opcode)) Changed else Unchanged
+ case Y => if (ChangesY(opcode)) Changed else Unchanged
+ case C => if (ChangesC(opcode)) Changed else Unchanged
+ case V => if (ChangesV(opcode)) Changed else Unchanged
+ case N | Z => if (ChangesNAndZ(opcode)) Changed else Unchanged
+ case D => Unchanged
+ }
+ }
+
+ def sizeInBytes: Int = addrMode match {
+ case Implied => 1
+ case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2
+ case AbsoluteX | Absolute | AbsoluteY | Indirect => 3
+ case DoesNotExist => 0
+ }
+
+ def cost: Int = addrMode match {
+ case Implied => 1000
+ case Relative | Immediate => 2000
+ case ZeroPage => 2001
+ case ZeroPageX | ZeroPageY => 2002
+ case IndexedX | IndexedY => 2003
+ case Absolute => 3000
+ case AbsoluteX | AbsoluteY | Indirect => 3001
+ case DoesNotExist => 1
+ }
+
+ def isPrintable: Boolean = true //addrMode != AddrMode.DoesNotExist || opcode == LABEL
+
+ override def toString: String =
+ if (opcode == LABEL) {
+ parameter.toString
+ } else if (addrMode == DoesNotExist) {
+ s" ; $opcode"
+ } else {
+ s" $opcode ${AddrMode.addrModeToString(addrMode, parameter.toString)}"
+ }
+}
diff --git a/src/main/scala/millfork/assembly/Chunk.scala b/src/main/scala/millfork/assembly/Chunk.scala
new file mode 100644
index 00000000..117587b1
--- /dev/null
+++ b/src/main/scala/millfork/assembly/Chunk.scala
@@ -0,0 +1,33 @@
+package millfork.assembly
+
+import millfork.env.Label
+
+sealed trait Chunk {
+ def linearize: List[AssemblyLine]
+ def sizeInBytes: Int
+}
+
+case object EmptyChunk extends Chunk {
+ override def linearize: Nil.type = Nil
+
+ override def sizeInBytes = 0
+}
+
+case class LabelledChunk(label: String, chunk: Chunk) extends Chunk {
+ override def linearize: List[AssemblyLine] = AssemblyLine.label(Label(label)) :: chunk.linearize
+
+ override def sizeInBytes: Int = chunk.sizeInBytes
+}
+
+case class SequenceChunk(chunks: List[Chunk]) extends Chunk {
+ override def linearize: List[AssemblyLine] = chunks.flatMap(_.linearize)
+
+ override def sizeInBytes: Int = chunks.map(_.sizeInBytes).sum
+}
+
+case class LinearChunk(lines: List[AssemblyLine]) extends Chunk {
+ def linearize: List[AssemblyLine] = lines
+
+ override def sizeInBytes: Int = lines.map(_.sizeInBytes).sum
+
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/assembly/Opcode.scala b/src/main/scala/millfork/assembly/Opcode.scala
new file mode 100644
index 00000000..55e5b099
--- /dev/null
+++ b/src/main/scala/millfork/assembly/Opcode.scala
@@ -0,0 +1,190 @@
+package millfork.assembly
+
+import java.util.Locale
+
+import millfork.error.ErrorReporting
+import millfork.node.Position
+
+object State extends Enumeration {
+ val A, X, Y, Z, D, C, N, V = Value
+}
+
+object Treatment extends Enumeration {
+ val Unchanged, Unsure, Changed, Cleared, Set = Value
+
+ implicit class OverriddenValue(val left: Value) extends AnyVal {
+ def ~(right: Treatment.Value): Treatment.Value = right match {
+ case Unchanged => left
+ case Cleared | Set => if (left == Unsure) Changed else right
+ case _ => right
+ }
+ }
+
+}
+
+object Opcode extends Enumeration {
+ val ADC, AND, ASL,
+ BIT, BNE, BEQ, BPL, BMI, BVS, BVC, BCC, BCS, BRK,
+ CMP, CPX, CPY, CLV, CLC, CLI, CLD,
+ DEC, DEX, DEY,
+ EOR,
+ INC, INX, INY,
+ JMP, JSR,
+ LDA, LDX, LDY, LSR,
+ NOP,
+ ORA,
+ PHA, PHP, PLA, PLP,
+ ROL, ROR, RTS, RTI,
+ SBC, SEC, SED, SEI, STA, STX, STY,
+ TAX, TAY, TXA, TXS, TSX, TYA,
+
+ LXA, XAA, ANC, ARR, ALR, SBX,
+ LAX, SAX, RLA, RRA, SLO, SRE, DCP, ISC,
+ TAS, LAS, SHX, SHY, AHX,
+ STZ, PHX, PHY, PLX, PLY,
+ BRA, TRB, TSB, STP, WAI,
+ DISCARD_AF, DISCARD_XF, DISCARD_YF,
+ LABEL = Value
+
+ def lookup(opcode: String, position: Option[Position]): Opcode.Value = opcode.toUpperCase(Locale.ROOT) match {
+ case "ADC" => ADC
+ case "AHX" => AHX
+ case "ALR" => ALR
+ case "ANC" => ANC
+ case "AND" => AND
+ case "ANE" => XAA
+ case "ARR" => ARR
+ case "ASL" => ASL
+ case "ASO" => SLO
+ case "AXA" => AHX
+ case "AXS" => SBX // TODO: could mean SAX
+ case "BCC" => BCC
+ case "BCS" => BCS
+ case "BEQ" => BEQ
+ case "BIT" => BIT
+ case "BMI" => BMI
+ case "BNE" => BNE
+ case "BPL" => BPL
+ case "BRA" => BRA
+ case "BRK" => BRK
+ case "BVC" => BVC
+ case "BVS" => BVS
+ case "CLC" => CLC
+ case "CLD" => CLD
+ case "CLI" => CLI
+ case "CLV" => CLV
+ case "CMP" => CMP
+ case "CPX" => CPX
+ case "CPY" => CPY
+ case "DCM" => DCP
+ case "DCP" => DCP
+ case "DEC" => DEC
+ case "DEX" => DEX
+ case "DEY" => DEY
+ case "EOR" => EOR
+ case "INC" => INC
+ case "INS" => ISC
+ case "INX" => INX
+ case "INY" => INY
+ case "ISC" => ISC
+ case "JMP" => JMP
+ case "JSR" => JSR
+ case "LAS" => LAS
+ case "LAX" => LAX
+ case "LDA" => LDA
+ case "LDX" => LDX
+ case "LDY" => LDY
+ case "LSE" => SRE
+ case "LSR" => LSR
+ case "LXA" => LXA
+ case "NOP" => NOP
+ case "OAL" => LXA
+ case "ORA" => ORA
+ case "PHA" => PHA
+ case "PHP" => PHP
+ case "PHX" => PHX
+ case "PHY" => PHY
+ case "PLA" => PLA
+ case "PLP" => PLP
+ case "PLX" => PLX
+ case "PLY" => PLY
+ case "RLA" => RLA
+ case "ROL" => ROL
+ case "ROR" => ROR
+ case "RRA" => RRA
+ case "RTI" => RTI
+ case "RTS" => RTS
+ case "SAX" => SAX // TODO: could mean SBX
+ case "SAY" => SHY
+ case "SBC" => SBC
+ case "SBX" => SBX
+ case "SEC" => SEC
+ case "SED" => SED
+ case "SEI" => SEI
+ case "SHX" => SHX
+ case "SHY" => SHY
+ case "SLO" => SLO
+ case "SRE" => SRE
+ case "STA" => STA
+ case "STP" => STP
+ case "STX" => STX
+ case "STY" => STY
+ case "STZ" => STZ
+ case "TAS" => TAS
+ case "TAX" => TAX
+ case "TAY" => TAY
+ case "TRB" => TRB
+ case "TSB" => TSB
+ case "TSX" => TSX
+ case "TXA" => TXA
+ case "TXS" => TXS
+ case "TYA" => TYA
+ case "WAI" => WAI
+ case "XAA" => XAA
+ case "XAS" => SHX
+ case _ =>
+ ErrorReporting.error(s"Invalid opcode `$opcode`", position)
+ LABEL
+ }
+
+}
+
+object AddrMode extends Enumeration {
+ val Implied,
+ Immediate,
+ Relative,
+ ZeroPage,
+ ZeroPageX,
+ ZeroPageY,
+ Absolute,
+ AbsoluteX,
+ AbsoluteY,
+ Indirect,
+ IndexedX,
+ IndexedY,
+ AbsoluteIndexedX,
+ ZeroPageIndirect,
+ Undecided,
+ DoesNotExist = Value
+
+
+ def argumentLength(a: AddrMode.Value): Int = a match {
+ case Absolute | AbsoluteX | AbsoluteY | Indirect =>
+ 2
+ case _ =>
+ 1
+ }
+
+ def addrModeToString(am: AddrMode.Value, argument: String): String = {
+ am match {
+ case Implied => ""
+ case Immediate => "#" + argument
+ case AbsoluteX | ZeroPageX => argument + ", X"
+ case AbsoluteY | ZeroPageY => argument + ", Y"
+ case IndexedX | AbsoluteIndexedX => "(" + argument + ", X)"
+ case IndexedY => "(" + argument + "), Y"
+ case Indirect | ZeroPageIndirect => "(" + argument + ")"
+ case _ => argument;
+ }
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala
new file mode 100644
index 00000000..fbf57de3
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala
@@ -0,0 +1,848 @@
+package millfork.assembly.opt
+
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicInteger
+
+import millfork.assembly.{opt, _}
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.assembly.OpcodeClasses._
+import millfork.env._
+
+/**
+ * These optimizations should not remove opportunities for more complex optimizations to trigger.
+ *
+ * @author Karol Stasiak
+ */
+object AlwaysGoodOptimizations {
+
+ val counter = new AtomicInteger(30000)
+
+ def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d"
+
+ val PointlessMath = new RuleBasedAssemblyOptimization("Pointless math",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcode(CLC) & Elidable) ~
+ (HasOpcode(ADC) & Elidable & MatchParameter(0)) ~
+ (HasOpcode(SEC) & Elidable) ~
+ (HasOpcode(SBC) & Elidable & MatchParameter(0)) ~
+ (LinearOrLabel & Not(ReadsNOrZ) & Not(ReadsV) & Not(ReadsC) & Not(NoopDiscardsFlags) & Not(Set(ADC, SBC))).* ~
+ (NoopDiscardsFlags | Set(ADC, SBC)) ~~> (_.drop(4)),
+ (HasOpcode(LDA) & HasImmediate(0) & Elidable) ~
+ (HasOpcode(CLC) & Elidable) ~
+ (HasOpcode(ADC) & Elidable) ~
+ (LinearOrLabel & Not(ReadsV) & Not(NoopDiscardsFlags) & Not(ChangesNAndZ)).* ~
+ (NoopDiscardsFlags | ChangesNAndZ) ~~> (code => code(2).copy(opcode = LDA) :: code.drop(3))
+ )
+
+ val PointlessMathFromFlow = new RuleBasedAssemblyOptimization("Pointless math from flow analysis",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (Elidable & MatchA(0) &
+ HasOpcode(ASL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil
+ },
+ (Elidable & MatchA(0) &
+ HasOpcode(LSR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil
+ },
+ (Elidable & MatchA(0) &
+ HasClear(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Int](0) << 1) :: Nil
+ },
+ (Elidable & MatchA(0) &
+ HasClear(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, (ctx.get[Int](0) & 0xff) >> 1) :: Nil
+ },
+ (Elidable & MatchA(0) &
+ HasSet(State.C) & HasOpcode(ROL) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Int](0) * 2 + 1) :: Nil
+ },
+ (Elidable & MatchA(0) &
+ HasSet(State.C) & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, 0x80 + (ctx.get[Int](0) & 0xff) / 2) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(ADC) & HasAddrMode(Immediate) &
+ HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(ADC) & HasAddrMode(Immediate) &
+ HasClear(State.D) & HasClear(State.C) & DoesntMatterWhatItDoesWith(State.V)) ~
+ Where(ctx => (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify match {
+ case NumericConstant(x, _) => x == (x & 0xff)
+ case _ => false
+ }) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ctx.get[Int](0)) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(ADC) & HasAddrMode(Immediate) &
+ HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, ctx.get[Constant](1) + ((ctx.get[Int](0) + 1) & 0xff)) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(SBC) & HasAddrMode(Immediate) &
+ HasClear(State.D) & HasSet(State.C) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Minus, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(EOR) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Exor, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(ORA) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.Or, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(AND) & HasAddrMode(Immediate)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
+ },
+ (Elidable &
+ MatchA(0) & MatchParameter(1) &
+ HasOpcode(ANC) & HasAddrMode(Immediate) & DoesntMatterWhatItDoesWith(State.C)) ~~> { (code, ctx) =>
+ AssemblyLine.immediate(LDA, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0), 1), ctx.get[Constant](1)).quickSimplify) :: Nil
+ },
+ )
+
+ val MathOperationOnTwoIdenticalMemoryOperands = new RuleBasedAssemblyOptimization("Math operation on two identical memory operands",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
+ (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
+
+ (HasOpcodeIn(Set(STA, LDA)) & HasAddrMode(AbsoluteX) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesX)).* ~
+ (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteX) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
+
+ (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrMode(AbsoluteY) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA) & Not(ChangesY)).* ~
+ (HasClear(State.D) & HasClear(State.C) & HasOpcode(ADC) & HasAddrMode(AbsoluteY) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.implied(ASL)),
+
+ (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
+ (DoesntMatterWhatItDoesWith(State.N, State.Z) & HasOpcodeIn(Set(ORA, AND)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init),
+
+ (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
+ (DoesntMatterWhatItDoesWith(State.N, State.Z, State.C) & HasOpcode(ANC) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init),
+
+ (HasOpcodeIn(Set(STA, LDA, LAX)) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchAddrMode(9) & MatchParameter(0)) ~
+ (Linear & DoesntChangeMemoryAt(9, 0) & Not(ChangesA)).* ~
+ (HasOpcode(EOR) & HasAddrModeIn(Set(ZeroPage, Absolute)) & MatchParameter(0) & Elidable) ~~> (code => code.init :+ AssemblyLine.immediate(LDA, 0)),
+ )
+
+ val PointlessStoreAfterLoad = new RuleBasedAssemblyOptimization("Pointless store after load",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+ (HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
+ (Elidable & HasOpcode(STX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+ (HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & DoesntChangeMemoryAt(0,1) & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0)).* ~
+ (Elidable & HasOpcode(STY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+ )
+
+ val PoinlessStoreBeforeStore = new RuleBasedAssemblyOptimization("Pointless store before store",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasAddrModeIn(Set(Absolute, ZeroPage)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~
+ (LinearOrLabel & DoesNotConcernMemoryAt(2, 1)).* ~
+ (MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STY, STZ)) ~~> (_.tail),
+ (Elidable & HasAddrModeIn(Set(AbsoluteX, ZeroPageX)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~
+ (LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesX)).* ~
+ (MatchParameter(1) & MatchAddrMode(2) & Set(STA, STY, STZ)) ~~> (_.tail),
+ (Elidable & HasAddrModeIn(Set(AbsoluteY, ZeroPageY)) & MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~
+ (LinearOrLabel & DoesntChangeMemoryAt(2, 1) & Not(ReadsMemory) & Not(ChangesY)).* ~
+ (MatchParameter(1) & MatchAddrMode(2) & Set(STA, SAX, STX, STZ)) ~~> (_.tail),
+ )
+
+ val PointlessLoadBeforeReturn = new RuleBasedAssemblyOptimization("Pointless load before return",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Set(LDA, TXA, TYA, EOR, AND, ORA, ANC) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_AF))).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail),
+ (Set(LDX, TAX, TSX, INX, DEX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_XF))).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail),
+ (Set(LDY, TAY, INY, DEY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & Not(HasOpcode(DISCARD_YF))).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail),
+ (HasOpcode(LDX) & Elidable & MatchAddrMode(3)) ~
+ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~
+ (HasOpcode(TXA) & Elidable) ~
+ ((LinearOrLabel & Not(ConcernsX) & Not(HasOpcode(DISCARD_XF))).* ~
+ HasOpcode(DISCARD_XF)).capture(2) ~~> { (c, ctx) =>
+ ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2))
+ },
+ (HasOpcode(LDY) & Elidable & MatchAddrMode(3)) ~
+ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ) & DoesntChangeIndexingInAddrMode(3)).*.capture(1) ~
+ (HasOpcode(TYA) & Elidable) ~
+ ((LinearOrLabel & Not(ConcernsY) & Not(HasOpcode(DISCARD_YF))).* ~
+ HasOpcode(DISCARD_YF)).capture(2) ~~> { (c, ctx) =>
+ ctx.get[List[AssemblyLine]](1) ++ (c.head.copy(opcode = LDA) :: ctx.get[List[AssemblyLine]](2))
+ },
+ )
+
+ private def operationPairBuilder(op1: Opcode.Value, op2: Opcode.Value, middle: AssemblyLinePattern) = {
+ (HasOpcode(op1) & Elidable) ~
+ (Linear & middle).*.capture(1) ~
+ (HasOpcode(op2) & Elidable) ~
+ ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(2) ~~> { (_, ctx) =>
+ ctx.get[List[AssemblyLine]](1) ++ ctx.get[List[AssemblyLine]](2)
+ }
+ }
+
+ val PointlessOperationPairRemoval = new RuleBasedAssemblyOptimization("Pointless operation pair",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ operationPairBuilder(PHA, PLA, Not(ConcernsA) & Not(ConcernsStack)),
+ operationPairBuilder(PHX, PLX, Not(ConcernsX) & Not(ConcernsStack)),
+ operationPairBuilder(PHY, PLY, Not(ConcernsY) & Not(ConcernsStack)),
+ operationPairBuilder(INX, DEX, Not(ConcernsX) & Not(ReadsNOrZ)),
+ operationPairBuilder(DEX, INX, Not(ConcernsX) & Not(ReadsNOrZ)),
+ operationPairBuilder(INY, DEY, Not(ConcernsX) & Not(ReadsNOrZ)),
+ operationPairBuilder(DEY, INY, Not(ConcernsX) & Not(ReadsNOrZ)),
+ )
+
+
+ val BranchInPlaceRemoval = new RuleBasedAssemblyOptimization("Branch in place",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (AllDirectJumps & MatchParameter(0) & Elidable) ~
+ HasOpcodeIn(NoopDiscardsFlags).* ~
+ (HasOpcode(LABEL) & MatchParameter(0)) ~~> (c => c.last :: Nil)
+ )
+
+ val ImpossibleBranchRemoval = new RuleBasedAssemblyOptimization("Impossible branch",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (HasOpcode(BCC) & HasSet(State.C) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BCS) & HasClear(State.C) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BVC) & HasSet(State.V) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BVS) & HasClear(State.V) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BNE) & HasSet(State.Z) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BEQ) & HasClear(State.Z) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BPL) & HasSet(State.N) & Elidable) ~~> (_ => Nil),
+ (HasOpcode(BMI) & HasClear(State.N) & Elidable) ~~> (_ => Nil),
+ )
+
+ val UnconditionalJumpRemoval = new RuleBasedAssemblyOptimization("Unconditional jump removal",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~
+ (Elidable & LinearOrBranch).* ~
+ (HasOpcode(LABEL) & MatchParameter(0)) ~~> (_ => Nil),
+ (Elidable & HasOpcode(JMP) & HasAddrMode(Absolute) & MatchParameter(0)) ~
+ (Not(HasOpcode(LABEL)) & Not(MatchParameter(0))).* ~
+ (HasOpcode(LABEL) & MatchParameter(0)) ~
+ (HasOpcode(LABEL) | HasOpcodeIn(NoopDiscardsFlags)).* ~
+ HasOpcode(RTS) ~~> (code => AssemblyLine.implied(RTS) :: code.tail),
+ (Elidable & HasOpcodeIn(ShortBranching) & MatchParameter(0)) ~
+ (HasOpcodeIn(NoopDiscardsFlags).* ~
+ (Elidable & HasOpcode(RTS))).capture(1) ~
+ (HasOpcode(LABEL) & MatchParameter(0)) ~
+ HasOpcodeIn(NoopDiscardsFlags).* ~
+ (Elidable & HasOpcode(RTS)) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1)),
+ )
+
+ val TailCallOptimization = new RuleBasedAssemblyOptimization("Tail call optimization",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(JSR)) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (Elidable & HasOpcode(RTS)) ~~> (c => c.tail.init :+ c.head.copy(opcode = JMP)),
+ (Elidable & HasOpcode(JSR)) ~
+ HasOpcode(LABEL).* ~
+ HasOpcodeIn(NoopDiscardsFlags).*.capture(0) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](0) ++ (code.head.copy(opcode = JMP) :: code.tail)),
+ )
+
+ val UnusedCodeRemoval = new RuleBasedAssemblyOptimization("Unreachable code removal",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ HasOpcode(JMP) ~ (Not(HasOpcode(LABEL)) & Elidable).+ ~~> (c => c.head :: Nil)
+ )
+
+ val PoinlessFlagChange = new RuleBasedAssemblyOptimization("Pointless flag change",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcodeIn(Set(CMP, CPX, CPY)) & Elidable) ~ NoopDiscardsFlags ~~> (_.tail),
+ (OverwritesC & Elidable) ~ (LinearOrLabel & Not(ReadsC) & Not(DiscardsC)).* ~ DiscardsC ~~> (_.tail),
+ (OverwritesD & Elidable) ~ (LinearOrLabel & Not(ReadsD) & Not(DiscardsD)).* ~ DiscardsD ~~> (_.tail),
+ (OverwritesV & Elidable) ~ (LinearOrLabel & Not(ReadsV) & Not(DiscardsV)).* ~ DiscardsV ~~> (_.tail)
+ )
+
+ val FlagFlowAnalysis = new RuleBasedAssemblyOptimization("Flag flow analysis",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (HasSet(State.C) & HasOpcode(SEC) & Elidable) ~~> (_ => Nil),
+ (HasSet(State.D) & HasOpcode(SED) & Elidable) ~~> (_ => Nil),
+ (HasClear(State.C) & HasOpcode(CLC) & Elidable) ~~> (_ => Nil),
+ (HasClear(State.D) & HasOpcode(CLD) & Elidable) ~~> (_ => Nil),
+ (HasClear(State.V) & HasOpcode(CLV) & Elidable) ~~> (_ => Nil),
+ (HasSet(State.C) & HasOpcode(BCS) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasClear(State.C) & HasOpcode(BCC) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasSet(State.N) & HasOpcode(BMI) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasClear(State.N) & HasOpcode(BPL) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasClear(State.V) & HasOpcode(BVC) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasSet(State.V) & HasOpcode(BVS) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasSet(State.Z) & HasOpcode(BEQ) & Elidable) ~~> (c => c.map(_.copy(opcode = JMP, addrMode = Absolute))),
+ (HasClear(State.Z) & HasOpcode(BNE) & Elidable) ~~> (_ => Nil),
+ )
+
+ val ReverseFlowAnalysis = new RuleBasedAssemblyOptimization("Reverse flow analysis",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (Elidable & HasOpcodeIn(Set(TXA, TYA, LDA, EOR, ORA, AND)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcode(ANC) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(TAX, TSX, LDX, INX, DEX)) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(TAY, LDY, DEY, INY)) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(LAX)) & DoesntMatterWhatItDoesWith(State.A, State.X, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(SEC, CLC)) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(CLD, SED)) & DoesntMatterWhatItDoesWith(State.D)) ~~> (_ => Nil),
+ (Elidable & HasOpcode(CLV) & DoesntMatterWhatItDoesWith(State.V)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(CMP, CPX, CPY)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(BIT)) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z, State.V)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(ASL, LSR, ROL, ROR)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.A, State.C, State.N, State.Z)) ~~> (_ => Nil),
+ (Elidable & HasOpcodeIn(Set(ADC, SBC)) & DoesntMatterWhatItDoesWith(State.A, State.C, State.V, State.N, State.Z)) ~~> (_ => Nil),
+ )
+
+ private def modificationOfJustWrittenValue(store: Opcode.Value,
+ addrMode: AddrMode.Value,
+ initExtra: AssemblyLinePattern,
+ modify: Opcode.Value,
+ meantimeExtra: AssemblyLinePattern,
+ atLeastTwo: Boolean,
+ flagsToTrash: Seq[State.Value],
+ fix: ((AssemblyMatchingContext, Int) => List[AssemblyLine]),
+ alternateStore: Opcode.Value = LABEL) = {
+ val actualFlagsToTrash = List(State.N, State.Z) ++ flagsToTrash
+ val init = Elidable & HasOpcode(store) & HasAddrMode(addrMode) & MatchAddrMode(3) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*) & initExtra
+ val meantime = (Linear & Not(ConcernsMemory) & meantimeExtra).*
+ val oneModification = Elidable & HasOpcode(modify) & HasAddrMode(addrMode) & MatchParameter(0) & DoesntMatterWhatItDoesWith(actualFlagsToTrash: _*)
+ val modifications = (if (atLeastTwo) oneModification ~ oneModification.+ else oneModification.+).captureLength(1)
+ if (alternateStore == LABEL) {
+ ((init ~ meantime).capture(2) ~ modifications) ~~> ((code, ctx) => fix(ctx, ctx.get[Int](1)) ++ ctx.get[List[AssemblyLine]](2))
+ } else {
+ (init.capture(3) ~ meantime.capture(2) ~ modifications) ~~> { (code, ctx) =>
+ fix(ctx, ctx.get[Int](1)) ++
+ List(AssemblyLine(alternateStore, ctx.get[AddrMode.Value](3), ctx.get[Constant](0))) ++
+ ctx.get[List[AssemblyLine]](2)
+ }
+ }
+ }
+
+ val ModificationOfJustWrittenValue = new RuleBasedAssemblyOptimization("Modification of Just written value",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ modificationOfJustWrittenValue(STA, Absolute, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, Absolute, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), INC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, ZeroPage, MatchA(5), DEC, Anything, atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), INC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) + i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, AbsoluteX, MatchA(5), DEC, Not(ChangesX), atLeastTwo = false, Seq(), (c, i) => List(
+ AssemblyLine.immediate(LDA, (c.get[Int](5) - i) & 0xff)
+ )),
+ modificationOfJustWrittenValue(STA, Absolute, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(CLC),
+ AssemblyLine.immediate(ADC, i)
+ )),
+ modificationOfJustWrittenValue(STA, Absolute, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(SEC),
+ AssemblyLine.immediate(SBC, i)
+ )),
+ modificationOfJustWrittenValue(STA, ZeroPage, Anything, INC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(CLC),
+ AssemblyLine.immediate(ADC, i)
+ )),
+ modificationOfJustWrittenValue(STA, ZeroPage, Anything, DEC, Anything, atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(SEC),
+ AssemblyLine.immediate(SBC, i)
+ )),
+ modificationOfJustWrittenValue(STA, AbsoluteX, Anything, INC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(CLC),
+ AssemblyLine.immediate(ADC, i)
+ )),
+ modificationOfJustWrittenValue(STA, AbsoluteX, Anything, DEC, Not(ChangesX), atLeastTwo = true, Seq(State.C, State.V), (_, i) => List(
+ AssemblyLine.implied(SEC),
+ AssemblyLine.immediate(SBC, i)
+ )),
+ modificationOfJustWrittenValue(STA, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
+ modificationOfJustWrittenValue(STA, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
+ modificationOfJustWrittenValue(STA, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
+ modificationOfJustWrittenValue(STA, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
+ modificationOfJustWrittenValue(STA, AbsoluteX, Anything, ASL, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(ASL))),
+ modificationOfJustWrittenValue(STA, AbsoluteX, Anything, LSR, Not(ChangesX), atLeastTwo = false, Seq(State.C), (_, i) => List.fill(i)(AssemblyLine.implied(LSR))),
+ modificationOfJustWrittenValue(STX, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))),
+ modificationOfJustWrittenValue(STX, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))),
+ modificationOfJustWrittenValue(STY, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))),
+ modificationOfJustWrittenValue(STY, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))),
+ modificationOfJustWrittenValue(STZ, Absolute, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
+ modificationOfJustWrittenValue(STZ, Absolute, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
+ modificationOfJustWrittenValue(STZ, Absolute, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA),
+ modificationOfJustWrittenValue(STZ, Absolute, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA),
+ modificationOfJustWrittenValue(STX, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INX))),
+ modificationOfJustWrittenValue(STX, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEX))),
+ modificationOfJustWrittenValue(STY, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(INY))),
+ modificationOfJustWrittenValue(STY, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(), (_, i) => List.fill(i)(AssemblyLine.implied(DEY))),
+ modificationOfJustWrittenValue(STZ, ZeroPage, Anything, ASL, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
+ modificationOfJustWrittenValue(STZ, ZeroPage, Anything, LSR, Anything, atLeastTwo = false, Seq(), (_, i) => Nil),
+ modificationOfJustWrittenValue(STZ, ZeroPage, Anything, INC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, i)), STA),
+ modificationOfJustWrittenValue(STZ, ZeroPage, Anything, DEC, Anything, atLeastTwo = false, Seq(State.A), (_, i) => List(AssemblyLine.immediate(LDA, 256 - i)), STA),
+ )
+
+ val ConstantFlowAnalysis = new RuleBasedAssemblyOptimization("Constant flow analysis",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (MatchX(0) & HasAddrMode(AbsoluteX) & SupportsAbsolute & Elidable) ~~> { (code, ctx) =>
+ code.map(l => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0)))
+ },
+ (MatchY(0) & HasAddrMode(AbsoluteY) & SupportsAbsolute & Elidable) ~~> { (code, ctx) =>
+ code.map(l => l.copy(addrMode = Absolute, parameter = l.parameter + ctx.get[Int](0)))
+ },
+ (MatchX(0) & HasAddrMode(ZeroPageX) & Elidable) ~~> { (code, ctx) =>
+ code.map(l => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0)))
+ },
+ (MatchY(0) & HasAddrMode(ZeroPageY) & Elidable) ~~> { (code, ctx) =>
+ code.map(l => l.copy(addrMode = ZeroPage, parameter = l.parameter + ctx.get[Int](0)))
+ },
+ )
+
+ val IdempotentDuplicateRemoval = new RuleBasedAssemblyOptimization("Idempotent duplicate operation",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ HasOpcode(RTS) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTS) ~ Elidable) ~~> (_.take(1)) ::
+ HasOpcode(RTI) ~ HasOpcodeIn(NoopDiscardsFlags).* ~ (HasOpcode(RTI) ~ Elidable) ~~> (_.take(1)) ::
+ HasOpcode(DISCARD_XF) ~ (Not(HasOpcode(DISCARD_XF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_XF) ~~> (_.tail) ::
+ HasOpcode(DISCARD_AF) ~ (Not(HasOpcode(DISCARD_AF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_AF) ~~> (_.tail) ::
+ HasOpcode(DISCARD_YF) ~ (Not(HasOpcode(DISCARD_YF)) & HasOpcodeIn(NoopDiscardsFlags + LABEL)).* ~ HasOpcode(DISCARD_YF) ~~> (_.tail) ::
+ List(RTS, RTI, SEC, CLC, CLV, CLD, SED, SEI, CLI, TAX, TXA, TYA, TAY, TXS, TSX).flatMap { opcode =>
+ Seq(
+ (HasOpcode(opcode) & Elidable) ~ (HasOpcodeIn(NoopDiscardsFlags) | HasOpcode(LABEL)).* ~ HasOpcode(opcode) ~~> (_.tail),
+ HasOpcode(opcode) ~ (HasOpcode(opcode) ~ Elidable) ~~> (_.init),
+ )
+ }: _*
+ )
+
+ val PointlessRegisterTransfers = new RuleBasedAssemblyOptimization("Pointless register transfers",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ HasOpcode(TYA) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init),
+ HasOpcode(TXA) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init),
+ HasOpcode(TAY) ~ (Elidable & Set(TYA, TAY)) ~~> (_.init),
+ HasOpcode(TAX) ~ (Elidable & Set(TXA, TAX)) ~~> (_.init),
+ HasOpcode(TSX) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
+ HasOpcode(TXS) ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
+ HasOpcode(TSX) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
+ HasOpcode(TXS) ~ (Not(ChangesX) & Not(ChangesS) & Linear).* ~ (Elidable & Set(TXS, TSX)) ~~> (_.init),
+ )
+
+ val PointlessRegisterTransfersBeforeStore = new RuleBasedAssemblyOptimization("Pointless register transfers before store",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (Elidable & HasOpcode(TXA)) ~
+ (Linear & Not(ConcernsA) & Not(ConcernsX)).* ~
+ (Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageY, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STX)),
+ (Elidable & HasOpcode(TYA)) ~
+ (Linear & Not(ConcernsA) & Not(ConcernsY)).* ~
+ (Elidable & HasOpcode(STA) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute)) & DoesntMatterWhatItDoesWith(State.A, State.N, State.Z)) ~~> (code => code.tail.init :+ code.last.copy(opcode = STY)),
+ )
+
+
+ val PointlessRegisterTransfersBeforeReturn = new RuleBasedAssemblyOptimization("Pointless register transfers before return",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcode(TAX) & Elidable) ~
+ HasOpcode(LABEL).* ~
+ HasOpcode(TXA).? ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
+ (HasOpcode(TSX) & Elidable) ~
+ HasOpcode(LABEL).* ~
+ HasOpcode(TSX).? ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_XF)).capture(1) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
+ (HasOpcode(TXA) & Elidable) ~
+ HasOpcode(LABEL).* ~
+ HasOpcode(TAX).? ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
+ (HasOpcode(TAY) & Elidable) ~
+ HasOpcode(LABEL).* ~
+ HasOpcode(TYA).? ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_YF)).capture(1) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
+ (HasOpcode(TYA) & Elidable) ~
+ HasOpcode(LABEL).* ~
+ HasOpcode(TAY).? ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(DISCARD_AF)).capture(1) ~
+ HasOpcode(RTS) ~~> ((code, ctx) => ctx.get[List[AssemblyLine]](1) ++ (AssemblyLine.implied(RTS) :: code.tail)),
+ )
+
+ val PointlessRegisterTransfersBeforeCompare = new RuleBasedAssemblyOptimization("Pointless register transfers before compare",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ HasOpcodeIn(Set(DEX, INX, LDX, LAX)) ~
+ (HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init),
+ HasOpcodeIn(Set(DEY, INY, LDY)) ~
+ (HasOpcode(TYA) & Elidable & DoesntMatterWhatItDoesWith(State.A)) ~~> (code => code.init),
+ )
+
+ private def stashing(tai: Opcode.Value, tia: Opcode.Value, readsI: AssemblyLinePattern, concernsI: AssemblyLinePattern, discardIF: Opcode.Value, withRts: Boolean, withBeq: Boolean) = {
+ val init: AssemblyPattern = if (withBeq) {
+ (Linear & ChangesNAndZ & ChangesA) ~
+ (HasOpcode(tai) & Elidable) ~
+ (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
+ (ShortBranching & ReadsNOrZ & MatchParameter(0))
+ } else {
+ (HasOpcode(tai) & Elidable) ~
+ (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
+ ((ShortBranching -- ReadsNOrZ) & MatchParameter(0))
+ }
+ val inner: AssemblyPattern = if (withRts) {
+ (Linear & Not(readsI) & Not(ReadsNOrZ ++ NoopDiscardsFlags)).* ~
+ ManyWhereAtLeastOne(HasOpcodeIn(NoopDiscardsFlags), HasOpcode(discardIF)) ~
+ HasOpcodeIn(Set(RTS, RTI)) ~
+ Not(HasOpcode(LABEL)).*
+ } else {
+ (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).*
+ }
+ val end: AssemblyPattern =
+ (HasOpcode(LABEL) & MatchParameter(0)) ~
+ (Linear & Not(concernsI) & Not(ChangesA) & Not(ReadsNOrZ)).* ~
+ (HasOpcode(tia) & Elidable)
+ val total = init ~ inner ~ end
+ if (withBeq) {
+ total ~~> (code => code.head :: (code.tail.tail.init :+ AssemblyLine.implied(tai)))
+ } else {
+ total ~~> (code => code.tail.init :+ AssemblyLine.implied(tai))
+ }
+ }
+
+ // Optimize the following patterns:
+ // TAX - B__ .a - don't change A - .a - TXA
+ // TAX - B__ .a - change A – discard X – RTS - .a - TXA
+ // by removing the first transfer and flipping the second one
+ val PointlessStashingToIndexOverShortSafeBranch = new RuleBasedAssemblyOptimization("Pointless stashing into index over short safe branch",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = false),
+ stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = false),
+ // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = false, withBeq = true),
+ // stashing(TAX, TXA, ReadsX, ConcernsX, DISCARD_XF, withRts = true, withBeq = true),
+ //
+ // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = false),
+ // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = false),
+ // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = false, withBeq = true),
+ // stashing(TAY, TYA, ReadsY, ConcernsY, DISCARD_YF, withRts = true, withBeq = true),
+ )
+
+ private def loadBeforeTransfer(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, ams: Set[AddrMode.Value]) =
+ (Elidable & HasOpcode(ld1) & MatchAddrMode(0) & MatchParameter(1) & HasAddrModeIn(ams)) ~
+ (Linear & Not(ReadsNOrZ) & Not(concerns1) & DoesntChangeMemoryAt(0, 1) & DoesntChangeIndexingInAddrMode(0) & Not(HasOpcode(t12))).*.capture(2) ~
+ (HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2)
+ }
+
+ val PointlessLoadBeforeTransfer = new RuleBasedAssemblyOptimization("Pointless load before transfer",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ loadBeforeTransfer(LDX, LDA, ConcernsX, State.X, TXA, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)),
+ loadBeforeTransfer(LDA, LDX, ConcernsA, State.A, TAX, Set(ZeroPage, Absolute, IndexedY, AbsoluteY)),
+ loadBeforeTransfer(LDY, LDA, ConcernsY, State.Y, TYA, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)),
+ loadBeforeTransfer(LDA, LDY, ConcernsA, State.A, TAY, Set(ZeroPage, Absolute, ZeroPageX, IndexedX, AbsoluteX)),
+ )
+
+ private def immediateLoadBeforeTwoTransfers(ld1: Opcode.Value, ld2: Opcode.Value, concerns1: AssemblyLinePattern, overwrites1: State.Value, t12: Opcode.Value, t21: Opcode.Value) =
+ (Elidable & HasOpcode(ld1) & HasAddrMode(Immediate)) ~
+ (Linear & Not(ReadsNOrZ) & Not(concerns1) & Not(HasOpcode(t12))).*.capture(2) ~
+ (HasOpcode(t12) & Elidable & DoesntMatterWhatItDoesWith(overwrites1, State.N, State.Z)) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = ld2)
+ }
+
+ val YYY = new RuleBasedAssemblyOptimization("Pointless load before transfer",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ immediateLoadBeforeTwoTransfers(LDA, LDY, ConcernsA, State.A, TAY, TYA),
+ )
+
+ val ConstantIndexPropagation = new RuleBasedAssemblyOptimization("Constant index propagation",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcode(LDX) & HasAddrMode(Immediate) & MatchParameter(0)) ~
+ (Linear & Not(ChangesX) & Not(HasAddrMode(AbsoluteX))).* ~
+ (Elidable & SupportsAbsolute & HasAddrMode(AbsoluteX)) ~~> { (lines, ctx) =>
+ val last = lines.last
+ val offset = ctx.get[Constant](0)
+ lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset)
+ },
+ (HasOpcode(LDY) & HasAddrMode(Immediate) & MatchParameter(0)) ~
+ (Linear & Not(ChangesY) & Not(HasAddrMode(AbsoluteY))).* ~
+ (Elidable & SupportsAbsolute & HasAddrMode(AbsoluteY)) ~~> { (lines, ctx) =>
+ val last = lines.last
+ val offset = ctx.get[Constant](0)
+ lines.init :+ last.copy(addrMode = Absolute, parameter = last.parameter + offset)
+ },
+ )
+
+ val PoinlessLoadBeforeAnotherLoad = new RuleBasedAssemblyOptimization("Pointless load before another load",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Set(LDA, TXA, TYA) & Elidable) ~ (LinearOrLabel & Not(ConcernsA) & Not(ReadsNOrZ)).* ~ OverwritesA ~~> (_.tail),
+ (Set(LDX, TAX, TSX) & Elidable) ~ (LinearOrLabel & Not(ConcernsX) & Not(ReadsNOrZ)).* ~ OverwritesX ~~> (_.tail),
+ (Set(LDY, TAY) & Elidable) ~ (LinearOrLabel & Not(ConcernsY) & Not(ReadsNOrZ)).* ~ OverwritesY ~~> (_.tail),
+ )
+
+ // TODO: better proofs that memory doesn't change
+ val PointlessLoadAfterLoadOrStore = new RuleBasedAssemblyOptimization("Pointless load after load or store",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+
+ (HasOpcodeIn(Set(LDA, STA)) & HasAddrMode(Implied) & MatchParameter(1)) ~
+ (Linear & Not(ChangesA)).* ~
+ (Elidable & HasOpcode(LDA) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
+
+ (HasOpcodeIn(Set(LDX, STX)) & HasAddrMode(Implied) & MatchParameter(1)) ~
+ (Linear & Not(ChangesX)).* ~
+ (Elidable & HasOpcode(LDX) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
+
+ (HasOpcodeIn(Set(LDY, STY)) & HasAddrMode(Implied) & MatchParameter(1)) ~
+ (Linear & Not(ChangesY)).* ~
+ (Elidable & HasOpcode(LDY) & HasAddrMode(Implied) & MatchParameter(1)) ~~> (_.init),
+
+ (HasOpcodeIn(Set(LDA, STA)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ChangesA) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+
+ (HasOpcodeIn(Set(LDX, STX)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ChangesX) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
+ (Elidable & HasOpcode(LDX) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+
+ (HasOpcodeIn(Set(LDY, STY)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ChangesY) & DoesntChangeIndexingInAddrMode(0) & DoesntChangeMemoryAt(0, 1)).* ~
+ (Elidable & HasOpcode(LDY) & MatchAddrMode(0) & MatchParameter(1)) ~~> (_.init),
+ )
+
+ val PointlessOperationAfterLoad = new RuleBasedAssemblyOptimization("Pointless operation after load",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(EOR) & HasImmediate(0)) ~~> (_.init),
+ (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(ORA) & HasImmediate(0)) ~~> (_.init),
+ (ChangesA & ChangesNAndZ) ~ (Elidable & HasOpcode(AND) & HasImmediate(0xff)) ~~> (_.init)
+ )
+
+ val SimplifiableBitOpsSequence = new RuleBasedAssemblyOptimization("Simplifiable sequence of bit operations",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(EOR) & MatchImmediate(0)) ~
+ (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
+ (Elidable & HasOpcode(EOR) & MatchImmediate(1)) ~~> { (lines, ctx) =>
+ lines.init.tail :+ AssemblyLine.immediate(EOR, CompoundConstant(MathOperator.Exor, ctx.get[Constant](0), ctx.get[Constant](1)))
+ },
+ (Elidable & HasOpcode(ORA) & MatchImmediate(0)) ~
+ (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
+ (Elidable & HasOpcode(ORA) & MatchImmediate(1)) ~~> { (lines, ctx) =>
+ lines.init.tail :+ AssemblyLine.immediate(ORA, CompoundConstant(MathOperator.Or, ctx.get[Constant](0), ctx.get[Constant](1)))
+ },
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsA)).* ~
+ (Elidable & HasOpcode(AND) & MatchImmediate(1)) ~~> { (lines, ctx) =>
+ lines.init.tail :+ AssemblyLine.immediate(AND, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1)))
+ },
+ (Elidable & HasOpcode(ANC) & MatchImmediate(0)) ~
+ (Linear & Not(ChangesA) & Not(ReadsNOrZ) & Not(ReadsC) & Not(ReadsA)).* ~
+ (Elidable & HasOpcode(ANC) & MatchImmediate(1)) ~~> { (lines, ctx) =>
+ lines.init.tail :+ AssemblyLine.immediate(ANC, CompoundConstant(MathOperator.And, ctx.get[Constant](0), ctx.get[Constant](1)))
+ },
+ )
+
+ val RemoveNops = new RuleBasedAssemblyOptimization("Removing NOP instructions",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(NOP)) ~~> (_ => Nil)
+ )
+
+ val RearrangeMath = new RuleBasedAssemblyOptimization("Rearranging math",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~
+ (Elidable & HasOpcodeIn(Set(CLC, SEC))) ~
+ (Elidable & HasOpcode(ADC) & Not(HasAddrMode(Immediate))) ~~> { c =>
+ c.last.copy(opcode = LDA) :: c(1) :: c.head.copy(opcode = ADC) :: Nil
+ },
+ (Elidable & HasOpcode(LDA) & HasAddrMode(Immediate)) ~
+ (Elidable & HasOpcodeIn(Set(ADC, EOR, ORA, AND)) & Not(HasAddrMode(Immediate))) ~~> { c =>
+ c.last.copy(opcode = LDA) :: c.head.copy(opcode = c.last.opcode) :: Nil
+ },
+ )
+
+ private def wordShifting(i: Int, hiFirst: Boolean, hiFromX: Boolean) = {
+ val ldax = if (hiFromX) LDX else LDA
+ val stax = if (hiFromX) STX else STA
+ val restriction = if (hiFromX) Not(ReadsX) else Anything
+ val originalStart = if (hiFirst) {
+ (Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~
+ (Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction) ~
+ (Elidable & HasOpcode(ldax) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5))
+ } else {
+ (Elidable & HasOpcode(ldax) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(stax) & MatchParameter(4) & MatchAddrMode(5)) ~
+ (Elidable & HasOpcode(LDA) & MatchParameter(0) & MatchAddrMode(1)) ~
+ (Elidable & HasOpcode(STA) & MatchParameter(2) & MatchAddrMode(3) & restriction)
+ }
+ val middle = (Linear & Not(ConcernsMemory) & DoesntChangeIndexingInAddrMode(3) & DoesntChangeIndexingInAddrMode(5)).*
+ val singleOriginalShift =
+ (Elidable & HasOpcode(ASL) & MatchParameter(2) & MatchAddrMode(3)) ~
+ (Elidable & HasOpcode(ROL) & MatchParameter(4) & MatchAddrMode(5) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z))
+ val originalShifting = (1 to i).map(_ => singleOriginalShift).reduce(_ ~ _)
+ originalStart ~ middle.capture(6) ~ originalShifting ~~> { (code, ctx) =>
+ val newStart = List(
+ code(0),
+ code(1).copy(addrMode = code(3).addrMode, parameter = code(3).parameter),
+ code(2),
+ code(3).copy(addrMode = code(1).addrMode, parameter = code(1).parameter))
+ val middle = ctx.get[List[AssemblyLine]](6)
+ val singleNewShift = List(
+ AssemblyLine(LSR, ctx.get[AddrMode.Value](5), ctx.get[Constant](4)),
+ AssemblyLine(ROR, ctx.get[AddrMode.Value](3), ctx.get[Constant](2)))
+ newStart ++ middle ++ (i until 8).flatMap(_ => singleNewShift)
+ }
+ }
+
+ val SmarterShiftingWords = new RuleBasedAssemblyOptimization("Smarter shifting of words",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ wordShifting(8, hiFirst = false, hiFromX = true),
+ wordShifting(8, hiFirst = false, hiFromX = false),
+ wordShifting(8, hiFirst = true, hiFromX = true),
+ wordShifting(8, hiFirst = true, hiFromX = false),
+ wordShifting(7, hiFirst = false, hiFromX = true),
+ wordShifting(7, hiFirst = false, hiFromX = false),
+ wordShifting(7, hiFirst = true, hiFromX = true),
+ wordShifting(7, hiFirst = true, hiFromX = false),
+ wordShifting(6, hiFirst = false, hiFromX = true),
+ wordShifting(6, hiFirst = false, hiFromX = false),
+ wordShifting(6, hiFirst = true, hiFromX = true),
+ wordShifting(6, hiFirst = true, hiFromX = false),
+ wordShifting(5, hiFirst = false, hiFromX = true),
+ wordShifting(5, hiFirst = false, hiFromX = false),
+ wordShifting(5, hiFirst = true, hiFromX = true),
+ wordShifting(5, hiFirst = true, hiFromX = false),
+ )
+
+ private def carryFlagConversionCase(shift: Int, firstSet: Boolean, zeroIfSet: Boolean) = {
+ val nonZero = 1 << shift
+ val test = Elidable & HasOpcode(if (firstSet) BCC else BCS) & MatchParameter(0)
+ val ifSet = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) 0 else nonZero)
+ val ifClear = Elidable & HasOpcode(LDA) & HasImmediate(if (zeroIfSet) nonZero else 0)
+ val jump = Elidable & HasOpcodeIn(Set(JMP, if (firstSet) BCS else BCC, if (zeroIfSet) BEQ else BNE)) & MatchParameter(1)
+ val elseLabel = Elidable & HasOpcode(LABEL) & MatchParameter(0)
+ val afterLabel = Elidable & HasOpcode(LABEL) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.C, State.N, State.V, State.Z)
+ val store = Elidable & (Not(ReadsC) & Linear | HasOpcodeIn(Set(RTS, JSR, RTI)))
+ val secondReturn = (Elidable & HasOpcodeIn(Set(RTS, RTI) | NoopDiscardsFlags)).*.capture(6)
+ val where = Where { ctx =>
+ ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) ||
+ ctx.get[List[AssemblyLine]](4) == ctx.get[List[AssemblyLine]](5) ++ ctx.get[List[AssemblyLine]](6)
+ }
+ val pattern =
+ if (firstSet) test ~ ifSet ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifClear ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where
+ else test ~ ifClear ~ store.*.capture(4) ~ jump ~ elseLabel ~ ifSet ~ store.*.capture(5) ~ afterLabel ~ secondReturn ~ where
+ pattern ~~> { (_, ctx) =>
+ List(
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.implied(if (shift >= 4) ROR else ROL)) ++
+ (if (shift >= 4) List.fill(7 - shift)(AssemblyLine.implied(LSR)) else List.fill(shift)(AssemblyLine.implied(ASL))) ++
+ (if (zeroIfSet) List(AssemblyLine.immediate(EOR, nonZero)) else Nil) ++
+ ctx.get[List[AssemblyLine]](5) ++
+ ctx.get[List[AssemblyLine]](6)
+ }
+ }
+
+ val CarryFlagConversion = new RuleBasedAssemblyOptimization("Carry flag conversion",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ // TODO: These yield 2 cycles more but 1–2 bytes less
+ // TODO: Add an "optimize for size" compilation option?
+ // carryFlagConversionCase(2, firstSet = false, zeroIfSet = false),
+ // carryFlagConversionCase(2, firstSet = true, zeroIfSet = false),
+ // carryFlagConversionCase(1, firstSet = true, zeroIfSet = true),
+ // carryFlagConversionCase(1, firstSet = false, zeroIfSet = true),
+ carryFlagConversionCase(1, firstSet = false, zeroIfSet = false),
+ carryFlagConversionCase(1, firstSet = true, zeroIfSet = false),
+ carryFlagConversionCase(0, firstSet = true, zeroIfSet = true),
+ carryFlagConversionCase(0, firstSet = false, zeroIfSet = true),
+ carryFlagConversionCase(0, firstSet = false, zeroIfSet = false),
+ carryFlagConversionCase(0, firstSet = true, zeroIfSet = false),
+ // carryFlagConversionCase(5, firstSet = false, zeroIfSet = false),
+ // carryFlagConversionCase(5, firstSet = true, zeroIfSet = false),
+ // carryFlagConversionCase(6, firstSet = true, zeroIfSet = true),
+ // carryFlagConversionCase(6, firstSet = false, zeroIfSet = true),
+ carryFlagConversionCase(6, firstSet = false, zeroIfSet = false),
+ carryFlagConversionCase(6, firstSet = true, zeroIfSet = false),
+ carryFlagConversionCase(7, firstSet = true, zeroIfSet = true),
+ carryFlagConversionCase(7, firstSet = false, zeroIfSet = true),
+ carryFlagConversionCase(7, firstSet = false, zeroIfSet = false),
+ carryFlagConversionCase(7, firstSet = true, zeroIfSet = false),
+ )
+
+ val Adc0Optimization = new RuleBasedAssemblyOptimization("ADC #0/#1 optimization",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("ah")
+ List(
+ AssemblyLine.relative(BCC, label),
+ code.last.copy(opcode = INC),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("ah")
+ List(
+ AssemblyLine.relative(BCC, label),
+ code.last.copy(opcode = INC),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~
+ (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = INC))
+ },
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.C) & MatchParameter(2) & HasAddrModeIn(Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX))) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(1) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = INC))
+ },
+ (Elidable & HasOpcode(TXA) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("ah")
+ List(
+ AssemblyLine.relative(BCC, label),
+ AssemblyLine.implied(INX),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(TYA) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("ah")
+ List(
+ AssemblyLine.relative(BCC, label),
+ AssemblyLine.implied(INY),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(TXA) & HasClear(State.D) & HasClear(State.C)) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
+ (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(AssemblyLine.implied(INX))
+ },
+ (Elidable & HasOpcode(TYA) & HasClear(State.D) & HasClear(State.C)) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
+ (Elidable & HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(AssemblyLine.implied(INY))
+ },
+ )
+
+ val IndexSequenceOptimization = new RuleBasedAssemblyOptimization("Index sequence optimization",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil),
+ (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INY))),
+ (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEY))),
+ (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil),
+ (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INX))),
+ (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~
+ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEX))),
+ )
+
+
+}
diff --git a/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala b/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala
new file mode 100644
index 00000000..5ab931fd
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/AssemblyOptimization.scala
@@ -0,0 +1,14 @@
+package millfork.assembly.opt
+
+import millfork.CompilationOptions
+import millfork.assembly.AssemblyLine
+import millfork.env.NormalFunction
+
+/**
+ * @author Karol Stasiak
+ */
+trait AssemblyOptimization {
+ def name: String
+
+ def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine]
+}
diff --git a/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala b/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala
new file mode 100644
index 00000000..d67f71c0
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/ChangeIndexRegisterOptimization.scala
@@ -0,0 +1,155 @@
+package millfork.assembly.opt
+
+import millfork.CompilationOptions
+import millfork.assembly.{AssemblyLine, OpcodeClasses}
+import millfork.env.NormalFunction
+import millfork.error.ErrorReporting
+
+/**
+ * @author Karol Stasiak
+ */
+
+object ChangeIndexRegisterOptimizationPreferringX2Y extends ChangeIndexRegisterOptimization(true)
+object ChangeIndexRegisterOptimizationPreferringY2X extends ChangeIndexRegisterOptimization(false)
+
+class ChangeIndexRegisterOptimization(preferX2Y: Boolean) extends AssemblyOptimization {
+
+ object IndexReg extends Enumeration {
+ val X, Y = Value
+ }
+
+ object IndexDirection extends Enumeration {
+ val X2Y, Y2X = Value
+ }
+
+ import IndexReg._
+ import IndexDirection._
+ import millfork.assembly.AddrMode._
+ import millfork.assembly.Opcode._
+
+ type IndexReg = IndexReg.Value
+ type IndexDirection = IndexDirection.Value
+
+ override def name = "Changing index registers"
+
+ override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
+ val usesIndex = code.exists(l =>
+ OpcodeClasses.ReadsXAlways(l.opcode) ||
+ OpcodeClasses.ReadsYAlways(l.opcode) ||
+ OpcodeClasses.ChangesX(l.opcode) ||
+ OpcodeClasses.ChangesY(l.opcode) ||
+ Set(AbsoluteX, AbsoluteY, ZeroPageY, ZeroPageX, IndexedX, IndexedY)(l.addrMode)
+ )
+ if (!usesIndex) {
+ return code
+ }
+ val canX2Y = f.returnType.size <= 1 && canOptimize(code, X2Y, None)
+ val canY2X = canOptimize(code, Y2X, None)
+ (canX2Y, canY2X) match {
+ case (false, false) => code
+ case (true, false) =>
+ ErrorReporting.debug("Changing index register from X to Y")
+ switchX2Y(code)
+ case (false, true) =>
+ ErrorReporting.debug("Changing index register from X to Y")
+ switchY2X(code)
+ case (true, true) =>
+ if (preferX2Y) {
+ ErrorReporting.debug("Changing index register from X to Y (arbitrarily)")
+ switchX2Y(code)
+ } else {
+ ErrorReporting.debug("Changing index register from Y to X (arbitrarily)")
+ switchY2X(code)
+ }
+ }
+ }
+
+ //noinspection OptionEqualsSome
+ private def canOptimize(code: List[AssemblyLine], dir: IndexDirection, loaded: Option[IndexReg]): Boolean = code match {
+ case AssemblyLine(_, AbsoluteY, _, _) :: xs if loaded != Some(Y) => false
+ case AssemblyLine(_, ZeroPageY, _, _) :: xs if loaded != Some(Y) => false
+ case AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y || loaded != Some(Y) => false
+ case AssemblyLine(_, AbsoluteX, _, _) :: xs if loaded != Some(X) => false
+ case AssemblyLine(_, ZeroPageX, _, _) :: xs if loaded != Some(X) => false
+ case AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X || loaded != Some(Y) => false
+
+ // using a wrong index register for one instruction is fine
+ case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs if dir == Y2X =>
+ canOptimize(xs, dir, None)
+ case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(_, IndexedX, _, _) :: xs if dir == X2Y =>
+ canOptimize(xs, dir, None)
+ case AssemblyLine(LDX | TAX, _, _, _) :: AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y =>
+ canOptimize(xs, dir, None)
+
+ case AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _) :: xs if dir == X2Y => false
+
+ case AssemblyLine(LAX, _, _, _) :: xs => false
+ case AssemblyLine(JSR, _, _, _) :: xs => false // TODO
+ case AssemblyLine(JMP, _, _, _) :: xs => canOptimize(xs, dir, None)
+ case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.ShortBranching(op) => canOptimize(xs, dir, None)
+ case AssemblyLine(RTS, _, _, _) :: xs => canOptimize(xs, dir, None)
+ case AssemblyLine(LABEL, _, _, _) :: xs => canOptimize(xs, dir, None)
+ case AssemblyLine(DISCARD_XF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != X))
+ case AssemblyLine(DISCARD_YF, _, _, _) :: xs => canOptimize(xs, dir, loaded.filter(_ != Y))
+ case AssemblyLine(_, DoesNotExist, _, _) :: xs => canOptimize(xs, dir, loaded)
+
+ case AssemblyLine(TAX | LDX | PLX, _, _, e) :: xs =>
+ (e || dir == Y2X) && canOptimize(xs, dir, Some(X))
+ case AssemblyLine(TAY | LDY | PLY, _, _, e) :: xs =>
+ (e || dir == X2Y) && canOptimize(xs, dir, Some(Y))
+ case AssemblyLine(TXA | STX | PHX | CPX | INX | DEX, _, _, e) :: xs =>
+ (e || dir == Y2X) && loaded == Some(X) && canOptimize(xs, dir, Some(X))
+ case AssemblyLine(TYA | STY | PHY | CPY | INY | DEY, _, _, e) :: xs =>
+ (e || dir == X2Y) && loaded == Some(Y) && canOptimize(xs, dir, Some(Y))
+
+ case AssemblyLine(SAX | TXS | SBX, _, _, _) :: xs => dir == Y2X && loaded == Some(X) && canOptimize(xs, dir, Some(X))
+ case AssemblyLine(TSX, _, _, _) :: xs => dir == Y2X && loaded != Some(Y) && canOptimize(xs, dir, Some(X))
+
+ case _ :: xs => canOptimize(xs, dir, loaded)
+
+ case Nil => true
+ }
+
+ private def switchX2Y(code: List[AssemblyLine]): List[AssemblyLine] = code match {
+ case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(INC | DEC | ASL | ROL | ROR | LSR | STZ, AbsoluteX | ZeroPageX, _, _)) :: xs => a :: b :: switchX2Y(xs)
+ case (a@AssemblyLine(LDX | TAX, _, _, _)) :: (b@AssemblyLine(_, IndexedX, _, _)) :: xs => a :: b :: switchX2Y(xs)
+ case (x@AssemblyLine(TAX, _, _, _)) :: xs => x.copy(opcode = TAY) :: switchX2Y(xs)
+ case (x@AssemblyLine(TXA, _, _, _)) :: xs => x.copy(opcode = TYA) :: switchX2Y(xs)
+ case (x@AssemblyLine(STX, _, _, _)) :: xs => x.copy(opcode = STY) :: switchX2Y(xs)
+ case (x@AssemblyLine(LDX, _, _, _)) :: xs => x.copy(opcode = LDY) :: switchX2Y(xs)
+ case (x@AssemblyLine(INX, _, _, _)) :: xs => x.copy(opcode = INY) :: switchX2Y(xs)
+ case (x@AssemblyLine(DEX, _, _, _)) :: xs => x.copy(opcode = DEY) :: switchX2Y(xs)
+ case (x@AssemblyLine(CPX, _, _, _)) :: xs => x.copy(opcode = CPY) :: switchX2Y(xs)
+
+ case AssemblyLine(LAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected LAX")
+ case AssemblyLine(TXS, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TXS")
+ case AssemblyLine(TSX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected TSX")
+ case AssemblyLine(SBX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SBX")
+ case AssemblyLine(SAX, _, _, _) :: xs => ErrorReporting.fatal("Unexpected SAX")
+
+ case (x@AssemblyLine(_, AbsoluteX, _, _)) :: xs => x.copy(addrMode = AbsoluteY) :: switchX2Y(xs)
+ case (x@AssemblyLine(_, ZeroPageX, _, _)) :: xs => x.copy(addrMode = ZeroPageY) :: switchX2Y(xs)
+ case (x@AssemblyLine(_, IndexedX, _, _)) :: xs => ErrorReporting.fatal("Unexpected IndexedX")
+
+ case x::xs => x :: switchX2Y(xs)
+ case Nil => Nil
+ }
+
+ private def switchY2X(code: List[AssemblyLine]): List[AssemblyLine] = code match {
+ case AssemblyLine(LDY | TAY, _, _, _) :: AssemblyLine(_, IndexedY, _, _) :: xs => code.take(2) ++ switchY2X(xs)
+ case (x@AssemblyLine(TAY, _, _, _)) :: xs => x.copy(opcode = TAX) :: switchY2X(xs)
+ case (x@AssemblyLine(TYA, _, _, _)) :: xs => x.copy(opcode = TXA) :: switchY2X(xs)
+ case (x@AssemblyLine(STY, _, _, _)) :: xs => x.copy(opcode = STX) :: switchY2X(xs)
+ case (x@AssemblyLine(LDY, _, _, _)) :: xs => x.copy(opcode = LDX) :: switchY2X(xs)
+ case (x@AssemblyLine(INY, _, _, _)) :: xs => x.copy(opcode = INX) :: switchY2X(xs)
+ case (x@AssemblyLine(DEY, _, _, _)) :: xs => x.copy(opcode = DEX) :: switchY2X(xs)
+ case (x@AssemblyLine(CPY, _, _, _)) :: xs => x.copy(opcode = CPX) :: switchY2X(xs)
+
+ case (x@AssemblyLine(_, AbsoluteY, _, _)) :: xs => x.copy(addrMode = AbsoluteX) :: switchY2X(xs)
+ case (x@AssemblyLine(_, ZeroPageY, _, _)) :: xs => x.copy(addrMode = ZeroPageX) :: switchY2X(xs)
+ case AssemblyLine(_, IndexedY, _, _) :: xs => ErrorReporting.fatal("Unexpected IndexedY")
+
+ case x::xs => x :: switchY2X(xs)
+ case Nil => Nil
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala b/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala
new file mode 100644
index 00000000..d1b4ca60
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/CmosOptimizations.scala
@@ -0,0 +1,36 @@
+package millfork.assembly.opt
+
+import millfork.assembly.{AssemblyLine, Opcode}
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.assembly.OpcodeClasses._
+import millfork.env.{Constant, NormalFunction}
+
+/**
+ * @author Karol Stasiak
+ */
+object CmosOptimizations {
+
+ val StzAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX)
+
+ val ZeroStoreAsStz = new RuleBasedAssemblyOptimization("Zero store",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (HasA(0) & HasOpcode(STA) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
+ code.head.copy(opcode = STZ) :: Nil
+ },
+ (HasX(0) & HasOpcode(STX) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
+ code.head.copy(opcode = STZ) :: Nil
+ },
+ (HasY(0) & HasOpcode(STY) & Elidable & HasAddrModeIn(StzAddrModes)) ~~> {code =>
+ code.head.copy(opcode = STZ) :: Nil
+ },
+ )
+
+ val OptimizeZeroIndex = new RuleBasedAssemblyOptimization("Optimizing zero index",
+ needsFlowInfo = FlowInfoRequirement.ForwardFlow,
+ (Elidable & HasY(0) & HasAddrMode(IndexedY) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => code.map(_.copy(addrMode = ZeroPageIndirect))),
+ (Elidable & HasX(0) & HasAddrMode(IndexedX) & HasOpcodeIn(SupportsZeroPageIndirect)) ~~> (code => code.map(_.copy(addrMode = ZeroPageIndirect))),
+ )
+
+ val All: List[AssemblyOptimization] = List(ZeroStoreAsStz)
+}
diff --git a/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala
new file mode 100644
index 00000000..7e21009c
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/CoarseFlowAnalyzer.scala
@@ -0,0 +1,259 @@
+package millfork.assembly.opt
+
+import millfork.assembly.{AssemblyLine, OpcodeClasses, State}
+import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
+
+import scala.collection.immutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+sealed trait Status[T] {
+ def contains(value: T): Boolean
+
+ def ~(that: Status[T]): Status[T] = {
+ (this, that) match {
+ case (AnyStatus(), _) => AnyStatus()
+ case (_, AnyStatus()) => AnyStatus()
+ case (SingleStatus(x), SingleStatus(y)) => if (x == y) SingleStatus(x) else AnyStatus()
+ case (SingleStatus(x), UnknownStatus()) => SingleStatus(x)
+ case (UnknownStatus(), SingleStatus(x)) => SingleStatus(x)
+ case (UnknownStatus(), UnknownStatus()) => UnknownStatus()
+ }
+ }
+
+}
+
+object Status {
+
+ implicit class IntStatusOps(val inner: Status[Int]) extends AnyVal {
+ def map[T](f: Int => T): Status[T] = inner match {
+ case SingleStatus(x) => SingleStatus(f(x))
+ case _ => AnyStatus()
+ }
+
+ def z(f: Int => Int = identity): Status[Boolean] = inner match {
+ case SingleStatus(x) =>
+ val y = f(x) & 0xff
+ SingleStatus(y == 0)
+ case _ => AnyStatus()
+ }
+
+ def n(f: Int => Int = identity): Status[Boolean] = inner match {
+ case SingleStatus(x) =>
+ val y = f(x) & 0xff
+ SingleStatus(y >= 0x80)
+ case _ => AnyStatus()
+ }
+ }
+
+}
+
+
+case class SingleStatus[T](t: T) extends Status[T] {
+ override def contains(value: T): Boolean = t == value
+
+ override def toString: String = t match {
+ case true => "1"
+ case false => "0"
+ case _ => t.toString
+ }
+}
+
+case class UnknownStatus[T]() extends Status[T] {
+ override def contains(value: T) = false
+
+ override def toString: String = "_"
+}
+
+case class AnyStatus[T]() extends Status[T] {
+ override def contains(value: T) = false
+
+ override def toString: String = "#"
+}
+//noinspection RedundantNewCaseClass
+case class CpuStatus(a: Status[Int] = UnknownStatus(),
+ x: Status[Int] = UnknownStatus(),
+ y: Status[Int] = UnknownStatus(),
+ z: Status[Boolean] = UnknownStatus(),
+ n: Status[Boolean] = UnknownStatus(),
+ c: Status[Boolean] = UnknownStatus(),
+ v: Status[Boolean] = UnknownStatus(),
+ d: Status[Boolean] = UnknownStatus(),
+ ) {
+
+ override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d"
+
+ def nz: CpuStatus =
+ this.copy(n = AnyStatus(), z = AnyStatus())
+
+ def nz(i: Long): CpuStatus =
+ this.copy(n = SingleStatus((i & 0x80) != 0), z = SingleStatus((i & 0xff) == 0))
+
+ def ~(that: CpuStatus) = new CpuStatus(
+ a = this.a ~ that.a,
+ x = this.x ~ that.x,
+ y = this.y ~ that.y,
+ z = this.z ~ that.z,
+ n = this.n ~ that.n,
+ c = this.c ~ that.c,
+ v = this.v ~ that.v,
+ d = this.d ~ that.d,
+ )
+
+ def hasClear(state: State.Value): Boolean = state match {
+ case State.A => a.contains(0)
+ case State.X => x.contains(0)
+ case State.Y => y.contains(0)
+ case State.Z => z.contains(false)
+ case State.N => n.contains(false)
+ case State.C => c.contains(false)
+ case State.V => v.contains(false)
+ case State.D => d.contains(false)
+ }
+
+ def hasSet(state: State.Value): Boolean = state match {
+ case State.A => false
+ case State.X => false
+ case State.Y => false
+ case State.Z => z.contains(true)
+ case State.N => n.contains(true)
+ case State.C => c.contains(true)
+ case State.V => v.contains(true)
+ case State.D => d.contains(true)
+ }
+}
+
+object CoarseFlowAnalyzer {
+ //noinspection RedundantNewCaseClass
+ def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuStatus] = {
+ val flagArray = Array.fill[CpuStatus](code.length)(CpuStatus())
+ val codeArray = code.toArray
+ val initialStatus = new CpuStatus(d = SingleStatus(false))
+
+ var changed = true
+ while (changed) {
+ changed = false
+ var currentStatus: CpuStatus = if (f.interrupt) CpuStatus() else initialStatus
+ for (i <- codeArray.indices) {
+ import millfork.assembly.Opcode._
+ import millfork.assembly.AddrMode._
+ if (flagArray(i) != currentStatus) {
+ changed = true
+ flagArray(i) = currentStatus
+ }
+ codeArray(i) match {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) =>
+ val L = l
+ currentStatus = codeArray.indices.flatMap(j => codeArray(j) match {
+ case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j))
+ case _ => None
+ }).fold(CpuStatus())(_ ~ _)
+
+ case AssemblyLine(BCC, _, _, _) =>
+ currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(true))
+ case AssemblyLine(BCS, _, _, _) =>
+ currentStatus = currentStatus.copy(c = currentStatus.c ~ SingleStatus(false))
+ case AssemblyLine(BVS, _, _, _) =>
+ currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(false))
+ case AssemblyLine(BVC, _, _, _) =>
+ currentStatus = currentStatus.copy(v = currentStatus.v ~ SingleStatus(true))
+ case AssemblyLine(BMI, _, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(false))
+ case AssemblyLine(BPL, _, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.n ~ SingleStatus(true))
+ case AssemblyLine(BEQ, _, _, _) =>
+ currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(false))
+ case AssemblyLine(BNE, _, _, _) =>
+ currentStatus = currentStatus.copy(z = currentStatus.z ~ SingleStatus(true))
+
+ case AssemblyLine(SED, _, _, _) =>
+ currentStatus = currentStatus.copy(d = SingleStatus(true))
+ case AssemblyLine(SEC, _, _, _) =>
+ currentStatus = currentStatus.copy(c = SingleStatus(true))
+ case AssemblyLine(CLD, _, _, _) =>
+ currentStatus = currentStatus.copy(d = SingleStatus(false))
+ case AssemblyLine(CLC, _, _, _) =>
+ currentStatus = currentStatus.copy(c = SingleStatus(false))
+ case AssemblyLine(CLV, _, _, _) =>
+ currentStatus = currentStatus.copy(v = SingleStatus(false))
+
+ case AssemblyLine(JSR, _, _, _) =>
+ currentStatus = initialStatus
+
+ case AssemblyLine(LDX, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.nz(n).copy(x = SingleStatus(n))
+ case AssemblyLine(LDY, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.nz(n).copy(y = SingleStatus(n))
+ case AssemblyLine(LDA, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.nz(n).copy(a = SingleStatus(n))
+ case AssemblyLine(LAX, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.nz(n).copy(a = SingleStatus(n), x = SingleStatus(n))
+
+ case AssemblyLine(EOR, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.copy(n = currentStatus.a.n(_ ^ n), z = currentStatus.a.z(_ ^ n), a = currentStatus.a.map(_ ^ n))
+ case AssemblyLine(AND, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), z = currentStatus.a.z(_ & n), a = currentStatus.a.map(_ & n))
+ case AssemblyLine(ANC, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.copy(n = currentStatus.a.n(_ & n), c = currentStatus.a.n(_ & n), z = currentStatus.x.z(_ & n), a = currentStatus.a.map(_ & n))
+ case AssemblyLine(ORA, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.copy(n = currentStatus.a.n(_ | n), z = currentStatus.a.z(_ | n), a = currentStatus.a.map(_ | n))
+ case AssemblyLine(ALR, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn.toInt
+ currentStatus = currentStatus.copy(
+ n = currentStatus.a.n(i => (i & n & 0xff) >> 1),
+ z = currentStatus.a.z(i => (i & n & 0xff) >> 1),
+ c = currentStatus.a.map(i => (i & n & 1) == 0),
+ a = currentStatus.a.map(i => (i & n & 0xff) >> 1))
+
+ case AssemblyLine(INX, Implied, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.x.n(_ + 1), z = currentStatus.x.z(_ + 1), x = currentStatus.x.map(_ + 1))
+ case AssemblyLine(DEX, Implied, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.x.n(_ - 1), z = currentStatus.x.z(_ - 1), x = currentStatus.x.map(_ - 1))
+ case AssemblyLine(INY, Implied, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.y.n(_ + 1), z = currentStatus.y.z(_ + 1), y = currentStatus.y.map(_ + 1))
+ case AssemblyLine(DEY, Implied, _, _) =>
+ currentStatus = currentStatus.copy(n = currentStatus.y.n(_ - 1), z = currentStatus.y.z(_ - 1), y = currentStatus.y.map(_ - 1))
+ case AssemblyLine(TAX, _, _, _) =>
+ currentStatus = currentStatus.copy(x = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z())
+ case AssemblyLine(TXA, _, _, _) =>
+ currentStatus = currentStatus.copy(a = currentStatus.x, n = currentStatus.x.n(), z = currentStatus.x.z())
+ case AssemblyLine(TAY, _, _, _) =>
+ currentStatus = currentStatus.copy(y = currentStatus.a, n = currentStatus.a.n(), z = currentStatus.a.z())
+ case AssemblyLine(TYA, _, _, _) =>
+ currentStatus = currentStatus.copy(a = currentStatus.y, n = currentStatus.y.n(), z = currentStatus.y.z())
+
+ case AssemblyLine(opcode, addrMode, parameter, _) =>
+ if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.copy(x = AnyStatus())
+ if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.copy(y = AnyStatus())
+ if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.copy(a = AnyStatus())
+ if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.copy(a = AnyStatus())
+ if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.nz
+ if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.copy(c = AnyStatus())
+ if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.copy(v = AnyStatus())
+ if (opcode == CMP || opcode == CPX || opcode == CPY) {
+ if (addrMode == Immediate) parameter match {
+ case NumericConstant(0, _) => currentStatus = currentStatus.copy(c = SingleStatus(true))
+ case _ => ()
+ }
+ }
+ }
+ }
+// flagArray.zip(codeArray).foreach{
+// case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s")
+// }
+// println("---------------------")
+ }
+
+ flagArray.toList
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala b/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala
new file mode 100644
index 00000000..649bd4e1
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/DangerousOptimizations.scala
@@ -0,0 +1,59 @@
+package millfork.assembly.opt
+
+import millfork.assembly._
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.env._
+
+/**
+ * @author Karol Stasiak
+ */
+object DangerousOptimizations {
+
+ val ConstantIndexOffsetPropagation = new RuleBasedAssemblyOptimization("Constant index offset propagation",
+ // TODO: try to guess when overflow can happen
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (Elidable & HasOpcode(CLC)).? ~
+ (Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~
+ (
+ (HasOpcode(TAY) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~
+ (Linear & Not(ConcernsY)).*
+ ).capture(1) ~
+ (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
+ val last = code.last
+ ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify)
+ },
+ (Elidable & HasOpcode(CLC)).? ~
+ (Elidable & HasClear(State.C) & HasOpcode(ADC) & MatchImmediate(0) & DoesntMatterWhatItDoesWith(State.V, State.C)) ~
+ (
+ (HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.N, State.Z, State.A)) ~
+ (Linear & Not(ConcernsX)).*
+ ).capture(1) ~
+ (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
+ val last = code.last
+ ctx.get[List[AssemblyLine]](1) :+ last.copy(parameter = last.parameter.+(ctx.get[Constant](0)).quickSimplify)
+ },
+ (Elidable & HasOpcode(INY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
+ (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
+ val last = code.last
+ List(last.copy(parameter = last.parameter.+(1).quickSimplify))
+ },
+ (Elidable & HasOpcode(DEY) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
+ (Elidable & HasAddrMode(AbsoluteY) & DoesntMatterWhatItDoesWith(State.Y)) ~~> { (code, ctx) =>
+ val last = code.last
+ List(last.copy(parameter = last.parameter.+(-1).quickSimplify))
+ },
+ (Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
+ (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
+ val last = code.last
+ List(last.copy(parameter = last.parameter.+(1).quickSimplify))
+ },
+ (Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~
+ (Elidable & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.X)) ~~> { (code, ctx) =>
+ val last = code.last
+ List(last.copy(parameter = last.parameter.+(-1).quickSimplify))
+ },
+ )
+
+ val All: List[AssemblyOptimization] = List(ConstantIndexOffsetPropagation)
+}
diff --git a/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala
new file mode 100644
index 00000000..eaeb718f
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/FlowAnalyzer.scala
@@ -0,0 +1,34 @@
+package millfork.assembly.opt
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.assembly.{AssemblyLine, State}
+import millfork.env.NormalFunction
+
+/**
+ * @author Karol Stasiak
+ */
+
+case class FlowInfo(statusBefore: CpuStatus, importanceAfter: CpuImportance) {
+
+ def hasClear(state: State.Value): Boolean = statusBefore.hasClear(state)
+
+ def hasSet(state: State.Value): Boolean = statusBefore.hasSet(state)
+
+ def isUnimportant(state: State.Value): Boolean = importanceAfter.isUnimportant(state)
+}
+
+object FlowInfo {
+ val Default = FlowInfo(CpuStatus(), CpuImportance())
+}
+
+object FlowAnalyzer {
+ def analyze(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[(FlowInfo, AssemblyLine)] = {
+ val forwardFlow = if (options.flag(CompilationFlag.DetailedFlowAnalysis)) {
+ QuantumFlowAnalyzer.analyze(f, code).map(_.collapse)
+ } else {
+ CoarseFlowAnalyzer.analyze(f, code)
+ }
+ val reverseFlow = ReverseFlowAnalyzer.analyze(f, code)
+ forwardFlow.zip(reverseFlow).map{case (s,i) => FlowInfo(s,i)}.zip(code)
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala b/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala
new file mode 100644
index 00000000..3a1b2095
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/LaterOptimizations.scala
@@ -0,0 +1,242 @@
+package millfork.assembly.opt
+
+import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State}
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.assembly.OpcodeClasses._
+import millfork.env.{Constant, NormalFunction, NumericConstant}
+
+/**
+ * These optimizations help on their own, but may prevent other optimizations from triggering.
+ *
+ * @author Karol Stasiak
+ */
+object LaterOptimizations {
+
+
+ // This optimization tends to prevent later Variable To Register Optimization,
+ // so run this only after it's pretty sure V2RO won't happen any more
+ val DoubleLoadToDifferentRegisters = new RuleBasedAssemblyOptimization("Double load to different registers",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDX, TAX),
+ TwoDifferentLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA), LDY, TAY),
+ TwoDifferentLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX), LDA, TXA),
+ TwoDifferentLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY), LDA, TYA),
+ TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDX, TAX),
+ TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA), LDY, TAY),
+ TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX), LDA, TXA),
+ TwoDifferentLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY), LDA, TYA),
+ )
+
+ private def TwoDifferentLoadsWithNoFlagChangeInBetween(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = {
+ (HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).* ~
+ (HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c =>
+ c.init :+ AssemblyLine.implied(transferOpcode)
+ }
+ }
+
+ private def TwoDifferentLoadsWhoseFlagsWillNotBeChecked(opcode1: Opcode.Value, middle: AssemblyLinePattern, opcode2: Opcode.Value, transferOpcode: Opcode.Value) = {
+ ((HasOpcode(opcode1) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & Not(ChangesMemory) & middle & Not(HasOpcode(opcode2))).*).capture(2) ~
+ (HasOpcode(opcode2) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
+ ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
+ ctx.get[List[AssemblyLine]](2) ++ (AssemblyLine.implied(transferOpcode) :: ctx.get[List[AssemblyLine]](3))
+ }
+ }
+
+ private def TwoIdenticalLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
+ (HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & Not(ChangesMemory) & middle & Not(ChangesNAndZ)).* ~
+ (HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { c =>
+ c.init
+ }
+ }
+
+ private def TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
+ (HasOpcode(opcode) & HasAddrMode(Immediate) & MatchParameter(1)) ~
+ (LinearOrLabel & middle & Not(ChangesNAndZ)).* ~
+ (HasOpcode(opcode) & Elidable & HasAddrMode(Immediate) & MatchParameter(1)) ~~> { c =>
+ c.init
+ }
+ }
+
+ private def TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(opcode: Opcode.Value, middle: AssemblyLinePattern) = {
+ ((HasOpcode(opcode) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & Not(ChangesMemory) & middle).*).capture(2) ~
+ (HasOpcode(opcode) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
+ ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
+ ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3)
+ }
+ }
+
+ //noinspection ZeroIndexToHead
+ private def InterleavedImmediateLoads(load: Opcode.Value, store: Opcode.Value) = {
+ (Elidable & HasOpcode(load) & MatchImmediate(0)) ~
+ (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8)) ~
+ (Elidable & HasOpcode(load) & MatchImmediate(1)) ~
+ (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8)) ~
+ (Elidable & HasOpcode(load) & MatchImmediate(0)) ~~> { c =>
+ List(c(2), c(3), c(0), c(1))
+ }
+ }
+
+ //noinspection ZeroIndexToHead
+ private def InterleavedAbsoluteLoads(load: Opcode.Value, store: Opcode.Value) = {
+ (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~
+ (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(8) & DontMatchParameter(0)) ~
+ (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(1) & DontMatchParameter(8) & DontMatchParameter(0)) ~
+ (Elidable & HasOpcode(store) & HasAddrMode(Absolute) & MatchParameter(9) & DontMatchParameter(8) & DontMatchParameter(1) & DontMatchParameter(0)) ~
+ (Elidable & HasOpcode(load) & HasAddrMode(Absolute) & MatchParameter(0)) ~~> { c =>
+ List(c(2), c(3), c(0), c(1))
+ }
+ }
+
+ val DoubleLoadToTheSameRegister = new RuleBasedAssemblyOptimization("Double load to the same register",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ TwoIdenticalLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)),
+ TwoIdenticalLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)),
+ TwoIdenticalLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)),
+ TwoIdenticalLoadsWithNoFlagChangeInBetween(LAX, Not(ChangesA) & Not(ChangesX)),
+ TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDA, Not(ChangesA)),
+ TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDX, Not(ChangesX)),
+ TwoIdenticalImmediateLoadsWithNoFlagChangeInBetween(LDY, Not(ChangesY)),
+ TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDA, Not(ChangesA)),
+ TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDX, Not(ChangesX)),
+ TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LDY, Not(ChangesY)),
+ TwoIdenticalLoadsWhoseFlagsWillNotBeChecked(LAX, Not(ChangesA) & Not(ChangesX)),
+ InterleavedImmediateLoads(LDA, STA),
+ InterleavedImmediateLoads(LDX, STX),
+ InterleavedImmediateLoads(LDY, STY),
+ InterleavedAbsoluteLoads(LDA, STA),
+ InterleavedAbsoluteLoads(LDX, STX),
+ InterleavedAbsoluteLoads(LDY, STY),
+ )
+
+ private def pointlessLoadAfterStore(store: Opcode.Value, load: Opcode.Value, addrMode: AddrMode.Value, meantime: AssemblyLinePattern = Anything) = {
+ ((HasOpcode(store) & HasAddrMode(addrMode) & MatchParameter(1)) ~
+ (LinearOrBranch & Not(ChangesA) & Not(ChangesMemory) & meantime).*).capture(2) ~
+ (HasOpcode(load) & Elidable & HasAddrMode(addrMode) & MatchParameter(1)) ~
+ ((LinearOrLabel & Not(ReadsNOrZ) & Not(ChangesNAndZ)).* ~ ChangesNAndZ).capture(3) ~~> { (_, ctx) =>
+ ctx.get[List[AssemblyLine]](2) ++ ctx.get[List[AssemblyLine]](3)
+ }
+ }
+
+ val PointlessLoadAfterStore = new RuleBasedAssemblyOptimization("Pointless load after store",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ pointlessLoadAfterStore(STA, LDA, Absolute),
+ pointlessLoadAfterStore(STA, LDA, AbsoluteX, Not(ChangesX)),
+ pointlessLoadAfterStore(STA, LDA, AbsoluteY, Not(ChangesY)),
+ pointlessLoadAfterStore(STX, LDX, Absolute),
+ pointlessLoadAfterStore(STY, LDY, Absolute),
+ )
+
+
+ private val ShiftAddrModes = Set(ZeroPage, ZeroPageX, Absolute, AbsoluteX)
+ private val ShiftOpcodes = Set(ASL, ROL, ROR, LSR)
+
+ // LDA-SHIFT-STA is slower than just SHIFT
+ // LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator
+ val PointessLoadingForShifting = new RuleBasedAssemblyOptimization("Pointless loading for shifting",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
+ (Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) =>
+ AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) :: code.drop(3)
+ },
+ (Elidable & HasOpcode(LDA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
+ (Elidable & HasOpcodeIn(ShiftOpcodes) & HasAddrMode(Implied) & MatchOpcode(2)) ~
+ (Elidable & HasOpcode(STA) & HasAddrModeIn(ShiftAddrModes) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Not(ReadsA) & Not(OverwritesA)).* ~ OverwritesA ~~> { (code, ctx) =>
+ val shift = AssemblyLine(ctx.get[Opcode.Value](2), ctx.get[AddrMode.Value](0), ctx.get[Constant](1))
+ shift :: shift :: code.drop(4)
+ }
+ )
+
+ // SHIFT-LDA is equally fast as LDA-SHIFT-STA, but can enable further optimizations doesn't use the accumulator
+ // LDA-SHIFT-SHIFT-STA is equally fast as SHIFT-SHIFT, but the latter doesn't use the accumulator
+ val LoadingAfterShifting = new RuleBasedAssemblyOptimization("Loading after shifting",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcodeIn(ShiftOpcodes) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ AssemblyLine(LDA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) ::
+ AssemblyLine.implied(code.head.opcode) ::
+ AssemblyLine(STA, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)) ::
+ code.drop(2)
+ }
+ )
+
+ val UseZeropageAddressingMode = new RuleBasedAssemblyOptimization("Using zeropage addressing mode",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasAddrMode(Absolute) & MatchParameter(0)) ~ Where(ctx => ctx.get[Constant](0).quickSimplify match {
+ case NumericConstant(x, _) => (x & 0xff00) == 0
+ case _ => false
+ }) ~~> (code => code.head.copy(addrMode = ZeroPage) :: Nil)
+ )
+
+ val UseXInsteadOfStack = new RuleBasedAssemblyOptimization("Using X instead of stack",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.X)) ~
+ (Not(ConcernsStack) & Not(ConcernsX)).capture(1) ~
+ Where(_.isExternallyLinearBlock(1)) ~
+ (Elidable & HasOpcode(PLA)) ~~> (c =>
+ AssemblyLine.implied(TAX) :: (c.tail.init :+ AssemblyLine.implied(TXA))
+ )
+ )
+
+ val UseYInsteadOfStack = new RuleBasedAssemblyOptimization("Using Y instead of stack",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (Elidable & HasOpcode(PHA) & DoesntMatterWhatItDoesWith(State.Y)) ~
+ (Not(ConcernsStack) & Not(ConcernsY)).capture(1) ~
+ Where(_.isExternallyLinearBlock(1)) ~
+ (Elidable & HasOpcode(PLA)) ~~> (c =>
+ AssemblyLine.implied(TAY) :: (c.tail.init :+ AssemblyLine.implied(TYA))
+ )
+ )
+
+ // TODO: make it more generic
+ val IndexSwitchingOptimization = new RuleBasedAssemblyOptimization("Index switching optimization",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~
+ (Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX)) ~
+ (HasOpcode(LDY) & Not(ConcernsX)) ~
+ (Linear & Not(ChangesY) & Not(ConcernsX) & HasAddrModeIn(Set(AbsoluteY, IndexedY, ZeroPageY))) ~
+ (Elidable & HasOpcode(LDY) & MatchAddrMode(2) & Not(ReadsX) & MatchParameter(0)) ~
+ (Elidable & Linear & Not(ChangesY) & HasAddrMode(AbsoluteY) & SupportsAbsoluteX & Not(ConcernsX) & DoesntMatterWhatItDoesWith(State.X, State.N, State.Z)) ~~> { (code, ctx) =>
+ List(
+ code(0).copy(opcode = LDX),
+ code(1).copy(addrMode = AbsoluteX),
+ code(2),
+ code(3),
+ code(5).copy(addrMode = AbsoluteX))
+ },
+ (Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~
+ (Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY)) ~
+ (HasOpcode(LDX) & Not(ConcernsY)) ~
+ (Linear & Not(ChangesX) & Not(ConcernsY) & HasAddrModeIn(Set(AbsoluteX, IndexedX, ZeroPageX, AbsoluteIndexedX))) ~
+ (Elidable & HasOpcode(LDX) & MatchAddrMode(2) & Not(ReadsY) & MatchParameter(0)) ~
+ (Elidable & Linear & Not(ChangesX) & HasAddrMode(AbsoluteX) & SupportsAbsoluteY & Not(ConcernsY) & DoesntMatterWhatItDoesWith(State.Y, State.N, State.Z)) ~~> { (code, ctx) =>
+ List(
+ code(0).copy(opcode = LDY),
+ code(1).copy(addrMode = AbsoluteY),
+ code(2),
+ code(3),
+ code(5).copy(addrMode = AbsoluteY))
+ },
+
+ )
+
+ val All = List(
+ DoubleLoadToDifferentRegisters,
+ DoubleLoadToTheSameRegister,
+ IndexSwitchingOptimization,
+ PointlessLoadAfterStore,
+ PointessLoadingForShifting,
+ LoadingAfterShifting,
+ UseXInsteadOfStack,
+ UseYInsteadOfStack,
+ UseZeropageAddressingMode)
+}
+
diff --git a/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala
new file mode 100644
index 00000000..2f0cabb2
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala
@@ -0,0 +1,425 @@
+package millfork.assembly.opt
+
+import millfork.assembly.{AssemblyLine, OpcodeClasses}
+import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
+
+import scala.collection.immutable.BitSet
+
+/**
+ * @author Karol Stasiak
+ */
+
+object QCpuStatus {
+ val InitialStatus = QCpuStatus((for {
+ c <- Seq(true, false)
+ v <- Seq(true, false)
+ n <- Seq(true, false)
+ z <- Seq(true, false)
+ } yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.NoEquality)).toMap)
+
+ val UnknownStatus = QCpuStatus((for {
+ c <- Seq(true, false)
+ v <- Seq(true, false)
+ n <- Seq(true, false)
+ z <- Seq(true, false)
+ } yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.UnknownEquality)).toMap)
+
+ def gather(l: List[(QFlagStatus, QRegStatus)]) =
+ QCpuStatus(l.groupBy(_._1).
+ map { case (k, vs) => k -> vs.map(_._2).reduce(_ ++ _) }.
+ filterNot(_._2.isEmpty))
+}
+
+case class QCpuStatus(data: Map[QFlagStatus, QRegStatus]) {
+ def collapse: CpuStatus = {
+ val registers = data.values.reduce(_ ++ _)
+
+ def bitset(b: BitSet): Status[Int] = if (b.size == 1) SingleStatus(b.head) else AnyStatus()
+
+ def flag(f: QFlagStatus => Boolean): Status[Boolean] =
+ if (data.keys.forall(k => f(k))) SingleStatus(true)
+ else if (data.keys.forall(k => !f(k))) SingleStatus(false)
+ else AnyStatus()
+
+ CpuStatus(
+ a = bitset(registers.a),
+ x = bitset(registers.x),
+ y = bitset(registers.y),
+ c = flag(_.c),
+ d = flag(_.d),
+ v = flag(_.v),
+ z = flag(_.z),
+ n = flag(_.n),
+ )
+ }
+
+ def changeFlagUnconditionally(f: QFlagStatus => QFlagStatus): QCpuStatus = {
+ QCpuStatus.gather(data.toList.map { case (k, v) => f(k) -> v })
+ }
+
+ def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus): QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v) })
+ }
+
+ def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus, h: QFlagStatus => QFlagStatus): QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v, h(k) -> v) })
+ }
+
+ def mapRegisters(f: QRegStatus => QRegStatus): QCpuStatus = {
+ QCpuStatus(data.map { case (k, v) => k -> f(v) })
+ }
+
+ def mapRegisters(f: (QFlagStatus, QRegStatus) => QRegStatus): QCpuStatus = {
+ QCpuStatus(data.map { case (k, v) => k -> f(k, v) })
+ }
+
+ def flatMap(f: (QFlagStatus, QRegStatus) => List[(QFlagStatus, QRegStatus)]): QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) => f(k, v) })
+ }
+
+ def changeNZFromA: QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
+ List(
+ k.copy(n = false, z = false) -> v.whereA(i => i.toByte > 0),
+ k.copy(n = true, z = false) -> v.whereA(i => i.toByte < 0),
+ k.copy(n = false, z = true) -> v.whereA(i => i.toByte == 0))
+ })
+ }
+
+ def changeNZFromX: QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
+ List(
+ k.copy(n = false, z = false) -> v.whereX(i => i.toByte > 0),
+ k.copy(n = true, z = false) -> v.whereX(i => i.toByte < 0),
+ k.copy(n = false, z = true) -> v.whereX(i => i.toByte == 0))
+ })
+ }
+
+ def changeNZFromY: QCpuStatus = {
+ QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
+ List(
+ k.copy(n = false, z = false) -> v.whereY(i => i.toByte > 0),
+ k.copy(n = true, z = false) -> v.whereY(i => i.toByte < 0),
+ k.copy(n = false, z = true) -> v.whereY(i => i.toByte == 0))
+ })
+ }
+
+ def ~(that: QCpuStatus): QCpuStatus = QCpuStatus.gather(this.data.toList ++ that.data.toList)
+}
+
+object QRegStatus {
+ val NoValues: BitSet = BitSet.empty
+ val AllValues: BitSet = BitSet.fromBitMask(Array(-1L, -1L, -1L, -1L))
+
+}
+
+object RegEquality extends Enumeration {
+ val NoEquality, AX, AY, XY, AXY, UnknownEquality = Value
+
+ def or(a: Value, b: Value) = {
+ (a, b) match {
+ case (UnknownEquality, _) => b
+ case (_, UnknownEquality) => a
+ case (NoEquality, _) => NoEquality
+ case (_, NoEquality) => NoEquality
+ case (_, _) if a == b => a
+ case (AXY, _) => b
+ case (_, AXY) => a
+ case _ => NoEquality
+ }
+ }
+
+ def afterTransfer(a: Value, b: Value) = {
+ (a, b) match {
+ case (UnknownEquality, _) => b
+ case (_, UnknownEquality) => a
+ case (NoEquality, _) => b
+ case (_, NoEquality) => a
+ case (_, _) if a == b => a
+ case _ => AXY
+ }
+ }
+}
+
+case class QRegStatus(a: BitSet, x: BitSet, y: BitSet, equal: RegEquality.Value) {
+ def isEmpty: Boolean = a.isEmpty || x.isEmpty || y.isEmpty
+
+ def ++(that: QRegStatus) = QRegStatus(
+ a = a ++ that.a,
+ x = x ++ that.x,
+ y = y ++ that.y,
+ equal = RegEquality.or(equal, that.equal))
+
+ def afterTransfer(transfer: RegEquality.Value): QRegStatus =
+ copy(equal = RegEquality.afterTransfer(equal, transfer))
+
+ def changeA(f: Int => Long): QRegStatus = {
+ val newA = a.map(i => f(i).toInt & 0xff)
+ val newEqual = equal match {
+ case RegEquality.XY => RegEquality.XY
+ case RegEquality.AXY => RegEquality.XY
+ case _ => RegEquality.NoEquality
+ }
+ QRegStatus(newA, x, y, newEqual)
+ }
+
+ def changeX(f: Int => Long): QRegStatus = {
+ val newA = a.map(i => f(i).toInt & 0xff)
+ val newEqual = equal match {
+ case RegEquality.XY => RegEquality.XY
+ case RegEquality.AXY => RegEquality.XY
+ case _ => RegEquality.NoEquality
+ }
+ QRegStatus(newA, x, y, newEqual)
+ }
+
+ def changeY(f: Int => Long): QRegStatus = {
+ val newA = a.map(i => f(i).toInt & 0xff)
+ val newEqual = equal match {
+ case RegEquality.XY => RegEquality.XY
+ case RegEquality.AXY => RegEquality.XY
+ case _ => RegEquality.NoEquality
+ }
+ QRegStatus(newA, x, y, newEqual)
+ }
+
+ def whereA(f: Int => Boolean): QRegStatus =
+ equal match {
+ case RegEquality.AXY =>
+ copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
+ case RegEquality.AY =>
+ copy(a = a.filter(f), y = y.filter(f))
+ case RegEquality.AX =>
+ copy(a = a.filter(f), x = x.filter(f))
+ case _ =>
+ copy(a = a.filter(f))
+ }
+
+ def whereX(f: Int => Boolean): QRegStatus =
+ equal match {
+ case RegEquality.AXY =>
+ copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
+ case RegEquality.XY =>
+ copy(x = x.filter(f), y = y.filter(f))
+ case RegEquality.AX =>
+ copy(a = a.filter(f), x = x.filter(f))
+ case _ =>
+ copy(x = x.filter(f))
+ }
+
+ def whereY(f: Int => Boolean): QRegStatus =
+ equal match {
+ case RegEquality.AXY =>
+ copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
+ case RegEquality.AY =>
+ copy(a = a.filter(f), y = y.filter(f))
+ case RegEquality.XY =>
+ copy(x = x.filter(f), y = y.filter(f))
+ case _ =>
+ copy(y = y.filter(f))
+ }
+}
+
+case class QFlagStatus(c: Boolean, d: Boolean, v: Boolean, z: Boolean, n: Boolean)
+
+object QuantumFlowAnalyzer {
+ private def loBit(b: Boolean) = if (b) 1 else 0
+
+ private def hiBit(b: Boolean) = if (b) 0x80 else 0
+
+ //noinspection RedundantNewCaseClass
+ def analyze(f: NormalFunction, code: List[AssemblyLine]): List[QCpuStatus] = {
+ val flagArray = Array.fill[QCpuStatus](code.length)(QCpuStatus.UnknownStatus)
+ val codeArray = code.toArray
+
+ var changed = true
+ while (changed) {
+ changed = false
+ var currentStatus: QCpuStatus = if (f.interrupt) QCpuStatus.UnknownStatus else QCpuStatus.UnknownStatus
+ for (i <- codeArray.indices) {
+ import millfork.assembly.Opcode._
+ import millfork.assembly.AddrMode._
+ if (flagArray(i) != currentStatus) {
+ changed = true
+ flagArray(i) = currentStatus
+ }
+ codeArray(i) match {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) =>
+ val L = l
+ currentStatus = codeArray.indices.flatMap(j => codeArray(j) match {
+ case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j))
+ case _ => None
+ }).fold(QCpuStatus.UnknownStatus)(_ ~ _)
+
+ case AssemblyLine(BCC, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
+ case AssemblyLine(BCS, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
+ case AssemblyLine(BVS, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
+ case AssemblyLine(BVC, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = true))
+ case AssemblyLine(BMI, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = false))
+ case AssemblyLine(BPL, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = true))
+ case AssemblyLine(BEQ, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = false))
+ case AssemblyLine(BNE, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = true))
+
+ case AssemblyLine(SED, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = true))
+ case AssemblyLine(SEC, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
+ case AssemblyLine(CLD, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = false))
+ case AssemblyLine(CLC, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
+ case AssemblyLine(CLV, _, _, _) =>
+ currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
+
+ case AssemblyLine(JSR, _, _, _) =>
+ currentStatus = QCpuStatus.InitialStatus
+
+ case AssemblyLine(LDX, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeX(_ => n)).changeNZFromX
+ case AssemblyLine(LDY, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeY(_ => n)).changeNZFromY
+ case AssemblyLine(LDA, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n)).changeNZFromA
+ case AssemblyLine(LAX, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n).changeX(_ => n).afterTransfer(RegEquality.AX)).changeNZFromA
+
+ case AssemblyLine(EOR, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ ^ n)).changeNZFromA
+ case AssemblyLine(AND, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA
+ case AssemblyLine(ANC, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA.changeFlagUnconditionally(f => f.copy(c = f.z))
+ case AssemblyLine(ORA, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeA(_ | n)).changeNZFromA
+
+ case AssemblyLine(INX, Implied, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeX(_ + 1)).changeNZFromX
+ case AssemblyLine(DEX, Implied, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeX(_ - 1)).changeNZFromX
+ case AssemblyLine(INY, Implied, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
+ case AssemblyLine(DEY, Implied, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
+ case AssemblyLine(TAX, _, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.copy(x = r.a).afterTransfer(RegEquality.AX)).changeNZFromX
+ case AssemblyLine(TXA, _, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.x).afterTransfer(RegEquality.AX)).changeNZFromA
+ case AssemblyLine(TAY, _, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.copy(y = r.a).afterTransfer(RegEquality.AY)).changeNZFromY
+ case AssemblyLine(TYA, _, _, _) =>
+ currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.y).afterTransfer(RegEquality.AY)).changeNZFromA
+
+ case AssemblyLine(ROL, Implied, _, _) =>
+ currentStatus = currentStatus.flatMap((f, r) => List(
+ f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2 + loBit(f.c)),
+ f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2 + loBit(f.c)),
+ )).changeNZFromA
+ case AssemblyLine(ROR, Implied, _, _) =>
+ currentStatus = currentStatus.flatMap((f, r) => List(
+ f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
+ f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
+ )).changeNZFromA
+ case AssemblyLine(ASL, Implied, _, _) =>
+ currentStatus = currentStatus.flatMap((f, r) => List(
+ f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2),
+ f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2),
+ )).changeNZFromA
+ case AssemblyLine(LSR, Implied, _, _) =>
+ currentStatus = currentStatus.flatMap((f, r) => List(
+ f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f),
+ f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f),
+ )).changeNZFromA
+ case AssemblyLine(ALR, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.flatMap((f, r) => List(
+ f.copy(c = true) -> r.whereA(a => (a & n & 1) != 0).changeA(a => ((a & n) >>> 2) & 0x7f),
+ f.copy(c = false) -> r.whereA(a => (a & n & 1) == 0).changeA(a => ((a & n) >>> 2) & 0x7f),
+ )).changeNZFromA
+ case AssemblyLine(ADC, Immediate, NumericConstant(nn, _), _) =>
+ val n = nn & 0xff
+ currentStatus = currentStatus.flatMap((f, r) =>
+ if (f.d) {
+ val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
+ List(
+ f.copy(c = false, v = false) -> regs,
+ f.copy(c = true, v = false) -> regs,
+ f.copy(c = false, v = true) -> regs,
+ f.copy(c = true, v = true) -> regs,
+ )
+ } else {
+ if (f.c) {
+ val regs = r.changeA(_ + n + 1)
+ List(
+ f.copy(c = false, v = false) -> regs.whereA(_ >= n),
+ f.copy(c = true, v = false) -> regs.whereA(_ < n),
+ f.copy(c = false, v = true) -> regs.whereA(_ >= n),
+ f.copy(c = true, v = true) -> regs.whereA(_ < n),
+ )
+ } else {
+ val regs = r.changeA(_ + n)
+ List(
+ f.copy(c = false, v = false) -> regs.whereA(_ > n),
+ f.copy(c = true, v = false) -> regs.whereA(_ <= n),
+ f.copy(c = false, v = true) -> regs.whereA(_ > n),
+ f.copy(c = true, v = true) -> regs.whereA(_ <= n),
+ )
+ }
+ }
+ ).changeNZFromA
+ case AssemblyLine(SBC, Immediate, NumericConstant(n, _), _) =>
+ currentStatus = currentStatus.flatMap((f, r) =>
+ if (f.d) {
+ val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
+ // TODO: guess the carry flag correctly
+ List(
+ f.copy(c = false, v = false) -> regs,
+ f.copy(c = true, v = false) -> regs,
+ f.copy(c = false, v = true) -> regs,
+ f.copy(c = true, v = true) -> regs,
+ )
+ } else {
+ val regs = if (f.c) r.changeA(_ - n) else r.changeA(_ - n - 1)
+ List(
+ f.copy(c = false, v = false) -> regs,
+ f.copy(c = true, v = false) -> regs,
+ f.copy(c = false, v = true) -> regs,
+ f.copy(c = true, v = true) -> regs,
+ )
+ }
+ ).changeNZFromA
+
+ case AssemblyLine(opcode, addrMode, parameter, _) =>
+ if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(x = QRegStatus.AllValues))
+ if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(y = QRegStatus.AllValues))
+ if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
+ if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
+ if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(
+ _.copy(n = false, z = false),
+ _.copy(n = true, z = false),
+ _.copy(n = false, z = true))
+ if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(c = false), _.copy(c = true))
+ if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(v = false), _.copy(v = true))
+ if (opcode == CMP || opcode == CPX || opcode == CPY) {
+ if (addrMode == Immediate) parameter match {
+ case NumericConstant(0, _) => currentStatus = currentStatus.changeFlagUnconditionally(_.copy(c = true))
+ case _ => ()
+ }
+ }
+ }
+ }
+ // flagArray.zip(codeArray).foreach{
+ // case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s")
+ // }
+ // println("---------------------")
+ }
+
+ flagArray.toList
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala
new file mode 100644
index 00000000..00d8ff91
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/ReverseFlowAnalyzer.scala
@@ -0,0 +1,149 @@
+package millfork.assembly.opt
+
+import millfork.assembly.{AssemblyLine, OpcodeClasses, State}
+import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
+
+import scala.collection.immutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+sealed trait Importance {
+ def ~(that: Importance) = (this, that) match {
+ case (_, Important) | (Important, _) => Important
+ case (_, Unimportant) | (Unimportant, _) => Unimportant
+ case (UnknownImportance, UnknownImportance) => UnknownImportance
+ }
+}
+
+case object Important extends Importance {
+ override def toString = "!"
+}
+
+
+case object Unimportant extends Importance {
+ override def toString = "*"
+}
+
+case object UnknownImportance extends Importance {
+ override def toString = "?"
+}
+
+//noinspection RedundantNewCaseClass
+case class CpuImportance(a: Importance = UnknownImportance,
+ x: Importance = UnknownImportance,
+ y: Importance = UnknownImportance,
+ n: Importance = UnknownImportance,
+ z: Importance = UnknownImportance,
+ v: Importance = UnknownImportance,
+ c: Importance = UnknownImportance,
+ d: Importance = UnknownImportance,
+ ) {
+ override def toString: String = s"A=$a,X=$x,Y=$y,Z=$z,N=$n,C=$c,V=$v,D=$d"
+
+ def ~(that: CpuImportance) = new CpuImportance(
+ a = this.a ~ that.a,
+ x = this.x ~ that.x,
+ y = this.y ~ that.y,
+ z = this.z ~ that.z,
+ n = this.n ~ that.n,
+ c = this.c ~ that.c,
+ v = this.v ~ that.v,
+ d = this.d ~ that.d,
+ )
+
+ def isUnimportant(state: State.Value): Boolean = state match {
+ case State.A => a == Unimportant
+ case State.X => x == Unimportant
+ case State.Y => y == Unimportant
+ case State.Z => z == Unimportant
+ case State.N => n == Unimportant
+ case State.C => c == Unimportant
+ case State.V => v == Unimportant
+ case State.D => d == Unimportant
+ }
+}
+
+object ReverseFlowAnalyzer {
+ //noinspection RedundantNewCaseClass
+ def analyze(f: NormalFunction, code: List[AssemblyLine]): List[CpuImportance] = {
+ val importanceArray = Array.fill[CpuImportance](code.length)(new CpuImportance())
+ val codeArray = code.toArray
+ val initialStatus = new CpuStatus(d = SingleStatus(false))
+
+ var changed = true
+ val finalImportance = new CpuImportance(a = Important, x = Important, y = Important, c = Important, v = Important, d = Important, z = Important, n = Important)
+ changed = true
+ while (changed) {
+ changed = false
+ var currentImportance: CpuImportance = finalImportance
+ for (i <- codeArray.indices.reverse) {
+ import millfork.assembly.Opcode._
+ import millfork.assembly.AddrMode._
+ if (importanceArray(i) != currentImportance) {
+ changed = true
+ importanceArray(i) = currentImportance
+ }
+ codeArray(i) match {
+ case AssemblyLine(opcode, Relative, MemoryAddressConstant(Label(l)), _) if OpcodeClasses.ShortBranching(opcode) =>
+ val L = l
+ val labelIndex = codeArray.indexWhere {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true
+ case _ => false
+ }
+ currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance
+ case _ =>
+ }
+ codeArray(i) match {
+ // TODO: JSR?
+ case AssemblyLine(JMP, Absolute, MemoryAddressConstant(Label(l)), _) =>
+ val L = l
+ val labelIndex = codeArray.indexWhere {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true
+ case _ => false
+ }
+ currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex)
+ case AssemblyLine(JMP, Indirect, _, _) =>
+ currentImportance = finalImportance
+ case AssemblyLine(BNE | BEQ, _, _, _) =>
+ currentImportance = currentImportance.copy(z = Important)
+ case AssemblyLine(BMI | BPL, _, _, _) =>
+ currentImportance = currentImportance.copy(n = Important)
+ case AssemblyLine(SED | CLD, _, _, _) =>
+ currentImportance = currentImportance.copy(d = Unimportant)
+ case AssemblyLine(RTS, _, _, _) =>
+ currentImportance = finalImportance
+ case AssemblyLine(DISCARD_XF, _, _, _) =>
+ currentImportance = currentImportance.copy(x = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
+ case AssemblyLine(DISCARD_YF, _, _, _) =>
+ currentImportance = currentImportance.copy(y = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
+ case AssemblyLine(DISCARD_AF, _, _, _) =>
+ currentImportance = currentImportance.copy(a = Unimportant, n = Unimportant, z = Unimportant, c = Unimportant, v = Unimportant)
+ case AssemblyLine(opcode, addrMode, _, _) =>
+ if (OpcodeClasses.ChangesC(opcode)) currentImportance = currentImportance.copy(c = Unimportant)
+ if (OpcodeClasses.ChangesV(opcode)) currentImportance = currentImportance.copy(v = Unimportant)
+ if (OpcodeClasses.ChangesNAndZ(opcode)) currentImportance = currentImportance.copy(n = Unimportant, z = Unimportant)
+ if (OpcodeClasses.OverwritesA(opcode)) currentImportance = currentImportance.copy(a = Unimportant)
+ if (OpcodeClasses.OverwritesX(opcode)) currentImportance = currentImportance.copy(x = Unimportant)
+ if (OpcodeClasses.OverwritesY(opcode)) currentImportance = currentImportance.copy(y = Unimportant)
+ if (OpcodeClasses.ReadsC(opcode)) currentImportance = currentImportance.copy(c = Important)
+ if (OpcodeClasses.ReadsD(opcode)) currentImportance = currentImportance.copy(d = Important)
+ if (OpcodeClasses.ReadsV(opcode)) currentImportance = currentImportance.copy(v = Important)
+ if (OpcodeClasses.ReadsXAlways(opcode)) currentImportance = currentImportance.copy(x = Important)
+ if (OpcodeClasses.ReadsYAlways(opcode)) currentImportance = currentImportance.copy(y = Important)
+ if (OpcodeClasses.ReadsAAlways(opcode)) currentImportance = currentImportance.copy(a = Important)
+ if (OpcodeClasses.ReadsAIfImplied(opcode) && addrMode == Implied) currentImportance = currentImportance.copy(a = Important)
+ if (addrMode == AbsoluteX || addrMode == IndexedX || addrMode == ZeroPageX) currentImportance = currentImportance.copy(x = Important)
+ if (addrMode == AbsoluteY || addrMode == IndexedY || addrMode == ZeroPageY) currentImportance = currentImportance.copy(y = Important)
+ }
+ }
+ }
+// importanceArray.zip(codeArray).foreach{
+// case (i, y) => if (y.isPrintable) println(f"$y%-32s $i%-32s")
+// }
+// println("---------------------")
+
+ importanceArray.toList
+ }
+}
diff --git a/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala
new file mode 100644
index 00000000..48cb84e0
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala
@@ -0,0 +1,757 @@
+package millfork.assembly.opt
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.assembly._
+import millfork.env._
+import millfork.error.ErrorReporting
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+object FlowInfoRequirement extends Enumeration {
+
+ val NoRequirement, BothFlows, ForwardFlow, BackwardFlow = Value
+
+ def assertForward(x: FlowInfoRequirement.Value): Unit = x match {
+ case BothFlows | ForwardFlow => ()
+ case NoRequirement | BackwardFlow => ErrorReporting.fatal("Forward flow info required")
+ }
+
+ def assertBackward(x: FlowInfoRequirement.Value): Unit = x match {
+ case BothFlows | BackwardFlow => ()
+ case NoRequirement | ForwardFlow => ErrorReporting.fatal("Backward flow info required")
+ }
+}
+
+class RuleBasedAssemblyOptimization(val name: String, val needsFlowInfo: FlowInfoRequirement.Value, val rules: AssemblyRule*) extends AssemblyOptimization {
+
+ rules.foreach(_.pattern.validate(needsFlowInfo))
+
+ override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
+ val effectiveCode = code.map(a => a.copy(parameter = a.parameter.quickSimplify))
+ val taggedCode = needsFlowInfo match {
+ case FlowInfoRequirement.NoRequirement => effectiveCode.map(FlowInfo.Default -> _)
+ case FlowInfoRequirement.BothFlows => FlowAnalyzer.analyze(f, effectiveCode, options)
+ case FlowInfoRequirement.ForwardFlow =>
+ if (options.flag(CompilationFlag.DetailedFlowAnalysis)) {
+ QuantumFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s.collapse, CpuImportance())).zip(code)
+ } else {
+ CoarseFlowAnalyzer.analyze(f, code).map(s => FlowInfo(s, CpuImportance())).zip(code)
+ }
+ case FlowInfoRequirement.BackwardFlow =>
+ ReverseFlowAnalyzer.analyze(f, code).map(i => FlowInfo(CpuStatus(), i)).zip(code)
+ }
+ optimizeImpl(f, taggedCode, options)
+ }
+
+ def optimizeImpl(f: NormalFunction, code: List[(FlowInfo, AssemblyLine)], options: CompilationOptions): List[AssemblyLine] = {
+ code match {
+ case Nil => Nil
+ case head :: tail =>
+ for ((rule, index) <- rules.zipWithIndex) {
+ val ctx = new AssemblyMatchingContext
+ rule.pattern.matchTo(ctx, code) match {
+ case Some(rest: List[(FlowInfo, AssemblyLine)]) =>
+ val matchedChunkToOptimize: List[AssemblyLine] = code.take(code.length - rest.length).map(_._2)
+ val optimizedChunk: List[AssemblyLine] = rule.result(matchedChunkToOptimize, ctx)
+ ErrorReporting.debug(s"Applied $name ($index)")
+ if (needsFlowInfo != FlowInfoRequirement.NoRequirement) {
+ val before = code.head._1.statusBefore
+ val after = code(matchedChunkToOptimize.length - 1)._1.importanceAfter
+ ErrorReporting.trace(s"Before: $before")
+ ErrorReporting.trace(s"After: $after")
+ }
+ matchedChunkToOptimize.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString))
+ ErrorReporting.trace(" ↓")
+ optimizedChunk.filter(_.isPrintable).foreach(l => ErrorReporting.trace(l.toString))
+ if (needsFlowInfo != FlowInfoRequirement.NoRequirement) {
+ return optimizedChunk ++ optimizeImpl(f, rest, options)
+ } else {
+ return optimize(f, optimizedChunk ++ rest.map(_._2), options)
+ }
+ case None => ()
+ }
+ }
+ head._2 :: optimizeImpl(f, tail, options)
+ }
+ }
+}
+
+class AssemblyMatchingContext {
+ private val map = mutable.Map[Int, Any]()
+
+ def addObject(i: Int, o: Any): Boolean = {
+ if (map.contains(i)) {
+ map(i) == o
+ } else {
+ map(i) = o
+ true
+ }
+ }
+
+ def dontMatch(i: Int, o: Any): Boolean = {
+ if (map.contains(i)) {
+ map(i) != o
+ } else {
+ false
+ }
+ }
+
+ def get[T: Manifest](i: Int): T = {
+ val t = map(i)
+ val clazz = implicitly[Manifest[T]].runtimeClass match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ // TODO
+ case x => x
+ }
+ if (clazz.isInstance(t)) {
+ t.asInstanceOf[T]
+ } else {
+ if (i eq null) {
+ ErrorReporting.fatal(s"Value at index $i is null")
+ } else {
+ ErrorReporting.fatal(s"Value at index $i is a ${t.getClass.getSimpleName}, not a ${clazz.getSimpleName}")
+ }
+ }
+ }
+
+ def isExternallyLinearBlock(i: Int): Boolean = {
+ val labels = mutable.Set[String]()
+ val jumps = mutable.Set[String]()
+ get[List[AssemblyLine]](i).foreach {
+ case AssemblyLine(Opcode.RTS | Opcode.RTI | Opcode.BRK, _, _, _) =>
+ return false
+ case AssemblyLine(Opcode.JMP, AddrMode.Indirect, _, _) =>
+ return false
+ case AssemblyLine(Opcode.LABEL, _, MemoryAddressConstant(Label(l)), _) =>
+ labels += l
+ case AssemblyLine(Opcode.JMP, AddrMode.Absolute, MemoryAddressConstant(Label(l)), _) =>
+ jumps += l
+ case AssemblyLine(Opcode.JMP, AddrMode.Absolute, _, _) =>
+ return false
+ case AssemblyLine(_, AddrMode.Relative, MemoryAddressConstant(Label(l)), _) =>
+ jumps += l
+ case AssemblyLine(br, _, _, _) if OpcodeClasses.ShortBranching(br) =>
+ return false
+ case _ => ()
+ }
+ // if a jump leads inside the block, then it's internal
+ // if a jump leads outside the block, then it's external
+ jumps --= labels
+ jumps.isEmpty
+ }
+
+ def areMemoryReferencesProvablyNonOverlapping(param1: Int, addrMode1: Int, param2: Int, addrMode2: Int): Boolean = {
+ val p1 = get[Constant](param1).quickSimplify
+ val a1 = get[AddrMode.Value](addrMode1)
+ val p2 = get[Constant](param2).quickSimplify
+ val a2 = get[AddrMode.Value](addrMode2)
+ import AddrMode._
+ val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX)
+ if (badAddrModes(a1) || badAddrModes(a2)) return false
+
+ def handleKnownDistance(distance: Short): Boolean = {
+ val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY)
+ val a1Indexing = indexingAddrModes(a1)
+ val a2Indexing = indexingAddrModes(a2)
+ (a1Indexing, a2Indexing) match {
+ case (false, false) => distance != 0
+ case (true, false) => distance > 255 || distance < 0
+ case (false, true) => distance > 0 || distance < -255
+ case (true, true) => distance > 255 || distance < -255
+ }
+ }
+
+ (p1, p2) match {
+ case (NumericConstant(n1, _), NumericConstant(n2, _)) =>
+ handleKnownDistance((n2 - n1).toShort)
+ case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance(distance.toShort)
+ case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance((-distance).toShort)
+ case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance((-distance).toShort)
+ case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance(distance.toShort)
+ case (MemoryAddressConstant(MemoryVariable(a, _, _)), MemoryAddressConstant(MemoryVariable(b, _, _))) =>
+ a.takeWhile(_ != '.') != a.takeWhile(_ != '.') // TODO: ???
+ case _ =>
+ false
+ }
+ }
+}
+
+case class AssemblyRule(pattern: AssemblyPattern, result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) {
+
+}
+
+trait AssemblyPattern {
+
+ def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = ()
+
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]]
+
+ def ~(x: AssemblyPattern) = Concatenation(this, x)
+
+ def ~(x: AssemblyLinePattern) = Concatenation(this, x)
+
+ def ~~>(result: (List[AssemblyLine], AssemblyMatchingContext) => List[AssemblyLine]) = AssemblyRule(this, result)
+
+ def ~~>(result: List[AssemblyLine] => List[AssemblyLine]) = AssemblyRule(this, (code, _) => result(code))
+
+ def capture(i: Int) = Capture(i, this)
+
+ def captureLength(i: Int) = CaptureLength(i, this)
+}
+
+case class Capture(i: Int, pattern: AssemblyPattern) extends AssemblyPattern {
+ override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] =
+ for {
+ rest <- pattern.matchTo(ctx, code)
+ } yield {
+ ctx.addObject(i, code.take(code.length - rest.length).map(_._2))
+ rest
+ }
+
+ override def toString: String = s"(?<$i>$pattern)"
+}
+
+case class CaptureLength(i: Int, pattern: AssemblyPattern) extends AssemblyPattern {
+ override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] =
+ for {
+ rest <- pattern.matchTo(ctx, code)
+ } yield {
+ ctx.addObject(i, code.length - rest.length)
+ rest
+ }
+
+ override def toString: String = s"(?<$i>$pattern)"
+}
+
+
+case class Where(predicate: (AssemblyMatchingContext => Boolean)) extends AssemblyPattern {
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
+ if (predicate(ctx)) Some(code) else None
+ }
+
+ override def toString: String = "Where(...)"
+}
+
+case class Concatenation(l: AssemblyPattern, r: AssemblyPattern) extends AssemblyPattern {
+
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ l.validate(needsFlowInfo)
+ r.validate(needsFlowInfo)
+ }
+
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
+ for {
+ middle <- l.matchTo(ctx, code)
+ end <- r.matchTo(ctx, middle)
+ } yield end
+ }
+
+ override def toString: String = (l, r) match {
+ case (_: Both, _: Both) => s"($l) · ($r)"
+ case (_, _: Both) => s"$l · ($r)"
+ case (_: Both, _) => s"($l) · $r"
+ case _ => s"$l · $r"
+ }
+}
+
+case class Many(rule: AssemblyLinePattern) extends AssemblyPattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ rule.validate(needsFlowInfo)
+ }
+
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
+ var c = code
+ while (true) {
+ c match {
+ case Nil =>
+ return Some(Nil)
+ case x :: xs =>
+ if (rule.matchLineTo(ctx, x._1, x._2)) {
+ c = xs
+ } else {
+ return Some(c)
+ }
+ }
+ }
+ None
+ }
+
+ override def toString: String = s"[$rule]*"
+}
+
+case class ManyWhereAtLeastOne(rule: AssemblyLinePattern, atLeastOneIsThis: AssemblyLinePattern) extends AssemblyPattern {
+
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ rule.validate(needsFlowInfo)
+ }
+
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
+ var c = code
+ var oneFound = false
+ while (true) {
+ c match {
+ case Nil =>
+ return Some(Nil)
+ case x :: xs =>
+ if (atLeastOneIsThis.matchLineTo(ctx, x._1, x._2)) {
+ oneFound = true
+ }
+ if (rule.matchLineTo(ctx, x._1, x._2)) {
+ c = xs
+ } else {
+ if (oneFound) {
+ return Some(c)
+ } else {
+ return None
+ }
+ }
+ }
+ }
+ None
+ }
+
+ override def toString: String = s"[∃$atLeastOneIsThis:$rule]*"
+}
+
+case class Opt(rule: AssemblyLinePattern) extends AssemblyPattern {
+
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ rule.validate(needsFlowInfo)
+ }
+
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = {
+ code match {
+ case Nil =>
+ Some(Nil)
+ case x :: xs =>
+ if (rule.matchLineTo(ctx, x._1, x._2)) {
+ Some(xs)
+ } else {
+ Some(code)
+ }
+ }
+ }
+
+ override def toString: String = s"[$rule]?"
+}
+
+trait AssemblyLinePattern extends AssemblyPattern {
+ def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match {
+ case Nil => None
+ case x :: xs => if (matchLineTo(ctx, x._1, x._2)) Some(xs) else None
+ }
+
+ def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean
+
+ def unary_! : AssemblyLinePattern = Not(this)
+
+ def ? : AssemblyPattern = Opt(this)
+
+ def * : AssemblyPattern = Many(this)
+
+ def + : AssemblyPattern = this ~ Many(this)
+
+ def |(x: AssemblyLinePattern): AssemblyLinePattern = Either(this, x)
+
+ def &(x: AssemblyLinePattern): AssemblyLinePattern = Both(this, x)
+
+ protected def memoryAccessDoesntOverlap(a1: AddrMode.Value, p1: Constant, a2: AddrMode.Value, p2: Constant): Boolean = {
+ import AddrMode._
+ val badAddrModes = Set(IndexedX, IndexedY, ZeroPageIndirect, AbsoluteIndexedX)
+ if (badAddrModes(a1) || badAddrModes(a2)) return false
+ val goodAddrModes = Set(Implied, Immediate, Relative)
+ if (goodAddrModes(a1) || goodAddrModes(a2)) return true
+
+ def handleKnownDistance(distance: Short): Boolean = {
+ val indexingAddrModes = Set(AbsoluteIndexedX, AbsoluteX, ZeroPageX, AbsoluteY, ZeroPageY)
+ val a1Indexing = indexingAddrModes(a1)
+ val a2Indexing = indexingAddrModes(a2)
+ (a1Indexing, a2Indexing) match {
+ case (false, false) => distance != 0
+ case (true, false) => distance > 255 || distance < 0
+ case (false, true) => distance > 0 || distance < -255
+ case (true, true) => distance > 255 || distance < -255
+ }
+ }
+
+ (p1.quickSimplify, p2.quickSimplify) match {
+ case (NumericConstant(n1, _), NumericConstant(n2, _)) =>
+ handleKnownDistance((n2 - n1).toShort)
+ case (a, CompoundConstant(MathOperator.Plus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance(distance.toShort)
+ case (CompoundConstant(MathOperator.Plus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance((-distance).toShort)
+ case (a, CompoundConstant(MathOperator.Minus, b, NumericConstant(distance, _))) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance((-distance).toShort)
+ case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify =>
+ handleKnownDistance(distance.toShort)
+ case (MemoryAddressConstant(a: ThingInMemory), MemoryAddressConstant(b:ThingInMemory)) =>
+ a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ???
+ case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)),
+ MemoryAddressConstant(b: ThingInMemory)) =>
+ a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ???
+ case (MemoryAddressConstant(a: ThingInMemory),
+ CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) =>
+ a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ???
+ case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)),
+ CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) =>
+ a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ???
+ case _ =>
+ false
+ }
+ }
+}
+
+//noinspection LanguageFeature
+object AssemblyLinePattern {
+ implicit def __implicitOpcodeIn(ops: Set[Opcode.Value]): AssemblyLinePattern = HasOpcodeIn(ops)
+}
+
+case class MatchA(i: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.a match {
+ case SingleStatus(value) => ctx.addObject(i, value)
+ case _ => false
+ }
+}
+
+case class MatchX(i: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.x match {
+ case SingleStatus(value) => ctx.addObject(i, value)
+ case _ => false
+ }
+}
+
+case class MatchY(i: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.y match {
+ case SingleStatus(value) => ctx.addObject(i, value)
+ case _ => false
+ }
+}
+
+case class HasA(value: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.a.contains(value)
+}
+
+case class HasX(value: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.x.contains(value)
+}
+
+case class HasY(value: Int) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.statusBefore.y.contains(value)
+}
+
+case class DoesntMatterWhatItDoesWith(states: State.Value*) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertBackward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ states.forall(state => flowInfo.importanceAfter.isUnimportant(state))
+
+ override def toString: String = states.mkString("[¯\\_(ツ)_/¯:", ",", "]")
+}
+
+case class HasSet(state: State.Value) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.hasSet(state)
+}
+
+case class HasClear(state: State.Value) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit =
+ FlowInfoRequirement.assertForward(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ flowInfo.hasClear(state)
+}
+
+case object Anything extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ true
+}
+
+case class Not(inner: AssemblyLinePattern) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = inner.validate(needsFlowInfo)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ !inner.matchLineTo(ctx, flowInfo, line)
+
+ override def toString: String = "¬" + inner
+}
+
+case class Both(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ l.validate(needsFlowInfo)
+ r.validate(needsFlowInfo)
+ }
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ l.matchLineTo(ctx, flowInfo, line) && r.matchLineTo(ctx, flowInfo, line)
+
+ override def toString: String = l + " ∧ " + r
+}
+
+case class Either(l: AssemblyLinePattern, r: AssemblyLinePattern) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ l.validate(needsFlowInfo)
+ r.validate(needsFlowInfo)
+ }
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ l.matchLineTo(ctx, flowInfo, line) || r.matchLineTo(ctx, flowInfo, line)
+
+ override def toString: String = s"($l ∨ $r)"
+}
+
+case object Elidable extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.elidable
+}
+
+case object Linear extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.AllLinear(line.opcode)
+}
+
+case object LinearOrBranch extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.AllLinear(line.opcode) || OpcodeClasses.ShortBranching(line.opcode)
+}
+
+case object LinearOrLabel extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.opcode == Opcode.LABEL || OpcodeClasses.AllLinear(line.opcode)
+}
+
+case object ReadsA extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ReadsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ReadsAIfImplied(line.opcode)
+}
+
+case object ReadsMemory extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.addrMode match {
+ case AddrMode.Indirect => true
+ case AddrMode.Implied | AddrMode.Immediate => false
+ case _ =>
+ OpcodeClasses.ReadsMemoryIfNotImpliedOrImmediate(line.opcode)
+ }
+}
+
+case object ReadsX extends AssemblyLinePattern {
+ val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX, AddrMode.AbsoluteIndexedX)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ReadsXAlways(line.opcode) || XAddrModes(line.addrMode)
+}
+
+case object ReadsY extends AssemblyLinePattern {
+ val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ReadsYAlways(line.opcode) || YAddrModes(line.addrMode)
+}
+
+case object ConcernsC extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ReadsC(line.opcode) && OpcodeClasses.ChangesC(line.opcode)
+}
+
+case object ConcernsA extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ConcernsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ConcernsAIfImplied(line.opcode)
+}
+
+case object ConcernsX extends AssemblyLinePattern {
+ val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ConcernsXAlways(line.opcode) || XAddrModes(line.addrMode)
+}
+
+case object ConcernsY extends AssemblyLinePattern {
+ val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY)
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ConcernsYAlways(line.opcode) || YAddrModes(line.addrMode)
+}
+
+case object ChangesA extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ChangesAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ChangesAIfImplied(line.opcode)
+}
+
+case object ChangesMemory extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode)
+}
+
+case class DoesntChangeMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = {
+ val p1 = ctx.get[Constant](param1)
+ val p2 = line.parameter
+ val a1 = ctx.get[AddrMode.Value](addrMode1)
+ val a2 = line.addrMode
+ val changesSomeMemory = OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode)
+ !changesSomeMemory || memoryAccessDoesntOverlap(a1, p1, a2, p2)
+ }
+}
+
+case object ConcernsMemory extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ReadsMemory.matchLineTo(ctx, flowInfo, line) && ChangesMemory.matchLineTo(ctx, flowInfo, line)
+}
+
+case class DoesNotConcernMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = {
+ val p1 = ctx.get[Constant](param1)
+ val p2 = line.parameter
+ val a1 = ctx.get[AddrMode.Value](addrMode1)
+ val a2 = line.addrMode
+ memoryAccessDoesntOverlap(a1, p1, a2, p2)
+ }
+}
+
+case class HasOpcode(op: Opcode.Value) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.opcode == op
+
+ override def toString: String = op.toString
+}
+
+case class HasOpcodeIn(ops: Set[Opcode.Value]) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ops(line.opcode)
+
+ override def toString: String = ops.mkString("{", ",", "}")
+}
+
+case class HasAddrMode(am: AddrMode.Value) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.addrMode == am
+
+ override def toString: String = am.toString
+}
+
+case class HasAddrModeIn(ams: Set[AddrMode.Value]) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ams(line.addrMode)
+
+ override def toString: String = ams.mkString("{", ",", "}")
+}
+
+case class HasImmediate(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ line.addrMode == AddrMode.Immediate && (line.parameter.quickSimplify match {
+ case NumericConstant(j, _) => (i & 0xff) == (j & 0xff)
+ case _ => false
+ })
+
+ override def toString: String = "#" + i
+}
+
+case class MatchObject(i: Int, f: Function[AssemblyLine, Any]) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.addObject(i, f(line))
+
+ override def toString: String = s"(?<$i>...)"
+}
+
+case class MatchParameter(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.addObject(i, line.parameter.quickSimplify)
+
+ override def toString: String = s"(?<$i>Param)"
+}
+
+case class DontMatchParameter(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.dontMatch(i, line.parameter.quickSimplify)
+
+ override def toString: String = s"¬(?<$i>Param)"
+}
+
+case class MatchAddrMode(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.addObject(i, line.addrMode)
+
+ override def toString: String = s"¬(?<$i>AddrMode)"
+}
+
+case class MatchOpcode(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.addObject(i, line.opcode)
+
+ override def toString: String = s"¬(?<$i>Op)"
+}
+
+case class MatchImmediate(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ if (line.addrMode == AddrMode.Immediate) {
+ ctx.addObject(i, line.parameter.quickSimplify)
+ } else false
+
+ override def toString: String = s"(?<$i>#)"
+}
+
+
+case class DoesntChangeIndexingInAddrMode(i: Int) extends AssemblyLinePattern {
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean =
+ ctx.get[AddrMode.Value](i) match {
+ case AddrMode.ZeroPageX | AddrMode.AbsoluteX | AddrMode.IndexedX | AddrMode.AbsoluteIndexedX => !OpcodeClasses.ChangesX.contains(line.opcode)
+ case AddrMode.ZeroPageY | AddrMode.AbsoluteY | AddrMode.IndexedY => !OpcodeClasses.ChangesY.contains(line.opcode)
+ case _ => true
+ }
+
+ override def toString: String = s"¬(?<$i>AddrMode)"
+}
+
+case class Before(pattern: AssemblyPattern) extends AssemblyLinePattern {
+ override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = {
+ pattern.validate(needsFlowInfo)
+ }
+
+ override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = code match {
+ case Nil => None
+ case x :: xs => pattern.matchTo(ctx, xs) match {
+ case Some(m) => Some(xs)
+ case None => None
+ }
+ }
+
+ override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = ???
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala b/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala
new file mode 100644
index 00000000..d76eefb5
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/SizeOptimizations.scala
@@ -0,0 +1,8 @@
+package millfork.assembly.opt
+
+/**
+ * @author Karol Stasiak
+ */
+object SizeOptimizations {
+
+}
diff --git a/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala b/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala
new file mode 100644
index 00000000..6213d13d
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/SuperOptimizer.scala
@@ -0,0 +1,75 @@
+package millfork.assembly.opt
+
+import millfork.{CompilationFlag, CompilationOptions, OptimizationPresets}
+import millfork.assembly.{AddrMode, AssemblyLine, Opcode}
+import millfork.env.NormalFunction
+import millfork.error.ErrorReporting
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+object SuperOptimizer extends AssemblyOptimization {
+
+ def optimize(m: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
+ val oldVerbosity = ErrorReporting.verbosity
+ ErrorReporting.verbosity = -1
+ var allOptimizers = OptimizationPresets.Good ++ LaterOptimizations.All
+ if (options.flag(CompilationFlag.EmitIllegals)) {
+ allOptimizers ++= UndocumentedOptimizations.All
+ }
+ if (options.flag(CompilationFlag.EmitCmosOpcodes)) {
+ allOptimizers ++= CmosOptimizations.All
+ }
+ allOptimizers ++= List(
+ VariableToRegisterOptimization,
+ ChangeIndexRegisterOptimizationPreferringX2Y,
+ ChangeIndexRegisterOptimizationPreferringY2X,
+ UnusedLabelRemoval)
+ val seenSoFar = mutable.Set[CodeView]()
+ val queue = mutable.Queue[(List[AssemblyOptimization], List[AssemblyLine])]()
+ val leaves = mutable.ListBuffer[(List[AssemblyOptimization], List[AssemblyLine])]()
+ seenSoFar += viewCode(code)
+ queue.enqueue(Nil -> code)
+ while(queue.nonEmpty) {
+ val (optsSoFar, codeSoFar) = queue.dequeue()
+ var isLeaf = true
+ allOptimizers.par.foreach { o =>
+ val optimized = o.optimize(m, codeSoFar, options)
+ val view = viewCode(optimized)
+ seenSoFar.synchronized{
+ if (!seenSoFar(view)) {
+ isLeaf = false
+ seenSoFar += view
+ queue.enqueue((o :: optsSoFar) -> optimized)
+ }
+ }
+ }
+ if (isLeaf) {
+// println(codeSoFar.map(_.sizeInBytes).sum + " B: " + optsSoFar.reverse.map(_.name).mkString(" -> "))
+ leaves += optsSoFar -> codeSoFar
+ }
+ }
+
+ val result = leaves.minBy(_._2.map(_.cost).sum)
+ ErrorReporting.verbosity = oldVerbosity
+ ErrorReporting.debug(s"Visited ${leaves.size} leaves")
+ ErrorReporting.debug(s"${code.map(_.sizeInBytes).sum} B -> ${result._2.map(_.sizeInBytes).sum} B: ${result._1.reverse.map(_.name).mkString(" -> ")}")
+ result._1.reverse.foldLeft(code){(c, opt) =>
+ val n = opt.optimize(m, c, options)
+// println(c.mkString("","",""))
+// println(n.mkString("","",""))
+ n
+ }
+ result._2
+ }
+
+ override val name = "Superoptimizer"
+
+ def viewCode(code: List[AssemblyLine]): CodeView = {
+ CodeView(code.map(l => l.opcode -> l.addrMode))
+ }
+}
+
+case class CodeView(content: List[(Opcode.Value, AddrMode.Value)])
diff --git a/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala b/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala
new file mode 100644
index 00000000..456256c6
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/UndocumentedOptimizations.scala
@@ -0,0 +1,340 @@
+package millfork.assembly.opt
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import millfork.assembly.{AddrMode, AssemblyLine, Opcode, State}
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.assembly.OpcodeClasses._
+import millfork.env.{Constant, NormalFunction, NumericConstant}
+
+/**
+ * @author Karol Stasiak
+ */
+object UndocumentedOptimizations {
+
+ val counter = new AtomicInteger(30000)
+
+ def getNextLabel(prefix: String) = f".${prefix}%s__${counter.getAndIncrement()}%05d"
+
+ // TODO: test these
+
+ private val LaxAddrModeRestriction = Not(HasAddrModeIn(Set(AbsoluteX, ZeroPageX, IndexedX, Immediate)))
+
+ //noinspection ScalaUnnecessaryParentheses
+ val UseLax = new RuleBasedAssemblyOptimization("Using undocumented instruction LAX",
+ needsFlowInfo = FlowInfoRequirement.BackwardFlow,
+ (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~
+ (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
+ },
+ (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~
+ (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
+ },
+
+ (HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~
+ (HasOpcode(TAX) & Elidable) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
+ },
+ (HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~
+ (HasOpcode(TXA) & Elidable) ~~> { (code, ctx) =>
+ ctx.get[List[AssemblyLine]](2) :+ code.head.copy(opcode = LAX)
+ },
+
+ (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(LDX))).*.capture(2) ~
+ (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
+ code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
+ },
+ (HasOpcode(LDX) & Elidable & MatchAddrMode(0) & MatchParameter(1) & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(LDA))).*.capture(2) ~
+ (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
+ code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
+ },
+
+ (HasOpcode(LDA) & Elidable & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsX) & Not(ChangesA) & Not(ChangesMemory) & Not(HasOpcode(TAX))).*.capture(2) ~
+ (HasOpcode(TAX) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
+ code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
+ },
+ (HasOpcode(LDX) & Elidable & LaxAddrModeRestriction) ~
+ (LinearOrLabel & Not(ConcernsA) & Not(ChangesX) & Not(ChangesMemory) & Not(HasOpcode(TXA))).*.capture(2) ~
+ (HasOpcode(TXA) & Elidable & DoesntMatterWhatItDoesWith(State.N, State.Z)) ~~> { (code, ctx) =>
+ code.head.copy(opcode = LAX) :: ctx.get[List[AssemblyLine]](2)
+ },
+ )
+
+ val SaxModes: Set[AddrMode.Value] = Set(ZeroPage, IndexedX, ZeroPageY, Absolute)
+
+ val UseSax = new RuleBasedAssemblyOptimization("Using undocumented instruction SAX",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(10) ~
+ (HasOpcode(AND) & Elidable & MatchAddrMode(2) & MatchParameter(3) & Not(ReadsX)) ~
+ (Linear & Not(ConcernsA) & Not(ConcernsX)).?.capture(11) ~
+ (HasOpcode(STA) & Elidable & MatchAddrMode(4) & MatchParameter(5) & HasAddrModeIn(SaxModes) & DontMatchParameter(0)) ~
+ (Linear & Not(ConcernsA) & Not(ConcernsX) & Not(ChangesMemory)).?.capture(12) ~
+ (HasOpcode(LDA) & Elidable & MatchAddrMode(0) & MatchParameter(1)) ~
+ (LinearOrLabel & Not(ConcernsX)).*.capture(13) ~ OverwritesX ~~> { (code, ctx) =>
+ val lda = code.head
+ val ldx = AssemblyLine(LDX, ctx.get[AddrMode.Value](2), ctx.get[Constant](3))
+ val sax = AssemblyLine(SAX, ctx.get[AddrMode.Value](4), ctx.get[Constant](5))
+ val fragment0 = lda :: ctx.get[List[AssemblyLine]](10)
+ val fragment1 = ldx :: ctx.get[List[AssemblyLine]](11)
+ val fragment2 = sax :: ctx.get[List[AssemblyLine]](12)
+ val fragment3 = ctx.get[List[AssemblyLine]](13)
+ List(fragment0, fragment1, fragment2, fragment3).flatten
+ },
+ )
+
+ def andConstant(const: Constant, mask: Int): Option[Long] = const match {
+ case NumericConstant(n, _) => Some(n & mask)
+ case _ => None
+ }
+
+ val UseAnc = new RuleBasedAssemblyOptimization("Using undocumented instruction ANC",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (Elidable & HasOpcode(LDA) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(CLC)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))),
+ (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.C)) ~~> (_ => List(AssemblyLine.immediate(ANC, 0))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~
+ (Elidable & HasOpcode(CLC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ Where(c => andConstant(c.get[Constant](0), 0x80).contains(0x80)) ~
+ (Elidable & HasOpcode(SEC)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ (Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ (Elidable & HasOpcode(CMP) & HasImmediate(0x80) & DoesntMatterWhatItDoesWith(State.Z, State.N)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0) & HasClear(State.C)) ~
+ Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0) & HasSet(State.C)) ~
+ Where(c => andConstant(c.get[Constant](0), 0x80).contains(0)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ (Elidable & HasOpcode(AND) & MatchImmediate(0)) ~
+ (Elidable & HasOpcodeIn(Set(ROL, ASL)) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.Z, State.N, State.A)) ~~> ((_, ctx) => List(AssemblyLine.immediate(ANC, ctx.get[Int](0)))),
+ )
+
+ val UseSbx = new RuleBasedAssemblyOptimization("Using undocumented instruction SBX",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (Elidable & HasOpcode(DEX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~
+ Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List(
+ AssemblyLine.implied(TXA),
+ AssemblyLine.immediate(SBX, ctx.get[Int](0)),
+ )),
+ (Elidable & HasOpcode(INX) & DoesntMatterWhatItDoesWith(State.A, State.C)).+.captureLength(0) ~
+ Where(_.get[Int](0) > 2) ~~> ((_, ctx) => List(
+ AssemblyLine.implied(TXA),
+ AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)),
+ )),
+ HasOpcode(TXA) ~
+ (Elidable & HasOpcode(CLC)).? ~
+ (Elidable & HasClear(State.C) & HasClear(State.D) & HasOpcode(ADC) & MatchImmediate(0)) ~
+ (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List(
+ code.head,
+ AssemblyLine.immediate(SBX, 256 - ctx.get[Int](0)),
+ )),
+ HasOpcode(TXA) ~
+ (Elidable & HasOpcode(SEC)).? ~
+ (Elidable & HasSet(State.C) & HasClear(State.D) & HasOpcode(SBC) & MatchImmediate(0)) ~
+ (Elidable & HasOpcode(TAX) & DoesntMatterWhatItDoesWith(State.C, State.A)) ~~> ((code, ctx) => List(
+ code.head,
+ AssemblyLine.immediate(SBX, ctx.get[Int](0)),
+ )),
+ )
+
+
+ val UseAlr = new RuleBasedAssemblyOptimization("Using undocumented instruction ALR",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ (Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~
+ (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~~> { (code, ctx) =>
+ List(AssemblyLine.immediate(ALR, code.head.parameter))
+ },
+ (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~
+ (Elidable & HasOpcode(CLC)) ~~> { (code, ctx) =>
+ List(AssemblyLine.immediate(ALR, 0xFE))
+ },
+ )
+
+ val UseArr = new RuleBasedAssemblyOptimization("Using undocumented instruction ARR",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ (HasClear(State.D) & Elidable & HasOpcode(AND) & HasAddrMode(Immediate)) ~
+ (Elidable & HasOpcode(ROR) & HasAddrMode(Implied) & DoesntMatterWhatItDoesWith(State.C, State.V)) ~~> { (code, ctx) =>
+ List(AssemblyLine.immediate(ARR, code.head.parameter))
+ },
+ )
+
+ private def trivialSequence1(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) =
+ (Elidable & HasOpcode(o1) & HasAddrMode(Absolute) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~
+ (Elidable & HasOpcode(o2) & HasAddrMode(Absolute) & MatchParameter(1)) ~~> { (code, ctx) =>
+ code.tail.init :+ AssemblyLine(combined, Absolute, ctx.get[Constant](1))
+ }
+
+ private def trivialSequence2(o1: Opcode.Value, o2: Opcode.Value, extra: AssemblyLinePattern, combined: Opcode.Value) =
+ (Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & DoesNotConcernMemoryAt(0, 1) & extra).* ~
+ (Elidable & HasOpcode(o2) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ code.tail.init :+ AssemblyLine(combined, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))
+ }
+
+ // ROL c LDA c AND d => LDA d RLA c
+ private def trivialCommutativeSequence(o1: Opcode.Value, o2: Opcode.Value, combined: Opcode.Value) = {
+ (Elidable & HasOpcode(o1) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(o2) & MatchAddrMode(2) & MatchParameter(3)) ~~> { code =>
+ List(code(2).copy(opcode = LDA), code(1).copy(opcode = combined))
+ }
+ }
+
+ val UseSlo = new RuleBasedAssemblyOptimization("Using undocumented instruction SLO",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ trivialSequence1(ASL, ORA, Not(ConcernsC), SLO),
+ trivialSequence2(ASL, ORA, Not(ConcernsC), SLO),
+ trivialCommutativeSequence(ASL, ORA, SLO),
+ (Elidable & HasOpcode(ASL) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ConcernsMemory)).* ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SLO, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)))
+ },
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~
+ (Elidable & HasOpcode(ASL) & HasAddrMode(Implied)) ~
+ (Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++
+ ctx.get[List[AssemblyLine]](2) ++
+ ctx.get[List[AssemblyLine]](3)
+ },
+ )
+
+ val UseSre = new RuleBasedAssemblyOptimization("Using undocumented instruction SRE",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ trivialSequence1(LSR, EOR, Not(ConcernsC), SRE),
+ trivialSequence2(LSR, EOR, Not(ConcernsC), SRE),
+ trivialCommutativeSequence(LSR, EOR, SRE),
+ (Elidable & HasOpcode(LSR) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ConcernsMemory)).* ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ code.tail.init ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1)))
+ },
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Linear & Not(ConcernsMemory) & Not(ChangesA)).*.capture(2) ~
+ (Elidable & HasOpcode(LSR) & HasAddrMode(Implied)) ~
+ (Linear & Not(ConcernsMemory) & Not(ChangesA) & Not(ReadsC) & Not(ReadsNOrZ)).*.capture(3) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { (code, ctx) =>
+ List(AssemblyLine.immediate(LDA, 0), AssemblyLine(SRE, ctx.get[AddrMode.Value](0), ctx.get[Constant](1))) ++
+ ctx.get[List[AssemblyLine]](2) ++
+ ctx.get[List[AssemblyLine]](3)
+ },
+ )
+
+ val UseRla = new RuleBasedAssemblyOptimization("Using undocumented instruction RLA",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ trivialSequence1(ROL, AND, Not(ConcernsC), RLA),
+ trivialSequence2(ROL, AND, Not(ConcernsC), RLA),
+ trivialCommutativeSequence(ROL, AND, RLA),
+ )
+
+ val UseRra = new RuleBasedAssemblyOptimization("Using undocumented instruction RRA",
+ needsFlowInfo = FlowInfoRequirement.NoRequirement,
+ // TODO: is it ok? carry flag and stuff?
+ trivialSequence1(ROR, ADC, Not(ConcernsC), RRA),
+ trivialSequence2(ROR, ADC, Not(ConcernsC), RRA),
+ trivialCommutativeSequence(ROR, ADC, RRA),
+ )
+
+ val UseDcp = new RuleBasedAssemblyOptimization("Using undocumented instruction DCP",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ trivialSequence1(DEC, CMP, Not(ConcernsC), DCP),
+ trivialSequence2(DEC, CMP, Not(ConcernsC), DCP),
+ (Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~
+ (Elidable & HasOpcode(TAX)) ~
+ (Elidable & HasOpcode(DEC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.head.copy(opcode = LDY), code.last.copy(opcode = DCP, addrMode = AbsoluteY))
+ },
+ (Elidable & HasOpcode(DEC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(CMP) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code =>
+ List(code(2).copy(opcode = LDA), code(1).copy(opcode = DCP))
+ }
+ )
+
+ val UseIsc = new RuleBasedAssemblyOptimization("Using undocumented instruction ISC",
+ needsFlowInfo = FlowInfoRequirement.BothFlows,
+ trivialSequence1(INC, SBC, Not(ReadsC), ISC),
+ trivialSequence2(INC, SBC, Not(ReadsC), ISC),
+ (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("is")
+ List(
+ AssemblyLine.relative(BCC, label),
+ code.last.copy(opcode = ISC),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(0) & HasClear(State.D)) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ val label = getNextLabel("is")
+ List(
+ AssemblyLine.relative(BCC, label),
+ code.last.copy(opcode = ISC),
+ AssemblyLine.label(label))
+ },
+ (Elidable & HasOpcode(CLC)).? ~
+ (Elidable & HasOpcode(LDA) & HasImmediate(1) & HasClear(State.D) & HasClear(State.C)) ~
+ (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = ISC))
+ },
+ (Elidable & HasOpcode(CLC)).? ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasClear(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(1)) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = ISC))
+ },
+ (Elidable & HasOpcode(SEC)).? ~
+ (Elidable & HasOpcode(LDA) & HasImmediate(0) & HasClear(State.D) & HasSet(State.C)) ~
+ (Elidable & HasOpcode(ADC) & MatchAddrMode(1) & MatchParameter(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchParameter(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = ISC))
+ },
+ (Elidable & HasOpcode(SEC)).? ~
+ (Elidable & HasOpcode(LDA) & MatchAddrMode(1) & HasClear(State.D) & HasSet(State.C) & MatchAddrMode(2) & HasAddrModeIn(Set(IndexedX, IndexedY, AbsoluteY))) ~
+ (Elidable & HasOpcode(ADC) & HasImmediate(0)) ~
+ (Elidable & HasOpcode(STA) & MatchAddrMode(1) & MatchAddrMode(2) & DoesntMatterWhatItDoesWith(State.A, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.last.copy(opcode = ISC))
+ },
+ (Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(IndexedX, ZeroPageX, AbsoluteX))) ~
+ (Elidable & HasOpcode(TAX)) ~
+ (Elidable & HasOpcode(INC) & HasAddrMode(AbsoluteX) & DoesntMatterWhatItDoesWith(State.A, State.Y, State.X, State.C, State.Z, State.N, State.V)) ~~> { code =>
+ List(code.head.copy(opcode = LDY), code.last.copy(opcode = ISC, addrMode = AbsoluteY))
+ },
+ (Elidable & HasOpcode(INC) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(LDA) & Not(HasAddrMode(Immediate)) & MatchAddrMode(0) & MatchParameter(1)) ~
+ (Elidable & HasOpcode(CMP) & HasClear(State.D) & MatchAddrMode(2) & MatchParameter(3) & DoesntMatterWhatItDoesWith(State.V, State.C, State.N, State.A)) ~~> { code =>
+ List(code(2).copy(opcode = LDA), AssemblyLine.implied(SEC), code(1).copy(opcode = ISC))
+ }
+ )
+
+ val All: List[AssemblyOptimization] = List(
+ UseLax,
+ UseSax,
+ UseSbx,
+ UseAnc,
+ UseSlo,
+ UseSre,
+ UseAlr,
+ UseArr,
+ UseRla,
+ UseRra,
+ UseIsc,
+ UseDcp,
+ )
+}
diff --git a/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala b/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala
new file mode 100644
index 00000000..dd919b48
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/UnusedLabelRemoval.scala
@@ -0,0 +1,38 @@
+package millfork.assembly.opt
+
+import millfork.CompilationOptions
+import millfork.assembly.AddrMode._
+import millfork.assembly.Opcode._
+import millfork.assembly.{AddrMode, AssemblyLine}
+import millfork.env._
+import millfork.error.ErrorReporting
+
+/**
+ * @author Karol Stasiak
+ */
+object UnusedLabelRemoval extends AssemblyOptimization {
+
+ override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
+ val usedLabels = code.flatMap {
+ case AssemblyLine(LABEL, _, _, _) => None
+ case AssemblyLine(_, _, MemoryAddressConstant(Label(l)), _) => Some(l)
+ case _ => None
+ }.toSet
+ val definedLabels = code.flatMap {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => Some(l).filter(_.startsWith("."))
+ case _ => None
+ }.toSet
+ val toRemove = definedLabels -- usedLabels
+ if (toRemove.nonEmpty) {
+ ErrorReporting.debug("Removing labels: " + toRemove.mkString(", "))
+ code.filterNot {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) => toRemove(l)
+ case _ => false
+ }
+ } else {
+ code
+ }
+ }
+
+ override def name = "Unused label removal"
+}
diff --git a/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala b/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala
new file mode 100644
index 00000000..e6bc2b12
--- /dev/null
+++ b/src/main/scala/millfork/assembly/opt/VariableToRegisterOptimization.scala
@@ -0,0 +1,322 @@
+package millfork.assembly.opt
+
+import millfork.CompilationOptions
+import millfork.assembly.{AddrMode, AssemblyLine}
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.env._
+import millfork.error.ErrorReporting
+
+import scala.annotation.tailrec
+
+/**
+ * @author Karol Stasiak
+ */
+object VariableToRegisterOptimization extends AssemblyOptimization {
+
+ // If any of these opcodes is present within a method,
+ // then it's too hard to assign any variable to a register.
+ private val opcodesThatAlwaysPrecludeXAllocation = Set(JSR, STX, TXA, PHX, PLX, INX, DEX, CPX, SBX, SAX)
+
+ private val opcodesThatAlwaysPrecludeYAllocation = Set(JSR, STY, TYA, PHY, PLY, INY, DEY, CPY)
+
+ // If any of these opcodes is used on a variable
+ // then it's too hard to assign that variable to a register.
+ // Also, LDY prevents assigning a variable to X and LDX prevents assigning a variable to Y.
+ private val opcodesThatCannotBeUsedWithIndexRegistersAsParameters =
+ Set(EOR, ORA, AND, BIT, ADC, SBC, CMP, CPX, CPY, STY, STX)
+
+ override def name = "Allocating variables to index registers"
+
+
+ override def optimize(f: NormalFunction, code: List[AssemblyLine], options: CompilationOptions): List[AssemblyLine] = {
+ val paramVariables = f.params match {
+ case NormalParamSignature(ps) =>
+ ps.map(_.name).toSet
+ case _ =>
+ // assembly functions do not get this optimization
+ return code
+ }
+ val stillUsedVariables = code.flatMap {
+ case AssemblyLine(_, _, MemoryAddressConstant(th), _) => Some(th.name)
+ case _ => None
+ }.toSet
+ val localVariables = f.environment.getAllLocalVariables.filter {
+ case MemoryVariable(name, typ, VariableAllocationMethod.Auto) =>
+ typ.size == 1 && !paramVariables(name) && stillUsedVariables(name)
+ case _ => false
+ }
+
+ val candidates = None :: localVariables.map(v => Option(v.name))
+
+ val variants = for {
+ vx <- candidates.par
+ vy <- candidates
+ if vx != vy
+ (score, prologueLength) <- canBeInlined(vx, vy, code.tail, Some(1))
+ if prologueLength >= 1
+ } yield (score, prologueLength, vx, vy)
+
+ if (variants.isEmpty) {
+ return code
+ }
+
+ val (_, bestPrologueLength, bestX, bestY) = variants.max
+
+ if ((bestX.isDefined || bestY.isDefined) && bestPrologueLength != 0xffff) {
+ (bestX, bestY) match {
+ case (Some(x), Some(y)) => ErrorReporting.debug(s"Inlining $x to X and $y to Y")
+ case (Some(x), None) => ErrorReporting.debug(s"Inlining $x to X")
+ case (None, Some(y)) => ErrorReporting.debug(s"Inlining $y to Y")
+ case _ =>
+ }
+ bestX.foreach(f.environment.removeVariable)
+ bestY.foreach(f.environment.removeVariable)
+ code.take(bestPrologueLength) ++ inlineVars(bestX, bestY, code.drop(bestPrologueLength))
+ } else {
+ code
+ }
+ }
+
+
+ private def add(i: Int) = (p: (Int, Int)) => (p._1 + i) -> p._2
+
+ private def mark(i: Option[Int]) = (p: (Int, Int)) => p._1 -> i.getOrElse(p._2)
+
+ def canBeInlined(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine], instrCounter: Option[Int]): Option[(Int, Int)] = {
+ val vx = xCandidate.getOrElse("-")
+ val vy = yCandidate.getOrElse("-")
+ val next = instrCounter.map(_ + 1)
+ val next2 = instrCounter.map(_ + 2)
+ lines match {
+ case AssemblyLine(_, Immediate, SubbyteConstant(MemoryAddressConstant(th), _), _) :: xs
+ if th.name == vx || th.name == vy =>
+ // if an address of a variable is used, then that variable cannot be assigned to a register
+ None
+ case AssemblyLine(_, Immediate, HalfWordConstant(MemoryAddressConstant(th), _), _) :: xs
+ if th.name == vx || th.name == vy =>
+ // if an address of a variable is used, then that variable cannot be assigned to a register
+ None
+
+ case AssemblyLine(_, AbsoluteX | AbsoluteY | ZeroPageX | ZeroPageY, MemoryAddressConstant(th), _) :: xs =>
+ // if a variable is used as an array, then it cannot be assigned to a register
+ if (th.name == vx || th.name == vy) {
+ None
+ } else {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ }
+
+ case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx && (opcode == LDY || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) =>
+ // if a variable is used by some opcodes, then it cannot be assigned to a register
+ None
+
+ case AssemblyLine(opcode, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy && (opcode == LDX || opcode == LAX || opcodesThatCannotBeUsedWithIndexRegistersAsParameters(opcode)) =>
+ // if a variable is used by some opcodes, then it cannot be assigned to a register
+ None
+
+ case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), elidable) :: xs
+ if xCandidate.isDefined =>
+ // if a register is populated with a different variable, then this variable cannot be assigned to that register
+ // removing LDX saves 3 cycles
+ if (elidable && th.name == vx) {
+ canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
+ } else {
+ None
+ }
+
+ case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), elidable) :: xs
+ if xCandidate.isDefined =>
+ // LAX = LDX-LDA, and since LDX simplifies to nothing and LDA simplifies to TXA,
+ // LAX simplifies to TXA, saving two bytes
+ if (elidable && th.name == vx) {
+ canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter))
+ } else {
+ None
+ }
+
+ case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), elidable) :: xs if yCandidate.isDefined =>
+ // if a register is populated with a different variable, then this variable cannot be assigned to that register
+ // removing LDX saves 3 cycles
+ if (elidable && th.name == vy) {
+ canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
+ } else {
+ None
+ }
+
+ case AssemblyLine(LDX, _, _, _) :: xs if xCandidate.isDefined =>
+ // if a register is populated with something else than a variable, then no variable cannot be assigned to that register
+ None
+
+ case AssemblyLine(LDY, _, _, _) :: xs if yCandidate.isDefined =>
+ // if a register is populated with something else than a variable, then no variable cannot be assigned to that register
+ None
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAX, _, _, elidable2) :: xs
+ if xCandidate.isDefined =>
+ // a variable cannot be inlined if there is TAX not after LDA of that variable
+ // but LDA-TAX can be simplified to TXA
+ if (elidable && elidable2 && th.name == vx) {
+ canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
+ } else {
+ None
+ }
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), elidable) :: AssemblyLine(TAY, _, _, elidable2) :: xs
+ if yCandidate.isDefined =>
+ // a variable cannot be inlined if there is TAY not after LDA of that variable
+ // but LDA-TAY can be simplified to TYA
+ if (elidable && elidable2 && th.name == vy) {
+ canBeInlined(xCandidate, yCandidate, xs, None).map(add(3)).map(mark(instrCounter))
+ } else {
+ None
+ }
+
+ case AssemblyLine(LDA | STA | INC | DEC, Absolute, MemoryAddressConstant(th), elidable) :: xs =>
+ // changing LDA->TXA, STA->TAX, INC->INX, DEC->DEX saves 2 cycles
+ if (th.name == vy || th.name == vx) {
+ if (elidable) canBeInlined(xCandidate, yCandidate, xs, None).map(add(2)).map(mark(instrCounter))
+ else None
+ } else {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ }
+
+ case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined =>
+ // a variable cannot be inlined if there is TAX not after LDA of that variable
+ if (instrCounter.isDefined) {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ } else None
+
+ case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined =>
+ // a variable cannot be inlined if there is TAY not after LDA of that variable
+ if (instrCounter.isDefined) {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ } else None
+
+ case AssemblyLine(LABEL, _, _, _) :: xs =>
+ // labels always end the initial section
+ canBeInlined(xCandidate, yCandidate, xs, None).map(mark(instrCounter))
+
+ case x :: xs =>
+ if (instrCounter.isDefined) {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ } else {
+ if (xCandidate.isDefined && opcodesThatAlwaysPrecludeXAllocation(x.opcode)) {
+ None
+ } else if (yCandidate.isDefined && opcodesThatAlwaysPrecludeYAllocation(x.opcode)) {
+ None
+ } else {
+ canBeInlined(xCandidate, yCandidate, xs, next)
+ }
+ }
+
+ case Nil => Some(0 -> -1)
+ }
+ }
+
+ def inlineVars(xCandidate: Option[String], yCandidate: Option[String], lines: List[AssemblyLine]): List[AssemblyLine] = {
+ val vx = xCandidate.getOrElse("-")
+ val vy = yCandidate.getOrElse("-")
+ lines match {
+ case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ AssemblyLine.implied(INX) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(INC, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy =>
+ AssemblyLine.implied(INY) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ AssemblyLine.implied(DEX) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(DEC, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy =>
+ AssemblyLine.implied(DEY) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDX, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LAX, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDY, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy =>
+ inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAX, _, _, true) :: xs
+ if th.name == vx =>
+ // these TXA's may get optimized away by a different optimization
+ AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), true) :: AssemblyLine(TAY, _, _, true) :: xs
+ if th.name == vy =>
+ // these TYA's may get optimized away by a different optimization
+ AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs
+ if th.name == vx && doesntUseX(am) =>
+ // these TXA's may get optimized away by a different optimization
+ AssemblyLine(LDX, am, param) :: AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, am, param, true) :: AssemblyLine(STA, Absolute, MemoryAddressConstant(th), true) :: xs
+ if th.name == vy && doesntUseY(am) =>
+ // these TYA's may get optimized away by a different optimization
+ AssemblyLine(LDY, am, param) :: AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs
+ if th.name == vx && doesntUseXOrY(am) =>
+ // ditto
+ AssemblyLine.implied(TXA) :: AssemblyLine(CPX, am, param) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: AssemblyLine(CMP, am, param, true) :: xs
+ if th.name == vy && doesntUseXOrY(am) =>
+ // ditto
+ AssemblyLine.implied(TYA) :: AssemblyLine(CPY, am, param) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ AssemblyLine.implied(TXA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(LDA, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy =>
+ AssemblyLine.implied(TYA) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vx =>
+ AssemblyLine.implied(TAX) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(STA, Absolute, MemoryAddressConstant(th), _) :: xs
+ if th.name == vy =>
+ AssemblyLine.implied(TAY) :: inlineVars(xCandidate, yCandidate, xs)
+
+ case AssemblyLine(TAX, _, _, _) :: xs if xCandidate.isDefined =>
+ ErrorReporting.fatal("Unexpected TAX")
+
+ case AssemblyLine(TAY, _, _, _) :: xs if yCandidate.isDefined =>
+ ErrorReporting.fatal("Unexpected TAY")
+
+ case x :: xs => x :: inlineVars(xCandidate, yCandidate, xs)
+
+ case Nil => Nil
+ }
+ }
+
+ def doesntUseY(am: AddrMode.Value): Boolean = am match {
+ case AbsoluteY | ZeroPageY | IndexedY => false
+ case _ => true
+ }
+
+ def doesntUseX(am: AddrMode.Value): Boolean = am match {
+ case AbsoluteX | ZeroPageX | IndexedX => false
+ case _ => true
+ }
+
+ def doesntUseXOrY(am: AddrMode.Value): Boolean = am match {
+ case Immediate | ZeroPage | Absolute | Relative | Indirect => true
+ case _ => false
+ }
+}
diff --git a/src/main/scala/millfork/cli/CliOption.scala b/src/main/scala/millfork/cli/CliOption.scala
new file mode 100644
index 00000000..91925f55
--- /dev/null
+++ b/src/main/scala/millfork/cli/CliOption.scala
@@ -0,0 +1,201 @@
+package millfork.cli
+
+/**
+ * @author Karol Stasiak
+ */
+trait CliOption[T, O <: CliOption[T, O]] {
+ this: O =>
+ def toStrings(firstTab: Int): List[String] = {
+ val fl = firstLine
+ if (_description == "") {
+ List(fl)
+ } else if (fl.length < firstTab) {
+ List(fl.padTo(firstTab, ' ') + _description)
+ } else {
+ List(fl, "".padTo(firstTab, ' ') + _description)
+ }
+ }
+
+ protected def firstLine: String = names.mkString(" | ")
+
+ def names: Seq[String]
+
+ private[cli] def length: Int
+
+ private[cli] val _shortName: String
+ private[cli] var _description: String = ""
+ private[cli] var _hidden = false
+ private[cli] var _maxEncounters = 1
+ private[cli] var _minEncounters = 0
+ private[cli] var _actualEncounters = 0
+ private[cli] var _onTooFew: Option[Int => Unit] = None
+ private[cli] var _onTooMany: Option[Int => Unit] = None
+
+ def validate(): Boolean = {
+ var ok = true
+ if (_actualEncounters < _minEncounters) {
+ _onTooFew.fold(throw new IllegalArgumentException(s"Too few ${_shortName} options: required ${_minEncounters}, given ${_actualEncounters}"))(_ (_actualEncounters))
+ ok = false
+ }
+ if (_actualEncounters > _maxEncounters) {
+ _onTooMany.fold()(_ (_actualEncounters))
+ ok = false
+ }
+ ok
+ }
+
+ def onWrongNumber(action: Int => Unit): Unit = {
+ _onTooFew = Some(action)
+ _onTooMany = Some(action)
+ }
+
+ def onTooFew(action: Int => Unit): Unit = {
+ _onTooFew = Some(action)
+ }
+
+ def onTooMany(action: Int => Unit): Unit = {
+ _onTooMany = Some(action)
+ }
+
+ def encounter(): Unit = {
+ _actualEncounters += 1
+ }
+
+ def description(d: String): O = {
+ _description = d
+ this
+ }
+
+ def hidden(): O = {
+ _hidden = true
+ this
+ }
+
+ def minCount(count: Int): O = {
+ _minEncounters = count
+ this
+ }
+
+ def maxCount(count: Int): O = {
+ _maxEncounters = count
+ this
+ }
+
+ def required(): O = minCount(1)
+
+ def repeatable(): O = maxCount(Int.MaxValue)
+}
+
+class Fluff[T](val text: Seq[String]) extends CliOption[T, Fluff[T]] {
+ this.repeatable()
+
+ override def toStrings(firstTab: Int): List[String] = text.toList
+
+ override def length = 0
+
+ override val _shortName = ""
+
+ override def names = Nil
+}
+
+class NoMoreOptions[T](val names: Seq[String]) extends CliOption[T, NoMoreOptions[T]] {
+ this.repeatable()
+
+ override def length = 1
+
+ override val _shortName = names.head
+}
+
+class UnknownParamOption[T] extends CliOption[T, UnknownParamOption[T]] {
+ this._hidden = true
+
+ override def length = 0
+
+ val names: Seq[String] = Nil
+ private var _action: ((String, T) => T) = (_, x) => x
+
+ def action(a: ((String, T) => T)): UnknownParamOption[T] = {
+ _action = a
+ this
+ }
+
+ def encounter(value: String, t: T): T = {
+ encounter()
+ _action(value, t)
+ }
+
+ override private[cli] val _shortName = ""
+}
+
+class FlagOption[T](val names: Seq[String]) extends CliOption[T, FlagOption[T]] {
+ override def length = 1
+
+ private var _action: (T => T) = x => x
+
+ def action(a: (T => T)): FlagOption[T] = {
+ _action = a
+ this
+ }
+
+ def encounter(t: T): T = {
+ encounter()
+ _action(t)
+ }
+
+ override val _shortName = names.head
+}
+
+class BooleanOption[T](val trueName: String, val falseName: String) extends CliOption[T, BooleanOption[T]] {
+ override def length = 1
+
+ private var _action: ((T,Boolean) => T) = (x,_) => x
+
+ def action(a: ((T,Boolean) => T)): BooleanOption[T] = {
+ _action = a
+ this
+ }
+
+ def encounter(asName: String, t: T): T = {
+ encounter()
+ if (asName == trueName) {
+ return _action(t, true)
+ }
+ if (asName == falseName) {
+ return _action(t, false)
+ }
+ t
+ }
+
+ override val _shortName = names.head
+
+ override protected def firstLine: String = trueName + " | " + falseName
+
+ override def names = Seq(trueName, falseName)
+}
+
+class ParamOption[T](val names: Seq[String]) extends CliOption[T, ParamOption[T]] {
+
+ override protected def firstLine: String = names.mkString(" | ") + " " + _paramPlaceholder
+
+ override def length = 2
+
+ private var _action: ((String, T) => T) = (_, x) => x
+ private var _paramPlaceholder: String = ""
+
+ def placeholder(p: String): ParamOption[T] = {
+ _paramPlaceholder = p
+ this
+ }
+
+ def action(a: ((String, T) => T)): ParamOption[T] = {
+ _action = a
+ this
+ }
+
+ def encounter(value: String, t: T): T = {
+ encounter()
+ _action(value, t)
+ }
+
+ override val _shortName = names.head
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/cli/CliParser.scala b/src/main/scala/millfork/cli/CliParser.scala
new file mode 100644
index 00000000..25c7563f
--- /dev/null
+++ b/src/main/scala/millfork/cli/CliParser.scala
@@ -0,0 +1,81 @@
+package millfork.cli
+
+import fastparse.core.Parsed.Failure
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+class CliParser[T] {
+
+ private val options = mutable.ArrayBuffer[CliOption[T, _]]()
+ private val mapFlags = mutable.Map[String, CliOption[T, _]]()
+ private val mapOptions = mutable.Map[String, CliOption[T, _]]()
+ private val _default = new UnknownParamOption[T]().action((p, _) => throw new IllegalArgumentException(s"Unknown option $p"))
+ private var _status: Option[CliStatus.Value] = None
+ options += _default
+
+ private def add[O <: CliOption[T, _]](o: O) = {
+ options += o
+ o.length match {
+ case 1 =>
+ o.names.foreach { n => mapFlags(n) = o }
+ case 2 =>
+ o.names.foreach { n => mapOptions(n) = o }
+ case _ => ()
+ }
+ o
+ }
+
+ def parse(context: T, args: List[String]): (CliStatus.Value, T) = {
+ val t = parseInner(context, args)
+ _status.getOrElse(if (options.forall(_.validate())) CliStatus.Ok else CliStatus.Failed) -> t
+ }
+
+ def assumeStatus(s: CliStatus.Value): Unit = {
+ _status = Some(s)
+ }
+
+ private def parseInner(context: T, args: List[String]): T = {
+ args match {
+ case k :: v :: xs if mapOptions.contains(k) =>
+ mapOptions(k) match {
+ case p: ParamOption[T] => parseInner(p.encounter(v, context), xs)
+ case _ => ???
+ }
+ case k :: xs if mapFlags.contains(k) =>
+ mapFlags(k) match {
+ case p: FlagOption[T] =>
+ parseInner(p.encounter(context), xs)
+ case p: BooleanOption[T] =>
+ parseInner(p.encounter(k, context), xs)
+ case p: NoMoreOptions[T] =>
+ p.encounter()
+ xs.foldLeft(context)((t, x) => _default.encounter(x, t))
+ case _ => ???
+ }
+ case x :: xs =>
+ parseInner(_default.encounter(x, context), xs)
+ case Nil => context
+ }
+ }
+
+
+ def fluff(text: String*): Unit = add(new Fluff[T](text))
+
+ def flag(names: String*): FlagOption[T] = add(new FlagOption[T](names))
+
+ def boolean(trueName: String, falseName: String): BooleanOption[T] = add(new BooleanOption[T](trueName, falseName))
+
+ def endOfFlags(names: String*): NoMoreOptions[T] = add(new NoMoreOptions[T](names))
+
+ def default: UnknownParamOption[T] = _default
+
+ def printHelp(firstTab: Int): List[String] = {
+ options.filterNot(_._hidden).toList.flatMap(_.toStrings(firstTab))
+ }
+
+ def parameter(names: String*): ParamOption[T] = add(new ParamOption[T](names))
+
+}
diff --git a/src/main/scala/millfork/cli/CliStatus.scala b/src/main/scala/millfork/cli/CliStatus.scala
new file mode 100644
index 00000000..800d5782
--- /dev/null
+++ b/src/main/scala/millfork/cli/CliStatus.scala
@@ -0,0 +1,8 @@
+package millfork.cli
+
+/**
+ * @author Karol Stasiak
+ */
+object CliStatus extends Enumeration {
+ val Ok, Failed, Quit = Value
+}
diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala
new file mode 100644
index 00000000..452962d6
--- /dev/null
+++ b/src/main/scala/millfork/compiler/BuiltIns.scala
@@ -0,0 +1,832 @@
+package millfork.compiler
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.assembly._
+import millfork.env._
+import millfork.node._
+import millfork.assembly.Opcode._
+import millfork.assembly.AddrMode._
+import millfork.error.ErrorReporting
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+import scala.reflect.macros.blackbox
+
+
+object ComparisonType extends Enumeration {
+ val Equal, NotEqual,
+ LessUnsigned, LessSigned,
+ GreaterUnsigned, GreaterSigned,
+ LessOrEqualUnsigned, LessOrEqualSigned,
+ GreaterOrEqualUnsigned, GreaterOrEqualSigned = Value
+
+ def flip(x: ComparisonType.Value): ComparisonType.Value = x match {
+ case LessUnsigned => GreaterUnsigned
+ case GreaterUnsigned => LessUnsigned
+ case LessOrEqualUnsigned => GreaterOrEqualUnsigned
+ case GreaterOrEqualUnsigned => LessOrEqualUnsigned
+ case LessSigned => GreaterSigned
+ case GreaterSigned => LessSigned
+ case LessOrEqualSigned => GreaterOrEqualSigned
+ case GreaterOrEqualSigned => LessOrEqualSigned
+ case _ => x
+ }
+
+ def negate(x: ComparisonType.Value): ComparisonType.Value = x match {
+ case LessUnsigned => GreaterOrEqualUnsigned
+ case GreaterUnsigned => LessOrEqualUnsigned
+ case LessOrEqualUnsigned => GreaterUnsigned
+ case GreaterOrEqualUnsigned => LessUnsigned
+ case LessSigned => GreaterOrEqualSigned
+ case GreaterSigned => LessOrEqualSigned
+ case LessOrEqualSigned => GreaterSigned
+ case GreaterOrEqualSigned => LessSigned
+ case Equal => NotEqual
+ case NotEqual => Equal
+ }
+}
+
+/**
+ * @author Karol Stasiak
+ */
+object BuiltIns {
+
+ object IndexChoice extends Enumeration {
+ val RequireX, PreferX, PreferY = Value
+ }
+
+ def wrapInSedCldIfNeeded(decimal: Boolean, code: List[AssemblyLine]): List[AssemblyLine] = {
+ if (decimal) {
+ AssemblyLine.implied(SED) :: (code :+ AssemblyLine.implied(CLD))
+ } else {
+ code
+ }
+ }
+
+ def staTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == STA) x.copy(opcode = op) else x)
+
+ def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x)
+
+ def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean): List[AssemblyLine] = {
+ val env = ctx.env
+ val parts: (List[AssemblyLine], List[AssemblyLine]) = env.eval(source).fold {
+ val b = env.get[Type]("byte")
+ source match {
+ case VariableExpression(name) =>
+ val v = env.get[Variable](name)
+ if (v.typ.size > 1) {
+ ErrorReporting.error(s"Variable `$name` is too big for a built-in operation", source.position)
+ return Nil
+ }
+ Nil -> AssemblyLine.variable(ctx, opcode, v)
+ case IndexedExpression(arrayName, index) =>
+ indexChoice match {
+ case IndexChoice.RequireX | IndexChoice.PreferX =>
+ val array = env.getArrayOrPointer(arrayName)
+ val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.X, b)), NoBranching)
+ val baseAddress = array match {
+ case c: ConstantThing => c.value
+ case a: MlArray => a.toAddress
+ }
+ calculateIndex -> List(AssemblyLine.absoluteX(opcode, baseAddress))
+ case IndexChoice.PreferY =>
+ val array = env.getArrayOrPointer(arrayName)
+ val calculateIndex = MlCompiler.compile(ctx, index, Some(b -> RegisterVariable(Register.Y, b)), NoBranching)
+ val baseAddress = array match {
+ case c: ConstantThing => c.value
+ case a: MlArray => a.toAddress
+ }
+ calculateIndex -> List(AssemblyLine.absoluteY(opcode, baseAddress))
+ }
+ case f: FunctionCallExpression if commutative =>
+ // TODO: is it ok?
+ return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), f, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(opcode, 0x101),
+ AssemblyLine.implied(INX),
+ AssemblyLine.implied(TXS))
+ case _ =>
+ ErrorReporting.error("Right-hand-side expression is too complex", source.position)
+ return Nil
+ }
+ } {
+ const =>
+ if (const.requiredSize > 1) {
+ ErrorReporting.error("Constant too big for a built-in operation", source.position)
+ }
+ Nil -> List(AssemblyLine.immediate(opcode, const))
+ }
+ val preparations = parts._1
+ val finalRead = parts._2
+ if (preserveA && AssemblyLine.treatment(preparations, State.A) != Treatment.Unchanged) {
+ AssemblyLine.implied(PHA) :: (preparations ++ (AssemblyLine.implied(PLA) :: finalRead))
+ } else {
+ preparations ++ finalRead
+ }
+ }
+
+ def insertBeforeLast(item: AssemblyLine, list: List[AssemblyLine]): List[AssemblyLine] = list match {
+ case Nil => Nil
+ case last :: dex :: txs :: Nil if dex.opcode == DEX && txs.opcode == TXS => item :: last :: dex :: txs :: Nil
+ case last :: inx :: txs :: Nil if inx.opcode == INX && txs.opcode == TXS => item :: last :: inx :: txs :: Nil
+ case last :: Nil => item :: last :: Nil
+ case first :: rest => first :: insertBeforeLast(item, rest)
+ }
+
+ def compileAddition(ctx: CompilationContext, params: List[(Boolean, Expression)], decimal: Boolean): List[AssemblyLine] = {
+ if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
+ ErrorReporting.warn("Unsupported decimal operation", ctx.options, params.head._2.position)
+ }
+ // if (params.isEmpty) {
+ // return Nil
+ // }
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val sortedParams = params.sortBy { case (subtract, expr) =>
+ val constPart = env.eval(expr) match {
+ case Some(NumericConstant(_, _)) => "Z"
+ case Some(_) => "Y"
+ case None => expr match {
+ case VariableExpression(_) => "V"
+ case IndexedExpression(_, LiteralExpression(_, _)) => "K"
+ case IndexedExpression(_, VariableExpression(_)) => "J"
+ case IndexedExpression(_, _) => "I"
+ case _ => "A"
+ }
+ }
+ val subtractPart = if (subtract) "X" else "P"
+ constPart + subtractPart
+ }
+ // TODO: merge constants
+ val normalizedParams = sortedParams
+
+ val h = normalizedParams.head
+ val firstParamCompiled = MlCompiler.compile(ctx, h._2, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ val firstParamSignCompiled = if (h._1) {
+ List(AssemblyLine.immediate(EOR, 0xff), AssemblyLine.implied(SEC), AssemblyLine.immediate(ADC, 0))
+ } else {
+ Nil
+ }
+
+ val remainingParamsCompiled = normalizedParams.tail.flatMap { p =>
+ if (p._1) {
+ insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false))
+ } else {
+ insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true))
+ }
+ }
+
+ wrapInSedCldIfNeeded(decimal, firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled)
+ }
+
+ def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = {
+ val b = ctx.env.get[Type]("byte")
+
+ val sortedParams = params.sortBy { expr =>
+ ctx.env.eval(expr) match {
+ case Some(NumericConstant(_, _)) => "Z"
+ case Some(_) => "Y"
+ case None => expr match {
+ case VariableExpression(_) => "V"
+ case IndexedExpression(_, LiteralExpression(_, _)) => "K"
+ case IndexedExpression(_, VariableExpression(_)) => "J"
+ case IndexedExpression(_, _) => "I"
+ case _ => "A"
+ }
+ }
+ }
+
+ val h = sortedParams.head
+ val firstParamCompiled = MlCompiler.compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+
+ val remainingParamsCompiled = sortedParams.tail.flatMap { p =>
+ simpleOperation(opcode, ctx, p, IndexChoice.PreferY, preserveA = true, commutative = true)
+ }
+
+ firstParamCompiled ++ remainingParamsCompiled
+ }
+
+ def compileShiftOps(opcode: Opcode.Value, ctx: CompilationContext, l: Expression, r: Expression): List[AssemblyLine] = {
+ val b = ctx.env.get[Type]("byte")
+ val firstParamCompiled = MlCompiler.compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ ctx.env.eval(r) match {
+ case Some(NumericConstant(0, _)) =>
+ Nil
+ case Some(NumericConstant(v, _)) if v > 0 =>
+ firstParamCompiled ++ List.fill(v.toInt)(AssemblyLine.implied(opcode))
+ case _ =>
+ ErrorReporting.error("Cannot shift by a non-constant amount")
+ Nil
+ }
+ }
+
+ def compileNonetOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val (ldaHi, ldaLo) = lhs match {
+ case v: VariableExpression =>
+ val variable = env.get[Variable](v.name)
+ AssemblyLine.variable(ctx, LDA, variable, 1) -> AssemblyLine.variable(ctx, LDA, variable, 0)
+ case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
+ AssemblyLine.variable(ctx, LDA, env.get[Variable](h.name), 0) -> AssemblyLine.variable(ctx, LDA, env.get[Variable](l.name), 0)
+ case _ =>
+ ???
+ }
+ env.eval(rhs) match {
+ case Some(NumericConstant(0, _)) =>
+ Nil
+ case Some(NumericConstant(shift, _)) if shift > 0 =>
+ if (ctx.options.flag(CompilationFlag.RorWarning))
+ ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position)
+ ldaHi ++ List(AssemblyLine.implied(ROR)) ++ ldaLo ++ List(AssemblyLine.implied(ROR)) ++ List.fill(shift.toInt - 1)(AssemblyLine.implied(LSR))
+ case _ =>
+ ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
+ Nil
+ }
+ }
+
+ def compileInPlaceByteShiftOps(opcode: Opcode.Value, ctx: CompilationContext, lhs: LhsExpression, rhs: Expression): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ env.eval(rhs) match {
+ case Some(NumericConstant(0, _)) =>
+ Nil
+ case Some(NumericConstant(v, _)) if v > 0 =>
+ val result = simpleOperation(opcode, ctx, lhs, IndexChoice.RequireX, preserveA = true, commutative = false)
+ result ++ List.fill(v.toInt - 1)(result.last)
+ case _ =>
+ ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
+ Nil
+ }
+ }
+
+ def compileInPlaceWordOrLongShiftOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, aslRatherThanLsr: Boolean): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val targetBytes = lhs match {
+ case v: VariableExpression =>
+ val variable = env.get[Variable](v.name)
+ List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
+ case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
+ List(
+ AssemblyLine.variable(ctx, STA, env.get[Variable](l.name)),
+ AssemblyLine.variable(ctx, STA, env.get[Variable](h.name)))
+ }
+ val lo = targetBytes.head
+ val hi = targetBytes.last
+ env.eval(rhs) match {
+ case Some(NumericConstant(0, _)) =>
+ Nil
+ case Some(NumericConstant(shift, _)) if shift > 0 =>
+ List.fill(shift.toInt)(if (aslRatherThanLsr) {
+ staTo(ASL, lo) ++ targetBytes.tail.flatMap { b => staTo(ROL, b) }
+ } else {
+ if (ctx.options.flag(CompilationFlag.RorWarning))
+ ErrorReporting.warn("ROR instruction generated", ctx.options, lhs.position)
+ staTo(LSR, hi) ++ targetBytes.reverse.tail.flatMap { b => staTo(ROR, b) }
+ }).flatten
+ case _ =>
+ ErrorReporting.error("Non-constant shift amount", rhs.position) // TODO
+ Nil
+ }
+ }
+
+ def compileByteComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ env.eval(rhs) match {
+ case Some(NumericConstant(0, _)) =>
+ compType match {
+ case ComparisonType.LessUnsigned =>
+ ErrorReporting.warn("Unsigned < 0 is always false", ctx.options, lhs.position)
+ case ComparisonType.LessOrEqualUnsigned =>
+ if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings))
+ ErrorReporting.warn("Unsigned <= 0 means the same as unsigned == 0", ctx.options, lhs.position)
+ case ComparisonType.GreaterUnsigned =>
+ if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings))
+ ErrorReporting.warn("Unsigned > 0 means the same as unsigned != 0", ctx.options, lhs.position)
+ case ComparisonType.GreaterOrEqualUnsigned =>
+ ErrorReporting.warn("Unsigned >= 0 is always true", ctx.options, lhs.position)
+ case _ =>
+ }
+ case Some(NumericConstant(1, _)) =>
+ if (ctx.options.flag(CompilationFlag.ExtraComparisonWarnings)) {
+ compType match {
+ case ComparisonType.LessUnsigned =>
+ ErrorReporting.warn("Unsigned < 1 means the same as unsigned == 0", ctx.options, lhs.position)
+ case ComparisonType.GreaterOrEqualUnsigned =>
+ ErrorReporting.warn("Unsigned >= 1 means the same as unsigned != 0", ctx.options, lhs.position)
+ case _ =>
+ }
+ }
+ case _ =>
+ }
+ val secondParamCompiledUnoptimized = simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false)
+ val secondParamCompiled = compType match {
+ case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned =>
+ secondParamCompiledUnoptimized match {
+ case List(AssemblyLine(CMP, Immediate, NumericConstant(0, _), true)) =>
+ if (OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) {
+ Nil
+ } else {
+ secondParamCompiledUnoptimized
+ }
+ case _ => secondParamCompiledUnoptimized
+ }
+ case _ => secondParamCompiledUnoptimized
+ }
+ val (effectiveComparisonType, label) = branches match {
+ case NoBranching => return Nil
+ case BranchIfTrue(l) => compType -> l
+ case BranchIfFalse(l) => ComparisonType.negate(compType) -> l
+ }
+ val branchingCompiled = effectiveComparisonType match {
+ case ComparisonType.Equal =>
+ List(AssemblyLine.relative(BEQ, Label(label)))
+ case ComparisonType.NotEqual =>
+ List(AssemblyLine.relative(BNE, Label(label)))
+
+ case ComparisonType.LessUnsigned =>
+ List(AssemblyLine.relative(BCC, Label(label)))
+ case ComparisonType.GreaterOrEqualUnsigned =>
+ List(AssemblyLine.relative(BCS, Label(label)))
+ case ComparisonType.LessOrEqualUnsigned =>
+ List(AssemblyLine.relative(BCC, Label(label)), AssemblyLine.relative(BEQ, Label(label)))
+ case ComparisonType.GreaterUnsigned =>
+ val x = MlCompiler.nextLabel("co")
+ List(
+ AssemblyLine.relative(BEQ, x),
+ AssemblyLine.relative(BCS, Label(label)),
+ AssemblyLine.label(x))
+
+ case ComparisonType.LessSigned =>
+ List(AssemblyLine.relative(BMI, Label(label)))
+ case ComparisonType.GreaterOrEqualSigned =>
+ List(AssemblyLine.relative(BPL, Label(label)))
+ case ComparisonType.LessOrEqualSigned =>
+ List(AssemblyLine.relative(BMI, Label(label)), AssemblyLine.relative(BEQ, Label(label)))
+ case ComparisonType.GreaterSigned =>
+ val x = MlCompiler.nextLabel("co")
+ List(
+ AssemblyLine.relative(BEQ, x),
+ AssemblyLine.relative(BPL, Label(label)),
+ AssemblyLine.label(x))
+ }
+ firstParamCompiled ++ secondParamCompiled ++ branchingCompiled
+
+ }
+
+ def compileWordComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, branches: BranchSpec): List[AssemblyLine] = {
+ val env = ctx.env
+ // TODO: comparing stack variables
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+
+ val (effectiveComparisonType, x) = branches match {
+ case NoBranching => return Nil
+ case BranchIfTrue(label) => compType -> label
+ case BranchIfFalse(label) => ComparisonType.negate(compType) -> label
+ }
+ val (lh, ll, rh, rl, ram) = (lhs, env.eval(lhs), rhs, env.eval(rhs)) match {
+ case (_, Some(NumericConstant(lc, _)), _, Some(NumericConstant(rc, _))) =>
+ return if (effectiveComparisonType match {
+ // TODO: those masks are probably wrong
+ case ComparisonType.Equal =>
+ (lc & 0xffff) == (rc & 0xffff) // ??
+ case ComparisonType.NotEqual =>
+ (lc & 0xffff) != (rc & 0xffff) // ??
+
+ case ComparisonType.LessOrEqualUnsigned =>
+ (lc & 0xffff) <= (rc & 0xffff)
+ case ComparisonType.GreaterOrEqualUnsigned =>
+ (lc & 0xffff) >= (rc & 0xffff)
+ case ComparisonType.GreaterUnsigned =>
+ (lc & 0xffff) > (rc & 0xffff)
+ case ComparisonType.LessUnsigned =>
+ (lc & 0xffff) < (rc & 0xffff)
+
+ case ComparisonType.LessOrEqualSigned =>
+ lc.toShort <= rc.toShort
+ case ComparisonType.GreaterOrEqualSigned =>
+ lc.toShort >= rc.toShort
+ case ComparisonType.GreaterSigned =>
+ lc.toShort > rc.toShort
+ case ComparisonType.LessSigned =>
+ lc.toShort < rc.toShort
+ }) List(AssemblyLine.absolute(JMP, Label(x))) else Nil
+ case (_, Some(lc), _, Some(rc)) =>
+ // TODO: comparing late-bound constants
+ ???
+ case (_, Some(lc), rv: VariableInMemory, None) =>
+ return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches)
+ case (v: VariableExpression, None, _, Some(rc)) =>
+ // TODO: stack variables
+ (env.get[VariableInMemory](v.name + ".hi").toAddress,
+ env.get[VariableInMemory](v.name + ".lo").toAddress,
+ rc.hiByte,
+ rc.loByte,
+ Immediate)
+ case (lv: VariableExpression, None, rv: VariableExpression, None) =>
+ // TODO: stack variables
+ (env.get[VariableInMemory](lv.name + ".hi").toAddress,
+ env.get[VariableInMemory](lv.name + ".lo").toAddress,
+ env.get[VariableInMemory](rv.name + ".hi").toAddress,
+ env.get[VariableInMemory](rv.name + ".lo").toAddress, Absolute)
+ }
+ effectiveComparisonType match {
+ case ComparisonType.Equal =>
+ val innerLabel = MlCompiler.nextLabel("cp")
+ List(AssemblyLine.absolute(LDA, ll),
+ AssemblyLine(CMP, ram, rl),
+ AssemblyLine.relative(BNE, innerLabel),
+ AssemblyLine.absolute(LDA, lh),
+ AssemblyLine(CMP, ram, rh),
+ AssemblyLine.relative(BEQ, Label(x)),
+ AssemblyLine.label(innerLabel))
+
+ case ComparisonType.NotEqual =>
+ List(AssemblyLine.absolute(LDA, ll),
+ AssemblyLine(CMP, ram, rl),
+ AssemblyLine.relative(BNE, Label(x)),
+ AssemblyLine.absolute(LDA, lh),
+ AssemblyLine(CMP, ram, rh),
+ AssemblyLine.relative(BNE, Label(x)))
+
+ case ComparisonType.LessUnsigned =>
+ val innerLabel = MlCompiler.nextLabel("cp")
+ List(AssemblyLine.absolute(LDA, lh),
+ AssemblyLine(CMP, ram, rh),
+ AssemblyLine.relative(BCC, Label(x)),
+ AssemblyLine.relative(BNE, innerLabel),
+ AssemblyLine.absolute(LDA, ll),
+ AssemblyLine(CMP, ram, rl),
+ AssemblyLine.relative(BCC, Label(x)),
+ AssemblyLine.label(innerLabel))
+
+ case ComparisonType.LessOrEqualUnsigned =>
+ val innerLabel = MlCompiler.nextLabel("cp")
+ List(AssemblyLine(LDA, ram, rh),
+ AssemblyLine.absolute(CMP, lh),
+ AssemblyLine.relative(BCC, innerLabel),
+ AssemblyLine.relative(BNE, x),
+ AssemblyLine(LDA, ram, rl),
+ AssemblyLine.absolute(CMP, ll),
+ AssemblyLine.relative(BCS, x),
+ AssemblyLine.label(innerLabel))
+
+ case ComparisonType.GreaterUnsigned =>
+ val innerLabel = MlCompiler.nextLabel("cp")
+ List(AssemblyLine(LDA, ram, rh),
+ AssemblyLine.absolute(CMP, lh),
+ AssemblyLine.relative(BCC, Label(x)),
+ AssemblyLine.relative(BNE, innerLabel),
+ AssemblyLine(LDA, ram, rl),
+ AssemblyLine.absolute(CMP, ll),
+ AssemblyLine.relative(BCC, Label(x)),
+ AssemblyLine.label(innerLabel))
+
+ case ComparisonType.GreaterOrEqualUnsigned =>
+ val innerLabel = MlCompiler.nextLabel("cp")
+ List(AssemblyLine.absolute(LDA, lh),
+ AssemblyLine(CMP, ram, rh),
+ AssemblyLine.relative(BCC, innerLabel),
+ AssemblyLine.relative(BNE, x),
+ AssemblyLine.absolute(LDA, ll),
+ AssemblyLine(CMP, ram, rl),
+ AssemblyLine.relative(BCS, x),
+ AssemblyLine.label(innerLabel))
+
+ case _ => ???
+ // TODO: signed word comparisons
+ }
+ }
+
+ def compileInPlaceByteMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = {
+ val b = ctx.env.get[Type]("byte")
+ ctx.env.eval(addend) match {
+ case Some(NumericConstant(0, _)) =>
+ AssemblyLine.immediate(LDA, 0) :: MlCompiler.compileByteStorage(ctx, Register.A, v)
+ case Some(NumericConstant(1, _)) =>
+ Nil
+ case Some(NumericConstant(x, _)) =>
+ compileByteMultiplication(ctx, v, x.toInt) ++ MlCompiler.compileByteStorage(ctx, Register.A, v)
+ case _ =>
+ ErrorReporting.error("Multiplying by not a constant not supported", v.position)
+ Nil
+ }
+ }
+
+ def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = {
+ val result = ListBuffer[AssemblyLine]()
+ // TODO: optimise
+ val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false)
+ val adc = addingCode.last
+ val indexing = addingCode.init
+ result ++= indexing
+ result += AssemblyLine.immediate(LDA, 0)
+ val mult = c & 0xff
+ var mask = 128
+ var empty = true
+ while (mask > 0) {
+ if (!empty) {
+ result += AssemblyLine.implied(ASL)
+ }
+ if ((mult & mask) != 0) {
+ result ++= List(AssemblyLine.implied(CLC), adc)
+ empty = false
+ }
+
+ mask >>>= 1
+ }
+ result.toList
+ }
+
+ def compileByteMultiplication(ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = {
+ val (constants, variables) = params.map(p => p -> ctx.env.eval(p)).partition(_._2.exists(_.isInstanceOf[NumericConstant]))
+ val constant = constants.map(_._2.get.asInstanceOf[NumericConstant].value).foldLeft(1L)(_ * _).toInt
+ variables.length match {
+ case 0 => List(AssemblyLine.immediate(LDA, constant & 0xff))
+ case 1 =>compileByteMultiplication(ctx, variables.head._1, constant)
+ case 2 =>
+ ErrorReporting.error("Multiplying by not a constant not supported", params.head.position)
+ Nil
+ }
+ }
+
+ def compileInPlaceByteAddition(ctx: CompilationContext, v: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = {
+ if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
+ ErrorReporting.warn("Unsupported decimal operation", ctx.options, v.position)
+ }
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ env.eval(addend) match {
+ case Some(NumericConstant(0, _)) => Nil
+ case Some(NumericConstant(1, _)) if !decimal => if (subtract) {
+ simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
+ } else {
+ simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
+ }
+ // TODO: compile +=2 to two INCs
+ case Some(NumericConstant(-1, _)) if !decimal => if (subtract) {
+ simpleOperation(INC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
+ } else {
+ simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true)
+ }
+ case _ =>
+ val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ val modifyLhs = if (subtract) {
+ insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false))
+ } else {
+ insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true))
+ }
+ val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v)
+ wrapInSedCldIfNeeded(decimal, loadLhs ++ modifyLhs ++ storeLhs)
+ }
+ }
+
+ def compileInPlaceWordOrLongAddition(ctx: CompilationContext, lhs: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = {
+ if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode)) {
+ ErrorReporting.warn("Unsupported decimal operation", ctx.options, lhs.position)
+ }
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+ val targetBytes: List[List[AssemblyLine]] = lhs match {
+ case v: VariableExpression =>
+ val variable = env.get[Variable](v.name)
+ List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
+ case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
+ val lv = env.get[Variable](l.name)
+ val hv = env.get[Variable](h.name)
+ List(
+ AssemblyLine.variable(ctx, STA, lv),
+ AssemblyLine.variable(ctx, STA, hv))
+ }
+ val lhsIsStack = targetBytes.head.head.opcode == TSX
+ val targetSize = targetBytes.size
+ val addendType = MlCompiler.getExpressionType(ctx, addend)
+ var addendSize = addendType.size
+
+ def isRhsComplex(xs: List[AssemblyLine]): Boolean = xs match {
+ case AssemblyLine(LDA, _, _, _) :: Nil => false
+ case AssemblyLine(LDA, _, _, _) :: AssemblyLine(LDX, _, _, _) :: Nil => false
+ case _ => true
+ }
+
+ def isRhsStack(xs: List[AssemblyLine]): Boolean = xs.exists(_.opcode == TSX)
+
+ val (calculateRhs, addendByteRead0): (List[AssemblyLine], List[List[AssemblyLine]]) = env.eval(addend) match {
+ case Some(constant) =>
+ addendSize = targetSize
+ Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i))))
+ case None =>
+ addendSize match {
+ case 1 =>
+ val base = MlCompiler.compile(ctx, addend, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ if (subtract) {
+ if (isRhsComplex(base)) {
+ if (isRhsStack(base)) {
+ ErrorReporting.warn("Subtracting a stack-based value", ctx.options)
+ ???
+ }
+ (base ++ List(AssemblyLine.implied(PHA))) -> List(List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101)))
+ } else {
+ Nil -> base.map(_ :: Nil)
+ }
+ } else {
+ base -> List(Nil)
+ }
+ case 2 =>
+ val base = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AX, w)), NoBranching)
+ if (isRhsStack(base)) {
+ val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching)
+ if (subtract) {
+ ErrorReporting.warn("Subtracting a stack-based value", ctx.options)
+ if (isRhsComplex(base)) {
+ ???
+ } else {
+ Nil -> fixedBase
+ ???
+ }
+ } else {
+ fixedBase -> List(Nil, List(AssemblyLine.implied(TYA)))
+ }
+ } else {
+ if (subtract) {
+ if (isRhsComplex(base)) {
+ (base ++ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(PHA))
+ ) -> List(
+ List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x102)),
+ List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, 0x101)))
+ } else {
+ Nil -> base.map(_ :: Nil)
+ }
+ } else {
+ if (lhsIsStack) {
+ val fixedBase = MlCompiler.compile(ctx, addend, Some(w -> RegisterVariable(Register.AY, w)), NoBranching)
+ fixedBase -> List(Nil, List(AssemblyLine.implied(TYA)))
+ } else {
+ base -> List(Nil, List(AssemblyLine.implied(TXA)))
+ }
+ }
+ }
+ case _ => Nil -> (addend match {
+ case vv: VariableExpression =>
+ val source = env.get[Variable](vv.name)
+ List.tabulate(addendSize)(i => AssemblyLine.variable(ctx, LDA, source, i))
+ })
+ }
+ }
+ val addendByteRead = addendByteRead0 ++ List.fill((targetSize - addendByteRead0.size) max 0)(List(AssemblyLine.immediate(LDA, 0)))
+ val buffer = mutable.ListBuffer[AssemblyLine]()
+ buffer ++= calculateRhs
+ buffer += AssemblyLine.implied(if (subtract) SEC else CLC)
+ val extendMultipleBytes = targetSize > addendSize + 1
+ val extendAtLeastOneByte = targetSize > addendSize
+ for (i <- 0 until targetSize) {
+ if (subtract) {
+ if (addendSize < targetSize && addendType.isSigned) {
+ // TODO: sign extension
+ ???
+ }
+ buffer ++= staTo(LDA, targetBytes(i))
+ buffer ++= ldTo(SBC, addendByteRead(i))
+ buffer ++= targetBytes(i)
+ } else {
+ if (i >= addendSize) {
+ if (addendType.isSigned) {
+ val label = MlCompiler.nextLabel("sx")
+ buffer += AssemblyLine.implied(TXA)
+ if (i == addendSize) {
+ buffer += AssemblyLine.immediate(ORA, 0x7f)
+ buffer += AssemblyLine.relative(BMI, label)
+ buffer += AssemblyLine.immediate(LDA, 0)
+ buffer += AssemblyLine.label(label)
+ if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX)
+ }
+ } else {
+ buffer += AssemblyLine.immediate(LDA, 0)
+ }
+ } else {
+ buffer ++= addendByteRead(i)
+ if (addendType.isSigned && i == addendSize - 1 && extendAtLeastOneByte) {
+ buffer += AssemblyLine.implied(TAX)
+ }
+ }
+ buffer ++= staTo(ADC, targetBytes(i))
+ buffer ++= targetBytes(i)
+ }
+ }
+ for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) {
+ buffer += AssemblyLine.implied(PLA)
+ }
+ wrapInSedCldIfNeeded(decimal, buffer.toList)
+ }
+
+ def compileInPlaceByteBitOp(ctx: CompilationContext, v: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ (operation, env.eval(param)) match {
+ case (EOR, Some(NumericConstant(0, _)))
+ | (ORA, Some(NumericConstant(0, _)))
+ | (AND, Some(NumericConstant(0xff, _)))
+ | (AND, Some(NumericConstant(-1, _))) =>
+ Nil
+ case _ =>
+ val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ val modifyLhs = simpleOperation(operation, ctx, param, IndexChoice.PreferY, preserveA = true, commutative = true)
+ val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v)
+ loadLhs ++ modifyLhs ++ storeLhs
+ }
+ }
+
+
+ def compileInPlaceWordOrLongBitOp(ctx: CompilationContext, lhs: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+ val targetBytes: List[List[AssemblyLine]] = lhs match {
+ case v: VariableExpression =>
+ val variable = env.get[Variable](v.name)
+ List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
+ case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
+ val lv = env.get[Variable](l.name)
+ val hv = env.get[Variable](h.name)
+ List(
+ AssemblyLine.variable(ctx, STA, lv),
+ AssemblyLine.variable(ctx, STA, hv))
+ case _ =>
+ ???
+ }
+ val lo = targetBytes.head
+ val targetSize = targetBytes.size
+ val paramType = MlCompiler.getExpressionType(ctx, param)
+ var paramSize = paramType.size
+ val extendMultipleBytes = targetSize > paramSize + 1
+ val extendAtLeastOneByte = targetSize > paramSize
+ val (calculateRhs, addendByteRead) = env.eval(param) match {
+ case Some(constant) =>
+ paramSize = targetSize
+ Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i))))
+ case None =>
+ paramSize match {
+ case 1 =>
+ val base = MlCompiler.compile(ctx, param, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
+ base -> List(Nil)
+ case 2 =>
+ val base = MlCompiler.compile(ctx, param, Some(w -> RegisterVariable(Register.AX, w)), NoBranching)
+ base -> List(Nil, List(AssemblyLine.implied(TXA)))
+ case _ => Nil -> (param match {
+ case vv: VariableExpression =>
+ val source = env.get[Variable](vv.name)
+ List.tabulate(paramSize)(i => AssemblyLine.variable(ctx, LDA, source, i))
+ })
+ }
+ }
+ val AllOnes = (1L << (8 * targetSize)) - 1
+ (operation, env.eval(param)) match {
+ case (EOR, Some(NumericConstant(0, _)))
+ | (ORA, Some(NumericConstant(0, _)))
+ | (AND, Some(NumericConstant(AllOnes, _))) =>
+ Nil
+ case _ =>
+ val buffer = mutable.ListBuffer[AssemblyLine]()
+ buffer ++= calculateRhs
+ for (i <- 0 until targetSize) {
+ if (i < paramSize) {
+ buffer ++= addendByteRead(i)
+ if (paramType.isSigned && i == paramSize - 1 && extendAtLeastOneByte) {
+ buffer += AssemblyLine.implied(TAX)
+ }
+ } else {
+ if (paramType.isSigned) {
+ val label = MlCompiler.nextLabel("sx")
+ buffer += AssemblyLine.implied(TXA)
+ if (i == paramSize) {
+ buffer += AssemblyLine.immediate(ORA, 0x7f)
+ buffer += AssemblyLine.relative(BMI, label)
+ buffer += AssemblyLine.immediate(LDA, 0)
+ buffer += AssemblyLine.label(label)
+ if (extendMultipleBytes) buffer += AssemblyLine.implied(TAX)
+ }
+ } else {
+ buffer += AssemblyLine.immediate(LDA, 0)
+ }
+ }
+ buffer ++= staTo(operation, targetBytes(i))
+ buffer ++= targetBytes(i)
+ }
+ for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) {
+ buffer += AssemblyLine.implied(PLA)
+ }
+ buffer.toList
+ }
+ }
+
+
+}
diff --git a/src/main/scala/millfork/compiler/CompilationContext.scala b/src/main/scala/millfork/compiler/CompilationContext.scala
new file mode 100644
index 00000000..a321eb54
--- /dev/null
+++ b/src/main/scala/millfork/compiler/CompilationContext.scala
@@ -0,0 +1,12 @@
+package millfork.compiler
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.env.{Environment, MangledFunction, NormalFunction}
+
+/**
+ * @author Karol Stasiak
+ */
+case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){
+
+ def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i)
+}
diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala
new file mode 100644
index 00000000..a977a2ec
--- /dev/null
+++ b/src/main/scala/millfork/compiler/MfCompiler.scala
@@ -0,0 +1,1675 @@
+package millfork.compiler
+
+import java.util.concurrent.atomic.AtomicLong
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.assembly._
+import millfork.env._
+import millfork.node.{Register, _}
+import millfork.assembly.AddrMode._
+import millfork.assembly.Opcode._
+import millfork.error.ErrorReporting
+
+import scala.collection.JavaConverters._
+
+/**
+ * @author Karol Stasiak
+ */
+
+sealed trait BranchSpec {
+ def flip: BranchSpec
+}
+
+case object NoBranching extends BranchSpec {
+ override def flip = this
+}
+
+case class BranchIfTrue(label: String) extends BranchSpec {
+ override def flip = BranchIfFalse(label)
+}
+
+case class BranchIfFalse(label: String) extends BranchSpec {
+ override def flip = BranchIfTrue(label)
+}
+
+object BranchSpec {
+ val None = NoBranching
+}
+
+//noinspection NotImplementedCode,ScalaUnusedSymbol
+object MlCompiler {
+
+
+ private var labelCounter = new AtomicLong
+
+ def nextLabel(prefix: String): String = "." + prefix + "__" + labelCounter.incrementAndGet().formatted("%05d")
+
+ def compile(ctx: CompilationContext): Chunk = {
+ val chunk = compile(ctx, ctx.function.code)
+ val prefix = (if (ctx.function.interrupt) {
+ if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) {
+ List(
+ AssemblyLine.implied(SEI),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(PHX),
+ AssemblyLine.implied(PHY),
+ AssemblyLine.implied(CLD))
+ } else {
+ List(
+ AssemblyLine.implied(SEI),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TYA),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(CLD))
+ }
+ } else Nil) ++ stackPointerFixAtBeginning(ctx)
+ if (prefix.nonEmpty) {
+ LabelledChunk(ctx.function.name, SequenceChunk(List(LinearChunk(prefix), chunk)))
+ } else {
+ LabelledChunk(ctx.function.name, chunk)
+ }
+ }
+
+ def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = {
+ val paramsWithTypes = f.expressions.map(x => getExpressionType(ctx, x) -> x)
+ ctx.env.lookupFunction(f.functionName, paramsWithTypes).getOrElse(
+ ErrorReporting.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1)}`", f.position))
+ }
+
+ def getExpressionType(ctx: CompilationContext, expr: Expression): Type = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val bool = env.get[Type]("bool$")
+ val v = env.get[Type]("void")
+ val w = env.get[Type]("word")
+ val l = env.get[Type]("long")
+ expr match {
+ case LiteralExpression(value, size) =>
+ size match {
+ case 1 => b
+ case 2 => w
+ case 3 | 4 => l
+ }
+ case VariableExpression(name) =>
+ env.get[TypedThing](name, expr.position).typ
+ case HalfWordExpression(param, _) =>
+ getExpressionType(ctx, param)
+ b
+ case IndexedExpression(_, _) => b
+ case SeparateBytesExpression(h, l) =>
+ if (getExpressionType(ctx, h).size > 1) ErrorReporting.error("Hi byte too large", h.position)
+ if (getExpressionType(ctx, l).size > 1) ErrorReporting.error("Lo byte too large", l.position)
+ w
+ case SumExpression(params, _) => b
+ case FunctionCallExpression("not", params) => bool
+ case FunctionCallExpression("*", params) => b
+ case FunctionCallExpression("|", params) => b
+ case FunctionCallExpression("&", params) => b
+ case FunctionCallExpression("^", params) => b
+ case FunctionCallExpression("<<", params) => b
+ case FunctionCallExpression(">>", params) => b
+ case FunctionCallExpression("<<'", params) => b
+ case FunctionCallExpression(">>'", params) => b
+ case FunctionCallExpression(">>>>", params) => b
+ case FunctionCallExpression("&&", params) => bool
+ case FunctionCallExpression("||", params) => bool
+ case FunctionCallExpression("^^", params) => bool
+ case FunctionCallExpression("==", params) => bool
+ case FunctionCallExpression("!=", params) => bool
+ case FunctionCallExpression("<", params) => bool
+ case FunctionCallExpression(">", params) => bool
+ case FunctionCallExpression("<=", params) => bool
+ case FunctionCallExpression(">=", params) => bool
+ case FunctionCallExpression("+=", params) => v
+ case FunctionCallExpression("-=", params) => v
+ case FunctionCallExpression("*=", params) => v
+ case FunctionCallExpression("+'=", params) => v
+ case FunctionCallExpression("-'=", params) => v
+ case FunctionCallExpression("*'=", params) => v
+ case FunctionCallExpression("|=", params) => v
+ case FunctionCallExpression("&=", params) => v
+ case FunctionCallExpression("^=", params) => v
+ case FunctionCallExpression("<<=", params) => v
+ case FunctionCallExpression(">>=", params) => v
+ case FunctionCallExpression("<<'=", params) => v
+ case FunctionCallExpression(">>'=", params) => v
+ case f@FunctionCallExpression(name, params) =>
+ lookupFunction(ctx, f).returnType
+ }
+ }
+
+ def compileConstant(ctx: CompilationContext, expr: Constant, target: Variable): List[AssemblyLine] = {
+ target match {
+ case RegisterVariable(Register.A, _) => List(AssemblyLine(LDA, Immediate, expr))
+ case RegisterVariable(Register.X, _) => List(AssemblyLine(LDX, Immediate, expr))
+ case RegisterVariable(Register.Y, _) => List(AssemblyLine(LDY, Immediate, expr))
+ case RegisterVariable(Register.AX, _) => List(
+ AssemblyLine(LDA, Immediate, expr.loByte),
+ AssemblyLine(LDX, Immediate, expr.hiByte))
+ case RegisterVariable(Register.AY, _) => List(
+ AssemblyLine(LDA, Immediate, expr.loByte),
+ AssemblyLine(LDY, Immediate, expr.hiByte))
+ case RegisterVariable(Register.XA, _) => List(
+ AssemblyLine(LDA, Immediate, expr.hiByte),
+ AssemblyLine(LDX, Immediate, expr.loByte))
+ case RegisterVariable(Register.YA, _) => List(
+ AssemblyLine(LDA, Immediate, expr.hiByte),
+ AssemblyLine(LDY, Immediate, expr.loByte))
+ case m: VariableInMemory =>
+ val addr = m.toAddress
+ m.typ.size match {
+ case 0 => Nil
+ case 1 => List(
+ AssemblyLine(LDA, Immediate, expr.loByte),
+ AssemblyLine(STA, Absolute, addr))
+ case 2 => List(
+ AssemblyLine(LDA, Immediate, expr.loByte),
+ AssemblyLine(STA, Absolute, addr),
+ AssemblyLine(LDA, Immediate, expr.hiByte),
+ AssemblyLine(STA, Absolute, addr + 1))
+ case s => List.tabulate(s)(i => List(
+ AssemblyLine(LDA, Immediate, expr.subbyte(i)),
+ AssemblyLine(STA, Absolute, addr + i))).flatten
+ }
+ case StackVariable(_, t, offset) =>
+ t.size match {
+ case 0 => Nil
+ case 1 => List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.immediate(LDA, expr.loByte),
+ AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset))
+ case 2 => List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.immediate(LDA, expr.loByte),
+ AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset),
+ AssemblyLine.immediate(LDA, expr.hiByte),
+ AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + 1))
+ case s => AssemblyLine.implied(TSX) :: List.tabulate(s)(i => List(
+ AssemblyLine.immediate(LDA, expr.subbyte(i)),
+ AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + i))).flatten
+ }
+ }
+ }
+
+ def fixTsx(code: List[AssemblyLine]): List[AssemblyLine] = code match {
+ case (tsx@AssemblyLine(TSX, _, _, _)) :: xs => tsx :: AssemblyLine.implied(INX) :: fixTsx(xs)
+ case (txs@AssemblyLine(TXS, _, _, _)) :: xs => ???
+ case x :: xs => x :: fixTsx(xs)
+ case Nil => Nil
+ }
+
+ def preserveRegisterIfNeeded(ctx: CompilationContext, register: Register.Value, code: List[AssemblyLine]): List[AssemblyLine] = {
+ val state = register match {
+ case Register.A => State.A
+ case Register.X => State.X
+ case Register.Y => State.Y
+ }
+
+ val cmos = ctx.options.flag(CompilationFlag.EmitCmosOpcodes)
+ if (AssemblyLine.treatment(code, state) != Treatment.Unchanged) {
+ register match {
+ case Register.A => AssemblyLine.implied(PHA) +: fixTsx(code) :+ AssemblyLine.implied(PLA)
+ case Register.X => if (cmos) {
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(PHX),
+ ) ++ fixTsx(fixTsx(code)) ++ List(
+ AssemblyLine.implied(PLX),
+ AssemblyLine.implied(PLA),
+ )
+ } else {
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(PHA),
+ ) ++ fixTsx(fixTsx(code)) ++ List(
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA),
+ )
+ }
+ case Register.Y => if (cmos) {
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(PHY),
+ ) ++ fixTsx(fixTsx(code)) ++ List(
+ AssemblyLine.implied(PLY),
+ AssemblyLine.implied(PLA),
+ )
+ } else {
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TYA),
+ AssemblyLine.implied(PHA),
+ ) ++ fixTsx(fixTsx(code)) ++ List(
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(PLA),
+ )
+ }
+ }
+ } else {
+ code
+ }
+ }
+
+ def compileByteStorage(ctx: CompilationContext, register: Register.Value, target: LhsExpression): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val store = register match {
+ case Register.A => STA
+ case Register.X => STX
+ case Register.Y => STY
+ }
+ val transferToA = register match {
+ case Register.A => NOP
+ case Register.X => TXA
+ case Register.Y => TYA
+ }
+ target match {
+ case VariableExpression(name) =>
+ val v = env.get[Variable](name)
+ v.typ.size match {
+ case 0 => ???
+ case 1 =>
+ v match {
+ case mv: VariableInMemory => AssemblyLine.absolute(store, mv) :: Nil
+ case sv@StackVariable(_, _, offset) => AssemblyLine.implied(transferToA) :: AssemblyLine.implied(TSX) :: AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset) :: Nil
+ }
+ case s if s > 1 =>
+ v match {
+ case mv: VariableInMemory =>
+ AssemblyLine.absolute(store, mv) ::
+ AssemblyLine.immediate(LDA, 0) ::
+ List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, mv.toAddress + (i + 1)))
+ case sv@StackVariable(_, _, offset) =>
+ AssemblyLine.implied(transferToA) ::
+ AssemblyLine.implied(TSX) ::
+ AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset) ::
+ List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, offset + ctx.extraStackOffset + i + 1))
+ }
+ }
+ case IndexedExpression(arrayName, indexExpr) =>
+ val array = env.getArrayOrPointer(arrayName)
+ val (variableIndex, constIndex) = env.evalVariableAndConstantSubParts(indexExpr)
+
+ def storeToArrayAtUnknownIndex(variableIndex: Expression, arrayAddr: Constant) = {
+ // TODO check typ
+ val indexRegister = if (register == Register.Y) Register.X else Register.Y
+ val calculatingIndex = preserveRegisterIfNeeded(ctx, register, compile(ctx, variableIndex, Some(b, RegisterVariable(indexRegister, b)), NoBranching))
+ if (register == Register.A) {
+ indexRegister match {
+ case Register.Y =>
+ calculatingIndex ++ List(AssemblyLine.absoluteY(STA, arrayAddr + constIndex))
+ case Register.X =>
+ calculatingIndex ++ List(AssemblyLine.absoluteX(STA, arrayAddr + constIndex))
+ }
+ } else {
+ indexRegister match {
+ case Register.Y =>
+ calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteY(STA, arrayAddr + constIndex))
+ case Register.X =>
+ calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteX(STA, arrayAddr + constIndex))
+ }
+ }
+ }
+
+ (array, variableIndex) match {
+
+ case (p: ConstantThing, None) =>
+ List(AssemblyLine.absolute(store, env.genRelativeVariable(p.value + constIndex, b, zeropage = false)))
+ case (p: ConstantThing, Some(v)) =>
+ storeToArrayAtUnknownIndex(v, p.value)
+
+ case (a@InitializedArray(_, _, _), None) =>
+ List(AssemblyLine.absolute(store, env.genRelativeVariable(a.toAddress + constIndex, b, zeropage = false)))
+ case (a@InitializedArray(_, _, _), Some(v)) =>
+ storeToArrayAtUnknownIndex(v, a.toAddress)
+
+ case (a@UninitializedArray(_, _), None) =>
+ List(AssemblyLine.absolute(store, env.genRelativeVariable(a.toAddress + constIndex, b, zeropage = false)))
+ case (a@UninitializedArray(_, _), Some(v)) =>
+ storeToArrayAtUnknownIndex(v, a.toAddress)
+
+ case (RelativeArray(_, arrayAddr, _), None) =>
+ List(AssemblyLine.absolute(store, env.genRelativeVariable(arrayAddr + constIndex, b, zeropage = false)))
+ case (RelativeArray(_, arrayAddr, _), Some(v)) =>
+ storeToArrayAtUnknownIndex(v, arrayAddr)
+
+ // TODO: are those two below okay?
+ case (RelativeVariable(_, arrayAddr, typ, _), None) =>
+ List(AssemblyLine.absolute(store, env.genRelativeVariable(arrayAddr + constIndex, b, zeropage = false)))
+ case (RelativeVariable(_, arrayAddr, typ, _), Some(v)) =>
+ storeToArrayAtUnknownIndex(v, arrayAddr)
+
+ //TODO: should there be a type check or a zeropage check?
+ case (pointerVariable@MemoryVariable(_, typ, _), None) =>
+ register match {
+ case Register.A =>
+ List(AssemblyLine.immediate(LDY, constIndex), AssemblyLine.indexedY(STA, pointerVariable))
+ case Register.Y =>
+ List(AssemblyLine.implied(TYA), AssemblyLine.immediate(LDY, constIndex), AssemblyLine.indexedY(STA, pointerVariable), AssemblyLine.implied(TAY))
+ case Register.X =>
+ List(AssemblyLine.immediate(LDY, constIndex), AssemblyLine.implied(TXA), AssemblyLine.indexedY(STA, pointerVariable))
+ case _ =>
+ ErrorReporting.error("Cannot store a word in an array", target.position)
+ Nil
+ }
+ case (pointerVariable@MemoryVariable(_, typ, _), Some(_)) =>
+ val calculatingIndex = compile(ctx, indexExpr, Some(b, RegisterVariable(Register.Y, b)), NoBranching)
+ register match {
+ case Register.A =>
+ preserveRegisterIfNeeded(ctx, Register.A, calculatingIndex) :+ AssemblyLine.indexedY(STA, pointerVariable)
+ case Register.X =>
+ preserveRegisterIfNeeded(ctx, Register.X, calculatingIndex) ++ List(AssemblyLine.implied(TXA), AssemblyLine.indexedY(STA, pointerVariable))
+ case Register.Y =>
+ AssemblyLine.implied(TYA) :: preserveRegisterIfNeeded(ctx, Register.A, calculatingIndex) ++ List(
+ AssemblyLine.indexedY(STA, pointerVariable), AssemblyLine.implied(TAY)
+ )
+ case _ =>
+ ErrorReporting.error("Cannot store a word in an array", target.position)
+ Nil
+ }
+ }
+
+ }
+ }
+
+ def assertCompatible(exprType: Type, variableType: Type): Unit = {
+ // TODO
+ }
+
+ val noop: List[AssemblyLine] = Nil
+
+ def callingContext(ctx: CompilationContext, v: MemoryVariable): CompilationContext = {
+ val result = new Environment(Some(ctx.env), "")
+ result.registerVariable(VariableDeclarationStatement(v.name, v.typ.name, stack = false, global = false, constant = false, volatile = false, initialValue = None, address = None), ctx.options)
+ ctx.copy(env = result)
+ }
+
+ def assertBinary(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int) = {
+ if (params.length != 2) {
+ ErrorReporting.fatal("sfgdgfsd", None)
+ }
+ (params.head, params(1)) match {
+ case (l: Expression, r: Expression) => (l, r, getExpressionType(ctx, l).size max getExpressionType(ctx, r).size)
+ }
+ }
+
+ def assertComparison(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int, Boolean) = {
+ if (params.length != 2) {
+ ErrorReporting.fatal("sfgdgfsd", None)
+ }
+ (params.head, params(1)) match {
+ case (l: Expression, r: Expression) =>
+ val lt = getExpressionType(ctx, l)
+ val rt = getExpressionType(ctx, r)
+ (l, r, lt.size max rt.size, lt.isSigned || rt.isSigned)
+ }
+ }
+
+ def assertBool(ctx: CompilationContext, params: List[Expression], expectedParamCount: Int): Unit = {
+ if (params.length != expectedParamCount) {
+ ErrorReporting.error("Invalid number of parameters", params.headOption.flatMap(_.position))
+ return
+ }
+ params.foreach { param =>
+ if (!getExpressionType(ctx, param).isInstanceOf[BooleanType])
+ ErrorReporting.fatal("Parameter should be boolean", param.position)
+ }
+ }
+
+ def assertAssignmentLike(ctx: CompilationContext, params: List[Expression]): (LhsExpression, Expression, Int) = {
+ if (params.length != 2) {
+ ErrorReporting.fatal("sfgdgfsd", None)
+ }
+ (params.head, params(1)) match {
+ case (l: LhsExpression, r: Expression) =>
+ val lsize = getExpressionType(ctx, l).size
+ val rsize = getExpressionType(ctx, r).size
+ if (lsize < rsize) {
+ ErrorReporting.error("Left-hand-side expression is of smaller type than the right-hand-side expression", l.position)
+ }
+ (l, r, lsize)
+ case (err: Expression, _) => ErrorReporting.fatal("Invalid left-hand-side expression", err.position)
+ }
+ }
+
+ def compile(ctx: CompilationContext, expr: Expression, exprTypeAndVariable: Option[(Type, Variable)], branches: BranchSpec): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+ expr match {
+ case HalfWordExpression(expression, _) => ??? // TODO
+ case LiteralExpression(value, size) =>
+ exprTypeAndVariable.fold(noop) { case (exprType, target) =>
+ assertCompatible(exprType, target.typ)
+ compileConstant(ctx, NumericConstant(value, size), target)
+ }
+ case VariableExpression(name) =>
+ exprTypeAndVariable.fold(noop) { case (exprType, target) =>
+ assertCompatible(exprType, target.typ)
+ env.eval(expr).map(c => compileConstant(ctx, c, target)).getOrElse {
+ env.get[TypedThing](name) match {
+ case source: VariableInMemory =>
+ target match {
+ case RegisterVariable(Register.A, _) => List(AssemblyLine.absolute(LDA, source))
+ case RegisterVariable(Register.X, _) => List(AssemblyLine.absolute(LDX, source))
+ case RegisterVariable(Register.Y, _) => List(AssemblyLine.absolute(LDY, source))
+ case RegisterVariable(Register.AX, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA))
+ } else List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.immediate(LDX, 0))
+ case 2 => List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.absolute(LDX, source, 1))
+ }
+ case RegisterVariable(Register.AY, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label),
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(PLA))
+ } else {
+ List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.immediate(LDY, 0))
+ }
+ case 2 => List(
+ AssemblyLine.absolute(LDA, source),
+ AssemblyLine.absolute(LDY, source, 1))
+ }
+ case RegisterVariable(Register.XA, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(LDX, source),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label))
+ } else List(
+ AssemblyLine.absolute(LDX, source),
+ AssemblyLine.immediate(LDA, 0))
+ case 2 => List(
+ AssemblyLine.absolute(LDX, source),
+ AssemblyLine.absolute(LDA, source, 1))
+ }
+ case RegisterVariable(Register.YA, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(LDY, source),
+ AssemblyLine.implied(TYA),
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label))
+ } else List(
+ AssemblyLine.absolute(LDY, source),
+ AssemblyLine.immediate(LDA, 0))
+ case 2 => List(
+ AssemblyLine.absolute(LDY, source),
+ AssemblyLine.absolute(LDA, source, 1))
+ }
+ case target: VariableInMemory =>
+ if (exprType.size > target.typ.size) {
+ ErrorReporting.error(s"Variable `$target.name` is too small", expr.position)
+ Nil
+ } else {
+ val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absolute(LDA, source, i), AssemblyLine.absolute(STA, target, i)))
+ val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size))
+ } else {
+ AssemblyLine.immediate(LDA, 0) ::
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size))
+ }
+ copy.flatten ++ extend
+ }
+ case target: StackVariable =>
+ if (exprType.size > target.typ.size) {
+ ErrorReporting.error(s"Variable `$target.name` is too small", expr.position)
+ Nil
+ } else {
+ val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absolute(LDA, source, i), AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i)))
+ val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size))
+ } else {
+ AssemblyLine.immediate(LDA, 0) ::
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size))
+ }
+ AssemblyLine.implied(TSX) :: (copy.flatten ++ extend)
+ }
+ }
+ case source@StackVariable(_, sourceType, offset) =>
+ target match {
+ case RegisterVariable(Register.A, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset))
+ case RegisterVariable(Register.X, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset), AssemblyLine.implied(TAX))
+ case RegisterVariable(Register.Y, _) => List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset))
+ case RegisterVariable(Register.AX, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA))
+ } else List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset),
+ AssemblyLine.immediate(LDX, 0))
+ case 2 => List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + 1),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA))
+ }
+ case RegisterVariable(Register.AY, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ ??? // TODO
+ } else {
+ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset),
+ AssemblyLine.immediate(LDY, 0))
+ }
+ case 2 => List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset),
+ AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset + 1))
+ }
+ case RegisterVariable(Register.XA, _) =>
+ ??? // TODO
+ case RegisterVariable(Register.YA, _) =>
+ exprType.size match {
+ case 1 => if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ ??? // TODO
+ } else {
+ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset),
+ AssemblyLine.immediate(LDA, 0))
+ }
+ case 2 => List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(LDY, offset + ctx.extraStackOffset),
+ AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + 1))
+ }
+ case target: VariableInMemory =>
+ if (exprType.size > target.typ.size) {
+ ErrorReporting.error(s"Variable `$target.name` is too small", expr.position)
+ Nil
+ } else {
+ val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absolute(STA, target, i)))
+ val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size))
+ } else {
+ AssemblyLine.immediate(LDA, 0) ::
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absolute(STA, target, i + exprType.size))
+ }
+ AssemblyLine.implied(TSX) :: (copy.flatten ++ extend)
+ }
+ case target: StackVariable =>
+ if (exprType.size > target.typ.size) {
+ ErrorReporting.error(s"Variable `$target.name` is too small", expr.position)
+ Nil
+ } else {
+ val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absoluteX(STA, target.baseOffset + i)))
+ val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.immediate(ORA, 0x7F),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size))
+ } else {
+ AssemblyLine.immediate(LDA, 0) ::
+ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size))
+ }
+ AssemblyLine.implied(TSX) :: (copy.flatten ++ extend)
+ }
+ }
+ case source@ConstantThing(_, value, _) =>
+ compileConstant(ctx, value, target)
+ }
+ }
+ }
+ case IndexedExpression(arrayName, indexExpr) =>
+ val array = env.getArrayOrPointer(arrayName)
+ // TODO: check
+ val (variableIndex, constantIndex) = env.evalVariableAndConstantSubParts(indexExpr)
+ exprTypeAndVariable.fold(noop) { case (exprType, target) =>
+
+ val register = target match {
+ case RegisterVariable(r, _) => r
+ case _ => Register.A
+ }
+ val suffix = target match {
+ case RegisterVariable(_, _) => Nil
+ case target: VariableInMemory =>
+ if (target.typ.size == 1) {
+ AssemblyLine.variable(ctx, STA, target)
+ }
+ else if (target.typ.isSigned) {
+ val label = nextLabel("sx")
+ AssemblyLine.variable(ctx, STA, target) ++
+ List(
+ AssemblyLine.immediate(ORA, 0x7f),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++
+ List.tabulate(target.typ.size - 1)(i => AssemblyLine.variable(ctx, STA, target, i + 1)).flatten
+ } else {
+ AssemblyLine.variable(ctx, STA, target) ++
+ List(AssemblyLine.immediate(LDA, 0)) ++
+ List.tabulate(target.typ.size - 1)(i => AssemblyLine.variable(ctx, STA, target, i + 1)).flatten
+ }
+ }
+ val load = register match {
+ case Register.A | Register.AX | Register.AY => LDA
+ case Register.X => LDX
+ case Register.Y => LDY
+ }
+
+ def loadFromArrayAtUnknownIndex(variableIndex: Expression, arrayAddr: Constant) = {
+ // TODO check typ
+ val indexRegister = if (register == Register.Y) Register.X else Register.Y
+ val calculatingIndex = compile(ctx, variableIndex, Some(b, RegisterVariable(indexRegister, b)), NoBranching)
+ indexRegister match {
+ case Register.Y =>
+ calculatingIndex ++ List(AssemblyLine.absoluteY(load, arrayAddr + constantIndex))
+ case Register.X =>
+ calculatingIndex ++ List(AssemblyLine.absoluteX(load, arrayAddr + constantIndex))
+ }
+ }
+
+ val result = (array, variableIndex) match {
+ case (a: ConstantThing, None) =>
+ List(AssemblyLine.absolute(load, env.genRelativeVariable(a.value + constantIndex, b, zeropage = false)))
+ case (a: ConstantThing, Some(v)) =>
+ loadFromArrayAtUnknownIndex(v, a.value)
+
+ case (a: MlArray, None) =>
+ List(AssemblyLine.absolute(load, env.genRelativeVariable(a.toAddress + constantIndex, b, zeropage = false)))
+ case (a: MlArray, Some(v)) =>
+ loadFromArrayAtUnknownIndex(v, a.toAddress)
+
+ // TODO: see above
+ case (RelativeVariable(_, arrayAddr, typ, _), None) =>
+ List(AssemblyLine.absolute(load, env.genRelativeVariable(arrayAddr + constantIndex, b, zeropage = false)))
+ case (RelativeVariable(_, arrayAddr, typ, _), Some(v)) =>
+ loadFromArrayAtUnknownIndex(v, arrayAddr)
+
+ // TODO: see above
+ case (pointerVariable@MemoryVariable(_, typ, _), None) =>
+ register match {
+ case Register.A =>
+ List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDA, pointerVariable))
+ case Register.Y =>
+ List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDA, pointerVariable), AssemblyLine.implied(TAY))
+ case Register.X =>
+ List(AssemblyLine.immediate(LDY, constantIndex), AssemblyLine.indexedY(LDX, pointerVariable))
+ }
+ case (pointerVariable@MemoryVariable(_, typ, _), Some(_)) =>
+ val calculatingIndex = compile(ctx, indexExpr, Some(b, RegisterVariable(Register.Y, b)), NoBranching)
+ register match {
+ case Register.A =>
+ calculatingIndex :+ AssemblyLine.indexedY(LDA, pointerVariable)
+ case Register.X =>
+ calculatingIndex :+ AssemblyLine.indexedY(LDX, pointerVariable)
+ case Register.Y =>
+ calculatingIndex ++ List(AssemblyLine.indexedY(LDA, pointerVariable), AssemblyLine.implied(TAY))
+ }
+ }
+ register match {
+ case Register.A | Register.X | Register.Y => result ++ suffix
+ case Register.AX => result :+ AssemblyLine.immediate(LDX, 0)
+ case Register.AY => result :+ AssemblyLine.immediate(LDY, 0)
+ }
+ }
+ case SumExpression(params, decimal) =>
+ assertAllBytesForSum("Long addition not supported", ctx, params)
+ val calculate = BuiltIns.compileAddition(ctx, params, decimal = decimal)
+ val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position)
+ calculate ++ store
+ case SeparateBytesExpression(h, l) =>
+ exprTypeAndVariable.fold {
+ // TODO: order?
+ compile(ctx, l, None, branches) ++ compile(ctx, h, None, branches)
+ } { case (exprType, target) =>
+ assertCompatible(exprType, target.typ)
+ target match {
+ case RegisterVariable(Register.A | Register.X | Register.Y, _) => compile(ctx, l, exprTypeAndVariable, branches)
+ case RegisterVariable(Register.AX, _) =>
+ compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++
+ compile(ctx, h, Some(b -> RegisterVariable(Register.X, b)), branches)
+ case RegisterVariable(Register.AY, _) =>
+ compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++
+ compile(ctx, h, Some(b -> RegisterVariable(Register.Y, b)), branches)
+ case RegisterVariable(Register.XA, _) =>
+ compile(ctx, l, Some(b -> RegisterVariable(Register.X, b)), branches) ++
+ compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), branches)
+ case RegisterVariable(Register.YA, _) =>
+ compile(ctx, l, Some(b -> RegisterVariable(Register.Y, b)), branches) ++
+ compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), branches)
+ case target: VariableInMemory =>
+ target.typ.size match {
+ case 1 =>
+ ErrorReporting.error(s"Variable `$target.name` cannot hold a word", expr.position)
+ Nil
+ case 2 =>
+ compile(ctx, l, Some(b -> env.genRelativeVariable(target.toAddress, b, zeropage = target.zeropage)), branches) ++
+ compile(ctx, h, Some(b -> env.genRelativeVariable(target.toAddress + 1, b, zeropage = target.zeropage)), branches)
+ }
+ case target: StackVariable =>
+ target.typ.size match {
+ case 1 =>
+ ErrorReporting.error(s"Variable `$target.name` cannot hold a word", expr.position)
+ Nil
+ case 2 =>
+ compile(ctx, l, Some(b -> StackVariable("", b, target.baseOffset + ctx.extraStackOffset)), branches) ++
+ compile(ctx, h, Some(b -> StackVariable("", b, target.baseOffset + ctx.extraStackOffset + 1)), branches)
+ }
+ }
+ }
+
+ case f@FunctionCallExpression(name, params) =>
+ val calculate = name match {
+ case "not" =>
+ assertBool(ctx, params, 1)
+ compile(ctx, params.head, exprTypeAndVariable, branches.flip)
+ case "&&" =>
+ assertBool(ctx, params, 2)
+ val a = params.head
+ val b = params(1)
+ branches match {
+ case BranchIfFalse(_) =>
+ compile(ctx, a, exprTypeAndVariable, branches) ++ compile(ctx, b, exprTypeAndVariable, branches)
+ case _ =>
+ val skip = nextLabel("an")
+ compile(ctx, a, exprTypeAndVariable, BranchIfFalse(skip)) ++
+ compile(ctx, b, exprTypeAndVariable, branches) ++
+ List(AssemblyLine.label(skip))
+ }
+ case "||" =>
+ assertBool(ctx, params, 2)
+ val a = params.head
+ val b = params(1)
+ branches match {
+ case BranchIfTrue(_) =>
+ compile(ctx, a, exprTypeAndVariable, branches) ++ compile(ctx, b, exprTypeAndVariable, branches)
+ case _ =>
+ val skip = nextLabel("or")
+ compile(ctx, a, exprTypeAndVariable, BranchIfTrue(skip)) ++
+ compile(ctx, b, exprTypeAndVariable, branches) ++
+ List(AssemblyLine.label(skip))
+ }
+ case "^^" => ???
+ case "&" =>
+ assertAllBytes("Long bit ops not supported", ctx, params)
+ BuiltIns.compileBitOps(AND, ctx, params)
+ case "*" =>
+ assertAllBytes("Long multiplication not supported", ctx, params)
+ BuiltIns.compileByteMultiplication(ctx, params)
+ case "|" =>
+ assertAllBytes("Long bit ops not supported", ctx, params)
+ BuiltIns.compileBitOps(ORA, ctx, params)
+ case "^" =>
+ assertAllBytes("Long bit ops not supported", ctx, params)
+ BuiltIns.compileBitOps(EOR, ctx, params)
+ case ">>>>" =>
+ val (l, r, 2) = assertBinary(ctx, params)
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileNonetOps(ctx, v, r)
+ }
+ case "<<" =>
+ assertAllBytes("Long shift ops not supported", ctx, params)
+ val (l, r, 1) = assertBinary(ctx, params)
+ BuiltIns.compileShiftOps(ASL, ctx, l, r)
+ case ">>" =>
+ assertAllBytes("Long shift ops not supported", ctx, params)
+ val (l, r, 1) = assertBinary(ctx, params)
+ BuiltIns.compileShiftOps(LSR, ctx, l, r)
+ case "<" =>
+ // TODO: signed
+ val (l, r, size, signed) = assertComparison(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
+ }
+ case ">=" =>
+ // TODO: signed
+ val (l, r, size, signed) = assertComparison(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
+ }
+ case ">" =>
+ // TODO: signed
+ val (l, r, size, signed) = assertComparison(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
+ }
+ case "<=" =>
+ // TODO: signed
+ val (l, r, size, signed) = assertComparison(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
+ }
+ case "==" =>
+ val (l, r, size) = assertBinary(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches)
+ }
+ case "!=" =>
+ val (l, r, size) = assertBinary(ctx, params)
+ size match {
+ case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.NotEqual, l, r, branches)
+ case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.NotEqual, l, r, branches)
+ }
+ case "+=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = false)
+ case 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = false)
+ }
+ case i if i > 2 =>
+ l match {
+ case v: VariableExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = false)
+ }
+ }
+ case "-=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = false)
+ case 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = false)
+ }
+ case i if i > 2 =>
+ l match {
+ case v: VariableExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = false)
+ }
+ }
+ case "+'=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = true)
+ case 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = true)
+ }
+ case i if i > 2 =>
+ l match {
+ case v: VariableExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = false, decimal = true)
+ }
+ }
+ case "-'=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = true)
+ case 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = true)
+ }
+ case i if i > 2 =>
+ l match {
+ case v: VariableExpression =>
+ BuiltIns.compileInPlaceWordOrLongAddition(ctx, v, r, subtract = true, decimal = true)
+ }
+ }
+ case "<<=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteShiftOps(ASL, ctx, l, r)
+ case i if i >= 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongShiftOps(ctx, v, r, aslRatherThanLsr = true)
+ }
+ }
+ case ">>=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteShiftOps(LSR, ctx, l, r)
+ case i if i >= 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongShiftOps(ctx, v, r, aslRatherThanLsr = false)
+ }
+ }
+ case "*=" =>
+ assertAllBytes("Long multiplication not supported", ctx, params)
+ val (l, r, 1) = assertAssignmentLike(ctx, params)
+ BuiltIns.compileInPlaceByteMultiplication(ctx, l, r)
+ case "&=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteBitOp(ctx, l, r, AND)
+ case i if i >= 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, AND)
+ }
+ }
+ case "^=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteBitOp(ctx, l, r, EOR)
+ case i if i >= 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, EOR)
+ }
+ }
+ case "|=" =>
+ val (l, r, size) = assertAssignmentLike(ctx, params)
+ size match {
+ case 1 =>
+ BuiltIns.compileInPlaceByteBitOp(ctx, l, r, ORA)
+ case i if i >= 2 =>
+ l match {
+ case v: LhsExpression =>
+ BuiltIns.compileInPlaceWordOrLongBitOp(ctx, l, r, ORA)
+ }
+ }
+ case _ =>
+ lookupFunction(ctx, f) match {
+ case function: InlinedFunction =>
+ inlineFunction(function, params, Some(ctx)).map {
+ case AssemblyStatement(opcode, addrMode, expression, elidable) =>
+ val param = env.eval(expression).getOrElse {
+ expression match {
+ case VariableExpression(name) => env.get[ThingInMemory](name).toAddress
+ case _ =>
+ ErrorReporting.error("Inlining failed due to non-constant things", expression.position)
+ Constant.Zero
+ }
+ }
+ AssemblyLine(opcode, addrMode, param, elidable)
+
+ }
+ case function: EmptyFunction =>
+ ??? // TODO: type conversion?
+ case function: FunctionInMemory =>
+ function match {
+ case nf: NormalFunction =>
+ if (nf.interrupt) {
+ ErrorReporting.error(s"Calling an interrupt function `${f.functionName}`", expr.position)
+ }
+ case _ => ()
+ }
+ val result = function.params match {
+ case AssemblyParamSignature(paramConvs) =>
+ val pairs = params.zip(paramConvs)
+ val secondViaMemory = pairs.flatMap {
+ case (paramExpr, AssemblyParam(typ, paramVar: VariableInMemory, AssemblyParameterPassingBehaviour.Copy)) =>
+ compile(ctx, paramExpr, Some(typ -> paramVar), NoBranching)
+ case _ => Nil
+ }
+ val thirdViaRegisters = pairs.flatMap {
+ case (paramExpr, AssemblyParam(typ, paramVar@RegisterVariable(register, _), AssemblyParameterPassingBehaviour.Copy)) =>
+ compile(ctx, paramExpr, Some(typ -> paramVar), NoBranching)
+
+ // TODO: fix
+ case _ => Nil
+ }
+ secondViaMemory ++ thirdViaRegisters :+ AssemblyLine.absolute(JSR, function)
+ case NormalParamSignature(paramVars) =>
+ params.zip(paramVars).flatMap {
+ case (paramExpr, paramVar) =>
+ val callCtx = callingContext(ctx, paramVar)
+ compileAssignment(callCtx, paramExpr, VariableExpression(paramVar.name))
+ } ++ List(AssemblyLine.absolute(JSR, function))
+ }
+ result
+ }
+ }
+ val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position)
+ calculate ++ store
+ }
+ }
+
+ def expressionStorageFromAX(ctx: CompilationContext, exprTypeAndVariable: Option[(Type, Variable)], position: Option[Position]): List[AssemblyLine] = {
+ exprTypeAndVariable.fold(noop) {
+ case (VoidType, _) => ???
+ case (_, RegisterVariable(Register.A, _)) => noop
+ case (_, RegisterVariable(Register.X, _)) => List(AssemblyLine.implied(TAX))
+ case (_, RegisterVariable(Register.Y, _)) => List(AssemblyLine.implied(TAY))
+ case (_, RegisterVariable(Register.AX, _)) =>
+ // TODO: sign extension
+ noop
+ case (_, RegisterVariable(Register.XA, _)) =>
+ // TODO: sign extension
+ if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) {
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(PHX),
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(PLX))
+ } else {
+ List(
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TYA),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA)) // fuck this shit
+ }
+ case (_, RegisterVariable(Register.YA, _)) => {
+ // TODO: sign extension
+ List(
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(TXA))
+ }
+ case (_, RegisterVariable(Register.AY, _)) =>
+ // TODO: sign extension
+ List(
+ AssemblyLine.implied(PHA),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(PLA))
+ case (t, v: VariableInMemory) => t.size match {
+ case 1 => v.typ.size match {
+ case 1 =>
+ List(AssemblyLine.absolute(STA, v))
+ case s if s > 1 =>
+ if (t.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(STA, v),
+ AssemblyLine.immediate(ORA, 0x7f),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++ List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, v, i + 1))
+ } else {
+ List(
+ AssemblyLine.absolute(STA, v),
+ AssemblyLine.immediate(LDA, 0)) ++
+ List.tabulate(s - 1)(i => AssemblyLine.absolute(STA, v, i + 1))
+ }
+ }
+ case 2 => v.typ.size match {
+ case 1 =>
+ ErrorReporting.error(s"Variable `${v.name}` cannot hold a word", position)
+ Nil
+ case 2 =>
+ List(AssemblyLine.absolute(STA, v), AssemblyLine.absolute(STX, v, 1))
+ case s if s > 2 =>
+ if (t.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.absolute(STA, v),
+ AssemblyLine.absolute(STX, v, 1),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.immediate(ORA, 0x7f),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++ List.tabulate(s - 2)(i => AssemblyLine.absolute(STA, v, i + 2))
+ } else {
+ List(
+ AssemblyLine.absolute(STA, v),
+ AssemblyLine.absolute(STX, v, 1),
+ AssemblyLine.immediate(LDA, 0)) ++
+ List.tabulate(s - 2)(i => AssemblyLine.absolute(STA, v, i + 2))
+ }
+ }
+ }
+ case (t, v: StackVariable) => t.size match {
+ case 1 => v.typ.size match {
+ case 1 =>
+ List(AssemblyLine.implied(TSX), AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset))
+ case s if s > 1 =>
+ if (t.isSigned) {
+ val label = nextLabel("sx")
+ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset),
+ AssemblyLine.immediate(ORA, 0x7f),
+ AssemblyLine.relative(BMI, label),
+ AssemblyLine.immediate(LDA, 0),
+ AssemblyLine.label(label)) ++ List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + i + 1))
+ } else {
+ List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset),
+ AssemblyLine.immediate(LDA, 0)) ++
+ List.tabulate(s - 1)(i => AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + i + 1))
+ }
+ }
+ case 2 => v.typ.size match {
+ case 1 =>
+ ErrorReporting.error(s"Variable `${v.name}` cannot hold a word", position)
+ Nil
+ case 2 =>
+ List(
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(TXA),
+ AssemblyLine.implied(TSX),
+ AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset + 1),
+ AssemblyLine.implied(TYA),
+ AssemblyLine.absoluteX(STA, v.baseOffset + ctx.extraStackOffset))
+ case s if s > 2 => ???
+ }
+ }
+ }
+ }
+
+ private def assertAllBytesForSum(msg: String, ctx: CompilationContext, params: List[(Boolean, Expression)]): Unit = {
+ if (params.exists { case (_, expr) => getExpressionType(ctx, expr).size != 1 }) {
+ ErrorReporting.fatal(msg, params.head._2.position)
+ }
+ }
+
+ private def assertAllBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = {
+ if (params.exists { expr => getExpressionType(ctx, expr).size != 1 }) {
+ ErrorReporting.fatal(msg, params.head.position)
+ }
+ }
+
+ def compileAssignment(ctx: CompilationContext, source: Expression, target: LhsExpression): List[AssemblyLine] = {
+ val env = ctx.env
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+ target match {
+ case VariableExpression(name) =>
+ val v = env.get[Variable](name, target.position)
+ // TODO check v.typ
+ compile(ctx, source, Some((getExpressionType(ctx, source), v)), NoBranching)
+ case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) =>
+ compile(ctx, source, Some(w, RegisterVariable(Register.AX, w)), NoBranching) ++
+ compileByteStorage(ctx, Register.A, l) ++ compileByteStorage(ctx, Register.X, h)
+ case SeparateBytesExpression(_, _) =>
+ ErrorReporting.error("Invalid left-hand-side use of `:`")
+ Nil
+ case _ =>
+ compile(ctx, source, Some(b, RegisterVariable(Register.A, b)), NoBranching) ++ compileByteStorage(ctx, Register.A, target)
+ }
+ }
+
+
+ def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): Chunk = {
+ SequenceChunk(statements.map(s => compile(ctx, s)))
+ }
+
+ def inlineFunction(i: InlinedFunction, params: List[Expression], cc: Option[CompilationContext]): List[ExecutableStatement] = {
+ var actualCode = i.code
+ i.params match {
+ case AssemblyParamSignature(assParams) =>
+ assParams.zip(params).foreach {
+ case (AssemblyParam(typ, Placeholder(ph, phType), AssemblyParameterPassingBehaviour.ByReference), actualParam) =>
+ actualParam match {
+ case VariableExpression(vname) =>
+ cc.foreach(_.env.get[ThingInMemory](vname))
+ case l: LhsExpression =>
+ // TODO: ??
+ cc.foreach(c => compileByteStorage(c, Register.A, l))
+ case _ =>
+ ErrorReporting.error("A non-assignable expression was passed to an inlineable function as a `ref` parameter", actualParam.position)
+ }
+ actualCode = actualCode.map {
+ case a@AssemblyStatement(_, _, expr, _) =>
+ a.copy(expression = expr.replaceVariable(ph, actualParam))
+ case x => x
+ }
+ case (AssemblyParam(typ, Placeholder(ph, phType), AssemblyParameterPassingBehaviour.ByConstant), actualParam) =>
+ cc.foreach(_.env.eval(actualParam).getOrElse(Constant.error("Non-constant expression was passed to an inlineable function as a `const` parameter", actualParam.position)))
+ actualCode = actualCode.map {
+ case a@AssemblyStatement(_, _, expr, _) =>
+ a.copy(expression = expr.replaceVariable(ph, actualParam))
+ case x => x
+ }
+ case (AssemblyParam(_, _, AssemblyParameterPassingBehaviour.Copy), actualParam) =>
+ ???
+ case (_, actualParam) =>
+ }
+ case NormalParamSignature(Nil) => i.code
+ case NormalParamSignature(normalParams) => ???
+ }
+ actualCode
+ }
+
+ def stackPointerFixAtBeginning(ctx: CompilationContext): List[AssemblyLine] = {
+ val m = ctx.function
+ if (m.stackVariablesSize == 0) return Nil
+ if (ctx.options.flag(CompilationFlag.EmitIllegals)) {
+ if (m.stackVariablesSize > 4)
+ return List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.immediate(LDA, 0xff),
+ AssemblyLine.immediate(SBX, m.stackVariablesSize),
+ AssemblyLine.implied(TXS))
+ }
+ List.fill(m.stackVariablesSize)(AssemblyLine.implied(PHA))
+ }
+
+ def stackPointerFixBeforeReturn(ctx: CompilationContext): List[AssemblyLine] = {
+ val m = ctx.function
+ if (m.stackVariablesSize == 0) return Nil
+
+ if (m.returnType.size == 0 && m.stackVariablesSize <= 2)
+ return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLA))
+
+ if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) {
+ if (m.returnType.size == 1 && m.stackVariablesSize <= 2) {
+ return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLX))
+ }
+ if (m.returnType.size == 2 && m.stackVariablesSize <= 2) {
+ return List.fill(m.stackVariablesSize)(AssemblyLine.implied(PLY))
+ }
+ }
+
+ if (ctx.options.flag(CompilationFlag.EmitIllegals)) {
+ if (m.returnType.size == 0 && m.stackVariablesSize > 4)
+ return List(
+ AssemblyLine.implied(TSX),
+ AssemblyLine.immediate(LDA, 0xff),
+ AssemblyLine.immediate(SBX, 256 - m.stackVariablesSize),
+ AssemblyLine.implied(TXS))
+ if (m.returnType.size == 1 && m.stackVariablesSize > 6)
+ return List(
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(TSX),
+ AssemblyLine.immediate(LDA, 0xff),
+ AssemblyLine.immediate(SBX, 256 - m.stackVariablesSize),
+ AssemblyLine.implied(TXS),
+ AssemblyLine.implied(TYA))
+ }
+
+ AssemblyLine.implied(TSX) :: (List.fill(m.stackVariablesSize)(AssemblyLine.implied(INX)) :+ AssemblyLine.implied(TXS))
+ }
+
+ def compile(ctx: CompilationContext, statement: ExecutableStatement): Chunk = {
+ val env = ctx.env
+ val m = ctx.function
+ val b = env.get[Type]("byte")
+ val w = env.get[Type]("word")
+ val someRegisterA = Some(b, RegisterVariable(Register.A, b))
+ val someRegisterAX = Some(w, RegisterVariable(Register.AX, w))
+ val someRegisterYA = Some(w, RegisterVariable(Register.YA, w))
+ val returnInstructions = if (m.interrupt) {
+ if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) {
+ List(
+ AssemblyLine.implied(PLY),
+ AssemblyLine.implied(PLX),
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(CLI),
+ AssemblyLine.implied(RTI))
+ } else {
+ List(
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(TAY),
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(TAX),
+ AssemblyLine.implied(PLA),
+ AssemblyLine.implied(CLI),
+ AssemblyLine.implied(RTI))
+ }
+ } else {
+ List(AssemblyLine.implied(RTS))
+ }
+ statement match {
+ case AssemblyStatement(o, a, x, e) =>
+ val c: Constant = x match {
+ // TODO: hmmm
+ case VariableExpression(name) =>
+ if (OpcodeClasses.ShortBranching(o) || o == JMP || o == LABEL) {
+ MemoryAddressConstant(Label(name))
+ } else{
+ env.eval(x).getOrElse(env.get[ThingInMemory](name, x.position).toAddress)
+ }
+ case _ =>
+ env.eval(x).getOrElse(Constant.error(s"`$x` is not a constant", x.position))
+ }
+ val actualAddrMode = if (OpcodeClasses.ShortBranching(o) && a == Absolute) Relative else a
+ LinearChunk(List(AssemblyLine(o, actualAddrMode, c, e)))
+ case Assignment(dest, source) =>
+ LinearChunk(compileAssignment(ctx, source, dest))
+ case ExpressionStatement(e@FunctionCallExpression(name, params)) =>
+ env.lookupFunction(name, params.map(p => getExpressionType(ctx, p) -> p)) match {
+ case Some(i: InlinedFunction) =>
+ compile(ctx, inlineFunction(i, params, Some(ctx)))
+ case _ =>
+ LinearChunk(compile(ctx, e, None, NoBranching))
+ }
+ case ExpressionStatement(e) =>
+ LinearChunk(compile(ctx, e, None, NoBranching))
+ case BlockStatement(s) =>
+ SequenceChunk(s.map(compile(ctx, _)))
+ case ReturnStatement(None) =>
+ // TODO: return type check
+ // TODO: better stackpointer fix
+ ctx.function.returnType match {
+ case _: BooleanType =>
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++ returnInstructions)
+ case t => t.size match {
+ case 0 =>
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++
+ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions)
+ case 1 =>
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++
+ List(AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions)
+ case 2 =>
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++
+ List(AssemblyLine.discardYF()) ++ returnInstructions)
+ }
+ }
+ case ReturnStatement(Some(e)) =>
+ m.returnType match {
+ case _: BooleanType =>
+ m.returnType.size match {
+ case 0 =>
+ ErrorReporting.error("Cannot return anything from a void function", statement.position)
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++ returnInstructions)
+ case 1 =>
+ LinearChunk(compile(ctx, e, someRegisterA, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions)
+ case 2 =>
+ LinearChunk(compile(ctx, e, someRegisterAX, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions)
+ }
+ case _ =>
+ m.returnType.size match {
+ case 0 =>
+ ErrorReporting.error("Cannot return anything from a void function", statement.position)
+ LinearChunk(stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions)
+ case 1 =>
+ LinearChunk(compile(ctx, e, someRegisterA, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions)
+ case 2 =>
+ // TODO: ???
+ val stackPointerFix = stackPointerFixBeforeReturn(ctx)
+ if (stackPointerFix.isEmpty) {
+ LinearChunk(compile(ctx, e, someRegisterAX, NoBranching) ++ List(AssemblyLine.discardYF()) ++ returnInstructions)
+ } else {
+ LinearChunk(compile(ctx, e, someRegisterYA, NoBranching) ++
+ stackPointerFix ++
+ List(AssemblyLine.implied(TAX), AssemblyLine.implied(TYA), AssemblyLine.discardYF()) ++
+ returnInstructions)
+ }
+ }
+ }
+ case IfStatement(condition, thenPart, elsePart) =>
+ val condType = getExpressionType(ctx, condition)
+ val thenBlock = compile(ctx, thenPart)
+ val elseBlock = compile(ctx, elsePart)
+ val largeThenBlock = thenBlock.sizeInBytes > 100
+ val largeElseBlock = elseBlock.sizeInBytes > 100
+ condType match {
+ case ConstantBooleanType(_, true) => thenBlock
+ case ConstantBooleanType(_, false) => elseBlock
+ case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) =>
+ (thenPart, elsePart) match {
+ case (Nil, Nil) => EmptyChunk
+ case (Nil, _) =>
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ if (largeElseBlock) {
+ val middle = nextLabel("el")
+ val end = nextLabel("fi")
+ SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, middle), jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end)))
+ } else {
+ val end = nextLabel("fi")
+ SequenceChunk(List(conditionBlock, branchChunk(jumpIfTrue, end), elseBlock, labelChunk(end)))
+ }
+ case (_, Nil) =>
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ if (largeThenBlock) {
+ val middle = nextLabel("th")
+ val end = nextLabel("fi")
+ SequenceChunk(List(conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end)))
+ } else {
+ val end = nextLabel("fi")
+ SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, end), thenBlock, labelChunk(end)))
+ }
+ case _ =>
+ // TODO: large blocks
+ if (largeElseBlock || largeThenBlock) ErrorReporting.error("Large blocks in if statement", statement.position)
+ val middle = nextLabel("el")
+ val end = nextLabel("fi")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ SequenceChunk(List(conditionBlock, branchChunk(jumpIfFalse, middle), thenBlock, jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end)))
+ }
+ case BuiltInBooleanType =>
+ (thenPart, elsePart) match {
+ case (Nil, Nil) => EmptyChunk
+ case (Nil, _) =>
+ val end = nextLabel("fi")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(end)))
+ SequenceChunk(List(conditionBlock, elseBlock, labelChunk(end)))
+ case (_, Nil) =>
+ val end = nextLabel("fi")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end)))
+ SequenceChunk(List(conditionBlock, thenBlock, labelChunk(end)))
+ case _ =>
+ val middle = nextLabel("el")
+ val end = nextLabel("fi")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(middle)))
+ SequenceChunk(List(conditionBlock, thenBlock, jmpChunk(end), labelChunk(middle), elseBlock, labelChunk(end)))
+ }
+ case _ =>
+ ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
+ EmptyChunk
+ }
+ case WhileStatement(condition, bodyPart) =>
+ val condType = getExpressionType(ctx, condition)
+ val bodyBlock = compile(ctx, bodyPart)
+ val largeBodyBlock = bodyBlock.sizeInBytes > 100
+ condType match {
+ case ConstantBooleanType(_, true) =>
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ val start = nextLabel("wh")
+ SequenceChunk(List(labelChunk(start), bodyBlock, jmpChunk(start)))
+ case ConstantBooleanType(_, false) => EmptyChunk
+ case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) =>
+ if (largeBodyBlock) {
+ val start = nextLabel("wh")
+ val middle = nextLabel("he")
+ val end = nextLabel("ew")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ SequenceChunk(List(labelChunk(start), conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), bodyBlock, jmpChunk(start), labelChunk(end)))
+ } else {
+ val start = nextLabel("wh")
+ val end = nextLabel("ew")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ SequenceChunk(List(labelChunk(start), conditionBlock, branchChunk(jumpIfFalse, end), bodyBlock, jmpChunk(start), labelChunk(end)))
+ }
+ case BuiltInBooleanType =>
+ if (largeBodyBlock) {
+ val start = nextLabel("wh")
+ val middle = nextLabel("he")
+ val end = nextLabel("ew")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(middle)))
+ SequenceChunk(List(labelChunk(start), conditionBlock, jmpChunk(end), labelChunk(middle), bodyBlock, jmpChunk(start), labelChunk(end)))
+ } else {
+ val start = nextLabel("wh")
+ val end = nextLabel("ew")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end)))
+ SequenceChunk(List(labelChunk(start), conditionBlock, bodyBlock, jmpChunk(start), labelChunk(end)))
+ }
+ case _ =>
+ ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
+ EmptyChunk
+ }
+ case DoWhileStatement(bodyPart, condition) =>
+ val condType = getExpressionType(ctx, condition)
+ val bodyBlock = compile(ctx, bodyPart)
+ val largeBodyBlock = bodyBlock.sizeInBytes > 100
+ condType match {
+ case ConstantBooleanType(_, true) =>
+ val start = nextLabel("do")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ SequenceChunk(List(labelChunk(start), bodyBlock, jmpChunk(start)))
+ case ConstantBooleanType(_, false) => bodyBlock
+ case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) =>
+ val start = nextLabel("do")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, NoBranching))
+ if (largeBodyBlock) {
+ val end = nextLabel("od")
+ SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfFalse, end), jmpChunk(start), labelChunk(end)))
+ } else {
+ SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfTrue, start)))
+ }
+ case BuiltInBooleanType =>
+ val start = nextLabel("do")
+ if (largeBodyBlock) {
+ val end = nextLabel("od")
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfFalse(end)))
+ SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock, jmpChunk(start), labelChunk(end)))
+ } else {
+ val conditionBlock = LinearChunk(compile(ctx, condition, someRegisterA, BranchIfTrue(start)))
+ SequenceChunk(List(labelChunk(start), bodyBlock, conditionBlock))
+ }
+ case _ =>
+ ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
+ EmptyChunk
+ }
+ case f@ForStatement(variable, start, end, direction, body) =>
+ // TODO: check sizes
+ // TODO: special faster cases
+ val vex = VariableExpression(f.variable)
+ val one = LiteralExpression(1, 1)
+ val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)))
+ val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)))
+ (direction, env.eval(start), env.eval(end)) match {
+
+ case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 =>
+ compile(ctx, Assignment(vex, f.start) :: f.body)
+ case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s >= e =>
+ EmptyChunk
+
+ case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e =>
+ compile(ctx, Assignment(vex, f.start) :: f.body)
+ case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s > e =>
+ EmptyChunk
+
+ case (ForDirection.ParallelUntil, Some(NumericConstant(0, ssize)), Some(NumericConstant(e, _))) if e > 0 =>
+ compile(ctx, List(
+ Assignment(vex, f.end),
+ DoWhileStatement(decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)))
+ ))
+
+ case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e =>
+ compile(ctx, Assignment(vex, LiteralExpression(s, ssize)) :: f.body)
+ case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s < e =>
+ EmptyChunk
+ case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(0, esize))) if s > 0 =>
+ compile(ctx, List(
+ Assignment(vex, f.start),
+ DoWhileStatement(f.body :+ decrement, FunctionCallExpression("!=", List(vex, f.end)))
+ ))
+
+
+ case (ForDirection.Until | ForDirection.ParallelUntil, _, _) =>
+ compile(ctx, List(
+ Assignment(vex, f.start),
+ WhileStatement(
+ FunctionCallExpression("<", List(vex, f.end)),
+ f.body :+ increment),
+ ))
+ case (ForDirection.To | ForDirection.ParallelTo,_,_) =>
+ compile(ctx, List(
+ Assignment(vex, f.start),
+ WhileStatement(
+ FunctionCallExpression("<=", List(vex, f.end)),
+ f.body :+ increment),
+ ))
+ case (ForDirection.DownTo,_,_) =>
+ compile(ctx, List(
+ Assignment(vex, f.start),
+ IfStatement(
+ FunctionCallExpression(">=", List(vex, f.end)),
+ List(DoWhileStatement(
+ f.body :+ decrement,
+ FunctionCallExpression("!=", List(vex, f.end))
+ )),
+ Nil)
+ ))
+ }
+ // TODO
+ }
+ }
+
+ private def labelChunk(labelName: String) = {
+ LinearChunk(List(AssemblyLine.label(Label(labelName))))
+ }
+
+ private def jmpChunk(labelName: String) = {
+ LinearChunk(List(AssemblyLine.absolute(JMP, Label(labelName))))
+ }
+
+ private def branchChunk(opcode: Opcode.Value, labelName: String) = {
+ LinearChunk(List(AssemblyLine.relative(opcode, Label(labelName))))
+ }
+}
diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala
new file mode 100644
index 00000000..abc95f07
--- /dev/null
+++ b/src/main/scala/millfork/env/Constant.scala
@@ -0,0 +1,224 @@
+package millfork.env
+
+import millfork.error.ErrorReporting
+import millfork.node.Position
+
+object Constant {
+ val Zero: Constant = NumericConstant(0, 1)
+ val One: Constant = NumericConstant(1, 1)
+
+ def error(msg: String, position: Option[Position] = None): Constant = {
+ ErrorReporting.error(msg, position)
+ Zero
+ }
+
+ def minimumSize(value: Long): Int = if (value < -128 || value > 255) 2 else 1 // TODO !!!
+}
+
+import millfork.env.Constant.minimumSize
+import millfork.error.ErrorReporting
+import millfork.node.Position
+
+sealed trait Constant {
+
+ def asl(i: Constant): Constant = i match {
+ case NumericConstant(sa, _) => asl(sa.toInt)
+ case _ => CompoundConstant(MathOperator.Shl, this, i)
+ }
+
+ def asl(i: Int): Constant = CompoundConstant(MathOperator.Shl, this, NumericConstant(i, 1))
+
+ def requiredSize: Int
+
+ def +(that: Constant): Constant = CompoundConstant(MathOperator.Plus, this, that)
+
+ def -(that: Constant): Constant = CompoundConstant(MathOperator.Minus, this, that)
+
+ def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that))
+
+ def -(that: Long): Constant = this + (-that)
+
+ def loByte: Constant = {
+ if (requiredSize == 1) return this
+ HalfWordConstant(this, hi = false)
+ }
+
+ def hiByte: Constant = {
+ if (requiredSize == 1) Constant.Zero
+ else HalfWordConstant(this, hi = true)
+ }
+
+ def subbyte(index: Int): Constant = {
+ if (requiredSize <= index) Constant.Zero
+ else index match {
+ case 0 => loByte
+ case 1 => hiByte
+ case _ => SubbyteConstant(this, index)
+ }
+ }
+
+ def isLowestByteAlwaysEqual(i: Int) : Boolean = false
+
+ def quickSimplify: Constant = this
+}
+
+case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant
+
+case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
+ if (requiredSize == 1) {
+ if (value < -128 || value > 255) {
+ throw new IllegalArgumentException("The constant is too big")
+ }
+ }
+
+ override def isLowestByteAlwaysEqual(i: Int) : Boolean = (value & 0xff) == (i&0xff)
+
+ override def asl(i: Int) = NumericConstant(value << i, requiredSize + i / 8)
+
+ override def +(that: Constant): Constant = that + value
+
+ override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that))
+
+ override def toString: String = if (value > 9) value.formatted("$%X") else value.toString
+}
+
+case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant {
+ override def requiredSize = 2
+
+ override def toString: String = thing.name
+}
+
+case class HalfWordConstant(base: Constant, hi: Boolean) extends Constant {
+ override def quickSimplify: Constant = {
+ val simplified = base.quickSimplify
+ simplified match {
+ case NumericConstant(x, size) => if (hi) {
+ if (size == 1) Constant.Zero else NumericConstant((x >> 8) & 0xff, 1)
+ } else {
+ NumericConstant(x & 0xff, 1)
+ }
+ case _ => HalfWordConstant(simplified, hi)
+ }
+ }
+
+ override def requiredSize = 1
+
+ override def toString: String = base + (if (hi) ".hi" else ".lo")
+}
+
+case class SubbyteConstant(base: Constant, index: Int) extends Constant {
+ override def quickSimplify: Constant = {
+ val simplified = base.quickSimplify
+ simplified match {
+ case NumericConstant(x, size) => if (index >= size) {
+ Constant.Zero
+ } else {
+ NumericConstant((x >> (index * 8)) & 0xff, 1)
+ }
+ case _ => SubbyteConstant(simplified, index)
+ }
+ }
+
+ override def requiredSize = 1
+
+ override def toString: String = base + (index match {
+ case 0 => ".lo"
+ case 1 => ".hi"
+ case 2 => ".b2"
+ case 3 => ".b3"
+ })
+}
+
+object MathOperator extends Enumeration {
+ val Plus, Minus, Times, Shl, Shr,
+ DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShr,
+ And, Or, Exor = Value
+}
+
+case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Constant) extends Constant {
+ override def quickSimplify: Constant = {
+ val l = lhs.quickSimplify
+ val r = rhs.quickSimplify
+ (l, r) match {
+ case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
+ var size = ls max rs
+ val value = operator match {
+ case MathOperator.Plus => lv + rv
+ case MathOperator.Minus => lv - rv
+ case MathOperator.Times => lv * rv
+ case MathOperator.Shl => lv << rv
+ case MathOperator.Shr => lv >> rv
+ case MathOperator.Exor => lv ^ rv
+ case MathOperator.Or => lv | rv
+ case MathOperator.And => lv & rv
+ case _ => return this
+ }
+ operator match {
+ case MathOperator.Times | MathOperator.Shl =>
+ val mask = (1 << (size * 8)) - 1
+ if (value != (value & mask)){
+ size = ls + rs
+ }
+ case _ =>
+ }
+ NumericConstant(value, size)
+ case _ => CompoundConstant(operator, l, r)
+ }
+ }
+
+
+ import MathOperator._
+
+ override def +(that: Constant): Constant = {
+ that match {
+ case NumericConstant(n, _) => this + n
+ case _ => super.+(that)
+ }
+ }
+
+ override def +(that: Long): Constant = {
+ if (that == 0) {
+ return this
+ }
+ val That = that
+ val MinusThat = -that
+ this match {
+ case CompoundConstant(Plus, NumericConstant(MinusThat, _), r) => r
+ case CompoundConstant(Plus, l, NumericConstant(MinusThat, _)) => l
+ case CompoundConstant(Plus, NumericConstant(x, _), r) => CompoundConstant(Plus, r, NumericConstant(x + that, minimumSize(x + that)))
+ case CompoundConstant(Plus, l, NumericConstant(x, _)) => CompoundConstant(Plus, l, NumericConstant(x + that, minimumSize(x + that)))
+ case CompoundConstant(Minus, l, NumericConstant(That, _)) => l
+ case _ => CompoundConstant(Plus, this, NumericConstant(that, minimumSize(that)))
+ }
+ }
+
+ private def plhs = lhs match {
+ case _: NumericConstant | _: MemoryAddressConstant => lhs
+ case _ => "(" + lhs + ')'
+ }
+
+ private def prhs = lhs match {
+ case _: NumericConstant | _: MemoryAddressConstant => rhs
+ case _ => "(" + rhs + ')'
+ }
+
+ override def toString: String = {
+ operator match {
+ case Plus => f"$plhs + $prhs"
+ case Minus => f"$plhs - $prhs"
+ case Times => f"$plhs * $prhs"
+ case Shl => f"$plhs << $prhs"
+ case Shr => f"$plhs >> $prhs"
+ case DecimalPlus => f"$plhs +' $prhs"
+ case DecimalMinus => f"$plhs -' $prhs"
+ case DecimalTimes => f"$plhs *' $prhs"
+ case DecimalShl => f"$plhs <<' $prhs"
+ case DecimalShr => f"$plhs >>' $prhs"
+ case And => f"$plhs & $prhs"
+ case Or => f"$plhs | $prhs"
+ case Exor => f"$plhs ^ $prhs"
+ }
+ }
+
+ override def requiredSize: Int = lhs.requiredSize max rhs.requiredSize
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala
new file mode 100644
index 00000000..c5d8cb2b
--- /dev/null
+++ b/src/main/scala/millfork/env/Environment.scala
@@ -0,0 +1,618 @@
+package millfork.env
+
+import java.util.concurrent.atomic.AtomicLong
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.assembly.Opcode
+import millfork.compiler._
+import millfork.error.ErrorReporting
+import millfork.node._
+import millfork.output.VariableAllocator
+
+import scala.collection.mutable
+
+
+/**
+ * @author Karol Stasiak
+ */
+//noinspection NotImplementedCode
+class Environment(val parent: Option[Environment], val prefix: String) {
+
+
+ private var baseStackOffset = 0x101
+ private val relVarId = new AtomicLong
+
+ def genRelativeVariable(constant: Constant, typ: Type, zeropage: Boolean): RelativeVariable = {
+ val variable = RelativeVariable(".rv__" + relVarId.incrementAndGet().formatted("%06d"), constant, typ, zeropage = zeropage)
+ addThing(variable, None)
+ variable
+ }
+
+
+ def allThings: Environment = {
+ val allThings: Map[String, Thing] = things.values.map {
+ case m: FunctionInMemory =>
+ m.environment.getAllPrefixedThings
+ case m: InlinedFunction =>
+ m.environment.getAllPrefixedThings
+ case _ => Map[String, Thing]()
+ }.fold(things.toMap)(_ ++ _)
+ val e = new Environment(None, "")
+ e.things.clear()
+ e.things ++= allThings
+ e
+ }
+
+
+ private def getAllPrefixedThings = {
+ things.toMap.map { case (n, th) => (if (n.startsWith(".")) n else prefix + n, th) }
+ }
+
+ def getAllLocalVariables: List[Variable] = things.values.flatMap {
+ case v: Variable =>
+ Some(v)
+ case _ => None
+ }.toList
+
+ def allPreallocatables: List[PrellocableThing] = things.values.flatMap {
+ case m: NormalFunction => Some(m)
+ case m: InitializedArray => Some(m)
+ case _ => None
+ }.toList
+
+ def allConstants: List[ConstantThing] = things.values.flatMap {
+ case m: NormalFunction => m.environment.allConstants
+ case m: InlinedFunction => m.environment.allConstants
+ case m: ConstantThing => List(m)
+ case _ => Nil
+ }.toList
+
+ def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = {
+ val b = get[Type]("byte")
+ val p = get[Type]("pointer")
+ var params = nf.fold(List[String]()) { f =>
+ f.params match {
+ case NormalParamSignature(ps) =>
+ ps.map(p => p.name)
+ case _ =>
+ Nil
+ }
+ }.toSet
+ val toAdd = things.values.flatMap {
+ case m: UninitializedMemory =>
+ val vertex = if (options.flag(CompilationFlag.VariableOverlap)) {
+ nf.fold[VariableVertex](GlobalVertex) { f =>
+ if (m.alloc == VariableAllocationMethod.Static) {
+ GlobalVertex
+ } else if (params(m.name)) {
+ ParamVertex(f.name)
+ } else {
+ LocalVertex(f.name)
+ }
+ }
+ } else GlobalVertex
+ m.alloc match {
+ case VariableAllocationMethod.None =>
+ Nil
+ case VariableAllocationMethod.Zeropage =>
+ m.sizeInBytes match {
+ case 2 =>
+ val addr =
+ allocator.allocatePointer(callGraph, vertex)
+ onEachVariable(m.name, addr)
+ List(
+ ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
+ )
+ }
+ case VariableAllocationMethod.Auto | VariableAllocationMethod.Static =>
+ m.sizeInBytes match {
+ case 0 => Nil
+ case 2 =>
+ val addr =
+ allocator.allocateBytes(callGraph, vertex, options, 2)
+ onEachVariable(m.name, addr)
+ List(
+ ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
+ )
+ case count =>
+ val addr = allocator.allocateBytes(callGraph, vertex, options, count)
+ onEachVariable(m.name, addr)
+ List(
+ ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
+ )
+ }
+ }
+ case f: NormalFunction =>
+ f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable)
+ Nil
+ case _ => Nil
+ }.toList
+ val tagged: List[(String, Thing)] = toAdd.map(x => x.name -> x)
+ things ++= tagged
+ }
+
+ val things: mutable.Map[String, Thing] = mutable.Map()
+
+ private def addThing(t: Thing, position: Option[Position]): Unit = {
+ assertNotDefined(t.name, position)
+ things(t.name.stripPrefix(prefix)) = t
+ }
+
+ def removeVariable(str: String): Unit = {
+ things -= str
+ things -= str + ".addr"
+ }
+
+ def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = {
+ val clazz = implicitly[Manifest[T]].runtimeClass
+ if (things.contains(name)) {
+ val t: Thing = things(name)
+ if ((t ne null) && clazz.isInstance(t)) {
+ t.asInstanceOf[T]
+ } else {
+ ErrorReporting.fatal(s"`$name` is not a ${clazz.getSimpleName}", position)
+ }
+ } else parent.fold {
+ ErrorReporting.fatal(s"${clazz.getSimpleName} `$name` is not defined", position)
+ } {
+ _.get[T](name, position)
+ }
+ }
+
+ def maybeGet[T <: Thing : Manifest](name: String): Option[T] = {
+ if (things.contains(name)) {
+ val t: Thing = things(name)
+ val clazz = implicitly[Manifest[T]].runtimeClass
+ if ((t ne null) && clazz.isInstance(t)) {
+ Some(t.asInstanceOf[T])
+ } else {
+ None
+ }
+ } else parent.flatMap {
+ _.maybeGet[T](name)
+ }
+ }
+
+ def getArrayOrPointer(arrayName: String): Thing = {
+ maybeGet[ThingInMemory](arrayName).
+ orElse(maybeGet[ThingInMemory](arrayName + ".array")).
+ orElse(maybeGet[ConstantThing](arrayName)).
+ getOrElse(ErrorReporting.fatal(s"`$arrayName` is not an array or a pointer"))
+ }
+
+ if (parent.isEmpty) {
+ addThing(VoidType, None)
+ addThing(BuiltInBooleanType, None)
+ addThing(BasicPlainType("byte", 1), None)
+ addThing(BasicPlainType("word", 2), None)
+ addThing(BasicPlainType("long", 4), None)
+ addThing(DerivedPlainType("pointer", get[PlainType]("word"), isSigned = false), None)
+ addThing(DerivedPlainType("ubyte", get[PlainType]("byte"), isSigned = false), None)
+ addThing(DerivedPlainType("sbyte", get[PlainType]("byte"), isSigned = true), None)
+ addThing(DerivedPlainType("cent", get[PlainType]("byte"), isSigned = false), None)
+ val trueType = ConstantBooleanType("true$", value = true)
+ val falseType = ConstantBooleanType("false$", value = false)
+ addThing(trueType, None)
+ addThing(falseType, None)
+ addThing(ConstantThing("true", NumericConstant(0, 0), trueType), None)
+ addThing(ConstantThing("false", NumericConstant(0, 0), falseType), None)
+ addThing(FlagBooleanType("set_carry", Opcode.BCS, Opcode.BCC), None)
+ addThing(FlagBooleanType("clear_carry", Opcode.BCC, Opcode.BCS), None)
+ addThing(FlagBooleanType("set_overflow", Opcode.BVS, Opcode.BVC), None)
+ addThing(FlagBooleanType("clear_overflow", Opcode.BVC, Opcode.BVS), None)
+ addThing(FlagBooleanType("set_zero", Opcode.BEQ, Opcode.BNE), None)
+ addThing(FlagBooleanType("clear_zero", Opcode.BNE, Opcode.BEQ), None)
+ addThing(FlagBooleanType("set_negative", Opcode.BMI, Opcode.BPL), None)
+ addThing(FlagBooleanType("clear_negative", Opcode.BPL, Opcode.BMI), None)
+ }
+
+ def assertNotDefined(name: String, position: Option[Position]): Unit = {
+ if (things.contains(name) || parent.exists(_.things.contains(name)))
+ ErrorReporting.fatal(s"`$name` is already defined", position)
+ }
+
+ def registerType(stmt: TypeDefinitionStatement): Unit = {
+ // addThing(DerivedPlainType(stmt.name, get(stmt.parent)))
+ ???
+ }
+
+ def sequence[A](a: List[Option[A]]): Option[List[A]] = a match {
+ case Nil => Some(Nil)
+ case None :: _ => None
+ case Some(r) :: t => sequence(t) map (r :: _)
+ }
+
+ def evalVariableAndConstantSubParts(e: Expression): (Option[Expression], Constant) =
+ e match {
+ case SumExpression(params, false) =>
+ val (constants, variables) = params.map { case (sign, expr) => (sign, expr, eval(expr)) }.partition(_._3.isDefined)
+ val constant = eval(SumExpression(constants.map(x => (x._1, x._2)), decimal = false)).get
+ val variable = variables match {
+ case Nil => None
+ case List((false, x, _)) => Some(x)
+ case _ => Some(SumExpression(variables.map(x => (x._1, x._2)), decimal = false))
+ }
+ variable -> constant
+ case _ => eval(e) match {
+ case Some(c) => None -> c
+ case None => Some(e) -> Constant.Zero
+ }
+ }
+
+ def eval(e: Expression): Option[Constant] = {
+ e match {
+ case LiteralExpression(value, size) => Some(NumericConstant(value, size))
+ case VariableExpression(name) =>
+ maybeGet[ConstantThing](name).map(_.value)
+ case IndexedExpression(_, _) => None
+ case HalfWordExpression(param, hi) => eval(e).map(c => if (hi) c.hiByte else c.loByte)
+ case SumExpression(params, decimal) =>
+ params.map {
+ case (minus, param) => (minus, eval(param))
+ }.foldLeft(Some(Constant.Zero).asInstanceOf[Option[Constant]]) { (oc, pair) =>
+ oc.flatMap { c =>
+ pair match {
+ case (_, None) => None
+ case (minus, Some(addend)) =>
+ val op = if (decimal) {
+ if (minus) MathOperator.DecimalMinus else MathOperator.DecimalPlus
+ } else {
+ if (minus) MathOperator.Minus else MathOperator.Plus
+ }
+ Some(CompoundConstant(op, c, addend))
+ }
+ }
+ }
+ case SeparateBytesExpression(h, l) => for {
+ lc <- eval(l)
+ hc <- eval(h)
+ } yield hc.asl(8) + lc
+ case FunctionCallExpression(name, params) =>
+ name match {
+ case "*" =>
+ constantOperation(MathOperator.Times, params)
+ case "&&" | "&" =>
+ constantOperation(MathOperator.And, params)
+ case "^" =>
+ constantOperation(MathOperator.Exor, params)
+ case "||" | "|" =>
+ constantOperation(MathOperator.Or, params)
+ case _ =>
+ None
+ }
+ }
+ }
+
+ private def constantOperation(op: MathOperator.Value, params: List[Expression]) = {
+ params.map(eval(_)).reduceLeft[Option[Constant]] { (oc, om) =>
+ for {
+ c <- oc
+ m <- om
+ } yield CompoundConstant(op, c, m)
+ }
+ }
+
+ def registerFunction(stmt: FunctionDeclarationStatement, options: CompilationOptions): Unit = {
+ val w = get[Type]("word")
+ val name = stmt.name
+ val resultType = get[Type](stmt.resultType)
+
+ if (stmt.reentrant && stmt.interrupt) ErrorReporting.error(s"Reentrant function `$name` cannot be an interrupt handler", stmt.position)
+ if (stmt.reentrant && stmt.params.nonEmpty) ErrorReporting.error(s"Reentrant function `$name` cannot have parameters", stmt.position)
+ if (stmt.interrupt && stmt.params.nonEmpty) ErrorReporting.error(s"Interrupt function `$name` cannot have parameters", stmt.position)
+ if (stmt.inlined) {
+ if (!stmt.assembly) {
+ if (stmt.params.nonEmpty) ErrorReporting.error(s"Inline non-assembly function `$name` cannot have parameters", stmt.position) // TODO: ???
+ if (resultType != VoidType) ErrorReporting.error(s"Inline non-assembly function `$name` must return void", stmt.position)
+ }
+ if (stmt.params.exists(_.assemblyParamPassingConvention.inNonInlinedOnly))
+ ErrorReporting.error(s"Inline function `$name` cannot have by-variable parameters", stmt.position)
+ } else {
+ if (!stmt.assembly) {
+ if (stmt.params.exists(!_.assemblyParamPassingConvention.isInstanceOf[ByVariable]))
+ ErrorReporting.error(s"Non-assembly function `$name` cannot have non-variable parameters", stmt.position)
+ }
+ if (stmt.params.exists(_.assemblyParamPassingConvention.inInlinedOnly))
+ ErrorReporting.error(s"Non-inline function `$name` cannot have inlinable parameters", stmt.position)
+ }
+
+ val env = new Environment(Some(this), name + "$")
+ stmt.params.foreach(p => env.registerParameter(p))
+ val params = if (stmt.assembly) {
+ AssemblyParamSignature(stmt.params.map {
+ pd =>
+ val typ = env.get[Type](pd.typ)
+ pd.assemblyParamPassingConvention match {
+ case ByVariable(vn) =>
+ AssemblyParam(typ, env.get[MemoryVariable](vn), AssemblyParameterPassingBehaviour.Copy)
+ case ByRegister(reg) =>
+ AssemblyParam(typ, RegisterVariable(reg, typ), AssemblyParameterPassingBehaviour.Copy)
+ case ByConstant(vn) =>
+ AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByConstant)
+ case ByReference(vn) =>
+ AssemblyParam(typ, Placeholder(vn, typ), AssemblyParameterPassingBehaviour.ByReference)
+ }
+ })
+ } else {
+ NormalParamSignature(stmt.params.map { pd =>
+ env.get[MemoryVariable](pd.assemblyParamPassingConvention.asInstanceOf[ByVariable].name)
+ })
+ }
+ stmt.statements match {
+ case None =>
+ stmt.address match {
+ case None =>
+ ErrorReporting.error(s"Extern function `${stmt.name}`needs an address", stmt.position)
+ case Some(a) =>
+ val addr = eval(a).getOrElse(Constant.error(s"Address of `${stmt.name}` is not a constant", stmt.position))
+ val mangled = ExternFunction(
+ name,
+ resultType,
+ params,
+ addr,
+ env
+ )
+ addThing(mangled, stmt.position)
+ registerAddressConstant(mangled, stmt.position)
+ addThing(ConstantThing(name + '`', addr, w), stmt.position)
+ }
+
+ case Some(statements) =>
+ statements.foreach {
+ case v: VariableDeclarationStatement => env.registerVariable(v, options)
+ case _ => ()
+ }
+ val executableStatements = statements.flatMap {
+ case e: ExecutableStatement => Some(e)
+ case _ => None
+ }
+ val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement])
+ if (stmt.inlined) {
+ val mangled = new InlinedFunction(
+ name,
+ resultType,
+ params,
+ env,
+ executableStatements ++ (if (needsExtraRTS) List(AssemblyStatement.implied(Opcode.RTS, elidable = true)) else Nil),
+ )
+ addThing(mangled, stmt.position)
+ } else {
+ var stackVariablesSize = env.things.values.map {
+ case StackVariable(n, t, _) if !n.contains(".") => t.size
+ case _ => 0
+ }.sum
+ val mangled = NormalFunction(
+ name,
+ resultType,
+ params,
+ env,
+ stackVariablesSize,
+ stmt.address.map(a => this.eval(a).getOrElse(Constant.error(s"Address of `${stmt.name}` is not a constant"))),
+ executableStatements ++ (if (needsExtraRTS) List(ReturnStatement(None)) else Nil),
+ interrupt = stmt.interrupt,
+ reentrant = stmt.reentrant,
+ position = stmt.position
+ )
+ addThing(mangled, stmt.position)
+ registerAddressConstant(mangled, stmt.position)
+ }
+ }
+ }
+
+ private def registerAddressConstant(thing: ThingInMemory, position: Option[Position]): Unit = {
+ val addr = thing.toAddress
+ addThing(ConstantThing(thing.name + ".addr", addr, get[Type]("pointer")), position)
+ addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, get[Type]("byte")), position)
+ addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, get[Type]("byte")), position)
+ }
+
+ def registerParameter(stmt: ParameterDeclaration): Unit = {
+ val typ = get[Type](stmt.typ)
+ val b = get[Type]("byte")
+ val p = get[Type]("pointer")
+ stmt.assemblyParamPassingConvention match {
+ case ByVariable(name) =>
+ val zp = typ.name == "pointer" // TODO
+ val v = MemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto)
+ addThing(v, stmt.position)
+ registerAddressConstant(v, stmt.position)
+ if (typ.size == 2) {
+ val addr = v.toAddress
+ addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = zp), stmt.position)
+ addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = zp), stmt.position)
+ }
+ case ByRegister(_) => ()
+ case ByConstant(name) =>
+ val v = ConstantThing(prefix + name, UnexpandedConstant(prefix + name, typ.size), typ)
+ addThing(v, stmt.position)
+ case ByReference(name) =>
+ val addr = UnexpandedConstant(prefix + name, typ.size)
+ val v = RelativeVariable(prefix + name, addr, p, zeropage = false)
+ addThing(v, stmt.position)
+ addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = false), stmt.position)
+ addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = false), stmt.position)
+ }
+ }
+
+ def registerArray(stmt: ArrayDeclarationStatement): Unit = {
+ val b = get[Type]("byte")
+ val p = get[Type]("pointer")
+ stmt.elements match {
+ case None =>
+ stmt.length match {
+ case None => ErrorReporting.error(s"Array `${stmt.name}` without size nor contents", stmt.position)
+ case Some(l) =>
+ val address = stmt.address.map(a => eval(a).getOrElse(ErrorReporting.fatal(s"Array `${stmt.name}` has non-constant address", stmt.position)))
+ val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
+ lengthConst match {
+ case NumericConstant(length, _) =>
+ if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position)
+ val array = address match {
+ case None => UninitializedArray(stmt.name + ".array", length.toInt)
+ case Some(aa) => RelativeArray(stmt.name + ".array", aa, length.toInt)
+ }
+ addThing(array, stmt.position)
+ registerAddressConstant(MemoryVariable(stmt.name, p, VariableAllocationMethod.None), stmt.position)
+ val a = address match {
+ case None => array.toAddress
+ case Some(aa) => aa
+ }
+ addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position)
+ addThing(ConstantThing(stmt.name, a, p), stmt.position)
+ addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position)
+ if (length < 256) {
+ addThing(ConstantThing(stmt.name + ".length", lengthConst, b), stmt.position)
+ }
+ case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position)
+ }
+ }
+ case Some(contents) =>
+ stmt.length match {
+ case None =>
+ case Some(l) =>
+ val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
+ lengthConst match {
+ case NumericConstant(ll, _) =>
+ if (ll != contents.length) ErrorReporting.error(s"Array `${stmt.name}` has different declared and actual length", stmt.position)
+ case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position)
+ }
+ }
+ val length = contents.length
+ if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position)
+ val address = stmt.address.map(a => eval(a).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant address", stmt.position)))
+ val data = contents.map(x => eval(x).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant contents", stmt.position)))
+ val array = InitializedArray(stmt.name + ".array", address, data)
+ addThing(array, stmt.position)
+ registerAddressConstant(MemoryVariable(stmt.name, p, VariableAllocationMethod.None), stmt.position)
+ val a = address match {
+ case None => array.toAddress
+ case Some(aa) => aa
+ }
+ addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position)
+ addThing(ConstantThing(stmt.name, a, p), stmt.position)
+ addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position)
+ addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position)
+ if (length < 256) {
+ addThing(ConstantThing(stmt.name + ".length", NumericConstant(length, 1), b), stmt.position)
+ }
+ }
+ }
+
+ def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions): Unit = {
+ if (stmt.volatile) {
+ ErrorReporting.warn("`volatile` not yet supported", options)
+ }
+ val name = stmt.name
+ val position = stmt.position
+ if (stmt.stack && parent.isEmpty) {
+ if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position)
+ }
+ val b = get[Type]("byte")
+ val typ = get[PlainType](stmt.typ)
+ if (stmt.typ == "pointer") {
+ // if (stmt.constant) {
+ // ErrorReporting.error(s"Pointer `${stmt.name}` cannot be constant")
+ // }
+ stmt.address.flatMap(eval) match {
+ case Some(NumericConstant(a, _)) =>
+ if ((a & 0xff00) != 0)
+ ErrorReporting.error(s"Pointer `${stmt.name}` cannot be located outside the zero page")
+ case _ => ()
+ }
+ }
+ if (stmt.constant) {
+ if (stmt.stack) ErrorReporting.error(s"`$name` is a constant and cannot be on stack", position)
+ if (stmt.address.isDefined) ErrorReporting.error(s"`$name` is a constant and cannot have an address", position)
+ if (stmt.initialValue.isEmpty) ErrorReporting.error(s"`$name` is a constant and requires a value", position)
+ val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(Constant.error(s"`$name` has a non-constant value", position))
+ if (constantValue.requiredSize > typ.size) ErrorReporting.error(s"`$name` is has an invalid value: not in the range of `$typ`", position)
+ addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position)
+ if (typ.size == 2) {
+ addThing(ConstantThing(prefix + name + ".hi", constantValue + 1, b), stmt.position)
+ addThing(ConstantThing(prefix + name + ".lo", constantValue, b), stmt.position)
+ }
+ } else {
+ if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position)
+ if (stmt.initialValue.isDefined) ErrorReporting.error(s"`$name` is not a constant and cannot have a value", position)
+ if (stmt.stack) {
+ val v = StackVariable(prefix + name, typ, this.baseStackOffset)
+ baseStackOffset += typ.size
+ addThing(v, stmt.position)
+ if (typ.size == 2) {
+ addThing(StackVariable(prefix + name + ".lo", b, baseStackOffset), stmt.position)
+ addThing(StackVariable(prefix + name + ".hi", b, baseStackOffset + 1), stmt.position)
+ }
+ } else {
+ val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({
+ val alloc = if (typ.name == "pointer") VariableAllocationMethod.Zeropage else if (stmt.global) VariableAllocationMethod.Static else VariableAllocationMethod.Auto
+ val v = MemoryVariable(prefix + name, typ, alloc)
+ registerAddressConstant(v, stmt.position)
+ (v, v.toAddress)
+ })(a => {
+ val addr = eval(a).getOrElse(Constant.error(s"Address of `$name` has a non-constant value", position))
+ val zp = addr match {
+ case NumericConstant(n, _) => n < 0x100
+ case _ => false
+ }
+ (RelativeVariable(prefix + name, addr, typ, zeropage = zp), addr)
+ })
+ addThing(v, stmt.position)
+ if (!v.isInstanceOf[MemoryVariable]) {
+ addThing(ConstantThing(v.name + "`", addr, b), stmt.position)
+ }
+ if (typ.size == 2) {
+ addThing(RelativeVariable(prefix + name + ".hi", addr + 1, b, zeropage = v.zeropage), stmt.position)
+ addThing(RelativeVariable(prefix + name + ".lo", addr, b, zeropage = v.zeropage), stmt.position)
+ }
+ }
+ }
+ }
+
+ def lookup[T <: Thing : Manifest](name: String): Option[T] = {
+ if (things.contains(name)) {
+ maybeGet(name)
+ } else {
+ parent.flatMap(_.lookup[T](name))
+ }
+ }
+
+ def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = {
+ if (things.contains(name)) {
+ val function = get[MangledFunction](name)
+ if (function.params.length != actualParams.length) {
+ ErrorReporting.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position))
+ }
+ function.params match {
+ case NormalParamSignature(params) =>
+ function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) =>
+ if (!actual.isAssignableTo(required)) {
+ ErrorReporting.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position)
+ }
+ }
+ case AssemblyParamSignature(params) =>
+ function.params.types.zip(actualParams).zipWithIndex.foreach { case ((required, (actual, expr)), ix) =>
+ if (!actual.isAssignableTo(required)) {
+ ErrorReporting.error(s"Invalid value for parameter ${ix + 1} of function `$name`", expr.position)
+ }
+ }
+ }
+ Some(function)
+ } else {
+ parent.flatMap(_.lookupFunction(name, actualParams))
+ }
+ }
+
+ def collectDeclarations(program: Program, options: CompilationOptions): Unit = {
+ program.declarations.foreach {
+ case f: FunctionDeclarationStatement => registerFunction(f, options)
+ case v: VariableDeclarationStatement => registerVariable(v, options)
+ case a: ArrayDeclarationStatement => registerArray(a)
+ case i: ImportStatement => ()
+ }
+ }
+}
diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala
new file mode 100644
index 00000000..38d13035
--- /dev/null
+++ b/src/main/scala/millfork/env/Thing.scala
@@ -0,0 +1,264 @@
+package millfork.env
+
+import millfork.assembly.Opcode
+import millfork.error.ErrorReporting
+import millfork.node._
+
+sealed trait Thing {
+ def name: String
+}
+
+sealed trait Type extends Thing {
+
+ def size: Int
+
+ def isSigned: Boolean
+
+ def isSubtypeOf(other: Type): Boolean = this == other
+
+ def isCompatible(other: Type): Boolean = this == other
+
+ override def toString(): String = name
+
+ def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType)
+}
+
+case object VoidType extends Type {
+ def size = 0
+
+ def isSigned = false
+
+ override def name = "void"
+}
+
+sealed trait PlainType extends Type {
+ override def isCompatible(other: Type): Boolean = this == other || this.isSubtypeOf(other) || other.isSubtypeOf(this)
+
+ override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || (targetType match {
+ case BasicPlainType(_, size) => size > this.size // TODO
+ case _ => false
+ })
+}
+
+case class BasicPlainType(name: String, size: Int) extends PlainType {
+ def isSigned = false
+
+ override def isSubtypeOf(other: Type): Boolean = this == other
+}
+
+case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean) extends PlainType {
+ def size: Int = parent.size
+
+ override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other)
+}
+
+sealed trait BooleanType extends Type {
+ def size = 0
+
+ def isSigned = false
+}
+
+case class ConstantBooleanType(name: String, value: Boolean) extends BooleanType
+
+case class FlagBooleanType(name: String, jumpIfTrue: Opcode.Value, jumpIfFalse: Opcode.Value) extends BooleanType
+
+case object BuiltInBooleanType extends BooleanType {
+ override def name = "bool$"
+}
+
+sealed trait TypedThing extends Thing {
+ def typ: Type
+}
+
+
+sealed trait ThingInMemory extends Thing {
+ def toAddress: Constant
+}
+
+sealed trait PrellocableThing extends ThingInMemory {
+ def shouldGenerate: Boolean
+
+ def address: Option[Constant]
+
+ def toAddress: Constant = address.getOrElse(MemoryAddressConstant(this))
+}
+
+case class Label(name: String) extends ThingInMemory {
+ override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
+}
+
+sealed trait Variable extends TypedThing
+
+case class BlackHole(typ: Type) extends Variable {
+ override def name = ""
+}
+
+sealed trait VariableInMemory extends Variable with ThingInMemory {
+
+ def zeropage: Boolean
+}
+
+case class RegisterVariable(register: Register.Value, typ: Type) extends Variable {
+ def name: String = register.toString
+}
+
+case class Placeholder(name: String, typ: Type) extends Variable
+
+sealed trait UninitializedMemory extends ThingInMemory {
+ def sizeInBytes: Int
+
+ def alloc: VariableAllocationMethod.Value
+}
+
+object VariableAllocationMethod extends Enumeration {
+ val Auto, Static, Zeropage, None = Value
+}
+
+case class StackVariable(name: String, typ: Type, baseOffset: Int) extends Variable {
+ def sizeInBytes: Int = typ.size
+}
+
+case class MemoryVariable(name: String, typ: Type, alloc: VariableAllocationMethod.Value) extends VariableInMemory with UninitializedMemory {
+ override def sizeInBytes: Int = typ.size
+
+ override def zeropage: Boolean = alloc == VariableAllocationMethod.Zeropage
+
+ override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
+}
+
+trait MlArray extends ThingInMemory
+
+case class UninitializedArray(name: String, sizeInBytes: Int) extends MlArray with UninitializedMemory {
+ override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
+
+ override def alloc = VariableAllocationMethod.Static
+}
+
+case class RelativeArray(name: String, address: Constant, sizeInBytes: Int) extends MlArray {
+ override def toAddress: Constant = address
+}
+
+case class InitializedArray(name: String, address: Option[Constant], contents: List[Constant]) extends MlArray with PrellocableThing {
+ override def shouldGenerate = true
+}
+
+case class RelativeVariable(name: String, address: Constant, typ: Type, zeropage: Boolean) extends VariableInMemory {
+ override def toAddress: Constant = address
+}
+
+
+sealed trait MangledFunction extends Thing {
+ def name: String
+
+ def returnType: Type
+
+ def params: ParamSignature
+
+ def interrupt: Boolean
+}
+
+case class EmptyFunction(name: String,
+ returnType: Type,
+ paramType: Type) extends MangledFunction {
+ override def params = EmptyFunctionParamSignature(paramType)
+
+ override def interrupt = false
+}
+
+case class InlinedFunction(name: String,
+ returnType: Type,
+ params: ParamSignature,
+ environment: Environment,
+ code: List[ExecutableStatement]) extends MangledFunction {
+ override def interrupt = false
+}
+
+sealed trait FunctionInMemory extends MangledFunction with ThingInMemory {
+ def environment: Environment
+}
+
+case class ExternFunction(name: String,
+ returnType: Type,
+ params: ParamSignature,
+ address: Constant,
+ environment: Environment) extends FunctionInMemory {
+ override def toAddress: Constant = address
+
+ override def interrupt = false
+}
+
+case class NormalFunction(name: String,
+ returnType: Type,
+ params: ParamSignature,
+ environment: Environment,
+ stackVariablesSize: Int,
+ address: Option[Constant],
+ code: List[ExecutableStatement],
+ interrupt: Boolean,
+ reentrant: Boolean,
+ position: Option[Position]) extends FunctionInMemory with PrellocableThing {
+ override def shouldGenerate = true
+}
+
+case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing
+
+trait ParamSignature {
+ def types: List[Type]
+
+ def length: Int
+}
+
+case class NormalParamSignature(params: List[MemoryVariable]) extends ParamSignature {
+ override def length: Int = params.length
+
+ override def types: List[Type] = params.map(_.typ)
+}
+
+sealed trait ParamPassingConvention {
+ def inInlinedOnly: Boolean
+
+ def inNonInlinedOnly: Boolean
+}
+
+case class ByRegister(register: Register.Value) extends ParamPassingConvention {
+ override def inInlinedOnly = false
+
+ override def inNonInlinedOnly = false
+}
+
+case class ByVariable(name: String) extends ParamPassingConvention {
+ override def inInlinedOnly = false
+
+ override def inNonInlinedOnly = true
+}
+
+case class ByConstant(name: String) extends ParamPassingConvention {
+ override def inInlinedOnly = true
+
+ override def inNonInlinedOnly = false
+}
+
+case class ByReference(name: String) extends ParamPassingConvention {
+ override def inInlinedOnly = true
+
+ override def inNonInlinedOnly = false
+}
+
+object AssemblyParameterPassingBehaviour extends Enumeration {
+ val Copy, ByReference, ByConstant = Value
+}
+
+case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value)
+
+
+case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature {
+ override def length: Int = params.length
+
+ override def types: List[Type] = params.map(_.typ)
+}
+
+case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature {
+ override def length: Int = 1
+
+ override def types: List[Type] = List(paramType)
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/error/ErrorReporting.scala b/src/main/scala/millfork/error/ErrorReporting.scala
new file mode 100644
index 00000000..a3cdbbb9
--- /dev/null
+++ b/src/main/scala/millfork/error/ErrorReporting.scala
@@ -0,0 +1,74 @@
+package millfork.error
+
+import millfork.{CompilationFlag, CompilationOptions}
+import millfork.node.Position
+
+object ErrorReporting {
+
+ var verbosity = 0
+
+ var hasErrors = false
+
+ def f(position: Option[Position]): String = position.fold("")(p => s"(${p.line}:${p.column}) ")
+
+ def info(msg: String, position: Option[Position] = None): Unit = {
+ if (verbosity < 0) return
+ println("INFO: " + f(position) + msg)
+ flushOutput()
+ }
+
+ def debug(msg: String, position: Option[Position] = None): Unit = {
+ if (verbosity < 1) return
+ println("DEBUG: " + f(position) + msg)
+ flushOutput()
+ }
+
+ def trace(msg: String, position: Option[Position] = None): Unit = {
+ if (verbosity < 2) return
+ println("TRACE: " + f(position) + msg)
+ flushOutput()
+ }
+
+ private def flushOutput(): Unit = {
+ System.out.flush()
+ System.err.flush()
+ }
+
+ def warn(msg: String, options: CompilationOptions, position: Option[Position] = None): Unit = {
+ if (verbosity < 0) return
+ println("WARN: " + f(position) + msg)
+ flushOutput()
+ if (options.flag(CompilationFlag.FatalWarnings)) {
+ hasErrors = true
+ }
+ }
+
+ def error(msg: String, position: Option[Position] = None): Unit = {
+ hasErrors = true
+ println("ERROR: " + f(position) + msg)
+ flushOutput()
+ }
+
+ def fatal(msg: String, position: Option[Position] = None): Nothing = {
+ hasErrors = true
+ println("FATAL: " + f(position) + msg)
+ flushOutput()
+ throw new RuntimeException(msg)
+ }
+
+ def fatalQuit(msg: String, position: Option[Position] = None): Nothing = {
+ hasErrors = true
+ println("FATAL: " + f(position) + msg)
+ flushOutput()
+ System.exit(1)
+ throw new RuntimeException(msg)
+ }
+
+ def assertNoErrors(msg: String): Unit = {
+ if (hasErrors) {
+ error(msg)
+ fatal("Build halted due to previous errors")
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala
new file mode 100644
index 00000000..9fc994bd
--- /dev/null
+++ b/src/main/scala/millfork/node/CallGraph.scala
@@ -0,0 +1,151 @@
+package millfork.node
+
+import millfork.error.ErrorReporting
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+sealed trait VariableVertex {
+ def function: String
+}
+
+case class ParamVertex(function: String) extends VariableVertex
+
+case class LocalVertex(function: String) extends VariableVertex
+
+case object GlobalVertex extends VariableVertex {
+ override def function = ""
+}
+
+trait CallGraph {
+ def canOverlap(a: VariableVertex, b: VariableVertex): Boolean
+}
+
+object RestrictiveCallGraph extends CallGraph {
+
+ def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
+}
+
+class StandardCallGraph(program: Program) extends CallGraph {
+
+ private val entryPoints = mutable.Set[String]()
+ // (F,G) means function F calls function G
+ private val callEdges = mutable.Set[(String, String)]()
+ // (F,G) means function G is called when building parameters for function F
+ private val paramEdges = mutable.Set[(String, String)]()
+ private val multiaccessibleFunctions = mutable.Set[String]()
+ private val everCalledFunctions = mutable.Set[String]()
+ private val allFunctions = mutable.Set[String]()
+
+ entryPoints += "main"
+ program.declarations.foreach(s => add(None, Nil, s))
+ everCalledFunctions.retain(allFunctions)
+
+ def add(currentFunction: Option[String], callingFunctions: List[String], node: Node): Unit = {
+ node match {
+ case f: FunctionDeclarationStatement =>
+ allFunctions += f.name
+ if (f.address.isDefined || f.interrupt) entryPoints += f.name
+ f.statements.getOrElse(Nil).foreach(s => this.add(Some(f.name), Nil, s))
+ case s: Statement =>
+ s.getAllExpressions.foreach(e => add(currentFunction, callingFunctions, e))
+ case g: FunctionCallExpression =>
+ everCalledFunctions += g.functionName
+ currentFunction.foreach(f => callEdges += f -> g.functionName)
+ callingFunctions.foreach(f => paramEdges += f -> g.functionName)
+ g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr))
+ case x: VariableExpression =>
+ val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr")
+ everCalledFunctions += varName
+ case _ => ()
+ }
+ }
+
+
+ def fillOut(): Unit = {
+ var changed = true
+ while (changed) {
+ changed = false
+ val toAdd = for {
+ (a, b) <- callEdges
+ (c, d) <- callEdges
+ if b == c
+ if !callEdges.contains(a -> d)
+ } yield (a, d)
+ if (toAdd.nonEmpty) {
+ callEdges ++= toAdd
+ changed = true
+ }
+ }
+
+ changed = true
+ while (changed) {
+ changed = false
+ val toAdd = for {
+ (a, b) <- paramEdges
+ (c, d) <- callEdges
+ if b == c
+ if !paramEdges.contains(a -> d)
+ } yield (a, d)
+ if (toAdd.nonEmpty) {
+ paramEdges ++= toAdd
+ changed = true
+ }
+ }
+ multiaccessibleFunctions ++= entryPoints
+ everCalledFunctions ++= entryPoints
+ callEdges.filter(e => entryPoints.contains(e._1)).foreach(e => everCalledFunctions += e._2)
+ multiaccessibleFunctions ++= callEdges.filter(e => entryPoints.contains(e._1)).map(_._2).groupBy(identity).filter(p => p._2.size > 1).keys
+
+ ErrorReporting.trace("Call edges:")
+ callEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString))
+
+ ErrorReporting.trace("Param edges:")
+ paramEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString))
+
+ ErrorReporting.trace("Entry points:")
+ entryPoints.toList.sorted.foreach(ErrorReporting.trace(_))
+
+ ErrorReporting.trace("Multiaccessible functions:")
+ multiaccessibleFunctions.toList.sorted.foreach(ErrorReporting.trace(_))
+
+ ErrorReporting.trace("Ever called functions:")
+ everCalledFunctions.toList.sorted.foreach(ErrorReporting.trace(_))
+ }
+
+ def isEverCalled(function: String): Boolean = {
+ everCalledFunctions(function)
+ }
+
+ def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = {
+ if (a.function == b.function) {
+ return false
+ }
+ if (a == GlobalVertex || b == GlobalVertex) {
+ return false
+ }
+ if (multiaccessibleFunctions(a.function) || multiaccessibleFunctions(b.function)) {
+ return false
+ }
+ if (callEdges(a.function -> b.function) || callEdges(b.function -> a.function)) {
+ return false
+ }
+ a match {
+ case ParamVertex(af) =>
+ if (paramEdges(af -> b.function)) return false
+ case _ =>
+ }
+ b match {
+ case ParamVertex(bf) =>
+ if (paramEdges(bf -> a.function)) return false
+ case _ =>
+ }
+ ErrorReporting.trace(s"$a and $b can overlap")
+ true
+ }
+
+
+}
diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala
new file mode 100644
index 00000000..51d1367f
--- /dev/null
+++ b/src/main/scala/millfork/node/Node.scala
@@ -0,0 +1,181 @@
+package millfork.node
+
+import millfork.assembly.{AddrMode, Opcode}
+import millfork.env.{Label, ParamPassingConvention}
+
+case class Position(filename: String, line: Int, column: Int, cursor: Int)
+
+sealed trait Node {
+ var position: Option[Position] = None
+}
+
+object Node {
+ implicit class NodeOps[N<:Node](val node: N) extends AnyVal {
+ def pos(position: Position): N = {
+ node.position = Some(position)
+ node
+ }
+ }
+}
+
+sealed trait Expression extends Node {
+ def replaceVariable(variable: String, actualParam: Expression): Expression
+}
+
+case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression = this
+}
+
+case class BooleanLiteralExpression(value: Boolean) extends Expression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression = this
+}
+
+sealed trait LhsExpression extends Expression
+
+case object BlackHoleExpression extends LhsExpression {
+ override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
+}
+
+case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression {
+ def replaceVariable(variable: String, actualParam: Expression): Expression =
+ SeparateBytesExpression(
+ hi.replaceVariable(variable, actualParam),
+ lo.replaceVariable(variable, actualParam))
+}
+
+case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression =
+ SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal)
+}
+
+case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression =
+ FunctionCallExpression(functionName, expressions.map {
+ _.replaceVariable(variable, actualParam)
+ })
+}
+
+case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression =
+ HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte)
+}
+
+object Register extends Enumeration {
+ val A, X, Y, AX, AY, YA, XA, XY, YX = Value
+}
+
+//case class Indexing(child: Expression, register: Register.Value) extends Expression
+
+case class VariableExpression(name: String) extends LhsExpression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression =
+ if (name == variable) actualParam else this
+}
+
+case class IndexedExpression(name: String, index: Expression) extends LhsExpression {
+ override def replaceVariable(variable: String, actualParam: Expression): Expression =
+ if (name == variable) {
+ actualParam match {
+ case VariableExpression(actualVariable) => IndexedExpression(actualVariable, index.replaceVariable(variable, actualParam))
+ case _ => ??? // TODO
+ }
+ } else IndexedExpression(name, index.replaceVariable(variable, actualParam))
+}
+
+sealed trait Statement extends Node {
+ def getAllExpressions: List[Expression]
+}
+
+sealed trait DeclarationStatement extends Statement
+
+case class TypeDefinitionStatement(name: String, parent: String) extends DeclarationStatement {
+ override def getAllExpressions: List[Expression] = Nil
+}
+
+case class VariableDeclarationStatement(name: String,
+ typ: String,
+ global: Boolean,
+ stack: Boolean,
+ constant: Boolean,
+ volatile: Boolean,
+ initialValue: Option[Expression],
+ address: Option[Expression]) extends DeclarationStatement {
+ override def getAllExpressions: List[Expression] = List(initialValue, address).flatten
+}
+
+case class ArrayDeclarationStatement(name: String,
+ length: Option[Expression],
+ address: Option[Expression],
+ elements: Option[List[Expression]]) extends DeclarationStatement {
+ override def getAllExpressions: List[Expression] = List(length, address).flatten ++ elements.getOrElse(Nil)
+}
+
+case class ParameterDeclaration(typ: String,
+ assemblyParamPassingConvention: ParamPassingConvention) extends Node
+
+case class ImportStatement(filename: String) extends DeclarationStatement {
+ override def getAllExpressions: List[Expression] = Nil
+}
+
+case class FunctionDeclarationStatement(name: String,
+ resultType: String,
+ params: List[ParameterDeclaration],
+ address: Option[Expression],
+ statements: Option[List[Statement]],
+ inlined: Boolean,
+ assembly: Boolean,
+ interrupt: Boolean,
+ reentrant: Boolean) extends DeclarationStatement {
+ override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions)
+}
+
+sealed trait ExecutableStatement extends Statement
+
+case class ExpressionStatement(expression: Expression) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = List(expression)
+}
+
+case class ReturnStatement(value: Option[Expression]) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = value.toList
+}
+
+case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = List(destination, source)
+}
+
+case class LabelStatement(label: Label) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = Nil
+}
+
+case class AssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, expression: Expression, elidable: Boolean) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = List(expression)
+}
+
+case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions)
+}
+
+case class WhileStatement(condition: Expression, body: List[ExecutableStatement]) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
+}
+
+object ForDirection extends Enumeration {
+ val To, Until, DownTo, ParallelTo, ParallelUntil = Value
+}
+
+case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = start :: end :: body.flatMap(_.getAllExpressions)
+}
+
+case class DoWhileStatement(body: List[ExecutableStatement], condition: Expression) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
+}
+
+case class BlockStatement(body: List[ExecutableStatement]) extends ExecutableStatement {
+ override def getAllExpressions: List[Expression] = body.flatMap(_.getAllExpressions)
+}
+
+object AssemblyStatement {
+ def implied(opcode: Opcode.Value, elidable: Boolean) = AssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidable)
+
+ def nonexistent(opcode: Opcode.Value) = AssemblyStatement(opcode, AddrMode.DoesNotExist, LiteralExpression(0, 1), elidable = true)
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/node/Program.scala b/src/main/scala/millfork/node/Program.scala
new file mode 100644
index 00000000..89f1e12a
--- /dev/null
+++ b/src/main/scala/millfork/node/Program.scala
@@ -0,0 +1,11 @@
+package millfork.node
+
+import millfork.node.opt.NodeOptimization
+
+/**
+ * @author Karol Stasiak
+ */
+case class Program(declarations: List[DeclarationStatement]) {
+ def applyNodeOptimization(o: NodeOptimization) = Program(o.optimize(declarations).asInstanceOf[List[DeclarationStatement]])
+ def +(p:Program): Program = Program(this.declarations ++ p.declarations)
+}
diff --git a/src/main/scala/millfork/node/opt/NodeOptimization.scala b/src/main/scala/millfork/node/opt/NodeOptimization.scala
new file mode 100644
index 00000000..94421e8e
--- /dev/null
+++ b/src/main/scala/millfork/node/opt/NodeOptimization.scala
@@ -0,0 +1,16 @@
+package millfork.node.opt
+
+import millfork.node.{ExecutableStatement, Expression, Node, Statement}
+
+/**
+ * @author Karol Stasiak
+ */
+trait NodeOptimization {
+ def optimize(nodes: List[Node]): List[Node]
+
+ def optimizeExecutableStatements(nodes: List[ExecutableStatement]): List[ExecutableStatement] =
+ optimize(nodes).asInstanceOf[List[ExecutableStatement]]
+
+ def optimizeStatements(nodes: List[Statement]): List[Statement] =
+ optimize(nodes).asInstanceOf[List[Statement]]
+}
diff --git a/src/main/scala/millfork/node/opt/UnreachableCode.scala b/src/main/scala/millfork/node/opt/UnreachableCode.scala
new file mode 100644
index 00000000..fbff2a7b
--- /dev/null
+++ b/src/main/scala/millfork/node/opt/UnreachableCode.scala
@@ -0,0 +1,29 @@
+package millfork.node.opt
+
+import millfork.node._
+
+/**
+ * @author Karol Stasiak
+ */
+object UnreachableCode extends NodeOptimization {
+
+ override def optimize(nodes: List[Node]): List[Node] = nodes match {
+ case (x:FunctionDeclarationStatement)::xs =>
+ x.copy(statements = x.statements.map(optimizeStatements)) :: optimize(xs)
+ case (x:IfStatement)::xs =>
+ x.copy(
+ thenBranch = optimizeExecutableStatements(x.thenBranch),
+ elseBranch = optimizeExecutableStatements(x.elseBranch)) :: optimize(xs)
+ case (x:WhileStatement)::xs =>
+ x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs)
+ case (x:DoWhileStatement)::xs =>
+ x.copy(body = optimizeExecutableStatements(x.body)) :: optimize(xs)
+ case (x:ReturnStatement) :: xs =>
+ x :: Nil
+ case x :: xs =>
+ x :: optimize(xs)
+ case Nil =>
+ Nil
+ }
+
+}
diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala
new file mode 100644
index 00000000..bbb158f8
--- /dev/null
+++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala
@@ -0,0 +1,72 @@
+package millfork.node.opt
+
+import millfork.env._
+import millfork.error.ErrorReporting
+import millfork.node._
+
+/**
+ * @author Karol Stasiak
+ */
+object UnusedFunctions extends NodeOptimization {
+
+ override def optimize(nodes: List[Node]): List[Node] = {
+ val allNormalFunctions = nodes.flatMap {
+ case v: FunctionDeclarationStatement => if (v.address.isDefined || v.interrupt || v.name == "main") Nil else List(v.name)
+ case _ => Nil
+ }.toSet
+ val allCalledFunctions = getAllCalledFunctions(nodes).toSet
+ val unusedFunctions = allNormalFunctions -- allCalledFunctions
+ if (unusedFunctions.nonEmpty) {
+ ErrorReporting.debug("Removing unused functions: " + unusedFunctions.mkString(", "))
+ }
+ removeFunctionsFromProgram(nodes, unusedFunctions)
+ }
+
+ private def removeFunctionsFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = {
+ nodes match {
+ case (x: FunctionDeclarationStatement) :: xs if unusedVariables(x.name) =>
+ removeFunctionsFromProgram(xs, unusedVariables)
+ case x :: xs =>
+ x :: removeFunctionsFromProgram(xs, unusedVariables)
+ case Nil =>
+ Nil
+ }
+ }
+
+ def getAllCalledFunctions(c: Constant): List[String] = c match {
+ case HalfWordConstant(cc, _) => getAllCalledFunctions(cc)
+ case SubbyteConstant(cc, _) => getAllCalledFunctions(cc)
+ case CompoundConstant(_, l, r) => getAllCalledFunctions(l) ++ getAllCalledFunctions(r)
+ case MemoryAddressConstant(th) => List(
+ th.name,
+ th.name.stripSuffix(".addr"),
+ th.name.stripSuffix(".hi"),
+ th.name.stripSuffix(".lo"),
+ th.name.stripSuffix(".addr.lo"),
+ th.name.stripSuffix(".addr.hi"))
+ case _ => Nil
+ }
+
+ def getAllCalledFunctions(expressions: List[Node]): List[String] = expressions.flatMap {
+ case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
+ case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil))
+ case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
+ case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
+ case s: Statement => getAllCalledFunctions(s.getAllExpressions)
+ case s: VariableExpression => List(
+ s.name,
+ s.name.stripSuffix(".addr"),
+ s.name.stripSuffix(".hi"),
+ s.name.stripSuffix(".lo"),
+ s.name.stripSuffix(".addr.lo"),
+ s.name.stripSuffix(".addr.hi"))
+ case s: LiteralExpression => Nil
+ case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
+ case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2))
+ case FunctionCallExpression(name, xs) => name :: getAllCalledFunctions(xs)
+ case IndexedExpression(arr, index) => arr :: getAllCalledFunctions(List(index))
+ case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
+ case _ => Nil
+ }
+
+}
diff --git a/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala
new file mode 100644
index 00000000..83025f09
--- /dev/null
+++ b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala
@@ -0,0 +1,104 @@
+package millfork.node.opt
+
+import millfork.env._
+import millfork.error.ErrorReporting
+import millfork.node._
+
+/**
+ * @author Karol Stasiak
+ */
+object UnusedGlobalVariables extends NodeOptimization {
+
+ override def optimize(nodes: List[Node]): List[Node] = {
+
+ // TODO: volatile
+ val allNonvolatileGlobalVariables = nodes.flatMap {
+ case v: VariableDeclarationStatement => if (v.address.isDefined) Nil else List(v.name)
+ case v: ArrayDeclarationStatement => if (v.address.isDefined) Nil else List(v.name)
+ case _ => Nil
+ }.toSet
+ val allReadVariables = getAllReadVariables(nodes).toSet
+ val unusedVariables = allNonvolatileGlobalVariables -- allReadVariables
+ if (unusedVariables.nonEmpty) {
+ ErrorReporting.debug("Removing unused global variables: " + unusedVariables.mkString(", "))
+ }
+ removeVariablesFromProgram(nodes, unusedVariables.flatMap(v => Set(v, v + ".hi", v + ".lo")))
+ }
+
+ private def removeVariablesFromProgram(nodes: List[Node], unusedVariables: Set[String]): List[Node] = {
+ nodes match {
+ case (x: ArrayDeclarationStatement) :: xs if unusedVariables(x.name) => removeVariablesFromProgram(xs, unusedVariables)
+ case (x: VariableDeclarationStatement) :: xs if unusedVariables(x.name) => removeVariablesFromProgram(xs, unusedVariables)
+ case (x: FunctionDeclarationStatement) :: xs =>
+ x.copy(statements = x.statements.map(s => removeVariablesFromStatement(s, unusedVariables))) :: removeVariablesFromProgram(xs, unusedVariables)
+ case x :: xs =>
+ x :: removeVariablesFromProgram(xs, unusedVariables)
+ case Nil =>
+ Nil
+ }
+ }
+
+ def getAllReadVariables(c: Constant): List[String] = c match {
+ case HalfWordConstant(cc, _) => getAllReadVariables(cc)
+ case SubbyteConstant(cc, _) => getAllReadVariables(cc)
+ case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r)
+ case MemoryAddressConstant(th) => List(th.name.takeWhile(_ != '.'))
+ case _ => Nil
+ }
+
+ def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap {
+ case s: VariableDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.initialValue.toList)
+ case s: ArrayDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.elements.getOrElse(Nil))
+ case s: FunctionDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.statements.getOrElse(Nil))
+ case Assignment(VariableExpression(_), expr) => getAllReadVariables(expr :: Nil)
+ case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => getAllReadVariables(params)
+ case s: Statement => getAllReadVariables(s.getAllExpressions)
+ case s: VariableExpression => List(s.name.takeWhile(_ != '.'))
+ case s: LiteralExpression => Nil
+ case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil)
+ case SumExpression(xs, _) => getAllReadVariables(xs.map(_._2))
+ case FunctionCallExpression(name, xs) => name :: getAllReadVariables(xs)
+ case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index))
+ case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l))
+ case _ => Nil
+ }
+
+ def removeVariablesFromStatement(statements: List[Statement], globalsToRemove: Set[String]): List[Statement] = statements.flatMap {
+ case s: VariableDeclarationStatement =>
+ if (globalsToRemove(s.name)) None else Some(s)
+ case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") =>
+ if (globalsToRemove(n)) params.map(ExpressionStatement) else Some(s)
+ case s@Assignment(VariableExpression(n), expr) =>
+ if (globalsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s)
+ case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) =>
+ if (globalsToRemove(h)) {
+ if (globalsToRemove(l))
+ Some(ExpressionStatement(expr))
+ else
+ Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr))
+ } else {
+ if (globalsToRemove(l))
+ Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr))
+ else
+ Some(s)
+ }
+ case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) =>
+ if (globalsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr))
+ else Some(s)
+ case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) =>
+ if (globalsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr))
+ else Some(s)
+ case s: IfStatement =>
+ Some(s.copy(
+ thenBranch = removeVariablesFromStatement(s.thenBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]],
+ elseBranch = removeVariablesFromStatement(s.elseBranch, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s: WhileStatement =>
+ Some(s.copy(
+ body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s: DoWhileStatement =>
+ Some(s.copy(
+ body = removeVariablesFromStatement(s.body, globalsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s => Some(s)
+ }
+
+}
diff --git a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala
new file mode 100644
index 00000000..5b66d389
--- /dev/null
+++ b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala
@@ -0,0 +1,114 @@
+package millfork.node.opt
+
+import millfork.assembly.AssemblyLine
+import millfork.env._
+import millfork.error.ErrorReporting
+import millfork.node._
+
+/**
+ * @author Karol Stasiak
+ */
+object UnusedLocalVariables extends NodeOptimization {
+
+ override def optimize(nodes: List[Node]): List[Node] = nodes match {
+ case (x: FunctionDeclarationStatement) :: xs =>
+ x.copy(statements = x.statements.map(optimizeVariables)) :: optimize(xs)
+ case x :: xs =>
+ x :: optimize(xs)
+ case Nil =>
+ Nil
+ }
+
+ def getAllLocalVariables(statements: List[Statement]): List[String] = statements.flatMap {
+ case v: VariableDeclarationStatement => List(v.name)
+ case x: IfStatement => getAllLocalVariables(x.thenBranch) ++ getAllLocalVariables(x.elseBranch)
+ case x: WhileStatement => getAllLocalVariables(x.body)
+ case x: DoWhileStatement => getAllLocalVariables(x.body)
+ case _ => Nil
+ }
+
+ def getAllReadVariables(c: Constant): List[String] = c match {
+ case HalfWordConstant(cc, _) => getAllReadVariables(cc)
+ case SubbyteConstant(cc, _) => getAllReadVariables(cc)
+ case CompoundConstant(_, l, r) => getAllReadVariables(l) ++ getAllReadVariables(r)
+ case MemoryAddressConstant(th) => List(
+ th.name,
+ th.name.stripSuffix(".addr"),
+ th.name.stripSuffix(".hi"),
+ th.name.stripSuffix(".lo"),
+ th.name.stripSuffix(".addr.lo"),
+ th.name.stripSuffix(".addr.hi"))
+ case _ => Nil
+ }
+
+ def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap {
+ case s: VariableExpression => List(
+ s.name,
+ s.name.stripSuffix(".addr"),
+ s.name.stripSuffix(".hi"),
+ s.name.stripSuffix(".lo"),
+ s.name.stripSuffix(".addr.lo"),
+ s.name.stripSuffix(".addr.hi"))
+ case s: LiteralExpression => Nil
+ case HalfWordExpression(param, _) => getAllReadVariables(param :: Nil)
+ case SumExpression(xs, _) => getAllReadVariables(xs.map(_._2))
+ case FunctionCallExpression(_, xs) => getAllReadVariables(xs)
+ case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index))
+ case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l))
+ case _ => Nil
+ }
+
+
+ def optimizeVariables(statements: List[Statement]): List[Statement] = {
+ val allLocals = getAllLocalVariables(statements)
+ val allRead = getAllReadVariables(statements.flatMap {
+ case Assignment(VariableExpression(_), expression) => List(expression)
+ case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => params
+ case x => x.getAllExpressions
+ }).toSet
+ val localsToRemove = allLocals.filterNot(allRead).toSet
+ if (localsToRemove.nonEmpty) {
+ ErrorReporting.debug("Removing unused local variables: " + localsToRemove.mkString(", "))
+ }
+ removeVariables(statements, localsToRemove)
+ }
+
+ def removeVariables(statements: List[Statement], localsToRemove: Set[String]): List[Statement] = statements.flatMap {
+ case s: VariableDeclarationStatement =>
+ if (localsToRemove(s.name)) None else Some(s)
+ case s@ExpressionStatement(FunctionCallExpression(op, VariableExpression(n) :: params)) if op.endsWith("=") =>
+ if (localsToRemove(n)) params.map(ExpressionStatement) else Some(s)
+ case s@Assignment(VariableExpression(n), expr) =>
+ if (localsToRemove(n)) Some(ExpressionStatement(expr)) else Some(s)
+ case s@Assignment(SeparateBytesExpression(VariableExpression(h), VariableExpression(l)), expr) =>
+ if (localsToRemove(h)) {
+ if (localsToRemove(l))
+ Some(ExpressionStatement(expr))
+ else
+ Some(Assignment(SeparateBytesExpression(BlackHoleExpression, VariableExpression(l)), expr))
+ } else {
+ if (localsToRemove(l))
+ Some(Assignment(SeparateBytesExpression(VariableExpression(h), BlackHoleExpression), expr))
+ else
+ Some(s)
+ }
+ case s@Assignment(SeparateBytesExpression(h, VariableExpression(l)), expr) =>
+ if (localsToRemove(l)) Some(Assignment(SeparateBytesExpression(h, BlackHoleExpression), expr))
+ else Some(s)
+ case s@Assignment(SeparateBytesExpression(VariableExpression(h), l), expr) =>
+ if (localsToRemove(h)) Some(Assignment(SeparateBytesExpression(BlackHoleExpression, l), expr))
+ else Some(s)
+ case s: IfStatement =>
+ Some(s.copy(
+ thenBranch = removeVariables(s.thenBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]],
+ elseBranch = removeVariables(s.elseBranch, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s: WhileStatement =>
+ Some(s.copy(
+ body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s: DoWhileStatement =>
+ Some(s.copy(
+ body = removeVariables(s.body, localsToRemove).asInstanceOf[List[ExecutableStatement]]))
+ case s => Some(s)
+ }
+
+}
diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala
new file mode 100644
index 00000000..d4f06231
--- /dev/null
+++ b/src/main/scala/millfork/output/Assembler.scala
@@ -0,0 +1,612 @@
+package millfork.output
+
+import millfork.assembly.opt.AssemblyOptimization
+import millfork.assembly.{AddrMode, AssemblyLine, Opcode}
+import millfork.compiler.{CompilationContext, MlCompiler}
+import millfork.env._
+import millfork.error.ErrorReporting
+import millfork.node.CallGraph
+import millfork.{CompilationFlag, CompilationOptions}
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+case class AssemblerOutput(code: Array[Byte], asm: Array[String], labels: List[(String, Int)])
+
+class Assembler(private val rootEnv: Environment) {
+
+ var env = rootEnv.allThings
+ var unoptimizedCodeSize = 0
+ var optimizedCodeSize = 0
+ var initializedArraysSize = 0
+
+ val mem = new CompiledMemory
+ val labelMap = mutable.Map[String, Int]()
+ val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]()
+ val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]()
+
+ def writeByte(bank: Int, addr: Int, value: Byte): Unit = {
+ if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
+ mem.banks(bank).occupied(addr) = true
+ mem.banks(bank).readable(addr) = true
+ mem.banks(bank).output(addr) = value.toByte
+ }
+
+ def writeByte(bank: Int, addr: Int, value: Constant): Unit = {
+ if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
+ mem.banks(bank).occupied(addr) = true
+ mem.banks(bank).readable(addr) = true
+ value match {
+ case NumericConstant(x, _) =>
+ if (x > 0xffff) ErrorReporting.error("Byte overflow")
+ mem.banks(0).output(addr) = x.toByte
+ case _ =>
+ bytesToWriteLater += addr -> value
+ }
+ }
+
+ def writeWord(bank: Int, addr: Int, value: Constant): Unit = {
+ if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
+ mem.banks(bank).occupied(addr) = true
+ mem.banks(bank).occupied(addr + 1) = true
+ mem.banks(bank).readable(addr) = true
+ mem.banks(bank).readable(addr + 1) = true
+ value match {
+ case NumericConstant(x, _) =>
+ if (x > 0xffff) ErrorReporting.error("Word overflow")
+ mem.banks(bank).output(addr) = x.toByte
+ mem.banks(bank).output(addr + 1) = (x >> 8).toByte
+ case _ =>
+ wordsToWriteLater += addr -> value
+ }
+ }
+
+ def deepConstResolve(c: Constant): Long = {
+ c match {
+ case NumericConstant(v, _) => v
+ case MemoryAddressConstant(th) =>
+ if (labelMap.contains(th.name)) return labelMap(th.name)
+ if (labelMap.contains(th.name + "`")) return labelMap(th.name)
+ if (labelMap.contains(th.name + ".addr")) return labelMap(th.name)
+ val x1 = env.maybeGet[ConstantThing](th.name).map(_.value)
+ val x2 = env.maybeGet[ConstantThing](th.name + "`").map(_.value)
+ val x3 = env.maybeGet[NormalFunction](th.name).flatMap(_.address)
+ val x4 = env.maybeGet[ConstantThing](th.name + ".addr").map(_.value)
+ val x5 = env.maybeGet[RelativeVariable](th.name).map(_.address)
+ val x6 = env.maybeGet[ConstantThing](th.name.stripSuffix(".array") + ".addr").map(_.value)
+ val x = x1.orElse(x2).orElse(x3).orElse(x4).orElse(x5).orElse(x6)
+ x match {
+ case Some(cc) =>
+ deepConstResolve(cc)
+ case None =>
+ println(th)
+ ???
+ }
+ case HalfWordConstant(cc, true) => deepConstResolve(cc).>>>(8).&(0xff)
+ case HalfWordConstant(cc, false) => deepConstResolve(cc).&(0xff)
+ case SubbyteConstant(cc, i) => deepConstResolve(cc).>>>(i * 8).&(0xff)
+ case CompoundConstant(operator, lc, rc) =>
+ val l = deepConstResolve(lc)
+ val r = deepConstResolve(rc)
+ operator match {
+ case MathOperator.Plus => l + r
+ case MathOperator.Minus => l - r
+ case MathOperator.Times => l * r
+ case MathOperator.Shl => l << r
+ case MathOperator.Shr => l >>> r
+ case MathOperator.DecimalPlus => asDecimal(l, r, _ + _)
+ case MathOperator.DecimalMinus => asDecimal(l, r, _ - _)
+ case MathOperator.DecimalTimes => asDecimal(l, r, _ * _)
+ case MathOperator.DecimalShl => asDecimal(l, 1 << r, _ * _)
+ case MathOperator.DecimalShr => asDecimal(l, 1 << r, _ / _)
+ case MathOperator.And => l & r
+ case MathOperator.Exor => l ^ r
+ case MathOperator.Or => l | r
+ }
+ }
+ }
+
+ private def parseNormalToDecimalValue(a: Long): Long = {
+ if (a < 0) -parseNormalToDecimalValue(-a)
+ var x = a
+ var result = 0L
+ var multiplier = 1L
+ while (x > 0) {
+ result += multiplier * (a % 16L)
+ x /= 16L
+ multiplier *= 10L
+ }
+ result
+ }
+
+ private def storeDecimalValueInNormalRespresentation(a: Long): Long = {
+ if (a < 0) -storeDecimalValueInNormalRespresentation(-a)
+ var x = a
+ var result = 0L
+ var multiplier = 1L
+ while (x > 0) {
+ result += multiplier * (a % 10L)
+ x /= 10L
+ multiplier *= 16L
+ }
+ result
+ }
+
+ private def asDecimal(a: Long, b: Long, f: (Long, Long) => Long): Long =
+ storeDecimalValueInNormalRespresentation(f(parseNormalToDecimalValue(a), parseNormalToDecimalValue(b)))
+
+ def assemble(callGraph: CallGraph, optimizations: Seq[AssemblyOptimization], options: CompilationOptions): AssemblerOutput = {
+ val platform = options.platform
+
+ val assembly = mutable.ArrayBuffer[String]()
+
+ env.allPreallocatables.foreach {
+ case InitializedArray(name, Some(NumericConstant(address, _)), items) =>
+ var index = address.toInt
+ assembly.append("* = $" + index.toHexString)
+ assembly.append(name)
+ for (item <- items) {
+ writeByte(0, index, item)
+ assembly.append(" !byte " + item)
+ mem.banks(0).writeable(index) = true
+ index += 1
+ }
+ initializedArraysSize += items.length
+ case InitializedArray(name, Some(_), items) => ???
+ case f: NormalFunction if f.address.isDefined =>
+ var index = f.address.get.asInstanceOf[NumericConstant].value.toInt
+ labelMap(f.name) = index
+ compileFunction(f, index, optimizations, assembly, options)
+ case _ =>
+ }
+
+ var index = platform.org
+ env.allPreallocatables.foreach {
+ case f: NormalFunction if f.address.isEmpty && f.name == "main" =>
+ labelMap(f.name) = index
+ index = compileFunction(f, index, optimizations, assembly, options)
+ case _ =>
+ }
+ env.allPreallocatables.foreach {
+ case f: NormalFunction if f.address.isEmpty && f.name != "main" =>
+ labelMap(f.name) = index
+ index = compileFunction(f, index, optimizations, assembly, options)
+ case _ =>
+ }
+ env.allPreallocatables.foreach {
+ case InitializedArray(name, None, items) =>
+ labelMap(name) = index
+ assembly.append("* = $" + index.toHexString)
+ assembly.append(name)
+ for (item <- items) {
+ writeByte(0, index, item)
+ assembly.append(" !byte " + item)
+ mem.banks(0).writeable(index) = true
+ index += 1
+ }
+ initializedArraysSize += items.length
+ case _ =>
+ }
+ val allocator = platform.allocator
+ allocator.notifyAboutEndOfCode(index)
+ allocator.onEachByte = { addr =>
+ mem.banks(0).readable(addr) = true
+ mem.banks(0).writeable(addr) = true
+ }
+ env.allocateVariables(None, callGraph, allocator, options, labelMap.put)
+
+ env = rootEnv.allThings
+
+ for ((addr, b) <- bytesToWriteLater) {
+ val value = deepConstResolve(b)
+ mem.banks(0).output(addr) = value.toByte
+ }
+ for ((addr, b) <- wordsToWriteLater) {
+ val value = deepConstResolve(b)
+ mem.banks(0).output(addr) = value.toByte
+ mem.banks(0).output(addr + 1) = value.>>>(8).toByte
+ }
+
+ val start = mem.banks(0).occupied.indexOf(true)
+ val end = mem.banks(0).occupied.lastIndexOf(true)
+ val length = end - start + 1
+ mem.banks(0).start = start
+ mem.banks(0).end = end
+
+ labelMap.toList.sorted.foreach {case (l, v) =>
+ assembly += f"$l%-30s = $$$v%04X"
+ }
+ labelMap.toList.sortBy{case (a,b) => b->a}.foreach {case (l, v) =>
+ assembly += f" ; $$$v%04X = $l%s"
+ }
+
+ AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList)
+ }
+
+ private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
+ ErrorReporting.debug("Compiling: " + f.name, f.position)
+ var index = startFrom
+ assOut.append("* = $" + startFrom.toHexString)
+ val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize
+ unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum
+ val code = optimizations.foldLeft(unoptimized) { (c, opt) =>
+ opt.optimize(f, c, options)
+ }
+ optimizedCodeSize += code.map(_.sizeInBytes).sum
+ import millfork.assembly.AddrMode._
+ import millfork.assembly.Opcode._
+ for (instr <- code) {
+ if (instr.isPrintable) {
+ assOut.append(instr.toString)
+ }
+ instr match {
+ case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(labelName)), _) =>
+ labelMap(labelName) = index
+ case AssemblyLine(_, DoesNotExist, _, _) =>
+ ()
+ case AssemblyLine(op, Implied, _, _) =>
+ writeByte(0, index, Assembler.opcodeFor(op, Implied, options))
+ index += 1
+ case AssemblyLine(op, Relative, param, _) =>
+ writeByte(0, index, Assembler.opcodeFor(op, Relative, options))
+ writeByte(0, index + 1, param - (index + 2))
+ index += 2
+ case AssemblyLine(op, am@(Immediate | ZeroPage | ZeroPageX | ZeroPageY | IndexedY | IndexedX | ZeroPageIndirect), param, _) =>
+ writeByte(0, index, Assembler.opcodeFor(op, am, options))
+ writeByte(0, index + 1, param)
+ index += 2
+ case AssemblyLine(op, am@(Absolute | AbsoluteY | AbsoluteX | Indirect | AbsoluteIndexedX), param, _) =>
+ writeByte(0, index, Assembler.opcodeFor(op, am, options))
+ writeWord(0, index + 1, param)
+ index += 3
+ }
+ }
+ index
+ }
+}
+
+object Assembler {
+ val opcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
+ val illegalOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
+ val cmosOpcodes = mutable.Map[(Opcode.Value, AddrMode.Value), Byte]()
+
+ def opcodeFor(opcode: Opcode.Value, addrMode: AddrMode.Value, options: CompilationOptions): Byte = {
+ val key = opcode -> addrMode
+ opcodes.get(key) match {
+ case Some(v) => v
+ case None =>
+ illegalOpcodes.get(key) match {
+ case Some(v) =>
+ if (options.flag(CompilationFlag.EmitIllegals)) v
+ else ErrorReporting.fatal("Cannot assemble an illegal opcode " + key)
+ case None =>
+ cmosOpcodes.get(key) match {
+ case Some(v) =>
+ if (options.flag(CompilationFlag.EmitCmosOpcodes)) v
+ else ErrorReporting.fatal("Cannot assemble a CMOS opcode " + key)
+ case None =>
+ ErrorReporting.fatal("Cannot assemble an unknown opcode " + key)
+ }
+ }
+ }
+ }
+
+ private def op(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
+ if (x < 0 || x > 0xff) ???
+ opcodes(op -> am) = x.toByte
+ if (am == AddrMode.Relative) opcodes(op -> AddrMode.Immediate) = x.toByte
+ }
+
+ private def cm(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
+ if (x < 0 || x > 0xff) ???
+ cmosOpcodes(op -> am) = x.toByte
+ }
+
+ private def il(op: Opcode.Value, am: AddrMode.Value, x: Int): Unit = {
+ if (x < 0 || x > 0xff) ???
+ illegalOpcodes(op -> am) = x.toByte
+ }
+
+ def getStandardLegalOpcodes: Set[Int] = opcodes.values.map(_ & 0xff).toSet
+
+ import AddrMode._
+ import Opcode._
+
+ op(ADC, Immediate, 0x69)
+ op(ADC, ZeroPage, 0x65)
+ op(ADC, ZeroPageX, 0x75)
+ op(ADC, Absolute, 0x6D)
+ op(ADC, AbsoluteX, 0x7D)
+ op(ADC, AbsoluteY, 0x79)
+ op(ADC, IndexedX, 0x61)
+ op(ADC, IndexedY, 0x71)
+
+ op(AND, Immediate, 0x29)
+ op(AND, ZeroPage, 0x25)
+ op(AND, ZeroPageX, 0x35)
+ op(AND, Absolute, 0x2D)
+ op(AND, AbsoluteX, 0x3D)
+ op(AND, AbsoluteY, 0x39)
+ op(AND, IndexedX, 0x21)
+ op(AND, IndexedY, 0x31)
+
+ op(ASL, Implied, 0x0A)
+ op(ASL, ZeroPage, 0x06)
+ op(ASL, ZeroPageX, 0x16)
+ op(ASL, Absolute, 0x0E)
+ op(ASL, AbsoluteX, 0x1E)
+
+ op(BIT, ZeroPage, 0x24)
+ op(BIT, Absolute, 0x2C)
+
+ op(BPL, Relative, 0x10)
+ op(BMI, Relative, 0x30)
+ op(BVC, Relative, 0x50)
+ op(BVS, Relative, 0x70)
+ op(BCC, Relative, 0x90)
+ op(BCS, Relative, 0xB0)
+ op(BNE, Relative, 0xD0)
+ op(BEQ, Relative, 0xF0)
+
+ op(BRK, Implied, 0)
+
+ op(CMP, Immediate, 0xC9)
+ op(CMP, ZeroPage, 0xC5)
+ op(CMP, ZeroPageX, 0xD5)
+ op(CMP, Absolute, 0xCD)
+ op(CMP, AbsoluteX, 0xDD)
+ op(CMP, AbsoluteY, 0xD9)
+ op(CMP, IndexedX, 0xC1)
+ op(CMP, IndexedY, 0xD1)
+
+ op(CPX, Immediate, 0xE0)
+ op(CPX, ZeroPage, 0xE4)
+ op(CPX, Absolute, 0xEC)
+
+ op(CPY, Immediate, 0xC0)
+ op(CPY, ZeroPage, 0xC4)
+ op(CPY, Absolute, 0xCC)
+
+ op(DEC, ZeroPage, 0xC6)
+ op(DEC, ZeroPageX, 0xD6)
+ op(DEC, Absolute, 0xCE)
+ op(DEC, AbsoluteX, 0xDE)
+
+ op(EOR, Immediate, 0x49)
+ op(EOR, ZeroPage, 0x45)
+ op(EOR, ZeroPageX, 0x55)
+ op(EOR, Absolute, 0x4D)
+ op(EOR, AbsoluteX, 0x5D)
+ op(EOR, AbsoluteY, 0x59)
+ op(EOR, IndexedX, 0x41)
+ op(EOR, IndexedY, 0x51)
+
+ op(INC, ZeroPage, 0xE6)
+ op(INC, ZeroPageX, 0xF6)
+ op(INC, Absolute, 0xEE)
+ op(INC, AbsoluteX, 0xFE)
+
+ op(CLC, Implied, 0x18)
+ op(SEC, Implied, 0x38)
+ op(CLI, Implied, 0x58)
+ op(SEI, Implied, 0x78)
+ op(CLV, Implied, 0xB8)
+ op(CLD, Implied, 0xD8)
+ op(SED, Implied, 0xF8)
+
+ op(JMP, Absolute, 0x4C)
+ op(JMP, Indirect, 0x6C)
+
+ op(JSR, Absolute, 0x20)
+
+ op(LDA, Immediate, 0xA9)
+ op(LDA, ZeroPage, 0xA5)
+ op(LDA, ZeroPageX, 0xB5)
+ op(LDA, Absolute, 0xAD)
+ op(LDA, AbsoluteX, 0xBD)
+ op(LDA, AbsoluteY, 0xB9)
+ op(LDA, IndexedX, 0xA1)
+ op(LDA, IndexedY, 0xB1)
+
+ op(LDX, Immediate, 0xA2)
+ op(LDX, ZeroPage, 0xA6)
+ op(LDX, ZeroPageY, 0xB6)
+ op(LDX, Absolute, 0xAE)
+ op(LDX, AbsoluteY, 0xBE)
+
+ op(LDY, Immediate, 0xA0)
+ op(LDY, ZeroPage, 0xA4)
+ op(LDY, ZeroPageX, 0xB4)
+ op(LDY, Absolute, 0xAC)
+ op(LDY, AbsoluteX, 0xBC)
+
+ op(LSR, Implied, 0x4A)
+ op(LSR, ZeroPage, 0x46)
+ op(LSR, ZeroPageX, 0x56)
+ op(LSR, Absolute, 0x4E)
+ op(LSR, AbsoluteX, 0x5E)
+
+ op(NOP, Implied, 0xEA)
+
+ op(ORA, Immediate, 0x09)
+ op(ORA, ZeroPage, 0x05)
+ op(ORA, ZeroPageX, 0x15)
+ op(ORA, Absolute, 0x0D)
+ op(ORA, AbsoluteX, 0x1D)
+ op(ORA, AbsoluteY, 0x19)
+ op(ORA, IndexedX, 0x01)
+ op(ORA, IndexedY, 0x11)
+
+ op(TAX, Implied, 0xAA)
+ op(TXA, Implied, 0x8A)
+ op(DEX, Implied, 0xCA)
+ op(INX, Implied, 0xE8)
+ op(TAY, Implied, 0xA8)
+ op(TYA, Implied, 0x98)
+ op(DEY, Implied, 0x88)
+ op(INY, Implied, 0xC8)
+
+ op(ROL, Implied, 0x2A)
+ op(ROL, ZeroPage, 0x26)
+ op(ROL, ZeroPageX, 0x36)
+ op(ROL, Absolute, 0x2E)
+ op(ROL, AbsoluteX, 0x3E)
+
+ op(ROR, Implied, 0x6A)
+ op(ROR, ZeroPage, 0x66)
+ op(ROR, ZeroPageX, 0x76)
+ op(ROR, Absolute, 0x6E)
+ op(ROR, AbsoluteX, 0x7E)
+
+ op(RTI, Implied, 0x40)
+ op(RTS, Implied, 0x60)
+
+ op(SBC, Immediate, 0xE9)
+ op(SBC, ZeroPage, 0xE5)
+ op(SBC, ZeroPageX, 0xF5)
+ op(SBC, Absolute, 0xED)
+ op(SBC, AbsoluteX, 0xFD)
+ op(SBC, AbsoluteY, 0xF9)
+ op(SBC, IndexedX, 0xE1)
+ op(SBC, IndexedY, 0xF1)
+
+ op(STA, ZeroPage, 0x85)
+ op(STA, ZeroPageX, 0x95)
+ op(STA, Absolute, 0x8D)
+ op(STA, AbsoluteX, 0x9D)
+ op(STA, AbsoluteY, 0x99)
+ op(STA, IndexedX, 0x81)
+ op(STA, IndexedY, 0x91)
+
+ op(TXS, Implied, 0x9A)
+ op(TSX, Implied, 0xBA)
+ op(PHA, Implied, 0x48)
+ op(PLA, Implied, 0x68)
+ op(PHP, Implied, 0x08)
+ op(PLP, Implied, 0x28)
+
+ op(STX, ZeroPage, 0x86)
+ op(STX, ZeroPageY, 0x96)
+ op(STX, Absolute, 0x8E)
+
+ op(STY, ZeroPage, 0x84)
+ op(STY, ZeroPageX, 0x94)
+ op(STY, Absolute, 0x8C)
+
+ il(LAX, ZeroPage, 0xA7)
+ il(LAX, ZeroPageY, 0xB7)
+ il(LAX, Absolute, 0xAF)
+ il(LAX, AbsoluteY, 0xBF)
+ il(LAX, IndexedX, 0xA3)
+ il(LAX, IndexedY, 0xB3)
+
+ il(SAX, ZeroPage, 0x87)
+ il(SAX, ZeroPageY, 0x97)
+ il(SAX, Absolute, 0x8F)
+ il(TAS, AbsoluteY, 0x9B)
+ il(AHX, AbsoluteY, 0x9F)
+ il(SAX, IndexedX, 0x83)
+ il(AHX, IndexedY, 0x93)
+
+ il(ANC, Immediate, 0x0B)
+ il(ALR, Immediate, 0x4B)
+ il(ARR, Immediate, 0x6B)
+ il(XAA, Immediate, 0x8B)
+ il(LXA, Immediate, 0xAB)
+ il(SBX, Immediate, 0xCB)
+
+ il(SLO, ZeroPage, 0x07)
+ il(SLO, ZeroPageX, 0x17)
+ il(SLO, IndexedX, 0x03)
+ il(SLO, IndexedY, 0x13)
+ il(SLO, Absolute, 0x0F)
+ il(SLO, AbsoluteX, 0x1F)
+ il(SLO, AbsoluteY, 0x1B)
+
+ il(RLA, ZeroPage, 0x27)
+ il(RLA, ZeroPageX, 0x37)
+ il(RLA, IndexedX, 0x23)
+ il(RLA, IndexedY, 0x33)
+ il(RLA, Absolute, 0x2F)
+ il(RLA, AbsoluteX, 0x3F)
+ il(RLA, AbsoluteY, 0x3B)
+
+ il(SRE, ZeroPage, 0x47)
+ il(SRE, ZeroPageX, 0x57)
+ il(SRE, IndexedX, 0x43)
+ il(SRE, IndexedY, 0x53)
+ il(SRE, Absolute, 0x4F)
+ il(SRE, AbsoluteX, 0x5F)
+ il(SRE, AbsoluteY, 0x5B)
+
+ il(RRA, ZeroPage, 0x67)
+ il(RRA, ZeroPageX, 0x77)
+ il(RRA, IndexedX, 0x63)
+ il(RRA, IndexedY, 0x73)
+ il(RRA, Absolute, 0x6F)
+ il(RRA, AbsoluteX, 0x7F)
+ il(RRA, AbsoluteY, 0x7B)
+
+ il(DCP, ZeroPage, 0xC7)
+ il(DCP, ZeroPageX, 0xD7)
+ il(DCP, IndexedX, 0xC3)
+ il(DCP, IndexedY, 0xD3)
+ il(DCP, Absolute, 0xCF)
+ il(DCP, AbsoluteX, 0xDF)
+ il(DCP, AbsoluteY, 0xDB)
+
+ il(ISC, ZeroPage, 0xE7)
+ il(ISC, ZeroPageX, 0xF7)
+ il(ISC, IndexedX, 0xE3)
+ il(ISC, IndexedY, 0xF3)
+ il(ISC, Absolute, 0xEF)
+ il(ISC, AbsoluteX, 0xFF)
+ il(ISC, AbsoluteY, 0xFB)
+
+ il(NOP, Immediate, 0x80)
+ il(NOP, ZeroPage, 0x44)
+ il(NOP, ZeroPageX, 0x54)
+ il(NOP, Absolute, 0x5C)
+ il(NOP, AbsoluteX, 0x1C)
+
+ cm(NOP, Immediate, 0x02)
+ cm(NOP, ZeroPage, 0x44)
+ cm(NOP, ZeroPageX, 0x54)
+ cm(NOP, Absolute, 0x5C)
+
+ cm(STZ, ZeroPage, 0x64)
+ cm(STZ, ZeroPageX, 0x74)
+ cm(STZ, Absolute, 0x9C)
+ cm(STZ, AbsoluteX, 0x9E)
+
+ cm(PHX, Implied, 0xDA)
+ cm(PHY, Implied, 0x5A)
+ cm(PLX, Implied, 0xFA)
+ cm(PLY, Implied, 0x7A)
+
+ cm(ORA, ZeroPageIndirect, 0x12)
+ cm(AND, ZeroPageIndirect, 0x32)
+ cm(EOR, ZeroPageIndirect, 0x52)
+ cm(ADC, ZeroPageIndirect, 0x72)
+ cm(STA, ZeroPageIndirect, 0x92)
+ cm(LDA, ZeroPageIndirect, 0xB2)
+ cm(CMP, ZeroPageIndirect, 0xD2)
+ cm(SBC, ZeroPageIndirect, 0xF2)
+
+ cm(TSB, ZeroPage, 0x04)
+ cm(TSB, Absolute, 0x0C)
+ cm(TRB, ZeroPage, 0x14)
+ cm(TRB, Absolute, 0x1C)
+
+ cm(BIT, ZeroPageX, 0x34)
+ cm(BIT, AbsoluteX, 0x3C)
+ cm(INC, Implied, 0x1A)
+ cm(DEC, Implied, 0x3A)
+ cm(JMP, AbsoluteIndexedX, 0x7C)
+ cm(WAI, Implied, 0xCB)
+ cm(STP, Implied, 0xDB)
+
+}
diff --git a/src/main/scala/millfork/output/CompiledMemory.scala b/src/main/scala/millfork/output/CompiledMemory.scala
new file mode 100644
index 00000000..82f91768
--- /dev/null
+++ b/src/main/scala/millfork/output/CompiledMemory.scala
@@ -0,0 +1,29 @@
+package millfork.output
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+class CompiledMemory {
+ val banks = mutable.Map(0 -> new MemoryBank)
+}
+
+class MemoryBank {
+ def readByte(addr: Int) = output(addr) & 0xff
+
+ def readWord(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8)
+
+ def readMedium(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16)
+
+ def readLong(addr: Int) = readByte(addr) + (readByte(addr + 1) << 8) + (readByte(addr + 2) << 16) + (readByte(addr + 3) << 24)
+
+ def readWord(addrHi: Int, addrLo: Int) = readByte(addrLo) + (readByte(addrHi) << 8)
+
+ val output = Array.fill[Byte](1 << 16)(0)
+ val occupied = Array.fill(1 << 16)(false)
+ val readable = Array.fill(1 << 16)(false)
+ val writeable = Array.fill(1 << 16)(false)
+ var start: Int = 0
+ var end: Int = 0
+}
diff --git a/src/main/scala/millfork/output/OutputPackager.scala b/src/main/scala/millfork/output/OutputPackager.scala
new file mode 100644
index 00000000..997d7881
--- /dev/null
+++ b/src/main/scala/millfork/output/OutputPackager.scala
@@ -0,0 +1,60 @@
+package millfork.output
+
+import java.io.ByteArrayOutputStream
+
+/**
+ * @author Karol Stasiak
+ */
+trait OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte]
+}
+
+case class SequenceOutput(children: List[OutputPackager]) extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val baos = new ByteArrayOutputStream
+ children.foreach { c =>
+ val a = c.packageOutput(mem, bank)
+ baos.write(a, 0, a.length)
+ }
+ baos.toByteArray
+ }
+}
+
+case class ConstOutput(byte: Byte) extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = Array(byte)
+}
+
+case class CurrentBankFragmentOutput(start: Int, end: Int) extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val b = mem.banks(bank)
+ b.output.slice(start, end + 1)
+ }
+}
+
+case class BankFragmentOutput(alwaysBank: Int, start: Int, end: Int) extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val b = mem.banks(alwaysBank)
+ b.output.slice(start, end + 1)
+ }
+}
+
+object StartAddressOutput extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val b = mem.banks(bank)
+ Array(b.start.toByte, b.start.>>(8).toByte)
+ }
+}
+
+object EndAddressOutput extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val b = mem.banks(bank)
+ Array(b.end.toByte, b.end.>>(8).toByte)
+ }
+}
+
+object AllocatedDataOutput extends OutputPackager {
+ def packageOutput(mem: CompiledMemory, bank: Int): Array[Byte] = {
+ val b = mem.banks(bank)
+ b.output.slice(b.start, b.end + 1)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/millfork/output/VariableAllocator.scala b/src/main/scala/millfork/output/VariableAllocator.scala
new file mode 100644
index 00000000..7263518c
--- /dev/null
+++ b/src/main/scala/millfork/output/VariableAllocator.scala
@@ -0,0 +1,96 @@
+package millfork.output
+
+import millfork.error.ErrorReporting
+import millfork.node.{CallGraph, VariableVertex}
+import millfork.{CompilationFlag, CompilationOptions}
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+
+sealed trait ByteAllocator {
+ def notifyAboutEndOfCode(org: Int): Unit
+
+ def allocateBytes(count: Int, options: CompilationOptions): Int
+}
+
+class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator {
+ private var nextByte = startAt
+
+ def allocateBytes(count: Int, options: CompilationOptions): Int = {
+ if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
+ val t = nextByte
+ nextByte += count
+ if (nextByte > endBefore) {
+ ErrorReporting.fatal("Out of high memory")
+ }
+ t
+ }
+
+ def notifyAboutEndOfCode(org: Int): Unit = ()
+}
+
+class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator {
+ var nextByte = 0x200
+
+ def allocateBytes(count: Int, options: CompilationOptions): Int = {
+ if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
+ val t = nextByte
+ nextByte += count
+ if (nextByte > endBefore) {
+ ErrorReporting.fatal("Out of high memory")
+ }
+ t
+ }
+
+ def notifyAboutEndOfCode(org: Int): Unit = nextByte = org
+}
+
+class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) {
+
+ private var pointerMap = mutable.Map[Int, Set[VariableVertex]]()
+ private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]()
+
+ var onEachByte: (Int => Unit) = _
+
+ def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = {
+ pointerMap.foreach { case (addr, alreadyThere) =>
+ if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
+ pointerMap(addr) += p
+ return addr
+ }
+ }
+ pointers match {
+ case Nil =>
+ ErrorReporting.fatal("Out of zero-page memory")
+ case next :: rest =>
+ pointers = rest
+ onEachByte(next)
+ onEachByte(next + 1)
+ pointerMap(next) = Set(p)
+ next
+ }
+ }
+
+ def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1)
+
+ def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = {
+ if (!variableMap.contains(count)) {
+ variableMap(count) = mutable.Map()
+ }
+ variableMap(count).foreach { case (a, alreadyThere) =>
+ if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
+ variableMap(count)(a) += p
+ return a
+ }
+ }
+ val addr = bytes.allocateBytes(count, options)
+ (addr to (addr + count)).foreach(onEachByte)
+ variableMap(count)(addr) = Set(p)
+ addr
+ }
+
+ def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org)
+}
diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala
new file mode 100644
index 00000000..667a5360
--- /dev/null
+++ b/src/main/scala/millfork/parser/MfParser.scala
@@ -0,0 +1,435 @@
+package millfork.parser
+
+import java.nio.file.{Files, Paths}
+
+import fastparse.all._
+import millfork.assembly.{AddrMode, Opcode}
+import millfork.env._
+import millfork.error.ErrorReporting
+import millfork.node._
+import millfork.{CompilationOptions, SeparatedList}
+
+/**
+ * @author Karol Stasiak
+ */
+case class MfParser(filename: String, input: String, currentDirectory: String, options: CompilationOptions) {
+
+ var lastPosition = Position(filename, 1, 1, 0)
+ var lastLabel = ""
+
+ def toAst: Parsed[Program] = program.parse(input + "\n\n\n")
+
+ private val lineStarts: Array[Int] = (0 +: input.zipWithIndex.filter(_._1 == '\n').map(_._2)).toArray
+
+ def position(label: String = ""): P[Position] = Index.map(i => indexToPosition(i, label))
+
+ def indexToPosition(i: Int, label: String): Position = {
+ val prefix = lineStarts.takeWhile(_ <= i)
+ val newPosition = Position(filename, prefix.length, i - prefix.last, i)
+ if (newPosition.cursor > lastPosition.cursor) {
+ lastPosition = newPosition
+ lastLabel = label
+ }
+ newPosition
+ }
+
+ val comment: P[Unit] = P("//" ~/ CharsWhile(c => c != '\n' && c != '\r', min = 0) ~ ("\r\n" | "\r" | "\n"))
+
+ val SWS: P[Unit] = P(CharsWhileIn(" \t", min = 1)).opaque("")
+
+ val HWS: P[Unit] = P(CharsWhileIn(" \t", min = 0)).opaque("")
+
+ val AWS: P[Unit] = P((CharIn(" \t\n\r;") | NoCut(comment)).rep(min = 0)).opaque("")
+
+ val EOL: P[Unit] = P(HWS ~ ("\r\n" | "\r" | "\n" | comment).opaque("") ~ AWS).opaque("")
+
+ val letter: P[String] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_").!)
+
+ val letterOrDigit: P[Unit] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890"))
+
+ val lettersOrDigits: P[String] = P(CharsWhileIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890", min = 0).!)
+
+ val identifier: P[String] = P((letter ~ lettersOrDigits).map { case (a, b) => a + b }).opaque("")
+
+ // def operator: P[String] = P(CharsWhileIn("!-+*/><=~|&^", min=1).!) // TODO: only valid operators
+
+ // TODO: 3-byte types
+ def size(value: Int, wordLiteral: Boolean, longLiteral: Boolean): Int =
+ if (value > 255 || value < -128 || wordLiteral)
+ if (value > 0xffff || longLiteral) 4 else 2
+ else 1
+
+ def sign(abs: Int, minus: Boolean): Int = if (minus) -abs else abs
+
+ val decimalAtom: P[LiteralExpression] =
+ for {
+ p <- position()
+ minus <- "-".!.?
+ s <- CharsWhileIn("1234567890", min = 1).!.opaque("") ~ !("x" | "b")
+ } yield {
+ val abs = Integer.parseInt(s, 10)
+ val value = sign(abs, minus.isDefined)
+ LiteralExpression(value, size(value, s.length > 3, s.length > 5)).pos(p)
+ }
+
+ val binaryAtom: P[LiteralExpression] =
+ for {
+ p <- position()
+ minus <- "-".!.?
+ _ <- P("0b" | "%") ~/ Pass
+ s <- CharsWhileIn("01", min = 1).!.opaque("")
+ } yield {
+ val abs = Integer.parseInt(s, 2)
+ val value = sign(abs, minus.isDefined)
+ LiteralExpression(value, size(value, s.length > 8, s.length > 16)).pos(p)
+ }
+
+ val hexAtom: P[LiteralExpression] =
+ for {
+ p <- position()
+ minus <- "-".!.?
+ _ <- P("0x" | "$") ~/ Pass
+ s <- CharsWhileIn("1234567890abcdefABCDEF", min = 1).!.opaque("")
+ } yield {
+ val abs = Integer.parseInt(s, 16)
+ val value = sign(abs, minus.isDefined)
+ LiteralExpression(value, size(value, s.length > 2, s.length > 4)).pos(p)
+ }
+
+ val literalAtom: P[LiteralExpression] = binaryAtom | hexAtom | decimalAtom
+
+ val atom: P[Expression] = P(literalAtom | (position() ~ identifier).map { case (p, i) => VariableExpression(i).pos(p) })
+
+ val mlOperators = List(
+ List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'="),
+ List("||", "^^"),
+ List("&&"),
+ List("==", "<=", ">=", "!=", "<", ">"),
+ List(":"),
+ List("+'", "-'", "<<'", ">>'", ">>>>", "+", "-", "&", "|", "^", "<<", ">>"),
+ List("*'", "*"))
+
+ val nonStatementLevel = 1 // everything but not `=`
+ val mathLevel = 4 // the `:` operator
+
+ def flags(allowed: String*): P[Set[String]] = StringIn(allowed: _*).!.rep(min = 0, sep = SWS).map(_.toSet).opaque("")
+
+ def variableDefinition(implicitlyGlobal: Boolean): P[DeclarationStatement] = for {
+ p <- position()
+ flags <- flags("const", "static", "volatile", "stack") ~ HWS
+ typ <- identifier ~ SWS
+ name <- identifier ~/ HWS ~/ Pass
+ addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("") ~ HWS
+ initialValue <- ("=" ~/ HWS ~/ mlExpression(1)).? ~ HWS
+ _ <- &(EOL) ~/ ""
+ } yield {
+ VariableDeclarationStatement(name, typ,
+ global = implicitlyGlobal || flags("static"),
+ stack = flags("stack"),
+ constant = flags("const"),
+ volatile = flags("volatile"),
+ initialValue, addr).pos(p)
+ }
+
+ val externFunctionBody: P[Option[List[Statement]]] = P("extern" ~/ PassWith(None))
+
+ val paramDefinition: P[ParameterDeclaration] = for {
+ p <- position()
+ typ <- identifier ~/ SWS ~/ Pass
+ name <- identifier ~/ Pass
+ } yield {
+ ParameterDeclaration(typ, ByVariable(name)).pos(p)
+ }
+
+ val appcSimple: P[ParamPassingConvention] = P("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "stack" | "a" | "x" | "y").!.map {
+ case "xy" => ByRegister(Register.XY)
+ case "yx" => ByRegister(Register.YX)
+ case "ax" => ByRegister(Register.AX)
+ case "ay" => ByRegister(Register.AY)
+ case "xa" => ByRegister(Register.XA)
+ case "ya" => ByRegister(Register.YA)
+ case "a" => ByRegister(Register.A)
+ case "x" => ByRegister(Register.X)
+ case "y" => ByRegister(Register.Y)
+ case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`")
+ }
+
+ val appcComplex: P[ParamPassingConvention] = P((("const" | "ref").! ~/ AWS).? ~ AWS ~ identifier) map {
+ case (None, name) => ByVariable(name)
+ case (Some("const"), name) => ByConstant(name)
+ case (Some("ref"), name) => ByReference(name)
+ case x => ErrorReporting.fatal(s"Unknown assembly parameter passing convention: `$x`")
+ }
+
+ val asmParamDefinition: P[ParameterDeclaration] = for {
+ p <- position()
+ typ <- identifier ~ SWS
+ appc <- appcSimple | appcComplex
+ } yield ParameterDeclaration(typ, appc).pos(p)
+
+
+ val arrayListContents: P[List[Expression]] = ("[" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ "]" ~/ Pass).map(_.toList)
+
+ val doubleQuotedString: P[List[Char]] = P("\"" ~/ CharsWhile(c => c != '\"' && c != '\n' && c != '\r').! ~ "\"").map(_.toList)
+
+ val codec: P[TextCodec] = P(position() ~ identifier).map {
+ case (_, "ascii") => TextCodec.Ascii
+ case (_, "petscii") => TextCodec.Petscii
+ case (_, "pet") => TextCodec.Petscii
+ case (p, x) =>
+ ErrorReporting.error(s"Unknown string encoding: `$x`", Some(p))
+ TextCodec.Ascii
+ }
+
+ def arrayFileContents: P[List[Expression]] = for {
+ p <- "file" ~ HWS ~/ "(" ~/ HWS ~/ position()
+ filePath <- doubleQuotedString ~/ HWS
+ optSlice <- ("," ~/ HWS ~/ literalAtom ~/ HWS ~/ "," ~/ HWS ~/ literalAtom ~/ HWS ~/ Pass).?
+ _ <- ")" ~/ Pass
+ } yield {
+ val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString))
+ val slice = optSlice.fold(data) {
+ case (start, length) => data.drop(start.value.toInt).take(length.value.toInt)
+ }
+ slice.map(c => LiteralExpression(c & 0xff, 1)).toList
+ }
+
+ def arrayStringContents: P[List[Expression]] = P(position() ~ doubleQuotedString ~/ HWS ~ codec).map {
+ case (p, s, co) => s.map(c => LiteralExpression(co.decode(None, c), 1).pos(p))
+ }
+
+ def arrayContents: P[List[Expression]] = arrayListContents | arrayFileContents | arrayStringContents
+
+ def arrayDefinition: P[ArrayDeclarationStatement] = for {
+ p <- position()
+ name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS
+ length <- ("[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~ "]").? ~ HWS
+ addr <- ("@" ~/ HWS ~/ mlExpression(1)).? ~/ HWS
+ contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS
+ } yield ArrayDeclarationStatement(name, length, addr, contents).pos(p)
+
+ def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO
+
+ def mlExpression(level: Int): P[Expression] = {
+ val allowedOperators = mlOperators.drop(level).flatten
+
+ def inner: P[SeparatedList[Expression, String]] = {
+ for {
+ head <- tightMlExpression ~/ HWS
+ maybeOperator <- StringIn(allowedOperators: _*).!.?
+ maybeTail <- maybeOperator.fold[P[Option[List[(String, Expression)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail)))
+ } yield {
+ maybeTail.fold[SeparatedList[Expression, String]](SeparatedList.of(head))(t => SeparatedList(head, t))
+ }
+ }
+
+ def p(list: SeparatedList[Expression, String], level: Int): Expression =
+ if (level == mlOperators.length) list.head
+ else {
+ val xs = list.split(mlOperators(level).toSet(_))
+ xs.separators.distinct match {
+ case Nil =>
+ if (xs.tail.nonEmpty)
+ ErrorReporting.error("Too many different operators")
+ p(xs.head, level + 1)
+ case List("+") | List("-") | List("+", "-") | List("-", "+") =>
+ SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = false)
+ case List("+'") | List("-'") | List("+'", "-'") | List("-'", "+'") =>
+ SumExpression(xs.toPairList("+").map { case (op, value) => (op == "-", p(value, level + 1)) }, decimal = true)
+ case List(":") =>
+ if (xs.size != 2) {
+ ErrorReporting.error("The `:` operator can have only two arguments", xs.head.head.position)
+ LiteralExpression(0, 1)
+ } else {
+ SeparateBytesExpression(p(xs.head, level + 1), p(xs.tail.head._2, level + 1))
+ }
+ case List(op) =>
+ FunctionCallExpression(op, xs.items.map(value => p(value, level + 1)))
+ case _ =>
+ ErrorReporting.error("Too many different operators")
+ LiteralExpression(0, 1)
+ }
+ }
+
+ inner.map(x => p(x, 0))
+ }
+
+ def mlLhsExpressionSimple: P[LhsExpression] = mlIndexedExpression | (position() ~ identifier).map { case (p, n) => VariableExpression(n).pos(p) }
+
+ def mlLhsExpression: P[LhsExpression] = {
+ val separated = position() ~ mlLhsExpressionSimple ~ HWS ~ ":" ~/ HWS ~ mlLhsExpressionSimple
+ separated.map { case (p, h, l) => SeparateBytesExpression(h, l).pos(p) } | mlLhsExpressionSimple
+ }
+
+
+ def mlParenExpr: P[Expression] = P("(" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ ")")
+
+ def mlIndexedExpression: P[IndexedExpression] = for {
+ p <- position()
+ array <- identifier
+ index <- HWS ~ "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~/ "]"
+ } yield IndexedExpression(array, index).pos(p)
+
+ def functionCall: P[FunctionCallExpression] = for {
+ p <- position()
+ name <- identifier
+ params <- HWS ~ "(" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ ""
+ } yield FunctionCallExpression(name, params.toList).pos(p)
+
+ val expressionStatement: P[ExecutableStatement] = mlExpression(0).map(ExpressionStatement)
+
+ val assignmentStatement: P[ExecutableStatement] =
+ (position() ~ mlLhsExpression ~ HWS ~ "=" ~/ HWS ~ mlExpression(1)).map {
+ case (p, l, r) => Assignment(l, r).pos(p)
+ }
+
+ def keywordStatement: P[ExecutableStatement] = P(returnStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement)
+
+ def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) }
+
+ // TODO: label and instruction in one line
+ def asmLabel: P[ExecutableStatement] = (identifier ~ HWS ~ ":" ~/ HWS).map(l => AssemblyStatement(Opcode.LABEL, AddrMode.DoesNotExist, VariableExpression(l), elidable = true))
+
+ // def zeropageAddrModeHint: P[Option[Boolean]] = Pass
+
+ def asmOpcode: P[Opcode.Value] = (position() ~ letter.rep(exactly = 3).!).map { case (p, o) => Opcode.lookup(o, Some(p)) }
+
+ def asmExpression: P[Expression] = (position() ~ NoCut(
+ ("<" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = false)) |
+ (">" ~/ HWS ~ mlExpression(mathLevel)).map(e => HalfWordExpression(e, hiByte = true)) |
+ mlExpression(mathLevel)
+ )).map { case (p, e) => e.pos(p) }
+
+ val commaX = HWS ~ "," ~ HWS ~ ("X" | "x") ~ HWS
+ val commaY = HWS ~ "," ~ HWS ~ ("Y" | "y") ~ HWS
+
+ def asmParameter: P[(AddrMode.Value, Expression)] = {
+ (SWS ~ (
+ ("#" ~ asmExpression).map(AddrMode.Immediate -> _) |
+ ("(" ~ HWS ~ asmExpression ~ HWS ~ ")" ~ commaY).map(AddrMode.IndexedY -> _) |
+ ("(" ~ HWS ~ asmExpression ~ commaX ~ ")").map(AddrMode.IndexedX -> _) |
+ ("(" ~ HWS ~ asmExpression ~ HWS ~ ")").map(AddrMode.Indirect -> _) |
+ (asmExpression ~ commaX).map(AddrMode.AbsoluteX -> _) |
+ (asmExpression ~ commaY).map(AddrMode.AbsoluteY -> _) |
+ asmExpression.map(AddrMode.Absolute -> _)
+ )).?.map(_.getOrElse(AddrMode.Implied -> LiteralExpression(0, 1)))
+ }
+
+ def elidable: P[Boolean] = ("?".! ~/ HWS).?.map(_.isDefined)
+
+ def asmInstruction: P[ExecutableStatement] = {
+ val lineParser: P[(Boolean, Opcode.Value, (AddrMode.Value, Expression))] = !"}" ~ elidable ~/ asmOpcode ~/ asmParameter
+ lineParser.map { case (elid, op, param) =>
+ AssemblyStatement(op, param._1, param._2, elid)
+ }
+ }
+
+ def asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros
+
+ def statement: P[Statement] = (position() ~ P(keywordStatement | variableDefinition(false) | expressionStatement)).map { case (p, s) => s.pos(p) }
+
+ def asmStatements: P[List[ExecutableStatement]] = ("{" ~/ AWS ~/ asmStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList)
+
+ def statements: P[List[Statement]] = ("{" ~/ AWS ~ statement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList)
+
+ def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}"
+
+ def returnStatement: P[ExecutableStatement] = ("return" ~ !letterOrDigit ~/ HWS ~ mlExpression(nonStatementLevel).?).map(ReturnStatement)
+
+ def ifStatement: P[ExecutableStatement] = for {
+ condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
+ thenBranch <- AWS ~/ executableStatements
+ elseBranch <- (AWS ~ "else" ~/ AWS ~/ executableStatements).?
+ } yield IfStatement(condition, thenBranch.toList, elseBranch.getOrElse(Nil).toList)
+
+ def whileStatement: P[ExecutableStatement] = for {
+ condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
+ body <- AWS ~ executableStatements
+ } yield WhileStatement(condition, body.toList)
+
+ def forDirection: P[ForDirection.Value] =
+ ("parallel" ~ HWS ~ "to").!.map(_ => ForDirection.ParallelTo) |
+ ("parallel" ~ HWS ~ "until").!.map(_ => ForDirection.ParallelUntil) |
+ "until".!.map(_ => ForDirection.Until) |
+ "to".!.map(_ => ForDirection.To) |
+ ("down" ~/ HWS ~/ "to").!.map(_ => ForDirection.DownTo)
+
+ def forStatement: P[ExecutableStatement] = for {
+ identifier <- "for" ~ SWS ~/ identifier ~/ "," ~/ Pass
+ start <- mlExpression(nonStatementLevel) ~ HWS ~ "," ~/ HWS ~/ Pass
+ direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass
+ end <- mlExpression(nonStatementLevel)
+ body <- AWS ~ executableStatements
+ } yield ForStatement(identifier, start, end, direction, body.toList)
+
+ def inlineAssembly: P[ExecutableStatement] = for {
+ condition <- "asm" ~ !letterOrDigit ~/ Pass
+ body <- AWS ~ asmStatements
+ } yield BlockStatement(body)
+
+ //noinspection MutatorLikeMethodIsParameterless
+ def doWhileStatement: P[ExecutableStatement] = for {
+ body <- "do" ~ !letterOrDigit ~/ AWS ~ executableStatements ~/ AWS
+ condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
+ } yield DoWhileStatement(body.toList, condition)
+
+ def functionDefinition: P[DeclarationStatement] = for {
+ p <- position()
+ flags <- flags("asm", "inline", "interrupt", "reentrant") ~ HWS
+ returnType <- identifier ~ SWS
+ name <- identifier ~ HWS
+ params <- "(" ~/ AWS ~/ (if (flags("asm")) asmParamDefinition else paramDefinition).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ ")" ~/ AWS
+ addr <- ("@" ~/ HWS ~/ mlExpression(1)).?.opaque("") ~/ AWS
+ statements <- (externFunctionBody | (if (flags("asm")) asmStatements else statements).map(l => Some(l))) ~/ Pass
+ } yield {
+ if (flags("interrupt") && flags("inline")) ErrorReporting.error(s"Interrupt function `$name` cannot be inline", Some(p))
+ if (flags("interrupt") && flags("reentrant")) ErrorReporting.error("Interrupt function `$name` cannot be reentrant", Some(p))
+ if (flags("inline") && flags("reentrant")) ErrorReporting.error("Reentrant and inline exclude each other", Some(p))
+ if (flags("interrupt") && returnType != "void") ErrorReporting.error("Interrupt function `$name` has to return void", Some(p))
+ if (addr.isEmpty && statements.isEmpty) ErrorReporting.error("Extern function `$name` must have an address", Some(p))
+ if (statements.isEmpty && !flags("asm") && params.nonEmpty) ErrorReporting.error("Extern non-asm function `$name` cannot have parameters", Some(p))
+ if (flags("asm")) statements match {
+ case Some(Nil) => ErrorReporting.warn("Assembly function `$name` is empty, did you mean RTS or RTI", options, Some(p))
+ case Some(xs) =>
+ if (flags("interrupt")) {
+ if (xs.exists {
+ case AssemblyStatement(Opcode.RTS, _, _, _) => true
+ case _ => false
+ }) ErrorReporting.warn("Assembly interrupt function `$name` contains RTS, did you mean RTI?", options, Some(p))
+ } else {
+ if (xs.exists {
+ case AssemblyStatement(Opcode.RTI, _, _, _) => true
+ case _ => false
+ }) ErrorReporting.warn("Assembly non-interrupt function `$name` contains RTI, did you mean RTS?", options, Some(p))
+ }
+ if (!flags("inline")) {
+ xs.last match {
+ case AssemblyStatement(Opcode.RTS, _, _, _) => () // OK
+ case AssemblyStatement(Opcode.RTI, _, _, _) => () // OK
+ case AssemblyStatement(Opcode.JMP, _, _, _) => () // OK
+ case _ =>
+ val validReturn = if (flags("interrupt")) "RTI" else "RTS"
+ ErrorReporting.warn(s"Non-inline assembly function `$name` should end in " + validReturn, options, Some(p))
+ }
+ }
+ case None => ()
+ }
+ FunctionDeclarationStatement(name, returnType, params.toList,
+ addr,
+ statements,
+ flags("inline"),
+ flags("asm"),
+ flags("interrupt"),
+ flags("reentrant")).pos(p)
+ }
+
+ def importStatement: Parser[ImportStatement] = ("import" ~ !letterOrDigit ~/ SWS ~/ identifier).map(ImportStatement)
+
+ def program: Parser[Program] = for {
+ _ <- Start ~/ AWS ~/ Pass
+ definitions <- (importStatement | arrayDefinition | functionDefinition | variableDefinition(true)).rep(sep = EOL)
+ _ <- AWS ~ End
+ } yield Program(definitions.toList)
+
+
+}
diff --git a/src/main/scala/millfork/parser/MinimalTestCase.scala b/src/main/scala/millfork/parser/MinimalTestCase.scala
new file mode 100644
index 00000000..2f5852d6
--- /dev/null
+++ b/src/main/scala/millfork/parser/MinimalTestCase.scala
@@ -0,0 +1,24 @@
+package millfork.parser
+
+import fastparse.all._
+import fastparse.core
+
+object MinimalTestCase {
+ def AWS: P[Unit] = "\n".rep(min = 0).opaque("").log()
+
+ def EOL: P[Unit] = "\n".rep(min = 1).opaque("").log()
+
+ def identifier: P[String] = CharPred(_.isLetter).rep(min = 1).!.opaque("").log()
+
+ def identifierWithSpace: P[String] = (identifier ~/ AWS ~/ Pass).opaque("").log()
+
+ def separator: P[Unit] = ("," ~/ AWS ~/ Pass).opaque("").log()
+
+ def identifiers: P[Seq[String]] = identifierWithSpace.rep(min = 0, sep = separator)//.opaque("").log()
+
+ def array: P[Seq[String]] = ("[" ~/ AWS ~/ identifiers ~/ "]" ~/ Pass)//.opaque("").log()
+
+ def arrays: Parser[Seq[Seq[String]]] = (array ~/ EOL).rep(min = 0, sep = !End ~/ Pass)//.opaque("").log()
+
+ def program: Parser[Seq[Seq[String]]] = Start ~/ AWS ~/ arrays ~/ End
+}
diff --git a/src/main/scala/millfork/parser/ParserBase.scala b/src/main/scala/millfork/parser/ParserBase.scala
new file mode 100644
index 00000000..4836d700
--- /dev/null
+++ b/src/main/scala/millfork/parser/ParserBase.scala
@@ -0,0 +1,169 @@
+package millfork.parser
+
+import millfork.node.Position
+
+/**
+ * @author Karol Stasiak
+ */
+case class ParseException(msg: String, position: Option[Position]) extends Exception
+
+class ParserBase(filename: String, input: String) {
+
+ def reset(): Unit = {
+ cursor = 0
+ line = 1
+ column = FirstColumn
+ }
+
+ private val FirstColumn = 0
+ private val length = input.length
+ private var cursor = 0
+ private var line = 1
+ private var column = FirstColumn
+
+ def position = Position(filename, line, column, cursor)
+
+ def restorePosition(p: Position): Unit = {
+ cursor = p.cursor
+ column = p.column
+ line = p.line
+ }
+
+ def error(msg: String, pos: Option[Position]): Nothing = throw ParseException(msg, pos)
+
+ def error(msg: String, pos: Position): Nothing = throw ParseException(msg, Some(pos))
+
+ def error(msg: String): Nothing = throw ParseException(msg, Some(position))
+
+ def error() = throw ParseException("Syntax error", Some(position))
+
+ def nextChar() = {
+ if (cursor >= length) error("Unexpected end of input")
+ val c = input(cursor)
+ cursor += 1
+ if (c == '\n') {
+ line += 1
+ column = FirstColumn
+ } else {
+ column += 1
+ }
+ c
+ }
+
+ def peekChar(): Char = {
+ if (cursor >= length) '\0' else input(cursor)
+ }
+
+ def require(char: Char): Char = {
+ val pos = position
+ val c = nextChar()
+ if (c != char) error(s"Expected `$char`", pos)
+ c
+ }
+ def require(p: Char=>Boolean, errorMsg: String = "Unexpected character"): Char = {
+ val pos = position
+ val c = nextChar()
+ if (!p(c)) error(errorMsg, pos)
+ c
+ }
+
+ def require(s: String): String = {
+ val c = peekChars(s.length)
+ if (c != s) error(s"Expected `$s`")
+ 1 to s.length foreach (_=>nextChar())
+ s
+ }
+
+ def requireAny(s: String, errorMsg: String = "Unexpected character"): Char = {
+ val c = nextChar()
+ if (s.contains(c)) c
+ else error(errorMsg)
+ }
+
+ def peek2Chars(): String = {
+ peekChars(2)
+ }
+
+ def peekChars(n: Int): String = {
+ if (cursor > length - n) input.substring(cursor) else input.substring(cursor, cursor + n)
+ }
+
+ def charsWhile(pred: Char => Boolean, min: Int = 0, errorMsg: String = "Unexpected character"): String = {
+ val sb = new StringBuilder()
+ while (pred(peekChar())) {
+ sb += nextChar()
+ }
+ val s = sb.toString
+ if (s.length < min) error(errorMsg)
+ else s
+ }
+
+ def skipNextIfMatches(c: Char): Boolean = {
+ if (peekChar() == c) {
+ nextChar()
+ true
+ } else {
+ false
+ }
+ }
+
+ def either(c: Char, s: String): Unit = {
+ if (peekChar() == c) {
+ nextChar()
+ } else if (peekChars(s.length) == s) {
+ require(s)
+ } else {
+ error(s"Expected either `$c` or `$s`")
+ }
+ }
+
+ def sepOrEnd(sep: Char, end: Char): Boolean = {
+ val p = position
+ val c = nextChar()
+ if (c == sep) true
+ else if (c==end) false
+ else error(s"Expected `$sep` or `$end`", p)
+ }
+
+ def anyOf[T](errorMsg: String, alternatives: (()=> T)*): T = {
+ alternatives.foreach { t =>
+ val p = position
+ try {
+ return t()
+ } catch {
+ case _: ParseException => restorePosition(p)
+ }
+ }
+ error(errorMsg)
+ }
+
+ def surrounded[T](left: => Any, content: => T, right: => Any): T = {
+ left
+ val result = content
+ right
+ content
+ }
+
+ def followed[T](content: => T, right: => Any): T = {
+ val result = content
+ right
+ content
+ }
+
+ def attempt[T](content: => T): Option[T] = {
+ val p = position
+ try {
+ Some(content)
+ } catch {
+ case _: ParseException => None
+ }
+ }
+
+ def opaque[T](errorMsg: String)(block: =>T) :T={
+ try {
+ block
+ } catch{
+ case p:ParseException => error(errorMsg, p.position)
+ }
+ }
+}
diff --git a/src/main/scala/millfork/parser/SourceLoadingQueue.scala b/src/main/scala/millfork/parser/SourceLoadingQueue.scala
new file mode 100644
index 00000000..cba2df00
--- /dev/null
+++ b/src/main/scala/millfork/parser/SourceLoadingQueue.scala
@@ -0,0 +1,89 @@
+package millfork.parser
+
+import java.nio.file.{Files, Paths}
+
+import fastparse.core.Parsed.{Failure, Success}
+import millfork.CompilationOptions
+import millfork.error.ErrorReporting
+import millfork.node.{ImportStatement, Position, Program}
+
+import scala.collection.mutable
+
+/**
+ * @author Karol Stasiak
+ */
+class SourceLoadingQueue(val initialFilenames: List[String], val includePath: List[String], val options: CompilationOptions) {
+
+ private val parsedModules = mutable.Map[String, Program]()
+ private val moduleQueue = mutable.Queue[() => Unit]()
+ val extension: String = ".ml"
+
+
+ def run(): Program = {
+ initialFilenames.foreach { i =>
+ parseModule(extractName(i), includePath, Right(i), options)
+ }
+ options.platform.startingModules.foreach {m =>
+ moduleQueue.enqueue(() => parseModule(m, includePath, Left(None), options))
+ }
+ while (moduleQueue.nonEmpty) {
+ moduleQueue.dequeueAll(_ => true).par.foreach(_())
+ }
+ ErrorReporting.assertNoErrors("Parse failed")
+ parsedModules.values.reduce(_ + _)
+ }
+
+ def lookupModuleFile(includePath: List[String], moduleName: String, position: Option[Position]): String = {
+ includePath.foreach { dir =>
+ val file = Paths.get(dir, moduleName + extension).toFile
+ ErrorReporting.debug("Checking " + file)
+ if (file.exists()) {
+ return file.getAbsolutePath
+ }
+ }
+ ErrorReporting.fatal(s"Module `$moduleName` not found", position)
+ }
+
+ def parseModule(moduleName: String, includePath: List[String], why: Either[Option[Position], String], options: CompilationOptions): Unit = {
+ val filename: String = why.fold(p => lookupModuleFile(includePath, moduleName, p), s => s)
+ ErrorReporting.debug(s"Parsing $filename")
+ val path = Paths.get(filename)
+ val parentDir = path.toFile.getAbsoluteFile.getParent
+ val src = new String(Files.readAllBytes(path))
+ val parser = MfParser(filename, src, parentDir, options)
+ parser.toAst match {
+ case Success(prog, _) =>
+ parsedModules.synchronized {
+ parsedModules.put(moduleName, prog)
+ prog.declarations.foreach {
+ case s@ImportStatement(m) =>
+ if (!parsedModules.contains(m)) {
+ moduleQueue.enqueue(() => parseModule(m, parentDir :: includePath, Left(s.position), options))
+ }
+ case _ => ()
+ }
+ }
+ case f@Failure(a, b, d) =>
+ ErrorReporting.error(s"Failed to parse the module `$moduleName` in $filename", Some(parser.indexToPosition(f.index, parser.lastLabel)))
+// ErrorReporting.error(a.toString)
+// ErrorReporting.error(b.toString)
+// ErrorReporting.error(d.toString)
+// ErrorReporting.error(d.traced.expected)
+// ErrorReporting.error(d.traced.stack.toString)
+// ErrorReporting.error(d.traced.traceParsers.toString)
+// ErrorReporting.error(d.traced.fullStack.toString)
+// ErrorReporting.error(f.toString)
+ if (parser.lastLabel != "") {
+ ErrorReporting.error(s"Syntax error: ${parser.lastLabel} expected", Some(parser.lastPosition))
+ } else {
+ ErrorReporting.error("Syntax error", Some(parser.lastPosition))
+ }
+ }
+ }
+
+ def extractName(i: String): String = {
+ val noExt = i.stripSuffix(extension)
+ val lastSlash = noExt.lastIndexOf('/') max noExt.lastIndexOf('\\')
+ if (lastSlash >= 0) i.substring(lastSlash + 1) else i
+ }
+}
diff --git a/src/main/scala/millfork/parser/TextCodec.scala b/src/main/scala/millfork/parser/TextCodec.scala
new file mode 100644
index 00000000..cad9580b
--- /dev/null
+++ b/src/main/scala/millfork/parser/TextCodec.scala
@@ -0,0 +1,32 @@
+package millfork.parser
+import millfork.error.ErrorReporting
+import millfork.node.Position
+
+/**
+ * @author Karol Stasiak
+ */
+class TextCodec(val name:String, private val map: String, private val extra: Map[Char,Int]) {
+ def decode(position: Option[Position], c: Char): Int = {
+ if (extra.contains(c)) extra(c) else {
+ val index = map.indexOf(c)
+ if (index >= 0) {
+ index
+ } else {
+ ErrorReporting.fatal("Invalid character in string in ")
+ }
+ }
+ }
+}
+
+object TextCodec {
+ val NotAChar = '\ufffd'
+
+ val Ascii = new TextCodec("ASCII", 0.until(127).map{i => if (i<32) NotAChar else i.toChar}.mkString, Map.empty)
+
+ val Petscii = new TextCodec("PETSCII",
+ "\ufffd" * 32 + 0x20.to(0x3f).map(_.toChar).mkString + "@abcdefghijklmnopqrstuvwxyz[£]↑←–ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+ Map('^' -> 0x5E, 'π' -> 0x7E)
+ )
+
+
+}
diff --git a/src/test/java/com/grapeshot/halfnes/CPURAM.java b/src/test/java/com/grapeshot/halfnes/CPURAM.java
new file mode 100644
index 00000000..1227b5a4
--- /dev/null
+++ b/src/test/java/com/grapeshot/halfnes/CPURAM.java
@@ -0,0 +1,75 @@
+package com.grapeshot.halfnes;
+
+import com.grapeshot.halfnes.mappers.Mapper;
+import millfork.output.MemoryBank;
+
+/**
+ * Since the original CPURAM class was a convoluted mess of dependencies,
+ * I overrode it with mine that has only few pieces of junk glue to make it work
+ * @author Karol Stasiak
+ */
+@SuppressWarnings("unused")
+public class CPURAM {
+
+ private final MemoryBank mem;
+
+ // required by the CPU class for some reason
+ public Mapper mapper = new Mapper() {
+ @Override
+ public TVType getTVType() {
+ // the base class returns null, but this can't be null
+ return TVType.DENDY;
+ }
+ };
+ // required by the CPU class for some reason
+ public APU apu = new APU(null, null, this);
+
+ public CPURAM(MemoryBank mem) {
+ boolean[] readable = mem.readable();
+ boolean[] writeable = mem.writeable();
+ for (int i = 0xfffe; i >= 0; i--) {
+ if (readable[i]) {
+ // allow for dummy fetches by implied instructions
+ readable[i + 1] = true;
+ }
+ }
+ readable[0] = true;
+ readable[1] = true;
+ readable[2] = true;
+ for (int i = 0x100; i <= 0x1ff; i++) {
+ readable[i] = true;
+ writeable[i] = true;
+ }
+ for (int i = 0x4000; i <= 0x407f; i++) {
+ readable[i] = true;
+ writeable[i] = true;
+ }
+ for (int i = 0xc000; i <= 0xcfff; i++) {
+ readable[i] = true;
+ writeable[i] = true;
+ }
+ for (int i = 0xfffa; i <= 0xffff; i++) {
+ readable[i] = true;
+ writeable[i] = true;
+ }
+ this.mem = mem;
+ }
+
+
+ public final int read(int addr) {
+ addr &= 0xffff;
+ if (!mem.readable()[addr]) {
+ throw new RuntimeException("Can't read from $" + Integer.toHexString(addr));
+ }
+ return mem.output()[addr] & 0xff;
+ }
+
+ public final void write(int addr, int data) {
+ addr &= 0xffff;
+ if (!mem.writeable()[addr]) {
+ throw new RuntimeException("Can't write to $" + Integer.toHexString(addr));
+ }
+ mem.output()[addr] = (byte) data;
+ }
+
+}
diff --git a/src/test/scala/millfork/test/ArraySuite.scala b/src/test/scala/millfork/test/ArraySuite.scala
new file mode 100644
index 00000000..68bd25f4
--- /dev/null
+++ b/src/test/scala/millfork/test/ArraySuite.scala
@@ -0,0 +1,133 @@
+package millfork.test
+
+import millfork.{Cpu, OptimizationPresets}
+import millfork.assembly.opt.{AlwaysGoodOptimizations, DangerousOptimizations}
+import millfork.test.emu._
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ArraySuite extends FunSuite with Matchers {
+
+ test("Array assignment") {
+ val m = EmuSuperOptimizedRun(
+ """
+ | array output [3] @$c000
+ | array input = [5,6,7]
+ | void main () {
+ | copyEntry(0)
+ | copyEntry(1)
+ | copyEntry(2)
+ | }
+ | void copyEntry(byte index) {
+ | output[index] = input[index]
+ | }
+ """.stripMargin)
+ m.readByte(0xc000) should equal(5)
+ m.readByte(0xc001) should equal(6)
+ m.readByte(0xc002) should equal(7)
+
+ }
+ test("Array assignment with offset") {
+ EmuUltraBenchmarkRun(
+ """
+ | array output [8] @$c000
+ | void main () {
+ | byte i
+ | i = 0
+ | while i != 6 {
+ | output[i + 2] = i + 1
+ | output[i] = output[i]
+ | i += 1
+ | }
+ | }
+ """.stripMargin) { m =>
+ m.readByte(0xc002) should equal(1)
+ m.readByte(0xc007) should equal(6)
+ }
+ }
+
+ test("Array assignment with offset 1") {
+ val m = new EmuRun(Cpu.StrictMos, Nil, DangerousOptimizations.All ++ OptimizationPresets.Good, true)(
+ """
+ | array output [8] @$c000
+ | void main () {
+ | byte i
+ | i = 0
+ | while i != 6 {
+ | output[i + 2] = i + 1
+ | output[i] = output[i]
+ | i += 1
+ | }
+ | }
+ """.stripMargin)
+ m.readByte(0xc002) should equal(1)
+ m.readByte(0xc007) should equal(6)
+ }
+
+ test("Array assignment through a pointer") {
+ val m = EmuUnoptimizedRun(
+ """
+ | array output [3] @$c000
+ | pointer p
+ | void main () {
+ | p = output.addr
+ | byte i
+ | byte ignored
+ | i = 1
+ | word w
+ | w = $105
+ | p[i]:ignored = w
+ | }
+ """.stripMargin)
+ m.readByte(0xc001) should equal(1)
+
+ }
+
+ test("Array in place math") {
+ EmuBenchmarkRun(
+ """
+ | array output [4] @$c000
+ | void main () {
+ | byte i
+ | i = 3
+ | output[i] = 3
+ | output[i + 1 - 1] *= 4
+ | output[3] *= 5
+ | }
+ """.stripMargin)(_.readByte(0xc003) should equal(60))
+ }
+
+ test("Array simple read") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | array a[7]
+ | void main () {
+ | byte i
+ | i = 6
+ | a[i] = 6
+ | output = a[i]
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(6))
+ }
+
+ test("Array simple read 2") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | array a[7]
+ | void main () {
+ | output = 777
+ | byte i
+ | i = 6
+ | a[i] = 6
+ | output = a[i]
+ | }
+ """.stripMargin){m =>
+ m.readByte(0xc000) should equal(6)
+ m.readByte(0xc001) should equal(0)
+ }
+ }
+}
diff --git a/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala b/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala
new file mode 100644
index 00000000..783fb8e0
--- /dev/null
+++ b/src/test/scala/millfork/test/AssemblyOptimizationSuite.scala
@@ -0,0 +1,281 @@
+package millfork.test
+
+import millfork.{Cpu, OptimizationPresets}
+import millfork.assembly.opt.{AlwaysGoodOptimizations, LaterOptimizations, VariableToRegisterOptimization}
+import millfork.test.emu.{EmuBenchmarkRun, EmuUltraBenchmarkRun, EmuRun}
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class AssemblyOptimizationSuite extends FunSuite with Matchers {
+
+ test("Duplicate RTS") {
+ EmuBenchmarkRun(
+ """
+ | void main () {
+ | if 1 == 1 {
+ | return
+ | }
+ | }
+ """.stripMargin) { _ => }
+ }
+
+ test("Inlining variable") {
+ EmuBenchmarkRun(
+ """
+ | array output [5] @$C000
+ | void main () {
+ | byte i
+ | i = 0
+ | while (i<5) {
+ | output[i] = i
+ | i += 1
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc003) should equal(3))
+ }
+
+ test("Loading modified variables") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$C000
+ | void main () {
+ | byte x
+ | output = 5
+ | output += 1
+ | output += 1
+ | output += 1
+ | x = output
+ | output = x
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(8))
+ }
+
+ test("Bit ops") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$C000
+ | void main () {
+ | output ^= output
+ | output |= 5 | 6
+ | output |= 5 | 6
+ | output &= 5 & 6
+ | output ^= 8 ^ 16
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(28))
+ }
+
+ test("Inlining after a while") {
+ EmuBenchmarkRun(
+ """
+ | array output [2]@$C000
+ | void main () {
+ | byte i
+ | output[0] = 6
+ | lol()
+ | i = 1
+ | if (i > 0) {
+ | output[i] = 4
+ | }
+ | }
+ | void lol() {}
+ """.stripMargin)(_.readWord(0xc000) should equal(0x406))
+ }
+
+ test("Tail call") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$C000
+ | void main () {
+ | if (output != 55) {
+ | output += 1
+ | main()
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(55))
+ }
+
+ test("LDA-TAY elimination") {
+ new EmuRun(Cpu.StrictMos, OptimizationPresets.NodeOpt, List(VariableToRegisterOptimization, AlwaysGoodOptimizations.YYY), false)(
+ """
+ | array mouse_pointer[64]
+ | array arrow[64]
+ | byte output @$C000
+ | void main () {
+ | byte i
+ | i = 0
+ | while i < 63 {
+ | mouse_pointer[i] = arrow[i]
+ | i += 1
+ | }
+ | }
+ """.stripMargin)
+ }
+
+ test("Carry flag after AND-LSR") {
+ EmuUltraBenchmarkRun(
+ """
+ | byte output @$C000
+ | void main () {
+ | output = f(5)
+ | }
+ | byte f(byte x) {
+ | return ((x & $1E) >> 1) + 3
+ | }
+ |
+ """.stripMargin)(_.readByte(0xc000) should equal(5))
+ }
+
+ test("Index sequence") {
+ EmuUltraBenchmarkRun(
+ """
+ | array output[6] @$C000
+ | void main () {
+ | pointer o
+ | o = output.addr
+ | o[3] = 8
+ | o[4] = 8
+ | o[5] = 8
+ | }
+ |
+ """.stripMargin){m =>
+ m.readByte(0xc005) should equal(8)
+ }
+ }
+
+ test("Index switching") {
+ EmuUltraBenchmarkRun(
+ """
+ | array output1[6] @$C000
+ | array output2[6] @$C010
+ | array input[6] @$C010
+ | void main () {
+ | static byte a
+ | static byte b
+ | input[5] = 3
+ | a = five()
+ | b = five()
+ | output1[a] = input[b]
+ | output2[a] = input[b]
+ | }
+ | byte five() {
+ | return 5
+ | }
+ |
+ """.stripMargin){m =>
+ m.readByte(0xc005) should equal(3)
+ m.readByte(0xc015) should equal(3)
+ }
+ }
+
+ test("TAX-BCC-RTS-TXA optimization") {
+ new EmuRun(Cpu.StrictMos,
+ OptimizationPresets.NodeOpt, List(
+ AlwaysGoodOptimizations.PointlessStoreAfterLoad,
+ LaterOptimizations.PointlessLoadAfterStore,
+ VariableToRegisterOptimization,
+ LaterOptimizations.DoubleLoadToDifferentRegisters,
+ LaterOptimizations.DoubleLoadToTheSameRegister,
+ AlwaysGoodOptimizations.PointlessRegisterTransfers,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.PointlessLoadBeforeReturn,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessRegisterTransfers,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.PointlessLoadBeforeReturn,
+ AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
+ AlwaysGoodOptimizations.PointlessStashingToIndexOverShortSafeBranch,
+ AlwaysGoodOptimizations.PointlessRegisterTransfersBeforeReturn,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval,
+ AlwaysGoodOptimizations.IdempotentDuplicateRemoval), false)(
+ """
+ | byte output @$C000
+ | void main(){ delta() }
+ | byte delta () {
+ | output = 0
+ | byte mouse_delta
+ | mouse_delta = 6
+ | mouse_delta &= $3f
+ | if mouse_delta >= $20 {
+ | mouse_delta |= $c0
+ | return 3
+ | }
+ | return mouse_delta
+ | }
+ """.stripMargin).readByte(0xc000) should equal(0)
+ }
+
+ test("Memory access detection"){
+ EmuUltraBenchmarkRun(
+ """
+ | array h [4] @$C000
+ | array l [4] @$C404
+ | word output @$C00C
+ | word a @$C200
+ | void main () {
+ | byte i
+ | a = 0x102
+ | barrier()
+ | for i,0,until,4 {
+ | h[i]:l[i] = a
+ | }
+ | a.lo:a.hi=a
+ | output = a
+ | }
+ | void barrier (){}
+ |
+ """.stripMargin){m =>
+ m.readByte(0xc000) should equal(1)
+ m.readByte(0xc001) should equal(1)
+ m.readByte(0xc002) should equal(1)
+ m.readByte(0xc003) should equal(1)
+ m.readByte(0xc404) should equal(2)
+ m.readByte(0xc405) should equal(2)
+ m.readByte(0xc406) should equal(2)
+ m.readByte(0xc407) should equal(2)
+ m.readWord(0xc00c) should equal(0x201)
+ }
+ }
+
+ test("Memory access detection 2"){
+ EmuUltraBenchmarkRun(
+ """
+ | array h [4]
+ | array l [4]
+ | word output @$C00C
+ | word ptrh @$C000
+ | word ptrl @$C002
+ | void main () {
+ | ptrh = h.addr
+ | ptrl = l.addr
+ | byte i
+ | word a
+ | a = 0x102
+ | barrier()
+ | for i,0,until,4 {
+ | h[i]:l[i] = a
+ | }
+ | a.lo:a.hi=a
+ | output = a
+ | couput
+ | }
+ | void barrier (){}
+ |
+ """.stripMargin){m =>
+ val ptrh = 0xffff & m.readWord(0xC000)
+ val ptrl = 0xffff & m.readWord(0xC002)
+ m.readByte(ptrh + 0) should equal(1)
+ m.readByte(ptrh + 1) should equal(1)
+ m.readByte(ptrh + 2) should equal(1)
+ m.readByte(ptrh + 3) should equal(1)
+ m.readByte(ptrl + 0) should equal(2)
+ m.readByte(ptrl + 1) should equal(2)
+ m.readByte(ptrl + 2) should equal(2)
+ m.readByte(ptrl + 3) should equal(2)
+ m.readWord(0xc00c) should equal(0x201)
+ }
+ }
+}
diff --git a/src/test/scala/millfork/test/AssemblySuite.scala b/src/test/scala/millfork/test/AssemblySuite.scala
new file mode 100644
index 00000000..1c7acd5c
--- /dev/null
+++ b/src/test/scala/millfork/test/AssemblySuite.scala
@@ -0,0 +1,99 @@
+package millfork.test
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class AssemblySuite extends FunSuite with Matchers {
+
+ test("Inline assembly") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 0
+ | asm {
+ | inc $c000
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(1))
+ }
+
+ test("Assembly functions") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 0
+ | thing()
+ | }
+ | asm void thing() {
+ | inc $c000
+ | rts
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(1))
+ }
+
+ test("Empty assembly") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 1
+ | asm {}
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(1))
+ }
+
+ test("Passing params to assembly") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = f(5)
+ | }
+ | asm byte f(byte a) {
+ | clc
+ | adc #5
+ | rts
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(10))
+ }
+
+ test("Inline asm functions") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 0
+ | f()
+ | f()
+ | }
+ | inline asm void f() {
+ | inc $c000
+ | rts
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(1))
+ }
+
+ test("Inline asm functions 2") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 0
+ | add(output, 5)
+ | add(output, 5)
+ | }
+ | inline asm void add(byte ref v, byte const c) {
+ | lda v
+ | clc
+ | adc #c
+ | sta v
+ | rts
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(5))
+ }
+
+}
diff --git a/src/test/scala/millfork/test/BasicSymonTest.scala b/src/test/scala/millfork/test/BasicSymonTest.scala
new file mode 100644
index 00000000..ddf0f6a3
--- /dev/null
+++ b/src/test/scala/millfork/test/BasicSymonTest.scala
@@ -0,0 +1,28 @@
+package millfork.test
+
+import millfork.test.emu.EmuUnoptimizedRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class BasicSymonTest extends FunSuite with Matchers {
+ test("Empty test") {
+ EmuUnoptimizedRun(
+ """
+ | void main () {
+ |
+ | }
+ """.stripMargin)
+ }
+
+ test("Byte assignment") {
+ EmuUnoptimizedRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = (1)
+ | }
+ """.stripMargin).readByte(0xc000) should equal(1)
+ }
+}
diff --git a/src/test/scala/millfork/test/BitOpSuite.scala b/src/test/scala/millfork/test/BitOpSuite.scala
new file mode 100644
index 00000000..6156483d
--- /dev/null
+++ b/src/test/scala/millfork/test/BitOpSuite.scala
@@ -0,0 +1,36 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class BitOpSuite extends FunSuite with Matchers {
+
+ test("Word AND") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | byte b
+ | output = $5555
+ | b = $4E
+ | output &= b
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x44))
+ }
+ test("Long AND and EOR") {
+ EmuBenchmarkRun("""
+ | long output @$c000
+ | void main () {
+ | byte b
+ | word w
+ | output = $55555555
+ | w = $505
+ | output ^= w
+ | b = $4E
+ | output &= b
+ | }
+ """.stripMargin)(_.readLong(0xc000) should equal(0x40))
+ }
+}
diff --git a/src/test/scala/millfork/test/BooleanSuite.scala b/src/test/scala/millfork/test/BooleanSuite.scala
new file mode 100644
index 00000000..f9bf8b6d
--- /dev/null
+++ b/src/test/scala/millfork/test/BooleanSuite.scala
@@ -0,0 +1,64 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class BooleanSuite extends FunSuite with Matchers {
+
+ test("Not") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | array input = [5,6,7]
+ | void main () {
+ | byte a
+ | a = 5
+ | if not(a < 3) {output = 4}
+ | if not(a > 3) {output = 3}
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(4))
+
+ }
+
+
+ test("And") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | array input = [5,6,7]
+ | void main () {
+ | byte a
+ | byte b
+ | a = 5
+ | b = 5
+ | if a > 3 && b > 3 {output = 4}
+ | if a < 3 && b > 3 {output = 5}
+ | if a > 3 && b < 3 {output = 2}
+ | if a < 3 && b < 3 {output = 3}
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(4))
+ }
+
+
+ test("Or") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | array input = [5,6,7]
+ | void main () {
+ | byte a
+ | byte b
+ | a = 5
+ | b = 5
+ | output = 0
+ | if a > 3 || b > 3 {output += 4}
+ | if a < 3 || b > 3 {output += 5}
+ | if a > 3 || b < 3 {output += 2}
+ | if a < 3 || b < 3 {output = 30}
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(11))
+ }
+}
diff --git a/src/test/scala/millfork/test/ByteDecimalMathSuite.scala b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala
new file mode 100644
index 00000000..d0d8afae
--- /dev/null
+++ b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala
@@ -0,0 +1,68 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ByteDecimalMathSuite extends FunSuite with Matchers {
+
+ test("Decimal byte addition") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | byte a
+ | void main () {
+ | a = $36
+ | output = a +' a
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(0x72))
+ }
+
+ test("Decimal byte addition 2") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | byte a
+ | void main () {
+ | a = 1
+ | output = a +' $69
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(0x70))
+ }
+
+ test("In-place decimal byte addition") {
+ EmuBenchmarkRun(
+ """
+ | array output[3] @$c000
+ | byte a
+ | void main () {
+ | a = 1
+ | output[1] = 5
+ | output[a] +'= 1
+ | output[a] +'= $36
+ | }
+ """.stripMargin)(_.readByte(0xc001) should equal(0x42))
+ }
+
+ test("In-place decimal byte addition 2") {
+ EmuBenchmarkRun(
+ """
+ | array output[3] @$c000
+ | void main () {
+ | byte x
+ | byte y
+ | byte tmpx
+ | byte tmpy
+ | tmpx = one()
+ | tmpy = one()
+ | x = tmpx
+ | y = tmpy
+ | output[y] = $39
+ | output[x] +'= 1
+ | }
+ | byte one() { return 1 }
+ """.stripMargin)(_.readByte(0xc001) should equal(0x40))
+ }
+}
diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala
new file mode 100644
index 00000000..067743e2
--- /dev/null
+++ b/src/test/scala/millfork/test/ByteMathSuite.scala
@@ -0,0 +1,158 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ByteMathSuite extends FunSuite with Matchers {
+
+ test("Complex expression") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = (one() + one()) | (((one()<<2)-1) ^ one())
+ | }
+ | byte one() {
+ | return 1
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(2))
+ }
+
+ test("Byte addition") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | byte a
+ | void main () {
+ | a = 1
+ | output = a + a
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(2))
+ }
+
+ test("Byte addition 2") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | byte a
+ | void main () {
+ | a = 1
+ | output = a + 65
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(66))
+ }
+
+ test("In-place byte addition") {
+ EmuBenchmarkRun(
+ """
+ | array output[3] @$c000
+ | byte a
+ | void main () {
+ | a = 1
+ | output[1] = 5
+ | output[a] += 1
+ | output[a] += 36
+ | }
+ """.stripMargin)(_.readByte(0xc001) should equal(42))
+ }
+
+ test("Parameter order") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | array arr[6]
+ | void main () {
+ | output = 42
+ | }
+ | byte test1(byte a) @$6000 {
+ | return 5 + a
+ | }
+ | byte test2(byte a) @$6100 {
+ | return 5 | a
+ | }
+ | byte test3(byte a) @$6200 {
+ | return a + arr[a]
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(42))
+ }
+
+ test("In-place byte addition 2") {
+ EmuBenchmarkRun(
+ """
+ | array output[3] @$c000
+ | void main () {
+ | byte x
+ | byte y
+ | byte tmpx
+ | byte tmpy
+ | tmpx = one()
+ | tmpy = one()
+ | x = tmpx
+ | y = tmpy
+ | output[y] = 36
+ | output[x] += 1
+ | }
+ | byte one() { return 1 }
+ """.stripMargin)(_.readByte(0xc001) should equal(37))
+ }
+
+ test("In-place byte multiplication") {
+ multiplyCase1(0, 0)
+ multiplyCase1(0, 1)
+ multiplyCase1(0, 2)
+ multiplyCase1(0, 5)
+ multiplyCase1(1, 0)
+ multiplyCase1(5, 0)
+ multiplyCase1(7, 0)
+ multiplyCase1(2, 5)
+ multiplyCase1(7, 2)
+ multiplyCase1(100, 2)
+ multiplyCase1(54, 4)
+ multiplyCase1(2, 100)
+ multiplyCase1(4, 54)
+ }
+
+ private def multiplyCase1(x: Int, y: Int): Unit = {
+ EmuBenchmarkRun(
+ s"""
+ | byte output @$$c000
+ | void main () {
+ | output = $x
+ | output *= $y
+ | }
+ """.
+ stripMargin)(_.readByte(0xc000) should equal(x * y))
+ }
+
+ test("Byte multiplication") {
+ multiplyCase2(0, 0)
+ multiplyCase2(0, 1)
+ multiplyCase2(0, 2)
+ multiplyCase2(0, 5)
+ multiplyCase2(1, 0)
+ multiplyCase2(5, 0)
+ multiplyCase2(7, 0)
+ multiplyCase2(2, 5)
+ multiplyCase2(7, 2)
+ multiplyCase2(100, 2)
+ multiplyCase2(54, 4)
+ multiplyCase2(2, 100)
+ multiplyCase2(4, 54)
+ }
+
+ private def multiplyCase2(x: Int, y: Int): Unit = {
+ EmuBenchmarkRun(
+ s"""
+ | byte output @$$c000
+ | void main () {
+ | byte a
+ | a = $x
+ | output = a * $y
+ | }
+ """.
+ stripMargin)(_.readByte(0xc000) should equal(x * y))
+ }
+}
diff --git a/src/test/scala/millfork/test/CmosSuite.scala b/src/test/scala/millfork/test/CmosSuite.scala
new file mode 100644
index 00000000..23be132c
--- /dev/null
+++ b/src/test/scala/millfork/test/CmosSuite.scala
@@ -0,0 +1,31 @@
+package millfork.test
+
+import millfork.test.emu.EmuCmosBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class CmosSuite extends FunSuite with Matchers {
+
+ test("Zero store 1") {
+ EmuCmosBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | output = 1
+ | output = 0
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0))
+ }
+ test("Zero store 2") {
+ EmuCmosBenchmarkRun("""
+ | byte output @$c000
+ | void main () {
+ | output = 1
+ | output = 0
+ | output += 1
+ | output <<= 1
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(2))
+ }
+}
diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala
new file mode 100644
index 00000000..5b849d3b
--- /dev/null
+++ b/src/test/scala/millfork/test/ComparisonSuite.scala
@@ -0,0 +1,264 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ComparisonSuite extends FunSuite with Matchers {
+
+ test("Equality and inequality") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 5
+ | if (output == 5) {
+ | output += 1
+ | } else {
+ | output +=2
+ | }
+ | if (output != 6) {
+ | output += 78
+ | }
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(6))
+ }
+
+ test("Less") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 5
+ | while output < 150 {
+ | output += 1
+ | }
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(150))
+ }
+
+ test("Compare to zero") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | byte a
+ | a = 150
+ | while a != 0 {
+ | a -= 1
+ | output += 1
+ | }
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(150))
+ }
+
+ test("Carry flag optimization test") {
+ EmuBenchmarkRun(
+ """
+ | byte output @$c000
+ | void main () {
+ | byte a
+ | a = 150
+ | if (a >= 50) {
+ | output = 4
+ | } else {
+ | output = 0
+ | }
+ | output += get(55)
+ | }
+ | byte get(byte x) {
+ | if x >= 6 {return 0} else {return 128}
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(4))
+ }
+
+ test("Does it even work") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte a
+ | a = 150
+ | if a != 0 {
+ | output = 345
+ | }
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(345))
+ }
+
+ test("Word comparison constant") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | output = 0
+ | if 2222 == 2222 { output += 1 }
+ | if 2222 == 3333 { output -= 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+ test("Word comparison == and !=") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | word a
+ | word b
+ | word c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if a == 4 { output += 1 }
+ | if a == b { output += 1 }
+ | if a != c { output += 1 }
+ | if a != 5 { output += 1 }
+ | if a != 260 { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+ test("Word comparison <=") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | word a
+ | word b
+ | word c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if a <= 4 { output += 1 }
+ | if a <= 257 { output += 1 }
+ | if a <= b { output += 1 }
+ | if a <= c { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+ test("Word comparison <") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | word a
+ | word b
+ | word c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if a < 5 { output += 1 }
+ | if a < c { output += 1 }
+ | if a < 257 { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+
+ test("Word comparison >") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | word a
+ | word b
+ | word c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if c > a { output += 1 }
+ | if c > 1 { output += 1 }
+ | if c > 0 { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+ test("Word comparison >=") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | word a
+ | word b
+ | word c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if c >= 1 { output += 1 }
+ | if c >= a { output += 1 }
+ | if a >= a { output += 1 }
+ | if a >= 4 { output += 1 }
+ | if a >= 4 { output += 1 }
+ | if a >= 0 { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+ test("Signed comparison >=") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | sbyte a
+ | sbyte b
+ | sbyte c
+ | output = 0
+ | a = 4
+ | b = 4
+ | c = 5
+ | if c >= 1 { output += 1 }
+ | if c >= a { output += 1 }
+ | if a >= a { output += 1 }
+ | if a >= 4 { output += 1 }
+ | if a >= 4 { output += 1 }
+ | if a >= 0 { output += 1 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+
+ test("Signed comparison < and <=") {
+ val src =
+ """
+ | byte output @$c000
+ | void main () {
+ | sbyte a
+ | sbyte b
+ | sbyte c
+ | output = 0
+ | a = -1
+ | b = 0
+ | c = 1
+ | if a < 0 { output += 1 }
+ | if b < 0 { output -= 7 }
+ | if c < 0 { output -= 7 }
+ | if a < 1 { output += 1 }
+ | if b < 1 { output += 1 }
+ | if c < 1 { output -= 7 }
+ | if a <= 0 { output += 1 }
+ | if b <= 0 { output += 1 }
+ | if c <= 0 { output -= 7 }
+ | if a <= 1 { output += 1 }
+ | if b <= 1 { output += 1 }
+ | if c <= 1 { output += 1 }
+ | if a <= -1 { output += 1 }
+ | if b <= -1 { output -= 7 }
+ | if c <= -1 { output -= 7 }
+ | }
+ """.stripMargin
+ EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
+ }
+}
diff --git a/src/test/scala/millfork/test/ErasthotenesSuite.scala b/src/test/scala/millfork/test/ErasthotenesSuite.scala
new file mode 100644
index 00000000..34228b25
--- /dev/null
+++ b/src/test/scala/millfork/test/ErasthotenesSuite.scala
@@ -0,0 +1,37 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ErasthotenesSuite extends FunSuite with Matchers {
+
+ test("Erasthotenes") {
+ EmuBenchmarkRun(
+ """
+ | const pointer sieve = $C000
+ | const byte sqrt = 128
+ | void main () {
+ | byte i
+ | word j
+ | pointer s
+ | i = 2
+ | while i < sqrt {
+ | if sieve[i] == 0 {
+ | j = i << 1
+ | s = sieve
+ | s += j
+ | while j.hi == 0 {
+ | s[0] = 1
+ | s += i
+ | j += i
+ | }
+ | }
+ | i += 1
+ | }
+ | }
+ """.stripMargin){_=>}
+ }
+}
diff --git a/src/test/scala/millfork/test/ForLoopSuite.scala b/src/test/scala/millfork/test/ForLoopSuite.scala
new file mode 100644
index 00000000..bb9d9929
--- /dev/null
+++ b/src/test/scala/millfork/test/ForLoopSuite.scala
@@ -0,0 +1,76 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ForLoopSuite extends FunSuite with Matchers {
+
+ test("For-to") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte i
+ | output = 0
+ | for i,0,to,5 {
+ | output += i
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(15))
+ }
+ test("For-downto") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte i
+ | output = 0
+ | for i,5,downto,0 {
+ | output += i
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(15))
+ }
+ test("For-until") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte i
+ | output = 0
+ | for i,0,until,6 {
+ | output += i
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(15))
+ }
+ test("For-parallelto") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte i
+ | output = 0
+ | for i,0,parallelto,5 {
+ | output += i
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(15))
+ }
+ test("For-paralleluntil") {
+ EmuBenchmarkRun(
+ """
+ | word output @$c000
+ | void main () {
+ | byte i
+ | output = 0
+ | for i,0,paralleluntil,6 {
+ | output += i
+ | }
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(15))
+ }
+}
diff --git a/src/test/scala/millfork/test/IllegalSuite.scala b/src/test/scala/millfork/test/IllegalSuite.scala
new file mode 100644
index 00000000..36e8d312
--- /dev/null
+++ b/src/test/scala/millfork/test/IllegalSuite.scala
@@ -0,0 +1,51 @@
+package millfork.test
+
+import millfork.test.emu.EmuUndocumentedRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class IllegalSuite extends FunSuite with Matchers {
+
+ test("ALR test 1") {
+ EmuUndocumentedRun("""
+ | byte output @$c000
+ | byte input @$cfff
+ | void main () {
+ | output = (input & 0x45) >> 1
+ | }
+ """.stripMargin)
+ }
+
+ test("ALR test 2") {
+ EmuUndocumentedRun("""
+ | byte output @$c000
+ | byte input @$cfff
+ | void main () {
+ | output = (input & 0x45) >> 1
+ | }
+ """.stripMargin)
+ }
+
+ test("ISC/DCP test") {
+ val m = EmuUndocumentedRun("""
+ | array output [10] @$c000
+ | void main () {
+ | stack byte a
+ | output[5] = 36
+ | output[7] = 52
+ | five()
+ | a = 5
+ | five()
+ | output[a] += 1
+ | output[a + 2] -= 1
+ | }
+ | byte five () {
+ | return 5
+ | }
+ """.stripMargin)
+ m.readByte(0xc005) should equal(37)
+ m.readByte(0xc007) should equal(51)
+ }
+}
diff --git a/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala b/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala
new file mode 100644
index 00000000..2d62fe33
--- /dev/null
+++ b/src/test/scala/millfork/test/InlineAssemblyFunctionsSuite.scala
@@ -0,0 +1,80 @@
+package millfork.test
+
+import millfork.assembly.opt.DangerousOptimizations
+import millfork.test.emu.EmuBenchmarkRun
+import millfork.{Cpu, OptimizationPresets}
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class InlineAssemblyFunctionsSuite extends FunSuite with Matchers {
+
+ test("Poke test 1") {
+ EmuBenchmarkRun(
+ """
+ | inline asm void poke(word ref addr, byte const value) {
+ | ?LDA #value
+ | STA addr
+ | }
+ |
+ | byte output @$c000
+ | void main () {
+ | poke(output, 5)
+ | }
+ """.stripMargin) { m =>
+ m.readByte(0xc000) should equal(5)
+ }
+ }
+ test("Peek test 1") {
+ EmuBenchmarkRun(
+ """
+ | inline asm byte peek(word ref addr) {
+ | ?LDA addr
+ | }
+ |
+ | byte output @$c000
+ | void main () {
+ | byte a
+ | a = 5
+ | output = peek(a)
+ | }
+ """.stripMargin) { m =>
+ m.readByte(0xc000) should equal(5)
+ }
+ }
+ test("Poke test 2") {
+ EmuBenchmarkRun(
+ """
+ | inline asm void poke(word const addr, byte const value) {
+ | ?LDA #value
+ | STA addr
+ | }
+ |
+ | byte output @$c000
+ | void main () {
+ | poke($c000, 5)
+ | poke($c001, 5)
+ | }
+ """.stripMargin) { m =>
+ m.readByte(0xc000) should equal(5)
+ }
+ }
+ test("Peek test 2") {
+ EmuBenchmarkRun(
+ """
+ | inline asm byte peek(word const addr) {
+ | ?LDA addr
+ | }
+ |
+ | byte output @$c000
+ | void main () {
+ | byte a
+ | a = 5
+ | output = peek(a.addr)
+ | }
+ """.stripMargin) { m =>
+ m.readByte(0xc000) should equal(5)
+ }
+ }
+}
diff --git a/src/test/scala/millfork/test/LongTest.scala b/src/test/scala/millfork/test/LongTest.scala
new file mode 100644
index 00000000..129bb24f
--- /dev/null
+++ b/src/test/scala/millfork/test/LongTest.scala
@@ -0,0 +1,160 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class LongTest extends FunSuite with Matchers {
+
+ test("Long assignment") {
+ EmuBenchmarkRun(
+ """
+ | long output4 @$c000
+ | long output2 @$c004
+ | long output1 @$c008
+ | void main () {
+ | output4 = $11223344
+ | output2 = $11223344
+ | output1 = $11223344
+ | output2 = $7788
+ | output1 = $55
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x11223344)
+ m.readLong(0xc004) should equal(0x7788)
+ m.readLong(0xc008) should equal(0x55)
+ }
+ }
+ test("Long assignment 2") {
+ EmuBenchmarkRun(
+ """
+ | long output4 @$c000
+ | long output2 @$c004
+ | word output1 @$c008
+ | void main () {
+ | word w
+ | byte b
+ | w = $7788
+ | b = $55
+ | output4 = $11223344
+ | output2 = $11223344
+ | output1 = $11223344
+ | output2 = w
+ | output1 = b
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x11223344)
+ m.readLong(0xc004) should equal(0x7788)
+ m.readLong(0xc008) should equal(0x55)
+ }
+ }
+ test("Long addition") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | word w
+ | long l
+ | byte b
+ | w = $8000
+ | b = $8
+ | l = $50000
+ | output = 0
+ | output += l
+ | output += w
+ | output += b
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x58008)
+ }
+ }
+ test("Long addition 2") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | output = 0
+ | output += $50000
+ | output += $8000
+ | output += $8
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x58008)
+ }
+ }
+ test("Long subtraction") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | word w
+ | long l
+ | byte b
+ | w = $8000
+ | b = $8
+ | l = $50000
+ | output = $58008
+ | output -= l
+ | output -= w
+ | output -= b
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0)
+ }
+ }
+ test("Long subtraction 2") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | output = $58008
+ | output -= $50000
+ | output -= $8000
+ | output -= $8
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0)
+ }
+ }
+ test("Long subtraction 3") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | output = $58008
+ | output -= w()
+ | output -= b()
+ | }
+ | byte b() {
+ | return $8
+ | }
+ | word w() {
+ | return $8000
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x50000)
+ }
+ }
+
+ test("Long AND") {
+ EmuBenchmarkRun(
+ """
+ | long output @$c000
+ | void main () {
+ | output = $FFFFFF
+ | output &= w()
+ | output &= b()
+ | }
+ | byte b() {
+ | return $77
+ | }
+ | word w() {
+ | return $CCCC
+ | }
+ """.stripMargin) { m =>
+ m.readLong(0xc000) should equal(0x44)
+ }
+ }
+}
diff --git a/src/test/scala/millfork/test/MinimalTest.scala b/src/test/scala/millfork/test/MinimalTest.scala
new file mode 100644
index 00000000..e82c1346
--- /dev/null
+++ b/src/test/scala/millfork/test/MinimalTest.scala
@@ -0,0 +1,23 @@
+package millfork.test
+
+import fastparse.core.Mutable.Failure
+import fastparse.core.Parsed.Success
+import millfork.parser.MinimalTestCase
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class MinimalTest extends FunSuite with Matchers {
+// ignore("a") {
+// MinimalTestCase.program.parse("[]\n[a,g,%h]\n") match {
+// case Success(unoptimized, _) =>
+// println(unoptimized)
+// case f:Failure[_,_] =>
+// val g =f
+// println(f)
+// println(f.originalParser)
+// fail()
+// }
+// }
+}
diff --git a/src/test/scala/millfork/test/NodeOptimizationSuite.scala b/src/test/scala/millfork/test/NodeOptimizationSuite.scala
new file mode 100644
index 00000000..13f196a7
--- /dev/null
+++ b/src/test/scala/millfork/test/NodeOptimizationSuite.scala
@@ -0,0 +1,32 @@
+package millfork.test
+
+import millfork.test.emu.EmuNodeOptimizedRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class NodeOptimizationSuite extends FunSuite with Matchers {
+
+ test("Unreachable after return") {
+ EmuNodeOptimizedRun(
+ """
+ | byte crash @$ffff
+ | void main () {
+ | return
+ | crash = 2
+ | }
+ """.stripMargin)
+ }
+
+ test("Unused local variable") {
+ EmuNodeOptimizedRun(
+ """
+ | byte crash @$ffff
+ | void main () {
+ | byte a
+ | a = crash
+ | }
+ """.stripMargin)
+ }
+}
diff --git a/src/test/scala/millfork/test/NonetSuite.scala b/src/test/scala/millfork/test/NonetSuite.scala
new file mode 100644
index 00000000..f78f055e
--- /dev/null
+++ b/src/test/scala/millfork/test/NonetSuite.scala
@@ -0,0 +1,29 @@
+package millfork.test
+
+import millfork.test.emu.EmuUnoptimizedRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class NonetSuite extends FunSuite with Matchers {
+
+ test("Nonet operations") {
+ val m = EmuUnoptimizedRun(
+ """
+ | array output [3] @$c000
+ | array input = [5,6,7]
+ | void main () {
+ | word a
+ | a = $110
+ | output[1] = a >>>> 1
+ | output[2] = a >>>> 2
+ | }
+ | void copyEntry(byte index) {
+ | output[index] = input[index]
+ | }
+ """.stripMargin)
+ m.readByte(0xc001) should equal(0x88)
+ m.readByte(0xc002) should equal(0x44)
+ }
+}
diff --git a/src/test/scala/millfork/test/SeparateBytesSuite.scala b/src/test/scala/millfork/test/SeparateBytesSuite.scala
new file mode 100644
index 00000000..b6e2dd1a
--- /dev/null
+++ b/src/test/scala/millfork/test/SeparateBytesSuite.scala
@@ -0,0 +1,137 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class SeparateBytesSuite extends FunSuite with Matchers {
+
+ test("Separate assignment 1") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | output = 2:3
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x203))
+ }
+
+ test("Separate assignment 2") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | byte ignore @$c001
+ | void main () {
+ | word w
+ | w = $355
+ | output:ignore = w
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(3))
+ }
+
+ test("Separate assignment 3") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | byte ignore @$c001
+ | void main () {
+ | output:ignore = lol()
+ | }
+ | word lol() {
+ | return $567
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(5))
+ }
+
+ test("Separate assignment 4") {
+ EmuBenchmarkRun("""
+ | array output [5] @$c000
+ | byte ignore @$c001
+ | void main () {
+ | byte index
+ | index = 3
+ | output[index]:ignore = lol()
+ | }
+ | word lol() {
+ | return $567
+ | }
+ """.stripMargin)(_.readByte(0xc003) should equal(5))
+ }
+
+ test("Separate assignment 5") {
+ EmuBenchmarkRun("""
+ | array output [5] @$c000
+ | byte ignore @$c001
+ | void main () {
+ | byte index
+ | index = 3
+ | ignore:output[index] = lol()
+ | }
+ | word lol() {
+ | return $567
+ | }
+ """.stripMargin)(_.readByte(0xc003) should equal(0x67))
+ }
+
+ test("Magic split array") {
+ EmuBenchmarkRun("""
+ | array hi [16] @$c000
+ | array lo [16] @$c010
+ | void main () {
+ | word a
+ | word b
+ | word tmp
+ | a = 1
+ | b = 1
+ | byte i
+ | i = 0
+ | while i < 16 {
+ | hi[i]:lo[i] = a
+ | tmp = a
+ | tmp += b
+ | a = b
+ | b = tmp
+ | i += 1
+ | }
+ | }
+ """.stripMargin) { m=>
+ m.readWord(0xc000, 0xc010) should equal(1)
+ m.readWord(0xc001, 0xc011) should equal(1)
+ m.readWord(0xc002, 0xc012) should equal(2)
+ m.readWord(0xc003, 0xc013) should equal(3)
+ m.readWord(0xc004, 0xc014) should equal(5)
+ m.readWord(0xc005, 0xc015) should equal(8)
+ m.readWord(0xc006, 0xc016) should equal(13)
+ m.readWord(0xc007, 0xc017) should equal(21)
+ }
+ }
+
+ test("Separate addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | byte h
+ | byte l
+ | h = 6
+ | l = 5
+ | output = $101
+ | output += h:l
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x706))
+ }
+
+ test("Separate increase") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | byte h
+ | byte l
+ | h = 6
+ | l = 5
+ | (h:l) += $101
+ | output = h:l
+ | (h:l) += 1
+ | output = h:l
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x707))
+ }
+}
diff --git a/src/test/scala/millfork/test/ShiftSuite.scala b/src/test/scala/millfork/test/ShiftSuite.scala
new file mode 100644
index 00000000..68ab499b
--- /dev/null
+++ b/src/test/scala/millfork/test/ShiftSuite.scala
@@ -0,0 +1,63 @@
+package millfork.test
+import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedRun}
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class ShiftSuite extends FunSuite with Matchers {
+
+ test("In-place shifting") {
+ EmuUnoptimizedRun("""
+ | array output [3] @$c000
+ | void main () {
+ | output[0] = 1
+ | output[1] = 3
+ | output[output[0]] <<= 2
+ | }
+ """.stripMargin).readByte(0xc001) should equal(12)
+ }
+
+ test("Byte shifting") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | void main () {
+ | byte a
+ | a = 3
+ | output = a << 2
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(12))
+ }
+
+ test("Word shifting") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | byte a
+ | a = 3
+ | output = a
+ | output <<= 7
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(128 * 3))
+ }
+
+ test("Long shifting left") {
+ EmuBenchmarkRun("""
+ | long output @$c000
+ | void main () {
+ | output = $1010301
+ | output <<= 2
+ | }
+ """.stripMargin)(_.readLong(0xc000) should equal(0x4040C04))
+ }
+
+ test("Long shifting right") {
+ EmuBenchmarkRun("""
+ | long output @$c000
+ | void main () {
+ | output = $4040C04
+ | output >>= 2
+ | }
+ """.stripMargin)(_.readLong(0xc000) should equal(0x1010301))
+ }
+}
diff --git a/src/test/scala/millfork/test/SignExtensionSuite.scala b/src/test/scala/millfork/test/SignExtensionSuite.scala
new file mode 100644
index 00000000..993e8b7c
--- /dev/null
+++ b/src/test/scala/millfork/test/SignExtensionSuite.scala
@@ -0,0 +1,44 @@
+package millfork.test
+
+import millfork.test.emu.EmuUnoptimizedRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class SignExtensionSuite extends FunSuite with Matchers {
+
+ test("Sbyte to Word") {
+ EmuUnoptimizedRun("""
+ | word output @$c000
+ | void main () {
+ | sbyte b
+ | b = -1
+ | output = b
+ | }
+ """.stripMargin).readWord(0xc000) should equal(0xffff)
+ }
+ test("Sbyte to Word 2") {
+ EmuUnoptimizedRun("""
+ | word output @$c000
+ | void main () {
+ | output = b()
+ | }
+ | sbyte b() {
+ | return -1
+ | }
+ """.stripMargin).readWord(0xc000) should equal(0xffff)
+ }
+ test("Sbyte to Long") {
+ EmuUnoptimizedRun("""
+ | long output @$c000
+ | void main () {
+ | output = 421
+ | output += b()
+ | }
+ | sbyte b() {
+ | return -1
+ | }
+ """.stripMargin).readLong(0xc000) should equal(420)
+ }
+}
diff --git a/src/test/scala/millfork/test/StackVarSuite.scala b/src/test/scala/millfork/test/StackVarSuite.scala
new file mode 100644
index 00000000..8e5051f0
--- /dev/null
+++ b/src/test/scala/millfork/test/StackVarSuite.scala
@@ -0,0 +1,142 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class StackVarSuite extends FunSuite with Matchers {
+
+ test("Basic stack assignment") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | void main () {
+ | stack byte a
+ | stack byte b
+ | b = 4
+ | a = b
+ | output = a
+ | a = output
+ | }
+ """.stripMargin)(_.readByte(0xc000) should equal(4))
+ }
+
+ test("Stack byte addition") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | void main () {
+ | stack byte a
+ | stack byte b
+ | a = $11
+ | b = $44
+ | b += zzz()
+ | b += a
+ | output = b
+ | }
+ | byte zzz() {
+ | return $22
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x77))
+ }
+
+ test("Complex expressions involving stack variables") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | void main () {
+ | stack byte a
+ | a = 7
+ | output = f(a) + f(a) + f(a)
+ | }
+ | asm byte f(byte a) {
+ | rts
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(21))
+ }
+
+// test("Stack byte subtraction") {
+// SymonUnoptimizedRun("""
+// | byte output @$c000
+// | void main () {
+// | stack byte a
+// | stack byte b
+// | b = $77
+// | a = $11
+// | b -= zzz()
+// | b -= a
+// | output = b
+// | }
+// | byte zzz() {
+// | return $22
+// | }
+// """.stripMargin).readByte(0xc000) should equal(0x44)
+// }
+
+ test("Stack word addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | stack word a
+ | stack word b
+ | a = $111
+ | b = $444
+ | b += zzz()
+ | b += a
+ | output = b
+ | }
+ | word zzz() {
+ | return $222
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(0x777))
+ }
+
+ test("Recursion") {
+ EmuBenchmarkRun("""
+ | array output [6] @$c000
+ | byte fails @$c010
+ | void main () {
+ | word w
+ | byte i
+ | for i,0,until,output.length {
+ | w = fib(i)
+ | if w.hi != 0 { fails += 1 }
+ | output[i] = w.lo
+ | }
+ | }
+ | word fib(byte i) {
+ | stack byte j
+ | j = i
+ | if j < 2 {
+ | return 1
+ | }
+ | stack word sum
+ | sum = fib(j-1)
+ | sum += fib(j-2)
+ | sum &= $0F3F
+ | return sum
+ | }
+ """.stripMargin){ m =>
+ m.readByte(0xc010) should equal(0)
+ m.readByte(0xc000) should equal(1)
+ m.readByte(0xc001) should equal(1)
+ m.readByte(0xc002) should equal(2)
+ m.readByte(0xc003) should equal(3)
+ m.readByte(0xc004) should equal(5)
+ m.readByte(0xc005) should equal(8)
+ }
+ }
+
+
+ test("Indexing") {
+ EmuBenchmarkRun("""
+ | array output [200] @$c000
+ | void main () {
+ | stack byte a
+ | stack byte b
+ | a = $11
+ | b = $44
+ | output[a + b] = $66
+ | }
+ """.stripMargin){m => m.readWord(0xc055) should equal(0x66) }
+ }
+}
diff --git a/src/test/scala/millfork/test/TypeWideningSuite.scala b/src/test/scala/millfork/test/TypeWideningSuite.scala
new file mode 100644
index 00000000..7160f1b0
--- /dev/null
+++ b/src/test/scala/millfork/test/TypeWideningSuite.scala
@@ -0,0 +1,23 @@
+package millfork.test
+
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class TypeWideningSuite extends FunSuite with Matchers {
+
+ test("Word after simple ops") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | output = random()
+ | output = output.hi << 1
+ | }
+ | word random() {
+ | return $777
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(14))
+ }
+}
diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala
new file mode 100644
index 00000000..83db90d8
--- /dev/null
+++ b/src/test/scala/millfork/test/WordMathSuite.scala
@@ -0,0 +1,103 @@
+package millfork.test
+import millfork.test.emu.EmuBenchmarkRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+ * @author Karol Stasiak
+ */
+class WordMathSuite extends FunSuite with Matchers {
+
+ test("Word addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | word a
+ | void main () {
+ | a = 640
+ | output = a
+ | output += a
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(1280))
+ }
+
+ test("Word subtraction") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | word a
+ | void main () {
+ | a = 640
+ | output = 740
+ | output -= a
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(100))
+ }
+
+ test("Word subtraction 2") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | word a
+ | void main () {
+ | a = 640
+ | output = a
+ | output -= 400
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(240))
+ }
+
+ test("Byte-to-word addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | word pair
+ | void main () {
+ | pair = $A5A5
+ | pair.lo = 1
+ | output = 640
+ | output += pair.lo
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(641))
+ }
+
+ test("Literal addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | void main () {
+ | output = 640
+ | output += -0050
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(590))
+ }
+
+ test("Array element addition") {
+ EmuBenchmarkRun("""
+ | word output @$c000
+ | word pair
+ | array b[2]
+ | void main () {
+ | byte i
+ | i = 1
+ | b[1] = 5
+ | pair = $A5A5
+ | pair.lo = 1
+ | output = 640
+ | output += b[i]
+ | }
+ """.stripMargin)(_.readWord(0xc000) should equal(645))
+ }
+
+ test("nesdev.com example") {
+ EmuBenchmarkRun("""
+ | byte output @$c000
+ | array map [256] @$c300
+ | array b[2]
+ | void main () {
+ | output = get(5, 6)
+ | }
+ | byte get(byte mx, byte my) {
+ | pointer p
+ | p = mx
+ | p <<= 5
+ | p += map
+ | return p[my]
+ | }
+ """.stripMargin)(m => ())
+ }
+}
diff --git a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala
new file mode 100644
index 00000000..9188abaf
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala
@@ -0,0 +1,30 @@
+package millfork.test.emu
+
+import millfork.output.MemoryBank
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuBenchmarkRun {
+ def apply(source:String)(verifier: MemoryBank=>Unit) = {
+ val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source)
+ val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source)
+// val (Timings(t2, _), m2) = SymonSuperOptimizedRun.apply2(source)
+//val (Timings(t3, _), m3) = SymonQuantumOptimizedRun.apply2(source)
+ println(f"Before optimization: $t0%7d")
+ println(f"After optimization: $t1%7d")
+// println(f"After quantum: $t3%7d")
+// println(f"After superopt.: $t2%7d")
+ println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%")
+// println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%")
+// println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%")
+ println(f"Running unoptimized")
+ verifier(m0)
+ println(f"Running optimized")
+ verifier(m1)
+// println(f"Running quantum optimized")
+// verifier(m3)
+// println(f"Running superoptimized")
+// verifier(m2)
+ }
+}
diff --git a/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala
new file mode 100644
index 00000000..77f14579
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuCmosBenchmarkRun.scala
@@ -0,0 +1,26 @@
+package millfork.test.emu
+
+import millfork.output.MemoryBank
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuCmosBenchmarkRun {
+ def apply(source:String)(verifier: MemoryBank=>Unit) = {
+ val (Timings(_, t0), m0) = EmuUnoptimizedRun.apply2(source)
+ val (Timings(_, t1), m1) = EmuOptimizedRun.apply2(source)
+ val (Timings(_, t2), m2) = EmuOptimizedCmosRun.apply2(source)
+ println(f"Before optimization: $t0%7d")
+ println(f"After NMOS optimization: $t1%7d")
+ println(f"After CMOS optimization: $t2%7d")
+ println(f"Gain unopt->NMOS: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%")
+ println(f"Gain unopt->CMOS: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%")
+ println(f"Gain NMOS->CMOS: ${(100L*(t1-t2)/t1.toDouble).round}%7d%%")
+ println(f"Running unoptimized")
+ verifier(m0)
+ println(f"Running NMOS-optimized")
+ verifier(m1)
+ println(f"Running CMOS-optimized")
+ verifier(m2)
+ }
+}
diff --git a/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala
new file mode 100644
index 00000000..f3f65144
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuNodeOptimizedRun.scala
@@ -0,0 +1,15 @@
+package millfork.test.emu
+
+import millfork.Cpu
+import millfork.node.opt.{UnreachableCode, UnusedLocalVariables}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuNodeOptimizedRun extends EmuRun(
+ Cpu.StrictMos,
+ List(
+ UnreachableCode,
+ UnusedLocalVariables),
+ Nil,
+ false)
\ No newline at end of file
diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala
new file mode 100644
index 00000000..91aed3e5
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuOptimizedCmosRun.scala
@@ -0,0 +1,19 @@
+package millfork.test.emu
+
+import millfork.assembly.opt.CmosOptimizations
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuOptimizedCmosRun extends EmuRun(
+ Cpu.Cmos,
+ OptimizationPresets.NodeOpt,
+ OptimizationPresets.AssOpt ++
+ CmosOptimizations.All ++ OptimizationPresets.Good ++
+ CmosOptimizations.All ++ OptimizationPresets.Good ++
+ CmosOptimizations.All ++ OptimizationPresets.Good,
+ false)
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala
new file mode 100644
index 00000000..13d252b0
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala
@@ -0,0 +1,15 @@
+package millfork.test.emu
+
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuOptimizedRun extends EmuRun(
+ Cpu.StrictMos,
+ OptimizationPresets.NodeOpt,
+ OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good,
+ false)
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuPlatform.scala b/src/test/scala/millfork/test/emu/EmuPlatform.scala
new file mode 100644
index 00000000..4caca5f3
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuPlatform.scala
@@ -0,0 +1,19 @@
+package millfork.test.emu
+
+import millfork.output.{AfterCodeByteAllocator, CurrentBankFragmentOutput, VariableAllocator}
+import millfork.{Cpu, Platform}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuPlatform {
+ def get(cpu: Cpu.Value) = new Platform(
+ cpu,
+ Map(),
+ Nil,
+ CurrentBankFragmentOutput(0, 0xffff),
+ new VariableAllocator((0 until 256 by 2).toList, new AfterCodeByteAllocator(0xff00)),
+ 0x200,
+ ".bin"
+ )
+}
diff --git a/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala
new file mode 100644
index 00000000..9a85ac7c
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuQuantumOptimizedRun.scala
@@ -0,0 +1,15 @@
+package millfork.test.emu
+
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuQuantumOptimizedRun extends EmuRun(
+ Cpu.StrictMos,
+ OptimizationPresets.NodeOpt,
+ OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good,
+ true)
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala
new file mode 100644
index 00000000..59b0c853
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuRun.scala
@@ -0,0 +1,238 @@
+package millfork.test.emu
+
+import com.grapeshot.halfnes.{CPU, CPURAM}
+import com.loomcom.symon.InstructionTable.CpuBehavior
+import com.loomcom.symon.{Bus, Cpu, CpuState}
+import fastparse.core.Parsed.{Failure, Success}
+import millfork.assembly.opt.AssemblyOptimization
+import millfork.compiler.{CompilationContext, MlCompiler}
+import millfork.env.{Environment, InitializedArray, NormalFunction}
+import millfork.error.ErrorReporting
+import millfork.node.StandardCallGraph
+import millfork.node.opt.NodeOptimization
+import millfork.output.{Assembler, MemoryBank}
+import millfork.parser.MfParser
+import millfork.{CompilationFlag, CompilationOptions}
+import org.scalatest.Matchers
+
+/**
+ * @author Karol Stasiak
+ */
+case class Timings(nmos: Long, cmos: Long)
+
+class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], assemblyOptimizations: List[AssemblyOptimization], quantum: Boolean) extends Matchers {
+
+ def apply(source: String): MemoryBank = {
+ apply2(source)._2
+ }
+
+ def emitIllegals = false
+
+ private val timingNmos = Array[Int](
+ 7, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 4, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+ 6, 6, 0, 8, 3, 3, 5, 5, 4, 2, 2, 2, 4, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+
+ 6, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 3, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+ 6, 6, 0, 8, 3, 3, 5, 5, 4, 2, 2, 2, 5, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+
+ 2, 6, 2, 6, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4,
+ 2, 6, 0, 6, 4, 4, 4, 4, 2, 5, 2, 5, 5, 5, 5, 5,
+ 2, 6, 2, 6, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4,
+ 2, 5, 0, 5, 4, 4, 4, 4, 2, 4, 2, 4, 4, 4, 4, 4,
+
+ 2, 6, 2, 8, 3, 3, 5, 5, 2, 2, 2, 2, 4, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+ 2, 6, 2, 8, 3, 3, 5, 5, 2, 2, 2, 2, 4, 4, 6, 6,
+ 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
+ )
+
+ private val timingCmos = Array[Int](
+ 7, 6, 2, 1, 5, 3, 5, 5, 3, 2, 2, 1, 6, 4, 6, 5,
+ 2, 5, 5, 1, 5, 4, 6, 5, 2, 4, 2, 1, 6, 4, 6, 5,
+ 6, 6, 2, 1, 3, 3, 5, 5, 4, 2, 2, 1, 4, 4, 6, 5,
+ 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 2, 1, 4, 4, 6, 5,
+
+ 6, 6, 2, 1, 3, 3, 5, 5, 3, 2, 2, 1, 3, 4, 6, 5,
+ 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 3, 1, 8, 4, 6, 5,
+ 6, 6, 2, 1, 3, 3, 5, 5, 4, 2, 2, 1, 6, 4, 6, 5,
+ 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 4, 1, 6, 4, 6, 5,
+
+ 3, 6, 2, 1, 3, 3, 3, 5, 2, 2, 2, 1, 4, 4, 4, 5,
+ 2, 6, 5, 1, 4, 4, 4, 5, 2, 5, 2, 1, 4, 5, 5, 5,
+ 2, 6, 2, 1, 3, 3, 3, 5, 2, 2, 2, 1, 4, 4, 4, 5,
+ 2, 5, 5, 1, 4, 4, 4, 5, 2, 4, 2, 1, 4, 4, 4, 5,
+
+ 2, 6, 2, 1, 3, 3, 5, 5, 2, 2, 2, 3, 4, 4, 6, 5,
+ 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 3, 3, 4, 4, 7, 5,
+ 2, 6, 2, 1, 3, 3, 5, 5, 2, 2, 2, 1, 4, 4, 6, 5,
+ 2, 5, 5, 1, 4, 4, 6, 5, 2, 4, 4, 1, 4, 4, 7, 5,
+ )
+
+ private val variableLength = Set(0x10, 0x30, 0x50, 0x70, 0x90, 0xb0, 0xd0, 0xf0)
+
+ private val TooManyCycles: Long = 1000000
+
+ private def formatBool(b: Boolean, c: Char) = if (b) c else '-'
+
+ private def formatState(q: CpuState): String =
+ f"A=${q.a}%02X X=${q.x}%02X Y=${q.y}%02X S=${q.sp}%02X PC=${q.pc}%04X " +
+ formatBool(q.negativeFlag, 'N') + formatBool(q.overflowFlag, 'V') + formatBool(q.breakFlag, 'B') +
+ formatBool(q.decimalModeFlag, 'D') + formatBool(q.irqDisableFlag, 'I') + formatBool(q.zeroFlag, 'Z') + formatBool(q.carryFlag, 'C')
+
+ def apply2(source: String): (Timings, MemoryBank) = {
+ Console.out.flush()
+ Console.err.flush()
+ println(source)
+ val platform = EmuPlatform.get(cpu)
+ val options = new CompilationOptions(platform, Map(
+ CompilationFlag.EmitIllegals -> this.emitIllegals,
+ CompilationFlag.DetailedFlowAnalysis -> quantum,
+ ))
+ ErrorReporting.hasErrors = false
+ ErrorReporting.verbosity = 999
+ val parserF = MfParser("", source, "", options)
+ parserF.toAst match {
+ case Success(unoptimized, _) =>
+ ErrorReporting.assertNoErrors("Parse failed")
+
+
+ // prepare
+ val program = nodeOptimizations.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt))
+ val callGraph = new StandardCallGraph(program)
+ val env = new Environment(None, "")
+ env.collectDeclarations(program, options)
+
+ val hasOptimizations = assemblyOptimizations.nonEmpty
+ var optimizedSize = 0L
+ var unoptimizedSize = 0L
+ // print asm
+ env.allPreallocatables.foreach {
+ case f: NormalFunction =>
+ val result = MlCompiler.compile(CompilationContext(f.environment, f, 0, options))
+ val unoptimized = result.linearize
+ if (hasOptimizations) {
+ val optimized = assemblyOptimizations.foldLeft(unoptimized) { (c, opt) =>
+ opt.optimize(f, c, options)
+ }
+ println("Unoptimized:")
+ unoptimized.filter(_.isPrintable).foreach(println(_))
+ println("Optimized:")
+ optimized.filter(_.isPrintable).foreach(println(_))
+ unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
+ optimizedSize += optimized.map(_.sizeInBytes).sum
+ } else {
+ unoptimized.filter(_.isPrintable).foreach(println(_))
+ unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
+ optimizedSize += unoptimized.map(_.sizeInBytes).sum
+ }
+ case d: InitializedArray =>
+ println(d.name)
+ d.contents.foreach(c => println(" !byte " + c))
+ unoptimizedSize += d.contents.length
+ optimizedSize += d.contents.length
+ }
+
+ ErrorReporting.assertNoErrors("Compile failed")
+
+ if (unoptimizedSize == optimizedSize) {
+ println(f"Size: $unoptimizedSize%5d B")
+ } else {
+ println(f"Unoptimized size: $unoptimizedSize%5d B")
+ println(f"Optimized size: $optimizedSize%5d B")
+ println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%")
+ }
+
+ // compile
+ val assembler = new Assembler(env)
+ assembler.assemble(callGraph, assemblyOptimizations, options)
+ assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }
+
+ ErrorReporting.assertNoErrors("Code generation failed")
+
+ val memoryBank = assembler.mem.banks(0)
+ platform.cpu match {
+ case millfork.Cpu.Cmos =>
+ runViaSymon(memoryBank, platform.org, CpuBehavior.CMOS_6502)
+ case millfork.Cpu.Ricoh =>
+ runViaHalfnes(memoryBank, platform.org)
+ case millfork.Cpu.Mos =>
+ ErrorReporting.fatal("There's no NMOS emulator with decimal mode support")
+ Timings(-1, -1) -> memoryBank
+ case _ =>
+ runViaSymon(memoryBank, platform.org, CpuBehavior.NMOS_6502)
+ }
+ case f: Failure[_, _] =>
+ println(f)
+ println(f.extra.toString)
+ println(f.lastParser.toString)
+ ErrorReporting.error("Syntax error", Some(parserF.lastPosition))
+ ???
+ }
+ }
+
+ def runViaHalfnes(memoryBank: MemoryBank, org: Int): (Timings, MemoryBank) = {
+ val cpu = new CPU(new CPURAM(memoryBank))
+ cpu.reset()
+ cpu.PC = org
+ // stack underflow cannot be easily detected directly,
+ // but since the stack is full of zeroes, an underflowing RTS jumps to $0001
+ while (cpu.PC.&(0xffff) > 1 && cpu.clocks < TooManyCycles) {
+ // println(cpu.status())
+ cpu.runcycle(0, 0)
+ }
+ println("clocks: " + cpu.clocks)
+ System.out.flush()
+ cpu.clocks.toLong should be < TooManyCycles
+ println(cpu.clocks + " NMOS cycles")
+ cpu.flagstobyte().&(8).==(0) should be(true)
+ Timings(cpu.clocks, 0) -> memoryBank
+ }
+
+ def runViaSymon(memoryBank: MemoryBank, org: Int, behavior: CpuBehavior): (Timings, MemoryBank) = {
+ val cpu = new Cpu
+ cpu.setBehavior(behavior)
+ val ram = new SymonTestRam(memoryBank)
+ val bus = new Bus(1 << 16)
+ bus.addCpu(cpu)
+ bus.addDevice(ram)
+ cpu.setBus(bus)
+ cpu.setProgramCounter(org)
+ cpu.setStackPointer(0xff)
+ val legal = Assembler.getStandardLegalOpcodes
+
+ var countNmos = 0L
+ var countCmos = 0L
+ while (cpu.getStackPointer > 1 && countCmos < TooManyCycles) {
+ // println(cpu.disassembleNextOp())
+ val pcBefore = cpu.getProgramCounter
+ cpu.step()
+ val pcAfter = cpu.getProgramCounter
+ // println(formatState(cpu.getCpuState))
+ val instruction = cpu.getInstruction
+ if (behavior == CpuBehavior.NMOS_6502 || behavior == CpuBehavior.NMOS_WITH_ROR_BUG) {
+ if (!legal(instruction)) {
+ throw new RuntimeException("unexpected illegal: " + instruction.toHexString)
+ }
+ }
+ countNmos += timingNmos(instruction)
+ countCmos += timingCmos(instruction)
+ if (variableLength(instruction)) {
+ val jump = pcAfter - pcBefore
+ if (jump <= 0 || jump > 3) {
+ countNmos += 1
+ countCmos += 1
+ }
+ }
+ }
+ countCmos should be < TooManyCycles
+ println(countNmos + " NMOS cycles")
+ println(countCmos + " CMOS cycles")
+ cpu.getDecimalModeFlag should be(false)
+ Timings(countNmos, countCmos) -> memoryBank
+ }
+
+}
diff --git a/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala
new file mode 100644
index 00000000..fdf03635
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuSuperQuantumOptimizedRun.scala
@@ -0,0 +1,17 @@
+package millfork.test.emu
+
+import millfork.assembly.opt.SuperOptimizer
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+// TODO : it doesn't work
+object EmuSuperQuantumOptimizedRun extends EmuRun(
+ Cpu.StrictMos,
+ OptimizationPresets.NodeOpt,
+ List(SuperOptimizer),
+ true)
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala
new file mode 100644
index 00000000..30be3c0d
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala
@@ -0,0 +1,17 @@
+package millfork.test.emu
+
+import millfork.assembly.opt.SuperOptimizer
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+// TODO : it doesn't work
+object EmuSuperOptimizedRun extends EmuRun(
+ Cpu.StrictMos,
+ OptimizationPresets.NodeOpt,
+ List(SuperOptimizer),
+ false)
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala
new file mode 100644
index 00000000..008b7dcc
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala
@@ -0,0 +1,35 @@
+package millfork.test.emu
+
+import millfork.output.MemoryBank
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuUltraBenchmarkRun {
+ def apply(source:String)(verifier: MemoryBank=>Unit) = {
+ val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source)
+ val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source)
+ val (Timings(t2, _), m2) = EmuSuperOptimizedRun.apply2(source)
+ val (Timings(t3, _), m3) = EmuQuantumOptimizedRun.apply2(source)
+ val (Timings(t4, _), m4) = EmuSuperQuantumOptimizedRun.apply2(source)
+ println(f"Before optimization: $t0%7d")
+ println(f"After optimization: $t1%7d")
+ println(f"After superopt.: $t2%7d")
+ println(f"After quantum: $t3%7d")
+ println(f"After superquantum: $t4%7d")
+ println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%")
+ println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%")
+ println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%")
+ println(f"Super quantum gain: ${(100L*(t0-t4)/t0.toDouble).round}%7d%%")
+ println(f"Running unoptimized")
+ verifier(m0)
+ println(f"Running optimized")
+ verifier(m1)
+ println(f"Running superoptimized")
+ verifier(m2)
+ println(f"Running quantum optimized")
+ verifier(m3)
+ println(f"Running superquantum optimized")
+ verifier(m4)
+ }
+}
diff --git a/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala b/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala
new file mode 100644
index 00000000..16f9d1ae
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuUndocumentedRun.scala
@@ -0,0 +1,19 @@
+package millfork.test.emu
+
+import millfork.assembly.opt.UndocumentedOptimizations
+import millfork.{Cpu, OptimizationPresets}
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuUndocumentedRun extends EmuRun(
+ Cpu.Ricoh, // not Cpu.Mos, because I haven't found an emulator that supports both illegals and decimal mode yet
+ OptimizationPresets.NodeOpt,
+ OptimizationPresets.AssOpt ++ UndocumentedOptimizations.All ++ OptimizationPresets.Good ++ UndocumentedOptimizations.All ++ OptimizationPresets.Good,
+ false) {
+
+ override def emitIllegals = true
+}
+
+
+
diff --git a/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala
new file mode 100644
index 00000000..71af4c70
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/EmuUnoptimizedRun.scala
@@ -0,0 +1,9 @@
+package millfork.test.emu
+
+import millfork.Cpu
+
+
+/**
+ * @author Karol Stasiak
+ */
+object EmuUnoptimizedRun extends EmuRun(Cpu.StrictMos, Nil, Nil, false)
\ No newline at end of file
diff --git a/src/test/scala/millfork/test/emu/SymonTestRam.scala b/src/test/scala/millfork/test/emu/SymonTestRam.scala
new file mode 100644
index 00000000..0ae21e28
--- /dev/null
+++ b/src/test/scala/millfork/test/emu/SymonTestRam.scala
@@ -0,0 +1,39 @@
+package millfork.test.emu
+
+import com.loomcom.symon.devices.Device
+import millfork.output.MemoryBank
+
+/**
+ * @author Karol Stasiak
+ */
+class SymonTestRam(mem: MemoryBank) extends Device(0x0000, 0xffff, "RAM") {
+
+ mem.readable(1) = true
+ mem.readable(2) = true
+
+ (0x100 to 0x1ff).foreach { stack =>
+ mem.writeable(stack) = true
+ mem.readable(stack) = true
+ }
+
+ (0xc000 to 0xcfff).foreach { himem =>
+ mem.writeable(himem) = true
+ mem.readable(himem) = true
+ }
+
+ override def write(i: Int, i1: Int): Unit = {
+ if (!mem.writeable(i)) {
+ throw new RuntimeException(s"Can't write to $$${i.toHexString}")
+ }
+ mem.output(i) = i1.toByte
+ }
+
+ override def read(i: Int, b: Boolean): Int = {
+ if (!mem.readable(i)) {
+ throw new RuntimeException(s"Can't read from $$${i.toHexString}")
+ }
+ mem.output(i)
+ }
+
+ override def toString: String = "TestRam"
+}