1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-11 12:29:46 +00:00

Function pointers – initial version

This commit is contained in:
Karol Stasiak 2019-07-27 00:58:10 +02:00
parent be9d27e2ee
commit 35ba36ce11
21 changed files with 399 additions and 26 deletions

View File

@ -8,6 +8,8 @@
* Added `bool` type. * Added `bool` type.
* Added function pointers so far quite limited.
* Added arrays of elements of size greater than byte. * Added arrays of elements of size greater than byte.
* Improved passing of register parameters to assembly functions. * Improved passing of register parameters to assembly functions.

View File

@ -265,11 +265,9 @@ Other kinds of expressions than the above (even `nonet(byte + byte + byte)`) wil
* `hi`, `lo`: most/least significant byte of a word * `hi`, `lo`: most/least significant byte of a word
`hi(word)` `hi(word)`
Furthermore, any type that can be assigned to a variable can be used to convert Furthermore, any type that can be assigned to a variable can be used to convert
either from one type either to another type of the same size, either from one type either to another type of the same size,
or from a 1-byte integer type to a compatible 2-byte integer type. or from a 1-byte integer type to a compatible 2-byte integer type.
`byte``word` `byte``word`
`word``pointer` `word``pointer`
some enum → `byte` some enum → `byte`
@ -281,4 +279,12 @@ some enum → `word`
* `sizeof`: size of the argument in bytes; the argument can be an expression or a type, * `sizeof`: size of the argument in bytes; the argument can be an expression or a type,
and the result is a constant of either `byte` or `word` type, depending on situation and the result is a constant of either `byte` or `word` type, depending on situation
* `call`: calls a function via a pointer;
the first argument is the pointer to the function;
the second argument, if present, is the argument to the called function.
The function can have max one parameter, of size max 1 byte, and may return a value of size max 2 bytes.
You can't create typed pointers to other kinds of functions anyway.
If the pointed-to function returns a value, then the result of `call(...)` is the result of the function.
Using `call` on 6502 targets requires at least 4 bytes of zeropage pseudoregister.

View File

@ -75,6 +75,25 @@ Its actual value is defined using the feature `NULLPTR`, by default it's 0.
`nullptr` isn't directly assignable to non-pointer types. `nullptr` isn't directly assignable to non-pointer types.
## Function pointers
For every type `A` of size 1 (or `void`) and every type `B` of size 1 or 2 (or `void`),
there is a pointer type defined called `function.A.to.B`, which represents functions with a signature like this:
B function_name(A parameter)
B function_name() // if A is void
Examples:
word i
function.void.to.word p1 = f1.pointer
i = call(p1)
function.byte.to.byte p2 = f2.pointer
i += call(p2, 7)
Using `call` on 6502 requires at least 4 bytes of zeropage pseudoregister.
## Boolean types ## Boolean types
TODO TODO

View File

@ -6,6 +6,8 @@
* [Fizzbuzz](crossplatform/fizzbuzz.mfk) (C64/C16/PET/VIC-20/PET/Atari/Apple II/BBC Micro/ZX Spectrum/PC-88/Armstrad CPC/MSX) everyone's favourite programming task * [Fizzbuzz](crossplatform/fizzbuzz.mfk) (C64/C16/PET/VIC-20/PET/Atari/Apple II/BBC Micro/ZX Spectrum/PC-88/Armstrad CPC/MSX) everyone's favourite programming task
* [Fizzbuzz 2](crossplatform/fizzbuzz2.mfk) (C64/C16/PET/VIC-20/PET/Atari/Apple II/BBC Micro/ZX Spectrum/PC-88/Armstrad CPC/MSX) an alternative, more extensible implemententation of fizzbuzz
* [Text encodings](crossplatform/text_encodings.mfk) (C64/ZX Spectrum) examples of text encoding features * [Text encodings](crossplatform/text_encodings.mfk) (C64/ZX Spectrum) examples of text encoding features
* [Echo](crossplatform/echo.mfk) (C64/C16/ZX Spectrum/PC-88/MSX) simple text input and output * [Echo](crossplatform/echo.mfk) (C64/C16/ZX Spectrum/PC-88/MSX) simple text input and output

View File

