diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index b2683d08..2d5c0f7b 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -116,6 +116,17 @@ class AbstractExpressionCompiler[T <: AbstractCode] { case (err: Expression, _) => ErrorReporting.fatal("Invalid left-hand-side expression", err.position) } } + + def isUpToOneVar(params: List[(Boolean, Expression)]): Boolean = { + var count = 0 + params.foreach { + case (false, VariableExpression(_)) => count += 1 + case (_, _: LiteralExpression) => + case (_, _: GeneratedConstantExpression) => + case _ => return false + } + count <= 1 + } } object AbstractExpressionCompiler { @@ -133,6 +144,7 @@ object AbstractExpressionCompiler { case 3 => env.get[Type]("farword") case 4 => env.get[Type]("long") } + case GeneratedConstantExpression(c, t) => t case VariableExpression(name) => env.get[TypedThing](name, expr.position).typ case HalfWordExpression(param, _) => diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 01e7ad21..2e4f5f25 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -11,7 +11,14 @@ import millfork.node._ */ abstract class AbstractStatementCompiler[T <: AbstractCode] { - def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[T] + def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[T] = { + getStatementPreprocessor(ctx, statements)().flatMap(s => compile(ctx, s)) + } + + def getStatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) = + new AbstractStatementPreprocessor(ctx, statements) + + def compile(ctx: CompilationContext, statement: ExecutableStatement): List[T] def nextLabel(prefix: String): String diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala new file mode 100644 index 00000000..2a82d37f --- /dev/null +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -0,0 +1,193 @@ +package millfork.compiler + +import millfork.CompilationFlag +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ +import AbstractExpressionCompiler.getExpressionType +import millfork.compiler.AbstractStatementPreprocessor.hiddenEffectFreeFunctions + +import scala.collection.mutable.ListBuffer + +/** + * @author Karol Stasiak + */ +class AbstractStatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) { + type VV = Map[String, Constant] + private val optimize = true // TODO + private val env: Environment = ctx.env + private val localPrefix = ctx.function.name + "$" + private val usedIdentifiers = if (optimize) statements.flatMap(_.getAllExpressions).flatMap(_.getAllIdentifiers) else Set() + private val trackableVars: Set[String] = if (optimize) { + env.getAllLocalVariables.map(_.name.stripPrefix(localPrefix)) + .filterNot(_.contains(".")) + .filterNot(_.contains("$")) + .filterNot { vname => + val prefix = vname + "." + usedIdentifiers.exists(_.startsWith(prefix)) + }.toSet + } else Set() // TODO + if (ErrorReporting.traceEnabled && trackableVars.nonEmpty) { + ErrorReporting.trace("Tracking local variables: " + trackableVars.mkString(", ")) + } + private val reentrantVars: Set[String] = trackableVars.filter(v => env.get[Variable](v) match { + case _: StackVariable => true + case UninitializedMemoryVariable(_, _, VariableAllocationMethod.Auto, _) => ctx.options.flag(CompilationFlag.DangerousOptimizations) + case _ => false + }) + private val nonreentrantVars: Set[String] = trackableVars -- reentrantVars + + + def apply(): List[ExecutableStatement] = { + optimizeStmts(statements, Map())._1 + } + + def optimizeStmts(stmts: Seq[ExecutableStatement], currentVarValues: VV): (List[ExecutableStatement], VV) = { + val result = ListBuffer[ExecutableStatement]() + var cv = currentVarValues + for(stmt <- stmts){ + val p = optimizeStmt(stmt, cv) + result += p._1 + cv = p._2 + } + result.toList -> cv + } + + def optimizeStmt(stmt: ExecutableStatement, currentVarValues: VV): (ExecutableStatement, VV) = { + var cv = currentVarValues + val pos = stmt.position + stmt match { + case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) => + cv = search(arg, cv) + + Assignment(ve, optimizeExpr(arg, cv)).pos(pos) -> (env.eval(arg, currentVarValues) match { + case Some(c) => cv + (v -> c) + case None => cv - v + }) + case ExpressionStatement(expr@FunctionCallExpression("+=", List(VariableExpression(v), arg))) + if currentVarValues.contains(v) => + cv = search(arg, cv) + ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> (env.eval(expr, currentVarValues) match { + case Some(c) => if (cv.contains(v)) cv + (v -> (cv(v) + c)) else cv + case None => cv - v + }) + case ExpressionStatement(expr@FunctionCallExpression(op, List(VariableExpression(v), arg))) + if op.endsWith("=") && op != ">=" && op != "<=" && op != ":=" => + cv = search(arg, cv) + ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> (cv - v) + case ExpressionStatement(expr) => + cv = search(expr, cv) + ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> cv + case IfStatement(cond, th, el) => + cv = search(cond, cv) + val c = optimizeExpr(cond, cv) + val (t, vt) = optimizeStmts(th, cv) + val (e, ve) = optimizeStmts(el, cv) + IfStatement(c, t, e).pos(pos) -> commonVV(vt, ve) + case WhileStatement(cond, body, inc, labels) => + cv = search(cond, cv) + val c = optimizeExpr(cond, cv) + val (b, _) = optimizeStmts(body, Map()) + val (i, _) = optimizeStmts(inc, Map()) + WhileStatement(c, b, i, labels).pos(pos) -> Map() + case DoWhileStatement(body, inc, cond, labels) => + val c = optimizeExpr(cond, Map()) + val (b, _) = optimizeStmts(body, Map()) + val (i, _) = optimizeStmts(inc, Map()) + DoWhileStatement(b, i, c, labels).pos(pos) -> Map() + case ForStatement(v, st, en, dir, body) => + val s = optimizeExpr(st, Map()) + val e = optimizeExpr(en, Map()) + val (b, _) = optimizeStmts(body, Map()) + ForStatement(v, s, e, dir, b).pos(pos) -> Map() + case _ => stmt -> Map() + } + } + + def search(expr: Expression, cv: VV): VV = { + expr match { + case FunctionCallExpression(op, List(VariableExpression(v), arg)) if op.endsWith("=") && op != "<=" && op != ">=" => + search(arg, cv - v) + case FunctionCallExpression(name, params) + if hiddenEffectFreeFunctions(name) || env.maybeGet[Type](name).isDefined => + params.map(p => search(p, cv)).reduce(commonVV) + case FunctionCallExpression(_, _) => cv -- nonreentrantVars + case SumExpression(params, _) => params.map(p => search(p._2, cv)).reduce(commonVV) + case HalfWordExpression(arg, _) => search(arg, cv) + case IndexedExpression(_, arg) => search(arg, cv) + case _ => cv // TODO + } + } + + def commonVV(a: VV, b: VV): VV = { + if (a.isEmpty) return a + if (b.isEmpty) return b + val keys = a.keySet & b.keySet + if (keys.isEmpty) return Map() + keys.flatMap{ k => + val aa = a(k) + val bb = b(k) + if (aa == bb) Some(k -> aa) else None + }.toMap + } + + def isHiddenEffectFree(expr: Expression): Boolean = { + expr match { + case _: VariableExpression => true + case _: LiteralExpression => true + case _: ConstantArrayElementExpression => true + case _: GeneratedConstantExpression => true + case FunctionCallExpression(name, params) => + hiddenEffectFreeFunctions(name) && params.forall(isHiddenEffectFree) + case _ => false // TODO + } + } + + def optimizeExpr(expr: Expression, currentVarValues: VV): Expression = { + val pos = expr.position + expr match { + case FunctionCallExpression("->", List(handle, VariableExpression(field))) => + expr + case FunctionCallExpression("->", List(handle, FunctionCallExpression(method, params))) => + expr + case VariableExpression(v) if currentVarValues.contains(v) => + val constant = currentVarValues(v) + ErrorReporting.debug(s"Using node flow to replace $v with $constant", pos) + GeneratedConstantExpression(constant, getExpressionType(ctx, expr)).pos(pos) + case FunctionCallExpression(t1, List(FunctionCallExpression(t2, List(arg)))) + if optimize && pointlessDoubleCast(t1, t2, expr) => + ErrorReporting.debug(s"Pointless double cast $t1($t2(...))", pos) + optimizeExpr(FunctionCallExpression(t1, List(arg)), currentVarValues) + case FunctionCallExpression(t1, List(arg)) + if optimize && pointlessCast(t1, expr) => + ErrorReporting.debug(s"Pointless cast $t1(...)", pos) + optimizeExpr(arg, currentVarValues) + case _ => expr // TODO + } + } + + def pointlessCast(t1: String, expr: Expression): Boolean = { + val typ1 = env.maybeGet[Type](t1).getOrElse(return false) + val typ2 = getExpressionType(ctx, expr) + typ1.name == typ2.name + } + + def pointlessDoubleCast(t1: String, t2: String, expr: Expression): Boolean = { + val s1 = env.maybeGet[Type](t1).getOrElse(return false).size + val s2 = env.maybeGet[Type](t2).getOrElse(return false).size + if (s1 != s2) return false + val s3 = AbstractExpressionCompiler.getExpressionType(ctx, expr).size + s1 == s3 + } +} + +object AbstractStatementPreprocessor { + val hiddenEffectFreeFunctions = Set( + "+", "+'", "-", "-'", + "*", "*'", + "<<", "<<'", ">>", ">>'", ">>>>", + "&", "&&", "||", "|", "^", + "==", "!=", "<", ">", ">=", "<=", + "not", "hi", "lo", "nonet" + ) +} \ No newline at end of file diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 9875e651..8b9b61ff 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -158,6 +158,7 @@ object BuiltIns { case None => expr match { case VariableExpression(_) => 'V' case IndexedExpression(_, LiteralExpression(_, _)) => 'K' + case IndexedExpression(_, GeneratedConstantExpression(_, _)) => 'K' case IndexedExpression(_, expr@VariableExpression(v)) => env.eval(expr) match { case Some(_) => 'K' diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 79820e1a..9435d2fd 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -309,6 +309,11 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { assertCompatible(exprType, target.typ) compileConstant(ctx, NumericConstant(value, size), target) } + case GeneratedConstantExpression(value, _) => + exprTypeAndVariable.fold(noop) { case (exprType, target) => + assertCompatible(exprType, target.typ) + compileConstant(ctx, value, target) + } case VariableExpression(name) => exprTypeAndVariable.fold(noop) { case (exprType, target) => assertCompatible(exprType, target.typ) @@ -1110,16 +1115,11 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { if (ctx.env.eval(e).isEmpty) e match { case VariableExpression(_) => case LiteralExpression(_, _) => + case GeneratedConstantExpression(_, _) => case IndexedExpression(_, VariableExpression(_)) => case IndexedExpression(_, LiteralExpression(_, _)) => - case IndexedExpression(_, SumExpression(List( - (_, LiteralExpression(_, _)), - (false, VariableExpression(_)) - ), false)) => - case IndexedExpression(_, SumExpression(List( - (false, VariableExpression(_)), - (_, LiteralExpression(_, _)) - ), false)) => + case IndexedExpression(_, GeneratedConstantExpression(_, _)) => + case IndexedExpression(_, SumExpression(params, false)) if isUpToOneVar(params) => case _ => ErrorReporting.warn("A complex expression may be evaluated multiple times", ctx.options, e.position) } diff --git a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala index 65d1cb44..2faa8575 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala @@ -28,10 +28,6 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { MosExpressionCompiler.compile(ctx, expr, Some(b, RegisterVariable(MosRegister.A, b)), branching) } - def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[AssemblyLine] = { - statements.flatMap(s => compile(ctx, s)) - } - def compile(ctx: CompilationContext, statement: ExecutableStatement): List[AssemblyLine] = { val env = ctx.env val m = ctx.function @@ -143,6 +139,9 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { }) } statement match { + case EmptyStatement(stmts) => + stmts.foreach(s => compile(ctx, s)) + Nil case MosAssemblyStatement(o, a, x, e) => val c: Constant = x match { // TODO: hmmm @@ -183,7 +182,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { } case ExpressionStatement(e) => e match { - case VariableExpression(_) | LiteralExpression(_, _) => + case VariableExpression(_) | LiteralExpression(_, _) | _:GeneratedConstantExpression => ErrorReporting.warn("Pointless expression statement", ctx.options, statement.position) case _ => } diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index 6cc502ce..72d3ffa7 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -294,6 +294,7 @@ object PseudoregisterBuiltIns { case None => expr match { case VariableExpression(_) => 'V' case IndexedExpression(_, LiteralExpression(_, _)) => 'K' + case IndexedExpression(_, GeneratedConstantExpression(_, _)) => 'K' case IndexedExpression(_, VariableExpression(_)) => 'J' case IndexedExpression(_, _) => 'I' case _ => 'A' diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index d2683c4b..79ebf266 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -169,6 +169,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case None => expression match { case LiteralExpression(value, _) => ??? + case GeneratedConstantExpression(_, _) => ??? case VariableExpression(name) => env.get[Variable](name) match { case v: VariableInMemory => @@ -778,16 +779,11 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { if (ctx.env.eval(e).isEmpty) e match { case VariableExpression(_) => case LiteralExpression(_, _) => + case GeneratedConstantExpression(_, _) => case IndexedExpression(_, VariableExpression(_)) => case IndexedExpression(_, LiteralExpression(_, _)) => - case IndexedExpression(_, SumExpression(List( - (_, LiteralExpression(_, _)), - (false, VariableExpression(_)) - ), false)) => - case IndexedExpression(_, SumExpression(List( - (false, VariableExpression(_)), - (_, LiteralExpression(_, _)) - ), false)) => + case IndexedExpression(_, GeneratedConstantExpression(_, _)) => + case IndexedExpression(_, SumExpression(params, false)) if isUpToOneVar(params) => case _ => ErrorReporting.warn("A complex expression may be evaluated multiple times", ctx.options, e.position) } diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index 69ec7822..c52817e7 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -13,14 +13,14 @@ import millfork.error.ErrorReporting */ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { - def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[ZLine] = { - statements.flatMap(s => compile(ctx, s)) - } def compile(ctx: CompilationContext, statement: ExecutableStatement): List[ZLine] = { val options = ctx.options val env = ctx.env statement match { + case EmptyStatement(stmts) => + stmts.foreach(s => compile(ctx, s)) + Nil case ReturnStatement(None) => fixStackOnReturn(ctx) ++ (ctx.function.returnType match { case _: BooleanType => diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 2597c47e..bf5849bf 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -379,6 +379,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa else variable -> constant case Some(IndexedExpression(_, _)) => variable -> constant case Some(LiteralExpression(_, _)) => variable -> constant + case Some(GeneratedConstantExpression(_, _)) => variable -> constant case Some(SumExpression(List(negative@(true, _)), false)) => Some(SumExpression(List(false -> LiteralExpression(0xff, 1), negative), decimal = false)) -> (constant - 255).quickSimplify case Some(FunctionCallExpression( @@ -409,18 +410,26 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + def eval(e: Expression, vars: Map[String, Constant]): Option[Constant] = evalImpl(e, Some(vars)) + + def eval(e: Expression): Option[Constant] = evalImpl(e, None) + //noinspection ScalaUnnecessaryParentheses,ZeroIndexToHead - def eval(e: Expression): Option[Constant] = { + private def evalImpl(e: Expression, vv: Option[Map[String, Constant]]): Option[Constant] = { e match { case LiteralExpression(value, size) => Some(NumericConstant(value, size)) case ConstantArrayElementExpression(c) => Some(c) + case GeneratedConstantExpression(c, t) => Some(c) case VariableExpression(name) => - maybeGet[ConstantThing](name).map(_.value) + vv match { + case Some(m) if m.contains(name) => Some(m(name)) + case _ => maybeGet[ConstantThing](name).map(_.value) + } case IndexedExpression(_, _) => None - case HalfWordExpression(param, hi) => eval(e).map(c => if (hi) c.hiByte else c.loByte) + case HalfWordExpression(param, hi) => evalImpl(e, vv).map(c => if (hi) c.hiByte else c.loByte) case SumExpression(params, decimal) => params.map { - case (minus, param) => (minus, eval(param)) + case (minus, param) => (minus, evalImpl(param, vv)) }.foldLeft(Some(Constant.Zero).asInstanceOf[Option[Constant]]) { (oc, pair) => oc.flatMap { c => pair match { @@ -436,8 +445,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } case SeparateBytesExpression(h, l) => for { - lc <- eval(l) - hc <- eval(h) + lc <- evalImpl(l, vv) + hc <- evalImpl(h, vv) } yield hc.asl(8) + lc case FunctionCallExpression(name, params) => name match { @@ -1204,6 +1213,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case BlackHoleExpression => () case _:BooleanLiteralExpression => () case _:LiteralExpression => () + case _:GeneratedConstantExpression => () case VariableExpression(name) => checkName[VariableLikeThing]("Variable or constant", name, node.position) case IndexedExpression(name, index) => diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 7b0d7fc5..90b769db 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -2,7 +2,7 @@ package millfork.node import millfork.assembly.mos.{AddrMode, Opcode} import millfork.assembly.z80.{ZOpcode, ZRegisters} -import millfork.env.{Constant, ParamPassingConvention} +import millfork.env.{Constant, ParamPassingConvention, Type} case class Position(filename: String, line: Int, column: Int, cursor: Int) @@ -29,24 +29,35 @@ sealed trait Expression extends Node { def replaceVariable(variable: String, actualParam: Expression): Expression def containsVariable(variable: String): Boolean def isPure: Boolean + def getAllIdentifiers: Set[String] } case class ConstantArrayElementExpression(constant: Constant) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set.empty } case class LiteralExpression(value: Long, requiredSize: Int) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set.empty +} + +case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = this + override def containsVariable(variable: String): Boolean = false + override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set.empty } case class BooleanLiteralExpression(value: Boolean) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set.empty } sealed trait LhsExpression extends Expression @@ -55,6 +66,7 @@ case object BlackHoleExpression extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this override def containsVariable(variable: String): Boolean = false override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set.empty } case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression { @@ -64,6 +76,7 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx lo.replaceVariable(variable, actualParam)) override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable) override def isPure: Boolean = hi.isPure && lo.isPure + override def getAllIdentifiers: Set[String] = hi.getAllIdentifiers ++ lo.getAllIdentifiers } case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression { @@ -71,6 +84,7 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal) override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable)) override def isPure: Boolean = expressions.forall(_._2.isPure) + override def getAllIdentifiers: Set[String] = expressions.map(_._2.getAllIdentifiers).fold(Set[String]())(_ ++ _) } case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression { @@ -80,6 +94,7 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres }) override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable)) override def isPure: Boolean = false // TODO + override def getAllIdentifiers: Set[String] = expressions.map(_.getAllIdentifiers).fold(Set[String]())(_ ++ _) + functionName } case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression { @@ -87,6 +102,7 @@ case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends E HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte) override def containsVariable(variable: String): Boolean = expression.containsVariable(variable) override def isPure: Boolean = expression.isPure + override def getAllIdentifiers: Set[String] = expression.getAllIdentifiers } sealed class NiceFunctionProperty(override val toString: String) @@ -145,6 +161,7 @@ case class VariableExpression(name: String) extends LhsExpression { if (name == variable) actualParam else this override def containsVariable(variable: String): Boolean = name == variable override def isPure: Boolean = true + override def getAllIdentifiers: Set[String] = Set(name) } case class IndexedExpression(name: String, index: Expression) extends LhsExpression { @@ -157,6 +174,7 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress } else IndexedExpression(name, index.replaceVariable(variable, actualParam)) override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable) override def isPure: Boolean = index.isPure + override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name } sealed trait Statement extends Node { @@ -287,6 +305,10 @@ case class ReturnStatement(value: Option[Expression]) extends ExecutableStatemen override def getAllExpressions: List[Expression] = value.toList } +case class EmptyStatement(toTypecheck: List[ExecutableStatement]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = toTypecheck.flatMap(_.getAllExpressions) +} + trait ReturnDispatchLabel extends Node { def getAllExpressions: List[Expression] } diff --git a/src/main/scala/millfork/output/AbstractInliningCalculator.scala b/src/main/scala/millfork/output/AbstractInliningCalculator.scala index ef43bde0..da2bf25b 100644 --- a/src/main/scala/millfork/output/AbstractInliningCalculator.scala +++ b/src/main/scala/millfork/output/AbstractInliningCalculator.scala @@ -79,6 +79,7 @@ abstract class AbstractInliningCalculator[T <: AbstractCode] { s.name.stripSuffix(".addr.lo"), s.name.stripSuffix(".addr.hi")).toList.map(_ -> true) case s: LiteralExpression => Nil + case s: GeneratedConstantExpression => Nil case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil) case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2)) case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs) diff --git a/src/test/scala/millfork/test/StatementOptimizationSuite.scala b/src/test/scala/millfork/test/StatementOptimizationSuite.scala new file mode 100644 index 00000000..58ce1cc5 --- /dev/null +++ b/src/test/scala/millfork/test/StatementOptimizationSuite.scala @@ -0,0 +1,58 @@ +package millfork.test + +import millfork.Cpu +import millfork.test.emu.EmuCrossPlatformBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class StatementOptimizationSuite extends FunSuite with Matchers { + + test("Statement optimization 1") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | array output[10] @$c000 + | void main() { + | byte i + | for i,0,paralleluntil,output.length { + | output[i] = f(i) + | } + | } + | noinline byte f(byte a) { + | byte b + | byte c + | byte d + | byte e + | byte f + | b = a + | c = 5 + | d = c + | b += c + | if a > 4 { + | b += 1 + | } else { + | b += 2 + | } + | e = 4 + | f = e + | while a > 0 { + | d += f + | a -= 1 + | } + | return b + d + | } + """.stripMargin) { m=> + m.readByte(0xc000) should equal(12) + m.readByte(0xc001) should equal(17) + m.readByte(0xc002) should equal(22) + m.readByte(0xc003) should equal(27) + m.readByte(0xc004) should equal(32) + m.readByte(0xc005) should equal(36) + m.readByte(0xc006) should equal(41) + m.readByte(0xc007) should equal(46) + m.readByte(0xc008) should equal(51) + m.readByte(0xc009) should equal(56) + } + } +}