1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-26 20:33:02 +00:00

6502: Faster summing for-loops

This commit is contained in:
Karol Stasiak 2019-06-25 18:23:31 +02:00
parent 260cfd50c4
commit 2282e56845
4 changed files with 135 additions and 3 deletions

View File

@ -38,6 +38,8 @@ where `11111` is a sequential number and `xx` is the type:
* `fe` body of an `for` statement over a list
* `fo` certain optimized `for` loops
* `he` beginning of the body of a `while` statement
* `in` increment for larger types

View File

@ -1,9 +1,9 @@
package millfork.compiler.mos
import millfork.CompilationFlag
import millfork.assembly.mos.{AssemblyLine, AssemblyLine0}
import millfork.compiler.{BranchSpec, CompilationContext}
import millfork.env.{NumericConstant, Type, VariableInMemory}
import millfork.assembly.mos.{AddrMode, AssemblyLine, AssemblyLine0, Opcode}
import millfork.compiler.{AbstractExpressionCompiler, BranchSpec, CompilationContext}
import millfork.env.{Label, MemoryAddressConstant, NumericConstant, Type, Variable, VariableInMemory}
import millfork.node._
import millfork.assembly.mos.Opcode._
@ -176,4 +176,101 @@ object MosBulkMemoryOperations {
}
loadAll ++ setWholePages ++ setRest
}
def compileFold(ctx: CompilationContext, targetExpression: VariableExpression, operator: String, source: Expression, f: ForStatement): Option[List[AssemblyLine]] = {
import AddrMode._
import ForDirection._
val target = ctx.env.maybeGet[Variable](targetExpression.name).getOrElse(return None)
if (target.isVolatile) return None
val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, source)
if (sourceType.size != 1) return None
if (!sourceType.isArithmetic) return None
if (operator == "+=") {
if (!target.typ.isArithmetic) return None
if (target.typ.size > 2) return None
if (target.typ.size == 2 && sourceType.isSigned) return None
} else {
if (!target.typ.isArithmetic) return None
if (target.typ.size != 1) return None
}
val indexVariable = ctx.env.get[Variable](f.variable)
val loadSource = MosExpressionCompiler.compileToA(ctx, source) match {
case List(AssemblyLine0(LDY, Absolute | ZeroPage, MemoryAddressConstant(index)), l@AssemblyLine0(LDA, AbsoluteY | IndexedY, _)) if index.name == indexVariable.name => l
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => l
case _ => return None
}
def isSimple(l: List[AssemblyLine]): Boolean = l match {
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => true
case _ => false
}
def isConstant(l: List[AssemblyLine], c: Long => Boolean): Boolean = l match {
case List(l@AssemblyLine0(LDA, Immediate, NumericConstant(n, _))) => c(n)
case _ => false
}
def targetifyY(l: List[AssemblyLine]): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = LDY)) else l :+ AssemblyLine.implied(TAY)
def toCpy(l: List[AssemblyLine], branch: Opcode.Value, label: String): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = CPY)) :+ AssemblyLine.relative(branch, label) else throw new IllegalArgumentException
def branch(branch: Opcode.Value, label: String): List[AssemblyLine] = if (branch == JMP) List(AssemblyLine.absolute(JMP, Label(label))) else List(AssemblyLine.relative(branch, label))
val loadStart = MosExpressionCompiler.compileToA(ctx, f.start)
val loadEnd = MosExpressionCompiler.compileToA(ctx, f.end)
val dey = AssemblyLine.implied(DEY)
val iny = AssemblyLine.implied(INY)
lazy val skipLabel = ctx.nextLabel("fo") // TODO
val loopLabel = ctx.nextLabel("fo") // TODO
val loopStart = if (operator.contains("'")) List(AssemblyLine.implied(SED), AssemblyLine.label(loopLabel)) else List(AssemblyLine.label(loopLabel))
val cld = if (operator.contains("'")) List(AssemblyLine.implied(CLD)) else Nil
lazy val loopSkip = List(AssemblyLine.label(skipLabel))
val frame: (List[AssemblyLine], List[AssemblyLine]) = f.direction match {
case ParallelTo | DownTo if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
(targetifyY(loadEnd) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
case ParallelTo | To if isConstant(loadStart, i => i > 0 && i < 255) && isConstant(loadEnd, _ == 255) =>
(targetifyY(loadStart) ++ loopStart) -> (iny :: branch(BNE, loopLabel) ++ cld)
case ParallelUntil if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
(targetifyY(loadEnd).map(l => l.copy(parameter = l.parameter - 1)) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
case ParallelUntil if isConstant(loadStart, _ == 0) =>
(targetifyY(loadEnd) ++ loopStart :+ dey) -> (toCpy(loadStart, BNE, loopLabel) ++ cld)
case Until | ParallelUntil =>
if (isSimple(loadEnd)) {
if (cld.isEmpty) {
(targetifyY(loadStart) ++ branch(JMP, skipLabel) ++ loopStart) -> (iny :: loopSkip ++ toCpy(loadEnd, BNE, loopLabel))
} else {
(targetifyY(loadStart) ++ toCpy(loadEnd, BEQ, skipLabel) ++ loopStart) -> (iny :: toCpy(loadEnd, BNE, loopLabel) ++ cld ++ loopSkip)
}
} else return None
case To | ParallelTo | DownTo =>
return None
}
val opcode = operator match {
case "+=" | "+'=" => ADC
case "-=" | "-'=" => SBC
case "|=" => ORA
case "&=" => AND
case "^=" => EOR
case _ => throw new IllegalArgumentException(operator)
}
val carry = operator match {
case "+=" | "+'=" => List(AssemblyLine.implied(CLC))
case "-=" | "-'=" => List(AssemblyLine.implied(SEC))
case _ => Nil
}
val incHi = if (target.typ.size == 2) {
val l = ctx.nextLabel("ah")
List(AssemblyLine.relative(BCC, l), AssemblyLine.implied(INX), AssemblyLine.label(l))
} else Nil
val body = (carry ++ List(loadSource.copy(opcode = opcode)) ++ incHi).map(_.position(targetExpression.position))
val loadTarget = if (target.typ.size == 2) MosExpressionCompiler.compileToAX(ctx, targetExpression) else MosExpressionCompiler.compileToA(ctx, targetExpression)
val storeTarget =
if (target.typ.size == 2) MosExpressionCompiler.expressionStorageFromAX(ctx, Some(target.typ -> target), targetExpression.position)
else MosExpressionCompiler.expressionStorageFromA(ctx, Some(target.typ -> target), targetExpression.position)
Some(loadTarget ++ frame._1 ++ body ++ frame._2 ++ storeTarget)
}
}