@ -0,0 +1,38 @@
import stdio
bool divisible3(byte x) = x %% 3 == 0
bool divisible5(byte x) = x %% 5 == 0
struct stage {
function.byte.to.bool predicate
pointer text
}
// can't put text literals directly in struct constructors yet
array fizz = "fizz"z
array buzz = "buzz"z
array(stage) stages = [
stage(divisible3.pointer, fizz),
stage(divisible5.pointer, buzz)
]
void main() {
byte i, s
bool printed
for i,1,to,100 {
printed = false
for s,0,until,stages.length {
if call(stages[s].predicate, i) {
printed = true
putstrz(stages[js].text)
}
}
if not(printed) {
putword(i)
}
putchar(' ')
}
}

View File

@ -47,4 +47,8 @@ macro asm void panic() {
? JSR _panic ? JSR _panic
} }
array __constant8 = [8] const array __constant8 = [8]
#if ZPREG_SIZE < 4
const array call = [2]
#endif

View File

@ -53,3 +53,8 @@ _lo_nibble_to_hex_lbl:
macro asm void panic() { macro asm void panic() {
? CALL _panic ? CALL _panic
} }
noinline asm word call(word de) {
PUSH DE
RET
}

View File

@ -96,3 +96,11 @@ asm word __div_u16u8u16u8() {
} }
#endif #endif
#if ZPREG_SIZE >= 4
noinline asm word call(word ax) {
JMP ((__reg + 2))
}
#endif

View File

