Compare commits

...

32 Commits
v2.1 ... v2.3

Author SHA1 Message Date
4bfdbad2e4 added mandel gfx to examples 2020-07-03 23:56:36 +02:00
06137ecdc4 v2.3 2020-07-03 23:51:27 +02:00
d89f5b0df8 todo about fixing argclobbering 2020-07-03 23:49:17 +02:00
b6e2b36692 refactor 2020-07-03 23:37:38 +02:00
a6d789cfbc fixed function argument type cast bug 2020-07-03 17:24:43 +02:00
c07907e7bd fixed missing shifts codegen 2020-07-02 21:28:48 +02:00
7d8496c874 fixed missing shifts codegen 2020-07-02 19:18:47 +02:00
164ac56db1 compiler error todos 2020-07-01 22:31:38 +02:00
fdddb8ca64 slight optimization 2020-07-01 22:23:46 +02:00
a9d4b8b0fa fixed ast modifications on node arrays, in particular function call parameter lists 2020-07-01 22:03:54 +02:00
ec7b9f54c2 subroutine inlining is an optimizer step 2020-07-01 12:41:10 +02:00
307558a7e7 removed some double code related to call tree 2020-06-30 20:42:55 +02:00
febf423eab tehtriz compilation issues 2020-06-30 20:42:13 +02:00
a999c23014 simple subroutine inlining added 2020-06-27 17:03:03 +02:00
69f1ade595 gfx mandelbrot example added 2020-06-18 01:35:24 +02:00
b166576e54 comments 2020-06-17 23:27:54 +02:00
ee2ba5f398 some more optimizations for swap() function call asm code generation 2020-06-17 22:40:57 +02:00
cb9825484d some more optimized in-array assignments codegeneration 2020-06-17 21:41:38 +02:00
76cda82e23 v2.2 2020-06-16 01:43:44 +02:00
37b61d9e6b v2.2 2020-06-16 01:39:11 +02:00
52f0222a6d Got rid of old Ast transformer Api, some compiler error fixes 2020-06-16 01:25:49 +02:00
75ccac2f2c refactoring last of old Ast modification Api 2020-06-16 00:36:02 +02:00
5c771a91f7 refactoring last of old Ast modification Api 2020-06-14 16:56:48 +02:00
a242ad10e6 fix double printing of sub param vardecl 2020-06-14 13:46:46 +02:00
b5086b6a8f refactoring last of old Ast modification Api 2020-06-14 03:17:42 +02:00
3e47dad12a clearer no modifications 2020-06-14 02:54:29 +02:00
235610f40c refactored StatementOptimizer 2020-06-14 02:41:23 +02:00
6b59559c65 memory address assignment codegen 2020-06-14 02:12:40 +02:00
23e954f716 refactoring StatementOptimizer 2020-06-14 02:00:32 +02:00
983c899cad refactor AstIdentifierChecker 2020-06-13 00:14:19 +02:00
c2f9385965 refactor AstIdentifierChecker 2020-06-12 21:34:27 +02:00
ceb2c9e4f8 added string value assignment, leftstr, rightstr, substr functions 2020-06-06 00:05:39 +02:00
55 changed files with 1813 additions and 1774 deletions

View File

@ -781,3 +781,18 @@ set_array_float .proc
; -- copies the 5 bytes of the mflt value pointed to by SCRATCH_ZPWORD1, ; -- copies the 5 bytes of the mflt value pointed to by SCRATCH_ZPWORD1,
; into the 5 bytes pointed to by A/Y. Clobbers A,Y. ; into the 5 bytes pointed to by A/Y. Clobbers A,Y.
.pend .pend
swap_floats .proc
; -- swap floats pointed to by SCRATCH_ZPWORD1, SCRATCH_ZPWORD2
ldy #4
- lda (c64.SCRATCH_ZPWORD1),y
pha
lda (c64.SCRATCH_ZPWORD2),y
sta (c64.SCRATCH_ZPWORD1),y
pla
sta (c64.SCRATCH_ZPWORD2),y
dey
bpl -
rts
.pend

View File

@ -2078,3 +2078,127 @@ ror2_array_uw .proc
sta (c64.SCRATCH_ZPWORD1),y sta (c64.SCRATCH_ZPWORD1),y
+ rts + rts
.pend .pend
strcpy .proc
; copy a string (0-terminated) from A/Y to (ZPWORD1)
; it is assumed the target string is large enough.
sta c64.SCRATCH_ZPWORD2
sty c64.SCRATCH_ZPWORD2+1
ldy #$ff
- iny
lda (c64.SCRATCH_ZPWORD2),y
sta (c64.SCRATCH_ZPWORD1),y
bne -
rts
.pend
func_leftstr .proc
; leftstr(source, target, length) with params on stack
inx
lda c64.ESTACK_LO,x
tay ; length
inx
lda c64.ESTACK_LO,x
sta c64.SCRATCH_ZPWORD2
lda c64.ESTACK_HI,x
sta c64.SCRATCH_ZPWORD2+1
inx
lda c64.ESTACK_LO,x
sta c64.SCRATCH_ZPWORD1
lda c64.ESTACK_HI,x
sta c64.SCRATCH_ZPWORD1+1
lda #0
sta (c64.SCRATCH_ZPWORD2),y
- dey
cpy #$ff
bne +
rts
+ lda (c64.SCRATCH_ZPWORD1),y
sta (c64.SCRATCH_ZPWORD2),y
jmp -
.pend
func_rightstr .proc
; rightstr(source, target, length) with params on stack
; make place for the 4 parameters for substr()
dex
dex
dex
dex
; X-> .
; x+1 -> length of segment
; x+2 -> start index
; X+3 -> target LO+HI
; X+4 -> source LO+HI
; original parameters:
; x+5 -> original length LO
; x+6 -> original targetLO + HI
; x+7 -> original sourceLO + HI
; replicate paramters:
lda c64.ESTACK_LO+5,x
sta c64.ESTACK_LO+1,x
lda c64.ESTACK_LO+6,x
sta c64.ESTACK_LO+3,x
lda c64.ESTACK_HI+6,x
sta c64.ESTACK_HI+3,x
lda c64.ESTACK_LO+7,x
sta c64.ESTACK_LO+4,x
sta c64.SCRATCH_ZPWORD1
lda c64.ESTACK_HI+7,x
sta c64.ESTACK_HI+4,x
sta c64.SCRATCH_ZPWORD1+1
; determine string length
ldy #0
- lda (c64.SCRATCH_ZPWORD1),y
beq +
iny
bne -
+ tya
sec
sbc c64.ESTACK_LO+1,x ; start index = strlen - segment length
sta c64.ESTACK_LO+2,x
jsr func_substr
; unwind original params
inx
inx
inx
rts
.pend
func_substr .proc
; substr(source, target, start, length) with params on stack
inx
ldy c64.ESTACK_LO,x ; length
inx
lda c64.ESTACK_LO,x ; start
sta c64.SCRATCH_ZPB1
inx
lda c64.ESTACK_LO,x
sta c64.SCRATCH_ZPWORD2
lda c64.ESTACK_HI,x
sta c64.SCRATCH_ZPWORD2+1
inx
lda c64.ESTACK_LO,x
sta c64.SCRATCH_ZPWORD1
lda c64.ESTACK_HI,x
sta c64.SCRATCH_ZPWORD1+1
; adjust src location
clc
lda c64.SCRATCH_ZPWORD1
adc c64.SCRATCH_ZPB1
sta c64.SCRATCH_ZPWORD1
bcc +
inc c64.SCRATCH_ZPWORD1+1
+ lda #0
sta (c64.SCRATCH_ZPWORD2),y
jmp _startloop
- lda (c64.SCRATCH_ZPWORD1),y
sta (c64.SCRATCH_ZPWORD2),y
_startloop dey
cpy #$ff
bne -
rts
.pend

View File

@ -1 +1 @@
2.1 2.3

View File

@ -102,6 +102,12 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
} }
override fun visit(decl: VarDecl) { override fun visit(decl: VarDecl) {
// if the vardecl is a parameter of a subroutine, don't output it again
val paramNames = (decl.definingScope() as? Subroutine)?.parameters?.map { it.name }
if(paramNames!=null && decl.name in paramNames)
return
when(decl.type) { when(decl.type) {
VarDeclType.VAR -> {} VarDeclType.VAR -> {}
VarDeclType.CONST -> output("const ") VarDeclType.CONST -> output("const ")

View File

@ -4,7 +4,6 @@ import prog8.ast.base.*
import prog8.ast.expressions.Expression import prog8.ast.expressions.Expression
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
import prog8.ast.processing.AstWalker import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
@ -156,6 +155,7 @@ interface INameScope {
} }
fun containsCodeOrVars() = statements.any { it !is Directive || it.directive == "%asminclude" || it.directive == "%asm"} fun containsCodeOrVars() = statements.any { it !is Directive || it.directive == "%asminclude" || it.directive == "%asm"}
fun containsNoVars() = statements.all { it !is VarDecl }
fun containsNoCodeNorVars() = !containsCodeOrVars() fun containsNoCodeNorVars() = !containsCodeOrVars()
fun remove(stmt: Statement) { fun remove(stmt: Statement) {
@ -230,7 +230,7 @@ class Program(val name: String, val modules: MutableList<Module>): Node {
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Module && replacement is Module) require(node is Module && replacement is Module)
val idx = modules.indexOf(node) val idx = modules.withIndex().find { it.value===node }!!.index
modules[idx] = replacement modules[idx] = replacement
replacement.parent = this replacement.parent = this
} }
@ -257,14 +257,13 @@ class Module(override val name: String,
override fun definingScope(): INameScope = program.namespace override fun definingScope(): INameScope = program.namespace
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Statement && replacement is Statement) require(node is Statement && replacement is Statement)
val idx = statements.indexOf(node) val idx = statements.withIndex().find { it.value===node }!!.index
statements[idx] = replacement statements[idx] = replacement
replacement.parent = this replacement.parent = this
} }
override fun toString() = "Module(name=$name, pos=$position, lib=$isLibraryModule)" override fun toString() = "Module(name=$name, pos=$position, lib=$isLibraryModule)"
fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this)
fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }

View File

@ -2,7 +2,7 @@ package prog8.ast.base
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
class FatalAstException (override var message: String) : Exception(message) open class FatalAstException (override var message: String) : Exception(message)
open class AstException (override var message: String) : Exception(message) open class AstException (override var message: String) : Exception(message)

View File

@ -6,7 +6,6 @@ import prog8.ast.processing.*
import prog8.compiler.CompilationOptions import prog8.compiler.CompilationOptions
import prog8.compiler.BeforeAsmGenerationAstChanger import prog8.compiler.BeforeAsmGenerationAstChanger
import prog8.optimizer.AssignmentTransformer import prog8.optimizer.AssignmentTransformer
import prog8.optimizer.FlattenAnonymousScopesAndNopRemover
internal fun Program.checkValid(compilerOptions: CompilationOptions, errors: ErrorReporter) { internal fun Program.checkValid(compilerOptions: CompilationOptions, errors: ErrorReporter) {
@ -26,12 +25,23 @@ internal fun Program.reorderStatements() {
reorder.applyModifications() reorder.applyModifications()
} }
internal fun Program.inlineSubroutines(): Int {
val reorder = SubroutineInliner(this)
reorder.visit(this)
return reorder.applyModifications()
}
internal fun Program.addTypecasts(errors: ErrorReporter) { internal fun Program.addTypecasts(errors: ErrorReporter) {
val caster = TypecastsAdder(this, errors) val caster = TypecastsAdder(this, errors)
caster.visit(this) caster.visit(this)
caster.applyModifications() caster.applyModifications()
} }
internal fun Program.verifyFunctionArgTypes() {
val fixer = VerifyFunctionArgTypes(this)
fixer.visit(this)
}
internal fun Program.transformAssignments(errors: ErrorReporter) { internal fun Program.transformAssignments(errors: ErrorReporter) {
val transform = AssignmentTransformer(this, errors) val transform = AssignmentTransformer(this, errors)
transform.visit(this) transform.visit(this)
@ -56,21 +66,23 @@ internal fun Program.checkRecursion(errors: ErrorReporter) {
} }
internal fun Program.checkIdentifiers(errors: ErrorReporter) { internal fun Program.checkIdentifiers(errors: ErrorReporter) {
val checker = AstIdentifiersChecker(this, errors)
checker.visit(this) val checker2 = AstIdentifiersChecker(this, errors)
checker2.visit(this)
if(errors.isEmpty()) {
val transforms = AstVariousTransforms(this)
transforms.visit(this)
transforms.applyModifications()
}
if (modules.map { it.name }.toSet().size != modules.size) { if (modules.map { it.name }.toSet().size != modules.size) {
throw FatalAstException("modules should all be unique") throw FatalAstException("modules should all be unique")
} }
} }
internal fun Program.makeForeverLoops() { internal fun Program.variousCleanups() {
val checker = ForeverLoopsMaker() val process = VariousCleanups()
checker.visit(this) process.visit(this)
checker.applyModifications() process.applyModifications()
}
internal fun Program.removeNopsFlattenAnonScopes() {
val flattener = FlattenAnonymousScopesAndNopRemover()
flattener.visit(this)
} }

View File

@ -4,11 +4,11 @@ import prog8.ast.*
import prog8.ast.antlr.escape import prog8.ast.antlr.escape
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.processing.AstWalker import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.CompilationTarget import prog8.compiler.target.CompilationTarget
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
import prog8.functions.CannotEvaluateException
import prog8.functions.NotConstArgumentException import prog8.functions.NotConstArgumentException
import prog8.functions.builtinFunctionReturnType import prog8.functions.builtinFunctionReturnType
import java.util.* import java.util.*
@ -20,10 +20,9 @@ val associativeOperators = setOf("+", "*", "&", "|", "^", "or", "and", "xor", "=
sealed class Expression: Node { sealed class Expression: Node {
abstract fun constValue(program: Program): NumericLiteralValue? abstract fun constValue(program: Program): NumericLiteralValue?
abstract fun accept(visitor: IAstModifyingVisitor): Expression
abstract fun accept(visitor: IAstVisitor) abstract fun accept(visitor: IAstVisitor)
abstract fun accept(visitor: AstWalker, parent: Node) abstract fun accept(visitor: AstWalker, parent: Node)
abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead abstract fun referencesIdentifiers(vararg name: String): Boolean
abstract fun inferType(program: Program): InferredTypes.InferredType abstract fun inferType(program: Program): InferredTypes.InferredType
infix fun isSameAs(other: Expression): Boolean { infix fun isSameAs(other: Expression): Boolean {
@ -65,7 +64,6 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid
} }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -123,7 +121,6 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
// binary expression should actually have been optimized away into a single value, before const value was requested... // binary expression should actually have been optimized away into a single value, before const value was requested...
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -237,7 +234,6 @@ class ArrayIndexedExpression(var identifier: IdentifierReference,
} }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -274,7 +270,6 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -309,7 +304,6 @@ data class AddressOf(var identifier: IdentifierReference, override val position:
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun referencesIdentifiers(vararg name: String) = false override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD) override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD)
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
} }
@ -328,7 +322,6 @@ class DirectMemoryRead(var addressExpression: Expression, override val position:
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -388,7 +381,6 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
override fun referencesIdentifiers(vararg name: String) = false override fun referencesIdentifiers(vararg name: String) = false
override fun constValue(program: Program) = this override fun constValue(program: Program) = this
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -479,7 +471,6 @@ class StructLiteralValue(var values: List<Expression>,
} }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -510,7 +501,6 @@ class StringLiteralValue(val value: String,
override fun referencesIdentifiers(vararg name: String) = false override fun referencesIdentifiers(vararg name: String) = false
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -539,19 +529,18 @@ class ArrayLiteralValue(val type: InferredTypes.InferredType, // inferred be
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression) require(replacement is Expression)
val idx = value.indexOf(node) val idx = value.withIndex().find { it.value===node }!!.index
value[idx] = replacement value[idx] = replacement
replacement.parent = this replacement.parent = this
} }
override fun referencesIdentifiers(vararg name: String) = value.any { it.referencesIdentifiers(*name) } override fun referencesIdentifiers(vararg name: String) = value.any { it.referencesIdentifiers(*name) }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
override fun toString(): String = "$value" override fun toString(): String = "$value"
override fun inferType(program: Program): InferredTypes.InferredType = if(type.isUnknown) type else guessDatatype(program) override fun inferType(program: Program): InferredTypes.InferredType = if(type.isKnown) type else guessDatatype(program)
operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position) operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position)
override fun hashCode(): Int = Objects.hash(value, type) override fun hashCode(): Int = Objects.hash(value, type)
@ -640,7 +629,6 @@ class RangeExpr(var from: Expression,
} }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -720,7 +708,6 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex
} }
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -744,6 +731,8 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
fun targetVarDecl(namespace: INameScope): VarDecl? = targetStatement(namespace) as? VarDecl fun targetVarDecl(namespace: INameScope): VarDecl? = targetStatement(namespace) as? VarDecl
fun targetSubroutine(namespace: INameScope): Subroutine? = targetStatement(namespace) as? Subroutine fun targetSubroutine(namespace: INameScope): Subroutine? = targetStatement(namespace) as? Subroutine
override fun equals(other: Any?) = other is IdentifierReference && other.nameInSource==nameInSource
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
this.parent = parent this.parent = parent
} }
@ -768,7 +757,6 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
return "IdentifierRef($nameInSource)" return "IdentifierRef($nameInSource)"
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -812,7 +800,7 @@ class FunctionCall(override var target: IdentifierReference,
if(node===target) if(node===target)
target=replacement as IdentifierReference target=replacement as IdentifierReference
else { else {
val idx = args.indexOf(node) val idx = args.withIndex().find { it.value===node }!!.index
args[idx] = replacement as Expression args[idx] = replacement as Expression
} }
replacement.parent = this replacement.parent = this
@ -848,13 +836,16 @@ class FunctionCall(override var target: IdentifierReference,
// const-evaluating the builtin function call failed. // const-evaluating the builtin function call failed.
return null return null
} }
catch(x: CannotEvaluateException) {
// const-evaluating the builtin function call failed.
return null
}
} }
override fun toString(): String { override fun toString(): String {
return "FunctionCall(target=$target, pos=$position)" return "FunctionCall(target=$target, pos=$position)"
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)

View File

@ -326,7 +326,7 @@ internal class AstChecker(private val program: Program,
} }
override fun visit(repeatLoop: RepeatLoop) { override fun visit(repeatLoop: RepeatLoop) {
if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y")) // TODO use callgraph? if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y"))
errors.warn("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position) errors.warn("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position)
if(repeatLoop.untilCondition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes) if(repeatLoop.untilCondition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
errors.err("condition value should be an integer type", repeatLoop.untilCondition.position) errors.err("condition value should be an integer type", repeatLoop.untilCondition.position)
@ -334,7 +334,7 @@ internal class AstChecker(private val program: Program,
} }
override fun visit(whileLoop: WhileLoop) { override fun visit(whileLoop: WhileLoop) {
if(whileLoop.condition.referencesIdentifiers("A", "X", "Y")) // TODO use callgraph? if(whileLoop.condition.referencesIdentifiers("A", "X", "Y"))
errors.warn("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position) errors.warn("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position)
if(whileLoop.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes) if(whileLoop.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
errors.err("condition value should be an integer type", whileLoop.condition.position) errors.err("condition value should be an integer type", whileLoop.condition.position)
@ -463,7 +463,6 @@ internal class AstChecker(private val program: Program,
} }
// the initializer value can't refer to the variable itself (recursive definition) // the initializer value can't refer to the variable itself (recursive definition)
// TODO use callgraph for check?
if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) { if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) {
err("recursive var declaration") err("recursive var declaration")
} }
@ -733,7 +732,7 @@ internal class AstChecker(private val program: Program,
} }
"**" -> { "**" -> {
if(leftDt in IntegerDatatypes) if(leftDt in IntegerDatatypes)
errors.err("power operator requires floating point", expr.position) errors.err("power operator requires floating point operands", expr.position)
} }
"and", "or", "xor" -> { "and", "or", "xor" -> {
// only integer numeric operands accepted, and if literal constants, only boolean values accepted (0 or 1) // only integer numeric operands accepted, and if literal constants, only boolean values accepted (0 or 1)

View File

@ -1,8 +1,6 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.INameScope
import prog8.ast.Module import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
@ -10,98 +8,69 @@ import prog8.ast.statements.*
import prog8.compiler.target.CompilationTarget import prog8.compiler.target.CompilationTarget
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
// TODO implement using AstWalker instead of IAstModifyingVisitor internal class AstIdentifiersChecker(private val program: Program, private val errors: ErrorReporter) : IAstVisitor {
internal class AstIdentifiersChecker(private val program: Program,
private val errors: ErrorReporter) : IAstModifyingVisitor {
private var blocks = mutableMapOf<String, Block>() private var blocks = mutableMapOf<String, Block>()
private val vardeclsToAdd = mutableMapOf<INameScope, MutableList<VarDecl>>()
private fun nameError(name: String, position: Position, existing: Statement) { private fun nameError(name: String, position: Position, existing: Statement) {
errors.err("name conflict '$name', also defined in ${existing.position.file} line ${existing.position.line}", position) errors.err("name conflict '$name', also defined in ${existing.position.file} line ${existing.position.line}", position)
} }
override fun visit(module: Module) { override fun visit(module: Module) {
vardeclsToAdd.clear()
blocks.clear() // blocks may be redefined within a different module blocks.clear() // blocks may be redefined within a different module
super.visit(module) super.visit(module)
// add any new vardecls to the various scopes
for((where, decls) in vardeclsToAdd) {
where.statements.addAll(0, decls)
decls.forEach { it.linkParents(where as Node) }
}
} }
override fun visit(block: Block): Statement { override fun visit(block: Block) {
val existing = blocks[block.name] val existing = blocks[block.name]
if(existing!=null) if(existing!=null)
nameError(block.name, block.position, existing) nameError(block.name, block.position, existing)
else else
blocks[block.name] = block blocks[block.name] = block
return super.visit(block) super.visit(block)
} }
override fun visit(functionCall: FunctionCall): Expression { override fun visit(decl: VarDecl) {
if(functionCall.target.nameInSource.size==1 && functionCall.target.nameInSource[0]=="lsb") {
// lsb(...) is just an alias for type cast to ubyte, so replace with "... as ubyte"
val typecast = TypecastExpression(functionCall.args.single(), DataType.UBYTE, false, functionCall.position)
typecast.linkParents(functionCall.parent)
return super.visit(typecast)
}
return super.visit(functionCall)
}
override fun visit(decl: VarDecl): Statement {
// first, check if there are datatype errors on the vardecl
decl.datatypeErrors.forEach { errors.err(it.message, it.position) } decl.datatypeErrors.forEach { errors.err(it.message, it.position) }
// now check the identifier
if(decl.name in BuiltinFunctions) if(decl.name in BuiltinFunctions)
// the builtin functions can't be redefined
errors.err("builtin function cannot be redefined", decl.position) errors.err("builtin function cannot be redefined", decl.position)
if(decl.name in CompilationTarget.machine.opcodeNames) if(decl.name in CompilationTarget.machine.opcodeNames)
errors.err("can't use a cpu opcode name as a symbol: '${decl.name}'", decl.position) errors.err("can't use a cpu opcode name as a symbol: '${decl.name}'", decl.position)
// is it a struct variable? then define all its struct members as mangled names,
// and include the original decl as well.
if(decl.datatype==DataType.STRUCT) { if(decl.datatype==DataType.STRUCT) {
if(decl.structHasBeenFlattened) if (decl.structHasBeenFlattened)
return super.visit(decl) // don't do this multiple times return super.visit(decl) // don't do this multiple times
if(decl.struct==null) { if (decl.struct == null) {
errors.err("undefined struct type", decl.position) errors.err("undefined struct type", decl.position)
return super.visit(decl) return super.visit(decl)
} }
if(decl.struct!!.statements.any { (it as VarDecl).datatype !in NumericDatatypes}) if (decl.struct!!.statements.any { (it as VarDecl).datatype !in NumericDatatypes })
return super.visit(decl) // a non-numeric member, not supported. proper error is given by AstChecker later return super.visit(decl) // a non-numeric member, not supported. proper error is given by AstChecker later
if(decl.value is NumericLiteralValue) { if (decl.value is NumericLiteralValue) {
errors.err("you cannot initialize a struct using a single value", decl.position) errors.err("you cannot initialize a struct using a single value", decl.position)
return super.visit(decl) return super.visit(decl)
} }
if(decl.value != null && decl.value !is StructLiteralValue) { if (decl.value != null && decl.value !is StructLiteralValue) {
errors.err("initializing requires struct literal value", decl.value?.position ?: decl.position) errors.err("initializing requires struct literal value", decl.value?.position ?: decl.position)
return super.visit(decl) return super.visit(decl)
} }
val decls = decl.flattenStructMembers()
decls.add(decl)
val result = AnonymousScope(decls, decl.position)
result.linkParents(decl.parent)
return result
} }
val existing = program.namespace.lookup(listOf(decl.name), decl) val existing = program.namespace.lookup(listOf(decl.name), decl)
if (existing != null && existing !== decl) if (existing != null && existing !== decl)
nameError(decl.name, decl.position, existing) nameError(decl.name, decl.position, existing)
return super.visit(decl) super.visit(decl)
} }
override fun visit(subroutine: Subroutine): Statement { override fun visit(subroutine: Subroutine) {
if(subroutine.name in CompilationTarget.machine.opcodeNames) { if(subroutine.name in CompilationTarget.machine.opcodeNames) {
errors.err("can't use a cpu opcode name as a symbol: '${subroutine.name}'", subroutine.position) errors.err("can't use a cpu opcode name as a symbol: '${subroutine.name}'", subroutine.position)
} else if(subroutine.name in BuiltinFunctions) { } else if(subroutine.name in BuiltinFunctions) {
@ -138,30 +107,15 @@ internal class AstIdentifiersChecker(private val program: Program,
nameError(name, sub.position, subroutine) nameError(name, sub.position, subroutine)
} }
// inject subroutine params as local variables (if they're not there yet) (for non-kernel subroutines and non-asm parameters)
// NOTE:
// - numeric types BYTE and WORD and FLOAT are passed by value;
// - strings, arrays, matrices are passed by reference (their 16-bit address is passed as an uword parameter)
if(subroutine.asmAddress==null) {
if(subroutine.asmParameterRegisters.isEmpty()) {
subroutine.parameters
.filter { it.name !in namesInSub }
.forEach {
val vardecl = ParameterVarDecl(it.name, it.type, subroutine.position)
vardecl.linkParents(subroutine)
subroutine.statements.add(0, vardecl)
}
}
}
if(subroutine.isAsmSubroutine && subroutine.statements.any{it !is InlineAssembly}) { if(subroutine.isAsmSubroutine && subroutine.statements.any{it !is InlineAssembly}) {
errors.err("asmsub can only contain inline assembly (%asm)", subroutine.position) errors.err("asmsub can only contain inline assembly (%asm)", subroutine.position)
} }
} }
return super.visit(subroutine)
super.visit(subroutine)
} }
override fun visit(label: Label): Statement { override fun visit(label: Label) {
if(label.name in CompilationTarget.machine.opcodeNames) if(label.name in CompilationTarget.machine.opcodeNames)
errors.err("can't use a cpu opcode name as a symbol: '${label.name}'", label.position) errors.err("can't use a cpu opcode name as a symbol: '${label.name}'", label.position)
@ -179,163 +133,39 @@ internal class AstIdentifiersChecker(private val program: Program,
} }
} }
} }
return super.visit(label)
super.visit(label)
} }
override fun visit(forLoop: ForLoop): Statement { override fun visit(forLoop: ForLoop) {
// If the for loop has a decltype, it means to declare the loopvar inside the loop body if (forLoop.loopRegister != null) {
// rather than reusing an already declared loopvar from an outer scope. if (forLoop.loopRegister == Register.X)
// For loops that loop over an interable variable (instead of a range of numbers) get an
// additional interation count variable in their scope.
if(forLoop.loopRegister!=null) {
if(forLoop.loopRegister == Register.X)
errors.warn("writing to the X register is dangerous, because it's used as an internal pointer", forLoop.position) errors.warn("writing to the X register is dangerous, because it's used as an internal pointer", forLoop.position)
} else {
val loopVar = forLoop.loopVar
if (loopVar != null) {
val validName = forLoop.body.name.replace("<", "").replace(">", "").replace("-", "")
val loopvarName = "prog8_loopvar_$validName"
if (forLoop.iterable !is RangeExpr) {
val existing = if (forLoop.body.containsNoCodeNorVars()) null else forLoop.body.lookup(listOf(loopvarName), forLoop.body.statements.first())
if (existing == null) {
// create loop iteration counter variable (without value, to avoid an assignment)
val vardecl = VarDecl(VarDeclType.VAR, DataType.UBYTE, ZeropageWish.PREFER_ZEROPAGE, null, loopvarName, null, null,
isArray = false, autogeneratedDontRemove = true, position = loopVar.position)
vardecl.linkParents(forLoop.body)
forLoop.body.statements.add(0, vardecl)
loopVar.parent = forLoop.body // loopvar 'is defined in the body'
}
}
}
}
return super.visit(forLoop)
} }
override fun visit(assignTarget: AssignTarget): AssignTarget { super.visit(forLoop)
}
override fun visit(assignTarget: AssignTarget) {
if(assignTarget.register== Register.X) if(assignTarget.register== Register.X)
errors.warn("writing to the X register is dangerous, because it's used as an internal pointer", assignTarget.position) errors.warn("writing to the X register is dangerous, because it's used as an internal pointer", assignTarget.position)
return super.visit(assignTarget) super.visit(assignTarget)
} }
override fun visit(arrayLiteral: ArrayLiteralValue): Expression { override fun visit(string: StringLiteralValue) {
val array = super.visit(arrayLiteral)
if(array is ArrayLiteralValue) {
val vardecl = array.parent as? VarDecl
// adjust the datatype of the array (to an educated guess)
if(vardecl!=null) {
val arrayDt = array.type
if(!arrayDt.istype(vardecl.datatype)) {
val cast = array.cast(vardecl.datatype)
if (cast != null) {
vardecl.value = cast
cast.linkParents(vardecl)
return cast
}
}
return array
}
else {
val arrayDt = array.guessDatatype(program)
if(arrayDt.isKnown) {
// this array literal is part of an expression, turn it into an identifier reference
val litval2 = array.cast(arrayDt.typeOrElse(DataType.STRUCT))
return if (litval2 != null) {
litval2.parent = array.parent
makeIdentifierFromRefLv(litval2)
} else array
}
}
}
return array
}
override fun visit(stringLiteral: StringLiteralValue): Expression {
val string = super.visit(stringLiteral)
if(string is StringLiteralValue) {
val vardecl = string.parent as? VarDecl
// intern the string; move it into the heap
if (string.value.length !in 1..255) if (string.value.length !in 1..255)
errors.err("string literal length must be between 1 and 255", string.position) errors.err("string literal length must be between 1 and 255", string.position)
return if (vardecl != null)
string super.visit(string)
else
makeIdentifierFromRefLv(string) // replace the literal string by a identifier reference.
}
return string
} }
private fun makeIdentifierFromRefLv(array: ArrayLiteralValue): IdentifierReference { override fun visit(structDecl: StructDecl) {
// a referencetype literal value that's not declared as a variable
// we need to introduce an auto-generated variable for this to be able to refer to the value
// note: if the var references the same literal value, it is not yet de-duplicated here.
val scope = array.definingScope()
val variable = VarDecl.createAuto(array)
return replaceWithIdentifier(variable, scope, array.parent)
}
private fun makeIdentifierFromRefLv(string: StringLiteralValue): IdentifierReference {
// a referencetype literal value that's not declared as a variable
// we need to introduce an auto-generated variable for this to be able to refer to the value
// note: if the var references the same literal value, it is not yet de-duplicated here.
val scope = string.definingScope()
val variable = VarDecl.createAuto(string)
return replaceWithIdentifier(variable, scope, string.parent)
}
private fun replaceWithIdentifier(variable: VarDecl, scope: INameScope, parent: Node): IdentifierReference {
val variable1 = addVarDecl(scope, variable)
// replace the reference literal by a identifier reference
val identifier = IdentifierReference(listOf(variable1.name), variable1.position)
identifier.parent = parent
return identifier
}
override fun visit(structDecl: StructDecl): Statement {
for(member in structDecl.statements){ for(member in structDecl.statements){
val decl = member as? VarDecl val decl = member as? VarDecl
if(decl!=null && decl.datatype !in NumericDatatypes) if(decl!=null && decl.datatype !in NumericDatatypes)
errors.err("structs can only contain numerical types", decl.position) errors.err("structs can only contain numerical types", decl.position)
} }
return super.visit(structDecl) super.visit(structDecl)
} }
override fun visit(expr: BinaryExpression): Expression {
return when {
expr.left is StringLiteralValue ->
processBinaryExprWithString(expr.left as StringLiteralValue, expr.right, expr)
expr.right is StringLiteralValue ->
processBinaryExprWithString(expr.right as StringLiteralValue, expr.left, expr)
else -> super.visit(expr)
}
}
private fun processBinaryExprWithString(string: StringLiteralValue, operand: Expression, expr: BinaryExpression): Expression {
val constvalue = operand.constValue(program)
if(constvalue!=null) {
if (expr.operator == "*") {
// repeat a string a number of times
return StringLiteralValue(string.value.repeat(constvalue.number.toInt()), string.altEncoding, expr.position)
}
}
if(expr.operator == "+" && operand is StringLiteralValue) {
// concatenate two strings
return StringLiteralValue("${string.value}${operand.value}", string.altEncoding, expr.position)
}
return expr
}
private fun addVarDecl(scope: INameScope, variable: VarDecl): VarDecl {
if(scope !in vardeclsToAdd)
vardeclsToAdd[scope] = mutableListOf()
val declList = vardeclsToAdd.getValue(scope)
val existing = declList.singleOrNull { it.name==variable.name }
return if(existing!=null) {
existing
} else {
declList.add(variable)
variable
}
}
} }

