1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-12-23 23:30:22 +00:00

Added break and continue statements

This commit is contained in:
Karol Stasiak 2018-03-03 21:34:12 +01:00
parent ccb6e35a29
commit 50ddd52786
8 changed files with 229 additions and 54 deletions

View File

@ -20,6 +20,8 @@
* Added return dispatch statements.
* Added `break` and `continue` statements.
* Added octal and quaternary literals.
* Fixed several optimization bugs.

View File

@ -14,6 +14,7 @@ object OptimizationPresets {
UnusedGlobalVariables,
)
val AssOpt: List[AssemblyOptimization] = List[AssemblyOptimization](
UnusedLabelRemoval,
AlwaysGoodOptimizations.NonetAddition,
AlwaysGoodOptimizations.PointlessSignCheck,
AlwaysGoodOptimizations.PoinlessLoadBeforeAnotherLoad,
@ -125,6 +126,7 @@ object OptimizationPresets {
)
val Good: List[AssemblyOptimization] = List[AssemblyOptimization](
UnusedLabelRemoval,
AlwaysGoodOptimizations.Adc0Optimization,
AlwaysGoodOptimizations.BitPackingUnpacking,
AlwaysGoodOptimizations.BranchInPlaceRemoval,
@ -185,6 +187,7 @@ object OptimizationPresets {
)
val QuickPreset: List[AssemblyOptimization] = List[AssemblyOptimization](
UnusedLabelRemoval,
AlwaysGoodOptimizations.Adc0Optimization,
AlwaysGoodOptimizations.BranchInPlaceRemoval,
AlwaysGoodOptimizations.CommonBranchBodyOptimization,

View File

@ -1,18 +1,32 @@
package millfork.compiler
import millfork.{CompilationFlag, CompilationOptions}
import millfork.env.{Environment, MangledFunction, NormalFunction}
import millfork.env.{Environment, Label, NormalFunction}
/**
* @author Karol Stasiak
*/
case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){
case class CompilationContext(env: Environment,
function: NormalFunction,
extraStackOffset: Int,
options: CompilationOptions,
breakLabels: Map[String, Label] = Map(),
continueLabels: Map[String, Label] = Map()){
def withInlinedEnv(environment: Environment): CompilationContext = {
val newEnv = new Environment(Some(env), MfCompiler.nextLabel("en"))
newEnv.things ++= environment.things
copy(env = newEnv)
}
def addLabels(names: Set[String], breakLabel: Label, continueLabel: Label): CompilationContext = {
var b = breakLabels
var c = continueLabels
names.foreach{name =>
b += (name -> breakLabel)
c += (name -> continueLabel)
}
this.copy(breakLabels = b, continueLabels = c)
}
def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i)

View File

@ -29,13 +29,15 @@ object MacroExpander {
case ReturnDispatchStatement(i,ps, bs) => ReturnDispatchStatement(i.replaceVariable(paramName, target), ps.map(fx), bs.map{
case ReturnDispatchBranch(l, fu, pps) => ReturnDispatchBranch(l, f(fu), pps.map(f))
})
case WhileStatement(c, b) => WhileStatement(f(c), b.map(gx))
case DoWhileStatement(b, c) => DoWhileStatement(b.map(gx), f(c))
case WhileStatement(c, b, i, n) => WhileStatement(f(c), b.map(gx), i.map(gx), n)
case DoWhileStatement(b, i, c, n) => DoWhileStatement(b.map(gx), i.map(gx), f(c), n)
case ForStatement(v, start, end, dir, body) => ForStatement(h(v), f(start), f(end), dir, body.map(gx))
case IfStatement(c, t, e) => IfStatement(f(c), t.map(gx), e.map(gx))
case s:AssemblyStatement => s.copy(expression = f(s.expression))
case Assignment(d,s) => Assignment(fx(d), f(s))
case BlockStatement(s) => BlockStatement(s.map(gx))
case BreakStatement(s) => if (s == paramName) BreakStatement(target.toString) else stmt
case ContinueStatement(s) => if (s == paramName) ContinueStatement(target.toString) else stmt
case _ =>
println(stmt)
???

View File

@ -111,6 +111,11 @@ object StatementCompiler {
ExpressionCompiler.compile(ctx, e, None, NoBranching)
}
case ExpressionStatement(e) =>
e match {
case VariableExpression(_) | LiteralExpression(_, _) =>
ErrorReporting.warn("Pointless expression statement", ctx.options, statement.position)
case _ =>
}
ExpressionCompiler.compile(ctx, e, None, NoBranching)
case BlockStatement(s) =>
s.flatMap(compile(ctx, _))
@ -227,74 +232,67 @@ object StatementCompiler {
ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
Nil
}
case WhileStatement(condition, bodyPart) =>
case WhileStatement(condition, bodyPart, incrementPart, labels) =>
val start = MfCompiler.nextLabel("wh")
val middle = MfCompiler.nextLabel("he")
val inc = MfCompiler.nextLabel("fp")
val end = MfCompiler.nextLabel("ew")
val condType = ExpressionCompiler.getExpressionType(ctx, condition)
val bodyBlock = compile(ctx, bodyPart)
val bodyBlock = compile(ctx.addLabels(labels, Label(end), Label(inc)), bodyPart)
val incrementBlock = compile(ctx.addLabels(labels, Label(end), Label(inc)), incrementPart)
val largeBodyBlock = bodyBlock.map(_.sizeInBytes).sum > 100
condType match {
case ConstantBooleanType(_, true) =>
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, NoBranching)
val start = MfCompiler.nextLabel("wh")
List(labelChunk(start), bodyBlock, jmpChunk(start)).flatten
List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
case ConstantBooleanType(_, false) => Nil
case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) =>
if (largeBodyBlock) {
val start = MfCompiler.nextLabel("wh")
val middle = MfCompiler.nextLabel("he")
val end = MfCompiler.nextLabel("ew")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, NoBranching)
List(labelChunk(start), conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), bodyBlock, jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), labelChunk(middle), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
} else {
val start = MfCompiler.nextLabel("wh")
val end = MfCompiler.nextLabel("ew")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, NoBranching)
List(labelChunk(start), conditionBlock, branchChunk(jumpIfFalse, end), bodyBlock, jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), conditionBlock, branchChunk(jumpIfFalse, end), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
}
case BuiltInBooleanType =>
if (largeBodyBlock) {
val start = MfCompiler.nextLabel("wh")
val middle = MfCompiler.nextLabel("he")
val end = MfCompiler.nextLabel("ew")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, BranchIfTrue(middle))
List(labelChunk(start), conditionBlock, jmpChunk(end), labelChunk(middle), bodyBlock, jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), conditionBlock, jmpChunk(end), labelChunk(middle), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
} else {
val start = MfCompiler.nextLabel("wh")
val end = MfCompiler.nextLabel("ew")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, BranchIfFalse(end))
List(labelChunk(start), conditionBlock, bodyBlock, jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), conditionBlock, bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
}
case _ =>
ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
Nil
}
case DoWhileStatement(bodyPart, condition) =>
case DoWhileStatement(bodyPart, incrementPart, condition, labels) =>
val start = MfCompiler.nextLabel("do")
val inc = MfCompiler.nextLabel("fp")
val end = MfCompiler.nextLabel("od")
val condType = ExpressionCompiler.getExpressionType(ctx, condition)
val bodyBlock = compile(ctx, bodyPart)
val bodyBlock = compile(ctx.addLabels(labels, Label(end), Label(inc)), bodyPart)
val incrementBlock = compile(ctx.addLabels(labels, Label(end), Label(inc)), incrementPart)
val largeBodyBlock = bodyBlock.map(_.sizeInBytes).sum > 100
condType match {
case ConstantBooleanType(_, true) =>
val start = MfCompiler.nextLabel("do")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, NoBranching)
List(labelChunk(start), bodyBlock, jmpChunk(start)).flatten
case ConstantBooleanType(_, false) => bodyBlock
List(labelChunk(start),bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten
case ConstantBooleanType(_, false) =>
List(bodyBlock, labelChunk(inc), incrementBlock, labelChunk(end)).flatten
case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) =>
val start = MfCompiler.nextLabel("do")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, NoBranching)
if (largeBodyBlock) {
val end = MfCompiler.nextLabel("od")
List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfFalse, end), jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, conditionBlock, branchChunk(jumpIfFalse, end), jmpChunk(start), labelChunk(end)).flatten
} else {
List(labelChunk(start), bodyBlock, conditionBlock, branchChunk(jumpIfTrue, start)).flatten
List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, conditionBlock, branchChunk(jumpIfTrue, start), labelChunk(end)).flatten
}
case BuiltInBooleanType =>
val start = MfCompiler.nextLabel("do")
if (largeBodyBlock) {
val end = MfCompiler.nextLabel("od")
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, BranchIfFalse(end))
List(labelChunk(start), bodyBlock, conditionBlock, jmpChunk(start), labelChunk(end)).flatten
List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, conditionBlock, jmpChunk(start), labelChunk(end)).flatten
} else {
val conditionBlock = ExpressionCompiler.compile(ctx, condition, someRegisterA, BranchIfTrue(start))
List(labelChunk(start), bodyBlock, conditionBlock).flatten
List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, conditionBlock, labelChunk(end)).flatten
}
case _ =>
ErrorReporting.error(s"Illegal type for a condition: `$condType`", condition.position)
@ -307,32 +305,36 @@ object StatementCompiler {
val one = LiteralExpression(1, 1)
val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)))
val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)))
val names = Set("", "for", variable)
(direction, env.eval(start), env.eval(end)) match {
case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 =>
compile(ctx, Assignment(vex, f.start) :: f.body)
val end = MfCompiler.nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start) :: f.body) ++ labelChunk(end)
case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s >= e =>
Nil
case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e =>
compile(ctx, Assignment(vex, f.start) :: f.body)
val end = MfCompiler.nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start) :: f.body) ++ labelChunk(end)
case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s > e =>
Nil
case (ForDirection.ParallelUntil, Some(NumericConstant(0, ssize)), Some(NumericConstant(e, _))) if e > 0 =>
compile(ctx, List(
Assignment(vex, f.end),
DoWhileStatement(decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)))
DoWhileStatement(Nil, decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)), names)
))
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e =>
compile(ctx, Assignment(vex, LiteralExpression(s, ssize)) :: f.body)
val end = MfCompiler.nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, LiteralExpression(s, ssize)) :: f.body) ++ labelChunk(end)
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s < e =>
Nil
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(0, esize))) if s > 0 =>
compile(ctx, List(
Assignment(vex, f.start),
DoWhileStatement(f.body :+ decrement, FunctionCallExpression("!=", List(vex, f.end)))
DoWhileStatement(f.body, List(decrement), FunctionCallExpression("!=", List(vex, f.end)), names)
))
@ -341,7 +343,7 @@ object StatementCompiler {
Assignment(vex, f.start),
WhileStatement(
FunctionCallExpression("<", List(vex, f.end)),
f.body :+ increment),
f.body, List(increment), names),
))
// case (ForDirection.To | ForDirection.ParallelTo, _, Some(NumericConstant(n, _))) if n > 0 && n < 255 =>
// compile(ctx, List(
@ -351,16 +353,17 @@ object StatementCompiler {
// f.body :+ increment),
// ))
case (ForDirection.To | ForDirection.ParallelTo, _, _) =>
val label = MfCompiler.nextLabel("to")
compile(ctx, List(
Assignment(vex, f.start),
WhileStatement(
VariableExpression("true"),
f.body :+ IfStatement(
f.body,
List(IfStatement(
FunctionCallExpression("==", List(vex, f.end)),
List(AssemblyStatement(JMP, AddrMode.Absolute, VariableExpression(label), elidable = true)),
List(increment))),
AssemblyStatement(LABEL, AddrMode.DoesNotExist, VariableExpression(label), elidable=true)
List(BreakStatement(variable)),
List(increment)
)),
names),
))
case (ForDirection.DownTo, _, _) =>
compile(ctx, List(
@ -368,12 +371,34 @@ object StatementCompiler {
IfStatement(
FunctionCallExpression(">=", List(vex, f.end)),
List(DoWhileStatement(
f.body :+ decrement,
FunctionCallExpression("!=", List(vex, f.end))
f.body,
List(decrement),
FunctionCallExpression("!=", List(vex, f.end)),
names
)),
Nil)
))
}
case BreakStatement(l) =>
ctx.breakLabels.get(l) match {
case None =>
if (l == "") ErrorReporting.error("`break` outside a loop", statement.position)
else ErrorReporting.error("Invalid label: " + l, statement.position)
Nil
case Some(label) =>
List(AssemblyLine.absolute(JMP, label))
}
case ContinueStatement(l) =>
ctx.continueLabels.get(l) match {
case None =>
if (l == "") ErrorReporting.error("`continue` outside a loop", statement.position)
else ErrorReporting.error("Invalid label: " + l, statement.position)
Nil
case Some(label) =>
List(AssemblyLine.absolute(JMP, label))
}
// TODO
}
}

