1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-01 06:29:53 +00:00

8080: Use pointers instead of indexing when traversing an array in a loop

This commit is contained in:
Karol Stasiak 2019-06-26 01:56:32 +02:00
parent 8304650b3e
commit a3b21c4810
3 changed files with 180 additions and 3 deletions

View File

@ -1,12 +1,139 @@
package millfork.compiler.z80
import millfork.compiler.{AbstractStatementPreprocessor, CompilationContext}
import millfork.node.{ExecutableStatement, ForStatement}
import millfork.env.{MemoryVariable, MfArray, PointerType, Thing, Variable, VariableAllocationMethod}
import millfork.node.{Assignment, DerefDebuggingExpression, DoWhileStatement, ExecutableStatement, Expression, ExpressionStatement, ForDirection, ForEachStatement, ForStatement, FunctionCallExpression, IfStatement, IndexedExpression, IndirectFieldExpression, LhsExpression, LiteralExpression, Node, Statement, SumExpression, VariableDeclarationStatement, VariableExpression, WhileStatement}
/**
* @author Karol Stasiak
*/
class Z80StatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) extends AbstractStatementPreprocessor(ctx, statements) {
def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = None
def findIndexedArrays(nodes: Seq[Node], variable: String): Seq[String] = nodes.flatMap(n => findIndexedArrays(n, variable))
def findIndexedArrays(node: Node, variable: String): Seq[String] = node match {
case f: ForStatement =>
if (f.variable == variable) Nil else f.getChildStatements.flatMap(s => findIndexedArrays(s, variable))
case f: Assignment => findIndexedArrays(f.destination, variable) ++ findIndexedArrays(f.source, variable)
case f: IfStatement =>
findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.thenBranch, variable) ++ findIndexedArrays(f.elseBranch, variable)
case f: WhileStatement => findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.body, variable) ++ findIndexedArrays(f.increment, variable)
case f: DoWhileStatement => findIndexedArrays(f.condition, variable) ++ findIndexedArrays(f.body, variable) ++ findIndexedArrays(f.increment, variable)
case f: ForEachStatement => findIndexedArrays(f.values.right.toOption.getOrElse(Nil), variable) ++ findIndexedArrays(f.body, variable)
case f: ExpressionStatement => findIndexedArrays(f.expression, variable)
case f: FunctionCallExpression => findIndexedArrays(f.expressions, variable)
case f: SumExpression => findIndexedArrays(f.expressions.map(_._2), variable)
case f: VariableExpression => Nil
case f: LiteralExpression => Nil
case f: DerefDebuggingExpression => Nil
case IndexedExpression(a, VariableExpression(v)) => if (v == variable) {
ctx.env.maybeGet[Thing](a + ".array") match {
case Some(_: MfArray) => Seq(a)
case _ => Nil
}
} else Nil
case IndexedExpression(_, e) => findIndexedArrays(e, variable)
case _ => Nil
}
def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = {
// TODO: figure out when this is useful
// Currently all instances of arr[i] are replaced with arr`popt`i[0], where arr`popt`i is a new pointer variable.
// This breaks the main Millfork promise of not using hidden variables!
// This may be increase code size or runtime in certain circumstances, more experimentation is needed.
if (!optimize) return None
if (ctx.env.eval(f.start).isEmpty) return None
if (f.variable.contains(".")) return None
if (f.start.containsVariable(f.variable)) return None
if (f.end.containsVariable(f.variable)) return None
val indexVariable = env.get[Variable](f.variable)
if (indexVariable.typ.size != 1) return None
if (indexVariable.isVolatile) return None
indexVariable match {
case v: MemoryVariable =>
if (v.alloc == VariableAllocationMethod.Static) return None
case _ => return None
}
val indexedArrays = findIndexedArrays(f.body, f.variable).toSet
if (indexedArrays.isEmpty) return None
if (indexedArrays.size > 2) return None // TODO: is this the optimal limit?
for (a <- indexedArrays) {
val array = ctx.env.get[MfArray](a + ".array")
// Evil hidden memory usage:
env.registerVariable(VariableDeclarationStatement(
a + "`popt`" + f.variable,
"pointer",
None,
global = false,
stack = false,
constant = false,
volatile = false,
register = false,
None,
None,
None
), ctx.options, isPointy = true)
}
def replaceArrayIndexingsE(node: Expression): Expression = node.replaceIndexedExpression(
i => i.index match {
case VariableExpression(vn) => vn == f.variable && indexedArrays(i.name)
case _ => false
},
i => {
val array = ctx.env.get[MfArray](i.name + ".array")
optimizeExpr(IndirectFieldExpression(
FunctionCallExpression("pointer." + array.elementType.name, List(VariableExpression(i.name + "`popt`" + f.variable))),
Seq(LiteralExpression(0, 1)),
Seq()), Map())
}
)
def replaceArrayIndexingsL(node: LhsExpression): LhsExpression = replaceArrayIndexingsE(node.asInstanceOf[Expression]).asInstanceOf[LhsExpression]
def replaceArrayIndexingsS(node: ExecutableStatement): ExecutableStatement = node match {
case Assignment(t, s) => Assignment(replaceArrayIndexingsL(t), replaceArrayIndexingsE(s)).pos(node.position)
case ExpressionStatement(e) => ExpressionStatement(replaceArrayIndexingsE(e)).pos(node.position)
case IfStatement(c, t, e) => IfStatement(replaceArrayIndexingsE(c), replaceArrayIndexings(t), replaceArrayIndexings(e)).pos(node.position)
case WhileStatement(c, b, i, l) => WhileStatement(replaceArrayIndexingsE(c), replaceArrayIndexings(b), replaceArrayIndexings(i), l).pos(node.position)
case DoWhileStatement(b, i, c, l) => DoWhileStatement(replaceArrayIndexings(b), replaceArrayIndexings(i), replaceArrayIndexingsE(c), l).pos(node.position)
case _ => throw new ArrayIndexOutOfBoundsException // TODO
}
def replaceArrayIndexings(nodes: List[ExecutableStatement]): List[ExecutableStatement] = nodes.map(replaceArrayIndexingsS)
val newDirection = f.direction match {
case ForDirection.ParallelUntil => ForDirection.Until
case ForDirection.ParallelTo => ForDirection.To
case d => d
}
val operator = newDirection match {
case ForDirection.DownTo => "-="
case _ => "+="
}
try {
val newBody = replaceArrayIndexings(f.body) ++ indexedArrays.map(name => {
val array = ctx.env.get[MfArray](name + ".array")
ExpressionStatement(FunctionCallExpression(operator, List(
VariableExpression(name + "`popt`" + f.variable),
LiteralExpression(1, 1))))
})
val optStart = optimizeExpr(f.start, Map())
Some(IfStatement(VariableExpression("true"),
indexedArrays.map(name => {
val array = ctx.env.get[MfArray](name + ".array")
Assignment(
VariableExpression(name + "`popt`" + f.variable),
FunctionCallExpression("pointer." + array.elementType.name, List(
SumExpression(List(false -> VariableExpression(name + ".addr"), false -> optStart), decimal = false)
)))
}).toList :+ ForStatement(f.variable, optStart, optimizeExpr(f.end, Map()), newDirection, optimizeStmts(newBody, Map())._1),
Nil
) -> Map())
} catch {
// too complex, give up:
case _: ArrayIndexOutOfBoundsException => None
}
}
}

