1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-11 12:29:46 +00:00

Prevent all functions with return dispatch from being inlined

This commit is contained in:
Karol Stasiak 2018-03-27 00:38:40 +02:00
parent 0231e4c4fd
commit cb3d848d0a
2 changed files with 23 additions and 5 deletions

View File

@ -152,6 +152,10 @@ case class FunctionDeclarationStatement(name: String,
sealed trait ExecutableStatement extends Statement
sealed trait CompoundStatement extends ExecutableStatement {
def getChildStatements: Seq[Statement]
}
case class ExpressionStatement(expression: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(expression)
}
@ -188,24 +192,32 @@ case class AssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, exp
override def getAllExpressions: List[Expression] = List(expression)
}
case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends ExecutableStatement {
case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends CompoundStatement {
override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions)
override def getChildStatements: Seq[Statement] = thenBranch ++ elseBranch
}
case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends ExecutableStatement {
case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends CompoundStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
override def getChildStatements: Seq[Statement] = body ++ increment
}
object ForDirection extends Enumeration {
val To, Until, DownTo, ParallelTo, ParallelUntil = Value
}
case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends ExecutableStatement {
case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends CompoundStatement {
override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: body.flatMap(_.getAllExpressions)
override def getChildStatements: Seq[Statement] = body
}
case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends ExecutableStatement {
case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends CompoundStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
override def getChildStatements: Seq[Statement] = body ++ increment
}
case class BreakStatement(label: String) extends ExecutableStatement {

View File

@ -42,7 +42,7 @@ object InliningCalculator {
|| f.interrupt
|| f.reentrant
|| f.name == "main"
|| f.statements.exists(_.lastOption.exists(_.isInstanceOf[ReturnDispatchStatement]))) badFunctions += f.name
|| containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name
case _ =>
}
allFunctions --= badFunctions
@ -55,6 +55,12 @@ object InliningCalculator {
InliningResult(map, badFunctions.toSet)
}
private def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists {
case _: ReturnDispatchStatement => true
case c: CompoundStatement => containsReturnDispatch(c.getChildStatements)
case _ => false
}
private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case ReturnDispatchStatement(index, params, branches) =>