diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index a9a9f7001..9d48748c6 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -150,14 +150,14 @@ internal class AstChecker(private val program: Program, checkResult.add(ExpressionError("word loop variable can only loop over bytes or words", forLoop.position)) } DataType.FLOAT -> { - if(iterableDt!= DataType.FLOAT && iterableDt != DataType.ARRAY_F) - checkResult.add(ExpressionError("float loop variable can only loop over floats", forLoop.position)) + checkResult.add(ExpressionError("for loop only supports integers", forLoop.position)) } else -> checkResult.add(ExpressionError("loop variable must be numeric type", forLoop.position)) } } } } + super.visit(forLoop) } @@ -772,8 +772,11 @@ internal class AstChecker(private val program: Program, super.visit(range) val from = range.from.constValue(program) val to = range.to.constValue(program) - val stepLv = range.step.constValue(program) ?: NumericLiteralValue(DataType.UBYTE, 1, range.position) - if (stepLv.type !in IntegerDatatypes || stepLv.number.toInt() == 0) { + val stepLv = range.step.constValue(program) + if(stepLv==null) { + err("range step must be a constant integer") + return + } else if (stepLv.type !in IntegerDatatypes || stepLv.number.toInt() == 0) { err("range step must be an integer != 0") return } diff --git a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt index a49837735..c13569de0 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt @@ -952,71 +952,180 @@ internal class AsmGen2(val program: Program, } private fun translate(stmt: ForLoop) { - val iterableDt = stmt.iterable.inferType(program) - val loopLabel = makeLabel("for_loop") - val endLabel = makeLabel("for_end") + val iterableDt = stmt.iterable.inferType(program)!! when(stmt.iterable) { is RangeExpr -> { val range = (stmt.iterable as RangeExpr).toConstantIntegerRange() - if(range==null) - TODO("non-const range loop") - if(range.isEmpty()) - throw AssemblyError("empty range") - when(iterableDt) { - in ByteDatatypes -> { - if(stmt.loopRegister!=null) { + if(range==null) { + translateForOverNonconstRange(stmt, iterableDt, stmt.iterable as RangeExpr) + } else { + if (range.isEmpty()) + throw AssemblyError("empty range") + translateForOverConstRange(stmt, iterableDt, range) + } + } + is IdentifierReference -> { + translateForOverIterableVar(stmt, iterableDt, stmt.iterable as IdentifierReference) + } + else -> throw AssemblyError("can't iterate over ${stmt.iterable}") + } + } - // loop register over range + private fun translateForOverNonconstRange(stmt: ForLoop, iterableDt: DataType, range: RangeExpr) { + TODO("non-const range loop") + } - if(stmt.loopRegister!=Register.A) - throw AssemblyError("can only use A") - when { - range.step==1 -> { - // step = 1 - val counterLabel = makeLabel("for_counter") - out(""" + private fun translateForOverIterableVar(stmt: ForLoop, iterableDt: DataType, ident: IdentifierReference) { + val loopLabel = makeLabel("for_loop") + val endLabel = makeLabel("for_end") + val iterableName = asmIdentifierName(ident) + val decl = ident.targetVarDecl(program.namespace)!! + when(iterableDt) { + DataType.STR, DataType.STR_S -> { + if(stmt.loopRegister!=null && stmt.loopRegister!=Register.A) + throw AssemblyError("can only use A") + out(""" + lda #<$iterableName + ldy #>$iterableName + sta $loopLabel+1 + sty $loopLabel+2 +$loopLabel lda ${65535.toHex()} ; modified + beq $endLabel""") + if(stmt.loopVar!=null) + out(" sta ${asmIdentifierName(stmt.loopVar!!)}") + translate(stmt.body) + out(""" + inc $loopLabel+1 + bne $loopLabel + inc $loopLabel+2 + bne $loopLabel +$endLabel""") + } + DataType.ARRAY_UB, DataType.ARRAY_B -> { + val length = decl.arraysize!!.size() + if(stmt.loopRegister!=null && stmt.loopRegister!=Register.A) + throw AssemblyError("can only use A") + val counterLabel = makeLabel("for_counter") + val modifiedLabel = makeLabel("for_modified") + out(""" + lda #<$iterableName + ldy #>$iterableName + sta $modifiedLabel+1 + sty $modifiedLabel+2 + ldy #0 +$loopLabel sty $counterLabel + cpy #$length + beq $endLabel +$modifiedLabel lda ${65535.toHex()},y ; modified""") + if(stmt.loopVar!=null) + out(" sta ${asmIdentifierName(stmt.loopVar!!)}") + translate(stmt.body) + out(""" + ldy $counterLabel + iny + jmp $loopLabel +$counterLabel .byte 0 +$endLabel""") + } + DataType.ARRAY_W, DataType.ARRAY_UW -> { + val length = decl.arraysize!!.size()!! * 2 + if(stmt.loopRegister!=null) + throw AssemblyError("can't use register to loop over words") + val counterLabel = makeLabel("for_counter") + val modifiedLabel = makeLabel("for_modified") + val modifiedLabel2 = makeLabel("for_modified2") + val loopvarName = asmIdentifierName(stmt.loopVar!!) + out(""" + lda #<$iterableName + ldy #>$iterableName + sta $modifiedLabel+1 + sty $modifiedLabel+2 + lda #<$iterableName+1 + ldy #>$iterableName+1 + sta $modifiedLabel2+1 + sty $modifiedLabel2+2 + ldy #0 +$loopLabel sty $counterLabel + cpy #$length + beq $endLabel +$modifiedLabel lda ${65535.toHex()},y ; modified + sta $loopvarName +$modifiedLabel2 lda ${65535.toHex()},y ; modified + sta $loopvarName+1""") + translate(stmt.body) + out(""" + ldy $counterLabel + iny + iny + jmp $loopLabel +$counterLabel .byte 0 +$endLabel""") + } + DataType.ARRAY_F -> { + throw AssemblyError("for loop with floating point variables is not supported") + } + else -> throw AssemblyError("can't iterate over $iterableDt") + } + } + + private fun translateForOverConstRange(stmt: ForLoop, iterableDt: DataType, range: IntProgression) { + val loopLabel = makeLabel("for_loop") + val endLabel = makeLabel("for_end") + when(iterableDt) { + in ByteDatatypes -> { + if(stmt.loopRegister!=null) { + + // loop register over range + + if(stmt.loopRegister!=Register.A) + throw AssemblyError("can only use A") + when { + range.step==1 -> { + // step = 1 + val counterLabel = makeLabel("for_counter") + out(""" lda #${range.first} sta $loopLabel+1 lda #${range.last-range.first+1 and 255} sta $counterLabel $loopLabel lda #0 ; modified""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel inc $loopLabel+1 jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - range.step==-1 -> { - // step = -1 - val counterLabel = makeLabel("for_counter") - out(""" + } + range.step==-1 -> { + // step = -1 + val counterLabel = makeLabel("for_counter") + out(""" lda #${range.first} sta $loopLabel+1 lda #${range.first-range.last+1 and 255} sta $counterLabel $loopLabel lda #0 ; modified """) - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel dec $loopLabel+1 jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - range.step>0 -> { - // step >= 2 - val counterLabel = makeLabel("for_counter") - out(""" + } + range.step >= 2 -> { + // step >= 2 + val counterLabel = makeLabel("for_counter") + out(""" lda #${(range.last-range.first) / range.step + 1} sta $counterLabel lda #${range.first} $loopLabel pha""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" pla dec $counterLabel beq $endLabel @@ -1025,17 +1134,17 @@ $loopLabel pha""") jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - else -> { - // step <= -2 - val counterLabel = makeLabel("for_counter") - out(""" + } + else -> { + // step <= -2 + val counterLabel = makeLabel("for_counter") + out(""" lda #${(range.first-range.last) / range.step.absoluteValue + 1} sta $counterLabel lda #${range.first} $loopLabel pha""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" pla dec $counterLabel beq $endLabel @@ -1044,61 +1153,59 @@ $loopLabel pha""") jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - } + } + } - } else { + } else { - // loop over byte range via loopvar - val varname = asmIdentifierName(stmt.loopVar!!) - when { - range.step==1 -> { - // step = 1 - val counterLabel = makeLabel("for_counter") - out(""" + // loop over byte range via loopvar + val varname = asmIdentifierName(stmt.loopVar!!) + val counterLabel = makeLabel("for_counter") + when { + range.step==1 -> { + // step = 1 + out(""" lda #${range.first} sta $varname lda #${range.last-range.first+1 and 255} sta $counterLabel $loopLabel""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel inc $varname jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - range.step==-1 -> { - // step = -1 - val counterLabel = makeLabel("for_counter") - out(""" + } + range.step==-1 -> { + // step = -1 + out(""" lda #${range.first} sta $varname lda #${range.first-range.last+1 and 255} sta $counterLabel $loopLabel""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel dec $varname jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - range.step>0 -> { - // step >= 2 - val counterLabel = makeLabel("for_counter") - out(""" + } + range.step >= 2 -> { + // step >= 2 + out(""" lda #${(range.last-range.first) / range.step + 1} sta $counterLabel lda #${range.first} sta $varname $loopLabel""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel lda $varname @@ -1108,18 +1215,17 @@ $loopLabel""") jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - else -> { - // step <= -2 - val counterLabel = makeLabel("for_counter") - out(""" + } + else -> { + // step <= -2 + out(""" lda #${(range.first-range.last) / range.step.absoluteValue + 1} sta $counterLabel lda #${range.first} sta $varname $loopLabel""") - translate(stmt.body) - out(""" + translate(stmt.body) + out(""" dec $counterLabel beq $endLabel lda $varname @@ -1129,89 +1235,111 @@ $loopLabel""") jmp $loopLabel $counterLabel .byte 0 $endLabel""") - } - } } } - in WordDatatypes -> { - TODO("forloop over word range $stmt") // TODO - } - else -> throw AssemblyError("range expression can only be byte or word") } } - is IdentifierReference -> { - val ident = (stmt.iterable as IdentifierReference) - val iterableName = asmIdentifierName(ident) - val decl = ident.targetVarDecl(program.namespace)!! - when(iterableDt) { - DataType.STR, DataType.STR_S -> { - if(stmt.loopRegister!=null && stmt.loopRegister!=Register.A) - throw AssemblyError("can only use A") + in WordDatatypes -> { + // loop over word range via loopvar + val counterLabel = makeLabel("for_counter") + val varname = asmIdentifierName(stmt.loopVar!!) + when { + range.step == 1 -> { + // step = 1 out(""" - lda #<$iterableName - ldy #>$iterableName - sta $loopLabel+1 - sty $loopLabel+2 -$loopLabel lda ${65535.toHex()} ; modified - beq $endLabel""") - if(stmt.loopVar!=null) - out(" sta ${asmIdentifierName(stmt.loopVar!!)}") + lda #<${range.first} + ldy #>${range.first} + sta $varname + sty $varname+1 + lda #${range.last - range.first + 1 and 255} + sta $counterLabel +$loopLabel""") translate(stmt.body) out(""" - inc $loopLabel+1 - bne $loopLabel - inc $loopLabel+2 - bne $loopLabel -$endLabel""") - } - DataType.ARRAY_UB, DataType.ARRAY_B -> { - val length = decl.arraysize!!.size() - if(stmt.loopRegister!=null && stmt.loopRegister!=Register.A) - throw AssemblyError("can only use A") - val counterLabel = makeLabel("for_counter") - val modifiedLabel = makeLabel("for_modified") - out(""" - lda #<$iterableName - ldy #>$iterableName - sta $modifiedLabel+1 - sty $modifiedLabel+2 - ldy #0 -$loopLabel sty $counterLabel - cpy #$length + dec $counterLabel beq $endLabel -$modifiedLabel lda ${65535.toHex()},y ; modified""") - if(stmt.loopVar!=null) - out(" sta ${asmIdentifierName(stmt.loopVar!!)}") - translate(stmt.body) - out(""" - ldy $counterLabel - iny + inc $varname + bne $loopLabel + inc $varname+1 jmp $loopLabel -$counterLabel .byte 0 +$counterLabel .byte 0 $endLabel""") } - DataType.ARRAY_W, DataType.ARRAY_UW -> { - val length = decl.arraysize!!.size() - println("forloop over word array len $length $stmt") // TODO - if(stmt.loopRegister!=null) { - TODO("loop register over wordarray of len $length") - } else { - TODO("loop variable over wordarray of len $length") - } + range.step == -1 -> { + // step = 1 + out(""" + lda #<${range.first} + ldy #>${range.first} + sta $varname + sty $varname+1 + lda #${range.first - range.last + 1 and 255} + sta $counterLabel +$loopLabel""") + translate(stmt.body) + out(""" + dec $counterLabel + beq $endLabel + lda $varname + bne + + dec $varname+1 ++ dec $varname + jmp $loopLabel +$counterLabel .byte 0 +$endLabel""") } - DataType.ARRAY_F -> { - val length = decl.arraysize!!.size() - println("forloop over float array len $length $stmt") // TODO - if(stmt.loopRegister!=null) { - throw AssemblyError("can't use register to loop over floats") - } else { - TODO("loop variable over floatarray of len $length") - } + range.step >= 2 -> { + // step >= 2 + out(""" + lda #<${range.first} + ldy #>${range.first} + sta $varname + sty $varname+1 + lda #${(range.last-range.first) / range.step + 1} + sta $counterLabel +$loopLabel""") + translate(stmt.body) + out(""" + dec $counterLabel + beq $endLabel + clc + lda $varname + adc #<${range.step} + sta $varname + lda $varname+1 + adc #>${range.step} + sta $varname+1 + jmp $loopLabel +$counterLabel .byte 0 +$endLabel""") + } + else -> { + // step <= -2 + out(""" + lda #<${range.first} + ldy #>${range.first} + sta $varname + sty $varname+1 + lda #${(range.first-range.last) / range.step.absoluteValue + 1} + sta $counterLabel +$loopLabel""") + translate(stmt.body) + out(""" + dec $counterLabel + beq $endLabel + sec + lda $varname + sbc #<${range.step.absoluteValue} + sta $varname + lda $varname+1 + sbc #>${range.step.absoluteValue} + sta $varname+1 + jmp $loopLabel +$counterLabel .byte 0 +$endLabel""") } - else -> throw AssemblyError("can't iterate over $iterableDt") } } - else -> throw AssemblyError("can't iterate over ${stmt.iterable}") + else -> throw AssemblyError("range expression can only be byte or word") } } diff --git a/examples/test.p8 b/examples/test.p8 index f79fe8b13..97e7ee0f1 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -8,6 +8,9 @@ main { byte bvar ubyte var2 + ubyte[] barr = [22,33,44,55,66] + word[] warr = [-111,222,-333,444] + for A in "hello" { c64scr.print_ub(A) c64.CHROUT(',') @@ -43,6 +46,12 @@ main { c64.CHROUT(',') } c64.CHROUT('\n') + + for A in barr { + c64scr.print_ub(A) + c64.CHROUT(',') + } + c64.CHROUT('\n') c64.CHROUT('\n') for ubyte cc in "hello" { @@ -80,13 +89,49 @@ main { c64.CHROUT(',') } c64.CHROUT('\n') + + for ubyte cc7 in barr { + c64scr.print_ub(cc7) + c64.CHROUT(',') + } + c64.CHROUT('\n') c64.CHROUT('\n') + for uword ww1 in [1111, 2222, 3333] { + c64scr.print_uw(ww1) + c64.CHROUT(',') + } + c64.CHROUT('\n') -; for float fl in [1.1, 2.2, 5.5, 99.99] { -; c64flt.print_f(fl) -; c64.CHROUT(',') -; } -; c64.CHROUT('\n') + for word ww2 in warr { + c64scr.print_w(ww2) + c64.CHROUT(',') + } + c64.CHROUT('\n') + + for uword ww3 in 1111 to 1122 { + c64scr.print_uw(ww3) + c64.CHROUT(',') + } + c64.CHROUT('\n') + + for uword ww3b in 2000 to 1990 step -1 { + c64scr.print_uw(ww3b) + c64.CHROUT(',') + } + c64.CHROUT('\n') + + for uword ww3c in 1111 to 50000 step 3333 { + c64scr.print_uw(ww3c) + c64.CHROUT(',') + } + c64.CHROUT('\n') + + for word ww4 in 999 to -999 step -500 { + c64scr.print_w(ww4) + c64.CHROUT(',') + } + c64.CHROUT('\n') + c64.CHROUT('\n') } }