diff --git a/docs/api/command-line.md b/docs/api/command-line.md index baa7e551..e7e69149 100644 --- a/docs/api/command-line.md +++ b/docs/api/command-line.md @@ -218,7 +218,7 @@ Default: yes. * `-foptimize-stdlib`, `-fno-optimize-stdlib` – Whether should replace some standard library calls with constant parameters with more efficient variants. -Currently affects `putstrz` and `strzlen`, but may affect more functions in the future. +Currently affects `putstrz`, `putpstr`, `strzlen`, `scrstrlen` and `pstrlen`, but may affect more functions in the future. `.ini` equivalent: `optimize_stdlib`. Default: no. diff --git a/docs/stdlib/string.md b/docs/stdlib/string.md index 69492adb..55a3e046 100644 --- a/docs/stdlib/string.md +++ b/docs/stdlib/string.md @@ -54,7 +54,7 @@ It contains functions for handling strings in the screen encoding with the same ## pstring -The `scrstring` module automatically imports the [`err` module](./other.md). +The `pstring` module automatically imports the [`err` module](./other.md). It contains functions for handling length-prefixed strings in any 8-bit encoding. @@ -62,6 +62,14 @@ It contains functions for handling length-prefixed strings in any 8-bit encoding #### `sbyte pstrcmp(pointer str1, pointer str2)` #### `void pstrcopy(pointer dest, pointer src)` #### `void pstrpaste(pointer dest, pointer src)` -#### `word pstr2word(pointer str)` #### `void pstrappend(pointer buffer, pointer str)` #### `void pstrappendchar(pointer buffer, byte char)` +#### `word pstr2word(pointer str)` + +Converts a length-prefixed string to a number. Uses the default encoding. +Sets `errno`. + +#### `word pscrstr2word(pointer str)` + +Converts a length-prefixed string to a number. Uses the screen encoding. +Sets `errno`. diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index cea0de06..95343860 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -73,27 +73,36 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte // stdlib: if (optimizeStdlib) { stmt match { - case ExpressionStatement(FunctionCallExpression("putstrz", List(TextLiteralExpression(text)))) => - text.lastOption match { - case Some(LiteralExpression(0, _)) => - text.size match { - case 1 => - ctx.log.debug("Removing putstrz with empty argument", stmt.position) - return EmptyStatement(Nil) -> currentVarValues - case 2 => + case ExpressionStatement(FunctionCallExpression("putstrz", List(TextLiteralExpression(text)))) + if StdLibOptUtils.isValidNulTerminated(ctx.options.platform.defaultCodec, text) => + text.size match { + case 1 => + ctx.log.debug("Removing putstrz with empty argument", stmt.position) + return EmptyStatement(Nil) -> currentVarValues + case 2 => + ctx.log.debug("Replacing putstrz with putchar", stmt.position) + return ExpressionStatement(FunctionCallExpression("putchar", List(text.head))) -> currentVarValues + case 3 => + if (ctx.options.platform.cpuFamily == CpuFamily.M6502) { ctx.log.debug("Replacing putstrz with putchar", stmt.position) - return ExpressionStatement(FunctionCallExpression("putchar", List(text.head))) -> currentVarValues - case 3 => - if (ctx.options.platform.cpuFamily == CpuFamily.M6502) { - ctx.log.debug("Replacing putstrz with putchar", stmt.position) - return IfStatement(FunctionCallExpression("==", List(LiteralExpression(1, 1), LiteralExpression(1, 1))), List( - ExpressionStatement(FunctionCallExpression("putchar", List(text.head))), - ExpressionStatement(FunctionCallExpression("putchar", List(text(1)))) - ), Nil) -> currentVarValues - } - case _ => - } - } + return IfStatement(FunctionCallExpression("==", List(LiteralExpression(1, 1), LiteralExpression(1, 1))), List( + ExpressionStatement(FunctionCallExpression("putchar", List(text.head))), + ExpressionStatement(FunctionCallExpression("putchar", List(text(1)))) + ), Nil) -> currentVarValues + } + case _ => + } + case ExpressionStatement(FunctionCallExpression("putpstr", List(TextLiteralExpression(text)))) + if StdLibOptUtils.isValidPascal(text) => + text.length match { + case 1 => + ctx.log.debug("Removing putpstr with empty argument", stmt.position) + return EmptyStatement(Nil) -> currentVarValues + case 2 => + ctx.log.debug("Replacing putpstr with putchar", stmt.position) + return ExpressionStatement(FunctionCallExpression("putchar", List(text(1)))) -> currentVarValues + case _ => + } case _ => } } @@ -360,12 +369,11 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte if (optimizeStdlib) { expr match { case FunctionCallExpression("strzlen", List(TextLiteralExpression(text))) => - text.lastOption match { - case Some(LiteralExpression(0, _)) if text.size <= 256 => - ctx.log.debug("Replacing strzlen with constant argument", expr.position) - return LiteralExpression(text.size - 1, 1) - case _ => - } + StdLibOptUtils.evalStrzLen(ctx, expr, ctx.options.platform.defaultCodec , text).foreach(return _) + case FunctionCallExpression("scrstrzlen", List(TextLiteralExpression(text))) => + StdLibOptUtils.evalStrzLen(ctx, expr, ctx.options.platform.defaultCodec, text).foreach(return _) + case FunctionCallExpression("pstrlen", List(TextLiteralExpression(text))) => + StdLibOptUtils.evalPStrLen(ctx, expr, text).foreach(return _) case _ => } } diff --git a/src/main/scala/millfork/compiler/StdLibOptUtils.scala b/src/main/scala/millfork/compiler/StdLibOptUtils.scala new file mode 100644 index 00000000..db67f8ab --- /dev/null +++ b/src/main/scala/millfork/compiler/StdLibOptUtils.scala @@ -0,0 +1,67 @@ +package millfork.compiler + +import millfork.node.{Expression, FunctionCallExpression, LiteralExpression} +import millfork.parser.TextCodec + +/** + * @author Karol Stasiak + */ +object StdLibOptUtils { + + + private def fn(expr: Expression) = expr match { + case f: FunctionCallExpression => f.functionName + case _ => "library function call" + } + + def isValidNulTerminated(codec: TextCodec, text: List[Expression], minLength: Int = 0, maxLength: Int = 255): Boolean = { + if (codec.stringTerminator.length != 1) return false + val Nul = codec.stringTerminator.head + if (text.init.forall { + case LiteralExpression(c, _) => c != Nul + case _ => false + }) { + text.lastOption match { + case Some(LiteralExpression(Nul, _)) if text.size >= minLength + 1 && text.size <= maxLength + 1 => return true + case _ => + } + } + false + } + + def isValidPascal(text: List[Expression]): Boolean = { + text.headOption match { + case Some(LiteralExpression(l, _)) if text.size == l + 1 && text.size <= 256 => + true + case _ => + false + } + } + + def evalStrzLen(ctx: CompilationContext, expr: Expression, codec: TextCodec, text: List[Expression]): Option[Expression] = { + if (codec.stringTerminator.length != 1) return None + val Nul = codec.stringTerminator.head + if (text.init.forall { + case LiteralExpression(c, _) => c != Nul + case _ => false + }) { + text.lastOption match { + case Some(LiteralExpression(Nul, _)) if text.size <= 256 => + ctx.log.debug(s"Replacing ${fn(expr)} with constant argument", expr.position) + return Some(LiteralExpression(text.size - 1, 1)) + case _ => + } + } + None + } + + def evalPStrLen(ctx: CompilationContext, expr: Expression, text: List[Expression]): Option[Expression] = { + text.headOption match { + case Some(LiteralExpression(l, _)) if text.size == l + 1 && text.size <= 256 => + ctx.log.debug(s"Replacing ${fn(expr)} with constant argument", expr.position) + return Some(LiteralExpression(l, 1)) + case _ => + } + None + } +} diff --git a/src/test/resources/include/dummy_stdio.mfk b/src/test/resources/include/dummy_stdio.mfk index d24cb465..e7b406a0 100644 --- a/src/test/resources/include/dummy_stdio.mfk +++ b/src/test/resources/include/dummy_stdio.mfk @@ -1,10 +1,22 @@ noinline void putchar(byte b) { } noinline void putstrz(pointer p) { putchar(0) } -byte strzlen(pointer str) { +noinline void putpstr(pointer p) { putchar(0) } +noinline byte strzlen(pointer str) { byte index index = 0 - while str[index] != 0 { + while str[index] != nullchar { index += 1 } return index } +noinline byte scrstrzlen(pointer str) { + byte index + index = 0 + while str[index] != nullchar_scr { + index += 1 + } + return index +} +noinline byte pstrlen(pointer str) { + return str[0] +} diff --git a/src/test/scala/millfork/test/StatementOptimizationSuite.scala b/src/test/scala/millfork/test/StatementOptimizationSuite.scala index 74e023c5..4e7e3a19 100644 --- a/src/test/scala/millfork/test/StatementOptimizationSuite.scala +++ b/src/test/scala/millfork/test/StatementOptimizationSuite.scala @@ -62,16 +62,24 @@ class StatementOptimizationSuite extends FunSuite with Matchers { """ | import stdio | byte output @$c000 + | byte output2 @$c002 + | byte output3 @$c003 | void main() { | output = strzlen("test"z) + | output2 = scrstrzlen("test"z) + | output3 = pstrlen("test"p) | putstrz(""z) | putstrz("a"z) + | putpstr(""p) + | putpstr("a"p) | putstrz("bc"z) | putstrz("def"z) | } """.stripMargin ) { m => m.readByte(0xc000) should equal(4) + m.readByte(0xc002) should equal(4) + m.readByte(0xc003) should equal(4) } } }