From 025be8cb7cb91aedb3cf06df28f51cc7025873ac Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 29 Jul 2019 22:06:59 +0200 Subject: [PATCH] fix infinte loop in constantfolding of when choices --- compiler/src/prog8/ast/AstToSourceCode.kt | 9 +-- .../src/prog8/ast/processing/AstChecker.kt | 2 +- .../ast/processing/IAstModifyingVisitor.kt | 13 ++-- .../ast/processing/StatementReorderer.kt | 5 +- .../src/prog8/ast/statements/AstStatements.kt | 9 ++- .../src/prog8/optimizer/ConstantFolding.kt | 18 ++--- compiler/src/prog8/vm/astvm/AstVm.kt | 2 +- examples/test.p8 | 71 +++++++++++++++++-- 8 files changed, 98 insertions(+), 31 deletions(-) diff --git a/compiler/src/prog8/ast/AstToSourceCode.kt b/compiler/src/prog8/ast/AstToSourceCode.kt index 3b33f6d05..dd7bbf48d 100644 --- a/compiler/src/prog8/ast/AstToSourceCode.kt +++ b/compiler/src/prog8/ast/AstToSourceCode.kt @@ -343,7 +343,7 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program): } override fun visit(repeatLoop: RepeatLoop) { - outputln("repeat ") + output("repeat ") repeatLoop.body.accept(this) output(" until ") repeatLoop.untilCondition.accept(this) @@ -427,13 +427,14 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program): } override fun visit(whenChoice: WhenChoice) { - if(whenChoice.values==null) + val choiceValues = whenChoice.values + if(choiceValues==null) outputi("else -> ") else { outputi("") - for(value in whenChoice.values) { + for(value in choiceValues) { value.accept(this) - if(value !== whenChoice.values.last()) + if(value !== choiceValues.last()) output(",") } output(" -> ") diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index c4976c036..17674df63 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -962,7 +962,7 @@ internal class AstChecker(private val program: Program, val whenStmt = whenChoice.parent as WhenStatement if(whenChoice.values!=null) { val conditionType = whenStmt.condition.inferType(program) - val constvalues = whenChoice.values.map { it.constValue(program) } + val constvalues = whenChoice.values!!.map { it.constValue(program) } for(constvalue in constvalues) { when { constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position)) diff --git a/compiler/src/prog8/ast/processing/IAstModifyingVisitor.kt b/compiler/src/prog8/ast/processing/IAstModifyingVisitor.kt index ff382d69b..f3c7d71a1 100644 --- a/compiler/src/prog8/ast/processing/IAstModifyingVisitor.kt +++ b/compiler/src/prog8/ast/processing/IAstModifyingVisitor.kt @@ -12,7 +12,7 @@ interface IAstModifyingVisitor { } fun visit(module: Module) { - module.statements = module.statements.asSequence().map { it.accept(this) }.toMutableList() + module.statements = module.statements.map { it.accept(this) }.toMutableList() } fun visit(expr: PrefixExpression): Expression { @@ -31,18 +31,18 @@ interface IAstModifyingVisitor { } fun visit(block: Block): Statement { - block.statements = block.statements.asSequence().map { it.accept(this) }.toMutableList() + block.statements = block.statements.map { it.accept(this) }.toMutableList() return block } fun visit(decl: VarDecl): Statement { decl.value = decl.value?.accept(this) - decl.arraysize = decl.arraysize?.accept(this) + decl.arraysize?.accept(this) return decl } fun visit(subroutine: Subroutine): Statement { - subroutine.statements = subroutine.statements.asSequence().map { it.accept(this) }.toMutableList() + subroutine.statements = subroutine.statements.map { it.accept(this) }.toMutableList() return subroutine } @@ -191,7 +191,7 @@ interface IAstModifyingVisitor { } fun visit(scope: AnonymousScope): Statement { - scope.statements = scope.statements.asSequence().map { it.accept(this) }.toMutableList() + scope.statements = scope.statements.map { it.accept(this) }.toMutableList() return scope } @@ -241,12 +241,13 @@ interface IAstModifyingVisitor { } fun visit(whenChoice: WhenChoice) { - whenChoice.values?.forEach { it.accept(this) } + whenChoice.values = whenChoice.values?.map { it.accept(this) } val stmt = whenChoice.statements.accept(this) if(stmt is AnonymousScope) whenChoice.statements = stmt else { whenChoice.statements = AnonymousScope(mutableListOf(stmt), stmt.position) + whenChoice.statements.linkParents(whenChoice) } } diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index eb6a4bf7a..6955cd2c4 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -333,9 +333,10 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi // make sure all choices are just for one single value val choices = whenStatement.choices.toList() for(choice in choices) { - if(choice.values==null || choice.values.size==1) + val choiceValues = choice.values + if(choiceValues==null || choiceValues.size==1) continue - for(v in choice.values) { + for(v in choiceValues) { val newchoice=WhenChoice(listOf(v), choice.statements, choice.position) newchoice.parent = choice.parent whenStatement.choices.add(newchoice) diff --git a/compiler/src/prog8/ast/statements/AstStatements.kt b/compiler/src/prog8/ast/statements/AstStatements.kt index f34e58138..7e8dc41f4 100644 --- a/compiler/src/prog8/ast/statements/AstStatements.kt +++ b/compiler/src/prog8/ast/statements/AstStatements.kt @@ -303,9 +303,8 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node } } - fun accept(visitor: IAstModifyingVisitor): ArrayIndex { + fun accept(visitor: IAstModifyingVisitor) { index = index.accept(visitor) - return this } fun accept(visitor: IAstVisitor) { @@ -750,7 +749,7 @@ class RepeatLoop(var body: AnonymousScope, } class WhenStatement(var condition: Expression, - val choices: MutableList, + var choices: MutableList, override val position: Position): Statement() { override lateinit var parent: Node override val expensiveToInline: Boolean = true @@ -768,7 +767,7 @@ class WhenStatement(var condition: Expression, if(choice.values==null) result.add(null to choice) else { - val values = choice.values.map { it.constValue(program)?.number?.toInt() } + val values = choice.values!!.map { it.constValue(program)?.number?.toInt() } if(values.contains(null)) result.add(null to choice) else @@ -782,7 +781,7 @@ class WhenStatement(var condition: Expression, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) } -class WhenChoice(val values: List?, // if null, this is the 'else' part +class WhenChoice(var values: List?, // if null, this is the 'else' part var statements: AnonymousScope, override val position: Position) : Node { override lateinit var parent: Node diff --git a/compiler/src/prog8/optimizer/ConstantFolding.kt b/compiler/src/prog8/optimizer/ConstantFolding.kt index df4ff36a2..3df40d952 100644 --- a/compiler/src/prog8/optimizer/ConstantFolding.kt +++ b/compiler/src/prog8/optimizer/ConstantFolding.kt @@ -224,14 +224,16 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { */ override fun visit(expr: PrefixExpression): Expression { return try { - super.visit(expr) + val prefixExpr=super.visit(expr) + if(prefixExpr !is PrefixExpression) + return prefixExpr - val subexpr = expr.expression + val subexpr = prefixExpr.expression if (subexpr is NumericLiteralValue) { // accept prefixed literal values (such as -3, not true) return when { - expr.operator == "+" -> subexpr - expr.operator == "-" -> when { + prefixExpr.operator == "+" -> subexpr + prefixExpr.operator == "-" -> when { subexpr.type in IntegerDatatypes -> { optimizationsDone++ NumericLiteralValue.optimalNumeric(-subexpr.number.toInt(), subexpr.position) @@ -242,21 +244,21 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { } else -> throw ExpressionError("can only take negative of int or float", subexpr.position) } - expr.operator == "~" -> when { + prefixExpr.operator == "~" -> when { subexpr.type in IntegerDatatypes -> { optimizationsDone++ NumericLiteralValue.optimalNumeric(subexpr.number.toInt().inv(), subexpr.position) } else -> throw ExpressionError("can only take bitwise inversion of int", subexpr.position) } - expr.operator == "not" -> { + prefixExpr.operator == "not" -> { optimizationsDone++ NumericLiteralValue.fromBoolean(subexpr.number.toDouble() == 0.0, subexpr.position) } - else -> throw ExpressionError(expr.operator, subexpr.position) + else -> throw ExpressionError(prefixExpr.operator, subexpr.position) } } - return expr + return prefixExpr } catch (ax: AstException) { addError(ax) expr diff --git a/compiler/src/prog8/vm/astvm/AstVm.kt b/compiler/src/prog8/vm/astvm/AstVm.kt index 3f27aabc1..86e3e1a12 100644 --- a/compiler/src/prog8/vm/astvm/AstVm.kt +++ b/compiler/src/prog8/vm/astvm/AstVm.kt @@ -518,7 +518,7 @@ class AstVm(val program: Program) { executeAnonymousScope(choice.statements) break } else { - val value = choice.values.single().constValue(evalCtx.program) ?: throw VmExecutionException("can only use const values in when choices ${choice.position}") + val value = choice.values!!.single().constValue(evalCtx.program) ?: throw VmExecutionException("can only use const values in when choices ${choice.position}") val rtval = RuntimeValue.fromLv(value) if(condition==rtval) { executeAnonymousScope(choice.statements) diff --git a/examples/test.p8 b/examples/test.p8 index 76e1deee5..5ac4fea35 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -6,10 +6,73 @@ sub start() { uword target = 4444 - @($d020) = A - @($d020) = A+4 - @(target) = A+4 - @(target+4) = A+4 +; @($d020) = A +; @($d020) = A+4 +; @(target) = A+4 +; @(target+4) = A+4 + whenubyte(20) + whenubyte(111) + whenbyte(-10) + whenbyte(-111) + whenbyte(0) + + whenuword(500) + whenuword(44) + whenword(-3000) + whenword(-44) + whenword(0) + + sub whenbyte(byte value) { + when value { + -4 -> c64scr.print("minusfour") + -5 -> c64scr.print("minusfive") + -10,-20,-30 -> { + c64scr.print("minusten or twenty or thirty") + } + -99 -> c64scr.print("minusninetynine") + else -> c64scr.print("don't know") + } + c64.CHROUT('\n') + } + + sub whenubyte(ubyte value) { + when value { + 4 -> c64scr.print("four") + 5 -> c64scr.print("five") + 10,20,30 -> { + c64scr.print("ten or twenty or thirty") + } + 99 -> c64scr.print("ninetynine") + else -> c64scr.print("don't know") + } + c64.CHROUT('\n') + } + + sub whenuword(uword value) { + when value { + 400 -> c64scr.print("four100") + 500 -> c64scr.print("five100") + 1000,2000,3000 -> { + c64scr.print("thousand 2thousand or 3thousand") + } + 9999 -> c64scr.print("ninetynine99") + else -> c64scr.print("don't know") + } + c64.CHROUT('\n') + } + + sub whenword(word value) { + when value { + -400 -> c64scr.print("minusfour100") + -500 -> c64scr.print("minusfive100") + -1000,-2000,-3000 -> { + c64scr.print("minusthousand 2thousand or 3thousand") + } + -9999 -> c64scr.print("minusninetynine99") + else -> c64scr.print("don't know") + } + c64.CHROUT('\n') + } } }