@ -13,6 +13,7 @@ import millfork.error.FatalErrorReporting
object ZeropageRegisterOptimizations { object ZeropageRegisterOptimizations {
val functionsThatUsePseudoregisterAsInput: Map[String, Set[Int]] = Map( val functionsThatUsePseudoregisterAsInput: Map[String, Set[Int]] = Map(
"call" -> Set(2, 3),
"__mul_u8u8u8" -> Set(0, 1), "__mul_u8u8u8" -> Set(0, 1),
"__mod_u8u8u8u8" -> Set(0, 1), "__mod_u8u8u8u8" -> Set(0, 1),
"__div_u8u8u8u8" -> Set(0, 1), "__div_u8u8u8u8" -> Set(0, 1),

View File

@ -183,13 +183,13 @@ object ReverseFlowAnalyzer {
val cache = new FlowCache[ZLine, CpuImportance]("z80 reverse") val cache = new FlowCache[ZLine, CpuImportance]("z80 reverse")
val readsA: Set[String] = Set("__mul_u8u8u8", "__mul_u16u8u16") val readsA: Set[String] = Set("__mul_u8u8u8", "__mul_u16u8u16", "call")
val readsB: Set[String] = Set("") val readsB: Set[String] = Set("")
val readsC: Set[String] = Set("") val readsC: Set[String] = Set("")
val readsD: Set[String] = Set("__mul_u8u8u8","__mul_u16u8u16", "__divmod_u16u8u16u8") val readsD: Set[String] = Set("__mul_u8u8u8","__mul_u16u8u16", "__divmod_u16u8u16u8", "call")
val readsE: Set[String] = Set("__mul_u16u8u16") val readsE: Set[String] = Set("__mul_u16u8u16", "call")
val readsH: Set[String] = Set("__divmod_u16u8u16u8") val readsH: Set[String] = Set("__divmod_u16u8u16u8", "call")
val readsL: Set[String] = Set("__divmod_u16u8u16u8") val readsL: Set[String] = Set("__divmod_u16u8u16u8", "call")
//noinspection RedundantNewCaseClass //noinspection RedundantNewCaseClass
def analyze(f: NormalFunction, code: List[ZLine]): List[CpuImportance] = { def analyze(f: NormalFunction, code: List[ZLine]): List[CpuImportance] = {

View File

@ -260,7 +260,7 @@ object AbstractExpressionCompiler {
val v = env.get[Type]("void") val v = env.get[Type]("void")
val w = env.get[Type]("word") val w = env.get[Type]("word")
val t = expr match { val t: Type = expr match {
case LiteralExpression(_, size) => case LiteralExpression(_, size) =>
size match { size match {
case 1 => b case 1 => b
@ -372,6 +372,34 @@ object AbstractExpressionCompiler {
case Some(List(x)) => toType(!x) case Some(List(x)) => toType(!x)
case _ => bool case _ => bool
} }
case f@FunctionCallExpression("call", params) =>
params match {
case List(fp) =>
getExpressionType(env, log, fp) match {
case fpt@FunctionPointerType(_, _, _, Some(v), Some(r)) =>
if (v.name != "void"){
log.error(s"Invalid function pointer type: $fpt", fp.position)
}
r
case fpt =>
log.error(s"Not a function pointer type: $fpt", fp.position)
v
}
case List(fp, pp) =>
getExpressionType(env, log, fp) match {
case fpt@FunctionPointerType(_, _, _, Some(p), Some(r)) =>
if (!getExpressionType(env, log, pp).isAssignableTo(p)){
log.error(s"Invalid function pointer type: $fpt", fp.position)
}
r
case fpt =>
log.error(s"Not a function pointer type: $fpt", fp.position)
v
}
case _ =>
log.error("Invalid call(...) syntax; use either 1 or 2 arguments", f.position)
v
}
case FunctionCallExpression("hi", _) => b case FunctionCallExpression("hi", _) => b
case FunctionCallExpression("lo", _) => b case FunctionCallExpression("lo", _) => b
case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionType(env, log, params(1)) case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionType(env, log, params(1))

View File

@ -514,6 +514,12 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
compile(ctx, expr, Some(p -> env.get[Variable]("__reg.loword")), BranchSpec.None) compile(ctx, expr, Some(p -> env.get[Variable]("__reg.loword")), BranchSpec.None)
} }
def compileToZReg2(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = {
val env = ctx.env
val p = env.get[Type]("pointer")
compile(ctx, expr, Some(p -> env.get[Variable]("__reg.b2b3")), BranchSpec.None)
}
def getPhysicalPointerForDeref(ctx: CompilationContext, pointerExpression: Expression): (List[AssemblyLine], Constant, AddrMode.Value) = { def getPhysicalPointerForDeref(ctx: CompilationContext, pointerExpression: Expression): (List[AssemblyLine], Constant, AddrMode.Value) = {
pointerExpression match { pointerExpression match {
case VariableExpression(name) => case VariableExpression(name) =>
@ -1162,6 +1168,36 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
var zeroExtend = false var zeroExtend = false
var resultVariable = "" var resultVariable = ""
val calculate: List[AssemblyLine] = name match { val calculate: List[AssemblyLine] = name match {
case "call" =>
params match {
case List(fp) =>
getExpressionType(ctx, fp) match {
case FunctionPointerType(_, _, _, _, Some(v)) if (v.name == "void") =>
compileToZReg2(ctx, fp) :+ AssemblyLine.absolute(JSR, env.get[ThingInMemory]("call"))
case _ =>
ctx.log.error("Not a function pointer", fp.position)
compile(ctx, fp, None, BranchSpec.None)
}
case List(fp, param) =>
getExpressionType(ctx, fp) match {
case FunctionPointerType(_, _, _, Some(pt), Some(v)) =>
if (pt.size != 1) {
ctx.log.error("Invalid parameter type", param.position)
compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None)
} else if (getExpressionType(ctx, param).isAssignableTo(pt)) {
compileToA(ctx, param) ++ preserveRegisterIfNeeded(ctx, MosRegister.A, compileToZReg2(ctx, fp)) :+ AssemblyLine.absolute(JSR, env.get[ThingInMemory]("call"))
} else {
ctx.log.error("Invalid parameter type", param.position)
compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None)
}
case _ =>
ctx.log.error("Not a function pointer", fp.position)
compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None)
}
case _ =>
ctx.log.error("Invalid call syntax", f.position)
Nil
}
case "not" => case "not" =>
assertBool(ctx, "not", params, 1) assertBool(ctx, "not", params, 1)
compile(ctx, params.head, exprTypeAndVariable, branches.flip) compile(ctx, params.head, exprTypeAndVariable, branches.flip)
@ -2062,7 +2098,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
val reg = env.get[ThingInMemory]("__reg.loword") val reg = env.get[ThingInMemory]("__reg.loword")
(addr, addrSource) match { (addr, addrSource) match {
case (MemoryAddressConstant(th1: Thing), MemoryAddressConstant(th2: Thing)) case (MemoryAddressConstant(th1: Thing), MemoryAddressConstant(th2: Thing))
if (th1.name == "__reg.loword" || th1.name == "__reg") && (th2.name == "__reg.loword" || th2.name == "__reg") => if (th1.name == "__reg.loword" || th1.name == "__reg.b2b3" || th1.name == "__reg") && (th2.name == "__reg.loword" || th2.name == "__reg.b2b3" || th2.name == "__reg") =>
(MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3), (MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3),
MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3)) match { MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3)) match {
case (_, false) => case (_, false) =>

View File

@ -738,6 +738,37 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}) })
case f@FunctionCallExpression(name, params) => case f@FunctionCallExpression(name, params) =>
name match { name match {
case "call" =>
val callLine = ZLine(CALL, NoRegisters, env.get[ThingInMemory]("call").toAddress)
params match {
case List(fp) =>
getExpressionType(ctx, fp) match {
case FunctionPointerType(_, _, _, _, Some(v)) if (v.name == "void") =>
compileToDE(ctx, fp) :+ callLine
case _ =>
ctx.log.error("Not a function pointer", fp.position)
compile(ctx, fp, ZExpressionTarget.NOTHING, BranchSpec.None)
}
case List(fp, param) =>
getExpressionType(ctx, fp) match {
case FunctionPointerType(_, _, _, Some(pt), Some(v)) =>
if (pt.size != 1) {
ctx.log.error("Invalid parameter type", param.position)
compileToHL(ctx, fp) ++ compile(ctx, param, ZExpressionTarget.NOTHING)
} else if (getExpressionType(ctx, param).isAssignableTo(pt)) {
compileToDE(ctx, fp) ++ stashDEIfChanged(ctx, compileToA(ctx, param)) :+ callLine
} else {
ctx.log.error("Invalid parameter type", param.position)
compileToHL(ctx, fp) ++ compile(ctx, param, ZExpressionTarget.NOTHING)
}
case _ =>
ctx.log.error("Not a function pointer", fp.position)
compile(ctx, fp, ZExpressionTarget.NOTHING) ++ compile(ctx, param, ZExpressionTarget.NOTHING)
}
case _ =>
ctx.log.error("Invalid call syntax", f.position)
Nil
}
case "not" => case "not" =>
assertBool(ctx, "not", params, 1) assertBool(ctx, "not", params, 1)
compile(ctx, params.head, target, branches.flip) compile(ctx, params.head, target, branches.flip)

View File

@ -284,6 +284,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
} }
def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = { def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = {
if (name.startsWith("function.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) {
val tokens = name.stripPrefix("function.").split("\\.to\\.", 2)
if (tokens.length == 2) {
return FunctionPointerType(name, tokens(0), tokens(1), maybeGet[Type](tokens(0)), maybeGet[Type](tokens(1))).asInstanceOf[T]
}
}
if (name.startsWith("pointer.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) { if (name.startsWith("pointer.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) {
val targetName = name.stripPrefix("pointer.") val targetName = name.stripPrefix("pointer.")
val target = maybeGet[VariableType](targetName) val target = maybeGet[VariableType](targetName)
@ -1147,9 +1153,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
} }
} }
private def getFunctionPointerType(f: FunctionInMemory) = f.params.types match {
case List() =>
get[Type]("function.void.to." + f.returnType.name)
case p :: _ => // TODO: this only handles one type though!
get[Type]("function." + p.name + ".to." + f.returnType.name)
}
private def registerAddressConstant(thing: ThingInMemory, position: Option[Position], options: CompilationOptions, targetType: Option[Type]): Unit = { private def registerAddressConstant(thing: ThingInMemory, position: Option[Position], options: CompilationOptions, targetType: Option[Type]): Unit = {
val b = get[Type]("byte")
if (!thing.zeropage && options.flag(CompilationFlag.LUnixRelocatableCode)) { if (!thing.zeropage && options.flag(CompilationFlag.LUnixRelocatableCode)) {
val b = get[Type]("byte")
val w = get[Type]("word") val w = get[Type]("word")
val relocatable = UninitializedMemoryVariable(thing.name + ".addr", w, VariableAllocationMethod.Static, None, defaultVariableAlignment(options, 2), isVolatile = false) val relocatable = UninitializedMemoryVariable(thing.name + ".addr", w, VariableAllocationMethod.Static, None, defaultVariableAlignment(options, 2), isVolatile = false)
val addr = relocatable.toAddress val addr = relocatable.toAddress
@ -1166,20 +1179,36 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
addThing(ConstantThing(thing.name + ".rawaddr", rawaddr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".rawaddr", rawaddr, get[Type]("pointer")), position)
addThing(ConstantThing(thing.name + ".rawaddr.hi", rawaddr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.hi", rawaddr.hiByte, get[Type]("byte")), position)
addThing(ConstantThing(thing.name + ".rawaddr.lo", rawaddr.loByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.lo", rawaddr.loByte, get[Type]("byte")), position)
thing match {
case f: FunctionInMemory if f.canBePointedTo =>
val typedPointer = RelativeVariable(thing.name + ".pointer", addr, getFunctionPointerType(f), zeropage = false, None, isVolatile = false)
addThing(typedPointer, position)
addThing(RelativeVariable(thing.name + ".pointer.hi", addr + 1, b, zeropage = false, None, isVolatile = false), position)
addThing(RelativeVariable(thing.name + ".pointer.lo", addr, b, zeropage = false, None, isVolatile = false), position)
case _ =>
}
} else { } else {
val addr = thing.toAddress val addr = thing.toAddress
addThing(ConstantThing(thing.name + ".addr", addr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".addr", addr, get[Type]("pointer")), position)
addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, b), position)
addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, b), position)
addThing(ConstantThing(thing.name + ".rawaddr", addr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".rawaddr", addr, get[Type]("pointer")), position)
addThing(ConstantThing(thing.name + ".rawaddr.hi", addr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.hi", addr.hiByte, b), position)
addThing(ConstantThing(thing.name + ".rawaddr.lo", addr.loByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.lo", addr.loByte, b), position)
targetType.foreach { tt => targetType.foreach { tt =>
val pointerType = PointerType("pointer." + tt.name, tt.name, Some(tt)) val pointerType = PointerType("pointer." + tt.name, tt.name, Some(tt))
val typedPointer = RelativeVariable(thing.name + ".pointer", addr, pointerType, zeropage = false, None, isVolatile = false) val typedPointer = RelativeVariable(thing.name + ".pointer", addr, pointerType, zeropage = false, None, isVolatile = false)
addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position) addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position)
addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, b), position)
addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, b), position)
}
thing match {
case f: FunctionInMemory if f.canBePointedTo =>
val pointerType = getFunctionPointerType(f)
addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position)
addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, b), position)
addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, b), position)
case _ =>
} }
} }
} }
@ -1698,6 +1727,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
if (function.params.length != actualParams.length) { if (function.params.length != actualParams.length) {
log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position)) log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position))
} }
if (name == "call") return Some(function)
function.params match { function.params match {
case NormalParamSignature(params) => case NormalParamSignature(params) =>
function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) => function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) =>

