using sealed class instead of interface

This commit is contained in:
Irmen de Jong
2019-07-17 02:27:13 +02:00
parent e03c68b632
commit a8898a5993
22 changed files with 498 additions and 505 deletions

View File

@@ -1,18 +1,167 @@
package prog8.ast package prog8.ast
import prog8.ast.base.FatalAstException import prog8.ast.base.*
import prog8.ast.base.NameError import prog8.ast.expressions.Expression
import prog8.ast.base.ParentSentinel import prog8.ast.expressions.IdentifierReference
import prog8.ast.base.Position import prog8.ast.statements.*
import prog8.ast.statements.Block
import prog8.ast.statements.Label
import prog8.ast.statements.Subroutine
import prog8.ast.statements.VarDecl
import prog8.compiler.HeapValues import prog8.compiler.HeapValues
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
import java.nio.file.Path import java.nio.file.Path
interface Node {
val position: Position
var parent: Node // will be linked correctly later (late init)
fun linkParents(parent: Node)
fun definingModule(): Module {
if(this is Module)
return this
return findParentNode<Module>(this)!!
}
fun definingSubroutine(): Subroutine? = findParentNode<Subroutine>(this)
fun definingScope(): INameScope {
val scope = findParentNode<INameScope>(this)
if(scope!=null) {
return scope
}
if(this is Label && this.name.startsWith("builtin::")) {
return BuiltinFunctionScopePlaceholder
}
if(this is GlobalNamespace)
return this
throw FatalAstException("scope missing from $this")
}
}
interface IFunctionCall {
var target: IdentifierReference
var arglist: MutableList<Expression>
}
interface INameScope {
val name: String
val position: Position
val statements: MutableList<Statement>
val parent: Node
fun linkParents(parent: Node)
fun subScopes(): Map<String, INameScope> {
val subscopes = mutableMapOf<String, INameScope>()
for(stmt in statements) {
when(stmt) {
// NOTE: if other nodes are introduced that are a scope, or contain subscopes, they must be added here!
is ForLoop -> subscopes[stmt.body.name] = stmt.body
is RepeatLoop -> subscopes[stmt.body.name] = stmt.body
is WhileLoop -> subscopes[stmt.body.name] = stmt.body
is BranchStatement -> {
subscopes[stmt.truepart.name] = stmt.truepart
if(stmt.elsepart.containsCodeOrVars())
subscopes[stmt.elsepart.name] = stmt.elsepart
}
is IfStatement -> {
subscopes[stmt.truepart.name] = stmt.truepart
if(stmt.elsepart.containsCodeOrVars())
subscopes[stmt.elsepart.name] = stmt.elsepart
}
is WhenStatement -> {
stmt.choices.forEach { subscopes[it.statements.name] = it.statements }
}
is INameScope -> subscopes[stmt.name] = stmt
else -> {}
}
}
return subscopes
}
fun getLabelOrVariable(name: String): Statement? {
// this is called A LOT and could perhaps be optimized a bit more,
// but adding a memoization cache didn't make much of a practical runtime difference
for (stmt in statements) {
if (stmt is VarDecl && stmt.name==name) return stmt
if (stmt is Label && stmt.name==name) return stmt
if (stmt is AnonymousScope) {
val sub = stmt.getLabelOrVariable(name)
if(sub!=null)
return sub
}
}
return null
}
fun allDefinedSymbols(): List<Pair<String, Statement>> {
return statements.mapNotNull {
when (it) {
is Label -> it.name to it
is VarDecl -> it.name to it
is Subroutine -> it.name to it
is Block -> it.name to it
else -> null
}
}
}
fun lookup(scopedName: List<String>, localContext: Node) : Statement? {
if(scopedName.size>1) {
// a scoped name can a) refer to a member of a struct, or b) refer to a name in another module.
// try the struct first.
val thing = lookup(scopedName.dropLast(1), localContext) as? VarDecl
val struct = thing?.struct
if (struct != null) {
if(struct.statements.any { (it as VarDecl).name == scopedName.last()}) {
// return ref to the mangled name variable
val mangled = mangledStructMemberName(thing.name, scopedName.last())
return thing.definingScope().getLabelOrVariable(mangled)
}
}
// it's a qualified name, look it up from the root of the module's namespace (consider all modules in the program)
for(module in localContext.definingModule().program.modules) {
var scope: INameScope? = module
for(name in scopedName.dropLast(1)) {
scope = scope?.subScopes()?.get(name)
if(scope==null)
break
}
if(scope!=null) {
val result = scope.getLabelOrVariable(scopedName.last())
if(result!=null)
return result
return scope.subScopes()[scopedName.last()] as Statement?
}
}
return null
} else {
// unqualified name, find the scope the localContext is in, look in that first
var statementScope = localContext
while(statementScope !is ParentSentinel) {
val localScope = statementScope.definingScope()
val result = localScope.getLabelOrVariable(scopedName[0])
if (result != null)
return result
val subscope = localScope.subScopes()[scopedName[0]] as Statement?
if (subscope != null)
return subscope
// not found in this scope, look one higher up
statementScope = statementScope.parent
}
return null
}
}
fun containsCodeOrVars() = statements.any { it !is Directive || it.directive == "%asminclude" || it.directive == "%asm"}
fun containsNoCodeNorVars() = !containsCodeOrVars()
fun remove(stmt: Statement) {
if(!statements.remove(stmt))
throw FatalAstException("stmt to remove wasn't found in scope")
}
}
/*********** Everything starts from here, the Program; zero or more modules *************/ /*********** Everything starts from here, the Program; zero or more modules *************/
class Program(val name: String, val modules: MutableList<Module>) { class Program(val name: String, val modules: MutableList<Module>) {
@@ -35,7 +184,7 @@ class Program(val name: String, val modules: MutableList<Module>) {
} }
class Module(override val name: String, class Module(override val name: String,
override var statements: MutableList<IStatement>, override var statements: MutableList<Statement>,
override val position: Position, override val position: Position,
val isLibraryModule: Boolean, val isLibraryModule: Boolean,
val source: Path) : Node, INameScope { val source: Path) : Node, INameScope {
@@ -59,14 +208,14 @@ class Module(override val name: String,
class GlobalNamespace(val modules: List<Module>): Node, INameScope { class GlobalNamespace(val modules: List<Module>): Node, INameScope {
override val name = "<<<global>>>" override val name = "<<<global>>>"
override val position = Position("<<<global>>>", 0, 0, 0) override val position = Position("<<<global>>>", 0, 0, 0)
override val statements = mutableListOf<IStatement>() override val statements = mutableListOf<Statement>()
override var parent: Node = ParentSentinel override var parent: Node = ParentSentinel
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
modules.forEach { it.linkParents(this) } modules.forEach { it.linkParents(this) }
} }
override fun lookup(scopedName: List<String>, localContext: Node): IStatement? { override fun lookup(scopedName: List<String>, localContext: Node): Statement? {
if (scopedName.size == 1 && scopedName[0] in BuiltinFunctions) { if (scopedName.size == 1 && scopedName[0] in BuiltinFunctions) {
// builtin functions always exist, return a dummy localContext for them // builtin functions always exist, return a dummy localContext for them
val builtinPlaceholder = Label("builtin::${scopedName.last()}", localContext.position) val builtinPlaceholder = Label("builtin::${scopedName.last()}", localContext.position)
@@ -100,7 +249,7 @@ class GlobalNamespace(val modules: List<Module>): Node, INameScope {
object BuiltinFunctionScopePlaceholder : INameScope { object BuiltinFunctionScopePlaceholder : INameScope {
override val name = "<<builtin-functions-scope-placeholder>>" override val name = "<<builtin-functions-scope-placeholder>>"
override val position = Position("<<placeholder>>", 0, 0, 0) override val position = Position("<<placeholder>>", 0, 0, 0)
override var statements = mutableListOf<IStatement>() override var statements = mutableListOf<Statement>()
override var parent: Node = ParentSentinel override var parent: Node = ParentSentinel
override fun linkParents(parent: Node) {} override fun linkParents(parent: Node) {}
} }

View File

@@ -1,224 +0,0 @@
package prog8.ast
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.*
// TODO sealed classes instead??
interface Node {
val position: Position
var parent: Node // will be linked correctly later (late init)
fun linkParents(parent: Node)
fun definingModule(): Module {
if(this is Module)
return this
return findParentNode<Module>(this)!!
}
fun definingSubroutine(): Subroutine? = findParentNode<Subroutine>(this)
fun definingScope(): INameScope {
val scope = findParentNode<INameScope>(this)
if(scope!=null) {
return scope
}
if(this is Label && this.name.startsWith("builtin::")) {
return BuiltinFunctionScopePlaceholder
}
if(this is GlobalNamespace)
return this
throw FatalAstException("scope missing from $this")
}
}
interface IStatement : Node {
fun accept(visitor: IAstModifyingVisitor) : IStatement
fun accept(visitor: IAstVisitor)
fun makeScopedName(name: String): String {
// easy way out is to always return the full scoped name.
// it would be nicer to find only the minimal prefixed scoped name, but that's too much hassle for now.
// and like this, we can cache the name even,
// like in a lazy property on the statement object itself (label, subroutine, vardecl)
val scope = mutableListOf<String>()
var statementScope = this.parent
while(statementScope !is ParentSentinel && statementScope !is Module) {
if(statementScope is INameScope) {
scope.add(0, statementScope.name)
}
statementScope = statementScope.parent
}
if(name.isNotEmpty())
scope.add(name)
return scope.joinToString(".")
}
val expensiveToInline: Boolean
fun definingBlock(): Block {
if(this is Block)
return this
return findParentNode<Block>(this)!!
}
}
interface IFunctionCall {
var target: IdentifierReference
var arglist: MutableList<IExpression>
}
interface INameScope {
val name: String
val position: Position
val statements: MutableList<IStatement>
val parent: Node
fun linkParents(parent: Node)
fun subScopes(): Map<String, INameScope> {
val subscopes = mutableMapOf<String, INameScope>()
for(stmt in statements) {
when(stmt) {
// NOTE: if other nodes are introduced that are a scope, or contain subscopes, they must be added here!
is ForLoop -> subscopes[stmt.body.name] = stmt.body
is RepeatLoop -> subscopes[stmt.body.name] = stmt.body
is WhileLoop -> subscopes[stmt.body.name] = stmt.body
is BranchStatement -> {
subscopes[stmt.truepart.name] = stmt.truepart
if(stmt.elsepart.containsCodeOrVars())
subscopes[stmt.elsepart.name] = stmt.elsepart
}
is IfStatement -> {
subscopes[stmt.truepart.name] = stmt.truepart
if(stmt.elsepart.containsCodeOrVars())
subscopes[stmt.elsepart.name] = stmt.elsepart
}
is WhenStatement -> {
stmt.choices.forEach { subscopes[it.statements.name] = it.statements }
}
is INameScope -> subscopes[stmt.name] = stmt
}
}
return subscopes
}
fun getLabelOrVariable(name: String): IStatement? {
// this is called A LOT and could perhaps be optimized a bit more,
// but adding a memoization cache didn't make much of a practical runtime difference
for (stmt in statements) {
if (stmt is VarDecl && stmt.name==name) return stmt
if (stmt is Label && stmt.name==name) return stmt
if (stmt is AnonymousScope) {
val sub = stmt.getLabelOrVariable(name)
if(sub!=null)
return sub
}
}
return null
}
fun allDefinedSymbols(): List<Pair<String, IStatement>> {
return statements.mapNotNull {
when (it) {
is Label -> it.name to it
is VarDecl -> it.name to it
is Subroutine -> it.name to it
is Block -> it.name to it
else -> null
}
}
}
fun lookup(scopedName: List<String>, localContext: Node) : IStatement? {
if(scopedName.size>1) {
// a scoped name can a) refer to a member of a struct, or b) refer to a name in another module.
// try the struct first.
val thing = lookup(scopedName.dropLast(1), localContext) as? VarDecl
val struct = thing?.struct
if (struct != null) {
if(struct.statements.any { (it as VarDecl).name == scopedName.last()}) {
// return ref to the mangled name variable
val mangled = mangledStructMemberName(thing.name, scopedName.last())
return thing.definingScope().getLabelOrVariable(mangled)
}
}
// it's a qualified name, look it up from the root of the module's namespace (consider all modules in the program)
for(module in localContext.definingModule().program.modules) {
var scope: INameScope? = module
for(name in scopedName.dropLast(1)) {
scope = scope?.subScopes()?.get(name)
if(scope==null)
break
}
if(scope!=null) {
val result = scope.getLabelOrVariable(scopedName.last())
if(result!=null)
return result
return scope.subScopes()[scopedName.last()] as IStatement?
}
}
return null
} else {
// unqualified name, find the scope the localContext is in, look in that first
var statementScope = localContext
while(statementScope !is ParentSentinel) {
val localScope = statementScope.definingScope()
val result = localScope.getLabelOrVariable(scopedName[0])
if (result != null)
return result
val subscope = localScope.subScopes()[scopedName[0]] as IStatement?
if (subscope != null)
return subscope
// not found in this scope, look one higher up
statementScope = statementScope.parent
}
return null
}
}
fun containsCodeOrVars() = statements.any { it !is Directive || it.directive == "%asminclude" || it.directive == "%asm"}
fun containsNoCodeNorVars() = !containsCodeOrVars()
fun remove(stmt: IStatement) {
if(!statements.remove(stmt))
throw FatalAstException("stmt to remove wasn't found in scope")
}
}
interface IExpression: Node {
fun constValue(program: Program): NumericLiteralValue?
fun accept(visitor: IAstModifyingVisitor): IExpression
fun accept(visitor: IAstVisitor)
fun referencesIdentifiers(vararg name: String): Boolean // todo: remove and use calltree instead
fun inferType(program: Program): DataType?
infix fun isSameAs(other: IExpression): Boolean {
if(this===other)
return true
when(this) {
is RegisterExpr ->
return (other is RegisterExpr && other.register==register)
is IdentifierReference ->
return (other is IdentifierReference && other.nameInSource==nameInSource)
is PrefixExpression ->
return (other is PrefixExpression && other.operator==operator && other.expression isSameAs expression)
is BinaryExpression ->
return (other is BinaryExpression && other.operator==operator
&& other.left isSameAs left
&& other.right isSameAs right)
is ArrayIndexedExpression -> {
return (other is ArrayIndexedExpression && other.identifier.nameInSource == identifier.nameInSource
&& other.arrayspec.index isSameAs arrayspec.index)
}
is NumericLiteralValue -> return other==this
is ReferenceLiteralValue -> return other==this
}
return false
}
}

View File

@@ -3,8 +3,6 @@ package prog8.ast.antlr
import org.antlr.v4.runtime.IntStream import org.antlr.v4.runtime.IntStream
import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.TerminalNode import org.antlr.v4.runtime.tree.TerminalNode
import prog8.ast.IExpression
import prog8.ast.IStatement
import prog8.ast.Module import prog8.ast.Module
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
@@ -41,7 +39,7 @@ private fun ParserRuleContext.toPosition() : Position {
} }
private fun prog8Parser.ModulestatementContext.toAst(isInLibrary: Boolean) : IStatement { private fun prog8Parser.ModulestatementContext.toAst(isInLibrary: Boolean) : Statement {
val directive = directive()?.toAst() val directive = directive()?.toAst()
if(directive!=null) return directive if(directive!=null) return directive
@@ -52,15 +50,15 @@ private fun prog8Parser.ModulestatementContext.toAst(isInLibrary: Boolean) : ISt
} }
private fun prog8Parser.BlockContext.toAst(isInLibrary: Boolean) : IStatement = private fun prog8Parser.BlockContext.toAst(isInLibrary: Boolean) : Statement =
Block(identifier().text, integerliteral()?.toAst()?.number?.toInt(), statement_block().toAst(), isInLibrary, toPosition()) Block(identifier().text, integerliteral()?.toAst()?.number?.toInt(), statement_block().toAst(), isInLibrary, toPosition())
private fun prog8Parser.Statement_blockContext.toAst(): MutableList<IStatement> = private fun prog8Parser.Statement_blockContext.toAst(): MutableList<Statement> =
statement().asSequence().map { it.toAst() }.toMutableList() statement().asSequence().map { it.toAst() }.toMutableList()
private fun prog8Parser.StatementContext.toAst() : IStatement { private fun prog8Parser.StatementContext.toAst() : Statement {
vardecl()?.let { return it.toAst() } vardecl()?.let { return it.toAst() }
varinitializer()?.let { varinitializer()?.let {
@@ -216,7 +214,7 @@ private fun prog8Parser.StatementContext.toAst() : IStatement {
throw FatalAstException("unprocessed source text (are we missing ast conversion rules for parser elements?): $text") throw FatalAstException("unprocessed source text (are we missing ast conversion rules for parser elements?): $text")
} }
private fun prog8Parser.AsmsubroutineContext.toAst(): IStatement { private fun prog8Parser.AsmsubroutineContext.toAst(): Statement {
val name = identifier().text val name = identifier().text
val address = asmsub_address()?.address?.toAst()?.number?.toInt() val address = asmsub_address()?.address?.toAst()?.number?.toInt()
val params = asmsub_params()?.toAst() ?: emptyList() val params = asmsub_params()?.toAst() ?: emptyList()
@@ -265,7 +263,7 @@ private fun prog8Parser.Asmsub_paramsContext.toAst(): List<AsmSubroutineParamete
private fun prog8Parser.StatusregisterContext.toAst() = Statusflag.valueOf(text) private fun prog8Parser.StatusregisterContext.toAst() = Statusflag.valueOf(text)
private fun prog8Parser.Functioncall_stmtContext.toAst(): IStatement { private fun prog8Parser.Functioncall_stmtContext.toAst(): Statement {
val location = scoped_identifier().toAst() val location = scoped_identifier().toAst()
return if(expression_list() == null) return if(expression_list() == null)
FunctionCallStatement(location, mutableListOf(), toPosition()) FunctionCallStatement(location, mutableListOf(), toPosition())
@@ -298,7 +296,7 @@ private fun prog8Parser.UnconditionaljumpContext.toAst(): Jump {
} }
private fun prog8Parser.LabeldefContext.toAst(): IStatement = private fun prog8Parser.LabeldefContext.toAst(): Statement =
Label(children[0].text, toPosition()) Label(children[0].text, toPosition())
@@ -411,7 +409,7 @@ private fun prog8Parser.IntegerliteralContext.toAst(): NumericLiteral {
} }
private fun prog8Parser.ExpressionContext.toAst() : IExpression { private fun prog8Parser.ExpressionContext.toAst() : Expression {
val litval = literalvalue() val litval = literalvalue()
if(litval!=null) { if(litval!=null) {
@@ -521,7 +519,7 @@ private fun prog8Parser.BooleanliteralContext.toAst() = when(text) {
} }
private fun prog8Parser.ArrayliteralContext.toAst() : Array<IExpression> = private fun prog8Parser.ArrayliteralContext.toAst() : Array<Expression> =
expression().map { it.toAst() }.toTypedArray() expression().map { it.toAst() }.toTypedArray()
private fun prog8Parser.If_stmtContext.toAst(): IfStatement { private fun prog8Parser.If_stmtContext.toAst(): IfStatement {
@@ -534,7 +532,7 @@ private fun prog8Parser.If_stmtContext.toAst(): IfStatement {
return IfStatement(condition, trueScope, elseScope, toPosition()) return IfStatement(condition, trueScope, elseScope, toPosition())
} }
private fun prog8Parser.Else_partContext.toAst(): MutableList<IStatement> { private fun prog8Parser.Else_partContext.toAst(): MutableList<Statement> {
return statement_block()?.toAst() ?: mutableListOf(statement().toAst()) return statement_block()?.toAst() ?: mutableListOf(statement().toAst())
} }

View File

@@ -1,11 +1,17 @@
package prog8.ast.expressions package prog8.ast.expressions
import prog8.ast.* import prog8.ast.IFunctionCall
import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.antlr.escape import prog8.ast.antlr.escape
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.* import prog8.ast.statements.ArrayIndex
import prog8.ast.statements.BuiltinFunctionStatementPlaceholder
import prog8.ast.statements.Subroutine
import prog8.ast.statements.VarDecl
import prog8.compiler.HeapValues import prog8.compiler.HeapValues
import prog8.compiler.IntegerOrAddressOf import prog8.compiler.IntegerOrAddressOf
import prog8.compiler.target.c64.Petscii import prog8.compiler.target.c64.Petscii
@@ -18,7 +24,38 @@ import kotlin.math.abs
val associativeOperators = setOf("+", "*", "&", "|", "^", "or", "and", "xor", "==", "!=") val associativeOperators = setOf("+", "*", "&", "|", "^", "or", "and", "xor", "==", "!=")
class PrefixExpression(val operator: String, var expression: IExpression, override val position: Position) : IExpression { sealed class Expression: Node {
abstract fun constValue(program: Program): NumericLiteralValue?
abstract fun accept(visitor: IAstModifyingVisitor): Expression
abstract fun accept(visitor: IAstVisitor)
abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove and use calltree instead
abstract fun inferType(program: Program): DataType?
infix fun isSameAs(other: Expression): Boolean {
if(this===other)
return true
when(this) {
is RegisterExpr ->
return (other is RegisterExpr && other.register==register)
is IdentifierReference ->
return (other is IdentifierReference && other.nameInSource==nameInSource)
is PrefixExpression ->
return (other is PrefixExpression && other.operator==operator && other.expression isSameAs expression)
is BinaryExpression ->
return (other is BinaryExpression && other.operator==operator
&& other.left isSameAs left
&& other.right isSameAs right)
is ArrayIndexedExpression -> {
return (other is ArrayIndexedExpression && other.identifier.nameInSource == identifier.nameInSource
&& other.arrayspec.index isSameAs arrayspec.index)
}
else -> return other==this
}
}
}
class PrefixExpression(val operator: String, var expression: Expression, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -37,7 +74,7 @@ class PrefixExpression(val operator: String, var expression: IExpression, overri
} }
} }
class BinaryExpression(var left: IExpression, var operator: String, var right: IExpression, override val position: Position) : IExpression { class BinaryExpression(var left: Expression, var operator: String, var right: Expression, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -148,7 +185,7 @@ class BinaryExpression(var left: IExpression, var operator: String, var right: I
} }
fun commonDatatype(leftDt: DataType, rightDt: DataType, fun commonDatatype(leftDt: DataType, rightDt: DataType,
left: IExpression, right: IExpression): Pair<DataType, IExpression?> { left: Expression, right: Expression): Pair<DataType, Expression?> {
// byte + byte -> byte // byte + byte -> byte
// byte + word -> word // byte + word -> word
// word + byte -> word // word + byte -> word
@@ -212,7 +249,7 @@ class BinaryExpression(var left: IExpression, var operator: String, var right: I
class ArrayIndexedExpression(val identifier: IdentifierReference, class ArrayIndexedExpression(val identifier: IdentifierReference,
var arrayspec: ArrayIndex, var arrayspec: ArrayIndex,
override val position: Position) : IExpression { override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
this.parent = parent this.parent = parent
@@ -243,7 +280,7 @@ class ArrayIndexedExpression(val identifier: IdentifierReference,
} }
} }
class TypecastExpression(var expression: IExpression, var type: DataType, val implicit: Boolean, override val position: Position) : IExpression { class TypecastExpression(var expression: Expression, var type: DataType, val implicit: Boolean, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -267,7 +304,7 @@ class TypecastExpression(var expression: IExpression, var type: DataType, val im
} }
} }
data class AddressOf(val identifier: IdentifierReference, override val position: Position) : IExpression { data class AddressOf(val identifier: IdentifierReference, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -283,7 +320,7 @@ data class AddressOf(val identifier: IdentifierReference, override val position:
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
class DirectMemoryRead(var addressExpression: IExpression, override val position: Position) : IExpression { class DirectMemoryRead(var addressExpression: Expression, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -304,7 +341,7 @@ class DirectMemoryRead(var addressExpression: IExpression, override val position
class NumericLiteralValue(val type: DataType, // only numerical types allowed class NumericLiteralValue(val type: DataType, // only numerical types allowed
val number: Number, // can be byte, word or float depending on the type val number: Number, // can be byte, word or float depending on the type
override val position: Position) : IExpression { override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
companion object { companion object {
@@ -420,8 +457,8 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
} }
} }
class StructLiteralValue(var values: List<IExpression>, class StructLiteralValue(var values: List<Expression>,
override val position: Position): IExpression { override val position: Position): Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -442,10 +479,10 @@ class StructLiteralValue(var values: List<IExpression>,
class ReferenceLiteralValue(val type: DataType, // only reference types allowed here class ReferenceLiteralValue(val type: DataType, // only reference types allowed here
val str: String? = null, val str: String? = null,
val array: Array<IExpression>? = null, val array: Array<Expression>? = null,
// actually, at the moment, we don't have struct literals in the language // actually, at the moment, we don't have struct literals in the language
initHeapId: Int? =null, initHeapId: Int? =null,
override val position: Position) : IExpression { override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun referencesIdentifiers(vararg name: String) = array?.any { it.referencesIdentifiers(*name) } ?: false override fun referencesIdentifiers(vararg name: String) = array?.any { it.referencesIdentifiers(*name) } ?: false
@@ -558,10 +595,10 @@ class ReferenceLiteralValue(val type: DataType, // only reference types allo
} }
} }
class RangeExpr(var from: IExpression, class RangeExpr(var from: Expression,
var to: IExpression, var to: Expression,
var step: IExpression, var step: Expression,
override val position: Position) : IExpression { override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -635,7 +672,7 @@ class RangeExpr(var from: IExpression,
} }
} }
class RegisterExpr(val register: Register, override val position: Position) : IExpression { class RegisterExpr(val register: Register, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -653,7 +690,7 @@ class RegisterExpr(val register: Register, override val position: Position) : IE
override fun inferType(program: Program) = DataType.UBYTE override fun inferType(program: Program) = DataType.UBYTE
} }
data class IdentifierReference(val nameInSource: List<String>, override val position: Position) : IExpression { data class IdentifierReference(val nameInSource: List<String>, override val position: Position) : Expression() {
override lateinit var parent: Node override lateinit var parent: Node
fun targetStatement(namespace: INameScope) = fun targetStatement(namespace: INameScope) =
@@ -710,8 +747,8 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
} }
class FunctionCall(override var target: IdentifierReference, class FunctionCall(override var target: IdentifierReference,
override var arglist: MutableList<IExpression>, override var arglist: MutableList<Expression>,
override val position: Position) : IExpression, IFunctionCall { override val position: Position) : Expression(), IFunctionCall {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -780,8 +817,7 @@ class FunctionCall(override var target: IdentifierReference,
return stmt.returntypes[0] return stmt.returntypes[0]
return null // has multiple return types... so not a single resulting datatype possible return null // has multiple return types... so not a single resulting datatype possible
} }
is Label -> return null else -> return null
} }
return null // calling something we don't recognise...
} }
} }

View File

@@ -1,6 +1,8 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.* import prog8.ast.INameScope
import prog8.ast.Module
import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
@@ -357,7 +359,7 @@ internal class AstChecker(private val program: Program,
checkResult.add(ExpressionError("address out of range", assignTarget.position)) checkResult.add(ExpressionError("address out of range", assignTarget.position))
} }
val assignment = assignTarget.parent as IStatement val assignment = assignTarget.parent as Statement
if (assignTarget.identifier != null) { if (assignTarget.identifier != null) {
val targetName = assignTarget.identifier.nameInSource val targetName = assignTarget.identifier.nameInSource
val targetSymbol = program.namespace.lookup(targetName, assignment) val targetSymbol = program.namespace.lookup(targetName, assignment)
@@ -798,7 +800,7 @@ internal class AstChecker(private val program: Program,
override fun visit(functionCall: FunctionCall) { override fun visit(functionCall: FunctionCall) {
// this function call is (part of) an expression, which should be in a statement somewhere. // this function call is (part of) an expression, which should be in a statement somewhere.
val stmtOfExpression = findParentNode<IStatement>(functionCall) val stmtOfExpression = findParentNode<Statement>(functionCall)
?: throw FatalAstException("cannot determine statement scope of function call expression at ${functionCall.position}") ?: throw FatalAstException("cannot determine statement scope of function call expression at ${functionCall.position}")
val targetStatement = checkFunctionOrLabelExists(functionCall.target, stmtOfExpression) val targetStatement = checkFunctionOrLabelExists(functionCall.target, stmtOfExpression)
@@ -820,7 +822,7 @@ internal class AstChecker(private val program: Program,
super.visit(functionCallStatement) super.visit(functionCallStatement)
} }
private fun checkFunctionCall(target: IStatement, args: List<IExpression>, position: Position) { private fun checkFunctionCall(target: Statement, args: List<Expression>, position: Position) {
if(target is Label && args.isNotEmpty()) if(target is Label && args.isNotEmpty())
checkResult.add(SyntaxError("cannot use arguments when calling a label", position)) checkResult.add(SyntaxError("cannot use arguments when calling a label", position))
@@ -995,7 +997,7 @@ internal class AstChecker(private val program: Program,
} }
} }
private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: IStatement): IStatement? { private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? {
val targetStatement = target.targetStatement(program.namespace) val targetStatement = target.targetStatement(program.namespace)
if(targetStatement is Label || targetStatement is Subroutine || targetStatement is BuiltinFunctionStatementPlaceholder) if(targetStatement is Label || targetStatement is Subroutine || targetStatement is BuiltinFunctionStatementPlaceholder)
return targetStatement return targetStatement
@@ -1253,7 +1255,7 @@ internal class AstChecker(private val program: Program,
private fun checkAssignmentCompatible(targetDatatype: DataType, private fun checkAssignmentCompatible(targetDatatype: DataType,
target: AssignTarget, target: AssignTarget,
sourceDatatype: DataType, sourceDatatype: DataType,
sourceValue: IExpression, sourceValue: Expression,
position: Position) : Boolean { position: Position) : Boolean {
if(sourceValue is RangeExpr) if(sourceValue is RangeExpr)

View File

@@ -1,6 +1,9 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.* import prog8.ast.INameScope
import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
@@ -18,7 +21,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return checkResult return checkResult
} }
private fun nameError(name: String, position: Position, existing: IStatement) { private fun nameError(name: String, position: Position, existing: Statement) {
checkResult.add(NameError("name conflict '$name', also defined in ${existing.position.file} line ${existing.position.line}", position)) checkResult.add(NameError("name conflict '$name', also defined in ${existing.position.file} line ${existing.position.line}", position))
} }
@@ -33,7 +36,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
} }
} }
override fun visit(block: Block): IStatement { override fun visit(block: Block): Statement {
val existing = blocks[block.name] val existing = blocks[block.name]
if(existing!=null) if(existing!=null)
nameError(block.name, block.position, existing) nameError(block.name, block.position, existing)
@@ -43,7 +46,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(block) return super.visit(block)
} }
override fun visit(functionCall: FunctionCall): IExpression { override fun visit(functionCall: FunctionCall): Expression {
if(functionCall.target.nameInSource.size==1 && functionCall.target.nameInSource[0]=="lsb") { if(functionCall.target.nameInSource.size==1 && functionCall.target.nameInSource[0]=="lsb") {
// lsb(...) is just an alias for type cast to ubyte, so replace with "... as ubyte" // lsb(...) is just an alias for type cast to ubyte, so replace with "... as ubyte"
val typecast = TypecastExpression(functionCall.arglist.single(), DataType.UBYTE, false, functionCall.position) val typecast = TypecastExpression(functionCall.arglist.single(), DataType.UBYTE, false, functionCall.position)
@@ -53,7 +56,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(functionCall) return super.visit(functionCall)
} }
override fun visit(decl: VarDecl): IStatement { override fun visit(decl: VarDecl): Statement {
// first, check if there are datatype errors on the vardecl // first, check if there are datatype errors on the vardecl
decl.datatypeErrors.forEach { checkResult.add(it) } decl.datatypeErrors.forEach { checkResult.add(it) }
@@ -95,7 +98,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(decl) return super.visit(decl)
} }
override fun visit(subroutine: Subroutine): IStatement { override fun visit(subroutine: Subroutine): Statement {
if(subroutine.name in BuiltinFunctions) { if(subroutine.name in BuiltinFunctions) {
// the builtin functions can't be redefined // the builtin functions can't be redefined
checkResult.add(NameError("builtin function cannot be redefined", subroutine.position)) checkResult.add(NameError("builtin function cannot be redefined", subroutine.position))
@@ -142,7 +145,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(subroutine) return super.visit(subroutine)
} }
override fun visit(label: Label): IStatement { override fun visit(label: Label): Statement {
if(label.name in BuiltinFunctions) { if(label.name in BuiltinFunctions) {
// the builtin functions can't be redefined // the builtin functions can't be redefined
checkResult.add(NameError("builtin function cannot be redefined", label.position)) checkResult.add(NameError("builtin function cannot be redefined", label.position))
@@ -154,7 +157,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(label) return super.visit(label)
} }
override fun visit(forLoop: ForLoop): IStatement { override fun visit(forLoop: ForLoop): Statement {
// If the for loop has a decltype, it means to declare the loopvar inside the loop body // If the for loop has a decltype, it means to declare the loopvar inside the loop body
// rather than reusing an already declared loopvar from an outer scope. // rather than reusing an already declared loopvar from an outer scope.
// For loops that loop over an interable variable (instead of a range of numbers) get an // For loops that loop over an interable variable (instead of a range of numbers) get an
@@ -200,13 +203,13 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(assignTarget) return super.visit(assignTarget)
} }
override fun visit(returnStmt: Return): IStatement { override fun visit(returnStmt: Return): Statement {
if(returnStmt.value!=null) { if(returnStmt.value!=null) {
// possibly adjust any literal values returned, into the desired returning data type // possibly adjust any literal values returned, into the desired returning data type
val subroutine = returnStmt.definingSubroutine()!! val subroutine = returnStmt.definingSubroutine()!!
if(subroutine.returntypes.size!=1) if(subroutine.returntypes.size!=1)
return returnStmt // mismatch in number of return values, error will be printed later. return returnStmt // mismatch in number of return values, error will be printed later.
val newValue: IExpression val newValue: Expression
val lval = returnStmt.value as? NumericLiteralValue val lval = returnStmt.value as? NumericLiteralValue
if(lval!=null) { if(lval!=null) {
val adjusted = lval.cast(subroutine.returntypes.single()) val adjusted = lval.cast(subroutine.returntypes.single())
@@ -220,7 +223,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(returnStmt) return super.visit(returnStmt)
} }
override fun visit(refLiteral: ReferenceLiteralValue): IExpression { override fun visit(refLiteral: ReferenceLiteralValue): Expression {
if(refLiteral.parent !is VarDecl) { if(refLiteral.parent !is VarDecl) {
return makeIdentifierFromRefLv(refLiteral) return makeIdentifierFromRefLv(refLiteral)
} }
@@ -240,14 +243,14 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return identifier return identifier
} }
override fun visit(addressOf: AddressOf): IExpression { override fun visit(addressOf: AddressOf): Expression {
// register the scoped name of the referenced identifier // register the scoped name of the referenced identifier
val variable= addressOf.identifier.targetVarDecl(program.namespace) ?: return addressOf val variable= addressOf.identifier.targetVarDecl(program.namespace) ?: return addressOf
addressOf.scopedname = variable.scopedname addressOf.scopedname = variable.scopedname
return super.visit(addressOf) return super.visit(addressOf)
} }
override fun visit(structDecl: StructDecl): IStatement { override fun visit(structDecl: StructDecl): Statement {
for(member in structDecl.statements){ for(member in structDecl.statements){
val decl = member as? VarDecl val decl = member as? VarDecl
if(decl!=null && decl.datatype !in NumericDatatypes) if(decl!=null && decl.datatype !in NumericDatatypes)
@@ -257,7 +260,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
return super.visit(structDecl) return super.visit(structDecl)
} }
override fun visit(expr: BinaryExpression): IExpression { override fun visit(expr: BinaryExpression): Expression {
return when { return when {
expr.left is ReferenceLiteralValue -> expr.left is ReferenceLiteralValue ->
processBinaryExprWithReferenceVal(expr.left as ReferenceLiteralValue, expr.right, expr) processBinaryExprWithReferenceVal(expr.left as ReferenceLiteralValue, expr.right, expr)
@@ -267,7 +270,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
} }
} }
private fun processBinaryExprWithReferenceVal(refLv: ReferenceLiteralValue, operand: IExpression, expr: BinaryExpression): IExpression { private fun processBinaryExprWithReferenceVal(refLv: ReferenceLiteralValue, operand: Expression, expr: BinaryExpression): Expression {
// expressions on strings or arrays // expressions on strings or arrays
if(refLv.isString) { if(refLv.isString) {
val constvalue = operand.constValue(program) val constvalue = operand.constValue(program)

View File

@@ -1,7 +1,5 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.IExpression
import prog8.ast.IStatement
import prog8.ast.Module import prog8.ast.Module
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.expressions.* import prog8.ast.expressions.*
@@ -16,38 +14,38 @@ interface IAstModifyingVisitor {
module.statements = module.statements.asSequence().map { it.accept(this) }.toMutableList() module.statements = module.statements.asSequence().map { it.accept(this) }.toMutableList()
} }
fun visit(expr: PrefixExpression): IExpression { fun visit(expr: PrefixExpression): Expression {
expr.expression = expr.expression.accept(this) expr.expression = expr.expression.accept(this)
return expr return expr
} }
fun visit(expr: BinaryExpression): IExpression { fun visit(expr: BinaryExpression): Expression {
expr.left = expr.left.accept(this) expr.left = expr.left.accept(this)
expr.right = expr.right.accept(this) expr.right = expr.right.accept(this)
return expr return expr
} }
fun visit(directive: Directive): IStatement { fun visit(directive: Directive): Statement {
return directive return directive
} }
fun visit(block: Block): IStatement { fun visit(block: Block): Statement {
block.statements = block.statements.asSequence().map { it.accept(this) }.toMutableList() block.statements = block.statements.asSequence().map { it.accept(this) }.toMutableList()
return block return block
} }
fun visit(decl: VarDecl): IStatement { fun visit(decl: VarDecl): Statement {
decl.value = decl.value?.accept(this) decl.value = decl.value?.accept(this)
decl.arraysize?.accept(this) decl.arraysize?.accept(this)
return decl return decl
} }
fun visit(subroutine: Subroutine): IStatement { fun visit(subroutine: Subroutine): Statement {
subroutine.statements = subroutine.statements.asSequence().map { it.accept(this) }.toMutableList() subroutine.statements = subroutine.statements.asSequence().map { it.accept(this) }.toMutableList()
return subroutine return subroutine
} }
fun visit(functionCall: FunctionCall): IExpression { fun visit(functionCall: FunctionCall): Expression {
val newtarget = functionCall.target.accept(this) val newtarget = functionCall.target.accept(this)
if(newtarget is IdentifierReference) if(newtarget is IdentifierReference)
functionCall.target = newtarget functionCall.target = newtarget
@@ -55,7 +53,7 @@ interface IAstModifyingVisitor {
return functionCall return functionCall
} }
fun visit(functionCallStatement: FunctionCallStatement): IStatement { fun visit(functionCallStatement: FunctionCallStatement): Statement {
val newtarget = functionCallStatement.target.accept(this) val newtarget = functionCallStatement.target.accept(this)
if(newtarget is IdentifierReference) if(newtarget is IdentifierReference)
functionCallStatement.target = newtarget functionCallStatement.target = newtarget
@@ -63,13 +61,13 @@ interface IAstModifyingVisitor {
return functionCallStatement return functionCallStatement
} }
fun visit(identifier: IdentifierReference): IExpression { fun visit(identifier: IdentifierReference): Expression {
// note: this is an identifier that is used in an expression. // note: this is an identifier that is used in an expression.
// other identifiers are simply part of the other statements (such as jumps, subroutine defs etc) // other identifiers are simply part of the other statements (such as jumps, subroutine defs etc)
return identifier return identifier
} }
fun visit(jump: Jump): IStatement { fun visit(jump: Jump): Statement {
if(jump.identifier!=null) { if(jump.identifier!=null) {
val ident = jump.identifier.accept(this) val ident = jump.identifier.accept(this)
if(ident is IdentifierReference && ident!==jump.identifier) { if(ident is IdentifierReference && ident!==jump.identifier) {
@@ -79,27 +77,27 @@ interface IAstModifyingVisitor {
return jump return jump
} }
fun visit(ifStatement: IfStatement): IStatement { fun visit(ifStatement: IfStatement): Statement {
ifStatement.condition = ifStatement.condition.accept(this) ifStatement.condition = ifStatement.condition.accept(this)
ifStatement.truepart = ifStatement.truepart.accept(this) as AnonymousScope ifStatement.truepart = ifStatement.truepart.accept(this) as AnonymousScope
ifStatement.elsepart = ifStatement.elsepart.accept(this) as AnonymousScope ifStatement.elsepart = ifStatement.elsepart.accept(this) as AnonymousScope
return ifStatement return ifStatement
} }
fun visit(branchStatement: BranchStatement): IStatement { fun visit(branchStatement: BranchStatement): Statement {
branchStatement.truepart = branchStatement.truepart.accept(this) as AnonymousScope branchStatement.truepart = branchStatement.truepart.accept(this) as AnonymousScope
branchStatement.elsepart = branchStatement.elsepart.accept(this) as AnonymousScope branchStatement.elsepart = branchStatement.elsepart.accept(this) as AnonymousScope
return branchStatement return branchStatement
} }
fun visit(range: RangeExpr): IExpression { fun visit(range: RangeExpr): Expression {
range.from = range.from.accept(this) range.from = range.from.accept(this)
range.to = range.to.accept(this) range.to = range.to.accept(this)
range.step = range.step.accept(this) range.step = range.step.accept(this)
return range return range
} }
fun visit(label: Label): IStatement { fun visit(label: Label): Statement {
return label return label
} }
@@ -107,7 +105,7 @@ interface IAstModifyingVisitor {
return literalValue return literalValue
} }
fun visit(refLiteral: ReferenceLiteralValue): IExpression { fun visit(refLiteral: ReferenceLiteralValue): Expression {
if(refLiteral.array!=null) { if(refLiteral.array!=null) {
for(av in refLiteral.array.withIndex()) { for(av in refLiteral.array.withIndex()) {
val newvalue = av.value.accept(this) val newvalue = av.value.accept(this)
@@ -117,50 +115,50 @@ interface IAstModifyingVisitor {
return refLiteral return refLiteral
} }
fun visit(assignment: Assignment): IStatement { fun visit(assignment: Assignment): Statement {
assignment.target = assignment.target.accept(this) assignment.target = assignment.target.accept(this)
assignment.value = assignment.value.accept(this) assignment.value = assignment.value.accept(this)
return assignment return assignment
} }
fun visit(postIncrDecr: PostIncrDecr): IStatement { fun visit(postIncrDecr: PostIncrDecr): Statement {
postIncrDecr.target = postIncrDecr.target.accept(this) postIncrDecr.target = postIncrDecr.target.accept(this)
return postIncrDecr return postIncrDecr
} }
fun visit(contStmt: Continue): IStatement { fun visit(contStmt: Continue): Statement {
return contStmt return contStmt
} }
fun visit(breakStmt: Break): IStatement { fun visit(breakStmt: Break): Statement {
return breakStmt return breakStmt
} }
fun visit(forLoop: ForLoop): IStatement { fun visit(forLoop: ForLoop): Statement {
forLoop.loopVar?.accept(this) forLoop.loopVar?.accept(this)
forLoop.iterable = forLoop.iterable.accept(this) forLoop.iterable = forLoop.iterable.accept(this)
forLoop.body = forLoop.body.accept(this) as AnonymousScope forLoop.body = forLoop.body.accept(this) as AnonymousScope
return forLoop return forLoop
} }
fun visit(whileLoop: WhileLoop): IStatement { fun visit(whileLoop: WhileLoop): Statement {
whileLoop.condition = whileLoop.condition.accept(this) whileLoop.condition = whileLoop.condition.accept(this)
whileLoop.body = whileLoop.body.accept(this) as AnonymousScope whileLoop.body = whileLoop.body.accept(this) as AnonymousScope
return whileLoop return whileLoop
} }
fun visit(repeatLoop: RepeatLoop): IStatement { fun visit(repeatLoop: RepeatLoop): Statement {
repeatLoop.untilCondition = repeatLoop.untilCondition.accept(this) repeatLoop.untilCondition = repeatLoop.untilCondition.accept(this)
repeatLoop.body = repeatLoop.body.accept(this) as AnonymousScope repeatLoop.body = repeatLoop.body.accept(this) as AnonymousScope
return repeatLoop return repeatLoop
} }
fun visit(returnStmt: Return): IStatement { fun visit(returnStmt: Return): Statement {
returnStmt.value = returnStmt.value?.accept(this) returnStmt.value = returnStmt.value?.accept(this)
return returnStmt return returnStmt
} }
fun visit(arrayIndexedExpression: ArrayIndexedExpression): IExpression { fun visit(arrayIndexedExpression: ArrayIndexedExpression): Expression {
arrayIndexedExpression.identifier.accept(this) arrayIndexedExpression.identifier.accept(this)
arrayIndexedExpression.arrayspec.accept(this) arrayIndexedExpression.arrayspec.accept(this)
return arrayIndexedExpression return arrayIndexedExpression
@@ -173,17 +171,17 @@ interface IAstModifyingVisitor {
return assignTarget return assignTarget
} }
fun visit(scope: AnonymousScope): IStatement { fun visit(scope: AnonymousScope): Statement {
scope.statements = scope.statements.asSequence().map { it.accept(this) }.toMutableList() scope.statements = scope.statements.asSequence().map { it.accept(this) }.toMutableList()
return scope return scope
} }
fun visit(typecast: TypecastExpression): IExpression { fun visit(typecast: TypecastExpression): Expression {
typecast.expression = typecast.expression.accept(this) typecast.expression = typecast.expression.accept(this)
return typecast return typecast
} }
fun visit(memread: DirectMemoryRead): IExpression { fun visit(memread: DirectMemoryRead): Expression {
memread.addressExpression = memread.addressExpression.accept(this) memread.addressExpression = memread.addressExpression.accept(this)
return memread return memread
} }
@@ -192,28 +190,28 @@ interface IAstModifyingVisitor {
memwrite.addressExpression = memwrite.addressExpression.accept(this) memwrite.addressExpression = memwrite.addressExpression.accept(this)
} }
fun visit(addressOf: AddressOf): IExpression { fun visit(addressOf: AddressOf): Expression {
addressOf.identifier.accept(this) addressOf.identifier.accept(this)
return addressOf return addressOf
} }
fun visit(inlineAssembly: InlineAssembly): IStatement { fun visit(inlineAssembly: InlineAssembly): Statement {
return inlineAssembly return inlineAssembly
} }
fun visit(registerExpr: RegisterExpr): IExpression { fun visit(registerExpr: RegisterExpr): Expression {
return registerExpr return registerExpr
} }
fun visit(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder): IStatement { fun visit(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder): Statement {
return builtinFunctionStatementPlaceholder return builtinFunctionStatementPlaceholder
} }
fun visit(nopStatement: NopStatement): IStatement { fun visit(nopStatement: NopStatement): Statement {
return nopStatement return nopStatement
} }
fun visit(whenStatement: WhenStatement): IStatement { fun visit(whenStatement: WhenStatement): Statement {
whenStatement.condition.accept(this) whenStatement.condition.accept(this)
whenStatement.choices.forEach { it.accept(this) } whenStatement.choices.forEach { it.accept(this) }
return whenStatement return whenStatement
@@ -224,12 +222,12 @@ interface IAstModifyingVisitor {
whenChoice.statements.accept(this) whenChoice.statements.accept(this)
} }
fun visit(structDecl: StructDecl): IStatement { fun visit(structDecl: StructDecl): Statement {
structDecl.statements = structDecl.statements.map{ it.accept(this) }.toMutableList() structDecl.statements = structDecl.statements.map{ it.accept(this) }.toMutableList()
return structDecl return structDecl
} }
fun visit(structLv: StructLiteralValue): IExpression { fun visit(structLv: StructLiteralValue): Expression {
structLv.values = structLv.values.map { it.accept(this) } structLv.values = structLv.values.map { it.accept(this) }
return structLv return structLv
} }

View File

@@ -1,10 +1,10 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.IStatement
import prog8.ast.Module import prog8.ast.Module
import prog8.ast.base.SyntaxError import prog8.ast.base.SyntaxError
import prog8.ast.base.printWarning import prog8.ast.base.printWarning
import prog8.ast.statements.Directive import prog8.ast.statements.Directive
import prog8.ast.statements.Statement
internal class ImportedModuleDirectiveRemover : IAstModifyingVisitor { internal class ImportedModuleDirectiveRemover : IAstModifyingVisitor {
private val checkResult: MutableList<SyntaxError> = mutableListOf() private val checkResult: MutableList<SyntaxError> = mutableListOf()
@@ -18,7 +18,7 @@ internal class ImportedModuleDirectiveRemover : IAstModifyingVisitor {
*/ */
override fun visit(module: Module) { override fun visit(module: Module) {
super.visit(module) super.visit(module)
val newStatements : MutableList<IStatement> = mutableListOf() val newStatements : MutableList<Statement> = mutableListOf()
val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address") val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address")
for (sourceStmt in module.statements) { for (sourceStmt in module.statements) {

View File

@@ -97,7 +97,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
module.statements.addAll(0, directives) module.statements.addAll(0, directives)
} }
override fun visit(block: Block): IStatement { override fun visit(block: Block): Statement {
val subroutines = block.statements.filterIsInstance<Subroutine>() val subroutines = block.statements.filterIsInstance<Subroutine>()
var numSubroutinesAtEnd = 0 var numSubroutinesAtEnd = 0
@@ -161,7 +161,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return super.visit(block) return super.visit(block)
} }
override fun visit(subroutine: Subroutine): IStatement { override fun visit(subroutine: Subroutine): Statement {
super.visit(subroutine) super.visit(subroutine)
val varDecls = subroutine.statements.filterIsInstance<VarDecl>() val varDecls = subroutine.statements.filterIsInstance<VarDecl>()
@@ -186,7 +186,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return subroutine return subroutine
} }
override fun visit(expr: BinaryExpression): IExpression { override fun visit(expr: BinaryExpression): Expression {
val leftDt = expr.left.inferType(program) val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program) val rightDt = expr.right.inferType(program)
if(leftDt!=null && rightDt!=null && leftDt!=rightDt) { if(leftDt!=null && rightDt!=null && leftDt!=rightDt) {
@@ -209,7 +209,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return super.visit(expr) return super.visit(expr)
} }
override fun visit(assignment: Assignment): IStatement { override fun visit(assignment: Assignment): Statement {
val assg = super.visit(assignment) val assg = super.visit(assignment)
if(assg !is Assignment) if(assg !is Assignment)
return assg return assg
@@ -246,7 +246,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(assg.aug_op!=null) { if(assg.aug_op!=null) {
// transform augmented assg into normal assg so we have one case less to deal with later // transform augmented assg into normal assg so we have one case less to deal with later
val newTarget: IExpression = val newTarget: Expression =
when { when {
assg.target.register != null -> RegisterExpr(assg.target.register!!, assg.target.position) assg.target.register != null -> RegisterExpr(assg.target.register!!, assg.target.position)
assg.target.identifier != null -> assg.target.identifier!! assg.target.identifier != null -> assg.target.identifier!!
@@ -265,12 +265,12 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return assg return assg
} }
override fun visit(functionCallStatement: FunctionCallStatement): IStatement { override fun visit(functionCallStatement: FunctionCallStatement): Statement {
checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope()) checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope())
return super.visit(functionCallStatement) return super.visit(functionCallStatement)
} }
override fun visit(functionCall: FunctionCall): IExpression { override fun visit(functionCall: FunctionCall): Expression {
checkFunctionCallArguments(functionCall, functionCall.definingScope()) checkFunctionCallArguments(functionCall, functionCall.definingScope())
return super.visit(functionCall) return super.visit(functionCall)
} }
@@ -321,7 +321,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
} }
} }
override fun visit(typecast: TypecastExpression): IExpression { override fun visit(typecast: TypecastExpression): Expression {
// warn about any implicit type casts to Float, because that may not be intended // warn about any implicit type casts to Float, because that may not be intended
if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) { if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) {
printWarning("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position) printWarning("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position)
@@ -329,7 +329,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return super.visit(typecast) return super.visit(typecast)
} }
override fun visit(whenStatement: WhenStatement): IStatement { override fun visit(whenStatement: WhenStatement): Statement {
// make sure all choices are just for one single value // make sure all choices are just for one single value
val choices = whenStatement.choices.toList() val choices = whenStatement.choices.toList()
for(choice in choices) { for(choice in choices) {
@@ -349,7 +349,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return super.visit(whenStatement) return super.visit(whenStatement)
} }
override fun visit(memread: DirectMemoryRead): IExpression { override fun visit(memread: DirectMemoryRead): Expression {
// make sure the memory address is an uword // make sure the memory address is an uword
val dt = memread.addressExpression.inferType(program) val dt = memread.addressExpression.inferType(program)
if(dt!=DataType.UWORD) { if(dt!=DataType.UWORD) {
@@ -378,7 +378,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
super.visit(memwrite) super.visit(memwrite)
} }
override fun visit(structLv: StructLiteralValue): IExpression { override fun visit(structLv: StructLiteralValue): Expression {
val litval = super.visit(structLv) val litval = super.visit(structLv)
if(litval !is StructLiteralValue) if(litval !is StructLiteralValue)
return litval return litval

View File

@@ -1,6 +1,8 @@
package prog8.ast.processing package prog8.ast.processing
import prog8.ast.* import prog8.ast.INameScope
import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
@@ -31,7 +33,7 @@ internal class VarInitValueAndAddressOfCreator(private val namespace: INameScope
} }
} }
override fun visit(decl: VarDecl): IStatement { override fun visit(decl: VarDecl): Statement {
super.visit(decl) super.visit(decl)
if(decl.isArray && decl.value==null) { if(decl.isArray && decl.value==null) {
@@ -70,25 +72,25 @@ internal class VarInitValueAndAddressOfCreator(private val namespace: INameScope
return decl return decl
} }
override fun visit(functionCall: FunctionCall): IExpression { override fun visit(functionCall: FunctionCall): Expression {
val targetStatement = functionCall.target.targetSubroutine(namespace) val targetStatement = functionCall.target.targetSubroutine(namespace)
if(targetStatement!=null) { if(targetStatement!=null) {
var node: Node = functionCall var node: Node = functionCall
while(node !is IStatement) while(node !is Statement)
node=node.parent node=node.parent
addAddressOfExprIfNeeded(targetStatement, functionCall.arglist, node) addAddressOfExprIfNeeded(targetStatement, functionCall.arglist, node)
} }
return functionCall return functionCall
} }
override fun visit(functionCallStatement: FunctionCallStatement): IStatement { override fun visit(functionCallStatement: FunctionCallStatement): Statement {
val targetStatement = functionCallStatement.target.targetSubroutine(namespace) val targetStatement = functionCallStatement.target.targetSubroutine(namespace)
if(targetStatement!=null) if(targetStatement!=null)
addAddressOfExprIfNeeded(targetStatement, functionCallStatement.arglist, functionCallStatement) addAddressOfExprIfNeeded(targetStatement, functionCallStatement.arglist, functionCallStatement)
return functionCallStatement return functionCallStatement
} }
private fun addAddressOfExprIfNeeded(subroutine: Subroutine, arglist: MutableList<IExpression>, parent: IStatement) { private fun addAddressOfExprIfNeeded(subroutine: Subroutine, arglist: MutableList<Expression>, parent: Statement) {
// functions that accept UWORD and are given an array type, or string, will receive the AddressOf (memory location) of that value instead. // functions that accept UWORD and are given an array type, or string, will receive the AddressOf (memory location) of that value instead.
for(argparam in subroutine.parameters.withIndex().zip(arglist)) { for(argparam in subroutine.parameters.withIndex().zip(arglist)) {
if(argparam.first.value.type==DataType.UWORD || argparam.first.value.type in StringDatatypes) { if(argparam.first.value.type==DataType.UWORD || argparam.first.value.type in StringDatatypes) {

View File

@@ -9,7 +9,38 @@ import prog8.compiler.HeapValues
import prog8.compiler.target.c64.MachineDefinition import prog8.compiler.target.c64.MachineDefinition
class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position) : IStatement { sealed class Statement : Node {
abstract fun accept(visitor: IAstModifyingVisitor) : Statement
abstract fun accept(visitor: IAstVisitor)
fun makeScopedName(name: String): String {
// easy way out is to always return the full scoped name.
// it would be nicer to find only the minimal prefixed scoped name, but that's too much hassle for now.
// and like this, we can cache the name even,
// like in a lazy property on the statement object itself (label, subroutine, vardecl)
val scope = mutableListOf<String>()
var statementScope = this.parent
while(statementScope !is ParentSentinel && statementScope !is Module) {
if(statementScope is INameScope) {
scope.add(0, statementScope.name)
}
statementScope = statementScope.parent
}
if(name.isNotEmpty())
scope.add(name)
return scope.joinToString(".")
}
abstract val expensiveToInline: Boolean
fun definingBlock(): Block {
if(this is Block)
return this
return findParentNode<Block>(this)!!
}
}
class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position) : Statement() {
override var parent: Node = ParentSentinel override var parent: Node = ParentSentinel
override fun linkParents(parent: Node) {} override fun linkParents(parent: Node) {}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
@@ -24,9 +55,9 @@ data class RegisterOrStatusflag(val registerOrPair: RegisterOrPair?, val statusf
class Block(override val name: String, class Block(override val name: String,
val address: Int?, val address: Int?,
override var statements: MutableList<IStatement>, override var statements: MutableList<Statement>,
val isInLibrary: Boolean, val isInLibrary: Boolean,
override val position: Position) : IStatement, INameScope { override val position: Position) : Statement(), INameScope {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline override val expensiveToInline
get() = statements.any { it.expensiveToInline } get() = statements.any { it.expensiveToInline }
@@ -46,7 +77,7 @@ class Block(override val name: String,
fun options() = statements.filter { it is Directive && it.directive == "%option" }.flatMap { (it as Directive).args }.map {it.name!!}.toSet() fun options() = statements.filter { it is Directive && it.directive == "%option" }.flatMap { (it as Directive).args }.map {it.name!!}.toSet()
} }
data class Directive(val directive: String, val args: List<DirectiveArg>, override val position: Position) : IStatement { data class Directive(val directive: String, val args: List<DirectiveArg>, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -67,7 +98,7 @@ data class DirectiveArg(val str: String?, val name: String?, val int: Int?, over
} }
} }
data class Label(val name: String, override val position: Position) : IStatement { data class Label(val name: String, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -85,7 +116,7 @@ data class Label(val name: String, override val position: Position) : IStatement
val scopedname: String by lazy { makeScopedName(name) } val scopedname: String by lazy { makeScopedName(name) }
} }
open class Return(var value: IExpression?, override val position: Position) : IStatement { open class Return(var value: Expression?, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = value!=null && value !is NumericLiteralValue override val expensiveToInline = value!=null && value !is NumericLiteralValue
@@ -111,7 +142,7 @@ class ReturnFromIrq(override val position: Position) : Return(null, position) {
} }
} }
class Continue(override val position: Position) : IStatement { class Continue(override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -123,7 +154,7 @@ class Continue(override val position: Position) : IStatement {
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
class Break(override val position: Position) : IStatement { class Break(override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -149,10 +180,10 @@ class VarDecl(val type: VarDeclType,
var arraysize: ArrayIndex?, var arraysize: ArrayIndex?,
val name: String, val name: String,
private val structName: String?, private val structName: String?,
var value: IExpression?, var value: Expression?,
val isArray: Boolean, val isArray: Boolean,
val autogeneratedDontRemove: Boolean, val autogeneratedDontRemove: Boolean,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
var struct: StructDecl? = null // set later (because at parse time, we only know the name) var struct: StructDecl? = null // set later (because at parse time, we only know the name)
private set private set
@@ -231,7 +262,7 @@ class VarDecl(val type: VarDeclType,
return decl return decl
} }
fun flattenStructMembers(): MutableList<IStatement> { fun flattenStructMembers(): MutableList<Statement> {
val result = struct!!.statements.withIndex().map { val result = struct!!.statements.withIndex().map {
val member = it.value as VarDecl val member = it.value as VarDecl
val initvalue = if(value!=null) (value as StructLiteralValue).values[it.index] else null val initvalue = if(value!=null) (value as StructLiteralValue).values[it.index] else null
@@ -246,14 +277,14 @@ class VarDecl(val type: VarDeclType,
member.isArray, member.isArray,
true, true,
member.position member.position
) as IStatement ) as Statement
}.toMutableList() }.toMutableList()
structHasBeenFlattened = true structHasBeenFlattened = true
return result return result
} }
} }
class ArrayIndex(var index: IExpression, override val position: Position) : Node { class ArrayIndex(var index: Expression, override val position: Position) : Node {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -282,7 +313,7 @@ class ArrayIndex(var index: IExpression, override val position: Position) : Node
fun size() = (index as? NumericLiteralValue)?.number?.toInt() fun size() = (index as? NumericLiteralValue)?.number?.toInt()
} }
open class Assignment(var target: AssignTarget, val aug_op : String?, var value: IExpression, override val position: Position) : IStatement { open class Assignment(var target: AssignTarget, val aug_op : String?, var value: Expression, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline override val expensiveToInline
get() = value !is NumericLiteralValue get() = value !is NumericLiteralValue
@@ -303,7 +334,7 @@ open class Assignment(var target: AssignTarget, val aug_op : String?, var value:
// This is a special class so the compiler can see if the assignments are for initializing the vars in the scope, // This is a special class so the compiler can see if the assignments are for initializing the vars in the scope,
// or just a regular assignment. It may optimize the initialization step from this. // or just a regular assignment. It may optimize the initialization step from this.
class VariableInitializationAssignment(target: AssignTarget, aug_op: String?, value: IExpression, position: Position) class VariableInitializationAssignment(target: AssignTarget, aug_op: String?, value: Expression, position: Position)
: Assignment(target, aug_op, value, position) : Assignment(target, aug_op, value, position)
data class AssignTarget(val register: Register?, data class AssignTarget(val register: Register?,
@@ -324,7 +355,7 @@ data class AssignTarget(val register: Register?,
fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this)
companion object { companion object {
fun fromExpr(expr: IExpression): AssignTarget { fun fromExpr(expr: Expression): AssignTarget {
return when (expr) { return when (expr) {
is RegisterExpr -> AssignTarget(expr.register, null, null, null, expr.position) is RegisterExpr -> AssignTarget(expr.register, null, null, null, expr.position)
is IdentifierReference -> AssignTarget(null, expr, null, null, expr.position) is IdentifierReference -> AssignTarget(null, expr, null, null, expr.position)
@@ -335,7 +366,7 @@ data class AssignTarget(val register: Register?,
} }
} }
fun inferType(program: Program, stmt: IStatement): DataType? { fun inferType(program: Program, stmt: Statement): DataType? {
if(register!=null) if(register!=null)
return DataType.UBYTE return DataType.UBYTE
@@ -356,7 +387,7 @@ data class AssignTarget(val register: Register?,
return null return null
} }
infix fun isSameAs(value: IExpression): Boolean { infix fun isSameAs(value: Expression): Boolean {
return when { return when {
this.memoryAddress!=null -> false this.memoryAddress!=null -> false
this.register!=null -> value is RegisterExpr && value.register==register this.register!=null -> value is RegisterExpr && value.register==register
@@ -411,7 +442,7 @@ data class AssignTarget(val register: Register?,
} }
} }
class PostIncrDecr(var target: AssignTarget, val operator: String, override val position: Position) : IStatement { class PostIncrDecr(var target: AssignTarget, val operator: String, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -431,7 +462,7 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val
class Jump(val address: Int?, class Jump(val address: Int?,
val identifier: IdentifierReference?, val identifier: IdentifierReference?,
val generatedLabel: String?, // used in code generation scenarios val generatedLabel: String?, // used in code generation scenarios
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -449,8 +480,8 @@ class Jump(val address: Int?,
} }
class FunctionCallStatement(override var target: IdentifierReference, class FunctionCallStatement(override var target: IdentifierReference,
override var arglist: MutableList<IExpression>, override var arglist: MutableList<Expression>,
override val position: Position) : IStatement, IFunctionCall { override val position: Position) : Statement(), IFunctionCall {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline override val expensiveToInline
get() = arglist.any { it !is NumericLiteralValue } get() = arglist.any { it !is NumericLiteralValue }
@@ -469,7 +500,7 @@ class FunctionCallStatement(override var target: IdentifierReference,
} }
} }
class InlineAssembly(val assembly: String, override val position: Position) : IStatement { class InlineAssembly(val assembly: String, override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = true override val expensiveToInline = true
@@ -481,8 +512,8 @@ class InlineAssembly(val assembly: String, override val position: Position) : IS
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
class AnonymousScope(override var statements: MutableList<IStatement>, class AnonymousScope(override var statements: MutableList<Statement>,
override val position: Position) : INameScope, IStatement { override val position: Position) : INameScope, Statement() {
override val name: String override val name: String
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline override val expensiveToInline
@@ -506,7 +537,7 @@ class AnonymousScope(override var statements: MutableList<IStatement>,
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
class NopStatement(override val position: Position): IStatement { class NopStatement(override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = false override val expensiveToInline = false
@@ -518,7 +549,7 @@ class NopStatement(override val position: Position): IStatement {
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
companion object { companion object {
fun insteadOf(stmt: IStatement): NopStatement { fun insteadOf(stmt: Statement): NopStatement {
val nop = NopStatement(stmt.position) val nop = NopStatement(stmt.position)
nop.parent = stmt.parent nop.parent = stmt.parent
return nop return nop
@@ -537,8 +568,8 @@ class Subroutine(override val name: String,
val asmClobbers: Set<Register>, val asmClobbers: Set<Register>,
val asmAddress: Int?, val asmAddress: Int?,
val isAsmSubroutine: Boolean, val isAsmSubroutine: Boolean,
override var statements: MutableList<IStatement>, override var statements: MutableList<Statement>,
override val position: Position) : IStatement, INameScope { override val position: Position) : Statement(), INameScope {
var keepAlways: Boolean = false var keepAlways: Boolean = false
override val expensiveToInline override val expensiveToInline
@@ -614,10 +645,10 @@ open class SubroutineParameter(val name: String,
} }
} }
class IfStatement(var condition: IExpression, class IfStatement(var condition: Expression,
var truepart: AnonymousScope, var truepart: AnonymousScope,
var elsepart: AnonymousScope, var elsepart: AnonymousScope,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline: Boolean override val expensiveToInline: Boolean
get() = truepart.expensiveToInline || elsepart.expensiveToInline get() = truepart.expensiveToInline || elsepart.expensiveToInline
@@ -636,7 +667,7 @@ class IfStatement(var condition: IExpression,
class BranchStatement(var condition: BranchCondition, class BranchStatement(var condition: BranchCondition,
var truepart: AnonymousScope, var truepart: AnonymousScope,
var elsepart: AnonymousScope, var elsepart: AnonymousScope,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline: Boolean override val expensiveToInline: Boolean
get() = truepart.expensiveToInline || elsepart.expensiveToInline get() = truepart.expensiveToInline || elsepart.expensiveToInline
@@ -655,9 +686,9 @@ class ForLoop(val loopRegister: Register?,
val decltype: DataType?, val decltype: DataType?,
val zeropage: ZeropageWish, val zeropage: ZeropageWish,
val loopVar: IdentifierReference?, val loopVar: IdentifierReference?,
var iterable: IExpression, var iterable: Expression,
var body: AnonymousScope, var body: AnonymousScope,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = true override val expensiveToInline = true
@@ -680,9 +711,9 @@ class ForLoop(val loopRegister: Register?,
} }
} }
class WhileLoop(var condition: IExpression, class WhileLoop(var condition: Expression,
var body: AnonymousScope, var body: AnonymousScope,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = true override val expensiveToInline = true
@@ -697,8 +728,8 @@ class WhileLoop(var condition: IExpression,
} }
class RepeatLoop(var body: AnonymousScope, class RepeatLoop(var body: AnonymousScope,
var untilCondition: IExpression, var untilCondition: Expression,
override val position: Position) : IStatement { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline = true override val expensiveToInline = true
@@ -712,9 +743,9 @@ class RepeatLoop(var body: AnonymousScope,
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
class WhenStatement(val condition: IExpression, class WhenStatement(val condition: Expression,
val choices: MutableList<WhenChoice>, val choices: MutableList<WhenChoice>,
override val position: Position): IStatement { override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline: Boolean = true override val expensiveToInline: Boolean = true
@@ -745,7 +776,7 @@ class WhenStatement(val condition: IExpression,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
} }
class WhenChoice(val values: List<IExpression>?, // if null, this is the 'else' part class WhenChoice(val values: List<Expression>?, // if null, this is the 'else' part
val statements: AnonymousScope, val statements: AnonymousScope,
override val position: Position) : Node { override val position: Position) : Node {
override lateinit var parent: Node override lateinit var parent: Node
@@ -766,8 +797,8 @@ class WhenChoice(val values: List<IExpression>?, // if null, this is
class StructDecl(override val name: String, class StructDecl(override val name: String,
override var statements: MutableList<IStatement>, // actually, only vardecls here override var statements: MutableList<Statement>, // actually, only vardecls here
override val position: Position): IStatement, INameScope { override val position: Position): Statement(), INameScope {
override lateinit var parent: Node override lateinit var parent: Node
override val expensiveToInline: Boolean = true override val expensiveToInline: Boolean = true
@@ -797,7 +828,7 @@ class StructDecl(override val name: String,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
} }
class DirectMemoryWrite(var addressExpression: IExpression, override val position: Position) : Node { class DirectMemoryWrite(var addressExpression: Expression, override val position: Position) : Node {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {

View File

@@ -1,6 +1,8 @@
package prog8.compiler package prog8.compiler
import prog8.ast.* import prog8.ast.IFunctionCall
import prog8.ast.Module
import prog8.ast.Program
import prog8.ast.antlr.escape import prog8.ast.antlr.escape
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.base.NumericDatatypes import prog8.ast.base.NumericDatatypes
@@ -176,7 +178,7 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
} }
} }
private fun outputStatements(statements: List<IStatement>) { private fun outputStatements(statements: List<Statement>) {
for(stmt in statements) { for(stmt in statements) {
if(stmt is VarDecl && stmt.autogeneratedDontRemove) if(stmt is VarDecl && stmt.autogeneratedDontRemove)
continue // skip autogenerated decls (to avoid generating a newline) continue // skip autogenerated decls (to avoid generating a newline)
@@ -267,7 +269,7 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
} }
} }
private fun outputListMembers(array: Sequence<IExpression>, openchar: Char, closechar: Char) { private fun outputListMembers(array: Sequence<Expression>, openchar: Char, closechar: Char) {
var counter = 0 var counter = 0
output(openchar.toString()) output(openchar.toString())
scopelevel++ scopelevel++

View File

@@ -1,9 +1,11 @@
package prog8.compiler package prog8.compiler
import prog8.ast.* import prog8.ast.INameScope
import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.base.RegisterOrPair.* import prog8.ast.base.RegisterOrPair.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.mangledStructMemberName
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.intermediate.IntermediateProgram import prog8.compiler.intermediate.IntermediateProgram
import prog8.compiler.intermediate.Opcode import prog8.compiler.intermediate.Opcode
@@ -179,8 +181,8 @@ internal class Compiler(private val program: Program) {
processVariables(subscope.value) processVariables(subscope.value)
} }
private fun translate(statements: List<IStatement>) { private fun translate(statements: List<Statement>) {
for (stmt: IStatement in statements) { for (stmt: Statement in statements) {
generatedLabelSequenceNumber++ generatedLabelSequenceNumber++
when (stmt) { when (stmt) {
is Label -> translate(stmt) is Label -> translate(stmt)
@@ -486,7 +488,7 @@ internal class Compiler(private val program: Program) {
} }
} }
private fun makeLabel(scopeStmt: IStatement, postfix: String): String { private fun makeLabel(scopeStmt: Statement, postfix: String): String {
generatedLabelSequenceNumber++ generatedLabelSequenceNumber++
return "${scopeStmt.makeScopedName("")}.<s-$generatedLabelSequenceNumber-$postfix>" return "${scopeStmt.makeScopedName("")}.<s-$generatedLabelSequenceNumber-$postfix>"
} }
@@ -551,7 +553,7 @@ internal class Compiler(private val program: Program) {
prog.instr(Opcode.NOP) prog.instr(Opcode.NOP)
} }
private fun translate(expr: IExpression) { private fun translate(expr: Expression) {
when(expr) { when(expr) {
is RegisterExpr -> { is RegisterExpr -> {
prog.instr(Opcode.PUSH_VAR_BYTE, callLabel = expr.register.name) prog.instr(Opcode.PUSH_VAR_BYTE, callLabel = expr.register.name)
@@ -716,7 +718,7 @@ internal class Compiler(private val program: Program) {
} }
} }
private fun translateBuiltinFunctionCall(funcname: String, args: List<IExpression>) { private fun translateBuiltinFunctionCall(funcname: String, args: List<Expression>) {
// some builtin functions are implemented directly as vm opcodes // some builtin functions are implemented directly as vm opcodes
if(funcname == "swap") { if(funcname == "swap") {
@@ -900,7 +902,7 @@ internal class Compiler(private val program: Program) {
} }
} }
private fun translateSwap(args: List<IExpression>) { private fun translateSwap(args: List<Expression>) {
// swap(x,y) is treated differently, it's not a normal function call // swap(x,y) is treated differently, it's not a normal function call
if (args.size != 2) if (args.size != 2)
throw AstException("swap requires 2 arguments") throw AstException("swap requires 2 arguments")
@@ -923,7 +925,7 @@ internal class Compiler(private val program: Program) {
return return
} }
private fun translateSubroutineCall(subroutine: Subroutine, arguments: List<IExpression>, callPosition: Position) { private fun translateSubroutineCall(subroutine: Subroutine, arguments: List<Expression>, callPosition: Position) {
// evaluate the arguments and assign them into the subroutine's argument variables. // evaluate the arguments and assign them into the subroutine's argument variables.
var restoreX = Register.X in subroutine.asmClobbers var restoreX = Register.X in subroutine.asmClobbers
if(restoreX) if(restoreX)
@@ -970,7 +972,7 @@ internal class Compiler(private val program: Program) {
} }
} }
private fun translateAsmSubCallArguments(subroutine: Subroutine, arguments: List<IExpression>, callPosition: Position, restoreXinitial: Boolean): Boolean { private fun translateAsmSubCallArguments(subroutine: Subroutine, arguments: List<Expression>, callPosition: Position, restoreXinitial: Boolean): Boolean {
var restoreX = restoreXinitial var restoreX = restoreXinitial
if (subroutine.parameters.size != subroutine.asmParameterRegisters.size) if (subroutine.parameters.size != subroutine.asmParameterRegisters.size)
TODO("no support yet for mix of register and non-register subroutine arguments") TODO("no support yet for mix of register and non-register subroutine arguments")
@@ -1009,8 +1011,8 @@ internal class Compiler(private val program: Program) {
prog.instr(Opcode.RSAVEX) prog.instr(Opcode.RSAVEX)
restoreX = true restoreX = true
} }
val valueA: IExpression val valueA: Expression
val valueX: IExpression val valueX: Expression
val paramDt = arg.first.inferType(program) val paramDt = arg.first.inferType(program)
when (paramDt) { when (paramDt) {
DataType.UBYTE -> { DataType.UBYTE -> {
@@ -1032,8 +1034,8 @@ internal class Compiler(private val program: Program) {
} }
} }
AY -> { AY -> {
val valueA: IExpression val valueA: Expression
val valueY: IExpression val valueY: Expression
val paramDt = arg.first.inferType(program) val paramDt = arg.first.inferType(program)
when (paramDt) { when (paramDt) {
DataType.UBYTE -> { DataType.UBYTE -> {
@@ -1059,8 +1061,8 @@ internal class Compiler(private val program: Program) {
prog.instr(Opcode.RSAVEX) prog.instr(Opcode.RSAVEX)
restoreX = true restoreX = true
} }
val valueX: IExpression val valueX: Expression
val valueY: IExpression val valueY: Expression
val paramDt = arg.first.inferType(program) val paramDt = arg.first.inferType(program)
when (paramDt) { when (paramDt) {
DataType.UBYTE -> { DataType.UBYTE -> {
@@ -1482,7 +1484,7 @@ internal class Compiler(private val program: Program) {
popValueIntoTarget(stmt.target, datatype) popValueIntoTarget(stmt.target, datatype)
} }
private fun pushHeapVarAddress(value: IExpression, removeLastOpcode: Boolean) { private fun pushHeapVarAddress(value: Expression, removeLastOpcode: Boolean) {
if (value is IdentifierReference) { if (value is IdentifierReference) {
val vardecl = value.targetVarDecl(program.namespace)!! val vardecl = value.targetVarDecl(program.namespace)!!
if(removeLastOpcode) prog.removeLastInstruction() if(removeLastOpcode) prog.removeLastInstruction()
@@ -1491,7 +1493,7 @@ internal class Compiler(private val program: Program) {
else throw CompilerException("can only take address of a literal string value or a string/array variable") else throw CompilerException("can only take address of a literal string value or a string/array variable")
} }
private fun pushFloatAddress(value: IExpression) { private fun pushFloatAddress(value: Expression) {
if (value is IdentifierReference) { if (value is IdentifierReference) {
val vardecl = value.targetVarDecl(program.namespace)!! val vardecl = value.targetVarDecl(program.namespace)!!
prog.instr(Opcode.PUSH_ADDR_HEAPVAR, callLabel = vardecl.scopedname) prog.instr(Opcode.PUSH_ADDR_HEAPVAR, callLabel = vardecl.scopedname)
@@ -1499,7 +1501,7 @@ internal class Compiler(private val program: Program) {
else throw CompilerException("can only take address of a the float as constant literal or variable") else throw CompilerException("can only take address of a the float as constant literal or variable")
} }
private fun pushStructAddress(value: IExpression) { private fun pushStructAddress(value: Expression) {
if (value is IdentifierReference) { if (value is IdentifierReference) {
// notice that the mangled name of the first struct member is the start address of this struct var // notice that the mangled name of the first struct member is the start address of this struct var
val vardecl = value.targetVarDecl(program.namespace)!! val vardecl = value.targetVarDecl(program.namespace)!!

View File

@@ -1,12 +1,8 @@
package prog8.functions package prog8.functions
import prog8.ast.IExpression
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.DirectMemoryRead import prog8.ast.expressions.*
import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.expressions.ReferenceLiteralValue
import prog8.compiler.CompilerException import prog8.compiler.CompilerException
import kotlin.math.* import kotlin.math.*
@@ -14,7 +10,7 @@ import kotlin.math.*
class BuiltinFunctionParam(val name: String, val possibleDatatypes: Set<DataType>) class BuiltinFunctionParam(val name: String, val possibleDatatypes: Set<DataType>)
typealias ConstExpressionCaller = (args: List<IExpression>, position: Position, program: Program) -> NumericLiteralValue typealias ConstExpressionCaller = (args: List<Expression>, position: Position, program: Program) -> NumericLiteralValue
class FunctionSignature(val pure: Boolean, // does it have side effects? class FunctionSignature(val pure: Boolean, // does it have side effects?
@@ -117,9 +113,9 @@ val BuiltinFunctions = mapOf(
) )
fun builtinFunctionReturnType(function: String, args: List<IExpression>, program: Program): DataType? { fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): DataType? {
fun datatypeFromIterableArg(arglist: IExpression): DataType { fun datatypeFromIterableArg(arglist: Expression): DataType {
if(arglist is ReferenceLiteralValue) { if(arglist is ReferenceLiteralValue) {
if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) { if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) {
val dt = arglist.array!!.map {it.inferType(program)} val dt = arglist.array!!.map {it.inferType(program)}
@@ -189,7 +185,7 @@ fun builtinFunctionReturnType(function: String, args: List<IExpression>, program
class NotConstArgumentException: AstException("not a const argument to a built-in function") class NotConstArgumentException: AstException("not a const argument to a built-in function")
private fun oneDoubleArg(args: List<IExpression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue { private fun oneDoubleArg(args: List<Expression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue {
if(args.size!=1) if(args.size!=1)
throw SyntaxError("built-in function requires one floating point argument", position) throw SyntaxError("built-in function requires one floating point argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -197,7 +193,7 @@ private fun oneDoubleArg(args: List<IExpression>, position: Position, program: P
return numericLiteral(function(float), args[0].position) return numericLiteral(function(float), args[0].position)
} }
private fun oneDoubleArgOutputWord(args: List<IExpression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue { private fun oneDoubleArgOutputWord(args: List<Expression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue {
if(args.size!=1) if(args.size!=1)
throw SyntaxError("built-in function requires one floating point argument", position) throw SyntaxError("built-in function requires one floating point argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -205,7 +201,7 @@ private fun oneDoubleArgOutputWord(args: List<IExpression>, position: Position,
return NumericLiteralValue(DataType.WORD, function(float).toInt(), args[0].position) return NumericLiteralValue(DataType.WORD, function(float).toInt(), args[0].position)
} }
private fun oneIntArgOutputInt(args: List<IExpression>, position: Position, program: Program, function: (arg: Int)->Number): NumericLiteralValue { private fun oneIntArgOutputInt(args: List<Expression>, position: Position, program: Program, function: (arg: Int)->Number): NumericLiteralValue {
if(args.size!=1) if(args.size!=1)
throw SyntaxError("built-in function requires one integer argument", position) throw SyntaxError("built-in function requires one integer argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -216,7 +212,7 @@ private fun oneIntArgOutputInt(args: List<IExpression>, position: Position, prog
return numericLiteral(function(integer).toInt(), args[0].position) return numericLiteral(function(integer).toInt(), args[0].position)
} }
private fun collectionArgNeverConst(args: List<IExpression>, position: Position): NumericLiteralValue { private fun collectionArgNeverConst(args: List<Expression>, position: Position): NumericLiteralValue {
if(args.size!=1) if(args.size!=1)
throw SyntaxError("builtin function requires one non-scalar argument", position) throw SyntaxError("builtin function requires one non-scalar argument", position)
@@ -224,7 +220,7 @@ private fun collectionArgNeverConst(args: List<IExpression>, position: Position)
throw NotConstArgumentException() throw NotConstArgumentException()
} }
private fun builtinAbs(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinAbs(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
// 1 arg, type = float or int, result type= isSameAs as argument type // 1 arg, type = float or int, result type= isSameAs as argument type
if(args.size!=1) if(args.size!=1)
throw SyntaxError("abs requires one numeric argument", position) throw SyntaxError("abs requires one numeric argument", position)
@@ -237,7 +233,7 @@ private fun builtinAbs(args: List<IExpression>, position: Position, program: Pro
} }
} }
private fun builtinStrlen(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinStrlen(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("strlen requires one argument", position) throw SyntaxError("strlen requires one argument", position)
val argument = args[0].constValue(program) ?: throw NotConstArgumentException() val argument = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -247,7 +243,7 @@ private fun builtinStrlen(args: List<IExpression>, position: Position, program:
throw NotConstArgumentException() // this function is not considering the string argument a constant throw NotConstArgumentException() // this function is not considering the string argument a constant
} }
private fun builtinLen(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinLen(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
// note: in some cases the length is > 255 and then we have to return a UWORD type instead of a UBYTE. // note: in some cases the length is > 255 and then we have to return a UWORD type instead of a UBYTE.
if(args.size!=1) if(args.size!=1)
throw SyntaxError("len requires one argument", position) throw SyntaxError("len requires one argument", position)
@@ -288,7 +284,7 @@ private fun builtinLen(args: List<IExpression>, position: Position, program: Pro
} }
private fun builtinMkword(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinMkword(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 2) if (args.size != 2)
throw SyntaxError("mkword requires lsb and msb arguments", position) throw SyntaxError("mkword requires lsb and msb arguments", position)
val constLsb = args[0].constValue(program) ?: throw NotConstArgumentException() val constLsb = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -297,7 +293,7 @@ private fun builtinMkword(args: List<IExpression>, position: Position, program:
return NumericLiteralValue(DataType.UWORD, result, position) return NumericLiteralValue(DataType.UWORD, result, position)
} }
private fun builtinSin8(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinSin8(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("sin8 requires one argument", position) throw SyntaxError("sin8 requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -305,7 +301,7 @@ private fun builtinSin8(args: List<IExpression>, position: Position, program: Pr
return NumericLiteralValue(DataType.BYTE, (127.0 * sin(rad)).toShort(), position) return NumericLiteralValue(DataType.BYTE, (127.0 * sin(rad)).toShort(), position)
} }
private fun builtinSin8u(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinSin8u(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("sin8u requires one argument", position) throw SyntaxError("sin8u requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -313,7 +309,7 @@ private fun builtinSin8u(args: List<IExpression>, position: Position, program: P
return NumericLiteralValue(DataType.UBYTE, (128.0 + 127.5 * sin(rad)).toShort(), position) return NumericLiteralValue(DataType.UBYTE, (128.0 + 127.5 * sin(rad)).toShort(), position)
} }
private fun builtinCos8(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinCos8(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("cos8 requires one argument", position) throw SyntaxError("cos8 requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -321,7 +317,7 @@ private fun builtinCos8(args: List<IExpression>, position: Position, program: Pr
return NumericLiteralValue(DataType.BYTE, (127.0 * cos(rad)).toShort(), position) return NumericLiteralValue(DataType.BYTE, (127.0 * cos(rad)).toShort(), position)
} }
private fun builtinCos8u(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinCos8u(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("cos8u requires one argument", position) throw SyntaxError("cos8u requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -329,7 +325,7 @@ private fun builtinCos8u(args: List<IExpression>, position: Position, program: P
return NumericLiteralValue(DataType.UBYTE, (128.0 + 127.5 * cos(rad)).toShort(), position) return NumericLiteralValue(DataType.UBYTE, (128.0 + 127.5 * cos(rad)).toShort(), position)
} }
private fun builtinSin16(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinSin16(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("sin16 requires one argument", position) throw SyntaxError("sin16 requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -337,7 +333,7 @@ private fun builtinSin16(args: List<IExpression>, position: Position, program: P
return NumericLiteralValue(DataType.WORD, (32767.0 * sin(rad)).toInt(), position) return NumericLiteralValue(DataType.WORD, (32767.0 * sin(rad)).toInt(), position)
} }
private fun builtinSin16u(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinSin16u(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("sin16u requires one argument", position) throw SyntaxError("sin16u requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -345,7 +341,7 @@ private fun builtinSin16u(args: List<IExpression>, position: Position, program:
return NumericLiteralValue(DataType.UWORD, (32768.0 + 32767.5 * sin(rad)).toInt(), position) return NumericLiteralValue(DataType.UWORD, (32768.0 + 32767.5 * sin(rad)).toInt(), position)
} }
private fun builtinCos16(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinCos16(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("cos16 requires one argument", position) throw SyntaxError("cos16 requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
@@ -353,7 +349,7 @@ private fun builtinCos16(args: List<IExpression>, position: Position, program: P
return NumericLiteralValue(DataType.WORD, (32767.0 * cos(rad)).toInt(), position) return NumericLiteralValue(DataType.WORD, (32767.0 * cos(rad)).toInt(), position)
} }
private fun builtinCos16u(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue { private fun builtinCos16u(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1) if (args.size != 1)
throw SyntaxError("cos16u requires one argument", position) throw SyntaxError("cos16u requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()

View File

@@ -1,6 +1,9 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.* import prog8.ast.INameScope
import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.base.ParentSentinel import prog8.ast.base.ParentSentinel
import prog8.ast.base.VarDeclType import prog8.ast.base.VarDeclType
@@ -19,7 +22,7 @@ class CallGraph(private val program: Program): IAstVisitor {
val subroutinesCalling = mutableMapOf<INameScope, List<Subroutine>>().withDefault { mutableListOf() } val subroutinesCalling = mutableMapOf<INameScope, List<Subroutine>>().withDefault { mutableListOf() }
val subroutinesCalledBy = mutableMapOf<Subroutine, List<Node>>().withDefault { mutableListOf() } val subroutinesCalledBy = mutableMapOf<Subroutine, List<Node>>().withDefault { mutableListOf() }
// TODO add dataflow graph: what statements use what variables // TODO add dataflow graph: what statements use what variables
val usedSymbols = mutableSetOf<IStatement>() val usedSymbols = mutableSetOf<Statement>()
init { init {
visit(program) visit(program)
@@ -94,11 +97,11 @@ class CallGraph(private val program: Program): IAstVisitor {
super.visit(identifier) super.visit(identifier)
} }
private fun addNodeAndParentScopes(stmt: IStatement) { private fun addNodeAndParentScopes(stmt: Statement) {
usedSymbols.add(stmt) usedSymbols.add(stmt)
var node: Node=stmt var node: Node=stmt
do { do {
if(node is INameScope && node is IStatement) { if(node is INameScope && node is Statement) {
usedSymbols.add(node) usedSymbols.add(node)
} }
node=node.parent node=node.parent
@@ -176,7 +179,7 @@ class CallGraph(private val program: Program): IAstVisitor {
super.visit(inlineAssembly) super.visit(inlineAssembly)
} }
private fun scanAssemblyCode(asm: String, context: IStatement, scope: INameScope) { private fun scanAssemblyCode(asm: String, context: Statement, scope: INameScope) {
val asmJumpRx = Regex("""[\-+a-zA-Z0-9_ \t]+(jmp|jsr)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE) val asmJumpRx = Regex("""[\-+a-zA-Z0-9_ \t]+(jmp|jsr)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE)
val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE) val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE)
asm.lines().forEach { line -> asm.lines().forEach { line ->

View File

@@ -1,14 +1,14 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.IExpression
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.Expression
import prog8.ast.expressions.NumericLiteralValue import prog8.ast.expressions.NumericLiteralValue
import kotlin.math.pow import kotlin.math.pow
class ConstExprEvaluator { class ConstExprEvaluator {
fun evaluate(left: NumericLiteralValue, operator: String, right: NumericLiteralValue): IExpression { fun evaluate(left: NumericLiteralValue, operator: String, right: NumericLiteralValue): Expression {
return when(operator) { return when(operator) {
"+" -> plus(left, right) "+" -> plus(left, right)
"-" -> minus(left, right) "-" -> minus(left, right)
@@ -34,7 +34,7 @@ class ConstExprEvaluator {
} }
} }
private fun shiftedright(left: NumericLiteralValue, amount: NumericLiteralValue): IExpression { private fun shiftedright(left: NumericLiteralValue, amount: NumericLiteralValue): Expression {
if(left.type !in IntegerDatatypes || amount.type !in IntegerDatatypes) if(left.type !in IntegerDatatypes || amount.type !in IntegerDatatypes)
throw ExpressionError("cannot compute $left >> $amount", left.position) throw ExpressionError("cannot compute $left >> $amount", left.position)
val result = val result =
@@ -45,7 +45,7 @@ class ConstExprEvaluator {
return NumericLiteralValue(left.type, result, left.position) return NumericLiteralValue(left.type, result, left.position)
} }
private fun shiftedleft(left: NumericLiteralValue, amount: NumericLiteralValue): IExpression { private fun shiftedleft(left: NumericLiteralValue, amount: NumericLiteralValue): Expression {
if(left.type !in IntegerDatatypes || amount.type !in IntegerDatatypes) if(left.type !in IntegerDatatypes || amount.type !in IntegerDatatypes)
throw ExpressionError("cannot compute $left << $amount", left.position) throw ExpressionError("cannot compute $left << $amount", left.position)
val result = left.number.toInt().shl(amount.number.toInt()) val result = left.number.toInt().shl(amount.number.toInt())

View File

@@ -1,8 +1,6 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.IExpression
import prog8.ast.IFunctionCall import prog8.ast.IFunctionCall
import prog8.ast.IStatement
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
@@ -28,7 +26,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
} }
override fun visit(decl: VarDecl): IStatement { override fun visit(decl: VarDecl): Statement {
// the initializer value can't refer to the variable itself (recursive definition) // the initializer value can't refer to the variable itself (recursive definition)
// TODO: use call tree for this? // TODO: use call tree for this?
if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) { if(decl.value?.referencesIdentifiers(decl.name) == true || decl.arraysize?.index?.referencesIdentifiers(decl.name) == true) {
@@ -193,7 +191,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
/** /**
* replace identifiers that refer to const value, with the value itself (if it's a simple type) * replace identifiers that refer to const value, with the value itself (if it's a simple type)
*/ */
override fun visit(identifier: IdentifierReference): IExpression { override fun visit(identifier: IdentifierReference): Expression {
return try { return try {
val cval = identifier.constValue(program) ?: return identifier val cval = identifier.constValue(program) ?: return identifier
return when { return when {
@@ -211,7 +209,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
} }
override fun visit(functionCall: FunctionCall): IExpression { override fun visit(functionCall: FunctionCall): Expression {
return try { return try {
super.visit(functionCall) super.visit(functionCall)
typeCastConstArguments(functionCall) typeCastConstArguments(functionCall)
@@ -222,7 +220,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
} }
override fun visit(functionCallStatement: FunctionCallStatement): IStatement { override fun visit(functionCallStatement: FunctionCallStatement): Statement {
super.visit(functionCallStatement) super.visit(functionCallStatement)
typeCastConstArguments(functionCallStatement) typeCastConstArguments(functionCallStatement)
return functionCallStatement return functionCallStatement
@@ -246,7 +244,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
} }
override fun visit(memread: DirectMemoryRead): IExpression { override fun visit(memread: DirectMemoryRead): Expression {
// @( &thing ) --> thing // @( &thing ) --> thing
val addrOf = memread.addressExpression as? AddressOf val addrOf = memread.addressExpression as? AddressOf
if(addrOf!=null) if(addrOf!=null)
@@ -259,7 +257,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
* Compile-time constant sub expressions will be evaluated on the spot. * Compile-time constant sub expressions will be evaluated on the spot.
* For instance, the expression for "- 4.5" will be optimized into the float literal -4.5 * For instance, the expression for "- 4.5" will be optimized into the float literal -4.5
*/ */
override fun visit(expr: PrefixExpression): IExpression { override fun visit(expr: PrefixExpression): Expression {
return try { return try {
super.visit(expr) super.visit(expr)
@@ -317,7 +315,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
* (X / c1) * c2 -> X / (c2/c1) * (X / c1) * c2 -> X / (c2/c1)
* (X + c1) - c2 -> X + (c1-c2) * (X + c1) - c2 -> X + (c1-c2)
*/ */
override fun visit(expr: BinaryExpression): IExpression { override fun visit(expr: BinaryExpression): Expression {
return try { return try {
super.visit(expr) super.visit(expr)
@@ -364,7 +362,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
leftIsConst: Boolean, leftIsConst: Boolean,
rightIsConst: Boolean, rightIsConst: Boolean,
subleftIsConst: Boolean, subleftIsConst: Boolean,
subrightIsConst: Boolean): IExpression subrightIsConst: Boolean): Expression
{ {
// @todo this implements only a small set of possible reorderings for now // @todo this implements only a small set of possible reorderings for now
if(expr.operator==subExpr.operator) { if(expr.operator==subExpr.operator) {
@@ -551,13 +549,13 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
} }
} }
override fun visit(forLoop: ForLoop): IStatement { override fun visit(forLoop: ForLoop): Statement {
fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr { fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr {
val newFrom = rangeFrom.cast(targetDt) val newFrom = rangeFrom.cast(targetDt)
val newTo = rangeTo.cast(targetDt) val newTo = rangeTo.cast(targetDt)
if (newFrom != null && newTo != null) { if (newFrom != null && newTo != null) {
val newStep: IExpression = val newStep: Expression =
if (stepLiteral != null) (stepLiteral.cast(targetDt) ?: stepLiteral) else range.step if (stepLiteral != null) (stepLiteral.cast(targetDt) ?: stepLiteral) else range.step
return RangeExpr(newFrom, newTo, newStep, range.position) return RangeExpr(newFrom, newTo, newStep, range.position)
} }
@@ -605,7 +603,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
return resultStmt return resultStmt
} }
override fun visit(refLiteral: ReferenceLiteralValue): IExpression { override fun visit(refLiteral: ReferenceLiteralValue): Expression {
val litval = super.visit(refLiteral) val litval = super.visit(refLiteral)
if(litval is ReferenceLiteralValue) { if(litval is ReferenceLiteralValue) {
if (litval.isString) { if (litval.isString) {
@@ -670,7 +668,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
return litval return litval
} }
override fun visit(assignment: Assignment): IStatement { override fun visit(assignment: Assignment): Statement {
super.visit(assignment) super.visit(assignment)
val lv = assignment.value as? NumericLiteralValue val lv = assignment.value as? NumericLiteralValue
if(lv!=null) { if(lv!=null) {

View File

@@ -1,7 +1,5 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.IExpression
import prog8.ast.IStatement
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.AstException import prog8.ast.base.AstException
import prog8.ast.base.DataType import prog8.ast.base.DataType
@@ -10,6 +8,7 @@ import prog8.ast.base.NumericDatatypes
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.Assignment import prog8.ast.statements.Assignment
import prog8.ast.statements.Statement
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.log2 import kotlin.math.log2
@@ -23,13 +22,13 @@ import kotlin.math.log2
internal class SimplifyExpressions(private val program: Program) : IAstModifyingVisitor { internal class SimplifyExpressions(private val program: Program) : IAstModifyingVisitor {
var optimizationsDone: Int = 0 var optimizationsDone: Int = 0
override fun visit(assignment: Assignment): IStatement { override fun visit(assignment: Assignment): Statement {
if (assignment.aug_op != null) if (assignment.aug_op != null)
throw AstException("augmented assignments should have been converted to normal assignments before this optimizer") throw AstException("augmented assignments should have been converted to normal assignments before this optimizer")
return super.visit(assignment) return super.visit(assignment)
} }
override fun visit(memread: DirectMemoryRead): IExpression { override fun visit(memread: DirectMemoryRead): Expression {
// @( &thing ) --> thing // @( &thing ) --> thing
val addrOf = memread.addressExpression as? AddressOf val addrOf = memread.addressExpression as? AddressOf
if(addrOf!=null) if(addrOf!=null)
@@ -37,7 +36,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return super.visit(memread) return super.visit(memread)
} }
override fun visit(typecast: TypecastExpression): IExpression { override fun visit(typecast: TypecastExpression): Expression {
var tc = typecast var tc = typecast
// try to statically convert a literal value into one of the desired type // try to statically convert a literal value into one of the desired type
@@ -83,7 +82,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
} }
} }
override fun visit(expr: PrefixExpression): IExpression { override fun visit(expr: PrefixExpression): Expression {
if (expr.operator == "+") { if (expr.operator == "+") {
// +X --> X // +X --> X
optimizationsDone++ optimizationsDone++
@@ -130,7 +129,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return super.visit(expr) return super.visit(expr)
} }
override fun visit(expr: BinaryExpression): IExpression { override fun visit(expr: BinaryExpression): Expression {
super.visit(expr) super.visit(expr)
val leftVal = expr.left.constValue(program) val leftVal = expr.left.constValue(program)
val rightVal = expr.right.constValue(program) val rightVal = expr.right.constValue(program)
@@ -343,7 +342,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr return expr
} }
private fun determineY(x: IExpression, subBinExpr: BinaryExpression): IExpression? { private fun determineY(x: Expression, subBinExpr: BinaryExpression): Expression? {
return when { return when {
subBinExpr.left isSameAs x -> subBinExpr.right subBinExpr.left isSameAs x -> subBinExpr.right
subBinExpr.right isSameAs x -> subBinExpr.left subBinExpr.right isSameAs x -> subBinExpr.left
@@ -450,7 +449,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return ReorderedAssociativeBinaryExpr(expr, leftVal, expr.right.constValue(program)) return ReorderedAssociativeBinaryExpr(expr, leftVal, expr.right.constValue(program))
} }
private fun optimizeAdd(pexpr: BinaryExpression, pleftVal: NumericLiteralValue?, prightVal: NumericLiteralValue?): IExpression { private fun optimizeAdd(pexpr: BinaryExpression, pleftVal: NumericLiteralValue?, prightVal: NumericLiteralValue?): Expression {
if(pleftVal==null && prightVal==null) if(pleftVal==null && prightVal==null)
return pexpr return pexpr
@@ -471,7 +470,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr return expr
} }
private fun optimizeSub(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): IExpression { private fun optimizeSub(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): Expression {
if(leftVal==null && rightVal==null) if(leftVal==null && rightVal==null)
return expr return expr
@@ -500,7 +499,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr return expr
} }
private fun optimizePower(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): IExpression { private fun optimizePower(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): Expression {
if(leftVal==null && rightVal==null) if(leftVal==null && rightVal==null)
return expr return expr
@@ -580,7 +579,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr return expr
} }
private fun optimizeRemainder(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): IExpression { private fun optimizeRemainder(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): Expression {
if(leftVal==null && rightVal==null) if(leftVal==null && rightVal==null)
return expr return expr
@@ -604,7 +603,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
} }
private fun optimizeDivision(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): IExpression { private fun optimizeDivision(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): Expression {
if(leftVal==null && rightVal==null) if(leftVal==null && rightVal==null)
return expr return expr
@@ -676,14 +675,14 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr return expr
} }
private fun optimizeMultiplication(pexpr: BinaryExpression, pleftVal: NumericLiteralValue?, prightVal: NumericLiteralValue?): IExpression { private fun optimizeMultiplication(pexpr: BinaryExpression, pleftVal: NumericLiteralValue?, prightVal: NumericLiteralValue?): Expression {
if(pleftVal==null && prightVal==null) if(pleftVal==null && prightVal==null)
return pexpr return pexpr
val (expr, _, rightVal) = reorderAssociative(pexpr, pleftVal) val (expr, _, rightVal) = reorderAssociative(pexpr, pleftVal)
if(rightVal!=null) { if(rightVal!=null) {
// right value is a constant, see if we can optimize // right value is a constant, see if we can optimize
val leftValue: IExpression = expr.left val leftValue: Expression = expr.left
val rightConst: NumericLiteralValue = rightVal val rightConst: NumericLiteralValue = rightVal
when(val cv = rightConst.number.toDouble()) { when(val cv = rightConst.number.toDouble()) {
-1.0 -> { -1.0 -> {

View File

@@ -59,7 +59,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
val scope = caller.definingScope() val scope = caller.definingScope()
if(sub.calledBy.count { it.definingScope()===scope } > 1) if(sub.calledBy.count { it.definingScope()===scope } > 1)
return return
if(caller !is IFunctionCall || caller !is IStatement || sub.statements.any { it is Subroutine }) if(caller !is IFunctionCall || caller !is Statement || sub.statements.any { it is Subroutine })
return return
if(sub.parameters.isEmpty() && sub.returntypes.isEmpty()) { if(sub.parameters.isEmpty() && sub.returntypes.isEmpty()) {
@@ -136,7 +136,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
} }
} }
override fun visit(block: Block): IStatement { override fun visit(block: Block): Statement {
if("force_output" !in block.options()) { if("force_output" !in block.options()) {
if (block.containsNoCodeNorVars()) { if (block.containsNoCodeNorVars()) {
optimizationsDone++ optimizationsDone++
@@ -154,7 +154,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(block) return super.visit(block)
} }
override fun visit(subroutine: Subroutine): IStatement { override fun visit(subroutine: Subroutine): Statement {
super.visit(subroutine) super.visit(subroutine)
val forceOutput = "force_output" in subroutine.definingBlock().options() val forceOutput = "force_output" in subroutine.definingBlock().options()
if(subroutine.asmAddress==null && !forceOutput) { if(subroutine.asmAddress==null && !forceOutput) {
@@ -191,7 +191,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return subroutine return subroutine
} }
override fun visit(decl: VarDecl): IStatement { override fun visit(decl: VarDecl): Statement {
val forceOutput = "force_output" in decl.definingBlock().options() val forceOutput = "force_output" in decl.definingBlock().options()
if(decl !in callgraph.usedSymbols && !forceOutput) { if(decl !in callgraph.usedSymbols && !forceOutput) {
if(decl.type == VarDeclType.VAR) if(decl.type == VarDeclType.VAR)
@@ -203,7 +203,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(decl) return super.visit(decl)
} }
private fun deduplicateAssignments(statements: List<IStatement>): MutableList<Int> { private fun deduplicateAssignments(statements: List<Statement>): MutableList<Int> {
// removes 'duplicate' assignments that assign the isSameAs target // removes 'duplicate' assignments that assign the isSameAs target
val linesToRemove = mutableListOf<Int>() val linesToRemove = mutableListOf<Int>()
var previousAssignmentLine: Int? = null var previousAssignmentLine: Int? = null
@@ -228,7 +228,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return linesToRemove return linesToRemove
} }
override fun visit(functionCallStatement: FunctionCallStatement): IStatement { override fun visit(functionCallStatement: FunctionCallStatement): Statement {
if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in BuiltinFunctions) { if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in BuiltinFunctions) {
val functionName = functionCallStatement.target.nameInSource[0] val functionName = functionCallStatement.target.nameInSource[0]
if (functionName in pureBuiltinFunctions) { if (functionName in pureBuiltinFunctions) {
@@ -286,7 +286,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(functionCallStatement) return super.visit(functionCallStatement)
} }
override fun visit(functionCall: FunctionCall): IExpression { override fun visit(functionCall: FunctionCall): Expression {
// if it calls a subroutine, // if it calls a subroutine,
// and the first instruction in the subroutine is a jump, call that jump target instead // and the first instruction in the subroutine is a jump, call that jump target instead
// if the first instruction in the subroutine is a return statement with constant value, replace with the constant value // if the first instruction in the subroutine is a return statement with constant value, replace with the constant value
@@ -306,7 +306,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(functionCall) return super.visit(functionCall)
} }
override fun visit(ifStatement: IfStatement): IStatement { override fun visit(ifStatement: IfStatement): Statement {
super.visit(ifStatement) super.visit(ifStatement)
if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsNoCodeNorVars()) { if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsNoCodeNorVars()) {
@@ -340,7 +340,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return ifStatement return ifStatement
} }
override fun visit(forLoop: ForLoop): IStatement { override fun visit(forLoop: ForLoop): Statement {
super.visit(forLoop) super.visit(forLoop)
if(forLoop.body.containsNoCodeNorVars()) { if(forLoop.body.containsNoCodeNorVars()) {
// remove empty for loop // remove empty for loop
@@ -370,7 +370,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return forLoop return forLoop
} }
override fun visit(whileLoop: WhileLoop): IStatement { override fun visit(whileLoop: WhileLoop): Statement {
super.visit(whileLoop) super.visit(whileLoop)
val constvalue = whileLoop.condition.constValue(program) val constvalue = whileLoop.condition.constValue(program)
if(constvalue!=null) { if(constvalue!=null) {
@@ -396,7 +396,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return whileLoop return whileLoop
} }
override fun visit(repeatLoop: RepeatLoop): IStatement { override fun visit(repeatLoop: RepeatLoop): Statement {
super.visit(repeatLoop) super.visit(repeatLoop)
val constvalue = repeatLoop.untilCondition.constValue(program) val constvalue = repeatLoop.untilCondition.constValue(program)
if(constvalue!=null) { if(constvalue!=null) {
@@ -426,7 +426,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return repeatLoop return repeatLoop
} }
override fun visit(whenStatement: WhenStatement): IStatement { override fun visit(whenStatement: WhenStatement): Statement {
val choices = whenStatement.choices.toList() val choices = whenStatement.choices.toList()
for(choice in choices) { for(choice in choices) {
if(choice.statements.containsNoCodeNorVars()) if(choice.statements.containsNoCodeNorVars())
@@ -441,12 +441,12 @@ internal class StatementOptimizer(private val program: Program, private val opti
{ {
var count=0 var count=0
override fun visit(breakStmt: Break): IStatement { override fun visit(breakStmt: Break): Statement {
count++ count++
return super.visit(breakStmt) return super.visit(breakStmt)
} }
override fun visit(contStmt: Continue): IStatement { override fun visit(contStmt: Continue): Statement {
count++ count++
return super.visit(contStmt) return super.visit(contStmt)
} }
@@ -460,7 +460,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return s.count > 0 return s.count > 0
} }
override fun visit(jump: Jump): IStatement { override fun visit(jump: Jump): Statement {
val subroutine = jump.identifier?.targetSubroutine(program.namespace) val subroutine = jump.identifier?.targetSubroutine(program.namespace)
if(subroutine!=null) { if(subroutine!=null) {
// if the first instruction in the subroutine is another jump, shortcut this one // if the first instruction in the subroutine is another jump, shortcut this one
@@ -484,7 +484,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return jump return jump
} }
override fun visit(assignment: Assignment): IStatement { override fun visit(assignment: Assignment): Statement {
if(assignment.aug_op!=null) if(assignment.aug_op!=null)
throw AstException("augmented assignments should have been converted to normal assignments before this optimizer") throw AstException("augmented assignments should have been converted to normal assignments before this optimizer")
@@ -615,7 +615,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(assignment) return super.visit(assignment)
} }
override fun visit(scope: AnonymousScope): IStatement { override fun visit(scope: AnonymousScope): Statement {
val linesToRemove = deduplicateAssignments(scope.statements) val linesToRemove = deduplicateAssignments(scope.statements)
if(linesToRemove.isNotEmpty()) { if(linesToRemove.isNotEmpty()) {
linesToRemove.reversed().forEach{scope.statements.removeAt(it)} linesToRemove.reversed().forEach{scope.statements.removeAt(it)}
@@ -623,7 +623,7 @@ internal class StatementOptimizer(private val program: Program, private val opti
return super.visit(scope) return super.visit(scope)
} }
override fun visit(label: Label): IStatement { override fun visit(label: Label): Statement {
// remove duplicate labels // remove duplicate labels
val stmts = label.definingScope().statements val stmts = label.definingScope().statements
val startIdx = stmts.indexOf(label) val startIdx = stmts.indexOf(label)
@@ -644,7 +644,7 @@ internal class FlattenAnonymousScopesAndRemoveNops: IAstVisitor {
super.visit(program) super.visit(program)
for(scope in scopesToFlatten.reversed()) { for(scope in scopesToFlatten.reversed()) {
val namescope = scope.parent as INameScope val namescope = scope.parent as INameScope
val idx = namescope.statements.indexOf(scope as IStatement) val idx = namescope.statements.indexOf(scope as Statement)
if(idx>=0) { if(idx>=0) {
val nop = NopStatement.insteadOf(namescope.statements[idx]) val nop = NopStatement.insteadOf(namescope.statements[idx])
nop.parent = namescope as Node nop.parent = namescope as Node

View File

@@ -1,10 +1,9 @@
package prog8.vm.astvm package prog8.vm.astvm
import prog8.ast.IExpression
import prog8.ast.INameScope import prog8.ast.INameScope
import prog8.ast.IStatement
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.Expression
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.NumericLiteralValue import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.statements.* import prog8.ast.statements.*
@@ -314,7 +313,7 @@ class AstVm(val program: Program) {
} }
private fun executeStatement(sub: INameScope, stmt: IStatement) { private fun executeStatement(sub: INameScope, stmt: Statement) {
instructionCounter++ instructionCounter++
if (instructionCounter % 200 == 0) if (instructionCounter % 200 == 0)
Thread.sleep(1) Thread.sleep(1)
@@ -545,7 +544,7 @@ class AstVm(val program: Program) {
performAssignment(target2, value1, swap, evalCtx) performAssignment(target2, value1, swap, evalCtx)
} }
fun performAssignment(target: AssignTarget, value: RuntimeValue, contextStmt: IStatement, evalCtx: EvalContext) { fun performAssignment(target: AssignTarget, value: RuntimeValue, contextStmt: Statement, evalCtx: EvalContext) {
when { when {
target.identifier != null -> { target.identifier != null -> {
val decl = contextStmt.definingScope().lookup(target.identifier.nameInSource, contextStmt) as? VarDecl val decl = contextStmt.definingScope().lookup(target.identifier.nameInSource, contextStmt) as? VarDecl
@@ -642,7 +641,7 @@ class AstVm(val program: Program) {
executeAnonymousScope(stmt.body) executeAnonymousScope(stmt.body)
} }
private fun evaluate(args: List<IExpression>) = args.map { evaluate(it, evalCtx) } private fun evaluate(args: List<Expression>) = args.map { evaluate(it, evalCtx) }
private fun performSyscall(sub: Subroutine, args: List<RuntimeValue>): RuntimeValue? { private fun performSyscall(sub: Subroutine, args: List<RuntimeValue>): RuntimeValue? {
var result: RuntimeValue? = null var result: RuntimeValue? = null

View File

@@ -1,6 +1,5 @@
package prog8.vm.astvm package prog8.vm.astvm
import prog8.ast.IExpression
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.ArrayElementTypes import prog8.ast.base.ArrayElementTypes
import prog8.ast.base.DataType import prog8.ast.base.DataType
@@ -25,7 +24,7 @@ class EvalContext(val program: Program, val mem: Memory, val statusflags: Status
val performBuiltinFunction: BuiltinfunctionCaller, val performBuiltinFunction: BuiltinfunctionCaller,
val executeSubroutine: SubroutineCaller) val executeSubroutine: SubroutineCaller)
fun evaluate(expr: IExpression, ctx: EvalContext): RuntimeValue { fun evaluate(expr: Expression, ctx: EvalContext): RuntimeValue {
val constval = expr.constValue(ctx.program) val constval = expr.constValue(ctx.program)
if(constval!=null) if(constval!=null)
return RuntimeValue.fromLv(constval) return RuntimeValue.fromLv(constval)

View File

@@ -1,6 +1,5 @@
package prog8.vm.astvm package prog8.vm.astvm
import prog8.ast.IStatement
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.base.Position import prog8.ast.base.Position
@@ -9,6 +8,7 @@ import prog8.ast.base.VarDeclType
import prog8.ast.expressions.NumericLiteralValue import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.expressions.ReferenceLiteralValue import prog8.ast.expressions.ReferenceLiteralValue
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.Statement
import prog8.ast.statements.StructDecl import prog8.ast.statements.StructDecl
import prog8.ast.statements.VarDecl import prog8.ast.statements.VarDecl
import prog8.ast.statements.ZeropageWish import prog8.ast.statements.ZeropageWish
@@ -40,7 +40,7 @@ class VariablesCreator(private val runtimeVariables: RuntimeVariables, private v
super.visit(program) super.visit(program)
} }
override fun visit(decl: VarDecl): IStatement { override fun visit(decl: VarDecl): Statement {
// if the decl is part of a struct, just skip it // if the decl is part of a struct, just skip it
if(decl.parent !is StructDecl) { if(decl.parent !is StructDecl) {
when (decl.type) { when (decl.type) {
@@ -67,7 +67,7 @@ class VariablesCreator(private val runtimeVariables: RuntimeVariables, private v
return super.visit(decl) return super.visit(decl)
} }
// override fun accept(assignment: Assignment): IStatement { // override fun accept(assignment: Assignment): Statement {
// if(assignment is VariableInitializationAssignment) { // if(assignment is VariableInitializationAssignment) {
// println("INIT VAR $assignment") // println("INIT VAR $assignment")
// } // }