mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-01 06:29:53 +00:00
Optimize pointer indexing
This commit is contained in:
parent
b7300616d1
commit
3873736424
@ -53,6 +53,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
result += p._1
|
||||
cv = p._2
|
||||
}
|
||||
result.foreach(println(_))
|
||||
result.toList -> cv
|
||||
}
|
||||
|
||||
@ -250,7 +251,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
}
|
||||
}
|
||||
|
||||
def optimizeExpr(expr: Expression, currentVarValues: VV): Expression = {
|
||||
def optimizeExpr(expr: Expression, currentVarValues: VV, optimizeSum: Boolean = false): Expression = {
|
||||
val pos = expr.position
|
||||
// stdlib:
|
||||
if (optimizeStdlib) {
|
||||
@ -360,9 +361,13 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
case DerefExpression(inner, 0, _) =>
|
||||
optimizeExpr(inner, currentVarValues).pos(pos)
|
||||
case DerefExpression(inner, offset, targetType) =>
|
||||
("pointer." + targetType.name) <| (
|
||||
("pointer" <| optimizeExpr(inner, currentVarValues).pos(pos)) #+# LiteralExpression(offset, 2)
|
||||
)
|
||||
if (offset == 0) {
|
||||
("pointer." + targetType.name) <| ("pointer" <| optimizeExpr(inner, currentVarValues).pos(pos))
|
||||
} else {
|
||||
("pointer." + targetType.name) <| (
|
||||
("pointer" <| optimizeExpr(inner, currentVarValues).pos(pos)) #+# LiteralExpression(offset, 2)
|
||||
)
|
||||
}
|
||||
case IndexedExpression(name, index) =>
|
||||
ctx.log.fatal("Oops!")
|
||||
case _ =>
|
||||
@ -392,26 +397,43 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
ok = false
|
||||
LiteralExpression(0, 1)
|
||||
} else {
|
||||
val inner = optimizeExpr(result, currentVarValues).pos(pos)
|
||||
val inner = optimizeExpr(result, currentVarValues, optimizeSum = true).pos(pos)
|
||||
val fieldOffset = subvariables.head._2
|
||||
val fieldType = subvariables.head._3
|
||||
pointerWrap match {
|
||||
case 0 =>
|
||||
DerefExpression(inner, fieldOffset, fieldType)
|
||||
case 1 =>
|
||||
("pointer." + fieldType.name) <| (
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
if (fieldOffset == 0) {
|
||||
("pointer." + fieldType.name) <| ("pointer" <| inner)
|
||||
} else {
|
||||
("pointer." + fieldType.name) <| (
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
}
|
||||
case 2 =>
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
case 10 =>
|
||||
"lo" <| (
|
||||
if (fieldOffset == 0) {
|
||||
"pointer" <| inner
|
||||
} else {
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
}
|
||||
case 10 =>
|
||||
if (fieldOffset == 0) {
|
||||
"lo" <| ("pointer" <| inner)
|
||||
} else {
|
||||
"lo" <| (
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
}
|
||||
case 11 =>
|
||||
"hi" <| (
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
if (fieldOffset == 0) {
|
||||
"hi" <| ("pointer" <| inner)
|
||||
} else {
|
||||
"hi" <| (
|
||||
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
|
||||
)
|
||||
}
|
||||
|
||||
case _ => throw new IllegalStateException
|
||||
}
|
||||
}
|
||||
@ -429,9 +451,9 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
}
|
||||
result
|
||||
case DerefDebuggingExpression(inner, 1) =>
|
||||
DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("byte")).pos(pos)
|
||||
DerefExpression(optimizeExpr(inner, currentVarValues, optimizeSum = true), 0, env.get[VariableType]("byte")).pos(pos)
|
||||
case DerefDebuggingExpression(inner, 2) =>
|
||||
DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("word")).pos(pos)
|
||||
DerefExpression(optimizeExpr(inner, currentVarValues, optimizeSum = true), 0, env.get[VariableType]("word")).pos(pos)
|
||||
case e@TextLiteralExpression(characters) =>
|
||||
val name = ctx.env.getTextLiteralArrayName(e)
|
||||
VariableExpression(name).pos(pos)
|
||||
@ -447,6 +469,18 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
if optimize && pointlessCast(t1, arg) =>
|
||||
ctx.log.debug(s"Pointless cast $t1(...)", pos)
|
||||
optimizeExpr(arg, currentVarValues)
|
||||
case FunctionCallExpression(op@("<<" | "<<'"), List(l, r)) =>
|
||||
env.eval(l) match {
|
||||
case Some(c) if c.isProvablyZero =>
|
||||
env.eval(r) match {
|
||||
case Some(rc) =>
|
||||
LiteralExpression(0, c.requiredSize)
|
||||
case _ =>
|
||||
FunctionCallExpression(op, List(optimizeExpr(l, currentVarValues), optimizeExpr(r, currentVarValues)))
|
||||
}
|
||||
case _ =>
|
||||
FunctionCallExpression(op, List(optimizeExpr(l, currentVarValues), optimizeExpr(r, currentVarValues)))
|
||||
}
|
||||
case FunctionCallExpression("nonet", args) =>
|
||||
// Eliminating variables may eliminate carry
|
||||
FunctionCallExpression("nonet", args.map(arg => optimizeExpr(arg, Map()))).pos(pos)
|
||||
@ -457,6 +491,12 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
case _ =>
|
||||
FunctionCallExpression(name, args.map(arg => optimizeExpr(arg, currentVarValues))).pos(pos)
|
||||
}
|
||||
case SumExpression(expressions, false) if optimizeSum =>
|
||||
SumExpression(expressions.map{
|
||||
case (minus, arg) => minus -> optimizeExpr(arg, currentVarValues)
|
||||
}.filterNot{
|
||||
case (_, e) => env.eval(e).exists(_.isProvablyZero)
|
||||
}, decimal = false)
|
||||
case SumExpression(expressions, decimal) =>
|
||||
// don't collapse additions, let the later stages deal with it
|
||||
// expecially important when inside a nonet operation
|
||||
@ -467,31 +507,50 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
targetType.size match {
|
||||
case 1 => IndexedExpression(name, optimizeExpr(index, Map())).pos(pos)
|
||||
case _ =>
|
||||
val arraySizeInBytes = pointy match {
|
||||
case p:ConstantPointy => p.sizeInBytes
|
||||
val constantOffset: Option[Long] = env.eval(index) match {
|
||||
case Some(z) if z.isProvablyZero => Some(0L)
|
||||
case Some(NumericConstant(n, _)) =>
|
||||
if (targetType.size * (n+1) <= 256) Some(targetType.size * n) else None
|
||||
case _ => None
|
||||
}
|
||||
val scaledIndex = arraySizeInBytes match {
|
||||
case Some(n) if n <= 256 => targetType.size match {
|
||||
case 1 => "byte" <| index
|
||||
case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1))
|
||||
case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1))
|
||||
case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1))
|
||||
case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.size, 1))
|
||||
}
|
||||
case Some(n) if n <= 512 && targetType.size == 2 =>
|
||||
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
|
||||
case _ => targetType.size match {
|
||||
case 1 => "word" <| index
|
||||
case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1))
|
||||
case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1))
|
||||
case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1))
|
||||
case _ => "*" <| ("word" <| index, LiteralExpression(targetType.size, 1))
|
||||
}
|
||||
constantOffset match {
|
||||
case Some(o) if o >= 0 && o <= 256 - targetType.size =>
|
||||
if (pointy.isArray) {
|
||||
DerefExpression(
|
||||
"pointer" <| VariableExpression(name).pos(pos),
|
||||
o.toInt, pointy.elementType).pos(pos)
|
||||
} else {
|
||||
DerefExpression(
|
||||
VariableExpression(name).pos(pos),
|
||||
o.toInt, pointy.elementType).pos(pos)
|
||||
}
|
||||
case _ =>
|
||||
val arraySizeInBytes = pointy match {
|
||||
case p: ConstantPointy => p.sizeInBytes
|
||||
case _ => None
|
||||
}
|
||||
val scaledIndex = arraySizeInBytes match {
|
||||
case Some(n) if n <= 256 => targetType.size match {
|
||||
case 1 => "byte" <| index
|
||||
case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1))
|
||||
case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1))
|
||||
case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1))
|
||||
case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.size, 1))
|
||||
}
|
||||
case Some(n) if n <= 512 && targetType.size == 2 =>
|
||||
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
|
||||
case _ => targetType.size match {
|
||||
case 1 => "word" <| index
|
||||
case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1))
|
||||
case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1))
|
||||
case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1))
|
||||
case _ => "*" <| ("word" <| index, LiteralExpression(targetType.size, 1))
|
||||
}
|
||||
}
|
||||
DerefExpression(
|
||||
("pointer" <| VariableExpression(name).pos(pos)) #+# optimizeExpr(scaledIndex, Map()),
|
||||
0, pointy.elementType).pos(pos)
|
||||
}
|
||||
DerefExpression(
|
||||
("pointer" <| VariableExpression(name).pos(pos)) #+# optimizeExpr(scaledIndex, Map()),
|
||||
0, pointy.elementType).pos(pos)
|
||||
}
|
||||
case _ => expr // TODO
|
||||
}
|
||||
|
7
src/main/scala/millfork/env/Constant.scala
vendored
7
src/main/scala/millfork/env/Constant.scala
vendored
@ -394,7 +394,7 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
|
||||
case MathOperator.Or | MathOperator.Exor | MathOperator.Plus | MathOperator.Minus =>
|
||||
lhs.isProvablyDivisibleBy256 && rhs.isProvablyDivisibleBy256
|
||||
case MathOperator.Shl =>
|
||||
rhs.isProvablyGreaterOrEqualThan(NumericConstant(8, 1))
|
||||
rhs.isProvablyGreaterOrEqualThan(NumericConstant(8, 1)) || lhs.isProvablyDivisibleBy256
|
||||
case _ => false
|
||||
}
|
||||
|
||||
@ -454,6 +454,11 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
|
||||
case MathOperator.Modulo => Constant.Zero
|
||||
case _ => CompoundConstant(operator, l, r)
|
||||
}
|
||||
case (NumericConstant(0, _), c) =>
|
||||
operator match {
|
||||
case MathOperator.Shl => l
|
||||
case _ => CompoundConstant(operator, l, r)
|
||||
}
|
||||
case (c, NumericConstant(0, 1)) =>
|
||||
operator match {
|
||||
case MathOperator.Plus => c
|
||||
|
7
src/main/scala/millfork/env/Pointy.scala
vendored
7
src/main/scala/millfork/env/Pointy.scala
vendored
@ -7,16 +7,19 @@ trait Pointy {
|
||||
def indexType: VariableType
|
||||
def elementType: VariableType
|
||||
def readOnly: Boolean
|
||||
def isArray: Boolean
|
||||
}
|
||||
|
||||
case class StackVariablePointy(offset: Int, indexType: VariableType, elementType: VariableType) extends Pointy {
|
||||
override def name: Option[String] = None
|
||||
override def readOnly: Boolean = false
|
||||
override def isArray: Boolean = false
|
||||
}
|
||||
|
||||
case class VariablePointy(addr: Constant, indexType: VariableType, elementType: VariableType, zeropage: Boolean) extends Pointy {
|
||||
override def name: Option[String] = None
|
||||
override def readOnly: Boolean = false
|
||||
override def isArray: Boolean = false
|
||||
}
|
||||
|
||||
case class ConstantPointy(value: Constant,
|
||||
@ -26,4 +29,6 @@ case class ConstantPointy(value: Constant,
|
||||
indexType: VariableType,
|
||||
elementType: VariableType,
|
||||
alignment: MemoryAlignment,
|
||||
override val readOnly: Boolean) extends Pointy
|
||||
override val readOnly: Boolean) extends Pointy {
|
||||
override def isArray: Boolean = elementCount.isDefined
|
||||
}
|
||||
|
@ -322,4 +322,21 @@ class PointerSuite extends FunSuite with Matchers with AppendedClues {
|
||||
m.readWord(0xc000) should equal(5)
|
||||
}
|
||||
}
|
||||
|
||||
test("Fast pointer indexing") {
|
||||
EmuUnoptimizedCrossPlatformRun(Cpu.Mos) (
|
||||
"""
|
||||
|pointer.word p
|
||||
|array(word) input [6]
|
||||
|word output @$c000
|
||||
|void main () {
|
||||
| input[3] = 555
|
||||
| output = f(input.pointer)
|
||||
|}
|
||||
|noinline word f(pointer.word p) = p[3]
|
||||
""".stripMargin
|
||||
){ m =>
|
||||
m.readWord(0xc000) should equal(555)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user