diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 00fec4ed..644a3853 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -152,6 +152,10 @@ case class FunctionDeclarationStatement(name: String, sealed trait ExecutableStatement extends Statement +sealed trait CompoundStatement extends ExecutableStatement { + def getChildStatements: Seq[Statement] +} + case class ExpressionStatement(expression: Expression) extends ExecutableStatement { override def getAllExpressions: List[Expression] = List(expression) } @@ -188,24 +192,32 @@ case class AssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, exp override def getAllExpressions: List[Expression] = List(expression) } -case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends ExecutableStatement { +case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions) + + override def getChildStatements: Seq[Statement] = thenBranch ++ elseBranch } -case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends ExecutableStatement { +case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) + + override def getChildStatements: Seq[Statement] = body ++ increment } object ForDirection extends Enumeration { val To, Until, DownTo, ParallelTo, ParallelUntil = Value } -case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends ExecutableStatement { +case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends CompoundStatement { override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: body.flatMap(_.getAllExpressions) + + override def getChildStatements: Seq[Statement] = body } -case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends ExecutableStatement { +case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) + + override def getChildStatements: Seq[Statement] = body ++ increment } case class BreakStatement(label: String) extends ExecutableStatement { diff --git a/src/main/scala/millfork/output/InliningCalculator.scala b/src/main/scala/millfork/output/InliningCalculator.scala index 72a91dea..97345305 100644 --- a/src/main/scala/millfork/output/InliningCalculator.scala +++ b/src/main/scala/millfork/output/InliningCalculator.scala @@ -42,7 +42,7 @@ object InliningCalculator { || f.interrupt || f.reentrant || f.name == "main" - || f.statements.exists(_.lastOption.exists(_.isInstanceOf[ReturnDispatchStatement]))) badFunctions += f.name + || containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name case _ => } allFunctions --= badFunctions @@ -55,6 +55,12 @@ object InliningCalculator { InliningResult(map, badFunctions.toSet) } + private def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists { + case _: ReturnDispatchStatement => true + case c: CompoundStatement => containsReturnDispatch(c.getChildStatements) + case _ => false + } + private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) case ReturnDispatchStatement(index, params, branches) =>