diff --git a/compiler/src/prog8/ast/base/Extensions.kt b/compiler/src/prog8/ast/base/Extensions.kt index 00c4b9dcf..f1285926c 100644 --- a/compiler/src/prog8/ast/base/Extensions.kt +++ b/compiler/src/prog8/ast/base/Extensions.kt @@ -38,6 +38,12 @@ internal fun Program.addTypecasts(errors: ErrorReporter) { caster.applyModifications() } +internal fun Program.simplifyNumericCasts() { + val fixer = TypecastsSimplifier(this) + fixer.visit(this) + fixer.applyModifications() +} + internal fun Program.transformAssignments(errors: ErrorReporter) { val transform = AssignmentTransformer(this, errors) transform.visit(this) diff --git a/compiler/src/prog8/ast/processing/TypecastsAdder.kt b/compiler/src/prog8/ast/processing/TypecastsAdder.kt index 967044771..3f5fe465f 100644 --- a/compiler/src/prog8/ast/processing/TypecastsAdder.kt +++ b/compiler/src/prog8/ast/processing/TypecastsAdder.kt @@ -7,6 +7,7 @@ import prog8.ast.Program import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.statements.* +import prog8.compiler.CompilerException import prog8.functions.BuiltinFunctions @@ -125,21 +126,18 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke } is BuiltinFunctionStatementPlaceholder -> { val func = BuiltinFunctions.getValue(sub.name) - if(func.pure) { - // non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters - for (arg in func.parameters.zip(call.args.withIndex())) { - val argItype = arg.second.value.inferType(program) - if (argItype.isKnown) { - val argtype = argItype.typeOrElse(DataType.STRUCT) - if (arg.first.possibleDatatypes.any { argtype == it }) - continue - for (possibleType in arg.first.possibleDatatypes) { - if (argtype isAssignableTo possibleType) { - modifications += IAstModification.ReplaceNode( - call.args[arg.second.index], - TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position), - call as Node) - } + for (arg in func.parameters.zip(call.args.withIndex())) { + val argItype = arg.second.value.inferType(program) + if (argItype.isKnown) { + val argtype = argItype.typeOrElse(DataType.STRUCT) + if (arg.first.possibleDatatypes.any { argtype == it }) + continue + for (possibleType in arg.first.possibleDatatypes) { + if (argtype isAssignableTo possibleType) { + modifications += IAstModification.ReplaceNode( + call.args[arg.second.index], + TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position), + call as Node) } } } @@ -249,3 +247,52 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke return noModifications } } + + + +class TypecastsSimplifier(val program: Program) : AstWalker() { + /* + * Typecasts of a numeric literal value can be replaced by the numeric value of the type directly. + */ + + private val noModifications = emptyList() + + override fun before(typecast: TypecastExpression, parent: Node): Iterable { + if(typecast.expression is NumericLiteralValue) { + val value = (typecast.expression as NumericLiteralValue).cast(typecast.type) + return listOf(IAstModification.ReplaceNode(typecast, value, parent)) + } + + return noModifications + } + + override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable + = checkCallArgTypes(functionCallStatement as IFunctionCall, functionCallStatement.definingScope()) + + override fun after(functionCall: FunctionCall, parent: Node): Iterable + = checkCallArgTypes(functionCall as IFunctionCall, functionCall.definingScope()) + + private fun checkCallArgTypes(call: IFunctionCall, scope: INameScope): Iterable { + val argtypes = call.args.map { it.inferType(program).typeOrElse(DataType.STRUCT) } + val target = call.target.targetStatement(scope) + when(target) { + is Subroutine -> { + val paramtypes = target.parameters.map { it.type } + if(argtypes!=paramtypes) + throw CompilerException("parameter type mismatch $call") + } + is BuiltinFunctionStatementPlaceholder -> { + val func = BuiltinFunctions.getValue(target.name) + val paramtypes = func.parameters.map { it.possibleDatatypes } + for(x in argtypes.zip(paramtypes)) { + if(x.first !in x.second) + throw CompilerException("parameter type mismatch $call") + } + } + else -> {} + } + println("**** $target") + return noModifications + } + +} diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index c732300bb..c39317dfa 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -180,6 +180,7 @@ private fun postprocessAst(programAst: Program, errors: ErrorReporter, compilerO errors.handle() programAst.addTypecasts(errors) errors.handle() + programAst.simplifyNumericCasts() programAst.removeNopsFlattenAnonScopes() programAst.checkValid(compilerOptions, errors) // check if final tree is still valid errors.handle() diff --git a/examples/tehtriz.p8 b/examples/tehtriz.p8 index fab634674..e0538edb1 100644 --- a/examples/tehtriz.p8 +++ b/examples/tehtriz.p8 @@ -9,8 +9,7 @@ -; TODO fix crash when piece reaches bottom. (codegen issue). -; TODO fix wrong behavior when compiled without optimizations (codegen issue). +; TODO fix wrong block behavior at bottom when compiled without optimizations (codegen issue). main {