View File

@ -82,6 +82,14 @@ case class PointerType(name: String, targetName: String, var target: Option[Type
override def pointerTargetName: String = targetName override def pointerTargetName: String = targetName
} }
case class FunctionPointerType(name: String, paramTypeName:String, returnTypeName: String, var paramType: Option[Type], var returnType: Option[Type]) extends VariableType {
def size = 2
override def isSigned: Boolean = false
override def isPointy: Boolean = false
}
case object NullType extends VariableType { case object NullType extends VariableType {
override def size: Int = 2 override def size: Int = 2
@ -91,9 +99,9 @@ case object NullType extends VariableType {
override def isPointy: Boolean = true override def isPointy: Boolean = true
override def isSubtypeOf(other: Type): Boolean = this == other || other.isPointy && other.size == 2 override def isSubtypeOf(other: Type): Boolean = this == other || (other.isPointy || other.isInstanceOf[FunctionPointerType]) && other.size == 2
override def isAssignableTo(targetType: Type): Boolean = this == targetType || targetType.isPointy && targetType.size == 2 override def isAssignableTo(targetType: Type): Boolean = this == targetType || (targetType.isPointy || targetType.isInstanceOf[FunctionPointerType]) && targetType.size == 2
} }
case class EnumType(name: String, count: Option[Int]) extends VariableType { case class EnumType(name: String, count: Option[Int]) extends VariableType {
@ -340,6 +348,8 @@ sealed trait MangledFunction extends CallableThing {
def params: ParamSignature def params: ParamSignature
def interrupt: Boolean def interrupt: Boolean
def canBePointedTo: Boolean
} }
case class EmptyFunction(name: String, case class EmptyFunction(name: String,
@ -348,6 +358,8 @@ case class EmptyFunction(name: String,
override def params = EmptyFunctionParamSignature(paramType) override def params = EmptyFunctionParamSignature(paramType)
override def interrupt = false override def interrupt = false
override def canBePointedTo: Boolean = false
} }
case class MacroFunction(name: String, case class MacroFunction(name: String,
@ -356,6 +368,8 @@ case class MacroFunction(name: String,
environment: Environment, environment: Environment,
code: List[ExecutableStatement]) extends MangledFunction { code: List[ExecutableStatement]) extends MangledFunction {
override def interrupt = false override def interrupt = false
override def canBePointedTo: Boolean = false
} }
sealed trait FunctionInMemory extends MangledFunction with ThingInMemory { sealed trait FunctionInMemory extends MangledFunction with ThingInMemory {
@ -381,6 +395,8 @@ case class ExternFunction(name: String,
override def zeropage: Boolean = false override def zeropage: Boolean = false
override def isVolatile: Boolean = false override def isVolatile: Boolean = false
override def canBePointedTo: Boolean = !interrupt && returnType.size <= 2 && params.canBePointedTo && name !="call"
} }
case class NormalFunction(name: String, case class NormalFunction(name: String,
@ -402,6 +418,8 @@ case class NormalFunction(name: String,
override def zeropage: Boolean = false override def zeropage: Boolean = false
override def isVolatile: Boolean = false override def isVolatile: Boolean = false
override def canBePointedTo: Boolean = !interrupt && returnType.size <= 2 && params.canBePointedTo && name !="call"
} }
case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing { case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing {
@ -416,12 +434,16 @@ trait ParamSignature {
def types: List[Type] def types: List[Type]
def length: Int def length: Int
def canBePointedTo: Boolean
} }
case class NormalParamSignature(params: List[VariableInMemory]) extends ParamSignature { case class NormalParamSignature(params: List[VariableInMemory]) extends ParamSignature {
override def length: Int = params.length override def length: Int = params.length
override def types: List[Type] = params.map(_.typ) override def types: List[Type] = params.map(_.typ)
def canBePointedTo: Boolean = params.size <= 1 && params.forall(_.typ.size.<=(1))
} }
sealed trait ParamPassingConvention { sealed trait ParamPassingConvention {
@ -464,17 +486,27 @@ object AssemblyParameterPassingBehaviour extends Enumeration {
val Copy, ByReference, ByConstant = Value val Copy, ByReference, ByConstant = Value
} }
case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value) case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value) {
def canBePointedTo: Boolean = behaviour == AssemblyParameterPassingBehaviour.Copy && (variable match {
case RegisterVariable(MosRegister.A | MosRegister.AX, _) => true
case ZRegisterVariable(ZRegister.A | ZRegister.HL, _) => true
case _ => false
})
}
case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature { case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature {
override def length: Int = params.length override def length: Int = params.length
override def types: List[Type] = params.map(_.typ) override def types: List[Type] = params.map(_.typ)
def canBePointedTo: Boolean = params.size <= 1 && params.forall(_.canBePointedTo)
} }
case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature { case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature {
override def length: Int = 1 override def length: Int = 1
override def types: List[Type] = List(paramType) override def types: List[Type] = List(paramType)
def canBePointedTo: Boolean = false
} }

View File

@ -59,12 +59,16 @@ abstract class CallGraph(program: Program, log: Logger) {
case s: SumExpression => case s: SumExpression =>
s.expressions.foreach(expr => add(currentFunction, callingFunctions, expr._2)) s.expressions.foreach(expr => add(currentFunction, callingFunctions, expr._2))
case x: VariableExpression => case x: VariableExpression =>
val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr").stripSuffix(".pointer")
everCalledFunctions += varName everCalledFunctions += varName
entryPoints += varName // TODO: figure out how to interpret pointed-to functions
case i: IndexedExpression => case i: IndexedExpression =>
val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr").stripSuffix(".pointer")
everCalledFunctions += varName everCalledFunctions += varName
entryPoints += varName
add(currentFunction, callingFunctions, i.index) add(currentFunction, callingFunctions, i.index)
case i: DerefExpression =>
add(currentFunction, callingFunctions, i.inner)
case i: DerefDebuggingExpression => case i: DerefDebuggingExpression =>
add(currentFunction, callingFunctions, i.inner) add(currentFunction, callingFunctions, i.inner)
case IndirectFieldExpression(root, firstIndices, fields) => case IndirectFieldExpression(root, firstIndices, fields) =>

View File

@ -77,6 +77,9 @@ abstract class AbstractInliningCalculator[T <: AbstractCode] {
case s: Statement => getAllCalledFunctions(s.getAllExpressions) case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => Set( case s: VariableExpression => Set(
s.name, s.name,
s.name.stripSuffix(".pointer"),
s.name.stripSuffix(".pointer.lo"),
s.name.stripSuffix(".pointer.hi"),
s.name.stripSuffix(".addr"), s.name.stripSuffix(".addr"),
s.name.stripSuffix(".hi"), s.name.stripSuffix(".hi"),
s.name.stripSuffix(".lo"), s.name.stripSuffix(".lo"),

View File

@ -0,0 +1,102 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile}
import org.scalatest.{AppendedClues, FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class FunctionPointerSuite extends FunSuite with Matchers with AppendedClues{
test("Function pointers 1") {
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Cmos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(
"""
|
| byte output @$c000
| void f1() {
| output = 100
| }
|
| void main() {
| function.void.to.void p1
| p1 = f1.pointer
| call(p1)
| }
|
""".stripMargin) { m =>
m.readByte(0xc000) should equal(100)
}
}
test("Function pointers 2") {
EmuUnoptimizedCrossPlatformRun (Cpu.Mos, Cpu.Cmos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(
"""
| const byte COUNT = 128
| array output0[COUNT] @$c000
| array output1[COUNT] @$c100
| array output2[COUNT] @$c200
| array output3[COUNT] @$c300
|
| void tabulate(pointer target, function.byte.to.byte f) {
| byte i
| for i,0,until,COUNT {
| target[i] = call(f, i)
| }
| }
|
| byte double(byte x) = x*2
| byte negate(byte x) = 0-x
| byte zero(byte x) = 0
| byte id(byte x) = x
|
| void main() {
| tabulate(output0, zero.pointer)
| tabulate(output1, id.pointer)
| tabulate(output2, double.pointer)
| tabulate(output3, negate.pointer)
| }
|
""".stripMargin) { m =>
for (i <- 0 until 0x80) {
m.readByte(0xc000 + i) should equal(0) withClue ("zero " + i)
m.readByte(0xc100 + i) should equal(i) withClue ("id " + i)
m.readByte(0xc200 + i) should equal(i * 2) withClue ("double " + i)
m.readByte(0xc300 + i) should equal((256 - i) & 0xff) withClue ("negate " + i)
}
}
}
test("Function pointers: invalid types") {
ShouldNotCompile(
"""
|void main() {
| call(main.pointer, 1)
|}
|""".stripMargin)
ShouldNotCompile(
"""
|void f(byte a) = 0
|void main() {
| call(f.pointer)
|}
|""".stripMargin)
ShouldNotCompile(
"""
|enum e {}
|void f(e a) = 0
|void main() {
| call(f.pointer, 0)
|}
|""".stripMargin)
ShouldNotCompile(
"""
|enum e {}
|void f(byte a) = 0
|void main() {
| call(f.pointer, e(7))
|}
|""".stripMargin)
}
}

View File

@ -156,6 +156,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
if (native16 && platform.cpu != millfork.Cpu.Sixteen) throw new IllegalStateException if (native16 && platform.cpu != millfork.Cpu.Sixteen) throw new IllegalStateException
var effectiveSource = source var effectiveSource = source
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
if (source.contains("call(")) effectiveSource += "\nnoinline asm word call(word ax) {\nJMP ((__reg.b2b3))\n}\n"
if (native16) effectiveSource += if (native16) effectiveSource +=
""" """
| |

View File

@ -92,6 +92,11 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio
log.verbosity = 999 log.verbosity = 999
var effectiveSource = source var effectiveSource = source
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
if (source.contains("call(")) {
if (options.flag(CompilationFlag.UseIntelSyntaxForInput))
effectiveSource += "\nnoinline asm word call(word de) {\npush d\nret\n}\n"
else effectiveSource += "\nnoinline asm word call(word de) {\npush de\nret\n}\n"
}
log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq)) log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq))
val PreprocessingResult(preprocessedSource, features, pragmas) = Preprocessor.preprocessForTest(options, effectiveSource) val PreprocessingResult(preprocessedSource, features, pragmas) = Preprocessor.preprocessForTest(options, effectiveSource)
// tests use Intel syntax only when forced to: // tests use Intel syntax only when forced to:

View File

@ -8,7 +8,7 @@ import millfork.compiler.{CompilationContext, LabelGenerator}
import millfork.compiler.mos.MosCompiler import millfork.compiler.mos.MosCompiler
import millfork.env.{Environment, InitializedArray, InitializedMemoryVariable, NormalFunction} import millfork.env.{Environment, InitializedArray, InitializedMemoryVariable, NormalFunction}
import millfork.node.StandardCallGraph import millfork.node.StandardCallGraph
import millfork.parser.{MosParser, PreprocessingResult, Preprocessor} import millfork.parser.{MosParser, PreprocessingResult, Preprocessor, Z80Parser}
import millfork._ import millfork._
import millfork.compiler.m6809.M6809Compiler import millfork.compiler.m6809.M6809Compiler
import millfork.compiler.z80.Z80Compiler import millfork.compiler.z80.Z80Compiler
@ -36,11 +36,27 @@ object ShouldNotCompile extends Matchers {
log.verbosity = 999 log.verbosity = 999
var effectiveSource = source var effectiveSource = source
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
if (source.contains("call(")) {
platform.cpuFamily match {
case CpuFamily.M6502 =>
effectiveSource += "\nnoinline asm word call(word ax) {\nJMP ((__reg.b2b3))\n}\n"
case CpuFamily.I80 =>
if (options.flag(CompilationFlag.UseIntelSyntaxForInput))
effectiveSource += "\nnoinline asm word call(word de) {\npush d\nret\n}\n"
else effectiveSource += "\nnoinline asm word call(word de) {\npush de\nret\n}\n"
}
}
if (source.contains("import zp_reg")) if (source.contains("import zp_reg"))
effectiveSource += Files.readAllLines(Paths.get("include/zp_reg.mfk"), StandardCharsets.US_ASCII).asScala.mkString("\n", "\n", "") effectiveSource += Files.readAllLines(Paths.get("include/zp_reg.mfk"), StandardCharsets.US_ASCII).asScala.mkString("\n", "\n", "")
log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq)) log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq))
val PreprocessingResult(preprocessedSource, features, _) = Preprocessor.preprocessForTest(options, effectiveSource) val PreprocessingResult(preprocessedSource, features, _) = Preprocessor.preprocessForTest(options, effectiveSource)
val parserF = MosParser("", preprocessedSource, "", options, features) val parserF =
platform.cpuFamily match {
case CpuFamily.M6502 =>
MosParser("", preprocessedSource, "", options, features)
case CpuFamily.I80 =>
Z80Parser("", preprocessedSource, "", options, features, options.flag(CompilationFlag.UseIntelSyntaxForInput))
}
parserF.toAst match { parserF.toAst match {
case Success(program, _) => case Success(program, _) =>
log.assertNoErrors("Parse failed") log.assertNoErrors("Parse failed")