mirror of
https://github.com/KarolS/millfork.git
synced 2024-09-29 09:55:38 +00:00
6502: Faster summing for-loops
This commit is contained in:
parent
260cfd50c4
commit
2282e56845
@ -38,6 +38,8 @@ where `11111` is a sequential number and `xx` is the type:
|
|||||||
|
|
||||||
* `fe` – body of an `for` statement over a list
|
* `fe` – body of an `for` statement over a list
|
||||||
|
|
||||||
|
* `fo` – certain optimized `for` loops
|
||||||
|
|
||||||
* `he` – beginning of the body of a `while` statement
|
* `he` – beginning of the body of a `while` statement
|
||||||
|
|
||||||
* `in` – increment for larger types
|
* `in` – increment for larger types
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package millfork.compiler.mos
|
package millfork.compiler.mos
|
||||||
|
|
||||||
import millfork.CompilationFlag
|
import millfork.CompilationFlag
|
||||||
import millfork.assembly.mos.{AssemblyLine, AssemblyLine0}
|
import millfork.assembly.mos.{AddrMode, AssemblyLine, AssemblyLine0, Opcode}
|
||||||
import millfork.compiler.{BranchSpec, CompilationContext}
|
import millfork.compiler.{AbstractExpressionCompiler, BranchSpec, CompilationContext}
|
||||||
import millfork.env.{NumericConstant, Type, VariableInMemory}
|
import millfork.env.{Label, MemoryAddressConstant, NumericConstant, Type, Variable, VariableInMemory}
|
||||||
import millfork.node._
|
import millfork.node._
|
||||||
import millfork.assembly.mos.Opcode._
|
import millfork.assembly.mos.Opcode._
|
||||||
|
|
||||||
@ -176,4 +176,101 @@ object MosBulkMemoryOperations {
|
|||||||
}
|
}
|
||||||
loadAll ++ setWholePages ++ setRest
|
loadAll ++ setWholePages ++ setRest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def compileFold(ctx: CompilationContext, targetExpression: VariableExpression, operator: String, source: Expression, f: ForStatement): Option[List[AssemblyLine]] = {
|
||||||
|
import AddrMode._
|
||||||
|
import ForDirection._
|
||||||
|
val target = ctx.env.maybeGet[Variable](targetExpression.name).getOrElse(return None)
|
||||||
|
if (target.isVolatile) return None
|
||||||
|
val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, source)
|
||||||
|
if (sourceType.size != 1) return None
|
||||||
|
if (!sourceType.isArithmetic) return None
|
||||||
|
if (operator == "+=") {
|
||||||
|
if (!target.typ.isArithmetic) return None
|
||||||
|
if (target.typ.size > 2) return None
|
||||||
|
if (target.typ.size == 2 && sourceType.isSigned) return None
|
||||||
|
} else {
|
||||||
|
if (!target.typ.isArithmetic) return None
|
||||||
|
if (target.typ.size != 1) return None
|
||||||
|
}
|
||||||
|
val indexVariable = ctx.env.get[Variable](f.variable)
|
||||||
|
val loadSource = MosExpressionCompiler.compileToA(ctx, source) match {
|
||||||
|
case List(AssemblyLine0(LDY, Absolute | ZeroPage, MemoryAddressConstant(index)), l@AssemblyLine0(LDA, AbsoluteY | IndexedY, _)) if index.name == indexVariable.name => l
|
||||||
|
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => l
|
||||||
|
case _ => return None
|
||||||
|
}
|
||||||
|
|
||||||
|
def isSimple(l: List[AssemblyLine]): Boolean = l match {
|
||||||
|
case List(l@AssemblyLine0(LDA, Absolute | ZeroPage | Immediate, _)) => true
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
def isConstant(l: List[AssemblyLine], c: Long => Boolean): Boolean = l match {
|
||||||
|
case List(l@AssemblyLine0(LDA, Immediate, NumericConstant(n, _))) => c(n)
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
def targetifyY(l: List[AssemblyLine]): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = LDY)) else l :+ AssemblyLine.implied(TAY)
|
||||||
|
|
||||||
|
def toCpy(l: List[AssemblyLine], branch: Opcode.Value, label: String): List[AssemblyLine] = if (isSimple(l)) l.map(_.copy(opcode = CPY)) :+ AssemblyLine.relative(branch, label) else throw new IllegalArgumentException
|
||||||
|
def branch(branch: Opcode.Value, label: String): List[AssemblyLine] = if (branch == JMP) List(AssemblyLine.absolute(JMP, Label(label))) else List(AssemblyLine.relative(branch, label))
|
||||||
|
|
||||||
|
val loadStart = MosExpressionCompiler.compileToA(ctx, f.start)
|
||||||
|
val loadEnd = MosExpressionCompiler.compileToA(ctx, f.end)
|
||||||
|
|
||||||
|
val dey = AssemblyLine.implied(DEY)
|
||||||
|
val iny = AssemblyLine.implied(INY)
|
||||||
|
lazy val skipLabel = ctx.nextLabel("fo") // TODO
|
||||||
|
val loopLabel = ctx.nextLabel("fo") // TODO
|
||||||
|
val loopStart = if (operator.contains("'")) List(AssemblyLine.implied(SED), AssemblyLine.label(loopLabel)) else List(AssemblyLine.label(loopLabel))
|
||||||
|
val cld = if (operator.contains("'")) List(AssemblyLine.implied(CLD)) else Nil
|
||||||
|
lazy val loopSkip = List(AssemblyLine.label(skipLabel))
|
||||||
|
|
||||||
|
val frame: (List[AssemblyLine], List[AssemblyLine]) = f.direction match {
|
||||||
|
case ParallelTo | DownTo if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
|
||||||
|
(targetifyY(loadEnd) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
|
||||||
|
case ParallelTo | To if isConstant(loadStart, i => i > 0 && i < 255) && isConstant(loadEnd, _ == 255) =>
|
||||||
|
(targetifyY(loadStart) ++ loopStart) -> (iny :: branch(BNE, loopLabel) ++ cld)
|
||||||
|
case ParallelUntil if isConstant(loadStart, _ == 0) && isConstant(loadEnd, i => i > 0 && i <= 0x7f) =>
|
||||||
|
(targetifyY(loadEnd).map(l => l.copy(parameter = l.parameter - 1)) ++ loopStart) -> (dey :: branch(BPL, loopLabel) ++ cld)
|
||||||
|
case ParallelUntil if isConstant(loadStart, _ == 0) =>
|
||||||
|
(targetifyY(loadEnd) ++ loopStart :+ dey) -> (toCpy(loadStart, BNE, loopLabel) ++ cld)
|
||||||
|
case Until | ParallelUntil =>
|
||||||
|
if (isSimple(loadEnd)) {
|
||||||
|
if (cld.isEmpty) {
|
||||||
|
(targetifyY(loadStart) ++ branch(JMP, skipLabel) ++ loopStart) -> (iny :: loopSkip ++ toCpy(loadEnd, BNE, loopLabel))
|
||||||
|
} else {
|
||||||
|
(targetifyY(loadStart) ++ toCpy(loadEnd, BEQ, skipLabel) ++ loopStart) -> (iny :: toCpy(loadEnd, BNE, loopLabel) ++ cld ++ loopSkip)
|
||||||
|
}
|
||||||
|
} else return None
|
||||||
|
case To | ParallelTo | DownTo =>
|
||||||
|
return None
|
||||||
|
}
|
||||||
|
val opcode = operator match {
|
||||||
|
case "+=" | "+'=" => ADC
|
||||||
|
case "-=" | "-'=" => SBC
|
||||||
|
case "|=" => ORA
|
||||||
|
case "&=" => AND
|
||||||
|
case "^=" => EOR
|
||||||
|
case _ => throw new IllegalArgumentException(operator)
|
||||||
|
}
|
||||||
|
val carry = operator match {
|
||||||
|
case "+=" | "+'=" => List(AssemblyLine.implied(CLC))
|
||||||
|
case "-=" | "-'=" => List(AssemblyLine.implied(SEC))
|
||||||
|
case _ => Nil
|
||||||
|
}
|
||||||
|
val incHi = if (target.typ.size == 2) {
|
||||||
|
val l = ctx.nextLabel("ah")
|
||||||
|
List(AssemblyLine.relative(BCC, l), AssemblyLine.implied(INX), AssemblyLine.label(l))
|
||||||
|
} else Nil
|
||||||
|
|
||||||
|
val body = (carry ++ List(loadSource.copy(opcode = opcode)) ++ incHi).map(_.position(targetExpression.position))
|
||||||
|
|
||||||
|
val loadTarget = if (target.typ.size == 2) MosExpressionCompiler.compileToAX(ctx, targetExpression) else MosExpressionCompiler.compileToA(ctx, targetExpression)
|
||||||
|
val storeTarget =
|
||||||
|
if (target.typ.size == 2) MosExpressionCompiler.expressionStorageFromAX(ctx, Some(target.typ -> target), targetExpression.position)
|
||||||
|
else MosExpressionCompiler.expressionStorageFromA(ctx, Some(target.typ -> target), targetExpression.position)
|
||||||
|
|
||||||
|
Some(loadTarget ++ frame._1 ++ body ++ frame._2 ++ storeTarget)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -285,6 +285,12 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
|
|||||||
compileDoWhileStatement(ctx, s)
|
compileDoWhileStatement(ctx, s)
|
||||||
case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) =>
|
case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) =>
|
||||||
MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil
|
MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil
|
||||||
|
case f@ForStatement(variable, start, end, _, List(ExpressionStatement(FunctionCallExpression(operator@("+=" | "-=" | "+'=" | "-'=" | "|=" | "^=" | "&="), List(target: VariableExpression, source)))))
|
||||||
|
if !target.containsVariable(variable) && !start.containsVariable(target.name) && !end.containsVariable(target.name) =>
|
||||||
|
MosBulkMemoryOperations.compileFold(ctx, target, operator, source, f) match {
|
||||||
|
case Some(x) => x -> Nil
|
||||||
|
case None => compileForStatement(ctx, f)
|
||||||
|
}
|
||||||
case f:ForStatement =>
|
case f:ForStatement =>
|
||||||
compileForStatement(ctx,f)
|
compileForStatement(ctx,f)
|
||||||
case f:ForEachStatement =>
|
case f:ForEachStatement =>
|
||||||
|
@ -369,4 +369,31 @@ class ForLoopSuite extends FunSuite with Matchers {
|
|||||||
m.readByte(0xc008) should equal(42)
|
m.readByte(0xc008) should equal(42)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Folding loops") {
|
||||||
|
EmuUnoptimizedRun(
|
||||||
|
"""
|
||||||
|
|array a [100]
|
||||||
|
|void main() {
|
||||||
|
| byte sum
|
||||||
|
| byte i
|
||||||
|
| pointer p
|
||||||
|
| p = a.addr
|
||||||
|
| sum = 0
|
||||||
|
| for i,0,paralleluntil,100 { sum += a[i] }
|
||||||
|
| for i,0,paralleluntil,100 { sum +'= a[i] }
|
||||||
|
| for i,0,paralleluntil,100 { sum &= a[i] }
|
||||||
|
| for i,0,until,100 { sum &= a[i] }
|
||||||
|
| for i,0,until,50 { sum &= a[i+1] }
|
||||||
|
| for i,0,parallelto,50 { sum &= a[i] }
|
||||||
|
| for i,0,until,100 { sum &= p[i] }
|
||||||
|
| word wsum
|
||||||
|
| for i,0,paralleluntil,100 { wsum += a[i] }
|
||||||
|
| stack byte ssum
|
||||||
|
| stack word swsum
|
||||||
|
| for i,0,paralleluntil,100 { ssum += a[i] }
|
||||||
|
| for i,0,paralleluntil,100 { swsum += a[i] }
|
||||||
|
|}
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user