replace most common subexpressions by a single temp variable

This commit is contained in:
Irmen de Jong 2024-01-01 14:55:29 +01:00
parent 07feb5c925
commit d1f8ee1e56
10 changed files with 225 additions and 25 deletions

View File

@ -31,7 +31,10 @@ sealed class PtNode(val position: Position) {
}
class PtNodeGroup : PtNode(Position.DUMMY)
sealed interface IPtStatementContainer
class PtNodeGroup : PtNode(Position.DUMMY), IPtStatementContainer
sealed class PtNamedNode(var name: String, position: Position): PtNode(position) {
@ -75,7 +78,7 @@ class PtBlock(name: String,
val source: SourceCode, // taken from the module the block is defined in.
val options: Options,
position: Position
) : PtNamedNode(name, position) {
) : PtNamedNode(name, position), IPtStatementContainer {
enum class BlockAlignment {
NONE,
WORD,

View File

@ -36,7 +36,14 @@ sealed class PtExpression(val type: DataType, position: Position) : PtNode(posit
return arrayIndexExpr!! isSameAs other.arrayIndexExpr!!
}
is PtArrayIndexer -> other is PtArrayIndexer && other.type==type && other.variable isSameAs variable && other.index isSameAs index && other.splitWords==splitWords
is PtBinaryExpression -> other is PtBinaryExpression && other.left isSameAs left && other.right isSameAs right
is PtBinaryExpression -> {
if(other !is PtBinaryExpression || other.operator!=operator)
false
else if(operator in AssociativeOperators)
(other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left)
else
other.left isSameAs left && other.right isSameAs right
}
is PtContainmentCheck -> other is PtContainmentCheck && other.type==type && other.element isSameAs element && other.iterable isSameAs iterable
is PtIdentifier -> other is PtIdentifier && other.type==type && other.name==name
is PtMachineRegister -> other is PtMachineRegister && other.type==type && other.register==register

View File

@ -23,7 +23,7 @@ class PtSub(
val parameters: List<PtSubroutineParameter>,
val returntype: DataType?,
position: Position
) : PtNamedNode(name, position), IPtSubroutine {
) : PtNamedNode(name, position), IPtSubroutine, IPtStatementContainer {
init {
// params and return value should not be str
if(parameters.any{ it.type !in NumericDatatypes })

View File

@ -0,0 +1,120 @@
package prog8.code.optimize
import prog8.code.ast.*
import prog8.code.core.*
fun optimizeIntermediateAst(program: PtProgram, options: CompilationOptions, errors: IErrorReporter) {
if (!options.optimize)
return
while(errors.noErrors() && optimizeCommonSubExpressions(program, errors)>0) {
// keep rolling
}
}
private fun walkAst(root: PtNode, act: (node: PtNode, depth: Int) -> Boolean) {
fun recurse(node: PtNode, depth: Int) {
if(act(node, depth))
node.children.forEach { recurse(it, depth+1) }
}
recurse(root, 0)
}
private fun optimizeCommonSubExpressions(program: PtProgram, errors: IErrorReporter): Int {
fun extractableSubExpr(expr: PtExpression): Boolean {
return if(expr is PtBinaryExpression)
!expr.left.isSimple() || !expr.right.isSimple() || (expr.operator !in LogicalOperators && expr.operator !in BitwiseOperators)
else
!expr.isSimple()
}
// for each Binaryexpression, recurse to find a common subexpression pair therein.
val commons = mutableMapOf<PtBinaryExpression, Pair<PtExpression, PtExpression>>()
walkAst(program) { node: PtNode, depth: Int ->
if(node is PtBinaryExpression) {
val subExpressions = mutableListOf<PtExpression>()
walkAst(node.left) { subNode: PtNode, subDepth: Int ->
if (subNode is PtExpression) {
if(extractableSubExpr(subNode)) subExpressions.add(subNode)
true
} else false
}
walkAst(node.right) { subNode: PtNode, subDepth: Int ->
if (subNode is PtExpression) {
if(extractableSubExpr(subNode)) subExpressions.add(subNode)
true
} else false
}
outer@for (first in subExpressions) {
for (second in subExpressions) {
if (first!==second && first isSameAs second) {
commons[node] = first to second
break@outer // do only 1 replacement at a time per binaryexpression
}
}
}
false
} else true
}
// replace common subexpressions by a temp variable that is assigned only once.
commons.forEach { binexpr, (occurrence1, occurrence2) ->
val (stmtContainer, stmt) = findContainingStatements(binexpr)
val occurrence1idx = occurrence1.parent.children.indexOf(occurrence1)
val occurrence2idx = occurrence2.parent.children.indexOf(occurrence2)
val containerScopedName = findScopeName(stmtContainer)
val tempvarName = "subexprvar_line${binexpr.position.line}_${binexpr.hashCode().toUInt()}"
// TODO: some tempvars could be reused, if they are from different lines
val datatype = occurrence1.type
val singleReplacement1 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence1.position)
val singleReplacement2 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence2.position)
occurrence1.parent.children[occurrence1idx] = singleReplacement1
singleReplacement1.parent = occurrence1.parent
occurrence2.parent.children[occurrence2idx] = singleReplacement2
singleReplacement2.parent = occurrence2.parent
val tempassign = PtAssignment(binexpr.position).also { assign ->
assign.add(PtAssignTarget(binexpr.position).also { tgt->
tgt.add(PtIdentifier("$containerScopedName.$tempvarName", datatype, binexpr.position))
})
assign.add(occurrence1)
occurrence1.parent = assign
}
stmtContainer.children.add(stmtContainer.children.indexOf(stmt), tempassign)
tempassign.parent = stmtContainer
val tempvar = PtVariable(tempvarName, datatype, ZeropageWish.DONTCARE, null, null, binexpr.position)
stmtContainer.add(0, tempvar)
tempvar.parent = stmtContainer
errors.info("common subexpressions replaced by a tempvar, maybe simplify the expression manually", binexpr.position)
}
return commons.size
}
internal fun findScopeName(node: PtNode): String {
var parent=node
while(parent !is PtNamedNode)
parent = parent.parent
return parent.scopedName
}
internal fun findContainingStatements(node: PtNode): Pair<PtNode, PtNode> { // returns (parentstatementcontainer, childstatement)
var parent = node.parent
var child = node
while(true) {
if(parent is IPtStatementContainer) {
return parent to child
}
child=parent
parent=parent.parent
}
}

View File

@ -10,6 +10,7 @@ import prog8.ast.statements.Directive
import prog8.code.SymbolTableMaker
import prog8.code.ast.PtProgram
import prog8.code.core.*
import prog8.code.optimize.optimizeIntermediateAst
import prog8.code.target.*
import prog8.codegen.vm.VmCodeGen
import prog8.compiler.astprocessing.*
@ -102,7 +103,6 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
compilationOptions,
args.errors,
BuiltinFunctionsFacade(BuiltinFunctions),
compTarget
)
}
postprocessAst(program, args.errors, compilationOptions)
@ -124,6 +124,9 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
// printProgram(program)
val intermediateAst = IntermediateAstMaker(program, args.errors).transform()
// printAst(intermediateAst, true) { println(it) }
optimizeIntermediateAst(intermediateAst, compilationOptions, args.errors)
args.errors.report()
// println("*********** AST RIGHT BEFORE ASM GENERATION *************")
// printAst(intermediateAst, true, ::println)
@ -378,13 +381,13 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions
errors.report()
}
private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) {
val remover = UnusedCodeRemover(program, errors, compTarget)
private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions) {
val remover = UnusedCodeRemover(program, errors, compilerOptions.compTarget)
remover.visit(program)
remover.applyModifications()
while (true) {
// keep optimizing expressions and statements until no more steps remain
val optsDone1 = program.simplifyExpressions(errors, compTarget)
val optsDone1 = program.simplifyExpressions(errors, compilerOptions.compTarget)
val optsDone2 = program.optimizeStatements(errors, functions, compilerOptions)
val optsDone3 = program.inlineSubroutines(compilerOptions)
program.constantFold(errors, compilerOptions) // because simplified statements and expressions can result in more constants that can be folded away
@ -395,7 +398,7 @@ private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, e
if (optsDone1 + optsDone2 + optsDone3 == 0)
break
}
val remover2 = UnusedCodeRemover(program, errors, compTarget)
val remover2 = UnusedCodeRemover(program, errors, compilerOptions.compTarget)
remover2.visit(program)
remover2.applyModifications()
if(errors.noErrors())

View File

@ -6,6 +6,7 @@ import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import prog8.code.ast.*
import prog8.code.core.DataType
import prog8.code.core.Position
import prog8.code.target.C64Target
import prog8.compiler.astprocessing.IntermediateAstMaker
import prog8tests.helpers.ErrorReporterForTests
@ -60,4 +61,31 @@ class TestIntermediateAst: FunSpec({
fcall.type shouldBe DataType.UBYTE
}
test("isSame on binaryExpressions") {
val expr1 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
val expr2 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
(expr1 isSameAs expr2) shouldBe true
val expr3 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
(expr1 isSameAs expr3) shouldBe false
}
test("isSame on binaryExpressions with associative operators") {
val expr1 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
val expr2 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
(expr1 isSameAs expr2) shouldBe true
val expr3 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
(expr1 isSameAs expr3) shouldBe true
}
})

