mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-01 06:29:53 +00:00
8080: Use pointers instead of indexing when traversing an array in a loop
This commit is contained in:
parent
8304650b3e
commit
a3b21c4810
@ -1,12 +1,139 @@
|
||||
package millfork.compiler.z80
|
||||
|
||||
import millfork.compiler.{AbstractStatementPreprocessor, CompilationContext}
|
||||
import millfork.node.{ExecutableStatement, ForStatement}
|
||||
import millfork.env.{MemoryVariable, MfArray, PointerType, Thing, Variable, VariableAllocationMethod}
|
||||
import millfork.node.{Assignment, DerefDebuggingExpression, DoWhileStatement, ExecutableStatement, Expression, ExpressionStatement, ForDirection, ForEachStatement, ForStatement, FunctionCallExpression, IfStatement, IndexedExpression, IndirectFieldExpression, LhsExpression, LiteralExpression, Node, Statement, SumExpression, VariableDeclarationStatement, VariableExpression, WhileStatement}
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
class Z80StatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) extends AbstractStatementPreprocessor(ctx, statements) {
|
||||
|
||||
def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = None
|
||||
def findIndexedArrays(nodes: Seq[Node], variable: String): Seq[String] = nodes.flatMap(n => findIndexedArrays(n, variable))
|
||||
|
||||
def findIndexedArrays(node: Node, variable: String): Seq[String] = node match {
|
||||
case f: ForStatement =>
|
||||
if (f.variable == variable) Nil else f.getChildStatements.flatMap(s => findIndexedArrays(s, variable))
|
||||
case f: Assignment => findIndexedArrays(f.destination, variable) ++ findIndexedArrays(f.source, variable)
|
||||
case f: IfStatement =>
|
||||
findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.thenBranch, variable) ++ findIndexedArrays(f.elseBranch, variable)
|
||||
case f: WhileStatement => findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.body, variable) ++ findIndexedArrays(f.increment, variable)
|
||||
case f: DoWhileStatement => findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.body, variable) ++ findIndexedArrays(f.increment, variable)
|
||||
case f: ForEachStatement => findIndexedArrays(f.values.right.toOption.getOrElse(Nil), variable) ++ findIndexedArrays(f.body, variable)
|
||||
case f: ExpressionStatement => findIndexedArrays(f.expression, variable)
|
||||
case f: FunctionCallExpression => findIndexedArrays(f.expressions, variable)
|
||||
case f: SumExpression => findIndexedArrays(f.expressions.map(_._2), variable)
|
||||
case f: VariableExpression => Nil
|
||||
case f: LiteralExpression => Nil
|
||||
case f: DerefDebuggingExpression => Nil
|
||||
case IndexedExpression(a, VariableExpression(v)) => if (v == variable) {
|
||||
ctx.env.maybeGet[Thing](a + ".array") match {
|
||||
case Some(_: MfArray) => Seq(a)
|
||||
case _ => Nil
|
||||
}
|
||||
} else Nil
|
||||
case IndexedExpression(_, e) => findIndexedArrays(e, variable)
|
||||
case _ => Nil
|
||||
}
|
||||
|
||||
|
||||
def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = {
|
||||
// TODO: figure out when this is useful
|
||||
// Currently all instances of arr[i] are replaced with arr`popt`i[0], where arr`popt`i is a new pointer variable.
|
||||
// This breaks the main Millfork promise of not using hidden variables!
|
||||
// This may be increase code size or runtime in certain circumstances, more experimentation is needed.
|
||||
if (!optimize) return None
|
||||
if (ctx.env.eval(f.start).isEmpty) return None
|
||||
if (f.variable.contains(".")) return None
|
||||
if (f.start.containsVariable(f.variable)) return None
|
||||
if (f.end.containsVariable(f.variable)) return None
|
||||
val indexVariable = env.get[Variable](f.variable)
|
||||
if (indexVariable.typ.size != 1) return None
|
||||
if (indexVariable.isVolatile) return None
|
||||
indexVariable match {
|
||||
case v: MemoryVariable =>
|
||||
if (v.alloc == VariableAllocationMethod.Static) return None
|
||||
case _ => return None
|
||||
}
|
||||
val indexedArrays = findIndexedArrays(f.body, f.variable).toSet
|
||||
if (indexedArrays.isEmpty) return None
|
||||
if (indexedArrays.size > 2) return None // TODO: is this the optimal limit?
|
||||
for (a <- indexedArrays) {
|
||||
val array = ctx.env.get[MfArray](a + ".array")
|
||||
// Evil hidden memory usage:
|
||||
env.registerVariable(VariableDeclarationStatement(
|
||||
a + "`popt`" + f.variable,
|
||||
"pointer",
|
||||
None,
|
||||
global = false,
|
||||
stack = false,
|
||||
constant = false,
|
||||
volatile = false,
|
||||
register = false,
|
||||
None,
|
||||
None,
|
||||
None
|
||||
), ctx.options, isPointy = true)
|
||||
}
|
||||
|
||||
def replaceArrayIndexingsE(node: Expression): Expression = node.replaceIndexedExpression(
|
||||
i => i.index match {
|
||||
case VariableExpression(vn) => vn == f.variable && indexedArrays(i.name)
|
||||
case _ => false
|
||||
},
|
||||
i => {
|
||||
val array = ctx.env.get[MfArray](i.name + ".array")
|
||||
optimizeExpr(IndirectFieldExpression(
|
||||
FunctionCallExpression("pointer." + array.elementType.name, List(VariableExpression(i.name + "`popt`" + f.variable))),
|
||||
Seq(LiteralExpression(0, 1)),
|
||||
Seq()), Map())
|
||||
}
|
||||
)
|
||||
|
||||
def replaceArrayIndexingsL(node: LhsExpression): LhsExpression = replaceArrayIndexingsE(node.asInstanceOf[Expression]).asInstanceOf[LhsExpression]
|
||||
|
||||
def replaceArrayIndexingsS(node: ExecutableStatement): ExecutableStatement = node match {
|
||||
case Assignment(t, s) => Assignment(replaceArrayIndexingsL(t), replaceArrayIndexingsE(s)).pos(node.position)
|
||||
case ExpressionStatement(e) => ExpressionStatement(replaceArrayIndexingsE(e)).pos(node.position)
|
||||
case IfStatement(c, t, e) => IfStatement(replaceArrayIndexingsE(c), replaceArrayIndexings(t), replaceArrayIndexings(e)).pos(node.position)
|
||||
case WhileStatement(c, b, i, l) => WhileStatement(replaceArrayIndexingsE(c), replaceArrayIndexings(b), replaceArrayIndexings(i), l).pos(node.position)
|
||||
case DoWhileStatement(b, i, c, l) => DoWhileStatement(replaceArrayIndexings(b), replaceArrayIndexings(i), replaceArrayIndexingsE(c), l).pos(node.position)
|
||||
case _ => throw new ArrayIndexOutOfBoundsException // TODO
|
||||
}
|
||||
|
||||
def replaceArrayIndexings(nodes: List[ExecutableStatement]): List[ExecutableStatement] = nodes.map(replaceArrayIndexingsS)
|
||||
|
||||
val newDirection = f.direction match {
|
||||
case ForDirection.ParallelUntil => ForDirection.Until
|
||||
case ForDirection.ParallelTo => ForDirection.To
|
||||
case d => d
|
||||
}
|
||||
val operator = newDirection match {
|
||||
case ForDirection.DownTo => "-="
|
||||
case _ => "+="
|
||||
}
|
||||
try {
|
||||
val newBody = replaceArrayIndexings(f.body) ++ indexedArrays.map(name => {
|
||||
val array = ctx.env.get[MfArray](name + ".array")
|
||||
ExpressionStatement(FunctionCallExpression(operator, List(
|
||||
VariableExpression(name + "`popt`" + f.variable),
|
||||
LiteralExpression(1, 1))))
|
||||
})
|
||||
val optStart = optimizeExpr(f.start, Map())
|
||||
Some(IfStatement(VariableExpression("true"),
|
||||
indexedArrays.map(name => {
|
||||
val array = ctx.env.get[MfArray](name + ".array")
|
||||
Assignment(
|
||||
VariableExpression(name + "`popt`" + f.variable),
|
||||
FunctionCallExpression("pointer." + array.elementType.name, List(
|
||||
SumExpression(List(false -> VariableExpression(name + ".addr"), false -> optStart), decimal = false)
|
||||
)))
|
||||
}).toList :+ ForStatement(f.variable, optStart, optimizeExpr(f.end, Map()), newDirection, optimizeStmts(newBody, Map())._1),
|
||||
Nil
|
||||
) -> Map())
|
||||
} catch {
|
||||
// too complex, give up:
|
||||
case _: ArrayIndexOutOfBoundsException => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ object Node {
|
||||
|
||||
sealed trait Expression extends Node {
|
||||
def replaceVariable(variable: String, actualParam: Expression): Expression
|
||||
def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression
|
||||
def containsVariable(variable: String): Boolean
|
||||
def getPointies: Seq[String]
|
||||
def isPure: Boolean
|
||||
@ -38,6 +39,7 @@ sealed trait Expression extends Node {
|
||||
|
||||
case class ConstantArrayElementExpression(constant: Constant) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -46,6 +48,7 @@ case class ConstantArrayElementExpression(constant: Constant) extends Expression
|
||||
|
||||
case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -54,6 +57,7 @@ case class LiteralExpression(value: Long, requiredSize: Int) extends Expression
|
||||
|
||||
case class TextLiteralExpression(characters: List[Expression]) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -62,6 +66,7 @@ case class TextLiteralExpression(characters: List[Expression]) extends Expressio
|
||||
|
||||
case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -70,6 +75,7 @@ case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expre
|
||||
|
||||
case class BooleanLiteralExpression(value: Boolean) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -80,6 +86,7 @@ sealed trait LhsExpression extends Expression
|
||||
|
||||
case object BlackHoleExpression extends LhsExpression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def getPointies: Seq[String] = Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -91,6 +98,10 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
|
||||
SeparateBytesExpression(
|
||||
hi.replaceVariable(variable, actualParam),
|
||||
lo.replaceVariable(variable, actualParam)).pos(position)
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
SeparateBytesExpression(
|
||||
hi.replaceIndexedExpression(predicate, replacement),
|
||||
lo.replaceIndexedExpression(predicate, replacement)).pos(position)
|
||||
override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable)
|
||||
override def getPointies: Seq[String] = hi.getPointies ++ lo.getPointies
|
||||
override def isPure: Boolean = hi.isPure && lo.isPure
|
||||
@ -100,6 +111,8 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
|
||||
case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression =
|
||||
SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal).pos(position)
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
SumExpression(expressions.map { case (n, e) => n -> e.replaceIndexedExpression(predicate, replacement) }, decimal).pos(position)
|
||||
override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable))
|
||||
override def getPointies: Seq[String] = expressions.flatMap(_._2.getPointies)
|
||||
override def isPure: Boolean = expressions.forall(_._2.isPure)
|
||||
@ -111,6 +124,10 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
|
||||
FunctionCallExpression(functionName, expressions.map {
|
||||
_.replaceVariable(variable, actualParam)
|
||||
}).pos(position)
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
FunctionCallExpression(functionName, expressions.map {
|
||||
_.replaceIndexedExpression(predicate, replacement)
|
||||
}).pos(position)
|
||||
override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable))
|
||||
override def getPointies: Seq[String] = expressions.flatMap(_.getPointies)
|
||||
override def isPure: Boolean = false // TODO
|
||||
@ -120,6 +137,8 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
|
||||
case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression =
|
||||
HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte).pos(position)
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
HalfWordExpression(expression.replaceIndexedExpression(predicate, replacement), hiByte).pos(position)
|
||||
override def containsVariable(variable: String): Boolean = expression.containsVariable(variable)
|
||||
override def getPointies: Seq[String] = expression.getPointies
|
||||
override def isPure: Boolean = expression.isPure
|
||||
@ -183,6 +202,7 @@ object ZRegister extends Enumeration {
|
||||
case class VariableExpression(name: String) extends LhsExpression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression =
|
||||
if (name == variable) actualParam else this
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = name == variable
|
||||
override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.stripSuffix(".addr.lo")) else Seq.empty
|
||||
override def isPure: Boolean = true
|
||||
@ -198,6 +218,9 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress
|
||||
case _ => ??? // TODO
|
||||
}
|
||||
} else IndexedExpression(name, index.replaceVariable(variable, actualParam)).pos(position)
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
if (predicate(this)) replacement(this).pos(position = position)
|
||||
else IndexedExpression(name, index.replaceIndexedExpression(predicate, replacement))
|
||||
override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable)
|
||||
override def getPointies: Seq[String] = Seq(name)
|
||||
override def isPure: Boolean = index.isPure
|
||||
@ -210,6 +233,12 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio
|
||||
root.replaceVariable(variable, actualParam),
|
||||
firstIndices.map(_.replaceVariable(variable, actualParam)),
|
||||
fields.map{case (f, i) => f -> i.map(_.replaceVariable(variable, actualParam))})
|
||||
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
IndirectFieldExpression(
|
||||
root.replaceIndexedExpression(predicate, replacement),
|
||||
firstIndices.map(_.replaceIndexedExpression(predicate, replacement)),
|
||||
fields.map{case (f, i) => f -> i.map(_.replaceIndexedExpression(predicate, replacement))})
|
||||
|
||||
override def containsVariable(variable: String): Boolean =
|
||||
root.containsVariable(variable) ||
|
||||
@ -228,6 +257,9 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio
|
||||
|
||||
case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefDebuggingExpression(inner.replaceVariable(variable, actualParam), preferredSize)
|
||||
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
DerefDebuggingExpression(inner.replaceIndexedExpression(predicate, replacement), preferredSize)
|
||||
|
||||
override def containsVariable(variable: String): Boolean = inner.containsVariable(variable)
|
||||
|
||||
@ -244,6 +276,9 @@ case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) exten
|
||||
case class DerefExpression(inner: Expression, offset: Int, targetType: Type) extends LhsExpression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefExpression(inner.replaceVariable(variable, actualParam), offset, targetType)
|
||||
|
||||
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
|
||||
DerefExpression(inner.replaceIndexedExpression(predicate, replacement), offset, targetType)
|
||||
|
||||
override def containsVariable(variable: String): Boolean = inner.containsVariable(variable)
|
||||
|
||||
override def getPointies: Seq[String] = inner match {
|
||||
|
@ -370,7 +370,7 @@ class ForLoopSuite extends FunSuite with Matchers {
|
||||
}
|
||||
}
|
||||
|
||||
test("Folding loops") {
|
||||
test("Folding loops on 6502") {
|
||||
EmuUnoptimizedRun(
|
||||
"""
|
||||
|array a [100]
|
||||
@ -396,4 +396,19 @@ class ForLoopSuite extends FunSuite with Matchers {
|
||||
|}
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("Folding loops on Z80") {
|
||||
EmuOptimizedZ80Run(
|
||||
"""
|
||||
|array a [10] = [0,1,2,3,4,5,6,7,8,9]
|
||||
|byte output @$c000
|
||||
|void main() {
|
||||
| byte sum
|
||||
| byte i
|
||||
| sum = 0
|
||||
| for i,0,paralleluntil,10 { sum += a[i] }
|
||||
| output = sum
|
||||
|}
|
||||
""".stripMargin).readByte(0xc000) should equal(45)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user