diff --git a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt index c32cc0d6e..ea29a626d 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt @@ -300,9 +300,7 @@ internal class AsmGen2(val program: Program, } } DataType.ARRAY_F -> { - val array = (decl.value as ReferenceLiteralValue).array - if(array==null) - TODO("fix this") + val array = (decl.value as ReferenceLiteralValue).array ?: throw AssemblyError("array should not be null?") val floatFills = array.map { val number = (it as NumericLiteralValue).number makeFloatFill(MachineDefinition.Mflpt5.fromNumber(number)) @@ -367,9 +365,7 @@ internal class AsmGen2(val program: Program, } private fun makeArrayFillDataUnsigned(decl: VarDecl): List { - val array = (decl.value as ReferenceLiteralValue).array - if(array==null) - TODO("fix this") + val array = (decl.value as ReferenceLiteralValue).array ?: throw AssemblyError("array should not be null?") return when { decl.datatype == DataType.ARRAY_UB -> // byte array can never contain pointer-to types, so treat values as all integers @@ -387,9 +383,7 @@ internal class AsmGen2(val program: Program, } private fun makeArrayFillDataSigned(decl: VarDecl): List { - val array = (decl.value as ReferenceLiteralValue).array - if(array==null) - TODO("fix this ${decl.value}") + val array = (decl.value as ReferenceLiteralValue).array ?: throw AssemblyError("array should not be null?") return when { decl.datatype == DataType.ARRAY_UB -> @@ -1865,44 +1859,103 @@ $endLabel""") } } + private val optimizedByteMultiplications = setOf(3,5,6,7,9,10,11,12,13,14,15,20,25,40) + private val optimizedWordMultiplications = setOf(3,5,6,7,9,10,12,15,20,25,40) + private val powerOfTwos = setOf(0,1,2,4,8,16,32,64,128,256) + private fun translateExpression(expr: BinaryExpression) { val leftDt = expr.left.inferType(program)!! val rightDt = expr.right.inferType(program)!! + + // see if we can apply some optimized routines when(expr.operator) { ">>" -> { // bit-shifts are always by a constant number (for now) translateExpression(expr.left) val amount = expr.right.constValue(program)!!.number.toInt() - when(leftDt) { + when (leftDt) { DataType.UBYTE -> repeat(amount) { out(" lsr $ESTACK_LO_PLUS1_HEX,x") } DataType.BYTE -> repeat(amount) { out(" lda $ESTACK_LO_PLUS1_HEX,x | asl a | ror $ESTACK_LO_PLUS1_HEX,x") } DataType.UWORD -> repeat(amount) { out(" lsr $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") } - DataType.WORD -> repeat(amount) { out( " lda $ESTACK_HI_PLUS1_HEX,x | asl a | ror $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") } + DataType.WORD -> repeat(amount) { out(" lda $ESTACK_HI_PLUS1_HEX,x | asl a | ror $ESTACK_HI_PLUS1_HEX,x | ror $ESTACK_LO_PLUS1_HEX,x") } else -> throw AssemblyError("weird type") } + return } "<<" -> { // bit-shifts are always by a constant number (for now) translateExpression(expr.left) val amount = expr.right.constValue(program)!!.number.toInt() - if(leftDt in ByteDatatypes) + if (leftDt in ByteDatatypes) repeat(amount) { out(" asl $ESTACK_LO_PLUS1_HEX,x") } else repeat(amount) { out(" asl $ESTACK_LO_PLUS1_HEX,x | rol $ESTACK_HI_PLUS1_HEX,x") } + return } - else -> { - translateExpression(expr.left) - translateExpression(expr.right) - if(leftDt!=rightDt) - throw AssemblyError("binary operator ${expr.operator} left/right dt not identical") // is this strictly required always? - when (leftDt) { - in ByteDatatypes -> translateBinaryOperatorBytes(expr.operator, leftDt) - in WordDatatypes -> translateBinaryOperatorWords(expr.operator, leftDt) - DataType.FLOAT -> translateBinaryOperatorFloats(expr.operator) - else -> throw AssemblyError("non-numerical datatype") + "*" -> { + val value = expr.right.constValue(program) + if(value!=null) { + if(rightDt in IntegerDatatypes) { + val amount = value.number.toInt() + if(amount in powerOfTwos) + printWarning("${expr.right.position} multiplication by power of 2 should have been optimized into a left shift instruction: $amount") + when(rightDt) { + DataType.UBYTE -> { + if(amount in optimizedByteMultiplications) { + translateExpression(expr.left) + out(" jsr math.mul_byte_$amount") + return + } + } + DataType.BYTE -> { + if(amount in optimizedByteMultiplications) { + translateExpression(expr.left) + out(" jsr math.mul_byte_$amount") + return + } + if(amount.absoluteValue in optimizedByteMultiplications) { + translateExpression(expr.left) + out(" jsr prog8_lib.neg_b | jsr math.mul_byte_${amount.absoluteValue}") + return + } + } + DataType.UWORD -> { + if(amount in optimizedWordMultiplications) { + translateExpression(expr.left) + out(" jsr math.mul_word_$amount") + return + } + } + DataType.WORD -> { + if(amount in optimizedWordMultiplications) { + translateExpression(expr.left) + out(" jsr math.mul_word_$amount") + return + } + if(amount.absoluteValue in optimizedWordMultiplications) { + translateExpression(expr.left) + out(" jsr prog8_lib.neg_w | jsr math.mul_word_${amount.absoluteValue}") + return + } + } + else -> {} + } + } } } } + + // the general, non-optimized cases + translateExpression(expr.left) + translateExpression(expr.right) + if(leftDt!=rightDt) + throw AssemblyError("binary operator ${expr.operator} left/right dt not identical") // is this strictly required always? + when (leftDt) { + in ByteDatatypes -> translateBinaryOperatorBytes(expr.operator, leftDt) + in WordDatatypes -> translateBinaryOperatorWords(expr.operator, leftDt) + DataType.FLOAT -> translateBinaryOperatorFloats(expr.operator) + else -> throw AssemblyError("non-numerical datatype") + } } private fun translateExpression(expr: PrefixExpression) { diff --git a/examples/cube3d.p8 b/examples/cube3d.p8 index 4c4cdd1e7..56263a290 100644 --- a/examples/cube3d.p8 +++ b/examples/cube3d.p8 @@ -75,22 +75,28 @@ main { ; plot the points of the 3d cube ; first the points on the back, then the points on the front (painter algorithm) - for ubyte i in 0 to len(xcoor)-1 { - word rz = rotatedz[i] + ubyte i + word rz + word persp + byte sx + byte sy + + for i in 0 to len(xcoor)-1 { + rz = rotatedz[i] if rz >= 10 { - word persp = (rz+200) / height - byte sx = rotatedx[i] / persp as byte + width/2 - byte sy = rotatedy[i] / persp as byte + height/2 + persp = (rz+200) / height + sx = rotatedx[i] / persp as byte + width/2 + sy = rotatedy[i] / persp as byte + height/2 c64scr.setcc(sx as ubyte, sy as ubyte, 46, vertexcolors[(rz as byte >>5) + 3]) } } - for ubyte i in 0 to len(xcoor)-1 { - word rz = rotatedz[i] + for i in 0 to len(xcoor)-1 { + rz = rotatedz[i] if rz < 10 { - word persp = (rz+200) / height - byte sx = rotatedx[i] / persp as byte + width/2 - byte sy = rotatedy[i] / persp as byte + height/2 + persp = (rz+200) / height + sx = rotatedx[i] / persp as byte + width/2 + sy = rotatedy[i] / persp as byte + height/2 c64scr.setcc(sx as ubyte, sy as ubyte, 81, vertexcolors[(rz as byte >>5) + 3]) } } diff --git a/examples/test.p8 b/examples/test.p8 index 01191bbb5..71cfec818 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -8,135 +8,20 @@ main { sub start() { - c64scr.plot(0,24) + byte bb + ubyte ub + word ww + uword uw + float fl - ubyte ub=200 - byte bb=-100 - uword uw = 2000 - word ww = -1000 - float fl = 999.99 - ubyte[3] ubarr = 200 - byte[3] barr = -100 - uword[3] uwarr = 2000 - word[3] warr = -1000 - float[3] flarr = 999.99 + bb = 10*bb + ub = 12*ub + ww = 15*ww + uw = 20*uw + fl = 20*fl - c64scr.print("++\n") - ub++ - bb++ - uw++ - ww++ - fl++ - ubarr[1]++ - barr[1]++ - uwarr[1]++ - warr[1]++ - flarr[1] ++ - check_ub(ub, 201) - Y=100 - Y++ - check_ub(Y, 101) - check_fl(fl, 1000.99) - check_b(bb, -99) - check_uw(uw, 2001) - check_w(ww, -999) - check_ub(ubarr[0], 200) - check_fl(flarr[0], 999.99) - check_b(barr[0], -100) - check_uw(uwarr[0], 2000) - check_w(warr[0], -1000) - check_ub(ubarr[1], 201) - check_fl(flarr[1], 1000.99) - check_b(barr[1], -99) - check_uw(uwarr[1], 2001) - check_w(warr[1], -999) - c64scr.print("--\n") - ub-- - bb-- - uw-- - ww-- - fl-- - ubarr[1]-- - barr[1]-- - uwarr[1]-- - warr[1]-- - flarr[1] -- - check_ub(ub, 200) - Y=100 - Y-- - check_ub(Y, 99) - check_fl(fl, 999.99) - check_b(bb, -100) - check_uw(uw, 2000) - check_w(ww, -1000) - check_ub(ubarr[1], 200) - check_fl(flarr[1], 999.99) - check_b(barr[1], -100) - check_uw(uwarr[1], 2000) - check_w(warr[1], -1000) - - @($0400+400-1) = X } - sub check_ub(ubyte value, ubyte expected) { - if value==expected - c64scr.print(" ok ") - else - c64scr.print("err! ") - c64scr.print(" ubyte ") - c64scr.print_ub(value) - c64.CHROUT(',') - c64scr.print_ub(expected) - c64.CHROUT('\n') - } - - sub check_b(byte value, byte expected) { - if value==expected - c64scr.print(" ok ") - else - c64scr.print("err! ") - c64scr.print(" byte ") - c64scr.print_b(value) - c64.CHROUT(',') - c64scr.print_b(expected) - c64.CHROUT('\n') - } - - sub check_uw(uword value, uword expected) { - if value==expected - c64scr.print(" ok ") - else - c64scr.print("err! ") - c64scr.print(" uword ") - c64scr.print_uw(value) - c64.CHROUT(',') - c64scr.print_uw(expected) - c64.CHROUT('\n') - } - - sub check_w(word value, word expected) { - if value==expected - c64scr.print(" ok ") - else - c64scr.print("err! ") - c64scr.print(" word ") - c64scr.print_w(value) - c64.CHROUT(',') - c64scr.print_w(expected) - c64.CHROUT('\n') - } - - sub check_fl(float value, float expected) { - if value==expected - c64scr.print(" ok ") - else - c64scr.print("err! ") - c64scr.print(" float ") - c64flt.print_f(value) - c64.CHROUT(',') - c64flt.print_f(expected) - c64.CHROUT('\n') - } }