View File

@ -29,6 +29,7 @@ object Node {
sealed trait Expression extends Node {
def replaceVariable(variable: String, actualParam: Expression): Expression
def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression
def containsVariable(variable: String): Boolean
def getPointies: Seq[String]
def isPure: Boolean
@ -38,6 +39,7 @@ sealed trait Expression extends Node {
case class ConstantArrayElementExpression(constant: Constant) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -46,6 +48,7 @@ case class ConstantArrayElementExpression(constant: Constant) extends Expression
case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -54,6 +57,7 @@ case class LiteralExpression(value: Long, requiredSize: Int) extends Expression
case class TextLiteralExpression(characters: List[Expression]) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -62,6 +66,7 @@ case class TextLiteralExpression(characters: List[Expression]) extends Expressio
case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -70,6 +75,7 @@ case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expre
case class BooleanLiteralExpression(value: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -80,6 +86,7 @@ sealed trait LhsExpression extends Expression
case object BlackHoleExpression extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
override def getPointies: Seq[String] = Seq.empty
override def isPure: Boolean = true
@ -91,6 +98,10 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
SeparateBytesExpression(
hi.replaceVariable(variable, actualParam),
lo.replaceVariable(variable, actualParam)).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
SeparateBytesExpression(
hi.replaceIndexedExpression(predicate, replacement),
lo.replaceIndexedExpression(predicate, replacement)).pos(position)
override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable)
override def getPointies: Seq[String] = hi.getPointies ++ lo.getPointies
override def isPure: Boolean = hi.isPure && lo.isPure
@ -100,6 +111,8 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
SumExpression(expressions.map { case (n, e) => n -> e.replaceIndexedExpression(predicate, replacement) }, decimal).pos(position)
override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable))
override def getPointies: Seq[String] = expressions.flatMap(_._2.getPointies)
override def isPure: Boolean = expressions.forall(_._2.isPure)
@ -111,6 +124,10 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
FunctionCallExpression(functionName, expressions.map {
_.replaceVariable(variable, actualParam)
}).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
FunctionCallExpression(functionName, expressions.map {
_.replaceIndexedExpression(predicate, replacement)
}).pos(position)
override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable))
override def getPointies: Seq[String] = expressions.flatMap(_.getPointies)
override def isPure: Boolean = false // TODO
@ -120,6 +137,8 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
HalfWordExpression(expression.replaceIndexedExpression(predicate, replacement), hiByte).pos(position)
override def containsVariable(variable: String): Boolean = expression.containsVariable(variable)
override def getPointies: Seq[String] = expression.getPointies
override def isPure: Boolean = expression.isPure
@ -183,6 +202,7 @@ object ZRegister extends Enumeration {
case class VariableExpression(name: String) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
if (name == variable) actualParam else this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = name == variable
override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.stripSuffix(".addr.lo")) else Seq.empty
override def isPure: Boolean = true
@ -198,6 +218,9 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress
case _ => ??? // TODO
}
} else IndexedExpression(name, index.replaceVariable(variable, actualParam)).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
if (predicate(this)) replacement(this).pos(position = position)
else IndexedExpression(name, index.replaceIndexedExpression(predicate, replacement))
override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable)
override def getPointies: Seq[String] = Seq(name)
override def isPure: Boolean = index.isPure
@ -210,6 +233,12 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio
root.replaceVariable(variable, actualParam),
firstIndices.map(_.replaceVariable(variable, actualParam)),
fields.map{case (f, i) => f -> i.map(_.replaceVariable(variable, actualParam))})
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
IndirectFieldExpression(
root.replaceIndexedExpression(predicate, replacement),
firstIndices.map(_.replaceIndexedExpression(predicate, replacement)),
fields.map{case (f, i) => f -> i.map(_.replaceIndexedExpression(predicate, replacement))})
override def containsVariable(variable: String): Boolean =
root.containsVariable(variable) ||
@ -228,6 +257,9 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio
case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefDebuggingExpression(inner.replaceVariable(variable, actualParam), preferredSize)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
DerefDebuggingExpression(inner.replaceIndexedExpression(predicate, replacement), preferredSize)
override def containsVariable(variable: String): Boolean = inner.containsVariable(variable)
@ -244,6 +276,9 @@ case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) exten
case class DerefExpression(inner: Expression, offset: Int, targetType: Type) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefExpression(inner.replaceVariable(variable, actualParam), offset, targetType)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
DerefExpression(inner.replaceIndexedExpression(predicate, replacement), offset, targetType)
override def containsVariable(variable: String): Boolean = inner.containsVariable(variable)
override def getPointies: Seq[String] = inner match {

View File

@ -370,7 +370,7 @@ class ForLoopSuite extends FunSuite with Matchers {
}
}
test("Folding loops") {
test("Folding loops on 6502") {
EmuUnoptimizedRun(
"""
|array a [100]
@ -396,4 +396,19 @@ class ForLoopSuite extends FunSuite with Matchers {
|}
""".stripMargin)
}
test("Folding loops on Z80") {
EmuOptimizedZ80Run(
"""
|array a [10] = [0,1,2,3,4,5,6,7,8,9]
|byte output @$c000
|void main() {
| byte sum
| byte i
| sum = 0
| for i,0,paralleluntil,10 { sum += a[i] }
| output = sum
|}
""".stripMargin).readByte(0xc000) should equal(45)
}
}