View File

@ -0,0 +1,176 @@
package prog8.ast.processing
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.statements.*
internal class AstVariousTransforms(private val program: Program) : AstWalker() {
private val noModifications = emptyList<IAstModification>()
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
if(functionCallStatement.target.nameInSource == listOf("swap")) {
// if x and y are both just identifiers, do not rewrite (there should be asm generation for that)
// otherwise:
// rewrite swap(x,y) as follows:
// - declare a temp variable of the same datatype
// - temp = x, x = y, y= temp
val first = functionCallStatement.args[0]
val second = functionCallStatement.args[1]
if(first !is IdentifierReference && second !is IdentifierReference) {
val dt = first.inferType(program).typeOrElse(DataType.STRUCT)
val tempname = "prog8_swaptmp_${functionCallStatement.hashCode()}"
val tempvardecl = VarDecl(VarDeclType.VAR, dt, ZeropageWish.DONTCARE, null, tempname, null, null, isArray = false, autogeneratedDontRemove = true, position = first.position)
val tempvar = IdentifierReference(listOf(tempname), first.position)
val assignTemp = Assignment(
AssignTarget(null, tempvar, null, null, first.position),
null,
first,
first.position
)
val assignFirst = Assignment(
AssignTarget.fromExpr(first),
null,
second,
first.position
)
val assignSecond = Assignment(
AssignTarget.fromExpr(second),
null,
tempvar,
first.position
)
val scope = AnonymousScope(mutableListOf(tempvardecl, assignTemp, assignFirst, assignSecond), first.position)
return listOf(IAstModification.ReplaceNode(functionCallStatement, scope, parent))
}
}
return noModifications
}
override fun before(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
if(functionCall.target.nameInSource.size==1 && functionCall.target.nameInSource[0]=="lsb") {
// lsb(...) is just an alias for type cast to ubyte, so replace with "... as ubyte"
val typecast = TypecastExpression(functionCall.args.single(), DataType.UBYTE, false, functionCall.position)
return listOf(IAstModification.ReplaceNode(
functionCall, typecast, parent
))
}
return noModifications
}
override fun before(decl: VarDecl, parent: Node): Iterable<IAstModification> {
// is it a struct variable? then define all its struct members as mangled names,
// and include the original decl as well.
if(decl.datatype==DataType.STRUCT && !decl.structHasBeenFlattened) {
val decls = decl.flattenStructMembers()
decls.add(decl)
val result = AnonymousScope(decls, decl.position)
return listOf(IAstModification.ReplaceNode(
decl, result, parent
))
}
return noModifications
}
override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
// For non-kernel subroutines and non-asm parameters:
// inject subroutine params as local variables (if they're not there yet).
val symbolsInSub = subroutine.allDefinedSymbols()
val namesInSub = symbolsInSub.map{ it.first }.toSet()
if(subroutine.asmAddress==null) {
if(subroutine.asmParameterRegisters.isEmpty() && subroutine.parameters.isNotEmpty()) {
val vars = subroutine.statements.filterIsInstance<VarDecl>().map { it.name }.toSet()
if(!vars.containsAll(subroutine.parameters.map{it.name})) {
return subroutine.parameters
.filter { it.name !in namesInSub }
.map {
val vardecl = ParameterVarDecl(it.name, it.type, subroutine.position)
IAstModification.InsertFirst(vardecl, subroutine)
}
}
}
}
return noModifications
}
override fun before(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
when {
expr.left is StringLiteralValue ->
return listOf(IAstModification.ReplaceNode(
expr,
processBinaryExprWithString(expr.left as StringLiteralValue, expr.right, expr),
parent
))
expr.right is StringLiteralValue ->
return listOf(IAstModification.ReplaceNode(
expr,
processBinaryExprWithString(expr.right as StringLiteralValue, expr.left, expr),
parent
))
}
return noModifications
}
override fun after(string: StringLiteralValue, parent: Node): Iterable<IAstModification> {
if(string.parent !is VarDecl) {
// replace the literal string by a identifier reference to a new local vardecl
val vardecl = VarDecl.createAuto(string)
val identifier = IdentifierReference(listOf(vardecl.name), vardecl.position)
return listOf(
IAstModification.ReplaceNode(string, identifier, parent),
IAstModification.InsertFirst(vardecl, string.definingScope() as Node)
)
}
return noModifications
}
override fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> {
val vardecl = array.parent as? VarDecl
if(vardecl!=null) {
// adjust the datatype of the array (to an educated guess)
val arrayDt = array.type
if(!arrayDt.istype(vardecl.datatype)) {
val cast = array.cast(vardecl.datatype)
if (cast != null && cast!=array)
return listOf(IAstModification.ReplaceNode(vardecl.value!!, cast, vardecl))
}
} else {
val arrayDt = array.guessDatatype(program)
if(arrayDt.isKnown) {
// this array literal is part of an expression, turn it into an identifier reference
val litval2 = array.cast(arrayDt.typeOrElse(DataType.STRUCT))
if(litval2!=null && litval2!=array) {
val vardecl2 = VarDecl.createAuto(litval2)
val identifier = IdentifierReference(listOf(vardecl2.name), vardecl2.position)
return listOf(
IAstModification.ReplaceNode(array, identifier, parent),
IAstModification.InsertFirst(vardecl2, array.definingScope() as Node)
)
}
}
}
return noModifications
}
private fun processBinaryExprWithString(string: StringLiteralValue, operand: Expression, expr: BinaryExpression): Expression {
val constvalue = operand.constValue(program)
if(constvalue!=null) {
if (expr.operator == "*") {
// repeat a string a number of times
return StringLiteralValue(string.value.repeat(constvalue.number.toInt()), string.altEncoding, expr.position)
}
}
if(expr.operator == "+" && operand is StringLiteralValue) {
// concatenate two strings
return StringLiteralValue("${string.value}${operand.value}", string.altEncoding, expr.position)
}
return expr
}
}

View File

@ -1,9 +1,6 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.INameScope import prog8.ast.*
import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.FatalAstException import prog8.ast.base.FatalAstException
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
@ -15,7 +12,7 @@ interface IAstModification {
class Remove(val node: Node, val parent: Node) : IAstModification { class Remove(val node: Node, val parent: Node) : IAstModification {
override fun perform() { override fun perform() {
if(parent is INameScope) { if(parent is INameScope) {
if (!parent.statements.remove(node)) if (!parent.statements.remove(node) && parent !is GlobalNamespace)
throw FatalAstException("attempt to remove non-existing node $node") throw FatalAstException("attempt to remove non-existing node $node")
} else { } else {
throw FatalAstException("parent of a remove modification is not an INameScope") throw FatalAstException("parent of a remove modification is not an INameScope")
@ -55,7 +52,7 @@ interface IAstModification {
class InsertAfter(val after: Statement, val stmt: Statement, val parent: Node) : IAstModification { class InsertAfter(val after: Statement, val stmt: Statement, val parent: Node) : IAstModification {
override fun perform() { override fun perform() {
if(parent is INameScope) { if(parent is INameScope) {
val idx = parent.statements.indexOf(after)+1 val idx = parent.statements.withIndex().find { it.value===after }!!.index + 1
parent.statements.add(idx, stmt) parent.statements.add(idx, stmt)
stmt.linkParents(parent) stmt.linkParents(parent)
} else { } else {

View File

@ -1,28 +0,0 @@
package prog8.ast.processing
import prog8.ast.Node
import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.statements.ForeverLoop
import prog8.ast.statements.RepeatLoop
import prog8.ast.statements.WhileLoop
internal class ForeverLoopsMaker: AstWalker() {
override fun before(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> {
val numeric = repeatLoop.untilCondition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() == 0) {
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
return listOf(IAstModification.ReplaceNode(repeatLoop, forever, parent))
}
return emptyList()
}
override fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
val numeric = whileLoop.condition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() != 0) {
val forever = ForeverLoop(whileLoop.body, whileLoop.position)
return listOf(IAstModification.ReplaceNode(whileLoop, forever, parent))
}
return emptyList()
}
}

View File

@ -1,267 +0,0 @@
package prog8.ast.processing
import prog8.ast.Module
import prog8.ast.Program
import prog8.ast.base.FatalAstException
import prog8.ast.expressions.*
import prog8.ast.statements.*
// TODO replace all occurrences of this with AstWalker
interface IAstModifyingVisitor {
fun visit(program: Program) {
program.modules.forEach { it.accept(this) }
}
fun visit(module: Module) {
module.statements = module.statements.map { it.accept(this) }.toMutableList()
}
fun visit(expr: PrefixExpression): Expression {
expr.expression = expr.expression.accept(this)
return expr
}
fun visit(expr: BinaryExpression): Expression {
expr.left = expr.left.accept(this)
expr.right = expr.right.accept(this)
return expr
}
fun visit(directive: Directive): Statement {
return directive
}
fun visit(block: Block): Statement {
block.statements = block.statements.map { it.accept(this) }.toMutableList()
return block
}
fun visit(decl: VarDecl): Statement {
decl.value = decl.value?.accept(this)
decl.arraysize?.accept(this)
return decl
}
fun visit(subroutine: Subroutine): Statement {
subroutine.statements = subroutine.statements.map { it.accept(this) }.toMutableList()
return subroutine
}
fun visit(functionCall: FunctionCall): Expression {
val newtarget = functionCall.target.accept(this)
if(newtarget is IdentifierReference)
functionCall.target = newtarget
else
throw FatalAstException("cannot change class of function call target")
functionCall.args = functionCall.args.map { it.accept(this) }.toMutableList()
return functionCall
}
fun visit(functionCallStatement: FunctionCallStatement): Statement {
val newtarget = functionCallStatement.target.accept(this)
if(newtarget is IdentifierReference)
functionCallStatement.target = newtarget
else
throw FatalAstException("cannot change class of function call target")
functionCallStatement.args = functionCallStatement.args.map { it.accept(this) }.toMutableList()
return functionCallStatement
}
fun visit(identifier: IdentifierReference): Expression {
// note: this is an identifier that is used in an expression.
// other identifiers are simply part of the other statements (such as jumps, subroutine defs etc)
return identifier
}
fun visit(jump: Jump): Statement {
if(jump.identifier!=null) {
val ident = jump.identifier.accept(this)
if(ident is IdentifierReference && ident!==jump.identifier) {
return Jump(null, ident, null, jump.position)
}
}
return jump
}
fun visit(ifStatement: IfStatement): Statement {
ifStatement.condition = ifStatement.condition.accept(this)
ifStatement.truepart = ifStatement.truepart.accept(this) as AnonymousScope
ifStatement.elsepart = ifStatement.elsepart.accept(this) as AnonymousScope
return ifStatement
}
fun visit(branchStatement: BranchStatement): Statement {
branchStatement.truepart = branchStatement.truepart.accept(this) as AnonymousScope
branchStatement.elsepart = branchStatement.elsepart.accept(this) as AnonymousScope
return branchStatement
}
fun visit(range: RangeExpr): Expression {
range.from = range.from.accept(this)
range.to = range.to.accept(this)
range.step = range.step.accept(this)
return range
}
fun visit(label: Label): Statement {
return label
}
fun visit(literalValue: NumericLiteralValue): NumericLiteralValue {
return literalValue
}
fun visit(stringLiteral: StringLiteralValue): Expression {
return stringLiteral
}
fun visit(arrayLiteral: ArrayLiteralValue): Expression {
for(av in arrayLiteral.value.withIndex()) {
val newvalue = av.value.accept(this)
arrayLiteral.value[av.index] = newvalue
}
return arrayLiteral
}
fun visit(assignment: Assignment): Statement {
assignment.target = assignment.target.accept(this)
assignment.value = assignment.value.accept(this)
return assignment
}
fun visit(postIncrDecr: PostIncrDecr): Statement {
postIncrDecr.target = postIncrDecr.target.accept(this)
return postIncrDecr
}
fun visit(contStmt: Continue): Statement {
return contStmt
}
fun visit(breakStmt: Break): Statement {
return breakStmt
}
fun visit(forLoop: ForLoop): Statement {
when(val newloopvar = forLoop.loopVar?.accept(this)) {
is IdentifierReference -> forLoop.loopVar = newloopvar
null -> forLoop.loopVar = null
else -> throw FatalAstException("can't change class of loopvar")
}
forLoop.iterable = forLoop.iterable.accept(this)
forLoop.body = forLoop.body.accept(this) as AnonymousScope
return forLoop
}
fun visit(whileLoop: WhileLoop): Statement {
whileLoop.condition = whileLoop.condition.accept(this)
whileLoop.body = whileLoop.body.accept(this) as AnonymousScope
return whileLoop
}
fun visit(foreverLoop: ForeverLoop): Statement {
foreverLoop.body = foreverLoop.body.accept(this) as AnonymousScope
return foreverLoop
}
fun visit(repeatLoop: RepeatLoop): Statement {
repeatLoop.untilCondition = repeatLoop.untilCondition.accept(this)
repeatLoop.body = repeatLoop.body.accept(this) as AnonymousScope
return repeatLoop
}
fun visit(returnStmt: Return): Statement {
returnStmt.value = returnStmt.value?.accept(this)
return returnStmt
}
fun visit(arrayIndexedExpression: ArrayIndexedExpression): ArrayIndexedExpression {
val ident = arrayIndexedExpression.identifier.accept(this)
if(ident is IdentifierReference)
arrayIndexedExpression.identifier = ident
arrayIndexedExpression.arrayspec.accept(this)
return arrayIndexedExpression
}
fun visit(assignTarget: AssignTarget): AssignTarget {
when (val ident = assignTarget.identifier?.accept(this)) {
is IdentifierReference -> assignTarget.identifier = ident
null -> assignTarget.identifier = null
else -> throw FatalAstException("can't change class of assign target identifier")
}
assignTarget.arrayindexed = assignTarget.arrayindexed?.accept(this)
assignTarget.memoryAddress?.let { visit(it) }
return assignTarget
}
fun visit(scope: AnonymousScope): Statement {
scope.statements = scope.statements.map { it.accept(this) }.toMutableList()
return scope
}
fun visit(typecast: TypecastExpression): Expression {
typecast.expression = typecast.expression.accept(this)
return typecast
}
fun visit(memread: DirectMemoryRead): Expression {
memread.addressExpression = memread.addressExpression.accept(this)
return memread
}
fun visit(memwrite: DirectMemoryWrite) {
memwrite.addressExpression = memwrite.addressExpression.accept(this)
}
fun visit(addressOf: AddressOf): Expression {
val ident = addressOf.identifier.accept(this)
if(ident is IdentifierReference)
addressOf.identifier = ident
else
throw FatalAstException("can't change class of addressof identifier")
return addressOf
}
fun visit(inlineAssembly: InlineAssembly): Statement {
return inlineAssembly
}
fun visit(registerExpr: RegisterExpr): Expression {
return registerExpr
}
fun visit(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder): Statement {
return builtinFunctionStatementPlaceholder
}
fun visit(nopStatement: NopStatement): Statement {
return nopStatement
}
fun visit(whenStatement: WhenStatement): Statement {
whenStatement.condition = whenStatement.condition.accept(this)
whenStatement.choices.forEach { it.accept(this) }
return whenStatement
}
fun visit(whenChoice: WhenChoice) {
whenChoice.values = whenChoice.values?.map { it.accept(this) }
val stmt = whenChoice.statements.accept(this)
if(stmt is AnonymousScope)
whenChoice.statements = stmt
else {
whenChoice.statements = AnonymousScope(mutableListOf(stmt), stmt.position)
whenChoice.statements.linkParents(whenChoice)
}
}
fun visit(structDecl: StructDecl): Statement {
structDecl.statements = structDecl.statements.map{ it.accept(this) }.toMutableList()
return structDecl
}
fun visit(structLv: StructLiteralValue): Expression {
structLv.values = structLv.values.map { it.accept(this) }
return structLv
}
}

View File

@ -10,11 +10,12 @@ internal class ImportedModuleDirectiveRemover: AstWalker() {
*/ */
private val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address") private val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address")
private val noModifications = emptyList<IAstModification>()
override fun before(directive: Directive, parent: Node): Iterable<IAstModification> { override fun before(directive: Directive, parent: Node): Iterable<IAstModification> {
if(directive.directive in moduleLevelDirectives) { if(directive.directive in moduleLevelDirectives) {
return listOf(IAstModification.Remove(directive, parent)) return listOf(IAstModification.Remove(directive, parent))
} }
return emptyList() return noModifications
} }
} }

View File

@ -19,7 +19,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
// - sorts the choices in when statement. // - sorts the choices in when statement.
// - insert AddressOf (&) expression where required (string params to a UWORD function param etc). // - insert AddressOf (&) expression where required (string params to a UWORD function param etc).
private val noModifications = emptyList<IAstModification>()
private val directivesToMove = setOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address", "%option") private val directivesToMove = setOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address", "%option")
override fun after(module: Module, parent: Node): Iterable<IAstModification> { override fun after(module: Module, parent: Node): Iterable<IAstModification> {
@ -33,7 +33,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
} }
reorderVardeclsAndDirectives(module.statements) reorderVardeclsAndDirectives(module.statements)
return emptyList() return noModifications
} }
private fun reorderVardeclsAndDirectives(statements: MutableList<Statement>) { private fun reorderVardeclsAndDirectives(statements: MutableList<Statement>) {
@ -56,7 +56,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
} }
reorderVardeclsAndDirectives(block.statements) reorderVardeclsAndDirectives(block.statements)
return emptyList() return noModifications
} }
override fun before(subroutine: Subroutine, parent: Node): Iterable<IAstModification> { override fun before(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
@ -68,7 +68,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
) )
} }
} }
return emptyList() return noModifications
} }
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> { override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
@ -86,7 +86,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
) )
} }
} }
return emptyList() return noModifications
} }
override fun after(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> { override fun after(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> {
@ -95,7 +95,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
} }
whenStatement.choices.clear() whenStatement.choices.clear()
choices.mapTo(whenStatement.choices) { it.second } choices.mapTo(whenStatement.choices) { it.second }
return emptyList() return noModifications
} }
override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> { override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> {
@ -119,7 +119,7 @@ internal class StatementReorderer(val program: Program) : AstWalker() {
} }
} }
return emptyList() return noModifications
} }
private fun flattenStructAssignmentFromStructLiteral(structAssignment: Assignment, program: Program): List<Assignment> { private fun flattenStructAssignmentFromStructLiteral(structAssignment: Assignment, program: Program): List<Assignment> {

View File

@ -0,0 +1,39 @@
package prog8.ast.processing
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.statements.*
import prog8.optimizer.CallGraph
internal class SubroutineInliner(private val program: Program) : AstWalker() {
private val noModifications = emptyList<IAstModification>()
private val callgraph = CallGraph(program)
override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
if(!subroutine.isAsmSubroutine && callgraph.calledBy[subroutine]!=null && subroutine.containsCodeOrVars()) {
// TODO for now, inlined subroutines can't have parameters or local variables - improve this
if(subroutine.parameters.isEmpty() && subroutine.containsNoVars()) {
if (subroutine.countStatements() <= 5) {
if (callgraph.calledBy.getValue(subroutine).size == 1 || !subroutine.statements.any { it.expensiveToInline })
return inline(subroutine)
}
}
}
return noModifications
}
private fun inline(subroutine: Subroutine): Iterable<IAstModification> {
val calls = callgraph.calledBy.getValue(subroutine)
return calls.map {
call -> IAstModification.ReplaceNode(
call,
AnonymousScope(subroutine.statements, call.position),
call.parent
)
}.plus(IAstModification.Remove(subroutine, subroutine.parent))
}
}

View File

@ -16,6 +16,8 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
* (this includes function call arguments) * (this includes function call arguments)
*/ */
private val noModifications = emptyList<IAstModification>()
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
val leftDt = expr.left.inferType(program) val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program) val rightDt = expr.right.inferType(program)
@ -32,7 +34,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
} }
} }
} }
return emptyList() return noModifications
} }
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> { override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
@ -72,7 +74,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
} }
} }
} }
return emptyList() return noModifications
} }
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> { override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
@ -85,7 +87,9 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable<IAstModification> { private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable<IAstModification> {
// see if a typecast is needed to convert the arguments into the required parameter's type // see if a typecast is needed to convert the arguments into the required parameter's type
return when(val sub = call.target.targetStatement(scope)) { val modifications = mutableListOf<IAstModification>()
when(val sub = call.target.targetStatement(scope)) {
is Subroutine -> { is Subroutine -> {
for(arg in sub.parameters.zip(call.args.withIndex())) { for(arg in sub.parameters.zip(call.args.withIndex())) {
val argItype = arg.second.value.inferType(program) val argItype = arg.second.value.inferType(program)
@ -94,26 +98,33 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
val requiredType = arg.first.type val requiredType = arg.first.type
if (requiredType != argtype) { if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) { if (argtype isAssignableTo requiredType) {
return listOf(IAstModification.ReplaceNode( modifications += IAstModification.ReplaceNode(
call.args[arg.second.index], call.args[arg.second.index],
TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position), TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position),
call as Node)) call as Node)
} else if(requiredType == DataType.UWORD && argtype in PassByReferenceDatatypes) { } else if(requiredType == DataType.UWORD && argtype in PassByReferenceDatatypes) {
// we allow STR/ARRAY values in place of UWORD parameters. Take their address instead. // we allow STR/ARRAY values in place of UWORD parameters. Take their address instead.
return listOf(IAstModification.ReplaceNode( modifications += IAstModification.ReplaceNode(
call.args[arg.second.index], call.args[arg.second.index],
AddressOf(arg.second.value as IdentifierReference, arg.second.value.position), AddressOf(arg.second.value as IdentifierReference, arg.second.value.position),
call as Node)) call as Node)
} else if(arg.second.value is NumericLiteralValue) {
try {
val castedValue = (arg.second.value as NumericLiteralValue).cast(requiredType)
modifications += IAstModification.ReplaceNode(
call.args[arg.second.index],
castedValue,
call as Node)
} catch (x: ExpressionError) {
// no cast possible
}
} }
} }
} }
} }
emptyList()
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionStatementPlaceholder -> {
val func = BuiltinFunctions.getValue(sub.name) val func = BuiltinFunctions.getValue(sub.name)
if(func.pure) {
// non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters
for (arg in func.parameters.zip(call.args.withIndex())) { for (arg in func.parameters.zip(call.args.withIndex())) {
val argItype = arg.second.value.inferType(program) val argItype = arg.second.value.inferType(program)
if (argItype.isKnown) { if (argItype.isKnown) {
@ -122,20 +133,20 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
continue continue
for (possibleType in arg.first.possibleDatatypes) { for (possibleType in arg.first.possibleDatatypes) {
if (argtype isAssignableTo possibleType) { if (argtype isAssignableTo possibleType) {
return listOf(IAstModification.ReplaceNode( modifications += IAstModification.ReplaceNode(
call.args[arg.second.index], call.args[arg.second.index],
TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position), TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position),
call as Node)) call as Node)
} }
} }
} }
} }
} }
emptyList() null -> { }
}
null -> emptyList()
else -> throw FatalAstException("call to something weird $sub ${call.target}") else -> throw FatalAstException("call to something weird $sub ${call.target}")
} }
return modifications
} }
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> { override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
@ -143,7 +154,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) { if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) {
errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position) errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position)
} }
return emptyList() return noModifications
} }
override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> { override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
@ -154,7 +165,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) ?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread)) return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread))
} }
return emptyList() return noModifications
} }
override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> { override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
@ -165,7 +176,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) ?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite)) return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite))
} }
return emptyList() return noModifications
} }
override fun after(structLv: StructLiteralValue, parent: Node): Iterable<IAstModification> { override fun after(structLv: StructLiteralValue, parent: Node): Iterable<IAstModification> {
@ -210,7 +221,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
} }
} }
} }
return emptyList() return noModifications
} }
override fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> { override fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> {
@ -221,7 +232,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
if(subroutine.returntypes.size==1) { if(subroutine.returntypes.size==1) {
val subReturnType = subroutine.returntypes.first() val subReturnType = subroutine.returntypes.first()
if (returnValue.inferType(program).istype(subReturnType)) if (returnValue.inferType(program).istype(subReturnType))
return emptyList() return noModifications
if (returnValue is NumericLiteralValue) { if (returnValue is NumericLiteralValue) {
returnStmt.value = returnValue.cast(subroutine.returntypes.single()) returnStmt.value = returnValue.cast(subroutine.returntypes.single())
} else { } else {
@ -232,6 +243,6 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
} }
} }
} }
return emptyList() return noModifications
} }
} }

