mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-26 20:33:02 +00:00
6502: Faster summing for-loops
This commit is contained in:
parent
260cfd50c4
commit
2282e56845
@ -38,6 +38,8 @@ where `11111` is a sequential number and `xx` is the type:
|
||||
|
||||
* `fe` – body of an `for` statement over a list
|
||||
|
||||
* `fo` – certain optimized `for` loops
|
||||
|
||||
* `he` – beginning of the body of a `while` statement
|
||||
|
||||
* `in` – increment for larger types
|
||||
|
@ -1,9 +1,9 @@
|
||||
package millfork.compiler.mos
|
||||
|
||||
import millfork.CompilationFlag
|
||||
import millfork.assembly.mos.{AssemblyLine, AssemblyLine0}
|
||||
import millfork.compiler.{BranchSpec, CompilationContext}
|
||||
import millfork.env.{NumericConstant, Type, VariableInMemory}
|
||||
import millfork.assembly.mos.{AddrMode, AssemblyLine, AssemblyLine0, Opcode}
|
||||
import millfork.compiler.{AbstractExpressionCompiler, BranchSpec, CompilationContext}
|
||||
import millfork.env.{Label, MemoryAddressConstant, NumericConstant, Type, Variable, VariableInMemory}
|
||||
import millfork.node._
|
||||
import millfork.assembly.mos.Opcode._
|
||||
|
||||
@ -176,4 +176,101 @@ object MosBulkMemoryOperations {
|
||||
}
|
||||
loadAll ++ setWholePages ++ setRest
|
||||
}
|
||||
|
||||
def compileFold(ctx: CompilationContext, targetExpression: VariableExpression, operator: String, source: Expression, f: ForStatement): Option[List[AssemblyLine]] = {
|
||||
import AddrMode._
|
||||
import ForDirection._
|
||||
val target = ctx.env.maybeGet[Variable](targetExpression.name).getOrElse(return None)
|
||||
if (target.isVolatile) return None
|
||||
val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, source)
|
||||
if (sourceType.size != 1) return None
|
||||
if (!sourceType.isArithmetic) return None
|
||||
if (operator == "+=") {
|
||||
if (!target.typ.isArithmetic) return None
|
||||
if (target.typ.size > 2) return None
|
||||
if (target.typ.size == 2 && sourceType.isSigned) return None
|
||||
} else {
|
||||
if (!target.typ.isArithmetic) return None
|
||||
if (target.typ.size != 1) return None
|
||||
}
|
||||
val indexVariable = ctx.env.get[Variable](f.variable)
|
||||
val loadSource = MosExpressionCompiler.compileToA(ctx, source) match {
|
||||
case List(AssemblyLine0(LDY, Absolute | ZeroPage, MemoryAddressConstant(index)), l@AssemblyLine0(LDA, AbsoluteY | IndexedY, _)) if index.name == indexVariable.name => l
|
||||
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => l
|
||||
case _ => return None
|
||||
}
|
||||
|
||||
def isSimple(l: List[AssemblyLine]): Boolean = l match {
|
||||
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def isConstant(l: List[AssemblyLine], c: Long => Boolean): Boolean = l match {
|
||||
case List(l@AssemblyLine0(LDA, Immediate, NumericConstant(n, _))) => c(n)
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def targetifyY(l: List[AssemblyLine]): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = LDY)) else l :+ AssemblyLine.implied(TAY)
|
||||
|
||||
def toCpy(l: List[AssemblyLine], branch: Opcode.Value, label: String): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = CPY)) :+ AssemblyLine.relative(branch, label) else throw new IllegalArgumentException
|
||||
def branch(branch: Opcode.Value, label: String): List[AssemblyLine] = if (branch == JMP) List(AssemblyLine.absolute(JMP, Label(label))) else List(AssemblyLine.relative(branch, label))
|
||||
|
||||
val loadStart = MosExpressionCompiler.compileToA(ctx, f.start)
|
||||
val loadEnd = MosExpressionCompiler.compileToA(ctx, f.end)
|
||||
|
||||
val dey = AssemblyLine.implied(DEY)
|
||||
val iny = AssemblyLine.implied(INY)
|
||||
lazy val skipLabel = ctx.nextLabel("fo") // TODO
|
||||
val loopLabel = ctx.nextLabel("fo") // TODO
|
||||
val loopStart = if (operator.contains("'")) List(AssemblyLine.implied(SED), AssemblyLine.label(loopLabel)) else List(AssemblyLine.label(loopLabel))
|
||||
val cld = if (operator.contains("'")) List(AssemblyLine.implied(CLD)) else Nil
|
||||
lazy val loopSkip = List(AssemblyLine.label(skipLabel))
|
||||
|
||||
val frame: (List[AssemblyLine], List[AssemblyLine]) = f.direction match {
|
||||
case ParallelTo | DownTo if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
|
||||
(targetifyY(loadEnd) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
|
||||
case ParallelTo | To if isConstant(loadStart, i => i > 0 && i < 255) && isConstant(loadEnd, _ == 255) =>
|
||||
(targetifyY(loadStart) ++ loopStart) -> (iny :: branch(BNE, loopLabel) ++ cld)
|
||||
case ParallelUntil if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
|
||||
(targetifyY(loadEnd).map(l => l.copy(parameter = l.parameter - 1)) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
|
||||
case ParallelUntil if isConstant(loadStart, _ == 0) =>
|
||||
(targetifyY(loadEnd) ++ loopStart :+ dey) -> (toCpy(loadStart, BNE, loopLabel) ++ cld)
|
||||
case Until | ParallelUntil =>
|
||||
if (isSimple(loadEnd)) {
|
||||
if (cld.isEmpty) {
|
||||
(targetifyY(loadStart) ++ branch(JMP, skipLabel) ++ loopStart) -> (iny :: loopSkip ++ toCpy(loadEnd, BNE, loopLabel))
|
||||
} else {
|
||||
(targetifyY(loadStart) ++ toCpy(loadEnd, BEQ, skipLabel) ++ loopStart) -> (iny :: toCpy(loadEnd, BNE, loopLabel) ++ cld ++ loopSkip)
|
||||
}
|
||||
} else return None
|
||||
case To | ParallelTo | DownTo =>
|
||||
return None
|
||||
}
|
||||
val opcode = operator match {
|
||||
case "+=" | "+'=" => ADC
|
||||
case "-=" | "-'=" => SBC
|
||||
case "|=" => ORA
|
||||
case "&=" => AND
|
||||
case "^=" => EOR
|
||||
case _ => throw new IllegalArgumentException(operator)
|
||||
}
|
||||
val carry = operator match {
|
||||
case "+=" | "+'=" => List(AssemblyLine.implied(CLC))
|
||||
case "-=" | "-'=" => List(AssemblyLine.implied(SEC))
|
||||
case _ => Nil
|
||||
}
|
||||
val incHi = if (target.typ.size == 2) {
|
||||
val l = ctx.nextLabel("ah")
|
||||
List(AssemblyLine.relative(BCC, l), AssemblyLine.implied(INX), AssemblyLine.label(l))
|
||||
} else Nil
|
||||
|
||||
val body = (carry ++ List(loadSource.copy(opcode = opcode)) ++ incHi).map(_.position(targetExpression.position))
|
||||
|
||||
val loadTarget = if (target.typ.size == 2) MosExpressionCompiler.compileToAX(ctx, targetExpression) else MosExpressionCompiler.compileToA(ctx, targetExpression)
|
||||
val storeTarget =
|
||||
if (target.typ.size == 2) MosExpressionCompiler.expressionStorageFromAX(ctx, Some(target.typ -> target), targetExpression.position)
|
||||
else MosExpressionCompiler.expressionStorageFromA(ctx, Some(target.typ -> target), targetExpression.position)
|
||||
|
||||
Some(loadTarget ++ frame._1 ++ body ++ frame._2 ++ storeTarget)
|
||||
}
|
||||
}
|
||||
|
@ -285,6 +285,12 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
|
||||
compileDoWhileStatement(ctx, s)
|
||||
case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) =>
|
||||
MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil
|
||||
case f@ForStatement(variable, start, end, _, List(ExpressionStatement(FunctionCallExpression(operator@("+=" | "-=" | "+'=" | "-'=" | "|=" | "^=" | "&="), List(target: VariableExpression, source)))))
|
||||
if !target.containsVariable(variable) && !start.containsVariable(target.name) && !end.containsVariable(target.name) =>
|
||||
MosBulkMemoryOperations.compileFold(ctx, target, operator, source, f) match {
|
||||
case Some(x) => x -> Nil
|
||||
case None => compileForStatement(ctx, f)
|
||||
}
|
||||
case f:ForStatement =>
|
||||
compileForStatement(ctx,f)
|
||||
case f:ForEachStatement =>
|
||||
|
@ -369,4 +369,31 @@ class ForLoopSuite extends FunSuite with Matchers {
|
||||
m.readByte(0xc008) should equal(42)
|
||||
}
|
||||
}
|
||||
|
||||
test("Folding loops") {
|
||||
EmuUnoptimizedRun(
|
||||
"""
|
||||
|array a [100]
|
||||
|void main() {
|
||||
| byte sum
|
||||
| byte i
|
||||
| pointer p
|
||||
| p = a.addr
|
||||
| sum = 0
|
||||
| for i,0,paralleluntil,100 { sum += a[i] }
|
||||
| for i,0,paralleluntil,100 { sum +'= a[i] }
|
||||
| for i,0,paralleluntil,100 { sum &= a[i] }
|
||||
| for i,0,until,100 { sum &= a[i] }
|
||||
| for i,0,until,50 { sum &= a[i+1] }
|
||||
| for i,0,parallelto,50 { sum &= a[i] }
|
||||
| for i,0,until,100 { sum &= p[i] }
|
||||
| word wsum
|
||||
| for i,0,paralleluntil,100 { wsum += a[i] }
|
||||
| stack byte ssum
|
||||
| stack word swsum
|
||||
| for i,0,paralleluntil,100 { ssum += a[i] }
|
||||
| for i,0,paralleluntil,100 { swsum += a[i] }
|
||||
|}
|
||||
""".stripMargin)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user