From 487faf3a08d284d1e3741a4f346673e7628728b4 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Thu, 10 Jan 2019 23:09:58 +0100 Subject: [PATCH] optimize and fix for loops --- compiler/src/prog8/ast/AST.kt | 3 +- compiler/src/prog8/compiler/Compiler.kt | 66 +++++++++++++------ .../src/prog8/compiler/intermediate/Opcode.kt | 4 ++ .../src/prog8/compiler/target/c64/AsmGen.kt | 20 +++++- .../src/prog8/optimizing/ConstantFolding.kt | 57 ++++++++++++++-- compiler/src/prog8/stackvm/StackVm.kt | 8 +-- examples/test.p8 | 26 ++++++-- examples/wizzine.p8 | 18 +++-- 8 files changed, 156 insertions(+), 46 deletions(-) diff --git a/compiler/src/prog8/ast/AST.kt b/compiler/src/prog8/ast/AST.kt index 6fd5521af..0082f7bdf 100644 --- a/compiler/src/prog8/ast/AST.kt +++ b/compiler/src/prog8/ast/AST.kt @@ -199,6 +199,7 @@ interface IAstProcessor { fun process(range: RangeExpr): IExpression { range.from = range.from.process(this) range.to = range.to.process(this) + range.step = range.step.process(this) return range } @@ -1316,7 +1317,7 @@ class RangeExpr(var from: IExpression, fromDt==DataType.STR_S && toDt==DataType.STR_S -> DataType.STR_S fromDt==DataType.STR_PS && toDt==DataType.STR_PS -> DataType.STR_PS fromDt==DataType.WORD || toDt==DataType.WORD -> DataType.WORD - fromDt==DataType.BYTE || toDt==DataType.BYTE -> DataType.UBYTE + fromDt==DataType.BYTE || toDt==DataType.BYTE -> DataType.BYTE else -> DataType.UBYTE } } diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index b5f2a708d..41ac18c0c 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -257,6 +257,16 @@ private class StatementTranslator(private val prog: IntermediateProgram, } } + private fun opcodeCompare(dt: DataType): Opcode { + return when (dt) { + DataType.UBYTE -> Opcode.CMP_UB + DataType.BYTE -> Opcode.CMP_B + DataType.UWORD -> Opcode.CMP_UW + DataType.WORD -> Opcode.CMP_W + else -> throw CompilerException("invalid dt $dt") + } + } + private fun opcodePushvar(dt: DataType): Opcode { return when (dt) { DataType.UBYTE, DataType.BYTE -> Opcode.PUSH_VAR_BYTE @@ -1729,10 +1739,18 @@ private class StatementTranslator(private val prog: IntermediateProgram, when (loopVarDt) { DataType.UBYTE -> { if (range.first < 0 || range.first > 255 || range.last < 0 || range.last > 255) - throw CompilerException("range out of bounds for byte") + throw CompilerException("range out of bounds for ubyte") } DataType.UWORD -> { if (range.first < 0 || range.first > 65535 || range.last < 0 || range.last > 65535) + throw CompilerException("range out of bounds for uword") + } + DataType.BYTE -> { + if (range.first < -128 || range.first > 127 || range.last < -128 || range.last > 127) + throw CompilerException("range out of bounds for byte") + } + DataType.WORD -> { + if (range.first < -32768 || range.first > 32767 || range.last < -32768 || range.last > 32767) throw CompilerException("range out of bounds for word") } else -> throw CompilerException("range must be byte or word") @@ -1840,11 +1858,8 @@ private class StatementTranslator(private val prog: IntermediateProgram, prog.label(continueLabel) prog.instr(opcodeIncvar(indexVarType), callLabel = indexVar.scopedname) - - // TODO: optimize edge cases if last value = 255 or 0 (for bytes) etc. to avoid PUSH_BYTE / SUB opcodes and make use of the wrapping around of the value. - prog.instr(opcodePush(indexVarType), Value(indexVarType, numElements)) prog.instr(opcodePushvar(indexVarType), callLabel = indexVar.scopedname) - prog.instr(opcodeSub(indexVarType)) + prog.instr(opcodeCompare(indexVarType), Value(indexVarType, numElements)) if(indexVarType==DataType.UWORD) prog.instr(Opcode.JNZW, callLabel = loopLabel) else @@ -1889,16 +1904,25 @@ private class StatementTranslator(private val prog: IntermediateProgram, prog.label(loopLabel) translate(body) prog.label(continueLabel) + val numberOfIncDecsForOptimize = 8 when { - range.step==1 -> prog.instr(opcodeIncvar(varDt), callLabel = varname) - range.step==-1 -> prog.instr(opcodeDecvar(varDt), callLabel = varname) - range.step>1 -> { + range.step in 1..numberOfIncDecsForOptimize -> { + repeat(range.step) { + prog.instr(opcodeIncvar(varDt), callLabel = varname) + } + } + range.step in -1 downTo -numberOfIncDecsForOptimize -> { + repeat(abs(range.step)) { + prog.instr(opcodeDecvar(varDt), callLabel = varname) + } + } + range.step>numberOfIncDecsForOptimize -> { prog.instr(opcodePushvar(varDt), callLabel = varname) prog.instr(opcodePush(varDt), Value(varDt, range.step)) prog.instr(opcodeAdd(varDt)) prog.instr(opcodePopvar(varDt), callLabel = varname) } - range.step<1 -> { + range.step { prog.instr(opcodePushvar(varDt), callLabel = varname) prog.instr(opcodePush(varDt), Value(varDt, abs(range.step))) prog.instr(opcodeSub(varDt)) @@ -1906,17 +1930,21 @@ private class StatementTranslator(private val prog: IntermediateProgram, } } - // TODO: optimize edge cases if last value = 255 or 0 (for bytes) etc. to avoid PUSH_BYTE / SUB opcodes and make use of the wrapping around of the value. - // TODO: ubyte/uword can't count down to 0 with negative step because test value will be <0 which causes "value out of range" crash - prog.instr(opcodePush(varDt), Value(varDt, range.last + range.step)) - prog.instr(opcodePushvar(varDt), callLabel = varname) - prog.instr(opcodeSub(varDt)) - val loopvarJumpOpcode = when(varDt) { - DataType.UBYTE, DataType.BYTE -> Opcode.JNZ - DataType.UWORD, DataType.WORD -> Opcode.JNZW - else -> throw CompilerException("invalid loop var datatype (expected byte or word) $varDt of var $varname") + if(range.last==0) { + // optimize for the for loop that counts to 0 + prog.instr(if(range.first>0) Opcode.BPOS else Opcode.BNEG, callLabel = loopLabel) + } else { + prog.instr(opcodePushvar(varDt), callLabel = varname) + val checkValue = + when (varDt) { + DataType.UBYTE -> (range.last + range.step) and 255 + DataType.UWORD -> (range.last + range.step) and 65535 + DataType.BYTE, DataType.WORD -> range.last + range.step + else -> throw CompilerException("invalid loop var dt $varDt") + } + prog.instr(opcodeCompare(varDt), Value(varDt, checkValue)) + prog.instr(Opcode.BNZ, callLabel = loopLabel) } - prog.instr(loopvarJumpOpcode, callLabel = loopLabel) prog.label(breakLabel) prog.instr(Opcode.NOP) // note: ending value of loop register / variable is *undefined* after this point! diff --git a/compiler/src/prog8/compiler/intermediate/Opcode.kt b/compiler/src/prog8/compiler/intermediate/Opcode.kt index ffd6491a1..c5faf6ce3 100644 --- a/compiler/src/prog8/compiler/intermediate/Opcode.kt +++ b/compiler/src/prog8/compiler/intermediate/Opcode.kt @@ -208,6 +208,10 @@ enum class Opcode { NOTEQUAL_BYTE, NOTEQUAL_WORD, NOTEQUAL_F, + CMP_B, // sets processor status flags based on comparison, instead of actually storing a result value + CMP_UB, // sets processor status flags based on comparison, instead of actually storing a result value + CMP_W, // sets processor status flags based on comparison, instead of actually storing a result value + CMP_UW, // sets processor status flags based on comparison, instead of actually storing a result value // array access and simple manipulations READ_INDEXED_VAR_BYTE, diff --git a/compiler/src/prog8/compiler/target/c64/AsmGen.kt b/compiler/src/prog8/compiler/target/c64/AsmGen.kt index 763b09a19..ff35fe0f0 100644 --- a/compiler/src/prog8/compiler/target/c64/AsmGen.kt +++ b/compiler/src/prog8/compiler/target/c64/AsmGen.kt @@ -3105,8 +3105,26 @@ class AsmGen(val options: CompilationOptions, val program: IntermediateProgram, adc ${(ESTACK_HI+1).toHex()},x sta ${(ESTACK_HI+1).toHex()},x """ - } + }, + AsmPattern(listOf(Opcode.PUSH_VAR_BYTE, Opcode.CMP_B), listOf(Opcode.PUSH_VAR_BYTE, Opcode.CMP_UB)) { segment -> + // this pattern is encountered as part of the loop bound condition in for loops (var + cmp + jz/jnz) + val cmpval = segment[1].arg!!.integerValue() + " lda ${segment[0].callLabel} | cmp #$cmpval " + }, + AsmPattern(listOf(Opcode.PUSH_VAR_WORD, Opcode.CMP_W), listOf(Opcode.PUSH_VAR_WORD, Opcode.CMP_UW)) { segment -> + // this pattern is encountered as part of the loop bound condition in for loops (var + cmp + jz/jnz) + """ + lda ${segment[0].callLabel} + cmp #<${hexVal(segment[1])} + bne + + lda ${segment[0].callLabel}+1 + cmp #>${hexVal(segment[1])} + bne + + lda #0 ++ + """ + } ) diff --git a/compiler/src/prog8/optimizing/ConstantFolding.kt b/compiler/src/prog8/optimizing/ConstantFolding.kt index 05e2f93dc..48ee68756 100644 --- a/compiler/src/prog8/optimizing/ConstantFolding.kt +++ b/compiler/src/prog8/optimizing/ConstantFolding.kt @@ -447,11 +447,58 @@ class ConstantFolding(private val namespace: INameScope, private val heap: HeapV } } - override fun process(range: RangeExpr): IExpression { - range.from = range.from.process(this) - range.to = range.to.process(this) - range.step = range.step.process(this) - return super.process(range) + override fun process(forLoop: ForLoop): IStatement { + + fun adjustRangeDt(rangeFrom: LiteralValue, targetDt: DataType, rangeTo: LiteralValue, stepLiteral: LiteralValue?, range: RangeExpr): RangeExpr { + val newFrom = rangeFrom.intoDatatype(targetDt) + val newTo = rangeTo.intoDatatype(targetDt) + if (newFrom != null && newTo != null) { + val newStep: IExpression = + if (stepLiteral != null) (stepLiteral.intoDatatype(targetDt) ?: stepLiteral) else range.step + return RangeExpr(newFrom, newTo, newStep, range.position) + } + return range + } + + // adjust the datatype of a range expression in for loops to the loop variable. + val resultStmt = super.process(forLoop) as ForLoop + val iterableRange = resultStmt.iterable as? RangeExpr ?: return resultStmt + val rangeFrom = iterableRange.from as? LiteralValue + val rangeTo = iterableRange.to as? LiteralValue + if(rangeFrom==null || rangeTo==null) return resultStmt + + val loopvar = resultStmt.loopVar!!.targetStatement(namespace) as? VarDecl + if(loopvar!=null) { + val stepLiteral = iterableRange.step as? LiteralValue + when(loopvar.datatype) { + DataType.UBYTE -> { + if(rangeFrom.type!=DataType.UBYTE) { + // attempt to translate the iterable into ubyte values + resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) + } + } + DataType.BYTE -> { + if(rangeFrom.type!=DataType.BYTE) { + // attempt to translate the iterable into byte values + resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) + } + } + DataType.UWORD -> { + if(rangeFrom.type!=DataType.UWORD) { + // attempt to translate the iterable into uword values + resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) + } + } + DataType.WORD -> { + if(rangeFrom.type!=DataType.WORD) { + // attempt to translate the iterable into word values + resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange) + } + } + else -> throw FatalAstException("invalid loopvar datatype $loopvar") + } + } + return resultStmt } override fun process(literalValue: LiteralValue): LiteralValue { diff --git a/compiler/src/prog8/stackvm/StackVm.kt b/compiler/src/prog8/stackvm/StackVm.kt index 8a51ad24d..a1a1f49cd 100644 --- a/compiler/src/prog8/stackvm/StackVm.kt +++ b/compiler/src/prog8/stackvm/StackVm.kt @@ -379,25 +379,25 @@ class StackVm(private var traceOutputFile: String?) { checkDt(second, DataType.FLOAT) evalstack.push(second.add(top)) } - Opcode.SUB_UB -> { + Opcode.SUB_UB, Opcode.CMP_UB -> { val (top, second) = evalstack.pop2() checkDt(top, DataType.UBYTE) checkDt(second, DataType.UBYTE) evalstack.push(second.sub(top)) } - Opcode.SUB_UW -> { + Opcode.SUB_UW, Opcode.CMP_UW -> { val (top, second) = evalstack.pop2() checkDt(top, DataType.UWORD) checkDt(second, DataType.UWORD) evalstack.push(second.sub(top)) } - Opcode.SUB_B -> { + Opcode.SUB_B, Opcode.CMP_B -> { val (top, second) = evalstack.pop2() checkDt(top, DataType.BYTE) checkDt(second, DataType.BYTE) evalstack.push(second.sub(top)) } - Opcode.SUB_W -> { + Opcode.SUB_W, Opcode.CMP_W -> { val (top, second) = evalstack.pop2() checkDt(top, DataType.WORD) checkDt(second, DataType.WORD) diff --git a/examples/test.p8 b/examples/test.p8 index fc5684eba..e01af000c 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -4,13 +4,27 @@ sub start() { - const word height=25 + ubyte i + byte j + uword uw + word w - word rz=33 - word persp = (rz+200) - persp = rz / 25 - persp = rz / height - persp = (rz+200) / height + for i in 5 to 0 step -1 { + c64scr.print_ub(i) + c64.CHROUT('\n') + } + c64.CHROUT('\n') + for j in 5 to 0 step -1 { + c64scr.print_b(j) + c64.CHROUT('\n') + } + c64.CHROUT('\n') + + for j in -5 to 0 { + c64scr.print_b(j) + c64.CHROUT('\n') + } + c64.CHROUT('\n') } } diff --git a/examples/wizzine.p8 b/examples/wizzine.p8 index 25566132e..383698ff5 100644 --- a/examples/wizzine.p8 +++ b/examples/wizzine.p8 @@ -52,16 +52,14 @@ sub irq() { angle++ c64.MSIGX=0 - ubyte i=14 -nextsprite: ; @todo should be a for loop from 14 to 0 step -2 but this causes a value out of range error at the moment - uword x = sin8u(angle*2-i*8) as uword + 50 - ubyte y = cos8u(angle*3-i*8) / 2 + 70 - c64.SPXY[i] = lsb(x) - c64.SPXY[i+1] = y - lsl(c64.MSIGX) - if msb(x) c64.MSIGX++ - i-=2 - if_pl goto nextsprite + for ubyte i in 14 to 0 step -2 { + uword x = sin8u(angle*2-i*8) as uword + 50 + ubyte y = cos8u(angle*3-i*8) / 2 + 70 + c64.SPXY[i] = lsb(x) + c64.SPXY[i+1] = y + lsl(c64.MSIGX) + if msb(x) c64.MSIGX++ + } c64.EXTCOL++ }