fix infinte loop in constantfolding of when choices

This commit is contained in:
Irmen de Jong 2019-07-29 22:06:59 +02:00
parent 3aea32551b
commit 025be8cb7c
8 changed files with 98 additions and 31 deletions

View File

@ -343,7 +343,7 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
} }
override fun visit(repeatLoop: RepeatLoop) { override fun visit(repeatLoop: RepeatLoop) {
outputln("repeat ") output("repeat ")
repeatLoop.body.accept(this) repeatLoop.body.accept(this)
output(" until ") output(" until ")
repeatLoop.untilCondition.accept(this) repeatLoop.untilCondition.accept(this)
@ -427,13 +427,14 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
} }
override fun visit(whenChoice: WhenChoice) { override fun visit(whenChoice: WhenChoice) {
if(whenChoice.values==null) val choiceValues = whenChoice.values
if(choiceValues==null)
outputi("else -> ") outputi("else -> ")
else { else {
outputi("") outputi("")
for(value in whenChoice.values) { for(value in choiceValues) {
value.accept(this) value.accept(this)
if(value !== whenChoice.values.last()) if(value !== choiceValues.last())
output(",") output(",")
} }
output(" -> ") output(" -> ")

View File

@ -962,7 +962,7 @@ internal class AstChecker(private val program: Program,
val whenStmt = whenChoice.parent as WhenStatement val whenStmt = whenChoice.parent as WhenStatement
if(whenChoice.values!=null) { if(whenChoice.values!=null) {
val conditionType = whenStmt.condition.inferType(program) val conditionType = whenStmt.condition.inferType(program)
val constvalues = whenChoice.values.map { it.constValue(program) } val constvalues = whenChoice.values!!.map { it.constValue(program) }
for(constvalue in constvalues) { for(constvalue in constvalues) {
when { when {
constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position)) constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position))

View File

@ -12,7 +12,7 @@ interface IAstModifyingVisitor {
} }
fun visit(module: Module) { fun visit(module: Module) {
module.statements = module.statements.asSequence().map { it.accept(this) }.toMutableList() module.statements = module.statements.map { it.accept(this) }.toMutableList()
} }
fun visit(expr: PrefixExpression): Expression { fun visit(expr: PrefixExpression): Expression {
@ -31,18 +31,18 @@ interface IAstModifyingVisitor {
} }
fun visit(block: Block): Statement { fun visit(block: Block): Statement {
block.statements = block.statements.asSequence().map { it.accept(this) }.toMutableList() block.statements = block.statements.map { it.accept(this) }.toMutableList()
return block return block
} }
fun visit(decl: VarDecl): Statement { fun visit(decl: VarDecl): Statement {
decl.value = decl.value?.accept(this) decl.value = decl.value?.accept(this)
decl.arraysize = decl.arraysize?.accept(this) decl.arraysize?.accept(this)
return decl return decl
} }
fun visit(subroutine: Subroutine): Statement { fun visit(subroutine: Subroutine): Statement {
subroutine.statements = subroutine.statements.asSequence().map { it.accept(this) }.toMutableList() subroutine.statements = subroutine.statements.map { it.accept(this) }.toMutableList()
return subroutine return subroutine
} }
@ -191,7 +191,7 @@ interface IAstModifyingVisitor {
} }
fun visit(scope: AnonymousScope): Statement { fun visit(scope: AnonymousScope): Statement {
scope.statements = scope.statements.asSequence().map { it.accept(this) }.toMutableList() scope.statements = scope.statements.map { it.accept(this) }.toMutableList()
return scope return scope
} }
@ -241,12 +241,13 @@ interface IAstModifyingVisitor {
} }
fun visit(whenChoice: WhenChoice) { fun visit(whenChoice: WhenChoice) {
whenChoice.values?.forEach { it.accept(this) } whenChoice.values = whenChoice.values?.map { it.accept(this) }
val stmt = whenChoice.statements.accept(this) val stmt = whenChoice.statements.accept(this)
if(stmt is AnonymousScope) if(stmt is AnonymousScope)
whenChoice.statements = stmt whenChoice.statements = stmt
else { else {
whenChoice.statements = AnonymousScope(mutableListOf(stmt), stmt.position) whenChoice.statements = AnonymousScope(mutableListOf(stmt), stmt.position)
whenChoice.statements.linkParents(whenChoice)
} }
} }

View File

@ -333,9 +333,10 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
// make sure all choices are just for one single value // make sure all choices are just for one single value
val choices = whenStatement.choices.toList() val choices = whenStatement.choices.toList()
for(choice in choices) { for(choice in choices) {
if(choice.values==null || choice.values.size==1) val choiceValues = choice.values
if(choiceValues==null || choiceValues.size==1)
continue continue
for(v in choice.values) { for(v in choiceValues) {
val newchoice=WhenChoice(listOf(v), choice.statements, choice.position) val newchoice=WhenChoice(listOf(v), choice.statements, choice.position)
newchoice.parent = choice.parent newchoice.parent = choice.parent
whenStatement.choices.add(newchoice) whenStatement.choices.add(newchoice)

View File

@ -303,9 +303,8 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node
} }
} }
fun accept(visitor: IAstModifyingVisitor): ArrayIndex { fun accept(visitor: IAstModifyingVisitor) {
index = index.accept(visitor) index = index.accept(visitor)
return this
} }
fun accept(visitor: IAstVisitor) { fun accept(visitor: IAstVisitor) {
@ -750,7 +749,7 @@ class RepeatLoop(var body: AnonymousScope,
} }
class WhenStatement(var condition: Expression, class WhenStatement(var condition: Expression,
val choices: MutableList<WhenChoice>, var choices: MutableList<WhenChoice>,
override val position: Position): Statement() { override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline: Boolean = true override val expensiveToInline: Boolean = true
@ -768,7 +767,7 @@ class WhenStatement(var condition: Expression,
if(choice.values==null) if(choice.values==null)
result.add(null to choice) result.add(null to choice)
else { else {
val values = choice.values.map { it.constValue(program)?.number?.toInt() } val values = choice.values!!.map { it.constValue(program)?.number?.toInt() }
if(values.contains(null)) if(values.contains(null))
result.add(null to choice) result.add(null to choice)
else else
@ -782,7 +781,7 @@ class WhenStatement(var condition: Expression,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
} }
class WhenChoice(val values: List<Expression>?, // if null, this is the 'else' part class WhenChoice(var values: List<Expression>?, // if null, this is the 'else' part
var statements: AnonymousScope, var statements: AnonymousScope,
override val position: Position) : Node { override val position: Position) : Node {
override lateinit var parent: Node override lateinit var parent: Node

View File

@ -224,14 +224,16 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
*/ */
override fun visit(expr: PrefixExpression): Expression { override fun visit(expr: PrefixExpression): Expression {
return try { return try {
super.visit(expr) val prefixExpr=super.visit(expr)
if(prefixExpr !is PrefixExpression)
return prefixExpr
val subexpr = expr.expression val subexpr = prefixExpr.expression
if (subexpr is NumericLiteralValue) { if (subexpr is NumericLiteralValue) {
// accept prefixed literal values (such as -3, not true) // accept prefixed literal values (such as -3, not true)
return when { return when {
expr.operator == "+" -> subexpr prefixExpr.operator == "+" -> subexpr
expr.operator == "-" -> when { prefixExpr.operator == "-" -> when {
subexpr.type in IntegerDatatypes -> { subexpr.type in IntegerDatatypes -> {
optimizationsDone++ optimizationsDone++
NumericLiteralValue.optimalNumeric(-subexpr.number.toInt(), subexpr.position) NumericLiteralValue.optimalNumeric(-subexpr.number.toInt(), subexpr.position)
@ -242,21 +244,21 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
else -> throw ExpressionError("can only take negative of int or float", subexpr.position) else -> throw ExpressionError("can only take negative of int or float", subexpr.position)
} }
expr.operator == "~" -> when { prefixExpr.operator == "~" -> when {
subexpr.type in IntegerDatatypes -> { subexpr.type in IntegerDatatypes -> {
optimizationsDone++ optimizationsDone++
NumericLiteralValue.optimalNumeric(subexpr.number.toInt().inv(), subexpr.position) NumericLiteralValue.optimalNumeric(subexpr.number.toInt().inv(), subexpr.position)
} }
else -> throw ExpressionError("can only take bitwise inversion of int", subexpr.position) else -> throw ExpressionError("can only take bitwise inversion of int", subexpr.position)
} }
expr.operator == "not" -> { prefixExpr.operator == "not" -> {
optimizationsDone++ optimizationsDone++
NumericLiteralValue.fromBoolean(subexpr.number.toDouble() == 0.0, subexpr.position) NumericLiteralValue.fromBoolean(subexpr.number.toDouble() == 0.0, subexpr.position)
} }
else -> throw ExpressionError(expr.operator, subexpr.position) else -> throw ExpressionError(prefixExpr.operator, subexpr.position)
} }
} }
return expr return prefixExpr
} catch (ax: AstException) { } catch (ax: AstException) {
addError(ax) addError(ax)
expr expr

View File

@ -518,7 +518,7 @@ class AstVm(val program: Program) {
executeAnonymousScope(choice.statements) executeAnonymousScope(choice.statements)
break break
} else { } else {
val value = choice.values.single().constValue(evalCtx.program) ?: throw VmExecutionException("can only use const values in when choices ${choice.position}") val value = choice.values!!.single().constValue(evalCtx.program) ?: throw VmExecutionException("can only use const values in when choices ${choice.position}")
val rtval = RuntimeValue.fromLv(value) val rtval = RuntimeValue.fromLv(value)
if(condition==rtval) { if(condition==rtval) {
executeAnonymousScope(choice.statements) executeAnonymousScope(choice.statements)

View File

@ -6,10 +6,73 @@
sub start() { sub start() {
uword target = 4444 uword target = 4444
@($d020) = A ; @($d020) = A
@($d020) = A+4 ; @($d020) = A+4
@(target) = A+4 ; @(target) = A+4
@(target+4) = A+4 ; @(target+4) = A+4
whenubyte(20)
whenubyte(111)
whenbyte(-10)
whenbyte(-111)
whenbyte(0)
whenuword(500)
whenuword(44)
whenword(-3000)
whenword(-44)
whenword(0)
sub whenbyte(byte value) {
when value {
-4 -> c64scr.print("minusfour")
-5 -> c64scr.print("minusfive")
-10,-20,-30 -> {
c64scr.print("minusten or twenty or thirty")
}
-99 -> c64scr.print("minusninetynine")
else -> c64scr.print("don't know")
}
c64.CHROUT('\n')
}
sub whenubyte(ubyte value) {
when value {
4 -> c64scr.print("four")
5 -> c64scr.print("five")
10,20,30 -> {
c64scr.print("ten or twenty or thirty")
}
99 -> c64scr.print("ninetynine")
else -> c64scr.print("don't know")
}
c64.CHROUT('\n')
}
sub whenuword(uword value) {
when value {
400 -> c64scr.print("four100")
500 -> c64scr.print("five100")
1000,2000,3000 -> {
c64scr.print("thousand 2thousand or 3thousand")
}
9999 -> c64scr.print("ninetynine99")
else -> c64scr.print("don't know")
}
c64.CHROUT('\n')
}
sub whenword(word value) {
when value {
-400 -> c64scr.print("minusfour100")
-500 -> c64scr.print("minusfive100")
-1000,-2000,-3000 -> {
c64scr.print("minusthousand 2thousand or 3thousand")
}
-9999 -> c64scr.print("minusninetynine99")
else -> c64scr.print("don't know")
}
c64.CHROUT('\n')
}
} }
} }