View File

@ -511,5 +511,33 @@ main {
val value = (st[5] as Assignment).value as BinaryExpression
value.operator shouldBe "%"
}
test("isSame on binary expressions") {
val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val expr1 = BinaryExpression(left1, "/", right1, Position.DUMMY)
val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val expr2 = BinaryExpression(left2, "/", right2, Position.DUMMY)
(expr1 isSameAs expr2) shouldBe true
val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val expr3 = BinaryExpression(left3, "/", right3, Position.DUMMY)
(expr1 isSameAs expr3) shouldBe false
}
test("isSame on binary expressions with associative operators") {
val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val expr1 = BinaryExpression(left1, "+", right1, Position.DUMMY)
val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val expr2 = BinaryExpression(left2, "+", right2, Position.DUMMY)
(expr1 isSameAs expr2) shouldBe true
val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY)
val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY)
val expr3 = BinaryExpression(left3, "+", right3, Position.DUMMY)
(expr1 isSameAs expr3) shouldBe true
}
})

View File

@ -34,11 +34,14 @@ sealed class Expression: Node {
(other is IdentifierReference && other.nameInSource==nameInSource)
is PrefixExpression ->
(other is PrefixExpression && other.operator==operator && other.expression isSameAs expression)
is BinaryExpression ->
(other is BinaryExpression && other.operator==operator
&& other.left isSameAs left
&& other.right isSameAs right
&& other.isChainedComparison() == isChainedComparison())
is BinaryExpression -> {
if(other !is BinaryExpression || other.operator!=operator || other.isChainedComparison()!=isChainedComparison())
false
else if(operator in AssociativeOperators)
(other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left)
else
other.left isSameAs left && other.right isSameAs right
}
is ArrayIndexedExpression -> {
(other is ArrayIndexedExpression && other.arrayvar.nameInSource == arrayvar.nameInSource
&& other.indexer isSameAs indexer)

View File

@ -1,10 +1,18 @@
TODO
====
PtAst/IR: attempt more complex common subexpression eliminations.
for any "top level" PtExpression enumerate all subexpressions and find commons, replace them by a tempvar
for walking the ast see walkAst() but it should not recurse into the "top level" PtExpression again
- why is the right term of cx16.r0 = (cx16.r1+cx16.r2) + (cx16.r1+cx16.r2) flipped around but the left term isn't?
- Revert or fix current "desugar chained comparisons" it causes problems with if statements.
ubyte @shared n=20
ubyte @shared L1=10
ubyte @shared L2=100
if n < L1 {
;txt.print("bing")
} else {
txt.print("boom") ; no longer triggers
}
...
@ -36,7 +44,6 @@ Compiler:
global initialization values are simply a list of LOAD instructions.
Variables replaced include all subroutine parameters! So the only variables that remain as variables are arrays and strings.
- ir: add more optimizations in IRPeepholeOptimizer
- ir: for expressions with array indexes that occur multiple times, can we avoid loading them into new virtualregs everytime and just reuse a single virtualreg as indexer? (this is a form of common subexpression elimination)
- ir: the @split arrays are currently also split in _lsb/_msb arrays in the IR, and operations take multiple (byte) instructions that may lead to verbose and slow operation and machine code generation down the line.
maybe another representation is needed once actual codegeneration is done from the IR...?
- [problematic due to using 64tass:] better support for building library programs, where unused .proc shouldn't be deleted from the assembly?

View File

@ -1,13 +1,14 @@
%import textio
%zeropage basicsafe
; Note: this program can be compiled for multiple target systems.
main {
sub start() {
cx16.r0L=1
while cx16.r0L < 10 and cx16.r0L>0 {
cx16.r0L++
}
ubyte[10] array1
ubyte[10] array2
ubyte @shared xx
cx16.r0 = (cx16.r1+cx16.r2) / (cx16.r2+cx16.r1)
cx16.r1 = 4*(cx16.r1+cx16.r2) + 3*(cx16.r1+cx16.r2)
cx16.r2 = array1[xx+20]==10 or array2[xx+20]==20 or array1[xx+20]==30 or array2[xx+20]==40
}
}