mirror of
https://github.com/irmen/prog8.git
synced 2025-01-26 19:30:59 +00:00
optimize and fix for loops
This commit is contained in:
parent
aa00db4d80
commit
487faf3a08
@ -199,6 +199,7 @@ interface IAstProcessor {
|
||||
fun process(range: RangeExpr): IExpression {
|
||||
range.from = range.from.process(this)
|
||||
range.to = range.to.process(this)
|
||||
range.step = range.step.process(this)
|
||||
return range
|
||||
}
|
||||
|
||||
@ -1316,7 +1317,7 @@ class RangeExpr(var from: IExpression,
|
||||
fromDt==DataType.STR_S && toDt==DataType.STR_S -> DataType.STR_S
|
||||
fromDt==DataType.STR_PS && toDt==DataType.STR_PS -> DataType.STR_PS
|
||||
fromDt==DataType.WORD || toDt==DataType.WORD -> DataType.WORD
|
||||
fromDt==DataType.BYTE || toDt==DataType.BYTE -> DataType.UBYTE
|
||||
fromDt==DataType.BYTE || toDt==DataType.BYTE -> DataType.BYTE
|
||||
else -> DataType.UBYTE
|
||||
}
|
||||
}
|
||||
|
@ -257,6 +257,16 @@ private class StatementTranslator(private val prog: IntermediateProgram,
|
||||
}
|
||||
}
|
||||
|
||||
private fun opcodeCompare(dt: DataType): Opcode {
|
||||
return when (dt) {
|
||||
DataType.UBYTE -> Opcode.CMP_UB
|
||||
DataType.BYTE -> Opcode.CMP_B
|
||||
DataType.UWORD -> Opcode.CMP_UW
|
||||
DataType.WORD -> Opcode.CMP_W
|
||||
else -> throw CompilerException("invalid dt $dt")
|
||||
}
|
||||
}
|
||||
|
||||
private fun opcodePushvar(dt: DataType): Opcode {
|
||||
return when (dt) {
|
||||
DataType.UBYTE, DataType.BYTE -> Opcode.PUSH_VAR_BYTE
|
||||
@ -1729,10 +1739,18 @@ private class StatementTranslator(private val prog: IntermediateProgram,
|
||||
when (loopVarDt) {
|
||||
DataType.UBYTE -> {
|
||||
if (range.first < 0 || range.first > 255 || range.last < 0 || range.last > 255)
|
||||
throw CompilerException("range out of bounds for byte")
|
||||
throw CompilerException("range out of bounds for ubyte")
|
||||
}
|
||||
DataType.UWORD -> {
|
||||
if (range.first < 0 || range.first > 65535 || range.last < 0 || range.last > 65535)
|
||||
throw CompilerException("range out of bounds for uword")
|
||||
}
|
||||
DataType.BYTE -> {
|
||||
if (range.first < -128 || range.first > 127 || range.last < -128 || range.last > 127)
|
||||
throw CompilerException("range out of bounds for byte")
|
||||
}
|
||||
DataType.WORD -> {
|
||||
if (range.first < -32768 || range.first > 32767 || range.last < -32768 || range.last > 32767)
|
||||
throw CompilerException("range out of bounds for word")
|
||||
}
|
||||
else -> throw CompilerException("range must be byte or word")
|
||||
@ -1840,11 +1858,8 @@ private class StatementTranslator(private val prog: IntermediateProgram,
|
||||
prog.label(continueLabel)
|
||||
|
||||
prog.instr(opcodeIncvar(indexVarType), callLabel = indexVar.scopedname)
|
||||
|
||||
// TODO: optimize edge cases if last value = 255 or 0 (for bytes) etc. to avoid PUSH_BYTE / SUB opcodes and make use of the wrapping around of the value.
|
||||
prog.instr(opcodePush(indexVarType), Value(indexVarType, numElements))
|
||||
prog.instr(opcodePushvar(indexVarType), callLabel = indexVar.scopedname)
|
||||
prog.instr(opcodeSub(indexVarType))
|
||||
prog.instr(opcodeCompare(indexVarType), Value(indexVarType, numElements))
|
||||
if(indexVarType==DataType.UWORD)
|
||||
prog.instr(Opcode.JNZW, callLabel = loopLabel)
|
||||
else
|
||||
@ -1889,16 +1904,25 @@ private class StatementTranslator(private val prog: IntermediateProgram,
|
||||
prog.label(loopLabel)
|
||||
translate(body)
|
||||
prog.label(continueLabel)
|
||||
val numberOfIncDecsForOptimize = 8
|
||||
when {
|
||||
range.step==1 -> prog.instr(opcodeIncvar(varDt), callLabel = varname)
|
||||
range.step==-1 -> prog.instr(opcodeDecvar(varDt), callLabel = varname)
|
||||
range.step>1 -> {
|
||||
range.step in 1..numberOfIncDecsForOptimize -> {
|
||||
repeat(range.step) {
|
||||
prog.instr(opcodeIncvar(varDt), callLabel = varname)
|
||||
}
|
||||
}
|
||||
range.step in -1 downTo -numberOfIncDecsForOptimize -> {
|
||||
repeat(abs(range.step)) {
|
||||
prog.instr(opcodeDecvar(varDt), callLabel = varname)
|
||||
}
|
||||
}
|
||||
range.step>numberOfIncDecsForOptimize -> {
|
||||
prog.instr(opcodePushvar(varDt), callLabel = varname)
|
||||
prog.instr(opcodePush(varDt), Value(varDt, range.step))
|
||||
prog.instr(opcodeAdd(varDt))
|
||||
prog.instr(opcodePopvar(varDt), callLabel = varname)
|
||||
}
|
||||
range.step<1 -> {
|
||||
range.step<numberOfIncDecsForOptimize -> {
|
||||
prog.instr(opcodePushvar(varDt), callLabel = varname)
|
||||
prog.instr(opcodePush(varDt), Value(varDt, abs(range.step)))
|
||||
prog.instr(opcodeSub(varDt))
|
||||
@ -1906,17 +1930,21 @@ private class StatementTranslator(private val prog: IntermediateProgram,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: optimize edge cases if last value = 255 or 0 (for bytes) etc. to avoid PUSH_BYTE / SUB opcodes and make use of the wrapping around of the value.
|
||||
// TODO: ubyte/uword can't count down to 0 with negative step because test value will be <0 which causes "value out of range" crash
|
||||
prog.instr(opcodePush(varDt), Value(varDt, range.last + range.step))
|
||||
prog.instr(opcodePushvar(varDt), callLabel = varname)
|
||||
prog.instr(opcodeSub(varDt))
|
||||
val loopvarJumpOpcode = when(varDt) {
|
||||
DataType.UBYTE, DataType.BYTE -> Opcode.JNZ
|
||||
DataType.UWORD, DataType.WORD -> Opcode.JNZW
|
||||
else -> throw CompilerException("invalid loop var datatype (expected byte or word) $varDt of var $varname")
|
||||
if(range.last==0) {
|
||||
// optimize for the for loop that counts to 0
|
||||
prog.instr(if(range.first>0) Opcode.BPOS else Opcode.BNEG, callLabel = loopLabel)
|
||||
} else {
|
||||
prog.instr(opcodePushvar(varDt), callLabel = varname)
|
||||
val checkValue =
|
||||
when (varDt) {
|
||||
DataType.UBYTE -> (range.last + range.step) and 255
|
||||
DataType.UWORD -> (range.last + range.step) and 65535
|
||||
DataType.BYTE, DataType.WORD -> range.last + range.step
|
||||
else -> throw CompilerException("invalid loop var dt $varDt")
|
||||
}
|
||||
prog.instr(opcodeCompare(varDt), Value(varDt, checkValue))
|
||||
prog.instr(Opcode.BNZ, callLabel = loopLabel)
|
||||
}
|
||||
prog.instr(loopvarJumpOpcode, callLabel = loopLabel)
|
||||
prog.label(breakLabel)
|
||||
prog.instr(Opcode.NOP)
|
||||
// note: ending value of loop register / variable is *undefined* after this point!
|
||||
|
@ -208,6 +208,10 @@ enum class Opcode {
|
||||
NOTEQUAL_BYTE,
|
||||
NOTEQUAL_WORD,
|
||||
NOTEQUAL_F,
|
||||
CMP_B, // sets processor status flags based on comparison, instead of actually storing a result value
|
||||
CMP_UB, // sets processor status flags based on comparison, instead of actually storing a result value
|
||||
CMP_W, // sets processor status flags based on comparison, instead of actually storing a result value
|
||||
CMP_UW, // sets processor status flags based on comparison, instead of actually storing a result value
|
||||
|
||||
// array access and simple manipulations
|
||||
READ_INDEXED_VAR_BYTE,
|
||||
|
@ -3105,8 +3105,26 @@ class AsmGen(val options: CompilationOptions, val program: IntermediateProgram,
|
||||
adc ${(ESTACK_HI+1).toHex()},x
|
||||
sta ${(ESTACK_HI+1).toHex()},x
|
||||
"""
|
||||
}
|
||||
},
|
||||
|
||||
AsmPattern(listOf(Opcode.PUSH_VAR_BYTE, Opcode.CMP_B), listOf(Opcode.PUSH_VAR_BYTE, Opcode.CMP_UB)) { segment ->
|
||||
// this pattern is encountered as part of the loop bound condition in for loops (var + cmp + jz/jnz)
|
||||
val cmpval = segment[1].arg!!.integerValue()
|
||||
" lda ${segment[0].callLabel} | cmp #$cmpval "
|
||||
},
|
||||
AsmPattern(listOf(Opcode.PUSH_VAR_WORD, Opcode.CMP_W), listOf(Opcode.PUSH_VAR_WORD, Opcode.CMP_UW)) { segment ->
|
||||
// this pattern is encountered as part of the loop bound condition in for loops (var + cmp + jz/jnz)
|
||||
"""
|
||||
lda ${segment[0].callLabel}
|
||||
cmp #<${hexVal(segment[1])}
|
||||
bne +
|
||||
lda ${segment[0].callLabel}+1
|
||||
cmp #>${hexVal(segment[1])}
|
||||
bne +
|
||||
lda #0
|
||||
+
|
||||
"""
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
|
@ -447,11 +447,58 @@ class ConstantFolding(private val namespace: INameScope, private val heap: HeapV
|
||||
}
|
||||
}
|
||||
|
||||
override fun process(range: RangeExpr): IExpression {
|
||||
range.from = range.from.process(this)
|
||||
range.to = range.to.process(this)
|
||||
range.step = range.step.process(this)
|
||||
return super.process(range)
|
||||
override fun process(forLoop: ForLoop): IStatement {
|
||||
|
||||
fun adjustRangeDt(rangeFrom: LiteralValue, targetDt: DataType, rangeTo: LiteralValue, stepLiteral: LiteralValue?, range: RangeExpr): RangeExpr {
|
||||
val newFrom = rangeFrom.intoDatatype(targetDt)
|
||||
val newTo = rangeTo.intoDatatype(targetDt)
|
||||
if (newFrom != null && newTo != null) {
|
||||
val newStep: IExpression =
|
||||
if (stepLiteral != null) (stepLiteral.intoDatatype(targetDt) ?: stepLiteral) else range.step
|
||||
return RangeExpr(newFrom, newTo, newStep, range.position)
|
||||
}
|
||||
return range
|
||||
}
|
||||
|
||||
// adjust the datatype of a range expression in for loops to the loop variable.
|
||||
val resultStmt = super.process(forLoop) as ForLoop
|
||||
val iterableRange = resultStmt.iterable as? RangeExpr ?: return resultStmt
|
||||
val rangeFrom = iterableRange.from as? LiteralValue
|
||||
val rangeTo = iterableRange.to as? LiteralValue
|
||||
if(rangeFrom==null || rangeTo==null) return resultStmt
|
||||
|
||||
val loopvar = resultStmt.loopVar!!.targetStatement(namespace) as? VarDecl
|
||||
if(loopvar!=null) {
|
||||
val stepLiteral = iterableRange.step as? LiteralValue
|
||||
when(loopvar.datatype) {
|
||||
DataType.UBYTE -> {
|
||||
if(rangeFrom.type!=DataType.UBYTE) {
|
||||
// attempt to translate the iterable into ubyte values
|
||||
resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
|
||||
}
|
||||
}
|
||||
DataType.BYTE -> {
|
||||
if(rangeFrom.type!=DataType.BYTE) {
|
||||
// attempt to translate the iterable into byte values
|
||||
resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
|
||||
}
|
||||
}
|
||||
DataType.UWORD -> {
|
||||
if(rangeFrom.type!=DataType.UWORD) {
|
||||
// attempt to translate the iterable into uword values
|
||||
resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
|
||||
}
|
||||
}
|
||||
DataType.WORD -> {
|
||||
if(rangeFrom.type!=DataType.WORD) {
|
||||
// attempt to translate the iterable into word values
|
||||
resultStmt.iterable = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
|
||||
}
|
||||
}
|
||||
else -> throw FatalAstException("invalid loopvar datatype $loopvar")
|
||||
}
|
||||
}
|
||||
return resultStmt
|
||||
}
|
||||
|
||||
override fun process(literalValue: LiteralValue): LiteralValue {
|
||||
|
@ -379,25 +379,25 @@ class StackVm(private var traceOutputFile: String?) {
|
||||
checkDt(second, DataType.FLOAT)
|
||||
evalstack.push(second.add(top))
|
||||
}
|
||||
Opcode.SUB_UB -> {
|
||||
Opcode.SUB_UB, Opcode.CMP_UB -> {
|
||||
val (top, second) = evalstack.pop2()
|
||||
checkDt(top, DataType.UBYTE)
|
||||
checkDt(second, DataType.UBYTE)
|
||||
evalstack.push(second.sub(top))
|
||||
}
|
||||
Opcode.SUB_UW -> {
|
||||
Opcode.SUB_UW, Opcode.CMP_UW -> {
|
||||
val (top, second) = evalstack.pop2()
|
||||
checkDt(top, DataType.UWORD)
|
||||
checkDt(second, DataType.UWORD)
|
||||
evalstack.push(second.sub(top))
|
||||
}
|
||||
Opcode.SUB_B -> {
|
||||
Opcode.SUB_B, Opcode.CMP_B -> {
|
||||
val (top, second) = evalstack.pop2()
|
||||
checkDt(top, DataType.BYTE)
|
||||
checkDt(second, DataType.BYTE)
|
||||
evalstack.push(second.sub(top))
|
||||
}
|
||||
Opcode.SUB_W -> {
|
||||
Opcode.SUB_W, Opcode.CMP_W -> {
|
||||
val (top, second) = evalstack.pop2()
|
||||
checkDt(top, DataType.WORD)
|
||||
checkDt(second, DataType.WORD)
|
||||
|
@ -4,13 +4,27 @@
|
||||
|
||||
sub start() {
|
||||
|
||||
const word height=25
|
||||
ubyte i
|
||||
byte j
|
||||
uword uw
|
||||
word w
|
||||
|
||||
word rz=33
|
||||
word persp = (rz+200)
|
||||
persp = rz / 25
|
||||
persp = rz / height
|
||||
persp = (rz+200) / height
|
||||
for i in 5 to 0 step -1 {
|
||||
c64scr.print_ub(i)
|
||||
c64.CHROUT('\n')
|
||||
}
|
||||
c64.CHROUT('\n')
|
||||
|
||||
for j in 5 to 0 step -1 {
|
||||
c64scr.print_b(j)
|
||||
c64.CHROUT('\n')
|
||||
}
|
||||
c64.CHROUT('\n')
|
||||
|
||||
for j in -5 to 0 {
|
||||
c64scr.print_b(j)
|
||||
c64.CHROUT('\n')
|
||||
}
|
||||
c64.CHROUT('\n')
|
||||
}
|
||||
}
|
||||
|
@ -52,16 +52,14 @@ sub irq() {
|
||||
angle++
|
||||
c64.MSIGX=0
|
||||
|
||||
ubyte i=14
|
||||
nextsprite: ; @todo should be a for loop from 14 to 0 step -2 but this causes a value out of range error at the moment
|
||||
uword x = sin8u(angle*2-i*8) as uword + 50
|
||||
ubyte y = cos8u(angle*3-i*8) / 2 + 70
|
||||
c64.SPXY[i] = lsb(x)
|
||||
c64.SPXY[i+1] = y
|
||||
lsl(c64.MSIGX)
|
||||
if msb(x) c64.MSIGX++
|
||||
i-=2
|
||||
if_pl goto nextsprite
|
||||
for ubyte i in 14 to 0 step -2 {
|
||||
uword x = sin8u(angle*2-i*8) as uword + 50
|
||||
ubyte y = cos8u(angle*3-i*8) / 2 + 70
|
||||
c64.SPXY[i] = lsb(x)
|
||||
c64.SPXY[i+1] = y
|
||||
lsl(c64.MSIGX)
|
||||
if msb(x) c64.MSIGX++
|
||||
}
|
||||
|
||||
c64.EXTCOL++
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user