View File

@ -0,0 +1,43 @@
package prog8.ast.processing
import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.expressions.TypecastExpression
import prog8.ast.statements.AnonymousScope
import prog8.ast.statements.NopStatement
internal class VariousCleanups: AstWalker() {
private val noModifications = emptyList<IAstModification>()
override fun before(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> {
return listOf(IAstModification.Remove(nopStatement, parent))
}
override fun before(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
return if(parent is INameScope)
listOf(ScopeFlatten(scope, parent as INameScope))
else
noModifications
}
class ScopeFlatten(val scope: AnonymousScope, val into: INameScope) : IAstModification {
override fun perform() {
val idx = into.statements.indexOf(scope)
if(idx>=0) {
into.statements.addAll(idx+1, scope.statements)
into.statements.remove(scope)
}
}
}
override fun before(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
if(typecast.expression is NumericLiteralValue) {
val value = (typecast.expression as NumericLiteralValue).cast(typecast.type)
return listOf(IAstModification.ReplaceNode(typecast, value, parent))
}
return noModifications
}
}

View File

@ -0,0 +1,42 @@
package prog8.ast.processing
import prog8.ast.IFunctionCall
import prog8.ast.INameScope
import prog8.ast.Program
import prog8.ast.base.DataType
import prog8.ast.expressions.FunctionCall
import prog8.ast.statements.BuiltinFunctionStatementPlaceholder
import prog8.ast.statements.FunctionCallStatement
import prog8.ast.statements.Subroutine
import prog8.compiler.CompilerException
import prog8.functions.BuiltinFunctions
class VerifyFunctionArgTypes(val program: Program) : IAstVisitor {
override fun visit(functionCall: FunctionCall)
= checkTypes(functionCall as IFunctionCall, functionCall.definingScope())
override fun visit(functionCallStatement: FunctionCallStatement)
= checkTypes(functionCallStatement as IFunctionCall, functionCallStatement.definingScope())
private fun checkTypes(call: IFunctionCall, scope: INameScope) {
val argtypes = call.args.map { it.inferType(program).typeOrElse(DataType.STRUCT) }
val target = call.target.targetStatement(scope)
when(target) {
is Subroutine -> {
val paramtypes = target.parameters.map { it.type }
if(argtypes!=paramtypes)
throw CompilerException("parameter type mismatch $call")
}
is BuiltinFunctionStatementPlaceholder -> {
val func = BuiltinFunctions.getValue(target.name)
val paramtypes = func.parameters.map { it.possibleDatatypes }
for(x in argtypes.zip(paramtypes)) {
if(x.first !in x.second)
throw CompilerException("parameter type mismatch $call")
}
}
else -> {}
}
}
}

View File

@ -4,12 +4,10 @@ import prog8.ast.*
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.processing.AstWalker import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
sealed class Statement : Node { sealed class Statement : Node {
abstract fun accept(visitor: IAstModifyingVisitor) : Statement
abstract fun accept(visitor: IAstVisitor) abstract fun accept(visitor: IAstVisitor)
abstract fun accept(visitor: AstWalker, parent: Node) abstract fun accept(visitor: AstWalker, parent: Node)
@ -44,7 +42,6 @@ sealed class Statement : Node {
class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position) : Statement() { class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position) : Statement() {
override var parent: Node = ParentSentinel override var parent: Node = ParentSentinel
override fun linkParents(parent: Node) {} override fun linkParents(parent: Node) {}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
override fun definingScope(): INameScope = BuiltinFunctionScopePlaceholder override fun definingScope(): INameScope = BuiltinFunctionScopePlaceholder
@ -72,12 +69,11 @@ class Block(override val name: String,
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement) require(replacement is Statement)
val idx = statements.indexOf(node) val idx = statements.withIndex().find { it.value===node }!!.index
statements[idx] = replacement statements[idx] = replacement
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -98,7 +94,6 @@ data class Directive(val directive: String, val args: List<DirectiveArg>, overri
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -121,7 +116,6 @@ data class Label(val name: String, override val position: Position) : Statement(
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -145,7 +139,6 @@ open class Return(var value: Expression?, override val position: Position) : Sta
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -155,7 +148,6 @@ open class Return(var value: Expression?, override val position: Position) : Sta
} }
class ReturnFromIrq(override val position: Position) : Return(null, position) { class ReturnFromIrq(override val position: Position) : Return(null, position) {
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun toString(): String { override fun toString(): String {
@ -173,7 +165,6 @@ class Continue(override val position: Position) : Statement() {
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -187,7 +178,6 @@ class Break(override val position: Position) : Statement() {
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -281,7 +271,6 @@ open class VarDecl(val type: VarDeclType,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -338,11 +327,8 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node
} }
} }
fun accept(visitor: IAstModifyingVisitor) {
index = index.accept(visitor)
}
fun accept(visitor: IAstVisitor) = index.accept(visitor) fun accept(visitor: IAstVisitor) = index.accept(visitor)
fun accept(visitor: AstWalker, parent: Node) = index.accept(visitor, parent) fun accept(visitor: AstWalker, parent: Node) = index.accept(visitor, this)
override fun toString(): String { override fun toString(): String {
return("ArrayIndex($index, pos=$position)") return("ArrayIndex($index, pos=$position)")
@ -354,7 +340,7 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node
open class Assignment(var target: AssignTarget, var aug_op : String?, var value: Expression, override val position: Position) : Statement() { open class Assignment(var target: AssignTarget, var aug_op : String?, var value: Expression, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline override val expensiveToInline
get() = value !is NumericLiteralValue get() = value is BinaryExpression
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
this.parent = parent this.parent = parent
@ -371,7 +357,6 @@ open class Assignment(var target: AssignTarget, var aug_op : String?, var value:
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -427,7 +412,6 @@ data class AssignTarget(val register: Register?,
replacement.parent = this replacement.parent = this
} }
fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this)
fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -538,7 +522,6 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -560,7 +543,6 @@ class Jump(val address: Int?,
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -587,13 +569,12 @@ class FunctionCallStatement(override var target: IdentifierReference,
if(node===target) if(node===target)
target = replacement as IdentifierReference target = replacement as IdentifierReference
else { else {
val idx = args.indexOf(node) val idx = args.withIndex().find { it.value===node }!!.index
args[idx] = replacement as Expression args[idx] = replacement as Expression
} }
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -611,7 +592,6 @@ class InlineAssembly(val assembly: String, override val position: Position) : St
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -639,12 +619,11 @@ class AnonymousScope(override var statements: MutableList<Statement>,
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement) require(replacement is Statement)
val idx = statements.indexOf(node) val idx = statements.withIndex().find { it.value===node }!!.index
statements[idx] = replacement statements[idx] = replacement
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -658,7 +637,6 @@ class NopStatement(override val position: Position): Statement() {
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -690,8 +668,6 @@ class Subroutine(override val name: String,
get() = statements.any { it.expensiveToInline } get() = statements.any { it.expensiveToInline }
override lateinit var parent: Node override lateinit var parent: Node
val calledBy = mutableListOf<Node>()
val calls = mutableSetOf<Subroutine>()
val scopedname: String by lazy { makeScopedName(name) } val scopedname: String by lazy { makeScopedName(name) }
@ -703,12 +679,11 @@ class Subroutine(override val name: String,
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement) require(replacement is Statement)
val idx = statements.indexOf(node) val idx = statements.withIndex().find { it.value===node }!!.index
statements[idx] = replacement statements[idx] = replacement
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -723,6 +698,32 @@ class Subroutine(override val name: String,
.filter { it is InlineAssembly } .filter { it is InlineAssembly }
.map { (it as InlineAssembly).assembly } .map { (it as InlineAssembly).assembly }
.count { " rti" in it || "\trti" in it || " rts" in it || "\trts" in it || " jmp" in it || "\tjmp" in it } .count { " rti" in it || "\trti" in it || " rts" in it || "\trts" in it || " jmp" in it || "\tjmp" in it }
fun countStatements(): Int {
class StatementCounter: IAstVisitor {
var count = 0
override fun visit(block: Block) {
count += block.statements.size
super.visit(block)
}
override fun visit(subroutine: Subroutine) {
count += subroutine.statements.size
super.visit(subroutine)
}
override fun visit(scope: AnonymousScope) {
count += scope.statements.size
super.visit(scope)
}
}
// the (recursive) number of statements
val counter = StatementCounter()
counter.visit(this)
return counter.count
}
} }
@ -765,7 +766,6 @@ class IfStatement(var condition: Expression,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -794,7 +794,6 @@ class BranchStatement(var condition: BranchCondition,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -825,7 +824,6 @@ class ForLoop(val loopRegister: Register?,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -861,7 +859,6 @@ class WhileLoop(var condition: Expression,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -881,7 +878,6 @@ class ForeverLoop(var body: AnonymousScope, override val position: Position) : S
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -907,7 +903,6 @@ class RepeatLoop(var body: AnonymousScope,
replacement.parent = this replacement.parent = this
} }
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -928,7 +923,7 @@ class WhenStatement(var condition: Expression,
if(node===condition) if(node===condition)
condition = replacement as Expression condition = replacement as Expression
else { else {
val idx = choices.indexOf(node) val idx = choices.withIndex().find { it.value===node }!!.index
choices[idx] = replacement as WhenChoice choices[idx] = replacement as WhenChoice
} }
replacement.parent = this replacement.parent = this
@ -952,7 +947,6 @@ class WhenStatement(var condition: Expression,
} }
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -978,7 +972,6 @@ class WhenChoice(var values: List<Expression>?, // if null, this is t
} }
fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this)
fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@ -997,7 +990,7 @@ class StructDecl(override val name: String,
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement) require(replacement is Statement)
val idx = statements.indexOf(node) val idx = statements.withIndex().find { it.value===node }!!.index
statements[idx] = replacement statements[idx] = replacement
replacement.parent = this replacement.parent = this
} }
@ -1006,7 +999,6 @@ class StructDecl(override val name: String,
get() = this.statements.size get() = this.statements.size
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
fun nameOfFirstMember() = (statements.first() as VarDecl).name fun nameOfFirstMember() = (statements.first() as VarDecl).name
@ -1031,6 +1023,5 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position
} }
fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this)
fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }

View File

@ -11,12 +11,14 @@ import prog8.ast.statements.*
internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: ErrorReporter) : AstWalker() { internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: ErrorReporter) : AstWalker() {
private val noModifications = emptyList<IAstModification>()
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> { override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
if (decl.value == null && decl.type == VarDeclType.VAR && decl.datatype in NumericDatatypes) { if (decl.value == null && decl.type == VarDeclType.VAR && decl.datatype in NumericDatatypes) {
// a numeric vardecl without an initial value is initialized with zero. // a numeric vardecl without an initial value is initialized with zero.
decl.value = decl.zeroElementValue() decl.value = decl.zeroElementValue()
} }
return emptyList() return noModifications
} }
override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> { override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
@ -45,7 +47,7 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: E
decls.map { IAstModification.InsertFirst(it, sub) } // move it up to the subroutine decls.map { IAstModification.InsertFirst(it, sub) } // move it up to the subroutine
} }
} }
return emptyList() return noModifications
} }
override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> { override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
@ -99,6 +101,6 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: E
} }
} }
return emptyList() return noModifications
} }
} }

View File

