mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-12 03:30:09 +00:00
A statement preprocessor with some basic node-level optimizations
This commit is contained in:
parent
07c049c168
commit
8c1492b211
@ -116,6 +116,17 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
|
||||
case (err: Expression, _) => ErrorReporting.fatal("Invalid left-hand-side expression", err.position)
|
||||
}
|
||||
}
|
||||
|
||||
def isUpToOneVar(params: List[(Boolean, Expression)]): Boolean = {
|
||||
var count = 0
|
||||
params.foreach {
|
||||
case (false, VariableExpression(_)) => count += 1
|
||||
case (_, _: LiteralExpression) =>
|
||||
case (_, _: GeneratedConstantExpression) =>
|
||||
case _ => return false
|
||||
}
|
||||
count <= 1
|
||||
}
|
||||
}
|
||||
|
||||
object AbstractExpressionCompiler {
|
||||
@ -133,6 +144,7 @@ object AbstractExpressionCompiler {
|
||||
case 3 => env.get[Type]("farword")
|
||||
case 4 => env.get[Type]("long")
|
||||
}
|
||||
case GeneratedConstantExpression(c, t) => t
|
||||
case VariableExpression(name) =>
|
||||
env.get[TypedThing](name, expr.position).typ
|
||||
case HalfWordExpression(param, _) =>
|
||||
|
@ -11,7 +11,14 @@ import millfork.node._
|
||||
*/
|
||||
abstract class AbstractStatementCompiler[T <: AbstractCode] {
|
||||
|
||||
def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[T]
|
||||
def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[T] = {
|
||||
getStatementPreprocessor(ctx, statements)().flatMap(s => compile(ctx, s))
|
||||
}
|
||||
|
||||
def getStatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) =
|
||||
new AbstractStatementPreprocessor(ctx, statements)
|
||||
|
||||
def compile(ctx: CompilationContext, statement: ExecutableStatement): List[T]
|
||||
|
||||
def nextLabel(prefix: String): String
|
||||
|
||||
|
@ -0,0 +1,193 @@
|
||||
package millfork.compiler
|
||||
|
||||
import millfork.CompilationFlag
|
||||
import millfork.env._
|
||||
import millfork.error.ErrorReporting
|
||||
import millfork.node._
|
||||
import AbstractExpressionCompiler.getExpressionType
|
||||
import millfork.compiler.AbstractStatementPreprocessor.hiddenEffectFreeFunctions
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
class AbstractStatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) {
|
||||
type VV = Map[String, Constant]
|
||||
private val optimize = true // TODO
|
||||
private val env: Environment = ctx.env
|
||||
private val localPrefix = ctx.function.name + "$"
|
||||
private val usedIdentifiers = if (optimize) statements.flatMap(_.getAllExpressions).flatMap(_.getAllIdentifiers) else Set()
|
||||
private val trackableVars: Set[String] = if (optimize) {
|
||||
env.getAllLocalVariables.map(_.name.stripPrefix(localPrefix))
|
||||
.filterNot(_.contains("."))
|
||||
.filterNot(_.contains("$"))
|
||||
.filterNot { vname =>
|
||||
val prefix = vname + "."
|
||||
usedIdentifiers.exists(_.startsWith(prefix))
|
||||
}.toSet
|
||||
} else Set() // TODO
|
||||
if (ErrorReporting.traceEnabled && trackableVars.nonEmpty) {
|
||||
ErrorReporting.trace("Tracking local variables: " + trackableVars.mkString(", "))
|
||||
}
|
||||
private val reentrantVars: Set[String] = trackableVars.filter(v => env.get[Variable](v) match {
|
||||
case _: StackVariable => true
|
||||
case UninitializedMemoryVariable(_, _, VariableAllocationMethod.Auto, _) => ctx.options.flag(CompilationFlag.DangerousOptimizations)
|
||||
case _ => false
|
||||
})
|
||||
private val nonreentrantVars: Set[String] = trackableVars -- reentrantVars
|
||||
|
||||
|
||||
def apply(): List[ExecutableStatement] = {
|
||||
optimizeStmts(statements, Map())._1
|
||||
}
|
||||
|
||||
def optimizeStmts(stmts: Seq[ExecutableStatement], currentVarValues: VV): (List[ExecutableStatement], VV) = {
|
||||
val result = ListBuffer[ExecutableStatement]()
|
||||
var cv = currentVarValues
|
||||
for(stmt <- stmts){
|
||||
val p = optimizeStmt(stmt, cv)
|
||||
result += p._1
|
||||
cv = p._2
|
||||
}
|
||||
result.toList -> cv
|
||||
}
|
||||
|
||||
def optimizeStmt(stmt: ExecutableStatement, currentVarValues: VV): (ExecutableStatement, VV) = {
|
||||
var cv = currentVarValues
|
||||
val pos = stmt.position
|
||||
stmt match {
|
||||
case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) =>
|
||||
cv = search(arg, cv)
|
||||
|
||||
Assignment(ve, optimizeExpr(arg, cv)).pos(pos) -> (env.eval(arg, currentVarValues) match {
|
||||
case Some(c) => cv + (v -> c)
|
||||
case None => cv - v
|
||||
})
|
||||
case ExpressionStatement(expr@FunctionCallExpression("+=", List(VariableExpression(v), arg)))
|
||||
if currentVarValues.contains(v) =>
|
||||
cv = search(arg, cv)
|
||||
ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> (env.eval(expr, currentVarValues) match {
|
||||
case Some(c) => if (cv.contains(v)) cv + (v -> (cv(v) + c)) else cv
|
||||
case None => cv - v
|
||||
})
|
||||
case ExpressionStatement(expr@FunctionCallExpression(op, List(VariableExpression(v), arg)))
|
||||
if op.endsWith("=") && op != ">=" && op != "<=" && op != ":=" =>
|
||||
cv = search(arg, cv)
|
||||
ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> (cv - v)
|
||||
case ExpressionStatement(expr) =>
|
||||
cv = search(expr, cv)
|
||||
ExpressionStatement(optimizeExpr(expr, cv)).pos(pos) -> cv
|
||||
case IfStatement(cond, th, el) =>
|
||||
cv = search(cond, cv)
|
||||
val c = optimizeExpr(cond, cv)
|
||||
val (t, vt) = optimizeStmts(th, cv)
|
||||
val (e, ve) = optimizeStmts(el, cv)
|
||||
IfStatement(c, t, e).pos(pos) -> commonVV(vt, ve)
|
||||
case WhileStatement(cond, body, inc, labels) =>
|
||||
cv = search(cond, cv)
|
||||
val c = optimizeExpr(cond, cv)
|
||||
val (b, _) = optimizeStmts(body, Map())
|
||||
val (i, _) = optimizeStmts(inc, Map())
|
||||
WhileStatement(c, b, i, labels).pos(pos) -> Map()
|
||||
case DoWhileStatement(body, inc, cond, labels) =>
|
||||
val c = optimizeExpr(cond, Map())
|
||||
val (b, _) = optimizeStmts(body, Map())
|
||||
val (i, _) = optimizeStmts(inc, Map())
|
||||
DoWhileStatement(b, i, c, labels).pos(pos) -> Map()
|
||||
case ForStatement(v, st, en, dir, body) =>
|
||||
val s = optimizeExpr(st, Map())
|
||||
val e = optimizeExpr(en, Map())
|
||||
val (b, _) = optimizeStmts(body, Map())
|
||||
ForStatement(v, s, e, dir, b).pos(pos) -> Map()
|
||||
case _ => stmt -> Map()
|
||||
}
|
||||
}
|
||||
|
||||
def search(expr: Expression, cv: VV): VV = {
|
||||
expr match {
|
||||
case FunctionCallExpression(op, List(VariableExpression(v), arg)) if op.endsWith("=") && op != "<=" && op != ">=" =>
|
||||
search(arg, cv - v)
|
||||
case FunctionCallExpression(name, params)
|
||||
if hiddenEffectFreeFunctions(name) || env.maybeGet[Type](name).isDefined =>
|
||||
params.map(p => search(p, cv)).reduce(commonVV)
|
||||
case FunctionCallExpression(_, _) => cv -- nonreentrantVars
|
||||
case SumExpression(params, _) => params.map(p => search(p._2, cv)).reduce(commonVV)
|
||||
case HalfWordExpression(arg, _) => search(arg, cv)
|
||||
case IndexedExpression(_, arg) => search(arg, cv)
|
||||
case _ => cv // TODO
|
||||
}
|
||||
}
|
||||
|
||||
def commonVV(a: VV, b: VV): VV = {
|
||||
if (a.isEmpty) return a
|
||||
if (b.isEmpty) return b
|
||||
val keys = a.keySet & b.keySet
|
||||
if (keys.isEmpty) return Map()
|
||||
keys.flatMap{ k =>
|
||||
val aa = a(k)
|
||||
val bb = b(k)
|
||||
if (aa == bb) Some(k -> aa) else None
|
||||
}.toMap
|
||||
}
|
||||
|
||||
def isHiddenEffectFree(expr: Expression): Boolean = {
|
||||
expr match {
|
||||
case _: VariableExpression => true
|
||||
case _: LiteralExpression => true
|
||||
case _: ConstantArrayElementExpression => true
|
||||
case _: GeneratedConstantExpression => true
|
||||
case FunctionCallExpression(name, params) =>
|
||||
hiddenEffectFreeFunctions(name) && params.forall(isHiddenEffectFree)
|
||||
case _ => false // TODO
|
||||
}
|
||||
}
|
||||
|
||||
def optimizeExpr(expr: Expression, currentVarValues: VV): Expression = {
|
||||
val pos = expr.position
|
||||
expr match {
|
||||
case FunctionCallExpression("->", List(handle, VariableExpression(field))) =>
|
||||
expr
|
||||
case FunctionCallExpression("->", List(handle, FunctionCallExpression(method, params))) =>
|
||||
expr
|
||||
case VariableExpression(v) if currentVarValues.contains(v) =>
|
||||
val constant = currentVarValues(v)
|
||||
ErrorReporting.debug(s"Using node flow to replace $v with $constant", pos)
|
||||
GeneratedConstantExpression(constant, getExpressionType(ctx, expr)).pos(pos)
|
||||
case FunctionCallExpression(t1, List(FunctionCallExpression(t2, List(arg))))
|
||||
if optimize && pointlessDoubleCast(t1, t2, expr) =>
|
||||
ErrorReporting.debug(s"Pointless double cast $t1($t2(...))", pos)
|
||||
optimizeExpr(FunctionCallExpression(t1, List(arg)), currentVarValues)
|
||||
case FunctionCallExpression(t1, List(arg))
|
||||
if optimize && pointlessCast(t1, expr) =>
|
||||
ErrorReporting.debug(s"Pointless cast $t1(...)", pos)
|
||||
optimizeExpr(arg, currentVarValues)
|
||||
case _ => expr // TODO
|
||||
}
|
||||
}
|
||||
|
||||
def pointlessCast(t1: String, expr: Expression): Boolean = {
|
||||
val typ1 = env.maybeGet[Type](t1).getOrElse(return false)
|
||||
val typ2 = getExpressionType(ctx, expr)
|
||||
typ1.name == typ2.name
|
||||
}
|
||||
|
||||
def pointlessDoubleCast(t1: String, t2: String, expr: Expression): Boolean = {
|
||||
val s1 = env.maybeGet[Type](t1).getOrElse(return false).size
|
||||
val s2 = env.maybeGet[Type](t2).getOrElse(return false).size
|
||||
if (s1 != s2) return false
|
||||
val s3 = AbstractExpressionCompiler.getExpressionType(ctx, expr).size
|
||||
s1 == s3
|
||||
}
|
||||
}
|
||||
|
||||
object AbstractStatementPreprocessor {
|
||||
val hiddenEffectFreeFunctions = Set(
|
||||
"+", "+'", "-", "-'",
|
||||
"*", "*'",
|
||||
"<<", "<<'", ">>", ">>'", ">>>>",
|
||||
"&", "&&", "||", "|", "^",
|
||||
"==", "!=", "<", ">", ">=", "<=",
|
||||
"not", "hi", "lo", "nonet"
|
||||
)
|
||||
}
|
@ -158,6 +158,7 @@ object BuiltIns {
|
||||
case None => expr match {
|
||||
case VariableExpression(_) => 'V'
|
||||
case IndexedExpression(_, LiteralExpression(_, _)) => 'K'
|
||||
case IndexedExpression(_, GeneratedConstantExpression(_, _)) => 'K'
|
||||
case IndexedExpression(_, expr@VariableExpression(v)) =>
|
||||
env.eval(expr) match {
|
||||
case Some(_) => 'K'
|
||||
|
@ -309,6 +309,11 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
|
||||
assertCompatible(exprType, target.typ)
|
||||
compileConstant(ctx, NumericConstant(value, size), target)
|
||||
}
|
||||
case GeneratedConstantExpression(value, _) =>
|
||||
exprTypeAndVariable.fold(noop) { case (exprType, target) =>
|
||||
assertCompatible(exprType, target.typ)
|
||||
compileConstant(ctx, value, target)
|
||||
}
|
||||
case VariableExpression(name) =>
|
||||
exprTypeAndVariable.fold(noop) { case (exprType, target) =>
|
||||
assertCompatible(exprType, target.typ)
|
||||
@ -1110,16 +1115,11 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
|
||||
if (ctx.env.eval(e).isEmpty) e match {
|
||||
case VariableExpression(_) =>
|
||||
case LiteralExpression(_, _) =>
|
||||
case GeneratedConstantExpression(_, _) =>
|
||||
case IndexedExpression(_, VariableExpression(_)) =>
|
||||
case IndexedExpression(_, LiteralExpression(_, _)) =>
|
||||
case IndexedExpression(_, SumExpression(List(
|
||||
(_, LiteralExpression(_, _)),
|
||||
(false, VariableExpression(_))
|
||||
), false)) =>
|
||||
case IndexedExpression(_, SumExpression(List(
|
||||
(false, VariableExpression(_)),
|
||||
(_, LiteralExpression(_, _))
|
||||
), false)) =>
|
||||
case IndexedExpression(_, GeneratedConstantExpression(_, _)) =>
|
||||
case IndexedExpression(_, SumExpression(params, false)) if isUpToOneVar(params) =>
|
||||
case _ =>
|
||||
ErrorReporting.warn("A complex expression may be evaluated multiple times", ctx.options, e.position)
|
||||
}
|
||||
|
@ -28,10 +28,6 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
|
||||
MosExpressionCompiler.compile(ctx, expr, Some(b, RegisterVariable(MosRegister.A, b)), branching)
|
||||
}
|
||||
|
||||
def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[AssemblyLine] = {
|
||||
statements.flatMap(s => compile(ctx, s))
|
||||
}
|
||||
|
||||
def compile(ctx: CompilationContext, statement: ExecutableStatement): List[AssemblyLine] = {
|
||||
val env = ctx.env
|
||||
val m = ctx.function
|
||||
@ -143,6 +139,9 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
|
||||
})
|
||||
}
|
||||
statement match {
|
||||
case EmptyStatement(stmts) =>
|
||||
stmts.foreach(s => compile(ctx, s))
|
||||
Nil
|
||||
case MosAssemblyStatement(o, a, x, e) =>
|
||||
val c: Constant = x match {
|
||||
// TODO: hmmm
|
||||
@ -183,7 +182,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
|
||||
}
|
||||
case ExpressionStatement(e) =>
|
||||
e match {
|
||||
case VariableExpression(_) | LiteralExpression(_, _) =>
|
||||
case VariableExpression(_) | LiteralExpression(_, _) | _:GeneratedConstantExpression =>
|
||||
ErrorReporting.warn("Pointless expression statement", ctx.options, statement.position)
|
||||
case _ =>
|
||||
}
|
||||
|
@ -294,6 +294,7 @@ object PseudoregisterBuiltIns {
|
||||
case None => expr match {
|
||||
case VariableExpression(_) => 'V'
|
||||
case IndexedExpression(_, LiteralExpression(_, _)) => 'K'
|
||||
case IndexedExpression(_, GeneratedConstantExpression(_, _)) => 'K'
|
||||
case IndexedExpression(_, VariableExpression(_)) => 'J'
|
||||
case IndexedExpression(_, _) => 'I'
|
||||
case _ => 'A'
|
||||
|
@ -169,6 +169,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
|
||||
case None =>
|
||||
expression match {
|
||||
case LiteralExpression(value, _) => ???
|
||||
case GeneratedConstantExpression(_, _) => ???
|
||||
case VariableExpression(name) =>
|
||||
env.get[Variable](name) match {
|
||||
case v: VariableInMemory =>
|
||||
@ -778,16 +779,11 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
|
||||
if (ctx.env.eval(e).isEmpty) e match {
|
||||
case VariableExpression(_) =>
|
||||
case LiteralExpression(_, _) =>
|
||||
case GeneratedConstantExpression(_, _) =>
|
||||
case IndexedExpression(_, VariableExpression(_)) =>
|
||||
case IndexedExpression(_, LiteralExpression(_, _)) =>
|
||||
case IndexedExpression(_, SumExpression(List(
|
||||
(_, LiteralExpression(_, _)),
|
||||
(false, VariableExpression(_))
|
||||
), false)) =>
|
||||
case IndexedExpression(_, SumExpression(List(
|
||||
(false, VariableExpression(_)),
|
||||
(_, LiteralExpression(_, _))
|
||||
), false)) =>
|
||||
case IndexedExpression(_, GeneratedConstantExpression(_, _)) =>
|
||||
case IndexedExpression(_, SumExpression(params, false)) if isUpToOneVar(params) =>
|
||||
case _ =>
|
||||
ErrorReporting.warn("A complex expression may be evaluated multiple times", ctx.options, e.position)
|
||||
}
|
||||
|
@ -13,14 +13,14 @@ import millfork.error.ErrorReporting
|
||||
*/
|
||||
object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] {
|
||||
|
||||
def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[ZLine] = {
|
||||
statements.flatMap(s => compile(ctx, s))
|
||||
}
|
||||
|
||||
def compile(ctx: CompilationContext, statement: ExecutableStatement): List[ZLine] = {
|
||||
val options = ctx.options
|
||||
val env = ctx.env
|
||||
statement match {
|
||||
case EmptyStatement(stmts) =>
|
||||
stmts.foreach(s => compile(ctx, s))
|
||||
Nil
|
||||
case ReturnStatement(None) =>
|
||||
fixStackOnReturn(ctx) ++ (ctx.function.returnType match {
|
||||
case _: BooleanType =>
|
||||
|
22
src/main/scala/millfork/env/Environment.scala
vendored
22
src/main/scala/millfork/env/Environment.scala
vendored
@ -379,6 +379,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
else variable -> constant
|
||||
case Some(IndexedExpression(_, _)) => variable -> constant
|
||||
case Some(LiteralExpression(_, _)) => variable -> constant
|
||||
case Some(GeneratedConstantExpression(_, _)) => variable -> constant
|
||||
case Some(SumExpression(List(negative@(true, _)), false)) =>
|
||||
Some(SumExpression(List(false -> LiteralExpression(0xff, 1), negative), decimal = false)) -> (constant - 255).quickSimplify
|
||||
case Some(FunctionCallExpression(
|
||||
@ -409,18 +410,26 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
}
|
||||
|
||||
def eval(e: Expression, vars: Map[String, Constant]): Option[Constant] = evalImpl(e, Some(vars))
|
||||
|
||||
def eval(e: Expression): Option[Constant] = evalImpl(e, None)
|
||||
|
||||
//noinspection ScalaUnnecessaryParentheses,ZeroIndexToHead
|
||||
def eval(e: Expression): Option[Constant] = {
|
||||
private def evalImpl(e: Expression, vv: Option[Map[String, Constant]]): Option[Constant] = {
|
||||
e match {
|
||||
case LiteralExpression(value, size) => Some(NumericConstant(value, size))
|
||||
case ConstantArrayElementExpression(c) => Some(c)
|
||||
case GeneratedConstantExpression(c, t) => Some(c)
|
||||
case VariableExpression(name) =>
|
||||
maybeGet[ConstantThing](name).map(_.value)
|
||||
vv match {
|
||||
case Some(m) if m.contains(name) => Some(m(name))
|
||||
case _ => maybeGet[ConstantThing](name).map(_.value)
|
||||
}
|
||||
case IndexedExpression(_, _) => None
|
||||
case HalfWordExpression(param, hi) => eval(e).map(c => if (hi) c.hiByte else c.loByte)
|
||||
case HalfWordExpression(param, hi) => evalImpl(e, vv).map(c => if (hi) c.hiByte else c.loByte)
|
||||
case SumExpression(params, decimal) =>
|
||||
params.map {
|
||||
case (minus, param) => (minus, eval(param))
|
||||
case (minus, param) => (minus, evalImpl(param, vv))
|
||||
}.foldLeft(Some(Constant.Zero).asInstanceOf[Option[Constant]]) { (oc, pair) =>
|
||||
oc.flatMap { c =>
|
||||
pair match {
|
||||
@ -436,8 +445,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
}
|
||||
case SeparateBytesExpression(h, l) => for {
|
||||
lc <- eval(l)
|
||||
hc <- eval(h)
|
||||
lc <- evalImpl(l, vv)
|
||||
hc <- evalImpl(h, vv)
|
||||
} yield hc.asl(8) + lc
|
||||
case FunctionCallExpression(name, params) =>
|
||||
name match {
|
||||
@ -1204,6 +1213,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
case BlackHoleExpression => ()
|
||||
case _:BooleanLiteralExpression => ()
|
||||
case _:LiteralExpression => ()
|
||||
case _:GeneratedConstantExpression => ()
|
||||
case VariableExpression(name) =>
|
||||
checkName[VariableLikeThing]("Variable or constant", name, node.position)
|
||||
case IndexedExpression(name, index) =>
|
||||
|
@ -2,7 +2,7 @@ package millfork.node
|
||||
|
||||
import millfork.assembly.mos.{AddrMode, Opcode}
|
||||
import millfork.assembly.z80.{ZOpcode, ZRegisters}
|
||||
import millfork.env.{Constant, ParamPassingConvention}
|
||||
import millfork.env.{Constant, ParamPassingConvention, Type}
|
||||
|
||||
case class Position(filename: String, line: Int, column: Int, cursor: Int)
|
||||
|
||||
@ -29,24 +29,35 @@ sealed trait Expression extends Node {
|
||||
def replaceVariable(variable: String, actualParam: Expression): Expression
|
||||
def containsVariable(variable: String): Boolean
|
||||
def isPure: Boolean
|
||||
def getAllIdentifiers: Set[String]
|
||||
}
|
||||
|
||||
case class ConstantArrayElementExpression(constant: Constant) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set.empty
|
||||
}
|
||||
|
||||
case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set.empty
|
||||
}
|
||||
|
||||
case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set.empty
|
||||
}
|
||||
|
||||
case class BooleanLiteralExpression(value: Boolean) extends Expression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set.empty
|
||||
}
|
||||
|
||||
sealed trait LhsExpression extends Expression
|
||||
@ -55,6 +66,7 @@ case object BlackHoleExpression extends LhsExpression {
|
||||
override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
|
||||
override def containsVariable(variable: String): Boolean = false
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set.empty
|
||||
}
|
||||
|
||||
case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression {
|
||||
@ -64,6 +76,7 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
|
||||
lo.replaceVariable(variable, actualParam))
|
||||
override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable)
|
||||
override def isPure: Boolean = hi.isPure && lo.isPure
|
||||
override def getAllIdentifiers: Set[String] = hi.getAllIdentifiers ++ lo.getAllIdentifiers
|
||||
}
|
||||
|
||||
case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
|
||||
@ -71,6 +84,7 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool
|
||||
SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal)
|
||||
override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable))
|
||||
override def isPure: Boolean = expressions.forall(_._2.isPure)
|
||||
override def getAllIdentifiers: Set[String] = expressions.map(_._2.getAllIdentifiers).fold(Set[String]())(_ ++ _)
|
||||
}
|
||||
|
||||
case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression {
|
||||
@ -80,6 +94,7 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
|
||||
})
|
||||
override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable))
|
||||
override def isPure: Boolean = false // TODO
|
||||
override def getAllIdentifiers: Set[String] = expressions.map(_.getAllIdentifiers).fold(Set[String]())(_ ++ _) + functionName
|
||||
}
|
||||
|
||||
case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
|
||||
@ -87,6 +102,7 @@ case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends E
|
||||
HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte)
|
||||
override def containsVariable(variable: String): Boolean = expression.containsVariable(variable)
|
||||
override def isPure: Boolean = expression.isPure
|
||||
override def getAllIdentifiers: Set[String] = expression.getAllIdentifiers
|
||||
}
|
||||
|
||||
sealed class NiceFunctionProperty(override val toString: String)
|
||||
@ -145,6 +161,7 @@ case class VariableExpression(name: String) extends LhsExpression {
|
||||
if (name == variable) actualParam else this
|
||||
override def containsVariable(variable: String): Boolean = name == variable
|
||||
override def isPure: Boolean = true
|
||||
override def getAllIdentifiers: Set[String] = Set(name)
|
||||
}
|
||||
|
||||
case class IndexedExpression(name: String, index: Expression) extends LhsExpression {
|
||||
@ -157,6 +174,7 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress
|
||||
} else IndexedExpression(name, index.replaceVariable(variable, actualParam))
|
||||
override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable)
|
||||
override def isPure: Boolean = index.isPure
|
||||
override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name
|
||||
}
|
||||
|
||||
sealed trait Statement extends Node {
|
||||
@ -287,6 +305,10 @@ case class ReturnStatement(value: Option[Expression]) extends ExecutableStatemen
|
||||
override def getAllExpressions: List[Expression] = value.toList
|
||||
}
|
||||
|
||||
case class EmptyStatement(toTypecheck: List[ExecutableStatement]) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = toTypecheck.flatMap(_.getAllExpressions)
|
||||
}
|
||||
|
||||
trait ReturnDispatchLabel extends Node {
|
||||
def getAllExpressions: List[Expression]
|
||||
}
|
||||
|
@ -79,6 +79,7 @@ abstract class AbstractInliningCalculator[T <: AbstractCode] {
|
||||
s.name.stripSuffix(".addr.lo"),
|
||||
s.name.stripSuffix(".addr.hi")).toList.map(_ -> true)
|
||||
case s: LiteralExpression => Nil
|
||||
case s: GeneratedConstantExpression => Nil
|
||||
case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
|
||||
case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2))
|
||||
case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs)
|
||||
|
@ -0,0 +1,58 @@
|
||||
package millfork.test
|
||||
|
||||
import millfork.Cpu
|
||||
import millfork.test.emu.EmuCrossPlatformBenchmarkRun
|
||||
import org.scalatest.{FunSuite, Matchers}
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
class StatementOptimizationSuite extends FunSuite with Matchers {
|
||||
|
||||
test("Statement optimization 1") {
|
||||
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(
|
||||
"""
|
||||
| array output[10] @$c000
|
||||
| void main() {
|
||||
| byte i
|
||||
| for i,0,paralleluntil,output.length {
|
||||
| output[i] = f(i)
|
||||
| }
|
||||
| }
|
||||
| noinline byte f(byte a) {
|
||||
| byte b
|
||||
| byte c
|
||||
| byte d
|
||||
| byte e
|
||||
| byte f
|
||||
| b = a
|
||||
| c = 5
|
||||
| d = c
|
||||
| b += c
|
||||
| if a > 4 {
|
||||
| b += 1
|
||||
| } else {
|
||||
| b += 2
|
||||
| }
|
||||
| e = 4
|
||||
| f = e
|
||||
| while a > 0 {
|
||||
| d += f
|
||||
| a -= 1
|
||||
| }
|
||||
| return b + d
|
||||
| }
|
||||
""".stripMargin) { m=>
|
||||
m.readByte(0xc000) should equal(12)
|
||||
m.readByte(0xc001) should equal(17)
|
||||
m.readByte(0xc002) should equal(22)
|
||||
m.readByte(0xc003) should equal(27)
|
||||
m.readByte(0xc004) should equal(32)
|
||||
m.readByte(0xc005) should equal(36)
|
||||
m.readByte(0xc006) should equal(41)
|
||||
m.readByte(0xc007) should equal(46)
|
||||
m.readByte(0xc008) should equal(51)
|
||||
m.readByte(0xc009) should equal(56)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user