defer is now done *after* calculating a return value

This commit is contained in:
Irmen de Jong 2024-10-18 20:53:30 +02:00
parent d8f1822c12
commit 2a52241f1c
3 changed files with 92 additions and 15 deletions

View File

@ -1,9 +1,9 @@
package prog8.compiler.astprocessing
import prog8.ast.base.FatalAstException
import prog8.code.SymbolTable
import prog8.code.ast.*
import prog8.code.core.DataType
import prog8.code.core.IErrorReporter
import prog8.code.core.*
internal fun postprocessIntermediateAst(program: PtProgram, st: SymbolTable, errors: IErrorReporter) {
coalesceDefers(program)
@ -39,7 +39,8 @@ private fun coalesceDefers(program: PtProgram) {
private fun integrateDefers(program: PtProgram, st: SymbolTable) {
val exitsToAugment = mutableListOf<PtNode>()
val jumpsToAugment = mutableListOf<PtJump>()
val returnsToAugment = mutableListOf<PtReturn>()
val subEndsToAugment = mutableListOf<PtSub>()
walkAst(program) { node, _ ->
@ -49,10 +50,10 @@ private fun integrateDefers(program: PtProgram, st: SymbolTable) {
val stNode = st.lookup(node.identifier!!.name)!!
val targetSub = stNode.astNode.definingSub()
if(targetSub!=node.definingSub())
exitsToAugment.add(node)
jumpsToAugment.add(node)
}
}
is PtReturn -> exitsToAugment.add(node)
is PtReturn -> returnsToAugment.add(node)
is PtSub -> {
val lastStmt = node.children.lastOrNull { it !is PtDefer }
if(lastStmt != null && lastStmt !is PtReturn && lastStmt !is PtJump)
@ -62,12 +63,66 @@ private fun integrateDefers(program: PtProgram, st: SymbolTable) {
}
}
for(exit in exitsToAugment) {
val defer = exit.definingSub()!!.children.singleOrNull { it is PtDefer }
fun invokedeferbefore(node: PtNode) {
val defer = node.definingSub()!!.children.singleOrNull { it is PtDefer }
if (defer != null) {
val idx = exit.parent.children.indexOf(exit)
val invokedefer = PtBuiltinFunctionCall("invoke_defer", true, false, DataType.UNDEFINED, exit.position)
exit.parent.add(idx, invokedefer)
val idx = node.parent.children.indexOf(node)
val invokedefer = PtBuiltinFunctionCall("invoke_defer", true, false, DataType.UNDEFINED, node.position)
node.parent.add(idx, invokedefer)
}
}
for(exit in jumpsToAugment) {
invokedeferbefore(exit)
}
for(ret in returnsToAugment) {
val defer = ret.definingSub()!!.children.singleOrNull { it is PtDefer }
if(defer == null)
continue
if(ret.children.size>1)
TODO("support defer on multi return values")
if(!ret.hasValue || ret.value!!.isSimple()) {
invokedeferbefore(ret)
} else {
val value = ret.value!!
var typecast: DataType? = null
var pushWord = false
var pushFloat = false
when(value.type) {
DataType.BOOL -> typecast = DataType.BOOL
DataType.BYTE -> typecast = DataType.BYTE
DataType.WORD -> {
pushWord = true
typecast = DataType.WORD
}
DataType.UBYTE -> {}
DataType.UWORD, in PassByReferenceDatatypes -> pushWord = true
DataType.FLOAT -> pushFloat = true
else -> throw FatalAstException("unsupported return value type ${value.type} with defer")
}
val pushFunc = if(pushFloat) "floats.push" else if(pushWord) "sys.pushw" else "sys.push"
val popFunc = if(pushFloat) "floats.pop" else if(pushWord) "sys.popw" else "sys.pop"
val pushCall = PtFunctionCall(pushFunc, true, value.type, value.position)
pushCall.add(value)
val popCall = if(typecast!=null) {
PtTypeCast(typecast, value.position).also {
it.add(PtFunctionCall(popFunc, false, value.type, value.position))
}
} else
PtFunctionCall(popFunc, false, value.type, value.position)
val newRet = PtReturn(ret.position)
newRet.add(popCall)
val group = PtNodeGroup()
group.add(pushCall)
group.add(PtBuiltinFunctionCall("invoke_defer", true, false, DataType.UNDEFINED, ret.position))
group.add(newRet)
group.parent = ret.parent
val idx = ret.parent.children.indexOf(ret)
ret.parent.children[idx] = group
}
}

View File

@ -1,7 +1,8 @@
TODO
====
- fix defer that a return <expression> is evaluated first (and saved), then the defer is called, then a simple return <value> is done
- check defer stmt/block to not contain a return or jump statement!
- run defer also before sys.exit/exit2/exit3
- unit test for defer
- describe defer in the manual

View File

@ -1,4 +1,5 @@
%import textio
%import floats
%option no_sysinit
%zeropage basicsafe
@ -8,6 +9,25 @@ main {
txt.print("result from call=")
txt.print_ub(x)
txt.nl()
float f = testdeferf()
txt.print("result from fcall=")
floats.print(f)
txt.nl()
floats.push(f)
txt.print("pushed f")
f = floats.pop()
floats.print(f)
txt.nl()
}
sub testdeferf() -> float {
defer {
txt.print("defer in floats\n")
}
float @shared zz = 111.111
cx16.r0++
return 123.456 + zz
}
sub testdefer() -> ubyte {
@ -22,7 +42,7 @@ main {
if var==22 {
var = 88
return var
return var + other()
}
else {
var++
@ -35,7 +55,8 @@ main {
}
sub other() {
cx16.r0++
sub other() -> ubyte {
txt.print("other()\n")
return 11
}
}