@ -42,7 +42,7 @@ fun compileProgram(filepath: Path,
optimizeAst(programAst, errors) optimizeAst(programAst, errors)
postprocessAst(programAst, errors, compilationOptions) postprocessAst(programAst, errors, compilationOptions)
// printAst(programAst) // TODO // printAst(programAst)
if(writeAssembly) if(writeAssembly)
programName = writeAssembly(programAst, errors, outputDir, optimize, compilationOptions) programName = writeAssembly(programAst, errors, outputDir, optimize, compilationOptions)
@ -144,13 +144,12 @@ private fun processAst(programAst: Program, errors: ErrorReporter, compilerOptio
println("Processing...") println("Processing...")
programAst.checkIdentifiers(errors) programAst.checkIdentifiers(errors)
errors.handle() errors.handle()
programAst.makeForeverLoops()
programAst.constantFold(errors) programAst.constantFold(errors)
errors.handle() errors.handle()
programAst.removeNopsFlattenAnonScopes()
programAst.reorderStatements() programAst.reorderStatements()
programAst.addTypecasts(errors) programAst.addTypecasts(errors)
errors.handle() errors.handle()
programAst.variousCleanups()
programAst.checkValid(compilerOptions, errors) programAst.checkValid(compilerOptions, errors)
errors.handle() errors.handle()
programAst.checkIdentifiers(errors) programAst.checkIdentifiers(errors)
@ -164,9 +163,10 @@ private fun optimizeAst(programAst: Program, errors: ErrorReporter) {
// keep optimizing expressions and statements until no more steps remain // keep optimizing expressions and statements until no more steps remain
val optsDone1 = programAst.simplifyExpressions() val optsDone1 = programAst.simplifyExpressions()
val optsDone2 = programAst.optimizeStatements(errors) val optsDone2 = programAst.optimizeStatements(errors)
val optsDone3 = programAst.inlineSubroutines()
programAst.constantFold(errors) // because simplified statements and expressions could give rise to more constants that can be folded away: programAst.constantFold(errors) // because simplified statements and expressions could give rise to more constants that can be folded away:
errors.handle() errors.handle()
if (optsDone1 + optsDone2 == 0) if (optsDone1 + optsDone2 + optsDone3 == 0)
break break
} }
@ -180,11 +180,12 @@ private fun postprocessAst(programAst: Program, errors: ErrorReporter, compilerO
errors.handle() errors.handle()
programAst.addTypecasts(errors) programAst.addTypecasts(errors)
errors.handle() errors.handle()
programAst.removeNopsFlattenAnonScopes() programAst.variousCleanups()
programAst.checkValid(compilerOptions, errors) // check if final tree is still valid programAst.checkValid(compilerOptions, errors) // check if final tree is still valid
errors.handle() errors.handle()
programAst.checkRecursion(errors) // check if there are recursive subroutine calls programAst.checkRecursion(errors) // check if there are recursive subroutine calls
errors.handle() errors.handle()
programAst.verifyFunctionArgTypes()
} }
private fun writeAssembly(programAst: Program, errors: ErrorReporter, outputDir: Path, private fun writeAssembly(programAst: Program, errors: ErrorReporter, outputDir: Path,
@ -194,7 +195,7 @@ private fun writeAssembly(programAst: Program, errors: ErrorReporter, outputDir:
programAst.processAstBeforeAsmGeneration(errors) programAst.processAstBeforeAsmGeneration(errors)
errors.handle() errors.handle()
// printAst(programAst) // TODO // printAst(programAst)
val assembly = CompilationTarget.asmGenerator( val assembly = CompilationTarget.asmGenerator(
programAst, programAst,

View File

@ -691,7 +691,6 @@ internal class AsmGen(private val program: Program,
loopEndLabels.push(endLabel) loopEndLabels.push(endLabel)
loopContinueLabels.push(whileLabel) loopContinueLabels.push(whileLabel)
out(whileLabel) out(whileLabel)
// TODO optimize for the simple cases, can we avoid stack use?
expressionsAsmGen.translateExpression(stmt.condition) expressionsAsmGen.translateExpression(stmt.condition)
val conditionDt = stmt.condition.inferType(program) val conditionDt = stmt.condition.inferType(program)
if(!conditionDt.isKnown) if(!conditionDt.isKnown)
@ -720,7 +719,6 @@ internal class AsmGen(private val program: Program,
loopEndLabels.push(endLabel) loopEndLabels.push(endLabel)
loopContinueLabels.push(repeatLabel) loopContinueLabels.push(repeatLabel)
out(repeatLabel) out(repeatLabel)
// TODO optimize this for the simple cases, can we avoid stack use?
translate(stmt.body) translate(stmt.body)
expressionsAsmGen.translateExpression(stmt.untilCondition) expressionsAsmGen.translateExpression(stmt.untilCondition)
val conditionDt = stmt.untilCondition.inferType(program) val conditionDt = stmt.untilCondition.inferType(program)

View File

@ -46,7 +46,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
} }
// TODO this is the FALLBACK: // TODO this is the slow FALLBACK, eventually we don't want to have to use it anymore:
errors.warn("using suboptimal in-place assignment code (this should still be optimized)", assign.position) errors.warn("using suboptimal in-place assignment code (this should still be optimized)", assign.position)
val normalAssignment = assign.asDesugaredNonaugmented() val normalAssignment = assign.asDesugaredNonaugmented()
return translateNormalAssignment(normalAssignment) return translateNormalAssignment(normalAssignment)
@ -91,7 +91,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
else -> throw AssemblyError("assignment to array: invalid array dt $arrayDt") else -> throw AssemblyError("assignment to array: invalid array dt $arrayDt")
} }
} else { } else {
TODO() TODO("aug assignment to element in array/string")
} }
return true return true
} }
@ -155,27 +155,29 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
sta $targetName+1,y sta $targetName+1,y
""") """)
} }
DataType.ARRAY_F -> return false // TODO optimize? DataType.ARRAY_F -> return false // TODO optimize instead of fallback?
else -> throw AssemblyError("assignment to array: invalid array dt $arrayDt") else -> throw AssemblyError("assignment to array: invalid array dt $arrayDt")
} }
return true return true
} }
is AddressOf -> { is AddressOf -> {
TODO("$assign") TODO("assign address into array $assign")
} }
is DirectMemoryRead -> { is DirectMemoryRead -> {
TODO("$assign") TODO("assign memory read into array $assign")
} }
is ArrayIndexedExpression -> { is ArrayIndexedExpression -> {
if(assign.aug_op != "setvalue")
return false // we don't put effort into optimizing anything beside simple assignment
val valueArrayExpr = assign.value as ArrayIndexedExpression val valueArrayExpr = assign.value as ArrayIndexedExpression
val valueArrayIndex = valueArrayExpr.arrayspec.index val valueArrayIndex = valueArrayExpr.arrayspec.index
if(valueArrayIndex is RegisterExpr || arrayIndex is RegisterExpr) {
throw AssemblyError("cannot generate code for array operations with registers as index")
}
val valueVariablename = asmgen.asmIdentifierName(valueArrayExpr.identifier) val valueVariablename = asmgen.asmIdentifierName(valueArrayExpr.identifier)
val valueDt = valueArrayExpr.identifier.inferType(program).typeOrElse(DataType.STRUCT) val valueDt = valueArrayExpr.identifier.inferType(program).typeOrElse(DataType.STRUCT)
when(arrayDt) { when(arrayDt) {
DataType.ARRAY_UB -> { DataType.ARRAY_UB, DataType.ARRAY_B, DataType.STR -> {
if (arrayDt != DataType.ARRAY_B && arrayDt != DataType.ARRAY_UB && arrayDt != DataType.STR)
throw AssemblyError("assignment to array: expected byte array or string")
if (assign.aug_op == "setvalue") {
if (valueArrayIndex is NumericLiteralValue) if (valueArrayIndex is NumericLiteralValue)
asmgen.out(" ldy #${valueArrayIndex.number.toHex()}") asmgen.out(" ldy #${valueArrayIndex.number.toHex()}")
else else
@ -186,14 +188,87 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
else else
asmgen.translateArrayIndexIntoY(targetArray) asmgen.translateArrayIndexIntoY(targetArray)
asmgen.out(" sta $targetName,y") asmgen.out(" sta $targetName,y")
} else {
return false // TODO optimize
} }
DataType.ARRAY_UW, DataType.ARRAY_W -> {
if (valueArrayIndex is NumericLiteralValue)
asmgen.out(" ldy #2*${valueArrayIndex.number.toHex()}")
else {
asmgen.translateArrayIndexIntoA(valueArrayExpr)
asmgen.out(" asl a | tay")
}
asmgen.out("""
lda $valueVariablename,y
pha
lda $valueVariablename+1,y
pha
""")
if (arrayIndex is NumericLiteralValue)
asmgen.out(" ldy #2*${arrayIndex.number.toHex()}")
else {
asmgen.translateArrayIndexIntoA(targetArray)
asmgen.out(" asl a | tay")
}
asmgen.out("""
pla
sta $targetName+1,y
pla
sta $targetName,y
""")
return true
}
DataType.ARRAY_F -> {
if (valueArrayIndex is NumericLiteralValue)
asmgen.out(" ldy #5*${valueArrayIndex.number.toHex()}")
else {
asmgen.translateArrayIndexIntoA(valueArrayExpr)
asmgen.out("""
sta ${C64Zeropage.SCRATCH_REG}
asl a
asl a
clc
adc ${C64Zeropage.SCRATCH_REG}
tay
""")
}
asmgen.out("""
lda $valueVariablename,y
pha
lda $valueVariablename+1,y
pha
lda $valueVariablename+2,y
pha
lda $valueVariablename+3,y
pha
lda $valueVariablename+4,y
pha
""")
if (arrayIndex is NumericLiteralValue)
asmgen.out(" ldy #5*${arrayIndex.number.toHex()}")
else {
asmgen.translateArrayIndexIntoA(targetArray)
asmgen.out("""
sta ${C64Zeropage.SCRATCH_REG}
asl a
asl a
clc
adc ${C64Zeropage.SCRATCH_REG}
tay
""")
}
asmgen.out("""
pla
sta $targetName+4,y
pla
sta $targetName+3,y
pla
sta $targetName+2,y
pla
sta $targetName+1,y
pla
sta $targetName,y
""")
return true
} }
DataType.ARRAY_B -> TODO()
DataType.ARRAY_UW -> TODO()
DataType.ARRAY_W -> TODO()
DataType.ARRAY_F -> TODO()
else -> throw AssemblyError("assignment to array: invalid array dt") else -> throw AssemblyError("assignment to array: invalid array dt")
} }
return true return true
@ -219,13 +294,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda #$hexValue | sta $hexAddr") "setvalue" -> asmgen.out(" lda #$hexValue | sta $hexAddr")
"+=" -> asmgen.out(" lda $hexAddr | clc | adc #$hexValue | sta $hexAddr") "+=" -> asmgen.out(" lda $hexAddr | clc | adc #$hexValue | sta $hexAddr")
"-=" -> asmgen.out(" lda $hexAddr | sec | sbc #$hexValue | sta $hexAddr") "-=" -> asmgen.out(" lda $hexAddr | sec | sbc #$hexValue | sta $hexAddr")
"/=" -> TODO("/=") "/=" -> TODO("membyte /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("membyte *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" lda $hexAddr | and #$hexValue | sta $hexAddr") "&=" -> asmgen.out(" lda $hexAddr | and #$hexValue | sta $hexAddr")
"|=" -> asmgen.out(" lda $hexAddr | ora #$hexValue | sta $hexAddr") "|=" -> asmgen.out(" lda $hexAddr | ora #$hexValue | sta $hexAddr")
"^=" -> asmgen.out(" lda $hexAddr | eor #$hexValue | sta $hexAddr") "^=" -> asmgen.out(" lda $hexAddr | eor #$hexValue | sta $hexAddr")
"%=" -> TODO("%=") "%=" -> TODO("membyte %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -258,9 +332,8 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.Y -> asmgen.out(" sty ${C64Zeropage.SCRATCH_B1} | lda $hexAddr | sec | sbc ${C64Zeropage.SCRATCH_B1} | sta $hexAddr") Register.Y -> asmgen.out(" sty ${C64Zeropage.SCRATCH_B1} | lda $hexAddr | sec | sbc ${C64Zeropage.SCRATCH_B1} | sta $hexAddr")
} }
} }
"/=" -> TODO("/=") "/=" -> TODO("membyte /= register")
"*=" -> TODO("*=") "*=" -> TODO("membyte *= register")
"**=" -> TODO("**=")
"&=" -> { "&=" -> {
when ((assign.value as RegisterExpr).register) { when ((assign.value as RegisterExpr).register) {
Register.A -> asmgen.out(" and $hexAddr | sta $hexAddr") Register.A -> asmgen.out(" and $hexAddr | sta $hexAddr")
@ -282,7 +355,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.Y -> asmgen.out(" tya | eor $hexAddr | sta $hexAddr") Register.Y -> asmgen.out(" tya | eor $hexAddr | sta $hexAddr")
} }
} }
"%=" -> TODO("%=") "%=" -> TODO("membyte %= register")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -291,17 +364,24 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
is IdentifierReference -> { is IdentifierReference -> {
val sourceName = asmgen.asmIdentifierName(assign.value as IdentifierReference) val sourceName = asmgen.asmIdentifierName(assign.value as IdentifierReference)
TODO("$assign") when(assign.aug_op) {
"setvalue" -> asmgen.out(" lda $sourceName | sta $hexAddr")
else -> TODO("membyte aug.assign variable $assign")
} }
is AddressOf -> { return true
TODO("$assign")
} }
is DirectMemoryRead -> { is DirectMemoryRead -> {
TODO("$assign") val memory = (assign.value as DirectMemoryRead).addressExpression.constValue(program)!!.number.toHex()
when(assign.aug_op) {
"setvalue" -> asmgen.out(" lda $memory | sta $hexAddr")
else -> TODO("membyte aug.assign memread $assign")
}
return true
} }
is ArrayIndexedExpression -> { is ArrayIndexedExpression -> {
TODO("$assign") TODO("membyte = array value $assign")
} }
is AddressOf -> throw AssemblyError("can't assign address to byte")
else -> { else -> {
fallbackAssignment(assign) fallbackAssignment(assign)
return true return true
@ -312,7 +392,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
private fun inplaceAssignToNonConstMemoryByte(assign: Assignment): Boolean { private fun inplaceAssignToNonConstMemoryByte(assign: Assignment): Boolean {
// target address is not constant, so evaluate it on the stack // target address is not constant, so evaluate it from the stack
asmgen.translateExpression(assign.target.memoryAddress!!.addressExpression) asmgen.translateExpression(assign.target.memoryAddress!!.addressExpression)
asmgen.out(""" asmgen.out("""
inx inx
@ -330,13 +410,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "setvalue" -> asmgen.out(" lda #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"+=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | clc | adc #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "+=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | clc | adc #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"-=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | sec | sbc #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "-=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | sec | sbc #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"/=" -> TODO("/=") "/=" -> TODO("membyte /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("membyte *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | and #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "&=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | and #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"|=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | ora #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "|=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | ora #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"^=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | eor #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y") "^=" -> asmgen.out(" lda (${C64Zeropage.SCRATCH_W1}),y | eor #$hexValue | sta (${C64Zeropage.SCRATCH_W1}),y")
"%=" -> TODO("%=") "%=" -> TODO("membyte %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -370,9 +449,8 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.Y -> asmgen.out(" tya | ldy #0 | sta ${C64Zeropage.SCRATCH_B1} | lda (${C64Zeropage.SCRATCH_W1}),y | sec | sbc ${C64Zeropage.SCRATCH_B1} | sta (${C64Zeropage.SCRATCH_W1}),y") Register.Y -> asmgen.out(" tya | ldy #0 | sta ${C64Zeropage.SCRATCH_B1} | lda (${C64Zeropage.SCRATCH_W1}),y | sec | sbc ${C64Zeropage.SCRATCH_B1} | sta (${C64Zeropage.SCRATCH_W1}),y")
} }
} }
"/=" -> TODO("/=") "/=" -> TODO("membyte /= register")
"*=" -> TODO("*=") "*=" -> TODO("membyte *= register")
"**=" -> TODO("**=")
"&=" -> { "&=" -> {
when ((assign.value as RegisterExpr).register) { when ((assign.value as RegisterExpr).register) {
Register.A -> asmgen.out(" ldy #0 | and (${C64Zeropage.SCRATCH_W1}),y| sta (${C64Zeropage.SCRATCH_W1}),y") Register.A -> asmgen.out(" ldy #0 | and (${C64Zeropage.SCRATCH_W1}),y| sta (${C64Zeropage.SCRATCH_W1}),y")
@ -394,7 +472,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.Y -> asmgen.out(" tya | ldy #0 | eor (${C64Zeropage.SCRATCH_W1}),y | sta (${C64Zeropage.SCRATCH_W1}),y") Register.Y -> asmgen.out(" tya | ldy #0 | eor (${C64Zeropage.SCRATCH_W1}),y | sta (${C64Zeropage.SCRATCH_W1}),y")
} }
} }
"%=" -> TODO("%=") "%=" -> TODO("membyte %= register")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -403,13 +481,10 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
is IdentifierReference -> { is IdentifierReference -> {
val sourceName = asmgen.asmIdentifierName(assign.value as IdentifierReference) val sourceName = asmgen.asmIdentifierName(assign.value as IdentifierReference)
TODO("$assign") TODO("membyte = variable $assign")
}
is AddressOf -> {
TODO("$assign")
} }
is DirectMemoryRead -> { is DirectMemoryRead -> {
TODO("$assign") TODO("membyte = memread $assign")
} }
is ArrayIndexedExpression -> { is ArrayIndexedExpression -> {
if (assign.aug_op == "setvalue") { if (assign.aug_op == "setvalue") {
@ -436,6 +511,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
return true return true
} }
is AddressOf -> throw AssemblyError("can't assign memory address to memory byte")
else -> { else -> {
fallbackAssignment(assign) fallbackAssignment(assign)
return true return true
@ -459,13 +535,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda #$hexValue | sta $targetName") "setvalue" -> asmgen.out(" lda #$hexValue | sta $targetName")
"+=" -> asmgen.out(" lda $targetName | clc | adc #$hexValue | sta $targetName") "+=" -> asmgen.out(" lda $targetName | clc | adc #$hexValue | sta $targetName")
"-=" -> asmgen.out(" lda $targetName | sec | sbc #$hexValue | sta $targetName") "-=" -> asmgen.out(" lda $targetName | sec | sbc #$hexValue | sta $targetName")
"/=" -> TODO("/=") "/=" -> TODO("variable /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("variable *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" lda $targetName | and #$hexValue | sta $targetName") "&=" -> asmgen.out(" lda $targetName | and #$hexValue | sta $targetName")
"|=" -> asmgen.out(" lda $targetName | ora #$hexValue | sta $targetName") "|=" -> asmgen.out(" lda $targetName | ora #$hexValue | sta $targetName")
"^=" -> asmgen.out(" lda $targetName | eor #$hexValue | sta $targetName") "^=" -> asmgen.out(" lda $targetName | eor #$hexValue | sta $targetName")
"%=" -> TODO("%=") "%=" -> TODO("variable %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -485,7 +560,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.Y -> asmgen.out(" sty $targetName") Register.Y -> asmgen.out(" sty $targetName")
} }
} }
else -> TODO("$assign") else -> TODO("aug.assign variable = register $assign")
} }
return true return true
} }
@ -495,13 +570,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda $sourceName | sta $targetName") "setvalue" -> asmgen.out(" lda $sourceName | sta $targetName")
"+=" -> asmgen.out(" lda $targetName | clc | adc $sourceName | sta $targetName") "+=" -> asmgen.out(" lda $targetName | clc | adc $sourceName | sta $targetName")
"-=" -> asmgen.out(" lda $targetName | sec | sbc $sourceName | sta $targetName") "-=" -> asmgen.out(" lda $targetName | sec | sbc $sourceName | sta $targetName")
"/=" -> TODO("/=") "/=" -> TODO("variable /= variable")
"*=" -> TODO("*=") "*=" -> TODO("variable *= variable")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" lda $targetName | and $sourceName | sta $targetName") "&=" -> asmgen.out(" lda $targetName | and $sourceName | sta $targetName")
"|=" -> asmgen.out(" lda $targetName | ora $sourceName | sta $targetName") "|=" -> asmgen.out(" lda $targetName | ora $sourceName | sta $targetName")
"^=" -> asmgen.out(" lda $targetName | eor $sourceName | sta $targetName") "^=" -> asmgen.out(" lda $targetName | eor $sourceName | sta $targetName")
"%=" -> TODO("%=") "%=" -> TODO("variable %= variable")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -509,7 +583,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
return true return true
} }
is DirectMemoryRead -> { is DirectMemoryRead -> {
TODO("$assign") TODO("variable = memory read $assign")
} }
is ArrayIndexedExpression -> { is ArrayIndexedExpression -> {
if (assign.aug_op == "setvalue") { if (assign.aug_op == "setvalue") {
@ -572,7 +646,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
sta $targetName+1 sta $targetName+1
""") """)
} }
else -> TODO("$assign") else -> TODO("variable aug.assign ${assign.aug_op} const $hexNumber")
} }
return true return true
} }
@ -617,7 +691,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
return true return true
} }
else -> { else -> {
TODO("$assign") TODO("variable aug.assign variable")
} }
} }
return true return true
@ -714,7 +788,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
""") """)
return true return true
} }
else -> TODO("$assign") else -> TODO("float const value aug.assign $assign")
} }
return true return true
} }
@ -727,7 +801,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> assignFromFloatVariable(assign.target, assign.value as IdentifierReference) "setvalue" -> assignFromFloatVariable(assign.target, assign.value as IdentifierReference)
"+=" -> return false // TODO optimized float += variable "+=" -> return false // TODO optimized float += variable
"-=" -> return false // TODO optimized float -= variable "-=" -> return false // TODO optimized float -= variable
else -> TODO("$assign") else -> TODO("float non-const value aug.assign $assign")
} }
return true return true
} }
@ -763,7 +837,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
sta $targetName+4 sta $targetName+4
""") """)
} }
else -> TODO("$assign") else -> TODO("float $assign")
} }
return true return true
} }
@ -773,6 +847,21 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
} }
} }
DataType.STR -> {
val identifier = assign.value as? IdentifierReference
?: throw AssemblyError("string value assignment expects identifier value")
val sourceName = asmgen.asmIdentifierName(identifier)
asmgen.out("""
lda #<$targetName
sta ${C64Zeropage.SCRATCH_W1}
lda #>$targetName
sta ${C64Zeropage.SCRATCH_W1+1}
lda #<$sourceName
ldy #>$sourceName
jsr prog8_lib.strcpy
""")
return true
}
else -> throw AssemblyError("assignment to identifier: invalid target datatype: $targetType") else -> throw AssemblyError("assignment to identifier: invalid target datatype: $targetType")
} }
return false return false
@ -788,13 +877,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda #$hexValue") "setvalue" -> asmgen.out(" lda #$hexValue")
"+=" -> asmgen.out(" clc | adc #$hexValue") "+=" -> asmgen.out(" clc | adc #$hexValue")
"-=" -> asmgen.out(" sec | sbc #$hexValue") "-=" -> asmgen.out(" sec | sbc #$hexValue")
"/=" -> TODO("/=") "/=" -> TODO("A /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("A *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" and #$hexValue") "&=" -> asmgen.out(" and #$hexValue")
"|=" -> asmgen.out(" ora #$hexValue") "|=" -> asmgen.out(" ora #$hexValue")
"^=" -> asmgen.out(" eor #$hexValue") "^=" -> asmgen.out(" eor #$hexValue")
"%=" -> TODO("%=") "%=" -> TODO("A %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -805,13 +893,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldx #$hexValue") "setvalue" -> asmgen.out(" ldx #$hexValue")
"+=" -> asmgen.out(" txa | clc | adc #$hexValue | tax") "+=" -> asmgen.out(" txa | clc | adc #$hexValue | tax")
"-=" -> asmgen.out(" txa | sec | sbc #$hexValue | tax") "-=" -> asmgen.out(" txa | sec | sbc #$hexValue | tax")
"/=" -> TODO("/=") "/=" -> TODO("X /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("X *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" txa | and #$hexValue | tax") "&=" -> asmgen.out(" txa | and #$hexValue | tax")
"|=" -> asmgen.out(" txa | ora #$hexValue | tax") "|=" -> asmgen.out(" txa | ora #$hexValue | tax")
"^=" -> asmgen.out(" txa | eor #$hexValue | tax") "^=" -> asmgen.out(" txa | eor #$hexValue | tax")
"%=" -> TODO("%=") "%=" -> TODO("X %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -822,13 +909,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldy #$hexValue") "setvalue" -> asmgen.out(" ldy #$hexValue")
"+=" -> asmgen.out(" tya | clc | adc #$hexValue | tay") "+=" -> asmgen.out(" tya | clc | adc #$hexValue | tay")
"-=" -> asmgen.out(" tya | sec | sbc #$hexValue | tay") "-=" -> asmgen.out(" tya | sec | sbc #$hexValue | tay")
"/=" -> TODO("/=") "/=" -> TODO("Y /= const $hexValue")
"*=" -> TODO("*=") "*=" -> TODO("Y *= const $hexValue")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" tya | and #$hexValue | tay") "&=" -> asmgen.out(" tya | and #$hexValue | tay")
"|=" -> asmgen.out(" tya | ora #$hexValue | tay") "|=" -> asmgen.out(" tya | ora #$hexValue | tay")
"^=" -> asmgen.out(" tya | eor #$hexValue | tay") "^=" -> asmgen.out(" tya | eor #$hexValue | tay")
"%=" -> TODO("%=") "%=" -> TODO("Y %= const $hexValue")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -848,13 +934,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda $sourceName") "setvalue" -> asmgen.out(" lda $sourceName")
"+=" -> asmgen.out(" clc | adc $sourceName") "+=" -> asmgen.out(" clc | adc $sourceName")
"-=" -> asmgen.out(" sec | sbc $sourceName") "-=" -> asmgen.out(" sec | sbc $sourceName")
"/=" -> TODO("/=") "/=" -> TODO("A /= variable")
"*=" -> TODO("*=") "*=" -> TODO("A *= variable")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" and $sourceName") "&=" -> asmgen.out(" and $sourceName")
"|=" -> asmgen.out(" ora $sourceName") "|=" -> asmgen.out(" ora $sourceName")
"^=" -> asmgen.out(" eor $sourceName") "^=" -> asmgen.out(" eor $sourceName")
"%=" -> TODO("%=") "%=" -> TODO("A %= variable")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -865,13 +950,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldx $sourceName") "setvalue" -> asmgen.out(" ldx $sourceName")
"+=" -> asmgen.out(" txa | clc | adc $sourceName | tax") "+=" -> asmgen.out(" txa | clc | adc $sourceName | tax")
"-=" -> asmgen.out(" txa | sec | sbc $sourceName | tax") "-=" -> asmgen.out(" txa | sec | sbc $sourceName | tax")
"/=" -> TODO("/=") "/=" -> TODO("X /= variable")
"*=" -> TODO("*=") "*=" -> TODO("X *= variable")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" txa | and $sourceName | tax") "&=" -> asmgen.out(" txa | and $sourceName | tax")
"|=" -> asmgen.out(" txa | ora $sourceName | tax") "|=" -> asmgen.out(" txa | ora $sourceName | tax")
"^=" -> asmgen.out(" txa | eor $sourceName | tax") "^=" -> asmgen.out(" txa | eor $sourceName | tax")
"%=" -> TODO("%=") "%=" -> TODO("X %= variable")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -882,13 +966,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldy $sourceName") "setvalue" -> asmgen.out(" ldy $sourceName")
"+=" -> asmgen.out(" tya | clc | adc $sourceName | tay") "+=" -> asmgen.out(" tya | clc | adc $sourceName | tay")
"-=" -> asmgen.out(" tya | sec | sbc $sourceName | tay") "-=" -> asmgen.out(" tya | sec | sbc $sourceName | tay")
"/=" -> TODO("/=") "/=" -> TODO("Y /= variable")
"*=" -> TODO("*=") "*=" -> TODO("Y *= variable")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" tya | and $sourceName | tay") "&=" -> asmgen.out(" tya | and $sourceName | tay")
"|=" -> asmgen.out(" tya | ora $sourceName | tay") "|=" -> asmgen.out(" tya | ora $sourceName | tay")
"^=" -> asmgen.out(" tya | eor $sourceName | tay") "^=" -> asmgen.out(" tya | eor $sourceName | tay")
"%=" -> TODO("%=") "%=" -> TODO("Y %= variable")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -903,8 +986,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
when ((assign.value as RegisterExpr).register) { when ((assign.value as RegisterExpr).register) {
Register.A -> { Register.A -> {
when (assign.target.register!!) { when (assign.target.register!!) {
Register.A -> { Register.A -> {}
}
Register.X -> asmgen.out(" tax") Register.X -> asmgen.out(" tax")
Register.Y -> asmgen.out(" tay") Register.Y -> asmgen.out(" tay")
} }
@ -912,8 +994,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
Register.X -> { Register.X -> {
when (assign.target.register!!) { when (assign.target.register!!) {
Register.A -> asmgen.out(" txa") Register.A -> asmgen.out(" txa")
Register.X -> { Register.X -> {}
}
Register.Y -> asmgen.out(" txa | tay") Register.Y -> asmgen.out(" txa | tay")
} }
} }
@ -921,8 +1002,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
when (assign.target.register!!) { when (assign.target.register!!) {
Register.A -> asmgen.out(" tya") Register.A -> asmgen.out(" tya")
Register.X -> asmgen.out(" tya | tax") Register.X -> asmgen.out(" tya | tax")
Register.Y -> { Register.Y -> {}
}
} }
} }
} }
@ -977,9 +1057,8 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
} }
} }
"/=" -> TODO("/=") "/=" -> TODO("register /= register")
"*=" -> TODO("*=") "*=" -> TODO("register *= register")
"**=" -> TODO("**=")
"&=" -> { "&=" -> {
when ((assign.value as RegisterExpr).register) { when ((assign.value as RegisterExpr).register) {
Register.A -> { Register.A -> {
@ -1005,9 +1084,9 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
} }
} }
"|=" -> TODO() "|=" -> TODO("register |= register")
"^=" -> TODO() "^=" -> TODO("register ^= register")
"%=" -> TODO("%=") "%=" -> TODO("register %= register")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -1024,13 +1103,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" lda $hexAddr") "setvalue" -> asmgen.out(" lda $hexAddr")
"+=" -> asmgen.out(" clc | adc $hexAddr") "+=" -> asmgen.out(" clc | adc $hexAddr")
"-=" -> asmgen.out(" sec | sbc $hexAddr") "-=" -> asmgen.out(" sec | sbc $hexAddr")
"/=" -> TODO("/=") "/=" -> TODO("A /= memory $hexAddr")
"*=" -> TODO("*=") "*=" -> TODO("A *= memory $hexAddr")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" and $hexAddr") "&=" -> asmgen.out(" and $hexAddr")
"|=" -> asmgen.out(" ora $hexAddr") "|=" -> asmgen.out(" ora $hexAddr")
"^=" -> asmgen.out(" eor $hexAddr") "^=" -> asmgen.out(" eor $hexAddr")
"%=" -> TODO("%=") "%=" -> TODO("A %= memory $hexAddr")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -1041,13 +1119,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldx $hexAddr") "setvalue" -> asmgen.out(" ldx $hexAddr")
"+=" -> asmgen.out(" txa | clc | adc $hexAddr | tax") "+=" -> asmgen.out(" txa | clc | adc $hexAddr | tax")
"-=" -> asmgen.out(" txa | sec | sbc $hexAddr | tax") "-=" -> asmgen.out(" txa | sec | sbc $hexAddr | tax")
"/=" -> TODO("/=") "/=" -> TODO("X /= memory $hexAddr")
"*=" -> TODO("*=") "*=" -> TODO("X *= memory $hexAddr")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" txa | and $hexAddr | tax") "&=" -> asmgen.out(" txa | and $hexAddr | tax")
"|=" -> asmgen.out(" txa | ora $hexAddr | tax") "|=" -> asmgen.out(" txa | ora $hexAddr | tax")
"^=" -> asmgen.out(" txa | eor $hexAddr | tax") "^=" -> asmgen.out(" txa | eor $hexAddr | tax")
"%=" -> TODO("%=") "%=" -> TODO("X %= memory $hexAddr")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -1058,13 +1135,12 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
"setvalue" -> asmgen.out(" ldy $hexAddr") "setvalue" -> asmgen.out(" ldy $hexAddr")
"+=" -> asmgen.out(" tya | clc | adc $hexAddr | tay") "+=" -> asmgen.out(" tya | clc | adc $hexAddr | tay")
"-=" -> asmgen.out(" tya | sec | sbc $hexAddr | tay") "-=" -> asmgen.out(" tya | sec | sbc $hexAddr | tay")
"/=" -> TODO("/=") "/=" -> TODO("Y /= memory $hexAddr")
"*=" -> TODO("*=") "*=" -> TODO("Y *= memory $hexAddr")
"**=" -> TODO("**=")
"&=" -> asmgen.out(" tya | and $hexAddr | tay") "&=" -> asmgen.out(" tya | and $hexAddr | tay")
"|=" -> asmgen.out(" tya | ora $hexAddr | tay") "|=" -> asmgen.out(" tya | ora $hexAddr | tay")
"^=" -> asmgen.out(" tya | eor $hexAddr | tay") "^=" -> asmgen.out(" tya | eor $hexAddr | tay")
"%=" -> TODO("%=") "%=" -> TODO("Y %= memory $hexAddr")
"<<=" -> throw AssemblyError("<<= should have been replaced by lsl()") "<<=" -> throw AssemblyError("<<= should have been replaced by lsl()")
">>=" -> throw AssemblyError("<<= should have been replaced by lsr()") ">>=" -> throw AssemblyError("<<= should have been replaced by lsr()")
else -> throw AssemblyError("invalid aug_op ${assign.aug_op}") else -> throw AssemblyError("invalid aug_op ${assign.aug_op}")
@ -1074,9 +1150,6 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
return true return true
} }
} }
is AddressOf -> {
TODO("$assign")
}
is ArrayIndexedExpression -> { is ArrayIndexedExpression -> {
if (assign.aug_op == "setvalue") { if (assign.aug_op == "setvalue") {
val arrayExpr = assign.value as ArrayIndexedExpression val arrayExpr = assign.value as ArrayIndexedExpression
@ -1102,6 +1175,7 @@ internal class AssignmentAsmGen(private val program: Program, private val errors
} }
return true return true
} }
is AddressOf -> throw AssemblyError("can't load a memory address into a register")
else -> { else -> {
fallbackAssignment(assign) fallbackAssignment(assign)
return true return true

View File

@ -2,12 +2,8 @@ package prog8.compiler.target.c64.codegen
import prog8.ast.IFunctionCall import prog8.ast.IFunctionCall
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.ByteDatatypes import prog8.ast.base.*
import prog8.ast.base.DataType
import prog8.ast.base.Register
import prog8.ast.base.WordDatatypes
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.AssignTarget
import prog8.ast.statements.FunctionCallStatement import prog8.ast.statements.FunctionCallStatement
import prog8.compiler.AssemblyError import prog8.compiler.AssemblyError
import prog8.compiler.target.c64.C64MachineDefinition.C64Zeropage import prog8.compiler.target.c64.C64MachineDefinition.C64Zeropage
@ -581,32 +577,32 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, private val
ldy $firstName ldy $firstName
lda $secondName lda $secondName
sta $firstName sta $firstName
tya sty $secondName
sta $secondName
ldy $firstName+1 ldy $firstName+1
lda $secondName+1 lda $secondName+1
sta $firstName+1 sta $firstName+1
tya sty $secondName+1
sta $secondName+1
""") """)
return return
} }
if(dt.istype(DataType.FLOAT)) { if(dt.istype(DataType.FLOAT)) {
TODO("optimized case for swapping 2 float vars-- asm subroutine") asmgen.out("""
lda #<$firstName
sta ${C64Zeropage.SCRATCH_W1}
lda #>$firstName
sta ${C64Zeropage.SCRATCH_W1+1}
lda #<$secondName
sta ${C64Zeropage.SCRATCH_W2}
lda #>$secondName
sta ${C64Zeropage.SCRATCH_W2+1}
jsr c64flt.swap_floats
""")
return return
} }
} }
// TODO more optimized cases? for instance swapping elements of array vars? // other types of swap() calls should have been replaced by a different statement sequence involving a temp variable
throw AssemblyError("no asm generation for swap funccall $fcall")
// suboptimal code via the evaluation stack...
asmgen.translateExpression(first)
asmgen.translateExpression(second)
// pop in reverse order
val firstTarget = AssignTarget.fromExpr(first)
val secondTarget = AssignTarget.fromExpr(second)
asmgen.assignFromEvalResult(firstTarget)
asmgen.assignFromEvalResult(secondTarget)
} }
private fun funcAbs(fcall: IFunctionCall, func: FSignature) { private fun funcAbs(fcall: IFunctionCall, func: FSignature) {

View File

@ -28,7 +28,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
is RegisterExpr -> translateExpression(expression) is RegisterExpr -> translateExpression(expression)
is IdentifierReference -> translateExpression(expression) is IdentifierReference -> translateExpression(expression)
is FunctionCall -> translateExpression(expression) is FunctionCall -> translateExpression(expression)
is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array assignment") is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array literal value assignment - should have been replaced by a variable")
is StructLiteralValue -> throw AssemblyError("struct literal value assignment should have been flattened") is StructLiteralValue -> throw AssemblyError("struct literal value assignment should have been flattened")
is RangeExpr -> throw AssemblyError("range expression should have been changed into array values") is RangeExpr -> throw AssemblyError("range expression should have been changed into array values")
} }
@ -240,16 +240,26 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
} }
} }
DataType.UWORD -> { DataType.UWORD -> {
if(amount<=2) var left = amount
repeat(amount) { asmgen.out(" lsr $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") } while(left>=7) {
asmgen.out(" jsr math.shift_right_uw_7")
left -= 7
}
if (left in 0..2)
repeat(left) { asmgen.out(" lsr $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") }
else else
asmgen.out(" jsr math.shift_right_uw_$amount") // 3-7 (8+ is done via other optimizations) asmgen.out(" jsr math.shift_right_uw_$left")
} }
DataType.WORD -> { DataType.WORD -> {
if(amount<=2) var left = amount
repeat(amount) { asmgen.out(" lda $ESTACK_HI_PLUS1_HEX,x | asl a | ror $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") } while(left>=7) {
asmgen.out(" jsr math.shift_right_w_7")
left -= 7
}
if (left in 0..2)
repeat(left) { asmgen.out(" lda $ESTACK_HI_PLUS1_HEX,x | asl a | ror $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") }
else else
asmgen.out(" jsr math.shift_right_w_$amount") // 3-7 (8+ is done via other optimizations) asmgen.out(" jsr math.shift_right_w_$left")
} }
else -> throw AssemblyError("weird type") else -> throw AssemblyError("weird type")
} }
@ -269,11 +279,15 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
} }
} }
else { else {
if(amount<=2) { var left=amount
repeat(amount) { asmgen.out(" asl $ESTACK_LO_PLUS1_HEX,x | rol $ESTACK_HI_PLUS1_HEX,x") } while(left>=7) {
} else { asmgen.out(" jsr math.shift_left_w_7")
asmgen.out(" jsr math.shift_left_w_$amount") // 3-7 (8+ is done via other optimizations) left -= 7
} }
if (left in 0..2)
repeat(left) { asmgen.out(" asl $ESTACK_LO_PLUS1_HEX,x | rol $ESTACK_HI_PLUS1_HEX,x") }
else
asmgen.out(" jsr math.shift_left_w_$left")
} }
return return
} }

