diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 38b3e7f8..d77930df 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -62,16 +62,19 @@ class AbstractExpressionCompiler[T <: AbstractCode] { ctx.log.error("Long multiplication not supported", params.head.position) } if (inPlace) { - if (lSize == 2 && rSize == 1 && rType.isSigned) { - ctx.log.error("Signed multiplication not supported", params.head.position) + if (params.size > 2) { + ctx.log.error("Too many arguments for *=", params.head.position) } +// if (lSize == 2 && rSize == 1 && rType.isSigned) { +// ctx.log.error("Signed multiplication not supported", params.head.position) +// } } else { - if (lSize == 2 && rSize == 1 && rType.isSigned) { - ctx.log.error("Signed multiplication not supported", params.head.position) - } - if (rSize == 2 && lSize == 1 && lType.isSigned) { - ctx.log.error("Signed multiplication not supported", params.head.position) - } +// if (lSize == 2 && rSize == 1 && rType.isSigned) { +// ctx.log.error("Signed multiplication not supported", params.head.position) +// } +// if (rSize == 2 && lSize == 1 && lType.isSigned) { +// ctx.log.error("Signed multiplication not supported", params.head.position) +// } } } @@ -503,7 +506,23 @@ object AbstractExpressionCompiler { case List(n, _) if n >= 3 => env.get[Type]("int" + n * 8) case _ => log.error(s"Invalid parameters to %%", expr.position); w } - case FunctionCallExpression(op@("*" | "|" | "&" | "^" | "/"), params) => params.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match { + case FunctionCallExpression(op@("*"), params) => + val paramTypes = params.map { e => getExpressionTypeImpl(env, log, e, loosely) } + val signed = paramTypes.exists(_.isSigned) + val unsigned = paramTypes.exists{ + case t: DerivedPlainType => !t.isSigned + case _ => false + } + if (signed && unsigned) { + log.error("Mixing signed and explicitly unsigned types in multiplication", expr.position) + } + paramTypes.map(_.size).max match { + case 0 | 1 => if (signed) env.get[Type]("sbyte") else b + case 2 => if (signed) env.get[Type]("signed16") else w + case n if n >= 3 => env.get[Type]("int" + n * 8) + case _ => log.error(s"Invalid parameters to " + op, expr.position); w + } + case FunctionCallExpression(op@("|" | "&" | "^" | "/"), params) => params.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match { case 0 | 1 => b case 2 => w case n if n >= 3 => env.get[Type]("int" + n * 8) diff --git a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala index 35059b03..e089b36b 100644 --- a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala @@ -270,6 +270,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { Nil } case "*" => + assertSizesForMultiplication(ctx, params, inPlace = false) getArithmeticParamMaxSize(ctx, params) match { case 1 => M6809MulDiv.compileByteMultiplication(ctx, params, updateDerefX = false) ++ targetifyB(ctx, target, isSigned = false) case 2 => M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target) @@ -474,6 +475,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { case _ => M6809LargeBuiltins.modifyInPlaceViaX(ctx, l, r, SUBA) } case "*=" => + assertSizesForMultiplication(ctx, params, inPlace = true) val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) size match { case 1 => compileAddressToX(ctx, l) ++ M6809MulDiv.compileByteMultiplication(ctx, List(r), updateDerefX = true) diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index 77c04ff8..71e8db8b 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -718,11 +718,16 @@ object PseudoregisterBuiltIns { ctx.log.error("Variable word-byte multiplication requires the zeropage pseudoregister of size at least 3", param1OrRegister.flatMap(_.position)) return Nil } - (param1OrRegister.fold(2)(e => AbstractExpressionCompiler.getExpressionType(ctx, e).size), - AbstractExpressionCompiler.getExpressionType(ctx, param2).size) match { + val lType = param1OrRegister.map(e => AbstractExpressionCompiler.getExpressionType(ctx, e)) + val rType = AbstractExpressionCompiler.getExpressionType(ctx, param2) + (lType.fold(2)(_.size), rType.size) match { case (2, 2) => return compileWordWordMultiplication(ctx, param1OrRegister, param2) case (1, 2) => return compileWordMultiplication(ctx, Some(param2), param1OrRegister.get, storeInRegLo) - case (2 | 1, 1) => // ok + case (2, 1) => + if (rType.isSigned) { + return compileWordWordMultiplication(ctx, param1OrRegister, param2) + } + case (1, 1) => // ok case _ => ctx.log.fatal("Invalid code path", param2.position) } if (!storeInRegLo && param1OrRegister.isDefined) { diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index 5a30237c..917854ce 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -317,11 +317,13 @@ object Z80Multiply { * Calculate HL = l * r */ def compile16BitMultiplyToHL(ctx: CompilationContext, l: Expression, r: Expression): List[ZLine] = { - (AbstractExpressionCompiler.getExpressionType(ctx, l).size, - AbstractExpressionCompiler.getExpressionType(ctx, r).size) match { + val lType = AbstractExpressionCompiler.getExpressionType(ctx, l) + val rType = AbstractExpressionCompiler.getExpressionType(ctx, r) + (lType.size, rType.size) match { case (2, 2) => return compile16x16BitMultiplyToHL(ctx, l, r) case (1, 2) => return compile16BitMultiplyToHL(ctx, r, l) - case (2 | 1, 1) => // ok + case (2, 1) => if (rType.isSigned) return compile16x16BitMultiplyToHL(ctx, l, r) + case (1, 1) => // ok case _ => ctx.log.fatal("Invalid code path", l.position) } ctx.env.eval(r) match { diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index eebd4bd2..4df69c4c 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -814,4 +814,56 @@ class WordMathSuite extends FunSuite with Matchers with AppendedClues { } } } + + test("Signed multiplication with type promotion") { + for { + (t1, t2) <- Seq("sbyte" -> "word", "sbyte" -> "signed16", "byte" -> "signed16", "signed16" -> "word") + x <- Seq(0, -1, 1, 120, -120) + x2 <- Seq(0, -1, 1, 120, -120) + } { + val x1 = if (t1 == "byte") x & 0xff else x + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + s""" + | import zp_reg + | signed16 output @$$c000 + | bool typeOk @$$c002 + | void main () { + | $t1 v1 + | v1 = $x1 + | $t2 v2 + | v2 = $x2 + | memory_barrier() + | output = v1 * v2 + | typeOk = typeof(v1 * v2) == typeof(signed16) + | }""". + stripMargin) { m => + m.readWord(0xc000).toShort should equal((x1 * x2).toShort) withClue s"= $t1($x1) * $t2($x2)" + m.readByte(0xc002) should equal(1) withClue s"= typeof($t1 * $t2)" + } + } + } + + test("Signed multiplication with type promotion 2") { + for { + t2 <- Seq("sbyte", "signed16", "byte", "word") + x1 <- Seq(0, -1, 1, 120, -120) + x <- Seq(0, -1, 1, 120, -120) + } { + val x2 = if (t2.startsWith("s")) x else if (t2.startsWith("w")) x & 0xffff else x & 0xff + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + s""" + | import zp_reg + | signed16 output @$$c000 + | void main () { + | output = $x1 + | $t2 v2 + | v2 = $x2 + | memory_barrier() + | output *= v2 + | }""". + stripMargin) { m => + m.readWord(0xc000).toShort should equal((x1 * x2).toShort) withClue s"= signed16($x1) * $t2($x2)" + } + } + } }