From 7eed1ebbf8b74aa3fcc0f1fa52ed14633957f79e Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Wed, 10 Jul 2019 02:54:39 +0200 Subject: [PATCH] optimized typecasting more --- compiler/src/prog8/ast/base/Base.kt | 7 ++ .../prog8/ast/expressions/AstExpressions.kt | 20 +++-- .../src/prog8/compiler/AstToSourceCode.kt | 5 +- compiler/src/prog8/compiler/Main.kt | 2 +- .../prog8/optimizer/SimplifyExpressions.kt | 30 +++++-- examples/test.p8 | 79 ++----------------- 6 files changed, 49 insertions(+), 94 deletions(-) diff --git a/compiler/src/prog8/ast/base/Base.kt b/compiler/src/prog8/ast/base/Base.kt index b30f5afef..513d94713 100644 --- a/compiler/src/prog8/ast/base/Base.kt +++ b/compiler/src/prog8/ast/base/Base.kt @@ -44,6 +44,13 @@ enum class DataType { in WordDatatypes -> other in ByteDatatypes else -> true } + + infix fun equalsSize(other: DataType) = + when(this) { + in ByteDatatypes -> other in ByteDatatypes + in WordDatatypes -> other in WordDatatypes + else -> false + } } enum class Register { diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index 7248b4418..504a156d5 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -519,17 +519,15 @@ open class LiteralValue(val type: DataType, return LiteralValue(targettype, floatvalue = wordvalue!!.toDouble(), position = position) } DataType.FLOAT -> { - if(floor(floatvalue!!) ==floatvalue) { - val value = floatvalue.toInt() - if (targettype == DataType.BYTE && value in -128..127) - return LiteralValue(targettype, bytevalue = value.toShort(), position = position) - if (targettype == DataType.UBYTE && value in 0..255) - return LiteralValue(targettype, bytevalue = value.toShort(), position = position) - if (targettype == DataType.WORD && value in -32768..32767) - return LiteralValue(targettype, wordvalue = value, position = position) - if (targettype == DataType.UWORD && value in 0..65535) - return LiteralValue(targettype, wordvalue = value, position = position) - } + val value = floatvalue!!.toInt() + if (targettype == DataType.BYTE && value in -128..127) + return LiteralValue(targettype, bytevalue = value.toShort(), position = position) + if (targettype == DataType.UBYTE && value in 0..255) + return LiteralValue(targettype, bytevalue = value.toShort(), position = position) + if (targettype == DataType.WORD && value in -32768..32767) + return LiteralValue(targettype, wordvalue = value, position = position) + if (targettype == DataType.UWORD && value in 0..65535) + return LiteralValue(targettype, wordvalue = value, position = position) } in StringDatatypes -> { if(targettype in StringDatatypes) diff --git a/compiler/src/prog8/compiler/AstToSourceCode.kt b/compiler/src/prog8/compiler/AstToSourceCode.kt index 8b41815f2..9cd5089bb 100644 --- a/compiler/src/prog8/compiler/AstToSourceCode.kt +++ b/compiler/src/prog8/compiler/AstToSourceCode.kt @@ -360,8 +360,9 @@ class AstToSourceCode(val output: (text: String) -> Unit): IAstVisitor { } override fun visit(typecast: TypecastExpression) { + output("(") typecast.expression.accept(this) - output(" as ${datatypeString(typecast.type)} ") + output(" as ${datatypeString(typecast.type)}) ") } override fun visit(memread: DirectMemoryRead) { @@ -424,6 +425,6 @@ class AstToSourceCode(val output: (text: String) -> Unit): IAstVisitor { outputln("") } override fun visit(nopStatement: NopStatement) { - TODO("NOP???") + output("; NOP") } } diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index 2371dab64..707dfa6e4 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -87,7 +87,7 @@ fun compileProgram(filepath: Path, programAst.checkValid(compilerOptions) // check if final tree is valid programAst.checkRecursion() // check if there are recursive subroutine calls - // printAst(programAst) + printAst(programAst) // namespace.debugPrint() if(generateVmCode) { diff --git a/compiler/src/prog8/optimizer/SimplifyExpressions.kt b/compiler/src/prog8/optimizer/SimplifyExpressions.kt index 14e58c94e..d7be2ee38 100644 --- a/compiler/src/prog8/optimizer/SimplifyExpressions.kt +++ b/compiler/src/prog8/optimizer/SimplifyExpressions.kt @@ -36,8 +36,19 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying } override fun visit(typecast: TypecastExpression): IExpression { - // remove redundant typecasts var tc = typecast + + // try to statically convert a literal value into one of the desired type + val literal = tc.expression as? LiteralValue + if(literal!=null) { + val newLiteral = literal.cast(tc.type) + if(newLiteral!=null && newLiteral!==literal) { + optimizationsDone++ + return newLiteral + } + } + + // remove redundant typecasts while(true) { val expr = tc.expression if(expr !is TypecastExpression || expr.type!=tc.type) { @@ -50,16 +61,21 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying } } - // if the previous typecast was casting to a 'bigger' type, just ignore that one val subTc = tc.expression as? TypecastExpression - if(subTc!=null && subTc.type largerThan tc.type) { - subTc.type = tc.type - subTc.parent = tc.parent - optimizationsDone++ - return subTc + if(subTc!=null) { + // if the previous typecast was casting to a 'bigger' type, just ignore that one + // if the previous typecast was casting to a similar type, ignore that one + if(subTc.type largerThan tc.type || subTc.type equalsSize tc.type) { + subTc.type = tc.type + subTc.parent = tc.parent + optimizationsDone++ + return subTc + } } + return super.visit(tc) } + optimizationsDone++ tc = expr } diff --git a/examples/test.p8 b/examples/test.p8 index c273bebd2..71d281d13 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,85 +1,18 @@ %import c64utils %zeropage basicsafe +%option enable_floats ~ main { sub start() { - ubyte aa = 100 - ubyte yy = 22 - uword uw = (aa as uword)*yy - c64scr.print("stack (255?): ") - c64scr.print_ub(X) - c64.CHROUT('\n') + word zc + word qq = zc>>13 + ubyte[] colors = [1,2,3,4,5,6,7,8] - aa=30 - yy=2 + uword bb = zc>>13 + c64.SPCOL[0] = colors[(zc>>13) as byte + 4] - c64scr.print_ub(7) - c64scr.print("?: ") - check(3, 4) - - c64scr.print_ub(aa+yy) - c64scr.print("?: ") - check(aa, yy) - aa++ - c64scr.print_ub(aa+yy) - c64scr.print("?: ") - check(aa, yy) - aa++ - c64scr.print_ub(aa+yy) - c64scr.print("?: ") - check(aa, yy) - - c64scr.print_uw(uw) - c64scr.print("?: ") - checkuw(uw) - uw++ - c64scr.print_uw(uw) - c64scr.print("?: ") - checkuw(uw) - uw++ - c64scr.print_uw(uw) - c64scr.print("?: ") - checkuw(uw) - - c64scr.print("stack (255?): ") - c64scr.print_ub(X) } - sub checkuw(uword uw) { - when uw { - 12345 -> c64scr.print("12345") - 12346 -> c64scr.print("12346") - 2200 -> c64scr.print("2200") - 2202 -> c64scr.print("2202") - 12347 -> c64scr.print("12347") - else -> c64scr.print("not in table") - } - c64.CHROUT('\n') - } - - sub check(ubyte a, ubyte y) { - when a+y { - 10 -> { - c64scr.print("ten") - } - 5, 6, 7 -> c64scr.print("five or six or seven") - 30 -> c64scr.print("thirty") - 31 -> c64scr.print("thirty1") - 32 -> c64scr.print("thirty2") - 33 -> c64scr.print("thirty3") - 99 -> c64scr.print("nn") - 55 -> { - ; should be optimized away - } - 56 -> { - ; should be optimized away - } - else -> { - c64scr.print("not in table") - } - } - c64.CHROUT('\n') - } }