View File

@ -16,7 +16,7 @@ import prog8.compiler.toHex
import kotlin.math.absoluteValue import kotlin.math.absoluteValue
// todo choose more efficient comparisons to avoid needless lda's // todo choose more efficient comparisons to avoid needless lda's
// todo optimize common case step == 2 / -2 // todo optimize common case when step == 2 or -2
internal class ForLoopsAsmGen(private val program: Program, private val asmgen: AsmGen) { internal class ForLoopsAsmGen(private val program: Program, private val asmgen: AsmGen) {
@ -37,7 +37,7 @@ internal class ForLoopsAsmGen(private val program: Program, private val asmgen:
is IdentifierReference -> { is IdentifierReference -> {
translateForOverIterableVar(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as IdentifierReference) translateForOverIterableVar(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as IdentifierReference)
} }
else -> throw AssemblyError("can't iterate over ${stmt.iterable}") else -> throw AssemblyError("can't iterate over ${stmt.iterable.javaClass} - should have been replaced by a variable")
} }
} }
@ -339,7 +339,7 @@ $continueLabel inc $loopLabel+1
$endLabel""") $endLabel""")
} }
DataType.ARRAY_UB, DataType.ARRAY_B -> { DataType.ARRAY_UB, DataType.ARRAY_B -> {
// TODO: optimize loop code when the length of the array is < 256, don't need a separate counter in such cases // TODO: optimize loop code when the length of the array is < 256, don't need a separate counter var in such cases
val length = decl.arraysize!!.size()!! val length = decl.arraysize!!.size()!!
if(stmt.loopRegister!=null && stmt.loopRegister!= Register.A) if(stmt.loopRegister!=null && stmt.loopRegister!= Register.A)
throw AssemblyError("can only use A") throw AssemblyError("can only use A")
@ -366,7 +366,7 @@ $counterLabel .byte 0
$endLabel""") $endLabel""")
} }
DataType.ARRAY_W, DataType.ARRAY_UW -> { DataType.ARRAY_W, DataType.ARRAY_UW -> {
// TODO: optimize loop code when the length of the array is < 256, don't need a separate counter in such cases // TODO: optimize loop code when the length of the array is < 256, don't need a separate counter var in such cases
val length = decl.arraysize!!.size()!! * 2 val length = decl.arraysize!!.size()!! * 2
if(stmt.loopRegister!=null) if(stmt.loopRegister!=null)
throw AssemblyError("can't use register to loop over words") throw AssemblyError("can't use register to loop over words")
@ -410,7 +410,7 @@ $endLabel""")
} }
private fun translateForOverConstRange(stmt: ForLoop, iterableDt: DataType, range: IntProgression) { private fun translateForOverConstRange(stmt: ForLoop, iterableDt: DataType, range: IntProgression) {
// TODO: optimize loop code when the range is < 256 iterations, don't need a separate counter in such cases // TODO: optimize loop code when the range is < 256 iterations, don't need a separate counter var in such cases
if (range.isEmpty()) if (range.isEmpty())
throw AssemblyError("empty range") throw AssemblyError("empty range")
val loopLabel = asmgen.makeLabel("for_loop") val loopLabel = asmgen.makeLabel("for_loop")

View File

@ -87,7 +87,20 @@ val BuiltinFunctions = mapOf(
FParam("address", IterableDatatypes + DataType.UWORD), FParam("address", IterableDatatypes + DataType.UWORD),
FParam("numwords", setOf(DataType.UWORD)), FParam("numwords", setOf(DataType.UWORD)),
FParam("wordvalue", setOf(DataType.UWORD, DataType.WORD))), null), FParam("wordvalue", setOf(DataType.UWORD, DataType.WORD))), null),
"strlen" to FSignature(true, listOf(FParam("string", setOf(DataType.STR))), DataType.UBYTE, ::builtinStrlen) "strlen" to FSignature(true, listOf(FParam("string", setOf(DataType.STR))), DataType.UBYTE, ::builtinStrlen),
"substr" to FSignature(false, listOf(
FParam("source", IterableDatatypes + DataType.UWORD),
FParam("target", IterableDatatypes + DataType.UWORD),
FParam("start", setOf(DataType.UBYTE)),
FParam("length", setOf(DataType.UBYTE))), null),
"leftstr" to FSignature(false, listOf(
FParam("source", IterableDatatypes + DataType.UWORD),
FParam("target", IterableDatatypes + DataType.UWORD),
FParam("length", setOf(DataType.UBYTE))), null),
"rightstr" to FSignature(false, listOf(
FParam("source", IterableDatatypes + DataType.UWORD),
FParam("target", IterableDatatypes + DataType.UWORD),
FParam("length", setOf(DataType.UBYTE))), null)
) )
fun builtinMax(array: List<Number>): Number = array.maxBy { it.toDouble() }!! fun builtinMax(array: List<Number>): Number = array.maxBy { it.toDouble() }!!
@ -172,6 +185,7 @@ fun builtinFunctionReturnType(function: String, args: List<Expression>, program:
class NotConstArgumentException: AstException("not a const argument to a built-in function") class NotConstArgumentException: AstException("not a const argument to a built-in function")
class CannotEvaluateException(func:String, msg: String): FatalAstException("cannot evaluate built-in function $func: $msg")
private fun oneDoubleArg(args: List<Expression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue { private fun oneDoubleArg(args: List<Expression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue {
@ -252,17 +266,22 @@ private fun builtinLen(args: List<Expression>, position: Position, program: Prog
return NumericLiteralValue.optimalInteger((args[0] as ArrayLiteralValue).value.size, position) return NumericLiteralValue.optimalInteger((args[0] as ArrayLiteralValue).value.size, position)
if(args[0] !is IdentifierReference) if(args[0] !is IdentifierReference)
throw SyntaxError("len argument should be an identifier, but is ${args[0]}", position) throw SyntaxError("len argument should be an identifier, but is ${args[0]}", position)
val target = (args[0] as IdentifierReference).targetVarDecl(program.namespace)!! val target = (args[0] as IdentifierReference).targetVarDecl(program.namespace)
?: throw CannotEvaluateException("len", "no target vardecl")
return when(target.datatype) { return when(target.datatype) {
DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W -> { DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W -> {
arraySize = target.arraysize!!.size()!! arraySize = target.arraysize?.size()
if(arraySize==null)
throw CannotEvaluateException("len", "arraysize unknown")
if(arraySize>256) if(arraySize>256)
throw CompilerException("array length exceeds byte limit ${target.position}") throw CompilerException("array length exceeds byte limit ${target.position}")
NumericLiteralValue.optimalInteger(arraySize, args[0].position) NumericLiteralValue.optimalInteger(arraySize, args[0].position)
} }
DataType.ARRAY_F -> { DataType.ARRAY_F -> {
arraySize = target.arraysize!!.size()!! arraySize = target.arraysize?.size()
if(arraySize==null)
throw CannotEvaluateException("len", "arraysize unknown")
if(arraySize>256) if(arraySize>256)
throw CompilerException("array length exceeds byte limit ${target.position}") throw CompilerException("array length exceeds byte limit ${target.position}")
NumericLiteralValue.optimalInteger(arraySize, args[0].position) NumericLiteralValue.optimalInteger(arraySize, args[0].position)

View File

@ -14,6 +14,7 @@ import prog8.ast.statements.PostIncrDecr
internal class AssignmentTransformer(val program: Program, val errors: ErrorReporter) : AstWalker() { internal class AssignmentTransformer(val program: Program, val errors: ErrorReporter) : AstWalker() {
var optimizationsDone: Int = 0 var optimizationsDone: Int = 0
private val noModifications = emptyList<IAstModification>()
override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> { override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> {
// modify A = A + 5 back into augmented form A += 5 for easier code generation for optimized in-place assignments // modify A = A + 5 back into augmented form A += 5 for easier code generation for optimized in-place assignments
@ -27,7 +28,7 @@ internal class AssignmentTransformer(val program: Program, val errors: ErrorRepo
assignment.aug_op = binExpr.operator + "=" assignment.aug_op = binExpr.operator + "="
assignment.value.parent = assignment assignment.value.parent = assignment
optimizationsDone++ optimizationsDone++
return emptyList() return noModifications
} }
} }
assignment.aug_op = "setvalue" assignment.aug_op = "setvalue"
@ -151,6 +152,6 @@ internal class AssignmentTransformer(val program: Program, val errors: ErrorRepo
} }
} }
} }
return emptyList() return noModifications
} }
} }

View File

@ -24,10 +24,10 @@ private val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexO
class CallGraph(private val program: Program) : IAstVisitor { class CallGraph(private val program: Program) : IAstVisitor {
val modulesImporting = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() } val imports = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() }
val modulesImportedBy = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() } val importedBy = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() }
val subroutinesCalling = mutableMapOf<INameScope, List<Subroutine>>().withDefault { mutableListOf() } val calls = mutableMapOf<INameScope, List<Subroutine>>().withDefault { mutableListOf() }
val subroutinesCalledBy = mutableMapOf<Subroutine, List<Node>>().withDefault { mutableListOf() } val calledBy = mutableMapOf<Subroutine, List<Node>>().withDefault { mutableListOf() }
// TODO add dataflow graph: what statements use what variables - can be used to eliminate unused vars // TODO add dataflow graph: what statements use what variables - can be used to eliminate unused vars
val usedSymbols = mutableSetOf<Statement>() val usedSymbols = mutableSetOf<Statement>()
@ -55,17 +55,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
it.importedBy.clear() it.importedBy.clear()
it.imports.clear() it.imports.clear()
it.importedBy.addAll(modulesImportedBy.getValue(it)) it.importedBy.addAll(importedBy.getValue(it))
it.imports.addAll(modulesImporting.getValue(it)) it.imports.addAll(imports.getValue(it))
forAllSubroutines(it) { sub ->
sub.calledBy.clear()
sub.calls.clear()
sub.calledBy.addAll(subroutinesCalledBy.getValue(sub))
sub.calls.addAll(subroutinesCalling.getValue(sub))
}
} }
val rootmodule = program.modules.first() val rootmodule = program.modules.first()
@ -85,8 +76,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
val thisModule = directive.definingModule() val thisModule = directive.definingModule()
if (directive.directive == "%import") { if (directive.directive == "%import") {
val importedModule: Module = program.modules.single { it.name == directive.args[0].name } val importedModule: Module = program.modules.single { it.name == directive.args[0].name }
modulesImporting[thisModule] = modulesImporting.getValue(thisModule).plus(importedModule) imports[thisModule] = imports.getValue(thisModule).plus(importedModule)
modulesImportedBy[importedModule] = modulesImportedBy.getValue(importedModule).plus(thisModule) importedBy[importedModule] = importedBy.getValue(importedModule).plus(thisModule)
} else if (directive.directive == "%asminclude") { } else if (directive.directive == "%asminclude") {
val asm = loadAsmIncludeFile(directive.args[0].str!!, thisModule.source) val asm = loadAsmIncludeFile(directive.args[0].str!!, thisModule.source)
val scope = directive.definingScope() val scope = directive.definingScope()
@ -141,8 +132,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
val otherSub = functionCall.target.targetSubroutine(program.namespace) val otherSub = functionCall.target.targetSubroutine(program.namespace)
if (otherSub != null) { if (otherSub != null) {
functionCall.definingSubroutine()?.let { thisSub -> functionCall.definingSubroutine()?.let { thisSub ->
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) calls[thisSub] = calls.getValue(thisSub).plus(otherSub)
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCall) calledBy[otherSub] = calledBy.getValue(otherSub).plus(functionCall)
} }
} }
super.visit(functionCall) super.visit(functionCall)
@ -152,8 +143,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
val otherSub = functionCallStatement.target.targetSubroutine(program.namespace) val otherSub = functionCallStatement.target.targetSubroutine(program.namespace)
if (otherSub != null) { if (otherSub != null) {
functionCallStatement.definingSubroutine()?.let { thisSub -> functionCallStatement.definingSubroutine()?.let { thisSub ->
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) calls[thisSub] = calls.getValue(thisSub).plus(otherSub)
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCallStatement) calledBy[otherSub] = calledBy.getValue(otherSub).plus(functionCallStatement)
} }
} }
super.visit(functionCallStatement) super.visit(functionCallStatement)
@ -163,8 +154,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
val otherSub = jump.identifier?.targetSubroutine(program.namespace) val otherSub = jump.identifier?.targetSubroutine(program.namespace)
if (otherSub != null) { if (otherSub != null) {
jump.definingSubroutine()?.let { thisSub -> jump.definingSubroutine()?.let { thisSub ->
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) calls[thisSub] = calls.getValue(thisSub).plus(otherSub)
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(jump) calledBy[otherSub] = calledBy.getValue(otherSub).plus(jump)
} }
} }
super.visit(jump) super.visit(jump)
@ -190,14 +181,14 @@ class CallGraph(private val program: Program) : IAstVisitor {
if (jumptarget != null && (jumptarget[0].isLetter() || jumptarget[0] == '_')) { if (jumptarget != null && (jumptarget[0].isLetter() || jumptarget[0] == '_')) {
val node = program.namespace.lookup(jumptarget.split('.'), context) val node = program.namespace.lookup(jumptarget.split('.'), context)
if (node is Subroutine) { if (node is Subroutine) {
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) calls[scope] = calls.getValue(scope).plus(node)
subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) calledBy[node] = calledBy.getValue(node).plus(context)
} else if (jumptarget.contains('.')) { } else if (jumptarget.contains('.')) {
// maybe only the first part already refers to a subroutine // maybe only the first part already refers to a subroutine
val node2 = program.namespace.lookup(listOf(jumptarget.substringBefore('.')), context) val node2 = program.namespace.lookup(listOf(jumptarget.substringBefore('.')), context)
if (node2 is Subroutine) { if (node2 is Subroutine) {
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node2) calls[scope] = calls.getValue(scope).plus(node2)
subroutinesCalledBy[node2] = subroutinesCalledBy.getValue(node2).plus(context) calledBy[node2] = calledBy.getValue(node2).plus(context)
} }
} }
} }
@ -209,8 +200,8 @@ class CallGraph(private val program: Program) : IAstVisitor {
if (target.contains('.')) { if (target.contains('.')) {
val node = program.namespace.lookup(listOf(target.substringBefore('.')), context) val node = program.namespace.lookup(listOf(target.substringBefore('.')), context)
if (node is Subroutine) { if (node is Subroutine) {
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) calls[scope] = calls.getValue(scope).plus(node)
subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) calledBy[node] = calledBy.getValue(node).plus(context)
} }
} }
} }

View File

@ -1,25 +1,46 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.IFunctionCall import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModification
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.CompilationTarget import prog8.compiler.target.CompilationTarget
import prog8.functions.BuiltinFunctions
// TODO implement using AstWalker instead of IAstModifyingVisitor // First thing to do is replace all constant identifiers with their actual value,
internal class ConstantFoldingOptimizer(private val program: Program, private val errors: ErrorReporter) : IAstModifyingVisitor { // and the array var initializer values and sizes.
var optimizationsDone: Int = 0 // This is needed because further constant optimizations depend on those.
internal class ConstantIdentifierReplacer(private val program: Program, private val errors: ErrorReporter) : AstWalker() {
private val noModifications = emptyList<IAstModification>()
override fun visit(decl: VarDecl): Statement { override fun after(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> {
// replace identifiers that refer to const value, with the value itself
// if it's a simple type and if it's not a left hand side variable
if(identifier.parent is AssignTarget)
return noModifications
var forloop = identifier.parent as? ForLoop
if(forloop==null)
forloop = identifier.parent.parent as? ForLoop
if(forloop!=null && identifier===forloop.loopVar)
return noModifications
val cval = identifier.constValue(program) ?: return noModifications
return when (cval.type) {
in NumericDatatypes -> listOf(IAstModification.ReplaceNode(identifier, NumericLiteralValue(cval.type, cval.number, identifier.position), identifier.parent))
in PassByReferenceDatatypes -> throw FatalAstException("pass-by-reference type should not be considered a constant")
else -> noModifications
}
}
override fun before(decl: VarDecl, parent: Node): Iterable<IAstModification> {
// the initializer value can't refer to the variable itself (recursive definition) // the initializer value can't refer to the variable itself (recursive definition)
// TODO: use call graph for this? // TODO: use call graph for this?
if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) { if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) {
errors.err("recursive var declaration", decl.position) errors.err("recursive var declaration", decl.position)
return decl return noModifications
} }
if(decl.type==VarDeclType.CONST || decl.type==VarDeclType.VAR) { if(decl.type==VarDeclType.CONST || decl.type==VarDeclType.VAR) {
@ -28,15 +49,20 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
// for arrays that have no size specifier (or a non-constant one) attempt to deduce the size // for arrays that have no size specifier (or a non-constant one) attempt to deduce the size
val arrayval = decl.value as? ArrayLiteralValue val arrayval = decl.value as? ArrayLiteralValue
if(arrayval!=null) { if(arrayval!=null) {
decl.arraysize = ArrayIndex(NumericLiteralValue.optimalInteger(arrayval.value.size, decl.position), decl.position) return listOf(IAstModification.SetExpression(
optimizationsDone++ { decl.arraysize = ArrayIndex(it, decl.position) },
NumericLiteralValue.optimalInteger(arrayval.value.size, decl.position),
decl
))
} }
} }
else if(decl.arraysize?.size()==null) { else if(decl.arraysize?.size()==null) {
val size = decl.arraysize!!.index.accept(this) val size = decl.arraysize!!.index.constValue(program)
if(size is NumericLiteralValue) { if(size!=null) {
decl.arraysize = ArrayIndex(size, decl.position) return listOf(IAstModification.SetExpression(
optimizationsDone++ { decl.arraysize = ArrayIndex(it, decl.position) },
size, decl
))
} }
} }
} }
@ -47,9 +73,7 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
val litval = decl.value as? NumericLiteralValue val litval = decl.value as? NumericLiteralValue
if (litval!=null && litval.type in IntegerDatatypes) { if (litval!=null && litval.type in IntegerDatatypes) {
val newValue = NumericLiteralValue(DataType.FLOAT, litval.number.toDouble(), litval.position) val newValue = NumericLiteralValue(DataType.FLOAT, litval.number.toDouble(), litval.position)
decl.value = newValue return listOf(IAstModification.ReplaceNode(decl.value!!, newValue, decl))
optimizationsDone++
return super.visit(decl)
} }
} }
DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W -> { DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W -> {
@ -63,23 +87,21 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
val constRange = rangeExpr.toConstantIntegerRange() val constRange = rangeExpr.toConstantIntegerRange()
if(constRange!=null) { if(constRange!=null) {
val eltType = rangeExpr.inferType(program).typeOrElse(DataType.UBYTE) val eltType = rangeExpr.inferType(program).typeOrElse(DataType.UBYTE)
if(eltType in ByteDatatypes) { val newValue = if(eltType in ByteDatatypes) {
decl.value = ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype), ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype),
constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(), constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(),
position = decl.value!!.position) position = decl.value!!.position)
} else { } else {
decl.value = ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype), ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype),
constRange.map { NumericLiteralValue(eltType, it, decl.value!!.position) }.toTypedArray(), constRange.map { NumericLiteralValue(eltType, it, decl.value!!.position) }.toTypedArray(),
position = decl.value!!.position) position = decl.value!!.position)
} }
decl.value!!.linkParents(decl) return listOf(IAstModification.ReplaceNode(decl.value!!, newValue, decl))
optimizationsDone++
return super.visit(decl)
} }
} }
if(numericLv!=null && numericLv.type== DataType.FLOAT) if(numericLv!=null && numericLv.type==DataType.FLOAT)
errors.err("arraysize requires only integers here", numericLv.position) errors.err("arraysize requires only integers here", numericLv.position)
val size = decl.arraysize?.size() ?: return decl val size = decl.arraysize?.size() ?: return noModifications
if (rangeExpr==null && numericLv!=null) { if (rangeExpr==null && numericLv!=null) {
// arraysize initializer is empty or a single int, and we know the size; create the arraysize. // arraysize initializer is empty or a single int, and we know the size; create the arraysize.
val fillvalue = numericLv.number.toInt() val fillvalue = numericLv.number.toInt()
@ -105,19 +127,27 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
// create the array itself, filled with the fillvalue. // create the array itself, filled with the fillvalue.
val array = Array(size) {fillvalue}.map { NumericLiteralValue(ArrayElementTypes.getValue(decl.datatype), it, numericLv.position) as Expression}.toTypedArray() val array = Array(size) {fillvalue}.map { NumericLiteralValue(ArrayElementTypes.getValue(decl.datatype), it, numericLv.position) as Expression}.toTypedArray()
val refValue = ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype), array, position = numericLv.position) val refValue = ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype), array, position = numericLv.position)
decl.value = refValue return listOf(IAstModification.ReplaceNode(decl.value!!, refValue, decl))
refValue.parent=decl
optimizationsDone++
return super.visit(decl)
} }
} }
DataType.ARRAY_F -> { DataType.ARRAY_F -> {
val size = decl.arraysize?.size() ?: return decl val size = decl.arraysize?.size() ?: return noModifications
val litval = decl.value as? NumericLiteralValue val litval = decl.value as? NumericLiteralValue
if(litval==null) { val rangeExpr = decl.value as? RangeExpr
// there's no initialization value, but the size is known, so we're ok. if(rangeExpr!=null) {
return super.visit(decl) // convert the initializer range expression to an actual array of floats
} else { val declArraySize = decl.arraysize?.size()
if(declArraySize!=null && declArraySize!=rangeExpr.size())
errors.err("range expression size doesn't match declared array size", decl.value?.position!!)
val constRange = rangeExpr.toConstantIntegerRange()
if(constRange!=null) {
val newValue = ArrayLiteralValue(InferredTypes.InferredType.known(DataType.ARRAY_F),
constRange.map { NumericLiteralValue(DataType.FLOAT, it.toDouble(), decl.value!!.position) }.toTypedArray(),
position = decl.value!!.position)
return listOf(IAstModification.ReplaceNode(decl.value!!, newValue, decl))
}
}
if(rangeExpr==null && litval!=null) {
// arraysize initializer is a single int, and we know the size. // arraysize initializer is a single int, and we know the size.
val fillvalue = litval.number.toDouble() val fillvalue = litval.number.toDouble()
if (fillvalue < CompilationTarget.machine.FLOAT_MAX_NEGATIVE || fillvalue > CompilationTarget.machine.FLOAT_MAX_POSITIVE) if (fillvalue < CompilationTarget.machine.FLOAT_MAX_NEGATIVE || fillvalue > CompilationTarget.machine.FLOAT_MAX_POSITIVE)
@ -126,10 +156,7 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
// create the array itself, filled with the fillvalue. // create the array itself, filled with the fillvalue.
val array = Array(size) {fillvalue}.map { NumericLiteralValue(DataType.FLOAT, it, litval.position) as Expression}.toTypedArray() val array = Array(size) {fillvalue}.map { NumericLiteralValue(DataType.FLOAT, it, litval.position) as Expression}.toTypedArray()
val refValue = ArrayLiteralValue(InferredTypes.InferredType.known(DataType.ARRAY_F), array, position = litval.position) val refValue = ArrayLiteralValue(InferredTypes.InferredType.known(DataType.ARRAY_F), array, position = litval.position)
decl.value = refValue return listOf(IAstModification.ReplaceNode(decl.value!!, refValue, decl))
refValue.parent=decl
optimizationsDone++
return super.visit(decl)
} }
} }
} }
@ -144,135 +171,69 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
if(declValue!=null && decl.type==VarDeclType.VAR if(declValue!=null && decl.type==VarDeclType.VAR
&& declValue is NumericLiteralValue && !declValue.inferType(program).istype(decl.datatype)) { && declValue is NumericLiteralValue && !declValue.inferType(program).istype(decl.datatype)) {
// cast the numeric literal to the appropriate datatype of the variable // cast the numeric literal to the appropriate datatype of the variable
decl.value = declValue.cast(decl.datatype) return listOf(IAstModification.ReplaceNode(decl.value!!, declValue.cast(decl.datatype), decl))
} }
return super.visit(decl) return noModifications
} }
}
/**
* replace identifiers that refer to const value, with the value itself (if it's a simple type)
*/
override fun visit(identifier: IdentifierReference): Expression {
// don't replace when it's an assignment target or loop variable
if(identifier.parent is AssignTarget)
return identifier
var forloop = identifier.parent as? ForLoop
if(forloop==null)
forloop = identifier.parent.parent as? ForLoop
if(forloop!=null && identifier===forloop.loopVar)
return identifier
val cval = identifier.constValue(program) ?: return identifier internal class ConstantFoldingOptimizer(private val program: Program, private val errors: ErrorReporter) : AstWalker() {
return when (cval.type) { private val noModifications = emptyList<IAstModification>()
in NumericDatatypes -> {
val copy = NumericLiteralValue(cval.type, cval.number, identifier.position)
copy.parent = identifier.parent
copy
}
in PassByReferenceDatatypes -> throw FatalAstException("pass-by-reference type should not be considered a constant")
else -> identifier
}
}
override fun visit(functionCall: FunctionCall): Expression { override fun before(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
super.visit(functionCall)
typeCastConstArguments(functionCall)
return functionCall.constValue(program) ?: functionCall
}
override fun visit(functionCallStatement: FunctionCallStatement): Statement {
super.visit(functionCallStatement)
typeCastConstArguments(functionCallStatement)
return functionCallStatement
}
private fun typeCastConstArguments(functionCall: IFunctionCall) {
if(functionCall.target.nameInSource.size==1) {
val builtinFunction = BuiltinFunctions[functionCall.target.nameInSource.single()]
if(builtinFunction!=null) {
// match the arguments of a builtin function signature.
for(arg in functionCall.args.withIndex().zip(builtinFunction.parameters)) {
val possibleDts = arg.second.possibleDatatypes
val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type !in possibleDts) {
val convertedValue = argConst.cast(possibleDts.first())
functionCall.args[arg.first.index] = convertedValue
optimizationsDone++
}
}
return
}
}
// match the arguments of a subroutine.
val subroutine = functionCall.target.targetSubroutine(program.namespace)
if(subroutine!=null) {
// if types differ, try to typecast constant arguments to the function call to the desired data type of the parameter
for(arg in functionCall.args.withIndex().zip(subroutine.parameters)) {
val expectedDt = arg.second.type
val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type!=expectedDt) {
val convertedValue = argConst.cast(expectedDt)
functionCall.args[arg.first.index] = convertedValue
optimizationsDone++
}
}
}
}
override fun visit(memread: DirectMemoryRead): Expression {
// @( &thing ) --> thing // @( &thing ) --> thing
val addrOf = memread.addressExpression as? AddressOf val addrOf = memread.addressExpression as? AddressOf
if(addrOf!=null) return if(addrOf!=null)
return super.visit(addrOf.identifier) listOf(IAstModification.ReplaceNode(memread, addrOf.identifier, parent))
return super.visit(memread) else
noModifications
} }
/** override fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> {
* Try to accept a unary prefix expression. // Try to turn a unary prefix expression into a single constant value.
* Compile-time constant sub expressions will be evaluated on the spot. // Compile-time constant sub expressions will be evaluated on the spot.
* For instance, the expression for "- 4.5" will be optimized into the float literal -4.5 // For instance, the expression for "- 4.5" will be optimized into the float literal -4.5
*/ val subexpr = expr.expression
override fun visit(expr: PrefixExpression): Expression {
val prefixExpr=super.visit(expr)
if(prefixExpr !is PrefixExpression)
return prefixExpr
val subexpr = prefixExpr.expression
if (subexpr is NumericLiteralValue) { if (subexpr is NumericLiteralValue) {
// accept prefixed literal values (such as -3, not true) // accept prefixed literal values (such as -3, not true)
return when (prefixExpr.operator) { return when (expr.operator) {
"+" -> subexpr "+" -> listOf(IAstModification.ReplaceNode(expr, subexpr, parent))
"-" -> when (subexpr.type) { "-" -> when (subexpr.type) {
in IntegerDatatypes -> { in IntegerDatatypes -> {
optimizationsDone++ listOf(IAstModification.ReplaceNode(expr,
NumericLiteralValue.optimalNumeric(-subexpr.number.toInt(), subexpr.position) NumericLiteralValue.optimalNumeric(-subexpr.number.toInt(), subexpr.position),
parent))
} }
DataType.FLOAT -> { DataType.FLOAT -> {
optimizationsDone++ listOf(IAstModification.ReplaceNode(expr,
NumericLiteralValue(DataType.FLOAT, -subexpr.number.toDouble(), subexpr.position) NumericLiteralValue(DataType.FLOAT, -subexpr.number.toDouble(), subexpr.position),
parent))
} }
else -> throw ExpressionError("can only take negative of int or float", subexpr.position) else -> throw ExpressionError("can only take negative of int or float", subexpr.position)
} }
"~" -> when (subexpr.type) { "~" -> when (subexpr.type) {
in IntegerDatatypes -> { in IntegerDatatypes -> {
optimizationsDone++ listOf(IAstModification.ReplaceNode(expr,
NumericLiteralValue.optimalNumeric(subexpr.number.toInt().inv(), subexpr.position) NumericLiteralValue.optimalNumeric(subexpr.number.toInt().inv(), subexpr.position),
parent))
} }
else -> throw ExpressionError("can only take bitwise inversion of int", subexpr.position) else -> throw ExpressionError("can only take bitwise inversion of int", subexpr.position)
} }
"not" -> { "not" -> {
optimizationsDone++ listOf(IAstModification.ReplaceNode(expr,
NumericLiteralValue.fromBoolean(subexpr.number.toDouble() == 0.0, subexpr.position) NumericLiteralValue.fromBoolean(subexpr.number.toDouble() == 0.0, subexpr.position),
parent))
} }
else -> throw ExpressionError(prefixExpr.operator, subexpr.position) else -> throw ExpressionError(expr.operator, subexpr.position)
} }
} }
return prefixExpr return noModifications
} }
/** /**
* Try to accept a binary expression. * Try to constfold a binary expression.
* Compile-time constant sub expressions will be evaluated on the spot. * Compile-time constant sub expressions will be evaluated on the spot.
* For instance, "9 * (4 + 2)" will be optimized into the integer literal 54. * For instance, "9 * (4 + 2)" will be optimized into the integer literal 54.
* *
@ -288,13 +249,7 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
* (X / c1) * c2 -> X / (c2/c1) * (X / c1) * c2 -> X / (c2/c1)
* (X + c1) - c2 -> X + (c1-c2) * (X + c1) - c2 -> X + (c1-c2)
*/ */
override fun visit(expr: BinaryExpression): Expression { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
super.visit(expr)
if(expr.left is StringLiteralValue || expr.left is ArrayLiteralValue
|| expr.right is StringLiteralValue || expr.right is ArrayLiteralValue)
throw FatalAstException("binexpr with reference litval instead of numeric")
val leftconst = expr.left.constValue(program) val leftconst = expr.left.constValue(program)
val rightconst = expr.right.constValue(program) val rightconst = expr.right.constValue(program)
@ -308,218 +263,62 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
val subrightconst = subExpr.right.constValue(program) val subrightconst = subExpr.right.constValue(program)
if ((subleftconst != null && subrightconst == null) || (subleftconst==null && subrightconst!=null)) { if ((subleftconst != null && subrightconst == null) || (subleftconst==null && subrightconst!=null)) {
// try reordering. // try reordering.
return groupTwoConstsTogether(expr, subExpr, val change = groupTwoConstsTogether(expr, subExpr,
leftconst != null, rightconst != null, leftconst != null, rightconst != null,
subleftconst != null, subrightconst != null) subleftconst != null, subrightconst != null)
return change?.let { listOf(it) } ?: noModifications
} }
} }
// const fold when both operands are a const // const fold when both operands are a const
return when { if(leftconst != null && rightconst != null) {
leftconst != null && rightconst != null -> {
optimizationsDone++
val evaluator = ConstExprEvaluator() val evaluator = ConstExprEvaluator()
evaluator.evaluate(leftconst, expr.operator, rightconst) return listOf(IAstModification.ReplaceNode(
expr,
evaluator.evaluate(leftconst, expr.operator, rightconst),
parent
))
} }
else -> expr return noModifications
}
override fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> {
// because constant folding can result in arrays that are now suddenly capable
// of telling the type of all their elements (for instance, when they contained -2 which
// was a prefix expression earlier), we recalculate the array's datatype.
if(array.type.isKnown)
return noModifications
// if the array literalvalue is inside an array vardecl, take the type from that
// otherwise infer it from the elements of the array
val vardeclType = (array.parent as? VarDecl)?.datatype
if(vardeclType!=null) {
val newArray = array.cast(vardeclType)
if (newArray != null && newArray != array)
return listOf(IAstModification.ReplaceNode(array, newArray, parent))
} else {
val arrayDt = array.guessDatatype(program)
if (arrayDt.isKnown) {
val newArray = array.cast(arrayDt.typeOrElse(DataType.STRUCT))
if (newArray != null && newArray != array)
return listOf(IAstModification.ReplaceNode(array, newArray, parent))
} }
} }
private fun groupTwoConstsTogether(expr: BinaryExpression, return noModifications
subExpr: BinaryExpression, }
leftIsConst: Boolean,
rightIsConst: Boolean, override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
subleftIsConst: Boolean, // the args of a fuction are constfolded via recursion already.
subrightIsConst: Boolean): Expression val constvalue = functionCall.constValue(program)
{ return if(constvalue!=null)
// todo: this implements only a small set of possible reorderings at this time listOf(IAstModification.ReplaceNode(functionCall, constvalue, parent))
if(expr.operator==subExpr.operator) {
// both operators are the isSameAs.
// If + or *, we can simply swap the const of expr and Var in subexpr.
if(expr.operator=="+" || expr.operator=="*") {
if(leftIsConst) {
if(subleftIsConst)
expr.left = subExpr.right.also { subExpr.right = expr.left }
else else
expr.left = subExpr.left.also { subExpr.left = expr.left } noModifications
} else {
if(subleftIsConst)
expr.right = subExpr.right.also {subExpr.right = expr.right }
else
expr.right = subExpr.left.also { subExpr.left = expr.right }
}
optimizationsDone++
return expr
} }
// If - or /, we simetimes must reorder more, and flip operators (- -> +, / -> *) override fun after(forLoop: ForLoop, parent: Node): Iterable<IAstModification> {
if(expr.operator=="-" || expr.operator=="/") {
optimizationsDone++
if(leftIsConst) {
return if(subleftIsConst) {
val tmp = subExpr.right
subExpr.right = subExpr.left
subExpr.left = expr.left
expr.left = tmp
expr.operator = if(expr.operator=="-") "+" else "*"
expr
} else
BinaryExpression(
BinaryExpression(expr.left, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
expr.operator, subExpr.left, expr.position)
} else {
return if(subleftIsConst) {
expr.right = subExpr.right.also { subExpr.right = expr.right }
expr
} else
BinaryExpression(
subExpr.left, expr.operator,
BinaryExpression(expr.right, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
expr.position)
}
}
return expr
}
else
{
if(expr.operator=="/" && subExpr.operator=="*") {
optimizationsDone++
if(leftIsConst) {
return if(subleftIsConst) {
// C1/(C2*V) -> (C1/C2)/V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.left, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// C1/(V*C2) -> (C1/C2)/V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
"/",
subExpr.left, expr.position)
}
} else {
return if(subleftIsConst) {
// (C1*V)/C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(subExpr.left, "/", expr.right, subExpr.position),
"*",
subExpr.right, expr.position)
} else {
// (V*C1)/C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(subExpr.right, "/", expr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
}
}
else if(expr.operator=="*" && subExpr.operator=="/") {
optimizationsDone++
if(leftIsConst) {
return if(subleftIsConst) {
// C1*(C2/V) -> (C1*C2)/V
BinaryExpression(
BinaryExpression(expr.left, "*", subExpr.left, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// C1*(V/C2) -> (C1/C2)*V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
} else {
return if(subleftIsConst) {
// (C1/V)*C2 -> (C1*C2)/V
BinaryExpression(
BinaryExpression(subExpr.left, "*", expr.right, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// (V/C1)*C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(expr.right, "/", subExpr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
}
}
else if(expr.operator=="+" && subExpr.operator=="-") {
optimizationsDone++
if(leftIsConst){
return if(subleftIsConst){
// c1+(c2-v) -> (c1+c2)-v
BinaryExpression(
BinaryExpression(expr.left, "+", subExpr.left, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// c1+(v-c2) -> v+(c1-c2)
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
} else {
return if(subleftIsConst) {
// (c1-v)+c2 -> (c1+c2)-v
BinaryExpression(
BinaryExpression(subExpr.left, "+", expr.right, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// (v-c1)+c2 -> v+(c2-c1)
BinaryExpression(
BinaryExpression(expr.right, "-", subExpr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
}
}
else if(expr.operator=="-" && subExpr.operator=="+") {
optimizationsDone++
if(leftIsConst) {
return if(subleftIsConst) {
// c1-(c2+v) -> (c1-c2)-v
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.left, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// c1-(v+c2) -> (c1-c2)-v
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
"-",
subExpr.left, expr.position)
}
} else {
return if(subleftIsConst) {
// (c1+v)-c2 -> v+(c1-c2)
BinaryExpression(
BinaryExpression(subExpr.left, "-", expr.right, subExpr.position),
"+",
subExpr.right, expr.position)
} else {
// (v+c1)-c2 -> v+(c1-c2)
BinaryExpression(
BinaryExpression(subExpr.right, "-", expr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
}
}
return expr
}
}
override fun visit(forLoop: ForLoop): Statement {
fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr { fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr {
val newFrom: NumericLiteralValue val newFrom: NumericLiteralValue
val newTo: NumericLiteralValue val newTo: NumericLiteralValue
@ -537,87 +336,259 @@ internal class ConstantFoldingOptimizer(private val program: Program, private va
return RangeExpr(newFrom, newTo, newStep, range.position) return RangeExpr(newFrom, newTo, newStep, range.position)
} }
val forLoop2 = super.visit(forLoop) as ForLoop
// check if we need to adjust an array literal to the loop variable's datatype
val array = forLoop2.iterable as? ArrayLiteralValue
if(array!=null) {
val loopvarDt: DataType = when {
forLoop.loopVar!=null -> forLoop.loopVar!!.inferType(program).typeOrElse(DataType.UBYTE)
forLoop.loopRegister!=null -> DataType.UBYTE
else -> throw FatalAstException("weird for loop")
}
val arrayType = when(loopvarDt) {
DataType.UBYTE -> DataType.ARRAY_UB
DataType.BYTE -> DataType.ARRAY_B
DataType.UWORD -> DataType.ARRAY_UW
DataType.WORD -> DataType.ARRAY_W
DataType.FLOAT -> DataType.ARRAY_F
else -> throw FatalAstException("invalid array elt type")
}
val array2 = array.cast(arrayType)
if(array2!=null && array2!==array) {
forLoop2.iterable = array2
array2.linkParents(forLoop2)
}
}
// adjust the datatype of a range expression in for loops to the loop variable. // adjust the datatype of a range expression in for loops to the loop variable.
val iterableRange = forLoop2.iterable as? RangeExpr ?: return forLoop2 val iterableRange = forLoop.iterable as? RangeExpr ?: return noModifications
val rangeFrom = iterableRange.from as? NumericLiteralValue val rangeFrom = iterableRange.from as? NumericLiteralValue
val rangeTo = iterableRange.to as? NumericLiteralValue val rangeTo = iterableRange.to as? NumericLiteralValue
if(rangeFrom==null || rangeTo==null) return forLoop2 if(rangeFrom==null || rangeTo==null) return noModifications
val loopvar = forLoop2.loopVar?.targetVarDecl(program.namespace) val loopvar = forLoop.loopVar?.targetVarDecl(program.namespace)
if(loopvar!=null) { if(loopvar!=null) {
val stepLiteral = iterableRange.step as? NumericLiteralValue val stepLiteral = iterableRange.step as? NumericLiteralValue
when(loopvar.datatype) { when(loopvar.datatype) {
DataType.UBYTE -> { DataType.UBYTE -> {
if(rangeFrom.type!= DataType.UBYTE) { if(rangeFrom.type!= DataType.UBYTE) {
// attempt to translate the iterable into ubyte values // attempt to translate the iterable into ubyte values
forLoop2.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
} }
} }
DataType.BYTE -> { DataType.BYTE -> {
if(rangeFrom.type!= DataType.BYTE) { if(rangeFrom.type!= DataType.BYTE) {
// attempt to translate the iterable into byte values // attempt to translate the iterable into byte values
forLoop2.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
} }
} }
DataType.UWORD -> { DataType.UWORD -> {
if(rangeFrom.type!= DataType.UWORD) { if(rangeFrom.type!= DataType.UWORD) {
// attempt to translate the iterable into uword values // attempt to translate the iterable into uword values
forLoop2.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
} }
} }
DataType.WORD -> { DataType.WORD -> {
if(rangeFrom.type!= DataType.WORD) { if(rangeFrom.type!= DataType.WORD) {
// attempt to translate the iterable into word values // attempt to translate the iterable into word values
forLoop2.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
} }
} }
else -> throw FatalAstException("invalid loopvar datatype $loopvar") else -> throw FatalAstException("invalid loopvar datatype $loopvar")
} }
} }
return forLoop2
return noModifications
} }
override fun visit(arrayLiteral: ArrayLiteralValue): Expression { private class ShuffleOperands(val expr: BinaryExpression,
// because constant folding can result in arrays that are now suddenly capable val exprOperator: String?,
// of telling the type of all their elements (for instance, when they contained -2 which val subExpr: BinaryExpression,
// was a prefix expression earlier), we recalculate the array's datatype. val newExprLeft: Expression?,
val array = super.visit(arrayLiteral) val newExprRight: Expression?,
if(array is ArrayLiteralValue) { val newSubexprLeft: Expression?,
if(array.type.isKnown) val newSubexprRight: Expression?
return array ): IAstModification {
val arrayDt = array.guessDatatype(program) override fun perform() {
if(arrayDt.isKnown) { if(exprOperator!=null) expr.operator = exprOperator
val newArray = arrayLiteral.cast(arrayDt.typeOrElse(DataType.STRUCT)) if(newExprLeft!=null) expr.left = newExprLeft
if(newArray!=null) if(newExprRight!=null) expr.right = newExprRight
return newArray if(newSubexprLeft!=null) subExpr.left = newSubexprLeft
if(newSubexprRight!=null) subExpr.right = newSubexprRight
} }
} }
return array
private fun groupTwoConstsTogether(expr: BinaryExpression,
subExpr: BinaryExpression,
leftIsConst: Boolean,
rightIsConst: Boolean,
subleftIsConst: Boolean,
subrightIsConst: Boolean): IAstModification?
{
// todo: this implements only a small set of possible reorderings at this time
if(expr.operator==subExpr.operator) {
// both operators are the same.
// If + or *, we can simply shuffle the const operands around to optimize.
if(expr.operator=="+" || expr.operator=="*") {
return if(leftIsConst) {
if(subleftIsConst)
ShuffleOperands(expr, null, subExpr, subExpr.right, null, null, expr.left)
else
ShuffleOperands(expr, null, subExpr, subExpr.left, null, expr.left, null)
} else {
if(subleftIsConst)
ShuffleOperands(expr, null, subExpr, null, subExpr.right, null, expr.right)
else
ShuffleOperands(expr, null, subExpr, null, subExpr.left, expr.right, null)
} }
}
// If - or /, we simetimes must reorder more, and flip operators (- -> +, / -> *)
if(expr.operator=="-" || expr.operator=="/") {
if(leftIsConst) {
return if (subleftIsConst) {
ShuffleOperands(expr, if (expr.operator == "-") "+" else "*", subExpr, subExpr.right, null, expr.left, subExpr.left)
} else {
IAstModification.ReplaceNode(expr,
BinaryExpression(
BinaryExpression(expr.left, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
expr.operator, subExpr.left, expr.position),
expr.parent)
}
} else {
return if(subleftIsConst) {
return ShuffleOperands(expr, null, subExpr, null, subExpr.right, null, expr.right)
} else {
IAstModification.ReplaceNode(expr,
BinaryExpression(
subExpr.left, expr.operator,
BinaryExpression(expr.right, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
expr.position),
expr.parent)
}
}
}
return null
}
else
{
if(expr.operator=="/" && subExpr.operator=="*") {
if(leftIsConst) {
val change = if(subleftIsConst) {
// C1/(C2*V) -> (C1/C2)/V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.left, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// C1/(V*C2) -> (C1/C2)/V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
"/",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (C1*V)/C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(subExpr.left, "/", expr.right, subExpr.position),
"*",
subExpr.right, expr.position)
} else {
// (V*C1)/C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(subExpr.right, "/", expr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
}
}
else if(expr.operator=="*" && subExpr.operator=="/") {
if(leftIsConst) {
val change = if(subleftIsConst) {
// C1*(C2/V) -> (C1*C2)/V
BinaryExpression(
BinaryExpression(expr.left, "*", subExpr.left, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// C1*(V/C2) -> (C1/C2)*V
BinaryExpression(
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (C1/V)*C2 -> (C1*C2)/V
BinaryExpression(
BinaryExpression(subExpr.left, "*", expr.right, subExpr.position),
"/",
subExpr.right, expr.position)
} else {
// (V/C1)*C2 -> (C1/C2)*V
BinaryExpression(
BinaryExpression(expr.right, "/", subExpr.right, subExpr.position),
"*",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
}
}
else if(expr.operator=="+" && subExpr.operator=="-") {
if(leftIsConst){
val change = if(subleftIsConst){
// c1+(c2-v) -> (c1+c2)-v
BinaryExpression(
BinaryExpression(expr.left, "+", subExpr.left, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// c1+(v-c2) -> v+(c1-c2)
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (c1-v)+c2 -> (c1+c2)-v
BinaryExpression(
BinaryExpression(subExpr.left, "+", expr.right, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// (v-c1)+c2 -> v+(c2-c1)
BinaryExpression(
BinaryExpression(expr.right, "-", subExpr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
}
}
else if(expr.operator=="-" && subExpr.operator=="+") {
if(leftIsConst) {
val change = if(subleftIsConst) {
// c1-(c2+v) -> (c1-c2)-v
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.left, subExpr.position),
"-",
subExpr.right, expr.position)
} else {
// c1-(v+c2) -> (c1-c2)-v
BinaryExpression(
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
"-",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (c1+v)-c2 -> v+(c1-c2)
BinaryExpression(
BinaryExpression(subExpr.left, "-", expr.right, subExpr.position),
"+",
subExpr.right, expr.position)
} else {
// (v+c1)-c2 -> v+(c1-c2)
BinaryExpression(
BinaryExpression(subExpr.right, "-", expr.right, subExpr.position),
"+",
subExpr.left, expr.position)
}
return IAstModification.ReplaceNode(expr, change, expr.parent)
}
}
return null
}
}
} }

View File

@ -22,11 +22,12 @@ import kotlin.math.pow
internal class ExpressionSimplifier(private val program: Program) : AstWalker() { internal class ExpressionSimplifier(private val program: Program) : AstWalker() {
private val powersOfTwo = (1..16).map { (2.0).pow(it) }.toSet() private val powersOfTwo = (1..16).map { (2.0).pow(it) }.toSet()
private val negativePowersOfTwo = powersOfTwo.map { -it }.toSet() private val negativePowersOfTwo = powersOfTwo.map { -it }.toSet()
private val noModifications = emptyList<IAstModification>()
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> { override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if (assignment.aug_op != null) if (assignment.aug_op != null)
throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment") throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment")
return emptyList() return noModifications
} }
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> { override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
@ -82,10 +83,10 @@ internal class ExpressionSimplifier(private val program: Program) : AstWalker()
if (newExpr != null) if (newExpr != null)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent)) return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
} }
else -> return emptyList() else -> return noModifications
} }
} }
return emptyList() return noModifications
} }
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
@ -297,7 +298,7 @@ internal class ExpressionSimplifier(private val program: Program) : AstWalker()
if(newExpr != null) if(newExpr != null)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent)) return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
return emptyList() return noModifications
} }
private fun determineY(x: Expression, subBinExpr: BinaryExpression): Expression? { private fun determineY(x: Expression, subBinExpr: BinaryExpression): Expression? {

View File

@ -5,14 +5,23 @@ import prog8.ast.base.ErrorReporter
internal fun Program.constantFold(errors: ErrorReporter) { internal fun Program.constantFold(errors: ErrorReporter) {
val replacer = ConstantIdentifierReplacer(this, errors)
replacer.visit(this)
if(errors.isEmpty()) {
replacer.applyModifications()
val optimizer = ConstantFoldingOptimizer(this, errors) val optimizer = ConstantFoldingOptimizer(this, errors)
optimizer.visit(this) optimizer.visit(this)
while (errors.isEmpty() && optimizer.applyModifications() > 0) {
while(errors.isEmpty() && optimizer.optimizationsDone>0) {
optimizer.optimizationsDone = 0
optimizer.visit(this) optimizer.visit(this)
} }
if(errors.isEmpty()) {
replacer.visit(this)
replacer.applyModifications()
}
}
if(errors.isEmpty()) if(errors.isEmpty())
modules.forEach { it.linkParents(namespace) } // re-link in final configuration modules.forEach { it.linkParents(namespace) } // re-link in final configuration
} }
@ -21,9 +30,11 @@ internal fun Program.constantFold(errors: ErrorReporter) {
internal fun Program.optimizeStatements(errors: ErrorReporter): Int { internal fun Program.optimizeStatements(errors: ErrorReporter): Int {
val optimizer = StatementOptimizer(this, errors) val optimizer = StatementOptimizer(this, errors)
optimizer.visit(this) optimizer.visit(this)
val optimizationCount = optimizer.applyModifications()
modules.forEach { it.linkParents(this.namespace) } // re-link in final configuration modules.forEach { it.linkParents(this.namespace) } // re-link in final configuration
return optimizer.optimizationsDone return optimizationCount
} }
internal fun Program.simplifyExpressions() : Int { internal fun Program.simplifyExpressions() : Int {

View File

@ -1,46 +0,0 @@
package prog8.optimizer
import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.AnonymousScope
import prog8.ast.statements.NopStatement
import prog8.ast.statements.Statement
internal class FlattenAnonymousScopesAndNopRemover: IAstVisitor {
private var scopesToFlatten = mutableListOf<INameScope>()
private val nopStatements = mutableListOf<NopStatement>()
override fun visit(program: Program) {
super.visit(program)
for(scope in scopesToFlatten.reversed()) {
val namescope = scope.parent as INameScope
val idx = namescope.statements.indexOf(scope as Statement)
if(idx>=0) {
val nop = NopStatement.insteadOf(namescope.statements[idx])
nop.parent = namescope as Node
namescope.statements[idx] = nop
namescope.statements.addAll(idx, scope.statements)
scope.statements.forEach { it.parent = namescope }
visit(nop)
}
}
this.nopStatements.forEach {
it.definingScope().remove(it)
}
}
override fun visit(scope: AnonymousScope) {
if(scope.parent is INameScope) {
scopesToFlatten.add(scope) // get rid of the anonymous scope
}
return super.visit(scope)
}
override fun visit(nopStatement: NopStatement) {
nopStatements.add(nopStatement)
}
}

View File

@ -1,10 +1,12 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.INameScope import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModification
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.CompilationTarget import prog8.compiler.target.CompilationTarget
@ -14,54 +16,37 @@ import kotlin.math.floor
/* /*
TODO: remove unreachable code after return and exit() TODO: remove unreachable code after return and exit()
TODO: proper inlining of tiny subroutines (at first, restrict to subs without parameters and variables in them, and build it up from there: correctly renaming/relocating all variables in them and refs to those as well)
*/ */
// TODO implement using AstWalker instead of IAstModifyingVisitor
internal class StatementOptimizer(private val program: Program, internal class StatementOptimizer(private val program: Program,
private val errors: ErrorReporter) : IAstModifyingVisitor { private val errors: ErrorReporter) : AstWalker() {
var optimizationsDone: Int = 0
private set
private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } private val noModifications = emptyList<IAstModification>()
private val callgraph = CallGraph(program) private val callgraph = CallGraph(program)
private val vardeclsToRemove = mutableListOf<VarDecl>() private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure }
override fun visit(program: Program) { override fun after(block: Block, parent: Node): Iterable<IAstModification> {
super.visit(program)
for(decl in vardeclsToRemove) {
decl.definingScope().remove(decl)
}
}
override fun visit(block: Block): Statement {
if("force_output" !in block.options()) { if("force_output" !in block.options()) {
if (block.containsNoCodeNorVars()) { if (block.containsNoCodeNorVars()) {
optimizationsDone++
errors.warn("removing empty block '${block.name}'", block.position) errors.warn("removing empty block '${block.name}'", block.position)
return NopStatement.insteadOf(block) return listOf(IAstModification.Remove(block, parent))
} }
if (block !in callgraph.usedSymbols) { if (block !in callgraph.usedSymbols) {
optimizationsDone++
errors.warn("removing unused block '${block.name}'", block.position) errors.warn("removing unused block '${block.name}'", block.position)
return NopStatement.insteadOf(block) // remove unused block return listOf(IAstModification.Remove(block, parent))
} }
} }
return noModifications
}
return super.visit(block) override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
}
override fun visit(subroutine: Subroutine): Statement {
super.visit(subroutine)
val forceOutput = "force_output" in subroutine.definingBlock().options() val forceOutput = "force_output" in subroutine.definingBlock().options()
if(subroutine.asmAddress==null && !forceOutput) { if(subroutine.asmAddress==null && !forceOutput) {
if(subroutine.containsNoCodeNorVars()) { if(subroutine.containsNoCodeNorVars()) {
errors.warn("removing empty subroutine '${subroutine.name}'", subroutine.position) errors.warn("removing empty subroutine '${subroutine.name}'", subroutine.position)
optimizationsDone++ return listOf(IAstModification.Remove(subroutine, parent))
return NopStatement.insteadOf(subroutine)
} }
} }
@ -72,23 +57,341 @@ internal class StatementOptimizer(private val program: Program,
if(subroutine !in callgraph.usedSymbols && !forceOutput) { if(subroutine !in callgraph.usedSymbols && !forceOutput) {
errors.warn("removing unused subroutine '${subroutine.name}'", subroutine.position) errors.warn("removing unused subroutine '${subroutine.name}'", subroutine.position)
optimizationsDone++ return listOf(IAstModification.Remove(subroutine, parent))
return NopStatement.insteadOf(subroutine)
} }
return subroutine return noModifications
} }
override fun visit(decl: VarDecl): Statement { override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
val linesToRemove = deduplicateAssignments(scope.statements)
return linesToRemove.reversed().map { IAstModification.Remove(scope.statements[it], scope) }
}
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
val forceOutput = "force_output" in decl.definingBlock().options() val forceOutput = "force_output" in decl.definingBlock().options()
if(decl !in callgraph.usedSymbols && !forceOutput) { if(decl !in callgraph.usedSymbols && !forceOutput) {
if(decl.type == VarDeclType.VAR) if(decl.type == VarDeclType.VAR)
errors.warn("removing unused variable ${decl.type} '${decl.name}'", decl.position) errors.warn("removing unused variable '${decl.name}'", decl.position)
optimizationsDone++
return NopStatement.insteadOf(decl) return listOf(IAstModification.Remove(decl, parent))
} }
return super.visit(decl) return noModifications
}
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in BuiltinFunctions) {
val functionName = functionCallStatement.target.nameInSource[0]
if (functionName in pureBuiltinFunctions) {
errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position)
return listOf(IAstModification.Remove(functionCallStatement, parent))
}
}
// printing a literal string of just 2 or 1 characters is replaced by directly outputting those characters
// this is a C-64 specific optimization
if(functionCallStatement.target.nameInSource==listOf("c64scr", "print")) {
val arg = functionCallStatement.args.single()
val stringVar: IdentifierReference?
stringVar = if(arg is AddressOf) {
arg.identifier
} else {
arg as? IdentifierReference
}
if(stringVar!=null) {
val vardecl = stringVar.targetVarDecl(program.namespace)!!
val string = vardecl.value!! as StringLiteralValue
val pos = functionCallStatement.position
if(string.value.length==1) {
val firstCharEncoded = CompilationTarget.encodeString(string.value, string.altEncoding)[0]
val chrout = FunctionCallStatement(
IdentifierReference(listOf("c64", "CHROUT"), pos),
mutableListOf(NumericLiteralValue(DataType.UBYTE, firstCharEncoded.toInt(), pos)),
functionCallStatement.void, pos
)
return listOf(IAstModification.ReplaceNode(functionCallStatement, chrout, parent))
} else if(string.value.length==2) {
val firstTwoCharsEncoded = CompilationTarget.encodeString(string.value.take(2), string.altEncoding)
val chrout1 = FunctionCallStatement(
IdentifierReference(listOf("c64", "CHROUT"), pos),
mutableListOf(NumericLiteralValue(DataType.UBYTE, firstTwoCharsEncoded[0].toInt(), pos)),
functionCallStatement.void, pos
)
val chrout2 = FunctionCallStatement(
IdentifierReference(listOf("c64", "CHROUT"), pos),
mutableListOf(NumericLiteralValue(DataType.UBYTE, firstTwoCharsEncoded[1].toInt(), pos)),
functionCallStatement.void, pos
)
val anonscope = AnonymousScope(mutableListOf(), pos)
anonscope.statements.add(chrout1)
anonscope.statements.add(chrout2)
return listOf(IAstModification.ReplaceNode(functionCallStatement, anonscope, parent))
}
}
}
// if the first instruction in the called subroutine is a return statement, remove the jump altogeter
val subroutine = functionCallStatement.target.targetSubroutine(program.namespace)
if(subroutine!=null) {
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is ReturnFromIrq || first is Return)
return listOf(IAstModification.Remove(functionCallStatement, parent))
}
return noModifications
}
override fun before(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
// if the first instruction in the called subroutine is a return statement with constant value, replace with the constant value
val subroutine = functionCall.target.targetSubroutine(program.namespace)
if(subroutine!=null) {
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is Return && first.value!=null) {
val constval = first.value?.constValue(program)
if(constval!=null)
return listOf(IAstModification.ReplaceNode(functionCall, constval, parent))
}
}
return noModifications
}
override fun after(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> {
// remove empty if statements
if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsNoCodeNorVars())
return listOf(IAstModification.Remove(ifStatement, parent))
// empty true part? switch with the else part
if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsCodeOrVars()) {
val invertedCondition = PrefixExpression("not", ifStatement.condition, ifStatement.condition.position)
val emptyscope = AnonymousScope(mutableListOf(), ifStatement.elsepart.position)
val truepart = AnonymousScope(ifStatement.elsepart.statements, ifStatement.truepart.position)
return listOf(
IAstModification.ReplaceNode(ifStatement.condition, invertedCondition, ifStatement),
IAstModification.ReplaceNode(ifStatement.truepart, truepart, ifStatement),
IAstModification.ReplaceNode(ifStatement.elsepart, emptyscope, ifStatement)
)
}
val constvalue = ifStatement.condition.constValue(program)
if(constvalue!=null) {
return if(constvalue.asBooleanValue){
// always true -> keep only if-part
errors.warn("condition is always true", ifStatement.position)
listOf(IAstModification.ReplaceNode(ifStatement, ifStatement.truepart, parent))
} else {
// always false -> keep only else-part
errors.warn("condition is always false", ifStatement.position)
listOf(IAstModification.ReplaceNode(ifStatement, ifStatement.elsepart, parent))
}
}
return noModifications
}
override fun after(forLoop: ForLoop, parent: Node): Iterable<IAstModification> {
if(forLoop.body.containsNoCodeNorVars()) {
// remove empty for loop
return listOf(IAstModification.Remove(forLoop, parent))
} else if(forLoop.body.statements.size==1) {
val loopvar = forLoop.body.statements[0] as? VarDecl
if(loopvar!=null && loopvar.name==forLoop.loopVar?.nameInSource?.singleOrNull()) {
// remove empty for loop (only loopvar decl in it)
return listOf(IAstModification.Remove(forLoop, parent))
}
}
val range = forLoop.iterable as? RangeExpr
if(range!=null) {
if(range.size()==1) {
// for loop over a (constant) range of just a single value-- optimize the loop away
// loopvar/reg = range value , follow by block
val scope = AnonymousScope(mutableListOf(), forLoop.position)
scope.statements.add(Assignment(AssignTarget(forLoop.loopRegister, forLoop.loopVar, null, null, forLoop.position), null, range.from, forLoop.position))
scope.statements.addAll(forLoop.body.statements)
return listOf(IAstModification.ReplaceNode(forLoop, scope, parent))
}
}
val iterable = (forLoop.iterable as? IdentifierReference)?.targetVarDecl(program.namespace)
if(iterable!=null) {
if(iterable.datatype==DataType.STR) {
val sv = iterable.value as StringLiteralValue
val size = sv.value.length
if(size==1) {
// loop over string of length 1 -> just assign the single character
val character = CompilationTarget.encodeString(sv.value, sv.altEncoding)[0]
val byte = NumericLiteralValue(DataType.UBYTE, character, iterable.position)
val scope = AnonymousScope(mutableListOf(), forLoop.position)
scope.statements.add(Assignment(AssignTarget(forLoop.loopRegister, forLoop.loopVar, null, null, forLoop.position), null, byte, forLoop.position))
scope.statements.addAll(forLoop.body.statements)
return listOf(IAstModification.ReplaceNode(forLoop, scope, parent))
}
}
else if(iterable.datatype in ArrayDatatypes) {
val size = iterable.arraysize!!.size()
if(size==1) {
// loop over array of length 1 -> just assign the single value
val av = (iterable.value as ArrayLiteralValue).value[0].constValue(program)?.number
if(av!=null) {
val scope = AnonymousScope(mutableListOf(), forLoop.position)
scope.statements.add(Assignment(
AssignTarget(forLoop.loopRegister, forLoop.loopVar, null, null, forLoop.position), null, NumericLiteralValue.optimalInteger(av.toInt(), iterable.position),
forLoop.position))
scope.statements.addAll(forLoop.body.statements)
return listOf(IAstModification.ReplaceNode(forLoop, scope, parent))
}
}
}
}
return noModifications
}
override fun before(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> {
val constvalue = repeatLoop.untilCondition.constValue(program)
if(constvalue!=null) {
if(constvalue.asBooleanValue) {
// always true -> keep only the statement block (if there are no continue and break statements)
errors.warn("condition is always true", repeatLoop.untilCondition.position)
if(!hasContinueOrBreak(repeatLoop.body))
return listOf(IAstModification.ReplaceNode(repeatLoop, repeatLoop.body, parent))
} else {
// always false
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
return listOf(IAstModification.ReplaceNode(repeatLoop, forever, parent))
}
}
return noModifications
}
override fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
val constvalue = whileLoop.condition.constValue(program)
if(constvalue!=null) {
return if(constvalue.asBooleanValue) {
// always true
val forever = ForeverLoop(whileLoop.body, whileLoop.position)
listOf(IAstModification.ReplaceNode(whileLoop, forever, parent))
} else {
// always false -> remove the while statement altogether
errors.warn("condition is always false", whileLoop.condition.position)
listOf(IAstModification.Remove(whileLoop, parent))
}
}
return noModifications
}
override fun after(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> {
// remove empty choices
class ChoiceRemover(val choice: WhenChoice) : IAstModification {
override fun perform() {
whenStatement.choices.remove(choice)
}
}
return whenStatement.choices
.filter { !it.statements.containsCodeOrVars() }
.map { ChoiceRemover(it) }
}
override fun after(jump: Jump, parent: Node): Iterable<IAstModification> {
// if the jump is to the next statement, remove the jump
val scope = jump.definingScope()
val label = jump.identifier?.targetStatement(scope)
if(label!=null && scope.statements.indexOf(label) == scope.statements.indexOf(jump)+1)
return listOf(IAstModification.Remove(jump, parent))
return noModifications
}
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if(assignment.aug_op!=null)
throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment")
// remove assignments to self
if(assignment.target isSameAs assignment.value) {
if(assignment.target.isNotMemory(program.namespace))
return listOf(IAstModification.Remove(assignment, parent))
}
val targetIDt = assignment.target.inferType(program, assignment)
if(!targetIDt.isKnown)
throw FatalAstException("can't infer type of assignment target")
// optimize binary expressions a bit
val targetDt = targetIDt.typeOrElse(DataType.STRUCT)
val bexpr=assignment.value as? BinaryExpression
if(bexpr!=null) {
val cv = bexpr.right.constValue(program)?.number?.toDouble()
if (cv != null && assignment.target isSameAs bexpr.left) {
// assignments of the form: X = X <operator> <expr>
// remove assignments that have no effect (such as X=X+0)
// optimize/rewrite some other expressions
val vardeclDt = (assignment.target.identifier?.targetVarDecl(program.namespace))?.type
when (bexpr.operator) {
"+" -> {
if (cv == 0.0) {
return listOf(IAstModification.Remove(assignment, parent))
} else if (targetDt in IntegerDatatypes && floor(cv) == cv) {
if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) {
// replace by several INCs (a bit less when dealing with memory targets)
val incs = AnonymousScope(mutableListOf(), assignment.position)
repeat(cv.toInt()) {
incs.statements.add(PostIncrDecr(assignment.target, "++", assignment.position))
}
return listOf(IAstModification.ReplaceNode(assignment, incs, parent))
}
}
}
"-" -> {
if (cv == 0.0) {
return listOf(IAstModification.Remove(assignment, parent))
} else if (targetDt in IntegerDatatypes && floor(cv) == cv) {
if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) {
// replace by several DECs (a bit less when dealing with memory targets)
val decs = AnonymousScope(mutableListOf(), assignment.position)
repeat(cv.toInt()) {
decs.statements.add(PostIncrDecr(assignment.target, "--", assignment.position))
}
return listOf(IAstModification.ReplaceNode(assignment, decs, parent))
}
}
}
"*" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent))
"/" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent))
"**" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent))
"|" -> if (cv == 0.0) return listOf(IAstModification.Remove(assignment, parent))
"^" -> if (cv == 0.0) return listOf(IAstModification.Remove(assignment, parent))
"<<" -> {
if (cv == 0.0)
return listOf(IAstModification.Remove(assignment, parent))
// replace by in-place lsl(...) call
val scope = AnonymousScope(mutableListOf(), assignment.position)
var numshifts = cv.toInt()
while (numshifts > 0) {
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsl"), assignment.position),
mutableListOf(bexpr.left), true, assignment.position))
numshifts--
}
return listOf(IAstModification.ReplaceNode(assignment, scope, parent))
}
">>" -> {
if (cv == 0.0)
return listOf(IAstModification.Remove(assignment, parent))
// replace by in-place lsr(...) call
val scope = AnonymousScope(mutableListOf(), assignment.position)
var numshifts = cv.toInt()
while (numshifts > 0) {
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsr"), assignment.position),
mutableListOf(bexpr.left), true, assignment.position))
numshifts--
}
return listOf(IAstModification.ReplaceNode(assignment, scope, parent))
}
}
}
}
return noModifications
} }
private fun deduplicateAssignments(statements: List<Statement>): MutableList<Int> { private fun deduplicateAssignments(statements: List<Statement>): MutableList<Int> {
@ -116,207 +419,6 @@ internal class StatementOptimizer(private val program: Program,
return linesToRemove return linesToRemove
} }
override fun visit(functionCallStatement: FunctionCallStatement): Statement {
if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in BuiltinFunctions) {
val functionName = functionCallStatement.target.nameInSource[0]
if (functionName in pureBuiltinFunctions) {
errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position)
optimizationsDone++
return NopStatement.insteadOf(functionCallStatement)
}
}
if(functionCallStatement.target.nameInSource==listOf("c64scr", "print") ||
functionCallStatement.target.nameInSource==listOf("c64scr", "print_p")) {
// printing a literal string of just 2 or 1 characters is replaced by directly outputting those characters
val arg = functionCallStatement.args.single()
val stringVar: IdentifierReference?
stringVar = if(arg is AddressOf) {
arg.identifier
} else {
arg as? IdentifierReference
}
if(stringVar!=null) {
val vardecl = stringVar.targetVarDecl(program.namespace)!!
val string = vardecl.value!! as StringLiteralValue
if(string.value.length==1) {
val firstCharEncoded = CompilationTarget.encodeString(string.value, string.altEncoding)[0]
functionCallStatement.args.clear()
functionCallStatement.args.add(NumericLiteralValue.optimalInteger(firstCharEncoded.toInt(), functionCallStatement.position))
functionCallStatement.target = IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position)
vardeclsToRemove.add(vardecl)
optimizationsDone++
return functionCallStatement
} else if(string.value.length==2) {
val firstTwoCharsEncoded = CompilationTarget.encodeString(string.value.take(2), string.altEncoding)
val scope = AnonymousScope(mutableListOf(), functionCallStatement.position)
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position),
mutableListOf(NumericLiteralValue.optimalInteger(firstTwoCharsEncoded[0].toInt(), functionCallStatement.position)),
functionCallStatement.void, functionCallStatement.position))
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position),
mutableListOf(NumericLiteralValue.optimalInteger(firstTwoCharsEncoded[1].toInt(), functionCallStatement.position)),
functionCallStatement.void, functionCallStatement.position))
vardeclsToRemove.add(vardecl)
optimizationsDone++
return scope
}
}
}
// if it calls a subroutine,
// and the first instruction in the subroutine is a jump, call that jump target instead
// if the first instruction in the subroutine is a return statement, replace with a nop instruction
val subroutine = functionCallStatement.target.targetSubroutine(program.namespace)
if(subroutine!=null) {
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is Jump && first.identifier!=null) {
optimizationsDone++
return FunctionCallStatement(first.identifier, functionCallStatement.args, functionCallStatement.void, functionCallStatement.position)
}
if(first is ReturnFromIrq || first is Return) {
optimizationsDone++
return NopStatement.insteadOf(functionCallStatement)
}
}
return super.visit(functionCallStatement)
}
override fun visit(functionCall: FunctionCall): Expression {
// if it calls a subroutine,
// and the first instruction in the subroutine is a jump, call that jump target instead
// if the first instruction in the subroutine is a return statement with constant value, replace with the constant value
val subroutine = functionCall.target.targetSubroutine(program.namespace)
if(subroutine!=null) {
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is Jump && first.identifier!=null) {
optimizationsDone++
return FunctionCall(first.identifier, functionCall.args, functionCall.position)
}
if(first is Return && first.value!=null) {
val constval = first.value?.constValue(program)
if(constval!=null)
return constval
}
}
return super.visit(functionCall)
}
override fun visit(ifStatement: IfStatement): Statement {
super.visit(ifStatement)
if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsNoCodeNorVars()) {
optimizationsDone++
return NopStatement.insteadOf(ifStatement)
}
if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsCodeOrVars()) {
// invert the condition and move else part to true part
ifStatement.truepart = ifStatement.elsepart
ifStatement.elsepart = AnonymousScope(mutableListOf(), ifStatement.elsepart.position)
ifStatement.condition = PrefixExpression("not", ifStatement.condition, ifStatement.condition.position)
optimizationsDone++
return ifStatement
}
val constvalue = ifStatement.condition.constValue(program)
if(constvalue!=null) {
return if(constvalue.asBooleanValue){
// always true -> keep only if-part
errors.warn("condition is always true", ifStatement.position)
optimizationsDone++
ifStatement.truepart
} else {
// always false -> keep only else-part
errors.warn("condition is always false", ifStatement.position)
optimizationsDone++
ifStatement.elsepart
}
}
return ifStatement
}
override fun visit(forLoop: ForLoop): Statement {
super.visit(forLoop)
if(forLoop.body.containsNoCodeNorVars()) {
// remove empty for loop
optimizationsDone++
return NopStatement.insteadOf(forLoop)
} else if(forLoop.body.statements.size==1) {
val loopvar = forLoop.body.statements[0] as? VarDecl
if(loopvar!=null && loopvar.name==forLoop.loopVar?.nameInSource?.singleOrNull()) {
// remove empty for loop
optimizationsDone++
return NopStatement.insteadOf(forLoop)
}
}
val range = forLoop.iterable as? RangeExpr
if(range!=null) {
if(range.size()==1) {
// for loop over a (constant) range of just a single value-- optimize the loop away
// loopvar/reg = range value , follow by block
val assignment = Assignment(AssignTarget(forLoop.loopRegister, forLoop.loopVar, null, null, forLoop.position), null, range.from, forLoop.position)
forLoop.body.statements.add(0, assignment)
optimizationsDone++
return forLoop.body
}
}
return forLoop
}
override fun visit(whileLoop: WhileLoop): Statement {
super.visit(whileLoop)
val constvalue = whileLoop.condition.constValue(program)
if(constvalue!=null) {
return if(constvalue.asBooleanValue){
// always true -> print a warning, and optimize into a forever-loop
errors.warn("condition is always true", whileLoop.condition.position)
optimizationsDone++
ForeverLoop(whileLoop.body, whileLoop.position)
} else {
// always false -> remove the while statement altogether
errors.warn("condition is always false", whileLoop.condition.position)
optimizationsDone++
NopStatement.insteadOf(whileLoop)
}
}
return whileLoop
}
override fun visit(repeatLoop: RepeatLoop): Statement {
super.visit(repeatLoop)
val constvalue = repeatLoop.untilCondition.constValue(program)
if(constvalue!=null) {
return if(constvalue.asBooleanValue){
// always true -> keep only the statement block (if there are no continue and break statements)
errors.warn("condition is always true", repeatLoop.untilCondition.position)
if(hasContinueOrBreak(repeatLoop.body))
repeatLoop
else {
optimizationsDone++
repeatLoop.body
}
} else {
// always false -> print a warning, and optimize into a forever loop
errors.warn("condition is always false", repeatLoop.untilCondition.position)
optimizationsDone++
ForeverLoop(repeatLoop.body, repeatLoop.position)
}
}
return repeatLoop
}
override fun visit(whenStatement: WhenStatement): Statement {
val choices = whenStatement.choices.toList()
for(choice in choices) {
if(choice.statements.containsNoCodeNorVars())
whenStatement.choices.remove(choice)
}
return super.visit(whenStatement)
}
private fun hasContinueOrBreak(scope: INameScope): Boolean { private fun hasContinueOrBreak(scope: INameScope): Boolean {
class Searcher: IAstVisitor class Searcher: IAstVisitor
@ -341,185 +443,4 @@ internal class StatementOptimizer(private val program: Program,
return s.count > 0 return s.count > 0
} }
override fun visit(jump: Jump): Statement {
val subroutine = jump.identifier?.targetSubroutine(program.namespace)
if(subroutine!=null) {
// if the first instruction in the subroutine is another jump, shortcut this one
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is Jump) {
optimizationsDone++
return first
}
}
// if the jump is to the next statement, remove the jump
val scope = jump.definingScope()
val label = jump.identifier?.targetStatement(scope)
if(label!=null) {
if(scope.statements.indexOf(label) == scope.statements.indexOf(jump)+1) {
optimizationsDone++
return NopStatement.insteadOf(jump)
}
}
return jump
}
override fun visit(assignment: Assignment): Statement {
if(assignment.aug_op!=null)
throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment")
if(assignment.target isSameAs assignment.value) {
if(assignment.target.isNotMemory(program.namespace)) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
}
val targetIDt = assignment.target.inferType(program, assignment)
if(!targetIDt.isKnown)
throw FatalAstException("can't infer type of assignment target")
val targetDt = targetIDt.typeOrElse(DataType.STRUCT)
val bexpr=assignment.value as? BinaryExpression
if(bexpr!=null) {
val cv = bexpr.right.constValue(program)?.number?.toDouble()
if (cv == null) {
if (bexpr.operator == "+" && targetDt != DataType.FLOAT) {
if (bexpr.left isSameAs bexpr.right && assignment.target isSameAs bexpr.left) {
bexpr.operator = "*"
bexpr.right = NumericLiteralValue.optimalInteger(2, assignment.value.position)
optimizationsDone++
return assignment
}
}
} else {
if (assignment.target isSameAs bexpr.left) {
// remove assignments that have no effect X=X , X+=0, X-=0, X*=1, X/=1, X//=1, A |= 0, A ^= 0, A<<=0, etc etc
// A = A <operator> B
val vardeclDt = (assignment.target.identifier?.targetVarDecl(program.namespace))?.type
when (bexpr.operator) {
"+" -> {
if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
} else if (targetDt in IntegerDatatypes && floor(cv) == cv) {
if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) {
// replace by several INCs (a bit less when dealing with memory targets)
val decs = AnonymousScope(mutableListOf(), assignment.position)
repeat(cv.toInt()) {
decs.statements.add(PostIncrDecr(assignment.target, "++", assignment.position))
}
return decs
}
}
}
"-" -> {
if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
} else if (targetDt in IntegerDatatypes && floor(cv) == cv) {
if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) {
// replace by several DECs (a bit less when dealing with memory targets)
val decs = AnonymousScope(mutableListOf(), assignment.position)
repeat(cv.toInt()) {
decs.statements.add(PostIncrDecr(assignment.target, "--", assignment.position))
}
return decs
}
}
}
"*" -> if (cv == 1.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
"/" -> if (cv == 1.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
"**" -> if (cv == 1.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
"|" -> if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
"^" -> if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
"<<" -> {
if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
if (((targetDt == DataType.UWORD || targetDt == DataType.WORD) && cv > 15.0) ||
((targetDt == DataType.UBYTE || targetDt == DataType.BYTE) && cv > 7.0)) {
assignment.value = NumericLiteralValue.optimalInteger(0, assignment.value.position)
assignment.value.linkParents(assignment)
optimizationsDone++
} else {
// replace by in-place lsl(...) call
val scope = AnonymousScope(mutableListOf(), assignment.position)
var numshifts = cv.toInt()
while (numshifts > 0) {
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsl"), assignment.position),
mutableListOf(bexpr.left), true, assignment.position))
numshifts--
}
optimizationsDone++
return scope
}
}
">>" -> {
if (cv == 0.0) {
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
if ((targetDt == DataType.UWORD && cv > 15.0) || (targetDt == DataType.UBYTE && cv > 7.0)) {
assignment.value = NumericLiteralValue.optimalInteger(0, assignment.value.position)
assignment.value.linkParents(assignment)
optimizationsDone++
} else {
// replace by in-place lsr(...) call
val scope = AnonymousScope(mutableListOf(), assignment.position)
var numshifts = cv.toInt()
while (numshifts > 0) {
scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsr"), assignment.position),
mutableListOf(bexpr.left), true, assignment.position))
numshifts--
}
optimizationsDone++
return scope
}
}
}
}
}
}
return super.visit(assignment)
}
override fun visit(scope: AnonymousScope): Statement {
val linesToRemove = deduplicateAssignments(scope.statements)
if(linesToRemove.isNotEmpty()) {
linesToRemove.reversed().forEach{scope.statements.removeAt(it)}
}
return super.visit(scope)
}
override fun visit(label: Label): Statement {
// remove duplicate labels
val stmts = label.definingScope().statements
val startIdx = stmts.indexOf(label)
if(startIdx< stmts.lastIndex && stmts[startIdx+1] == label)
return NopStatement.insteadOf(label)
return super.visit(label)
}
} }

View File

@ -17,7 +17,7 @@ internal class UnusedCodeRemover: AstWalker() {
val entrypoint = program.entrypoint() val entrypoint = program.entrypoint()
program.modules.forEach { program.modules.forEach {
callgraph.forAllSubroutines(it) { sub -> callgraph.forAllSubroutines(it) { sub ->
if (sub !== entrypoint && !sub.keepAlways && (sub.calledBy.isEmpty() || (sub.containsNoCodeNorVars() && !sub.isAsmSubroutine))) if (sub !== entrypoint && !sub.keepAlways && (callgraph.calledBy[sub].isNullOrEmpty() || (sub.containsNoCodeNorVars() && !sub.isAsmSubroutine)))
removals.add(IAstModification.Remove(sub, sub.definingScope() as Node)) removals.add(IAstModification.Remove(sub, sub.definingScope() as Node))
} }
} }
@ -30,7 +30,7 @@ internal class UnusedCodeRemover: AstWalker() {
// remove modules that are not imported, or are empty (unless it's a library modules) // remove modules that are not imported, or are empty (unless it's a library modules)
program.modules.forEach { program.modules.forEach {
if (!it.isLibraryModule && (it.importedBy.isEmpty() || it.containsNoCodeNorVars())) if (!it.isLibraryModule && (it.importedBy.isEmpty() || it.containsNoCodeNorVars()))
removals.add(IAstModification.Remove(it, it.parent)) // TODO does removing modules work like this? removals.add(IAstModification.Remove(it, it.parent))
} }
return removals return removals

View File

@ -156,7 +156,7 @@ Design principles and features
- The compiler tries to optimize the program and generated code a bit, but hand-tuning of the - The compiler tries to optimize the program and generated code a bit, but hand-tuning of the
performance or space-critical parts will likely still be required. This is supported by performance or space-critical parts will likely still be required. This is supported by
the ability to easily write embedded assembly code directly in the program source code. the ability to easily write embedded assembly code directly in the program source code.
- There are many built-in functions, such as ``sin``, ``cos``, ``rnd``, ``abs``, ``min``, ``max``, ``sqrt``, ``msb``, ``rol``, ``ror``, ``swap``, ``memset``, ``memcopy``, ``sort`` and ``reverse`` - There are many built-in functions, such as ``sin``, ``cos``, ``rnd``, ``abs``, ``min``, ``max``, ``sqrt``, ``msb``, ``rol``, ``ror``, ``swap``, ``memset``, ``memcopy``, ``substr``, ``sort`` and ``reverse`` (and others)
- Assembling the generated code into a program wil be done by an external cross-assembler tool. - Assembling the generated code into a program wil be done by an external cross-assembler tool.

View File

@ -279,16 +279,23 @@ This @-prefix can also be used for character byte values.
You can concatenate two string literals using '+' (not very useful though) or repeat You can concatenate two string literals using '+' (not very useful though) or repeat
a string literal a given number of times using '*':: a string literal a given number of times using '*'. You can also assign a new string
value to another string. No bounds check is done so be sure the destination string is
large enough to contain the new value::
str string1 = "first part" + "second part" str string1 = "first part" + "second part"
str string2 = "hello!" * 10 str string2 = "hello!" * 10
string1 = string2
string1 = "new value"
.. caution:: .. caution::
Avoid changing strings after they've been created. It's probably best to avoid changing strings after they've been created. This
includes changing certain letters by index, or by assigning a new value, or by
modifying the string via other means for example ``substr`` function and its cousins.
This is because if your program exits and is restarted (without loading it again), This is because if your program exits and is restarted (without loading it again),
it will then start working with the changed strings instead of the original ones. it will then start working with the changed strings instead of the original ones!
The same is true for arrays. The same is true for arrays.
@ -802,6 +809,22 @@ memsetw(address, numwords, wordvalue)
Efficiently set a part of memory to the given (u)word value. Efficiently set a part of memory to the given (u)word value.
But the most efficient will always be to write a specialized fill routine in assembly yourself! But the most efficient will always be to write a specialized fill routine in assembly yourself!
leftstr(source, target, length)
Copies the left side of the source string of the given length to target string.
It is assumed the target string buffer is large enough to contain the result.
Modifies in-place, doesn't return a value (so can't be used in an expression).
rightstr(source, target, length)
Copies the right side of the source string of the given length to target string.
It is assumed the target string buffer is large enough to contain the result.
Modifies in-place, doesn't return a value (so can't be used in an expression).
substr(source, target, start, length)
Copies a segment from the source string, starting at the given index,
and of the given length to target string.
It is assumed the target string buffer is large enough to contain the result.
Modifies in-place, doesn't return a value (so can't be used in an expression).
swap(x, y) swap(x, y)
Swap the values of numerical variables (or memory locations) x and y in a fast way. Swap the values of numerical variables (or memory locations) x and y in a fast way.

View File

@ -2,11 +2,14 @@
TODO TODO
==== ====
- finalize (most) of the still missing "new" assignment asm code generation - BUG FIX: fix register argument clobbering when calling asmsubs. (see fixme_argclobber.p8)
- finalize (most) of the still missing "new" assignment asm code generation
- aliases for imported symbols for example perhaps '%alias print = c64scr.print' - aliases for imported symbols for example perhaps '%alias print = c64scr.print'
- option to load library files from a directory instead of the embedded ones (easier library development/debugging) - option to load library files from a directory instead of the embedded ones (easier library development/debugging)
- investigate support for 8bitguy's Commander X16 platform https://murray2.com/forums/commander-x16.9/ and https://github.com/commanderx16/x16-docs - investigate support for 8bitguy's Commander X16 platform https://www.commanderx16.com and https://github.com/commanderx16/x16-docs
- see if we can group some errors together for instance the (now single) errors about unidentified symbols
More optimizations More optimizations

View File

@ -1,8 +1,7 @@
%import c64utils %import c64utils
;%import c64flt
;%option enable_floats
%zeropage dontuse %zeropage dontuse
main { main {
sub start() { sub start() {

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,6 +1,7 @@
%import c64lib %import c64lib
%import c64utils %import c64utils
spritedata $2000 { spritedata $2000 {
; this memory block contains the sprite data ; this memory block contains the sprite data
; it must start on an address aligned to 64 bytes. ; it must start on an address aligned to 64 bytes.

View File

@ -0,0 +1,47 @@
%import c64lib
%import c64utils
%import c64flt
%zeropage basicsafe
%option enable_floats
; TODO: fix register argument clobbering when calling asmsubs.
; for instance if the first arg goes into Y, and the second in A,
; but when calculating the second argument clobbers Y, the first argument gets destroyed.
main {
sub start() {
function(20, calculate())
asmfunction(20, calculate())
c64.CHROUT('\n')
if @($0400)==@($0402) and @($0401) == @($0403) {
c64scr.print("ok: results are same\n")
} else {
c64scr.print("error: result differ; arg got clobbered\n")
}
}
sub function(ubyte a1, ubyte a2) {
; non-asm function passes via stack, this is ok
@($0400) = a1
@($0401) = a2
}
asmsub asmfunction(ubyte a1 @ Y, ubyte a2 @ A) {
; asm-function passes via registers, risk of clobbering
%asm {{
sty $0402
sta $0403
}}
}
sub calculate() -> ubyte {
Y = 99
return Y
}
}

View File

@ -1,6 +1,9 @@
%import c64lib %import c64lib
%import c64graphics %import c64graphics
; TODO fix compiler errors when compiling without optimizations
main { main {
sub start() { sub start() {

View File

@ -0,0 +1,54 @@
%import c64lib
%import c64flt
%import c64graphics
%zeropage floatsafe
; Draw a mandelbrot in graphics mode (the image will be 256 x 200 pixels).
; NOTE: this will take an eternity to draw on a real c64.
; even in Vice in warp mode (700% speed on my machine) it's slow, but you can see progress
; TODO fix compiler errors when compiling without optimizations
main {
const ubyte width = 255
const ubyte height = 200
const ubyte max_iter = 16
sub start() {
graphics.enable_bitmap_mode()
ubyte pixelx
ubyte pixely
for pixely in 0 to height-1 {
float yy = (pixely as float)/0.4/height - 1.0
for pixelx in 0 to width-1 {
float xx = (pixelx as float)/0.3/width - 2.2
float xsquared = 0.0
float ysquared = 0.0
float x = 0.0
float y = 0.0
ubyte iter = 0
while iter<max_iter and xsquared+ysquared<4.0 {
y = x*y*2.0 + yy
x = xsquared - ysquared + xx
xsquared = x*x
ysquared = y*y
iter++
}
if iter & 1 {
graphics.plotx = pixelx
graphics.plot(pixely)
}
}
}
forever {
}
}
}

View File

@ -30,7 +30,7 @@ main {
float y = 0.0 float y = 0.0
ubyte iter = 0 ubyte iter = 0
while (iter<max_iter and xsquared+ysquared<4.0) { while iter<max_iter and xsquared+ysquared<4.0 {
y = x*y*2.0 + yy y = x*y*2.0 + yy
x = xsquared - ysquared + xx x = xsquared - ysquared + xx
xsquared = x*x xsquared = x*x

View File

@ -1,36 +0,0 @@
%import c64lib
%zeropage basicsafe
main {
sub start() {
str s1 = "apple"
str s2 = "banana"
byte[] a1 = [66,77,88,0]
ubyte i1 = 101
uword w1 = 000
c64.STROUT(s1)
c64.CHROUT('\n')
c64.STROUT(a1)
c64.CHROUT('\n')
c64scr.print("bla\n")
; c64scr.print_uwhex(s1, true)
; w1 = &s1
; c64scr.print_uwhex(w1, true)
;
; c64scr.print_uwhex(a1, true)
; w1 = &a1
; c64scr.print_uwhex(w1, true)
;
; s1 = s1
; s1 = s2
; s2 = "zzz"
}
}

View File

@ -1,6 +1,9 @@
%import c64utils %import c64utils
%zeropage basicsafe %zeropage basicsafe
; TODO fix compiler errors when compiling ( /= )
main { main {
struct Color { struct Color {

View File

@ -12,10 +12,8 @@ main {
ubyte color ubyte color
forever { forever {
float x = sin(t) ubyte xx=(sin(t) * width/2.2) + width/2.0 as ubyte
float y = cos(t*1.1356) ubyte yy=(cos(t*1.1356) * height/2.2) + height/2.0 as ubyte
ubyte xx=(x * width/2.2) + width/2.0 as ubyte
ubyte yy=(y * height/2.2) + height/2.0 as ubyte
c64scr.setcc(xx, yy, 81, color) c64scr.setcc(xx, yy, 81, color)
t += 0.08 t += 0.08
color++ color++

View File

@ -7,6 +7,11 @@
; staged speed increase ; staged speed increase
; some simple sound effects ; some simple sound effects
; TODO fix noCollision() at bottom when compiled without optimizations (codegen issue).
main { main {
const ubyte boardOffsetX = 14 const ubyte boardOffsetX = 14
@ -107,6 +112,7 @@ waitkey:
break break
} }
} }
if dropypos>ypos { if dropypos>ypos {
ypos = dropypos ypos = dropypos
sound.blockdrop() sound.blockdrop()
@ -543,6 +549,7 @@ blocklogic {
sub noCollision(ubyte xpos, ubyte ypos) -> ubyte { sub noCollision(ubyte xpos, ubyte ypos) -> ubyte {
ubyte i ubyte i
for i in 15 downto 0 { for i in 15 downto 0 {
; TODO FIX THIS when compiling without optimizations (codegen problem: clobbering register arguments, see fixme_argclobber):
if currentBlock[i] and c64scr.getchr(xpos + (i&3), ypos+i/4)!=32 if currentBlock[i] and c64scr.getchr(xpos + (i&3), ypos+i/4)!=32
return false return false
} }

View File

@ -2,44 +2,45 @@
%import c64utils %import c64utils
%import c64flt %import c64flt
%zeropage basicsafe %zeropage basicsafe
%option enable_floats
; TODO: fix register argument clobbering when calling asmsubs.
; for instance if the first arg goes into Y, and the second in A,
; but when calculating the second argument clobbers Y, the first argument gets destroyed.
main { main {
sub start() { sub start() {
function(20, calculate())
asmfunction(20, calculate())
A += 50 c64.CHROUT('\n')
A += Y + 1 if @($0400)==@($0402) and @($0401) == @($0403) {
A -= Y + 1 c64scr.print("ok: results are same\n")
A += Y - 1 } else {
A -= Y - 1 c64scr.print("error: result differ; arg got clobbered\n")
}
}
A += Y + 2 sub function(ubyte a1, ubyte a2) {
A -= Y + 2 ; non-asm function passes via stack, this is ok
A += Y - 2 @($0400) = a1
A -= Y - 2 @($0401) = a2
}
; ubyte ubb asmsub asmfunction(ubyte a1 @ Y, ubyte a2 @ A) {
; byte bb ; asm-function passes via registers, risk of clobbering
; uword uww %asm {{
; word ww sty $0402
; word ww2 sta $0403
; }}
; A = ubb*0 }
; Y = ubb*1
; A = ubb*2
; Y = ubb*4
; A = ubb*8
; Y = ubb*16
; A = ubb*32
; Y = ubb*64
; A = ubb*128
; Y = ubb+ubb+ubb
; A = ubb+ubb+ubb+ubb
; ww = ww2+ww2
; ww = ww2+ww2+ww2
; ww = ww2+ww2+ww2+ww2
sub calculate() -> ubyte {
Y = 99
return Y
} }
} }

View File

@ -3,6 +3,7 @@
%import c64graphics %import c64graphics
%option enable_floats %option enable_floats
; TODO fix compiler errors when compiling without optimizations
main { main {