diff --git a/compiler/examples/numbergame.p8 b/compiler/examples/numbergame.p8 index e989e8b5a..bd2a92c72 100644 --- a/compiler/examples/numbergame.p8 +++ b/compiler/examples/numbergame.p8 @@ -2,14 +2,15 @@ ~ main { sub start() -> () { - str name = "?" * 20 - str guess = "?" * 20 + str name = " " + str guess = "0000000000" byte secretnumber = 0 byte attempts_left = 10 ; greeting + _vm_write_str("Let's play a number guessing game!\n") _vm_write_str("Enter your name: ") - ; _vm_input_str(name) + _vm_input_str(name) _vm_write_char($8d) _vm_write_char($8d) _vm_write_str("Hello, ") @@ -17,8 +18,16 @@ _vm_write_char($2e) _vm_write_char($8d) + secretnumber = make_number() return + + sub make_number() -> (X) { + byte number + number = rnd() + return rnd() + return number + } ; ; create a secret random number from 1-100 ; c64.RNDA(0) ; fac = rnd(0) ; c64.MUL10() ; fac *= 10 diff --git a/compiler/examples/test.p8 b/compiler/examples/test.p8 index a1cce409c..2555673ad 100644 --- a/compiler/examples/test.p8 +++ b/compiler/examples/test.p8 @@ -16,59 +16,18 @@ sub start() -> () { byte[2,3] matrix1 = [1,2, 3,4, 5,6] byte[2,3] matrix2 = [1,2, 3,4, 5,6] byte[2,3] matrix3 = [11,22, 33,44, 55,66] - byte num1 = 99 - word num2 = 12345 + byte num1 = sin(2.0) + word num2 = rndw() float num3 = 98.555 + num1=rnd() - num1 = max(msg3) - _vm_write_str(num1) - _vm_write_char($8d) - num1 = min(msg3) - _vm_write_str(num1) - _vm_write_char($8d) + num1=thing() + return - num3 = avg(array3) - _vm_write_str(num3) - _vm_write_char($8d) - num1 = all(array3) - _vm_write_str(num1) - _vm_write_char($8d) - num1 = any(array3) - _vm_write_str(num1) - _vm_write_char($8d) - - num1 = max(array3) - _vm_write_str(num1) - _vm_write_char($8d) - num1 = min(array3) - _vm_write_str(num1) - _vm_write_char($8d) - num2 = max(array5) - _vm_write_str(num2) - _vm_write_char($8d) - num2 = min(array5) - _vm_write_str(num2) - _vm_write_char($8d) - num2 = sum(array3) - _vm_write_str(num2) - _vm_write_char($8d) - num2 = sum(array5) - _vm_write_str(num2) - _vm_write_char($8d) - - num3 = avg(msg3) - _vm_write_str(num3) - _vm_write_char($8d) - num2 = sum(msg3) - _vm_write_str(num2) - _vm_write_char($8d) - num1 = all(msg3) - _vm_write_str(num1) - _vm_write_char($8d) - num1 = any(msg3) - _vm_write_str(num1) - _vm_write_char($8d) + sub thing() -> (X) { + return 99 + } } } diff --git a/compiler/src/prog8/CompilerMain.kt b/compiler/src/prog8/CompilerMain.kt index ce31d1d1f..56e2deab5 100644 --- a/compiler/src/prog8/CompilerMain.kt +++ b/compiler/src/prog8/CompilerMain.kt @@ -51,13 +51,14 @@ fun main(args: Array) { options.any{ it.name=="enable_floats"}) - // perform syntax checks and optimizations - moduleAst.checkIdentifiers() - - println("Optimizing...") + // perform initial syntax checks and constant folding val heap = HeapValues() + moduleAst.checkIdentifiers() moduleAst.constantFold(namespace, heap) moduleAst.checkValid(namespace, compilerOptions, heap) // check if tree is valid + + // optimize the parse tree + println("Optimizing...") val allScopedSymbolDefinitions = moduleAst.checkIdentifiers() while(true) { // keep optimizing expressions and statements until no more steps remain diff --git a/compiler/src/prog8/ast/AST.kt b/compiler/src/prog8/ast/AST.kt index 4860f389e..08e9b04a2 100644 --- a/compiler/src/prog8/ast/AST.kt +++ b/compiler/src/prog8/ast/AST.kt @@ -4,7 +4,9 @@ import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.TerminalNode import prog8.compiler.HeapValues import prog8.compiler.target.c64.Petscii -import prog8.functions.* +import prog8.functions.BuiltinFunctions +import prog8.functions.NotConstArgumentException +import prog8.functions.builtinFunctionReturnType import prog8.parser.prog8Parser import java.nio.file.Paths import kotlin.math.abs @@ -367,8 +369,7 @@ object BuiltinFunctionScopePlaceholder : INameScope { override fun registerUsedName(name: String) = throw NotImplementedError("not implemented on sub-scopes") } -object BuiltinFunctionStatementPlaceholder : IStatement { - override val position = Position("<>", 0, 0, 0) +class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position) : IStatement { override var parent: Node = ParentSentinel override fun linkParents(parent: Node) {} override fun process(processor: IAstProcessor): IStatement = this @@ -405,7 +406,7 @@ private class GlobalNamespace(override val name: String, private val scopedNamesUsed: MutableSet = mutableSetOf("main", "main.start") // main and main.start are always used override fun lookup(scopedName: List, statement: Node): IStatement? { - if(BuiltinFunctionNames.contains(scopedName.last())) { + if(scopedName.last() in BuiltinFunctions) { // builtin functions always exist, return a dummy statement for them val builtinPlaceholder = Label("builtin::${scopedName.last()}", statement.position) builtinPlaceholder.parent = ParentSentinel @@ -428,7 +429,7 @@ private class GlobalNamespace(override val name: String, override fun registerUsedName(name: String) { // make sure to also register each scope separately scopedNamesUsed.add(name) - if(name.contains('.')) + if('.' in name) registerUsedName(name.substringBeforeLast('.')) } } @@ -922,7 +923,7 @@ class LiteralValue(val type: DataType, override fun resultingDatatype(namespace: INameScope, heap: HeapValues) = type - override fun isIterable(namespace: INameScope, heap: HeapValues): Boolean = IterableDatatypes.contains(type) + override fun isIterable(namespace: INameScope, heap: HeapValues): Boolean = type in IterableDatatypes override fun hashCode(): Int { val bh = bytevalue?.hashCode() ?: 0x10001234 @@ -1056,8 +1057,8 @@ data class IdentifierReference(val nameInSource: List, override val posi override lateinit var parent: Node fun targetStatement(namespace: INameScope) = - if(nameInSource.size==1 && BuiltinFunctionNames.contains(nameInSource[0])) - BuiltinFunctionStatementPlaceholder + if(nameInSource.size==1 && nameInSource[0] in BuiltinFunctions) + BuiltinFunctionStatementPlaceholder(nameInSource[0], position) else namespace.lookup(nameInSource, this) @@ -1093,7 +1094,7 @@ data class IdentifierReference(val nameInSource: List, override val posi } } - override fun isIterable(namespace: INameScope, heap: HeapValues): Boolean = IterableDatatypes.contains(resultingDatatype(namespace, heap)) + override fun isIterable(namespace: INameScope, heap: HeapValues): Boolean = resultingDatatype(namespace, heap) in IterableDatatypes } @@ -1149,37 +1150,16 @@ class FunctionCall(override var target: IdentifierReference, // if the function is a built-in function and the args are consts, should try to const-evaluate! if(target.nameInSource.size>1) return null try { - val resultValue = when (target.nameInSource[0]) { - "sin" -> builtinSin(arglist, position, namespace, heap) - "cos" -> builtinCos(arglist, position, namespace, heap) - "abs" -> builtinAbs(arglist, position, namespace, heap) - "acos" -> builtinAcos(arglist, position, namespace, heap) - "asin" -> builtinAsin(arglist, position, namespace, heap) - "tan" -> builtinTan(arglist, position, namespace, heap) - "atan" -> builtinAtan(arglist, position, namespace, heap) - "ln" -> builtinLn(arglist, position, namespace, heap) - "log2" -> builtinLog2(arglist, position, namespace, heap) - "log10" -> builtinLog10(arglist, position, namespace, heap) - "sqrt" -> builtinSqrt(arglist, position, namespace, heap) - "max" -> builtinMax(arglist, position, namespace, heap) - "min" -> builtinMin(arglist, position, namespace, heap) - "round" -> builtinRound(arglist, position, namespace, heap) - "rad" -> builtinRad(arglist, position, namespace, heap) - "deg" -> builtinDeg(arglist, position, namespace, heap) - "sum" -> builtinSum(arglist, position, namespace, heap) - "avg" -> builtinAvg(arglist, position, namespace, heap) - "len" -> builtinLen(arglist, position, namespace, heap) - "lsb" -> builtinLsb(arglist, position, namespace, heap) - "msb" -> builtinMsb(arglist, position, namespace, heap) - "flt" -> builtinFlt(arglist, position, namespace, heap) - "any" -> builtinAny(arglist, position, namespace, heap) - "all" -> builtinAll(arglist, position, namespace, heap) - "floor" -> builtinFloor(arglist, position, namespace, heap) - "ceil" -> builtinCeil(arglist, position, namespace, heap) - "lsl", "lsr", "rol", "rol2", "ror", "ror2", "set_carry", "clear_carry", "set_irqd", "clear_irqd" -> + var resultValue: LiteralValue? = null + val func = BuiltinFunctions[target.nameInSource[0]] + if(func!=null) { + val exprfunc = func.expressionFunc + if(exprfunc!=null) + resultValue = exprfunc(arglist, position, namespace, heap) + else if(func.returnvalues.isEmpty()) throw ExpressionError("builtin function ${target.nameInSource[0]} can't be used in expressions because it doesn't return a value", position) - else -> null } + if(withDatatypeCheck) { val resultDt = this.resultingDatatype(namespace, heap) if(resultValue==null || resultDt == resultValue.type) @@ -1234,12 +1214,10 @@ class FunctionCall(override var target: IdentifierReference, } is Label -> return null } - TODO("datatype of functioncall to $stmt") + return null // calling something we don't recognise... } - override fun isIterable(namespace: INameScope, heap: HeapValues) : Boolean { - TODO("isIterable of function call result") - } + override fun isIterable(namespace: INameScope, heap: HeapValues) = resultingDatatype(namespace, heap) in IterableDatatypes } diff --git a/compiler/src/prog8/ast/AstChecker.kt b/compiler/src/prog8/ast/AstChecker.kt index 223c0e8f2..daf52f815 100644 --- a/compiler/src/prog8/ast/AstChecker.kt +++ b/compiler/src/prog8/ast/AstChecker.kt @@ -2,7 +2,7 @@ package prog8.ast import prog8.compiler.CompilationOptions import prog8.compiler.HeapValues -import prog8.functions.BuiltinFunctionNames +import prog8.functions.BuiltinFunctions import prog8.parser.ParsingFailedError /** @@ -20,7 +20,7 @@ fun printErrors(errors: List, moduleName: String) { val reportedMessages = mutableSetOf() errors.forEach { val msg = it.toString() - if(!reportedMessages.contains(msg)) { + if(msg !in reportedMessages) { System.err.println(msg) reportedMessages.add(msg) } @@ -179,7 +179,7 @@ class AstChecker(private val namespace: INameScope, checkResult.add(SyntaxError(msg, subroutine.position)) } - if(BuiltinFunctionNames.contains(subroutine.name)) + if(subroutine.name in BuiltinFunctions) err("cannot redefine a built-in function") val uniqueNames = subroutine.parameters.asSequence().map { it.name }.toSet() @@ -204,8 +204,7 @@ class AstChecker(private val namespace: INameScope, .asSequence() .filter { it is InlineAssembly } .map {(it as InlineAssembly).assembly} - .count { it.contains(" rts") || it.contains("\trts") || - it.contains(" jmp") || it.contains("\tjmp")} + .count { "rts" in it || "\trts" in it || "jmp" in it || "\tjmp" in it } if(amount==0 && subroutine.returnvalues.isNotEmpty()) err("subroutine has result value(s) and thus must have at least one 'return' or 'goto' in it (or 'rts' / 'jmp' in case of %asm)") // @todo validate return values versus subroutine signature @@ -222,7 +221,7 @@ class AstChecker(private val namespace: INameScope, private fun checkSubroutinesPrecededByReturnOrJumpAndFollowedByLabelOrSub(statements: MutableList) { // @todo hmm, or move all the subroutines at the end of the block? (no fall-through execution) - var preceding: IStatement = BuiltinFunctionStatementPlaceholder + var preceding: IStatement = BuiltinFunctionStatementPlaceholder("dummy", Position("<>", 0, 0,0 )) var checkNext = false for (stmt in statements) { if(checkNext) { @@ -510,29 +509,34 @@ class AstChecker(private val namespace: INameScope, val targetStatement = checkFunctionOrLabelExists(functionCall.target, stmtOfExpression) if(targetStatement!=null) - checkBuiltinFunctionCall(functionCall, functionCall.position) - - if(targetStatement is Label && functionCall.arglist.isNotEmpty()) - checkResult.add(SyntaxError("cannot use arguments when calling a label", functionCall.position)) - - // todo check subroutine call parameters against signature - + checkFunctionCall(targetStatement, functionCall.arglist, functionCall.position) return super.process(functionCall) } override fun process(functionCall: FunctionCallStatement): IStatement { val targetStatement = checkFunctionOrLabelExists(functionCall.target, functionCall) if(targetStatement!=null) - checkBuiltinFunctionCall(functionCall, functionCall.position) - - if(targetStatement is Label && functionCall.arglist.isNotEmpty()) - checkResult.add(SyntaxError("cannot use arguments when calling a label", functionCall.position)) - - // todo check subroutine call parameters against signature - + checkFunctionCall(targetStatement, functionCall.arglist, functionCall.position) return super.process(functionCall) } + private fun checkFunctionCall(target: IStatement, args: List, position: Position) { + if(target is Label && args.isNotEmpty()) + checkResult.add(SyntaxError("cannot use arguments when calling a label", position)) + + if(target is BuiltinFunctionStatementPlaceholder) { + // it's a cal to a builtin function. + // todo make the signature checking similar to when calling a user-defined function + if(target.name=="set_carry" || target.name=="set_irqd" || target.name=="clear_carry" || target.name=="clear_irqd") { + // these functions have zero arguments + if(args.isNotEmpty()) + checkResult.add(SyntaxError("${target.name} has zero arguments", position)) + } + } else { + // @todo check call (params against signature) to user function + } + } + override fun process(postIncrDecr: PostIncrDecr): IStatement { if(postIncrDecr.target.register==null) { val targetName = postIncrDecr.target.identifier!!.nameInSource @@ -558,18 +562,6 @@ class AstChecker(private val namespace: INameScope, return null } - private fun checkBuiltinFunctionCall(call: IFunctionCall, position: Position) { - if(call.target.nameInSource.size==1 && BuiltinFunctionNames.contains(call.target.nameInSource[0])) { - val functionName = call.target.nameInSource[0] - if(functionName=="set_carry" || functionName=="set_irqd" || functionName=="clear_carry" || functionName=="clear_irqd") { - // these functions have zero arguments - if(call.arglist.isNotEmpty()) - checkResult.add(SyntaxError("$functionName has zero arguments", position)) - } - } - } - - private fun checkValueTypeAndRange(targetDt: DataType, arrayspec: ArraySpec?, range: RangeExpr) : Boolean { val from = range.from.constValue(namespace, heap) val to = range.to.constValue(namespace, heap) diff --git a/compiler/src/prog8/ast/AstIdentifiersChecker.kt b/compiler/src/prog8/ast/AstIdentifiersChecker.kt index 4f2c2e233..6a6963ea6 100644 --- a/compiler/src/prog8/ast/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/ast/AstIdentifiersChecker.kt @@ -1,6 +1,6 @@ package prog8.ast -import prog8.functions.BuiltinFunctionNames +import prog8.functions.BuiltinFunctions /** * Checks the validity of all identifiers (no conflicts) @@ -47,7 +47,7 @@ class AstIdentifiersChecker : IAstProcessor { decl.datatypeErrors.forEach { checkResult.add(it) } // now check the identifier - if(BuiltinFunctionNames.contains(decl.name)) + if(decl.name in BuiltinFunctions) // the builtin functions can't be redefined checkResult.add(NameError("builtin function cannot be redefined", decl.position)) @@ -62,11 +62,11 @@ class AstIdentifiersChecker : IAstProcessor { } override fun process(subroutine: Subroutine): IStatement { - if(BuiltinFunctionNames.contains(subroutine.name)) { + if(subroutine.name in BuiltinFunctions) { // the builtin functions can't be redefined checkResult.add(NameError("builtin function cannot be redefined", subroutine.position)) } else { - if (subroutine.parameters.any { BuiltinFunctionNames.contains(it.name) }) + if (subroutine.parameters.any { it.name in BuiltinFunctions }) checkResult.add(NameError("builtin function name cannot be used as parameter", subroutine.position)) val scopedName = subroutine.scopedname @@ -80,7 +80,7 @@ class AstIdentifiersChecker : IAstProcessor { // check that there are no local variables that redefine the subroutine's parameters val definedNames = subroutine.labelsAndVariables() val paramNames = subroutine.parameters.map { it.name } - val definedNamesCorrespondingToParameters = definedNames.filter { paramNames.contains(it.key) } + val definedNamesCorrespondingToParameters = definedNames.filter { it.key in paramNames } for(name in definedNamesCorrespondingToParameters) { if(name.value.position != subroutine.position) nameError(name.key, name.value.position, subroutine) @@ -99,7 +99,7 @@ class AstIdentifiersChecker : IAstProcessor { } override fun process(label: Label): IStatement { - if(BuiltinFunctionNames.contains(label.name)) { + if(label.name in BuiltinFunctions) { // the builtin functions can't be redefined checkResult.add(NameError("builtin function cannot be redefined", label.position)) } else { diff --git a/compiler/src/prog8/ast/ImportedAstChecker.kt b/compiler/src/prog8/ast/ImportedAstChecker.kt index 1c7f8c390..4fe3e345a 100644 --- a/compiler/src/prog8/ast/ImportedAstChecker.kt +++ b/compiler/src/prog8/ast/ImportedAstChecker.kt @@ -31,7 +31,7 @@ class ImportedAstChecker : IAstProcessor { for (sourceStmt in module.statements) { val stmt = sourceStmt.process(this) if(stmt is Directive && stmt.parent is Module) { - if(moduleLevelDirectives.contains(stmt.directive)) { + if(stmt.directive in moduleLevelDirectives) { printWarning("ignoring module directive because it was imported", stmt.position, stmt.directive) continue } diff --git a/compiler/src/prog8/ast/StmtReorderer.kt b/compiler/src/prog8/ast/StmtReorderer.kt index bf34885b5..1314db9bf 100644 --- a/compiler/src/prog8/ast/StmtReorderer.kt +++ b/compiler/src/prog8/ast/StmtReorderer.kt @@ -18,7 +18,7 @@ class StatementReorderer: IAstProcessor { val varDecls = module.statements.filter { it is VarDecl } module.statements.removeAll(varDecls) module.statements.addAll(0, varDecls) - val directives = module.statements.filter {it is Directive && directivesToMove.contains(it.directive)} + val directives = module.statements.filter {it is Directive && it.directive in directivesToMove} module.statements.removeAll(directives) module.statements.addAll(0, directives) super.process(module) @@ -33,7 +33,7 @@ class StatementReorderer: IAstProcessor { val varDecls = block.statements.filter { it is VarDecl } block.statements.removeAll(varDecls) block.statements.addAll(0, varDecls) - val directives = block.statements.filter {it is Directive && directivesToMove.contains(it.directive)} + val directives = block.statements.filter {it is Directive && it.directive in directivesToMove} block.statements.removeAll(directives) block.statements.addAll(0, directives) return super.process(block) @@ -43,7 +43,7 @@ class StatementReorderer: IAstProcessor { val varDecls = subroutine.statements.filter { it is VarDecl } subroutine.statements.removeAll(varDecls) subroutine.statements.addAll(0, varDecls) - val directives = subroutine.statements.filter {it is Directive && directivesToMove.contains(it.directive)} + val directives = subroutine.statements.filter {it is Directive && it.directive in directivesToMove} subroutine.statements.removeAll(directives) subroutine.statements.addAll(0, directives) return super.process(subroutine) diff --git a/compiler/src/prog8/compiler/Zeropage.kt b/compiler/src/prog8/compiler/Zeropage.kt index fa7153077..a2358e65d 100644 --- a/compiler/src/prog8/compiler/Zeropage.kt +++ b/compiler/src/prog8/compiler/Zeropage.kt @@ -70,12 +70,6 @@ abstract class Zeropage(private val options: CompilationOptions) { return location } - private fun loneByte(location: Int): Boolean { - return free.contains(location) && !free.contains(location-1) && !free.contains(location+1) - } - - private fun sequentialFree(location: Int, size: Int): Boolean { - return free.containsAll((location until location+size).toList()) - } - + private fun loneByte(location: Int) = location in free && location-1 !in free && location+1 !in free + private fun sequentialFree(location: Int, size: Int) = free.containsAll((location until location+size).toList()) } diff --git a/compiler/src/prog8/compiler/target/c64/Commodore64.kt b/compiler/src/prog8/compiler/target/c64/Commodore64.kt index 74d1d6e69..9f89d5c50 100644 --- a/compiler/src/prog8/compiler/target/c64/Commodore64.kt +++ b/compiler/src/prog8/compiler/target/c64/Commodore64.kt @@ -43,10 +43,10 @@ class C64Zeropage(options: CompilationOptions) : Zeropage(options) { 0x12, 0x2a, 0x52, 0x94, 0x95, 0xa7, 0xa8, 0xa9, 0xaa, 0xb5, 0xb6, 0xf7, 0xf8, 0xf9, 0xfa)) } - assert(!free.contains(SCRATCH_B1)) - assert(!free.contains(SCRATCH_B2)) - assert(!free.contains(SCRATCH_W1)) - assert(!free.contains(SCRATCH_W2)) + assert(SCRATCH_B1 !in free) + assert(SCRATCH_B2 !in free) + assert(SCRATCH_W1 !in free) + assert(SCRATCH_W2 !in free) } } diff --git a/compiler/src/prog8/functions/BuiltinFunctions.kt b/compiler/src/prog8/functions/BuiltinFunctions.kt index 34c264a1a..6c91d1dd6 100644 --- a/compiler/src/prog8/functions/BuiltinFunctions.kt +++ b/compiler/src/prog8/functions/BuiltinFunctions.kt @@ -5,19 +5,86 @@ import prog8.compiler.HeapValues import kotlin.math.log2 -val BuiltinFunctionNames = setOf( - "set_carry", "clear_carry", "set_irqd", "clear_irqd", "rol", "ror", "rol2", "ror2", "lsl", "lsr", - "sin", "cos", "abs", "acos", "asin", "tan", "atan", "rnd", "rndw", "rndf", - "ln", "log2", "log10", "sqrt", "rad", "deg", "round", "floor", "ceil", - "max", "min", "avg", "sum", "len", "any", "all", "lsb", "msb", "flt", - "_vm_write_memchr", "_vm_write_memstr", "_vm_write_num", "_vm_write_char", - "_vm_write_str", "_vm_input_str", "_vm_gfx_clearscr", "_vm_gfx_pixel", "_vm_gfx_text" - ) +class FunctionSignature(val pure: Boolean, // does it have side effects? + val parameters: List, + val returnvalues: List, + val type: DataType?, + val expressionFunc: ((args: List, position: Position, namespace: INameScope, heap: HeapValues) -> LiteralValue)?) { + companion object { + private val dummyPos = Position("dummy", 0, 0, 0) -val BuiltinFunctionsWithoutSideEffects = BuiltinFunctionNames - setOf( - "set_carry", "clear_carry", "set_irqd", "clear_irqd", "lsl", "lsr", "rol", "ror", "rol2", "ror2", - "_vm_write_memchr", "_vm_write_memstr", "_vm_write_num", "_vm_write_char", - "_vm_write_str", "_vm_gfx_clearscr", "_vm_gfx_pixel", "_vm_gfx_text") + fun sig(pure: Boolean, + args: List, + hasReturnValue: Boolean, + type: DataType?, + expressionFunc: ((args: List, position: Position, namespace: INameScope, heap: HeapValues) -> LiteralValue)? = null + ) : FunctionSignature { + if(!hasReturnValue && expressionFunc!=null) + throw IllegalArgumentException("can't have expression func when hasReturnValue is false") + return FunctionSignature(pure, + args.map { SubroutineParameter(it, null, null, dummyPos) }, + if(hasReturnValue) + listOf(SubroutineReturnvalue(null, null, false, dummyPos)) + else + emptyList(), + type, + expressionFunc + ) + } + } +} + + +val BuiltinFunctions = mapOf( + "set_carry" to FunctionSignature.sig(false, emptyList(), false, null), + "clear_carry" to FunctionSignature.sig(false, emptyList(), false, null), + "set_irqd" to FunctionSignature.sig(false, emptyList(), false, null), + "clear_irqd" to FunctionSignature.sig(false, emptyList(), false, null), + "rol" to FunctionSignature.sig(false, listOf("item"), false, null), + "ror" to FunctionSignature.sig(false, listOf("item"), false, null), + "rol2" to FunctionSignature.sig(false, listOf("item"), false, null), + "ror2" to FunctionSignature.sig(false, listOf("item"), false, null), + "lsl" to FunctionSignature.sig(false, listOf("item"), false, null), + "lsr" to FunctionSignature.sig(false, listOf("item"), false, null), + "sin" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::sin) }, + "cos" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::cos) }, + "acos" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::acos) }, + "asin" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::asin) }, + "tan" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::tan) }, + "atan" to FunctionSignature.sig(true, listOf("rads"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::atan) }, + "rnd" to FunctionSignature.sig(true, emptyList(), true, DataType.BYTE), + "rndw" to FunctionSignature.sig(true, emptyList(), true, DataType.WORD), + "rndf" to FunctionSignature.sig(true, emptyList(), true, DataType.FLOAT), + "ln" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::log) }, + "log2" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, ::log2) }, + "log10" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::log10) }, + "sqrt" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::sqrt) }, + "rad" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::toRadians) }, + "deg" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT) { a, p, n, h -> oneDoubleArg(a, p, n, h, Math::toDegrees) }, + "avg" to FunctionSignature.sig(true, listOf("values"), true, DataType.FLOAT, ::builtinAvg), + "abs" to FunctionSignature.sig(true, listOf("value"), true, null, ::builtinAbs), // type depends on arg + "round" to FunctionSignature.sig(true, listOf("value"), true, null) { a, p, n, h -> oneDoubleArgOutputInt(a, p, n, h, Math::round) }, // type depends on arg + "floor" to FunctionSignature.sig(true, listOf("value"), true, null) { a, p, n, h -> oneDoubleArgOutputInt(a, p, n, h, Math::floor) }, // type depends on arg + "ceil" to FunctionSignature.sig(true, listOf("value"), true, null) { a, p, n, h -> oneDoubleArgOutputInt(a, p, n, h, Math::ceil) }, // type depends on arg + "max" to FunctionSignature.sig(true, listOf("values"), true, null) { a, p, n, h -> collectionArgOutputNumber(a, p, n, h) { it.max()!! }}, // type depends on args + "min" to FunctionSignature.sig(true, listOf("values"), true, null) { a, p, n, h -> collectionArgOutputNumber(a, p, n, h) { it.min()!! }}, // type depends on args + "sum" to FunctionSignature.sig(true, listOf("values"), true, null) { a, p, n, h -> collectionArgOutputNumber(a, p, n, h) { it.sum() }}, // type depends on args + "len" to FunctionSignature.sig(true, listOf("values"), true, null, ::builtinLen), // type depends on args + "any" to FunctionSignature.sig(true, listOf("values"), true, DataType.BYTE) { a, p, n, h -> collectionArgOutputBoolean(a, p, n, h) { it.any { v -> v != 0.0} }}, + "all" to FunctionSignature.sig(true, listOf("values"), true, DataType.BYTE) { a, p, n, h -> collectionArgOutputBoolean(a, p, n, h) { it.all { v -> v != 0.0} }}, + "lsb" to FunctionSignature.sig(true, listOf("value"), true, DataType.BYTE) { a, p, n, h -> oneIntArgOutputInt(a, p, n, h) { x: Int -> x and 255 }}, + "msb" to FunctionSignature.sig(true, listOf("value"), true, DataType.BYTE) { a, p, n, h -> oneIntArgOutputInt(a, p, n, h) { x: Int -> x ushr 8 and 255}}, + "flt" to FunctionSignature.sig(true, listOf("value"), true, DataType.FLOAT, ::builtinFlt), + "_vm_write_memchr" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_write_memstr" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_write_num" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_write_char" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_write_str" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_input_str" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_gfx_clearscr" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_gfx_pixel" to FunctionSignature.sig(false, emptyList(), false, null), + "_vm_gfx_text" to FunctionSignature.sig(false, emptyList(), false, null) +) fun builtinFunctionReturnType(function: String, args: List, namespace: INameScope, heap: HeapValues): DataType? { @@ -57,12 +124,14 @@ fun builtinFunctionReturnType(function: String, args: List, namespa throw FatalAstException("function requires one argument which is an array $function") } + val func = BuiltinFunctions[function]!! + if(func.returnvalues.isEmpty()) + return null + if(func.type!=null) + return func.type + // function has return values, but the return type depends on the arguments + return when (function) { - "sin", "cos", "tan", "asin", "acos", "atan", "ln", "log2", "log10", - "sqrt", "rad", "deg", "avg", "rndf", "flt" -> DataType.FLOAT - "lsb", "msb", "any", "all", "rnd" -> DataType.BYTE - "rndw" -> DataType.WORD - "rol", "rol2", "ror", "ror2", "lsl", "lsr", "set_carry", "clear_carry", "set_irqd", "clear_irqd" -> null // no return value so no datatype "abs" -> args.single().resultingDatatype(namespace, heap) "max", "min" -> { val dt = datatypeFromListArg(args.single()) @@ -106,10 +175,7 @@ fun builtinFunctionReturnType(function: String, args: List, namespa else -> DataType.WORD } } - "_vm_write_memchr", "_vm_write_memstr", "_vm_write_num", "_vm_write_char", - "_vm_write_str", "_vm_gfx_clearscr", "_vm_gfx_pixel", "_vm_gfx_text" -> null // no return value for these - "_vm_input_str" -> DataType.STR - else -> throw FatalAstException("invalid builtin function $function") + else -> throw FatalAstException("unknown result type for builtin function $function") } } @@ -155,11 +221,11 @@ private fun collectionArgOutputNumber(args: List, position: Positio function: (arg: Collection)->Number): LiteralValue { if(args.size!=1) throw SyntaxError("builtin function requires one non-scalar argument", position) - var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() + val iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue != null) { - val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } - if(constants.contains(null)) + val constants = iterable.arrayvalue.map { it.constValue(namespace, heap)?.asNumericValue } + if(null in constants) throw NotConstArgumentException() function(constants.map { it!!.toDouble() }).toDouble() } else { @@ -174,11 +240,11 @@ private fun collectionArgOutputBoolean(args: List, position: Positi function: (arg: Collection)->Boolean): LiteralValue { if(args.size!=1) throw SyntaxError("builtin function requires one non-scalar argument", position) - var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() + val iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue != null) { - val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } - if(constants.contains(null)) + val constants = iterable.arrayvalue.map { it.constValue(namespace, heap)?.asNumericValue } + if(null in constants) throw NotConstArgumentException() function(constants.map { it!!.toDouble() }) } else { @@ -188,52 +254,7 @@ private fun collectionArgOutputBoolean(args: List, position: Positi return LiteralValue.fromBoolean(result, position) } -fun builtinRound(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArgOutputInt(args, position, namespace, heap, Math::round) - -fun builtinFloor(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArgOutputInt(args, position, namespace, heap, Math::floor) - -fun builtinCeil(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArgOutputInt(args, position, namespace, heap, Math::ceil) - -fun builtinSin(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::sin) - -fun builtinCos(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::cos) - -fun builtinAcos(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::acos) - -fun builtinAsin(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::asin) - -fun builtinTan(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::tan) - -fun builtinAtan(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::atan) - -fun builtinLn(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::log) - -fun builtinLog2(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, ::log2) - -fun builtinLog10(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::log10) - -fun builtinSqrt(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::sqrt) - -fun builtinRad(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::toRadians) - -fun builtinDeg(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneDoubleArg(args, position, namespace, heap, Math::toDegrees) - -fun builtinFlt(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { +private fun builtinFlt(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { // 1 numeric arg, convert to float if(args.size!=1) throw SyntaxError("flt requires one numeric argument", position) @@ -243,7 +264,7 @@ fun builtinFlt(args: List, position: Position, namespace:INameScope return LiteralValue(DataType.FLOAT, floatvalue = number.toDouble(), position = position) } -fun builtinAbs(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { +private fun builtinAbs(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { // 1 arg, type = float or int, result type= same as argument type if(args.size!=1) throw SyntaxError("abs requires one numeric argument", position) @@ -257,30 +278,14 @@ fun builtinAbs(args: List, position: Position, namespace:INameScope } } - -fun builtinLsb(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneIntArgOutputInt(args, position, namespace, heap) { x: Int -> x and 255 } - -fun builtinMsb(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = oneIntArgOutputInt(args, position, namespace, heap) { x: Int -> x ushr 8 and 255} - -fun builtinMin(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = collectionArgOutputNumber(args, position, namespace, heap) { it.min()!! } - -fun builtinMax(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = collectionArgOutputNumber(args, position, namespace, heap) { it.max()!! } - -fun builtinSum(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = collectionArgOutputNumber(args, position, namespace, heap) { it.sum() } - -fun builtinAvg(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { +private fun builtinAvg(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { if(args.size!=1) throw SyntaxError("avg requires array/matrix argument", position) - var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() + val iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue!=null) { - val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } - if (constants.contains(null)) + val constants = iterable.arrayvalue.map { it.constValue(namespace, heap)?.asNumericValue } + if (null in constants) throw NotConstArgumentException() (constants.map { it!!.toDouble() }).average() } @@ -291,7 +296,7 @@ fun builtinAvg(args: List, position: Position, namespace:INameScope return numericLiteral(result, args[0].position) } -fun builtinLen(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { +private fun builtinLen(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { if(args.size!=1) throw SyntaxError("len requires one argument", position) var argument = args[0].constValue(namespace, heap) @@ -314,13 +319,6 @@ fun builtinLen(args: List, position: Position, namespace:INameScope } } -fun builtinAny(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = collectionArgOutputBoolean(args, position, namespace, heap) { it.any { v -> v != 0.0} } - -fun builtinAll(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue - = collectionArgOutputBoolean(args, position, namespace, heap) { it.all { v -> v != 0.0} } - - private fun numericLiteral(value: Number, position: Position): LiteralValue { val floatNum=value.toDouble() val tweakedValue: Number = diff --git a/compiler/src/prog8/optimizing/ConstantFolding.kt b/compiler/src/prog8/optimizing/ConstantFolding.kt index b7b43563f..b7c44d7e8 100644 --- a/compiler/src/prog8/optimizing/ConstantFolding.kt +++ b/compiler/src/prog8/optimizing/ConstantFolding.kt @@ -13,7 +13,7 @@ class ConstantFolding(private val namespace: INameScope, private val heap: HeapV fun addError(x: AstException) { // check that we don't add the same error more than once - if(!reportedErrorMessages.contains(x.toString())) { + if(x.toString() !in reportedErrorMessages) { reportedErrorMessages.add(x.toString()) errors.add(x) } diff --git a/compiler/src/prog8/optimizing/SimplifyExpressions.kt b/compiler/src/prog8/optimizing/SimplifyExpressions.kt index 50d7b2a1d..b7a1b77a4 100644 --- a/compiler/src/prog8/optimizing/SimplifyExpressions.kt +++ b/compiler/src/prog8/optimizing/SimplifyExpressions.kt @@ -116,7 +116,7 @@ class SimplifyExpressions(private val namespace: INameScope, private val heap: H private data class ReorderedAssociativeBinaryExpr(val expr: BinaryExpression, val leftVal: LiteralValue?, val rightVal: LiteralValue?) private fun reorderAssociative(expr: BinaryExpression, leftVal: LiteralValue?): ReorderedAssociativeBinaryExpr { - if(associativeOperators.contains(expr.operator) && leftVal!=null) { + if(expr.operator in associativeOperators && leftVal!=null) { // swap left and right so that right is always the constant val tmp = expr.left expr.left = expr.right diff --git a/compiler/src/prog8/optimizing/StatementOptimizer.kt b/compiler/src/prog8/optimizing/StatementOptimizer.kt index 256981241..b7533d5ed 100644 --- a/compiler/src/prog8/optimizing/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizing/StatementOptimizer.kt @@ -2,8 +2,7 @@ package prog8.optimizing import prog8.ast.* import prog8.compiler.HeapValues -import prog8.functions.BuiltinFunctionNames -import prog8.functions.BuiltinFunctionsWithoutSideEffects +import prog8.functions.BuiltinFunctions /* @@ -30,11 +29,12 @@ class StatementOptimizer(private val globalNamespace: INameScope, private val he private set private var statementsToRemove = mutableListOf() + private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } override fun process(functionCall: FunctionCallStatement): IStatement { - if(functionCall.target.nameInSource.size==1 && BuiltinFunctionNames.contains(functionCall.target.nameInSource[0])) { + if(functionCall.target.nameInSource.size==1 && functionCall.target.nameInSource[0] in BuiltinFunctions) { val functionName = functionCall.target.nameInSource[0] - if (BuiltinFunctionsWithoutSideEffects.contains(functionName)) { + if (functionName in pureBuiltinFunctions) { printWarning("statement has no effect (function return value is discarded)", functionCall.position) statementsToRemove.add(functionCall) } diff --git a/compiler/src/prog8/stackvm/Program.kt b/compiler/src/prog8/stackvm/Program.kt index 39612579c..72029d44e 100644 --- a/compiler/src/prog8/stackvm/Program.kt +++ b/compiler/src/prog8/stackvm/Program.kt @@ -63,7 +63,7 @@ class Program (val name: String, DataType.ARRAY_W, DataType.MATRIX -> { val numbers = it.third.substring(1, it.third.length-1).split(',') - val intarray = numbers.map{it.trim().toInt()}.toIntArray() + val intarray = numbers.map{number->number.trim().toInt()}.toIntArray() heap.add(it.second, intarray) } else -> throw VmExecutionException("invalid heap value type $it.second") @@ -151,7 +151,7 @@ class Program (val name: String, if(line=="%end_variables") return vars val (name, typeStr, valueStr) = line.split(splitpattern, limit = 3) - if(valueStr[0] !='"' && !valueStr.contains(':')) + if(valueStr[0] !='"' && ':' !in valueStr) throw VmExecutionException("missing value type character") val type = DataType.valueOf(typeStr.toUpperCase()) val value = when(type) { diff --git a/compiler/src/prog8/stackvm/StackVm.kt b/compiler/src/prog8/stackvm/StackVm.kt index 8e3ee8230..8386e1b60 100644 --- a/compiler/src/prog8/stackvm/StackVm.kt +++ b/compiler/src/prog8/stackvm/StackVm.kt @@ -184,7 +184,7 @@ open class Instruction(val opcode: Opcode, val syscall = Syscall.values().find { it.callNr==arg!!.numericValue() } "syscall $syscall" } - opcodesWithVarArgument.contains(opcode) -> { + opcode in opcodesWithVarArgument -> { // opcodes that manipulate a variable "${opcode.toString().toLowerCase()} $callLabel" } @@ -253,12 +253,8 @@ class StackVm(private var traceOutputFile: String?) { for(variable in program.variables.flatMap { e->e.value.entries }) variables[variable.key] = variable.value - if(variables.contains("A") || - variables.contains("X") || - variables.contains("Y") || - variables.contains("XY") || - variables.contains("AX") || - variables.contains("AY")) + if("A" in variables || "X" in variables || "Y" in variables || + "XY" in variables || "AX" in variables ||"AY" in variables) throw VmExecutionException("program contains variable(s) for the reserved registers A,X,...") // define the 'registers' variables["A"] = Value(DataType.BYTE, 0) diff --git a/compiler/test/UnitTests.kt b/compiler/test/UnitTests.kt index ab657c108..5576dbdcc 100644 --- a/compiler/test/UnitTests.kt +++ b/compiler/test/UnitTests.kt @@ -178,7 +178,7 @@ class TestZeropage { assertEquals(239, zp.available()) val loc = zp.allocate(VarDecl(VarDeclType.VAR, DataType.FLOAT, null, "", null, dummypos)) assertTrue(loc > 3) - assertFalse(zp.free.contains(loc)) + assertFalse(loc in zp.free) val num = zp.available() / 5 val rest = zp.available() % 5