View File

@ -176,7 +176,7 @@ case class IfStatement(condition: Expression, thenBranch: List[ExecutableStateme
override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions)
}
case class WhileStatement(condition: Expression, body: List[ExecutableStatement]) extends ExecutableStatement {
case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
}
@ -188,7 +188,7 @@ case class ForStatement(variable: String, start: Expression, end: Expression, di
override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: body.flatMap(_.getAllExpressions)
}
case class DoWhileStatement(body: List[ExecutableStatement], condition: Expression) extends ExecutableStatement {
case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions)
}
@ -196,6 +196,14 @@ case class BlockStatement(body: List[ExecutableStatement]) extends ExecutableSta
override def getAllExpressions: List[Expression] = body.flatMap(_.getAllExpressions)
}
case class BreakStatement(label: String) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = Nil
}
case class ContinueStatement(label: String) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = Nil
}
object AssemblyStatement {
def implied(opcode: Opcode.Value, elidable: Boolean) = AssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidable)

View File

@ -312,7 +312,16 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
case (p, l, r) => Assignment(l, r).pos(p)
}
def keywordStatement: P[ExecutableStatement] = P(returnOrDispatchStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement)
def keywordStatement: P[ExecutableStatement] = P(
returnOrDispatchStatement |
ifStatement |
whileStatement |
forStatement |
doWhileStatement |
breakStatement |
continueStatement |
inlineAssembly |
assignmentStatement)
def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) }
@ -404,6 +413,10 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
def returnOrDispatchStatement: P[ExecutableStatement] = "return" ~ !letterOrDigit ~/ HWS ~ (dispatchStatementBody | mlExpression(nonStatementLevel).?.map(ReturnStatement))
def breakStatement: P[ExecutableStatement] = ("break" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => BreakStatement(l.getOrElse("")))
def continueStatement: P[ExecutableStatement] = ("continue" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => ContinueStatement(l.getOrElse("")))
def ifStatement: P[ExecutableStatement] = for {
condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
thenBranch <- AWS ~/ executableStatements
@ -413,7 +426,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
def whileStatement: P[ExecutableStatement] = for {
condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
body <- AWS ~ executableStatements
} yield WhileStatement(condition, body.toList)
} yield WhileStatement(condition, body.toList, Nil)
def forDirection: P[ForDirection.Value] =
("parallel" ~ HWS ~ "to").!.map(_ => ForDirection.ParallelTo) |
@ -439,7 +452,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
def doWhileStatement: P[ExecutableStatement] = for {
body <- "do" ~ !letterOrDigit ~/ AWS ~ executableStatements ~/ AWS
condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)
} yield DoWhileStatement(body.toList, condition)
} yield DoWhileStatement(body.toList, Nil, condition)
def functionDefinition: P[DeclarationStatement] = for {
p <- position()

View File

@ -0,0 +1,108 @@
package millfork.test
import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedRun}
import org.scalatest.{FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class BreakContinueSuite extends FunSuite with Matchers {
test("Break from one-iteration loop 1") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = 0
| do {
| break
| output += 1
| } while false
| }
""".stripMargin)(_.readByte(0xc000) should equal(0))
}
test("Break from one-iteration loop 2") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = 0
| byte i
| for i,0,to,0 {
| break
| output += 1
| }
| }
""".stripMargin)(_.readByte(0xc000) should equal(0))
}
test("Break from infinite loop 1") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = 0
| while true {
| output += 1
| break
| output += 1
| }
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
}
test("Break and continue from infinite loop 1") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = 0
| while true {
| if output != 0 { break }
| output += 1
| continue
| output += 1
| }
| }
""".stripMargin)(_.readByte(0xc000) should equal(1))
}
test("Nested break") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = 0
| do {
| output += 1
| while true {
| break while
| }
| output += 1
| } while false
| }
""".stripMargin)(_.readByte(0xc000) should equal(2))
}
test("Continue in for loop 1") {
EmuBenchmarkRun(
"""
| byte output @$c000
| byte counter @$c001
| void main () {
| output = 0
| byte i
| for i,0,paralleluntil,50 {
| counter += 1
| if i != 30 { continue }
| output = i
| break
| }
| }
""".stripMargin){m =>
m.readByte(0xc000) should equal(30)
m.readByte(0xc001) should be > 10
}
}
}