View File

@ -285,6 +285,12 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
compileDoWhileStatement(ctx, s)
case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) =>
MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil
case f@ForStatement(variable, start, end, _, List(ExpressionStatement(FunctionCallExpression(operator@("+=" | "-=" | "+'=" | "-'=" | "|=" | "^=" | "&="), List(target: VariableExpression, source)))))
if !target.containsVariable(variable) && !start.containsVariable(target.name) && !end.containsVariable(target.name) =>
MosBulkMemoryOperations.compileFold(ctx, target, operator, source, f) match {
case Some(x) => x -> Nil
case None => compileForStatement(ctx, f)
}
case f:ForStatement =>
compileForStatement(ctx,f)
case f:ForEachStatement =>

View File

@ -369,4 +369,31 @@ class ForLoopSuite extends FunSuite with Matchers {
m.readByte(0xc008) should equal(42)
}
}
test("Folding loops") {
EmuUnoptimizedRun(
"""
|array a [100]
|void main() {
| byte sum
| byte i
| pointer p
| p = a.addr
| sum = 0
| for i,0,paralleluntil,100 { sum += a[i] }
| for i,0,paralleluntil,100 { sum +'= a[i] }
| for i,0,paralleluntil,100 { sum &= a[i] }
| for i,0,until,100 { sum &= a[i] }
| for i,0,until,50 { sum &= a[i+1] }
| for i,0,parallelto,50 { sum &= a[i] }
| for i,0,until,100 { sum &= p[i] }
| word wsum
| for i,0,paralleluntil,100 { wsum += a[i] }
| stack byte ssum
| stack word swsum
| for i,0,paralleluntil,100 { ssum += a[i] }
| for i,0,paralleluntil,100 { swsum += a[i] }
|}
""".stripMargin)
}
}