pipes also as expressions, cleanup codegen, fix various typecasting issues

This commit is contained in:
Irmen de Jong
2022-01-08 13:45:00 +01:00
parent 75d857027e
commit 7dd7e562bc
21 changed files with 368 additions and 118 deletions

View File

@@ -47,7 +47,7 @@ class AsmGen(private val program: Program,
private val forloopsAsmGen = ForLoopsAsmGen(program, this) private val forloopsAsmGen = ForLoopsAsmGen(program, this)
private val postincrdecrAsmGen = PostIncrDecrAsmGen(program, this) private val postincrdecrAsmGen = PostIncrDecrAsmGen(program, this)
private val functioncallAsmGen = FunctionCallAsmGen(program, this) private val functioncallAsmGen = FunctionCallAsmGen(program, this)
private val expressionsAsmGen = ExpressionsAsmGen(program, this) private val expressionsAsmGen = ExpressionsAsmGen(program, this, functioncallAsmGen)
private val assignmentAsmGen = AssignmentAsmGen(program, this) private val assignmentAsmGen = AssignmentAsmGen(program, this)
private val builtinFunctionsAsmGen = BuiltinFunctionsAsmGen(program, this, assignmentAsmGen) private val builtinFunctionsAsmGen = BuiltinFunctionsAsmGen(program, this, assignmentAsmGen)
internal val loopEndLabels = ArrayDeque<String>() internal val loopEndLabels = ArrayDeque<String>()
@@ -850,7 +850,7 @@ class AsmGen(private val program: Program,
is RepeatLoop -> translate(stmt) is RepeatLoop -> translate(stmt)
is When -> translate(stmt) is When -> translate(stmt)
is AnonymousScope -> translate(stmt) is AnonymousScope -> translate(stmt)
is Pipe -> translate(stmt) is Pipe -> expressionsAsmGen.translatePipeExpression(stmt.expressions, stmt,true)
is BuiltinFunctionPlaceholder -> throw AssemblyError("builtin function should not have placeholder anymore") is BuiltinFunctionPlaceholder -> throw AssemblyError("builtin function should not have placeholder anymore")
is UntilLoop -> throw AssemblyError("do..until should have been converted to jumps") is UntilLoop -> throw AssemblyError("do..until should have been converted to jumps")
is WhileLoop -> throw AssemblyError("while should have been converted to jumps") is WhileLoop -> throw AssemblyError("while should have been converted to jumps")
@@ -1628,31 +1628,7 @@ $label nop""")
assemblyLines.add(assembly) assemblyLines.add(assembly)
} }
private fun translate(pipe: Pipe) { internal fun returnRegisterOfFunction(it: IdentifierReference): RegisterOrPair? {
// TODO more efficient code generation to avoid needless assignments to the temp var
var valueDt = pipe.valueDatatype(program)
var valueVar = getTempVarName(valueDt)
val subroutine = pipe.definingSubroutine
assignExpressionToVariable(pipe.expressions.first(), valueVar.joinToString("."), valueDt, subroutine)
pipe.expressions.drop(1).dropLast(1).forEach {
valueDt = functioncallAsmGen.translateFunctionCall(it as IdentifierReference, listOf(IdentifierReference(valueVar, it.position)), pipe)
// assign result value from the functioncall back to the temp var:
valueVar = getTempVarName(valueDt)
val valueVarTarget = AsmAssignTarget(TargetStorageKind.VARIABLE, program, this, valueDt, subroutine, variableAsmName = valueVar.joinToString("."))
val returnRegister = returnRegisterOfFunction(it)!!
assignRegister(returnRegister, valueVarTarget)
}
// the last term in the pipe, don't care about return var:
functioncallAsmGen.translateFunctionCallStatement(
pipe.expressions.last() as IdentifierReference,
listOf(IdentifierReference(valueVar, pipe.expressions.last().position)),
pipe
)
}
private fun returnRegisterOfFunction(it: IdentifierReference): RegisterOrPair? {
return when (val targetRoutine = it.targetStatement(program)!!) { return when (val targetRoutine = it.targetStatement(program)!!) {
is BuiltinFunctionPlaceholder -> { is BuiltinFunctionPlaceholder -> {
when (BuiltinFunctions.getValue(targetRoutine.name).known_returntype) { when (BuiltinFunctions.getValue(targetRoutine.name).known_returntype) {

View File

@@ -1,5 +1,6 @@
package prog8.codegen.target.cpu6502.codegen package prog8.codegen.target.cpu6502.codegen
import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
@@ -7,11 +8,13 @@ import prog8.ast.statements.BuiltinFunctionPlaceholder
import prog8.ast.statements.Subroutine import prog8.ast.statements.Subroutine
import prog8.ast.toHex import prog8.ast.toHex
import prog8.codegen.target.AssemblyError import prog8.codegen.target.AssemblyError
import prog8.codegen.target.cpu6502.codegen.assignment.AsmAssignTarget
import prog8.codegen.target.cpu6502.codegen.assignment.TargetStorageKind
import prog8.compilerinterface.BuiltinFunctions import prog8.compilerinterface.BuiltinFunctions
import prog8.compilerinterface.CpuType import prog8.compilerinterface.CpuType
import kotlin.math.absoluteValue import kotlin.math.absoluteValue
internal class ExpressionsAsmGen(private val program: Program, private val asmgen: AsmGen) { internal class ExpressionsAsmGen(private val program: Program, private val asmgen: AsmGen, private val functioncallAsmGen: FunctionCallAsmGen) {
@Deprecated("avoid calling this as it generates slow evalstack based code") @Deprecated("avoid calling this as it generates slow evalstack based code")
internal fun translateExpression(expression:Expression) { internal fun translateExpression(expression:Expression) {
@@ -37,6 +40,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
is NumericLiteralValue -> translateExpression(expression) is NumericLiteralValue -> translateExpression(expression)
is IdentifierReference -> translateExpression(expression) is IdentifierReference -> translateExpression(expression)
is FunctionCallExpression -> translateFunctionCallResultOntoStack(expression) is FunctionCallExpression -> translateFunctionCallResultOntoStack(expression)
is PipeExpression -> translatePipeExpression(expression.expressions, expression,false)
is ContainmentCheck -> throw AssemblyError("containment check as complex expression value is not supported") is ContainmentCheck -> throw AssemblyError("containment check as complex expression value is not supported")
is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array literal value assignment - should have been replaced by a variable") is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array literal value assignment - should have been replaced by a variable")
is RangeExpression -> throw AssemblyError("range expression should have been changed into array values") is RangeExpression -> throw AssemblyError("range expression should have been changed into array values")
@@ -789,4 +793,38 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
} }
asmgen.out(" dex") asmgen.out(" dex")
} }
internal fun translatePipeExpression(expressions: Iterable<Expression>, scope: Node, isStatement: Boolean) {
// TODO more efficient code generation to avoid needless assignments to the temp var
val subroutine = scope.definingSubroutine
var valueDt = expressions.first().inferType(program).getOrElse { throw FatalAstException("invalid dt") }
var valueVar = asmgen.getTempVarName(valueDt)
asmgen.assignExpressionToVariable(expressions.first(), valueVar.joinToString("."), valueDt, subroutine)
expressions.drop(1).dropLast(1).forEach {
valueDt = functioncallAsmGen.translateFunctionCall(it as IdentifierReference, listOf(IdentifierReference(valueVar, it.position)), scope)
// assign result value from the functioncall back to the temp var:
valueVar = asmgen.getTempVarName(valueDt)
val valueVarTarget = AsmAssignTarget(TargetStorageKind.VARIABLE, program, asmgen, valueDt, subroutine, variableAsmName = valueVar.joinToString("."))
val returnRegister = asmgen.returnRegisterOfFunction(it)!!
asmgen.assignRegister(returnRegister, valueVarTarget)
}
if(isStatement) {
// the last term in the pipe, don't care about return var:
functioncallAsmGen.translateFunctionCallStatement(
expressions.last() as IdentifierReference,
listOf(IdentifierReference(valueVar, expressions.last().position)),
scope
)
} else {
// the last term in the pipe, regular function call with returnvalue:
functioncallAsmGen.translateFunctionCall(
expressions.last() as IdentifierReference,
listOf(IdentifierReference(valueVar, expressions.last().position)),
scope
)
}
}
} }

View File

@@ -361,16 +361,36 @@ class ExpressionSimplifier(private val program: Program, private val errors: IEr
return noModifications return noModifications
} }
override fun after(pipe: Pipe, parent: Node): Iterable<IAstModification> { override fun after(pipeExpr: PipeExpression, parent: Node): Iterable<IAstModification> {
val firstValue = pipe.expressions.first() val expressions = pipeExpr.expressions
val firstValue = expressions.first()
if(firstValue.isSimple) { if(firstValue.isSimple) {
val funcname = pipe.expressions[1] as IdentifierReference val funcname = expressions[1] as IdentifierReference
val first = FunctionCallExpression(funcname.copy(), mutableListOf(firstValue), firstValue.position) val first = FunctionCallExpression(funcname.copy(), mutableListOf(firstValue), firstValue.position)
val newExprs = mutableListOf<Expression>(first) val newExprs = mutableListOf<Expression>(first)
newExprs.addAll(pipe.expressions.drop(2)) newExprs.addAll(expressions.drop(2))
return listOf(IAstModification.ReplaceNode(pipeExpr, PipeExpression(newExprs, pipeExpr.position), parent))
}
val singleExpr = expressions.singleOrNull()
if(singleExpr!=null) {
val callExpr = singleExpr as FunctionCallExpression
val call = FunctionCallExpression(callExpr.target, callExpr.args, callExpr.position)
return listOf(IAstModification.ReplaceNode(pipeExpr, call, parent))
}
return noModifications
}
override fun after(pipe: Pipe, parent: Node): Iterable<IAstModification> {
val expressions = pipe.expressions
val firstValue = expressions.first()
if(firstValue.isSimple) {
val funcname = expressions[1] as IdentifierReference
val first = FunctionCallExpression(funcname.copy(), mutableListOf(firstValue), firstValue.position)
val newExprs = mutableListOf<Expression>(first)
newExprs.addAll(expressions.drop(2))
return listOf(IAstModification.ReplaceNode(pipe, Pipe(newExprs, pipe.position), parent)) return listOf(IAstModification.ReplaceNode(pipe, Pipe(newExprs, pipe.position), parent))
} }
val singleExpr = pipe.expressions.singleOrNull() val singleExpr = expressions.singleOrNull()
if(singleExpr!=null) { if(singleExpr!=null) {
val callExpr = singleExpr as FunctionCallExpression val callExpr = singleExpr as FunctionCallExpression
val call = FunctionCallStatement(callExpr.target, callExpr.args, true, callExpr.position) val call = FunctionCallStatement(callExpr.target, callExpr.args, true, callExpr.position)

View File

@@ -1224,17 +1224,44 @@ internal class AstChecker(private val program: Program,
super.visit(containment) super.visit(containment)
} }
override fun visit(pipe: PipeExpression) {
processPipe(pipe.expressions, pipe)
val last = pipe.expressions.last() as IdentifierReference
val target = last.targetStatement(program)!!
when(target) {
is BuiltinFunctionPlaceholder -> {
if(BuiltinFunctions.getValue(target.name).known_returntype==null)
errors.err("invalid pipe expression; last term doesn't return a value", last.position)
}
is Subroutine -> {
if(target.returntypes.isEmpty())
errors.err("invalid pipe expression; last term doesn't return a value", last.position)
else if(target.returntypes.size!=1)
errors.err("invalid pipe expression; last term doesn't return a single value", last.position)
}
else -> errors.err("invalid pipe expression; last term doesn't return a value", last.position)
}
super.visit(pipe)
}
override fun visit(pipe: Pipe) { override fun visit(pipe: Pipe) {
processPipe(pipe.expressions, pipe)
super.visit(pipe)
}
private fun processPipe(expressions: List<Expression>, scope: Node) {
// first expression is just any expression producing a value // first expression is just any expression producing a value
// all other expressions should be the name of a unary function that returns a single value // all other expressions should be the name of a unary function that returns a single value
// the last expression should be the name of a unary function whose return value we don't care about. // the last expression should be the name of a unary function whose return value we don't care about.
if (pipe.expressions.size < 2) { if (expressions.size < 2) {
errors.err("pipe is missing one or more expressions", pipe.position) errors.err("pipe is missing one or more expressions", scope.position)
} else { } else {
// invalid size and other issues will be handled by the ast checker later. // invalid size and other issues will be handled by the ast checker later.
var valueDt = pipe.expressions[0].inferType(program).getOrElse { throw FatalAstException("invalid dt") } var valueDt = expressions[0].inferType(program).getOrElse {
throw FatalAstException("invalid dt ${expressions[0]} @ ${scope.position}")
}
for(expr in pipe.expressions.drop(1)) { // just keep the first expression value as-is for(expr in expressions.drop(1)) { // just keep the first expression value as-is
val functionName = expr as? IdentifierReference val functionName = expr as? IdentifierReference
val function = functionName?.targetStatement(program) val function = functionName?.targetStatement(program)
if(functionName!=null && function!=null) { if(functionName!=null && function!=null) {
@@ -1243,7 +1270,7 @@ internal class AstChecker(private val program: Program,
val func = BuiltinFunctions.getValue(function.name) val func = BuiltinFunctions.getValue(function.name)
if(func.parameters.size!=1) if(func.parameters.size!=1)
errors.err("can only use unary function", expr.position) errors.err("can only use unary function", expr.position)
else if(func.known_returntype==null && expr !== pipe.expressions.last()) else if(func.known_returntype==null && expr !== expressions.last())
errors.err("function must return a single value", expr.position) errors.err("function must return a single value", expr.position)
val paramDts = func.parameters.firstOrNull()?.possibleDatatypes val paramDts = func.parameters.firstOrNull()?.possibleDatatypes
@@ -1255,7 +1282,7 @@ internal class AstChecker(private val program: Program,
is Subroutine -> { is Subroutine -> {
if(function.parameters.size!=1) if(function.parameters.size!=1)
errors.err("can only use unary function", expr.position) errors.err("can only use unary function", expr.position)
else if(function.returntypes.size!=1 && expr !== pipe.expressions.last()) else if(function.returntypes.size!=1 && expr !== expressions.last())
errors.err("function must return a single value", expr.position) errors.err("function must return a single value", expr.position)
val paramDt = function.parameters.firstOrNull()?.type val paramDt = function.parameters.firstOrNull()?.type
@@ -1277,11 +1304,8 @@ internal class AstChecker(private val program: Program,
} }
} }
} }
return super.visit(pipe)
} }
private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? { private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? {
when (val targetStatement = target.targetStatement(program)) { when (val targetStatement = target.targetStatement(program)) {
is Label, is Subroutine, is BuiltinFunctionPlaceholder -> return targetStatement is Label, is Subroutine, is BuiltinFunctionPlaceholder -> return targetStatement

View File

@@ -33,11 +33,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if(valueDt isNotAssignableTo decl.datatype) if(valueDt isNotAssignableTo decl.datatype)
return noModifications return noModifications
return listOf(IAstModification.ReplaceNode( val modifications = mutableListOf<IAstModification>()
declValue, addTypecastOrCastedValueModification(modifications, declValue, decl.datatype, decl)
TypecastExpression(declValue, decl.datatype, true, declValue.position), return modifications
decl
))
} }
} }
return noModifications return noModifications
@@ -71,13 +69,13 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
// determine common datatype and add typecast as required to make left and right equal types // determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.getOr(DataType.UNDEFINED), rightDt.getOr(DataType.UNDEFINED), expr.left, expr.operator, expr.right) val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.getOr(DataType.UNDEFINED), rightDt.getOr(DataType.UNDEFINED), expr.left, expr.operator, expr.right)
if(toFix!=null) { if(toFix!=null) {
return when { val modifications = mutableListOf<IAstModification>()
toFix===expr.left -> listOf(IAstModification.ReplaceNode( when {
expr.left, TypecastExpression(expr.left, commonDt, true, expr.left.position), expr)) toFix===expr.left -> addTypecastOrCastedValueModification(modifications, expr.left, commonDt, expr)
toFix===expr.right -> listOf(IAstModification.ReplaceNode( toFix===expr.right -> addTypecastOrCastedValueModification(modifications, expr.right, commonDt, expr)
expr.right, TypecastExpression(expr.right, commonDt, true, expr.right.position), expr))
else -> throw FatalAstException("confused binary expression side") else -> throw FatalAstException("confused binary expression side")
} }
return modifications
} }
} }
return noModifications return noModifications
@@ -95,10 +93,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if(valuetype in IterableDatatypes && targettype==DataType.UWORD) if(valuetype in IterableDatatypes && targettype==DataType.UWORD)
// special case, don't typecast STR/arrays to UWORD, we support those assignments "directly" // special case, don't typecast STR/arrays to UWORD, we support those assignments "directly"
return noModifications return noModifications
return listOf(IAstModification.ReplaceNode( val modifications = mutableListOf<IAstModification>()
assignment.value, addTypecastOrCastedValueModification(modifications, assignment.value, targettype, assignment)
TypecastExpression(assignment.value, targettype, true, assignment.value.position), return modifications
assignment))
} else { } else {
fun castLiteral(cvalue2: NumericLiteralValue): List<IAstModification.ReplaceNode> { fun castLiteral(cvalue2: NumericLiteralValue): List<IAstModification.ReplaceNode> {
val cast = cvalue2.cast(targettype) val cast = cvalue2.cast(targettype)
@@ -154,17 +151,14 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if (argtype isAssignableTo requiredType) { if (argtype isAssignableTo requiredType) {
// don't need a cast for pass-by-reference types that are assigned to UWORD // don't need a cast for pass-by-reference types that are assigned to UWORD
if(requiredType!=DataType.UWORD || argtype !in PassByReferenceDatatypes) if(requiredType!=DataType.UWORD || argtype !in PassByReferenceDatatypes)
modifications += IAstModification.ReplaceNode( addTypecastOrCastedValueModification(modifications, pair.second, requiredType, call as Node)
call.args[index],
TypecastExpression(pair.second, requiredType, true, pair.second.position),
call as Node)
} else if(requiredType == DataType.UWORD && argtype in PassByReferenceDatatypes) { } else if(requiredType == DataType.UWORD && argtype in PassByReferenceDatatypes) {
// We allow STR/ARRAY values in place of UWORD parameters. // We allow STR/ARRAY values in place of UWORD parameters.
// Take their address instead, UNLESS it's a str parameter in the containing subroutine // Take their address instead, UNLESS it's a str parameter in the containing subroutine
val identifier = pair.second as? IdentifierReference val identifier = pair.second as? IdentifierReference
if(identifier?.isSubroutineParameter(program)==false) { if(identifier?.isSubroutineParameter(program)==false) {
modifications += IAstModification.ReplaceNode( modifications += IAstModification.ReplaceNode(
call.args[index], identifier,
AddressOf(identifier, pair.second.position), AddressOf(identifier, pair.second.position),
call as Node) call as Node)
} }
@@ -172,7 +166,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val cast = (pair.second as NumericLiteralValue).cast(requiredType) val cast = (pair.second as NumericLiteralValue).cast(requiredType)
if(cast.isValid) if(cast.isValid)
modifications += IAstModification.ReplaceNode( modifications += IAstModification.ReplaceNode(
call.args[index], pair.second,
cast.valueOrZero(), cast.valueOrZero(),
call as Node) call as Node)
} }
@@ -189,10 +183,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if (pair.first.possibleDatatypes.all { argtype != it }) { if (pair.first.possibleDatatypes.all { argtype != it }) {
for (possibleType in pair.first.possibleDatatypes) { for (possibleType in pair.first.possibleDatatypes) {
if (argtype isAssignableTo possibleType) { if (argtype isAssignableTo possibleType) {
modifications += IAstModification.ReplaceNode( addTypecastOrCastedValueModification(modifications, pair.second, possibleType, call as Node)
call.args[index],
TypecastExpression(pair.second, possibleType, true, pair.second.position),
call as Node)
break break
} }
else if(DataType.UWORD in pair.first.possibleDatatypes && argtype in PassByReferenceDatatypes) { else if(DataType.UWORD in pair.first.possibleDatatypes && argtype in PassByReferenceDatatypes) {
@@ -226,29 +217,36 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
else else
errors.err("integer implicitly converted to float but floating point is not enabled via options", typecast.position) errors.err("integer implicitly converted to float but floating point is not enabled via options", typecast.position)
} }
return noModifications return noModifications
} }
override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> { override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword // make sure the memory address is an uword
val modifications = mutableListOf<IAstModification>()
val dt = memread.addressExpression.inferType(program) val dt = memread.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero() val castedValue = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero()
?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) if(castedValue!=null)
return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread)) modifications += IAstModification.ReplaceNode(memread.addressExpression, castedValue, memread)
else
addTypecastOrCastedValueModification(modifications, memread.addressExpression, DataType.UWORD, memread)
} }
return noModifications return modifications
} }
override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> { override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword // make sure the memory address is an uword
val modifications = mutableListOf<IAstModification>()
val dt = memwrite.addressExpression.inferType(program) val dt = memwrite.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero() val castedValue = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero()
?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) if(castedValue!=null)
return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite)) modifications += IAstModification.ReplaceNode(memwrite.addressExpression, castedValue, memwrite)
else
addTypecastOrCastedValueModification(modifications, memwrite.addressExpression, DataType.UWORD, memwrite)
} }
return noModifications return modifications
} }
override fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> { override fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> {
@@ -265,13 +263,32 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if(cast.isValid) if(cast.isValid)
returnStmt.value = cast.valueOrZero() returnStmt.value = cast.valueOrZero()
} else { } else {
return listOf(IAstModification.ReplaceNode( val modifications = mutableListOf<IAstModification>()
returnValue, addTypecastOrCastedValueModification(modifications, returnValue, subReturnType, returnStmt)
TypecastExpression(returnValue, subReturnType, true, returnValue.position), return modifications
returnStmt))
} }
} }
} }
return noModifications return noModifications
} }
private fun addTypecastOrCastedValueModification(
modifications: MutableList<IAstModification>,
expressionToCast: Expression,
requiredType: DataType,
parent: Node
) {
val sourceDt = expressionToCast.inferType(program).getOr(DataType.UNDEFINED)
if(sourceDt == requiredType)
return
if(expressionToCast is NumericLiteralValue && expressionToCast.type!=DataType.FLOAT) { // refuse to automatically truncate floats
val castedValue = expressionToCast.cast(requiredType)
if (castedValue.isValid) {
modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.valueOrZero(), parent)
return
}
}
val cast = TypecastExpression(expressionToCast, requiredType, true, expressionToCast.position)
modifications += IAstModification.ReplaceNode(expressionToCast, cast, parent)
}
} }

View File

@@ -36,13 +36,13 @@ class TestNumericLiteralValue: FunSpec({
test("test rounding") { test("test rounding") {
shouldThrow<ExpressionError> { shouldThrow<ExpressionError> {
NumericLiteralValue(DataType.BYTE, -2.345, dummyPos) NumericLiteralValue(DataType.BYTE, -2.345, dummyPos)
}.message shouldContain "refused silent rounding" }.message shouldContain "refused rounding"
shouldThrow<ExpressionError> { shouldThrow<ExpressionError> {
NumericLiteralValue(DataType.BYTE, -2.6, dummyPos) NumericLiteralValue(DataType.BYTE, -2.6, dummyPos)
}.message shouldContain "refused silent rounding" }.message shouldContain "refused rounding"
shouldThrow<ExpressionError> { shouldThrow<ExpressionError> {
NumericLiteralValue(DataType.UWORD, 2222.345, dummyPos) NumericLiteralValue(DataType.UWORD, 2222.345, dummyPos)
}.message shouldContain "refused silent rounding" }.message shouldContain "refused rounding"
NumericLiteralValue(DataType.UBYTE, 2.0, dummyPos).number shouldBe 2.0 NumericLiteralValue(DataType.UBYTE, 2.0, dummyPos).number shouldBe 2.0
NumericLiteralValue(DataType.BYTE, -2.0, dummyPos).number shouldBe -2.0 NumericLiteralValue(DataType.BYTE, -2.0, dummyPos).number shouldBe -2.0
NumericLiteralValue(DataType.UWORD, 2222.0, dummyPos).number shouldBe 2222.0 NumericLiteralValue(DataType.UWORD, 2222.0, dummyPos).number shouldBe 2222.0

View File

@@ -590,9 +590,8 @@ class TestOptimization: FunSpec({
""" """
val errors = ErrorReporterForTests() val errors = ErrorReporterForTests()
compileText(C64Target, optimize=true, src, writeAssembly=false, errors = errors).assertFailure() compileText(C64Target, optimize=true, src, writeAssembly=false, errors = errors).assertFailure()
errors.errors.size shouldBe 2 errors.errors.size shouldBe 1
errors.errors[0] shouldContain "type of value BYTE doesn't match target UBYTE" errors.errors[0] shouldContain "type of value BYTE doesn't match target UBYTE"
errors.errors[1] shouldContain "value '-1' out of range for unsigned byte"
} }
test("test augmented expression asmgen") { test("test augmented expression asmgen") {

View File

@@ -6,6 +6,8 @@ import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.types.instanceOf import io.kotest.matchers.types.instanceOf
import prog8.ast.expressions.FunctionCallExpression import prog8.ast.expressions.FunctionCallExpression
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.PipeExpression
import prog8.ast.statements.Assignment
import prog8.ast.statements.Pipe import prog8.ast.statements.Pipe
import prog8.codegen.target.C64Target import prog8.codegen.target.C64Target
import prog8tests.helpers.ErrorReporterForTests import prog8tests.helpers.ErrorReporterForTests
@@ -16,7 +18,7 @@ import prog8tests.helpers.compileText
class TestPipes: FunSpec({ class TestPipes: FunSpec({
test("correct pipes") { test("correct pipe statements") {
val text = """ val text = """
%import floats %import floats
%import textio %import textio
@@ -54,7 +56,7 @@ class TestPipes: FunSpec({
pipew.expressions[1] shouldBe instanceOf<IdentifierReference>() pipew.expressions[1] shouldBe instanceOf<IdentifierReference>()
} }
test("incorrect type in pipe") { test("incorrect type in pipe statement") {
val text = """ val text = """
%option enable_floats %option enable_floats
@@ -79,4 +81,66 @@ class TestPipes: FunSpec({
errors.errors[0] shouldContain "incompatible" errors.errors[0] shouldContain "incompatible"
} }
test("correct pipe expressions") {
val text = """
%import floats
%import textio
main {
sub start() {
float @shared fl = 1.234 |> addfloat
|> addfloat
uword @shared ww = 9999 |> addword
|> addword
}
sub addfloat(float fl) -> float {
return fl+2.22
}
sub addword(uword ww) -> uword {
return ww+2222
}
}
"""
val result = compileText(C64Target, true, text, writeAssembly = true).assertSuccess()
val stmts = result.program.entrypoint.statements
stmts.size shouldBe 5
val assignf = stmts[1] as Assignment
val pipef = assignf.value as PipeExpression
pipef.expressions.size shouldBe 2
pipef.expressions[0] shouldBe instanceOf<FunctionCallExpression>()
pipef.expressions[1] shouldBe instanceOf<IdentifierReference>()
val assignw = stmts[3] as Assignment
val pipew = assignw.value as PipeExpression
pipew.expressions.size shouldBe 2
pipew.expressions[0] shouldBe instanceOf<FunctionCallExpression>()
pipew.expressions[1] shouldBe instanceOf<IdentifierReference>()
}
test("incorrect type in pipe expression") {
val text = """
%option enable_floats
main {
sub start() {
uword result = 1.234 |> addfloat
|> addword |> addword
}
sub addfloat(float fl) -> float {
return fl+2.22
}
sub addword(uword ww) -> uword {
return ww+2222
}
}
"""
val errors = ErrorReporterForTests()
compileText(C64Target, false, text, errors=errors).assertFailure()
errors.errors.size shouldBe 1
errors.errors[0] shouldContain "incompatible"
}
}) })

View File

@@ -9,7 +9,6 @@ import io.kotest.matchers.types.instanceOf
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.printProgram
import prog8.codegen.target.C64Target import prog8.codegen.target.C64Target
import prog8tests.helpers.ErrorReporterForTests import prog8tests.helpers.ErrorReporterForTests
import prog8tests.helpers.assertFailure import prog8tests.helpers.assertFailure

View File

@@ -3,10 +3,6 @@ package prog8tests
import io.kotest.core.spec.style.FunSpec import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.types.instanceOf
import prog8.ast.expressions.FunctionCallExpression
import prog8.ast.expressions.IdentifierReference
import prog8.ast.statements.Pipe
import prog8.codegen.target.C64Target import prog8.codegen.target.C64Target
import prog8tests.helpers.ErrorReporterForTests import prog8tests.helpers.ErrorReporterForTests
import prog8tests.helpers.assertFailure import prog8tests.helpers.assertFailure

View File

@@ -459,10 +459,18 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program:
} }
override fun visit(pipe: Pipe) { override fun visit(pipe: Pipe) {
pipe.expressions.first().accept(this) printPipe(pipe.expressions)
}
override fun visit(pipe: PipeExpression) {
printPipe(pipe.expressions)
}
private fun printPipe(expressions: Iterable<Expression>) {
expressions.first().accept(this)
outputln("") outputln("")
scopelevel++ scopelevel++
pipe.expressions.drop(1).forEach { expressions.drop(1).forEach {
outputi("|> ") outputi("|> ")
it.accept(this) it.accept(this)
outputln("") outputln("")

View File

@@ -9,6 +9,7 @@ import prog8.parser.Prog8ANTLRParser
import prog8.parser.SourceCode import prog8.parser.SourceCode
import java.nio.file.Path import java.nio.file.Path
import kotlin.io.path.isRegularFile import kotlin.io.path.isRegularFile
import kotlin.math.exp
/***************** Antlr Extension methods to create AST ****************/ /***************** Antlr Extension methods to create AST ****************/
@@ -481,6 +482,9 @@ private fun Prog8ANTLRParser.ExpressionContext.toAst() : Expression {
if(addressof()!=null) if(addressof()!=null)
return AddressOf(addressof().scoped_identifier().toAst(), toPosition()) return AddressOf(addressof().scoped_identifier().toAst(), toPosition())
if(pipe!=null)
return PipeExpression(pipesource.toAst(), pipetarget.toAst(), toPosition())
throw FatalAstException(text) throw FatalAstException(text)
} }
@@ -617,6 +621,7 @@ private fun Prog8ANTLRParser.VardeclContext.toAst(): VarDecl {
} }
private fun Prog8ANTLRParser.PipestmtContext.toAst(): Pipe { private fun Prog8ANTLRParser.PipestmtContext.toAst(): Pipe {
val expressions = expression().map { it.toAst() } val source = this.source.toAst()
return Pipe(expressions.toMutableList(), toPosition()) val target = this.target.toAst()
return Pipe(source, target, toPosition())
} }

View File

@@ -1076,6 +1076,52 @@ class ContainmentCheck(var element: Expression,
} }
} }
class PipeExpression(val expressions: MutableList<Expression>, override val position: Position): Expression() {
override lateinit var parent: Node
constructor(source: Expression, target: Expression, position: Position) : this(mutableListOf(), position) {
if(source is PipeExpression) {
expressions.addAll(source.expressions)
expressions.add(target)
} else {
expressions.add(source)
expressions.add(target)
}
}
override val isSimple = false
override fun linkParents(parent: Node) {
this.parent=parent
expressions.forEach { it.linkParents(this) }
}
override fun copy(): PipeExpression = PipeExpression(expressions.map {it.copy()}.toMutableList(), position)
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
override fun referencesIdentifier(nameInSource: List<String>) =
expressions.any { it.referencesIdentifier(nameInSource) }
override fun inferType(program: Program): InferredTypes.InferredType {
val last = expressions.last()
val type = last.inferType(program)
if(type.isKnown)
return type
val identifier = last as? IdentifierReference
if(identifier!=null) {
val call = FunctionCallExpression(identifier, mutableListOf(), identifier.position)
return call.inferType(program)
}
return InferredTypes.InferredType.unknown()
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Expression)
require(replacement is Expression)
val idx = expressions.indexOf(node)
expressions[idx] = replacement
}
}
fun invertCondition(cond: Expression): BinaryExpression? { fun invertCondition(cond: Expression): BinaryExpression? {
if(cond is BinaryExpression) { if(cond is BinaryExpression) {

View File

@@ -1019,6 +1019,18 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position
class Pipe(val expressions: MutableList<Expression>, override val position: Position): Statement() { class Pipe(val expressions: MutableList<Expression>, override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node
constructor(source: Expression, target: Expression, position: Position) : this(mutableListOf(), position) {
if(source is PipeExpression)
expressions.addAll(source.expressions)
else
expressions.add(source)
if(target is PipeExpression)
expressions.addAll(target.expressions)
else
expressions.add(target)
}
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
this.parent = parent this.parent = parent
expressions.forEach { it.linkParents(this) } expressions.forEach { it.linkParents(this) }
@@ -1034,7 +1046,4 @@ class Pipe(val expressions: MutableList<Expression>, override val position: Posi
val idx = expressions.indexOf(node) val idx = expressions.indexOf(node)
expressions[idx] = replacement expressions[idx] = replacement
} }
fun valueDatatype(program: Program): DataType =
expressions.first().inferType(program).getOrElse { throw FatalAstException("invalid dt") }
} }

View File

@@ -121,6 +121,7 @@ abstract class AstWalker {
open fun before(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications open fun before(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications
open fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications open fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun before(pipe: Pipe, parent: Node): Iterable<IAstModification> = noModifications open fun before(pipe: Pipe, parent: Node): Iterable<IAstModification> = noModifications
open fun before(pipeExpr: PipeExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun after(addressOf: AddressOf, parent: Node): Iterable<IAstModification> = noModifications open fun after(addressOf: AddressOf, parent: Node): Iterable<IAstModification> = noModifications
open fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> = noModifications open fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> = noModifications
@@ -165,6 +166,7 @@ abstract class AstWalker {
open fun after(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications open fun after(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications
open fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications open fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(pipe: Pipe, parent: Node): Iterable<IAstModification> = noModifications open fun after(pipe: Pipe, parent: Node): Iterable<IAstModification> = noModifications
open fun after(pipeExpr: PipeExpression, parent: Node): Iterable<IAstModification> = noModifications
protected val modifications = mutableListOf<Triple<IAstModification, Node, Node>>() protected val modifications = mutableListOf<Triple<IAstModification, Node, Node>>()
@@ -465,5 +467,11 @@ abstract class AstWalker {
pipe.expressions.forEach { it.accept(this, pipe) } pipe.expressions.forEach { it.accept(this, pipe) }
track(after(pipe, parent), pipe, parent) track(after(pipe, parent), pipe, parent)
} }
fun visit(pipe: PipeExpression, parent: Node) {
track(before(pipe, parent), pipe, parent)
pipe.expressions.forEach { it.accept(this, pipe) }
track(after(pipe, parent), pipe, parent)
}
} }

View File

@@ -185,4 +185,8 @@ interface IAstVisitor {
fun visit(pipe: Pipe) { fun visit(pipe: Pipe) {
pipe.expressions.forEach { it.accept(this) } pipe.expressions.forEach { it.accept(this) }
} }
fun visit(pipe: PipeExpression) {
pipe.expressions.forEach { it.accept(this) }
}
} }

View File

@@ -120,8 +120,18 @@ class CallGraph(private val program: Program) : IAstVisitor {
allAssemblyNodes.add(inlineAssembly) allAssemblyNodes.add(inlineAssembly)
} }
override fun visit(pipe: PipeExpression) {
processPipe(pipe.expressions, pipe)
super.visit(pipe)
}
override fun visit(pipe: Pipe) { override fun visit(pipe: Pipe) {
pipe.expressions.forEach { processPipe(pipe.expressions, pipe)
super.visit(pipe)
}
private fun processPipe(expressions: Iterable<Expression>, pipe: Node) {
expressions.forEach {
if(it is IdentifierReference){ if(it is IdentifierReference){
val otherSub = it.targetSubroutine(program) val otherSub = it.targetSubroutine(program)
if(otherSub!=null) { if(otherSub!=null) {
@@ -132,7 +142,6 @@ class CallGraph(private val program: Program) : IAstVisitor {
} }
} }
} }
super.visit(pipe)
} }
fun checkRecursiveCalls(errors: IErrorReporter) { fun checkRecursiveCalls(errors: IErrorReporter) {

View File

@@ -533,12 +533,11 @@ pipe: ``|>``
|> add_bonus |> add_bonus
|> txt.print_uw |> txt.print_uw
or even:: It also works for expressions that return a value, for example ``uword score = add_bonus(determine_score(get_player(1)))`` ::
1 |> get_player uword score = get_player(1)
|> determine_score |> determine_score
|> add_bonus |> add_bonus
|> txt.print_uw
address of: ``&`` address of: ``&``
This is a prefix operator that can be applied to a string or array variable or literal value. This is a prefix operator that can be applied to a string or array variable or literal value.

View File

@@ -3,7 +3,7 @@ TODO
For next compiler release (7.7) For next compiler release (7.7)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- make pipe statement an expression so that it DOES return a result value (possibly) so you can assign it ?? - make pipe statement also an expression so that it can return a result value
- optimize codegen of pipe operator to avoid needless assigns to temp var - optimize codegen of pipe operator to avoid needless assigns to temp var
- copying floats around: do it with a subroutine rather than 5 lda/sta pairs . - copying floats around: do it with a subroutine rather than 5 lda/sta pairs .
is slower but floats are very slow already anyway and this should take a lot less program size. is slower but floats are very slow already anyway and this should take a lot less program size.
@@ -25,7 +25,7 @@ Blocked by an official Commander-x16 r39 release
Future Things and Ideas Future Things and Ideas
^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^
- can we promise a left-to-right function call argument evaluation? without sacrificing performance - can we promise a left-to-right function call argument evaluation? without sacrificing performance
- for the pipe operator: recognise a placeholder (? or % or _) in a non-unary function call to allow things as 4 |> mkword(?, $44) |> print_uw - for the pipe operator: recognise a placeholder (``?`` or ``%`` or ``_``) in a non-unary function call to allow things as ``4 |> mkword(?, $44) |> print_uw``
- make it possible to use cpu opcodes such as 'nop' as variable names by prefixing all asm vars with something such as ``v_`` - make it possible to use cpu opcodes such as 'nop' as variable names by prefixing all asm vars with something such as ``v_``
then we can get rid of the instruction lists in the machinedefinitions as well? then we can get rid of the instruction lists in the machinedefinitions as well?
- fix the asm-labels problem (github issue #62) - fix the asm-labels problem (github issue #62)

View File

@@ -1,20 +1,49 @@
%import textio %import textio
%import floats %import floats
%import test_stack
%zeropage basicsafe %zeropage basicsafe
main { main {
sub start() { sub start() {
1.234 |> addfloat |> addfloat |> floats.print_f ; float @shared f1
;
; f1 = 1.234 |> addfloat1 |> addfloat2 |> addfloat3 ; TODO fix that the value is actually returned
; floats.print_f(f1)
; txt.nl()
; 1.234 |> addfloat1
; |> addfloat2 |> addfloat3 |> floats.print_f
; txt.nl()
;
; 9+3 |> assemblything
; |> sin8u
; |> add_one
; |> times_two
; |> txt.print_uw
; txt.nl()
test_stack.test()
; TODO fix that the value is actually returned (398) and that X register is preserved:
uword @shared uw = 9+3 |> assemblything
|> sin8u
|> add_one
|> times_two
txt.print_uw(uw)
txt.nl() txt.nl()
9 * 3 |> assemblything test_stack.test()
|> sin8u
|> add_one |> times_two
|> txt.print_uw
} }
sub addfloat(float fl) -> float { sub addfloat1(float fl) -> float {
return fl+1.11 return fl+1.11
} }
sub addfloat2(float fl) -> float {
return fl+2.22
}
sub addfloat3(float fl) -> float {
return fl+3.33
}
sub add_one(ubyte input) -> ubyte { sub add_one(ubyte input) -> ubyte {
return input+1 return input+1
} }

View File

@@ -101,8 +101,8 @@ statement :
| repeatloop | repeatloop
| whenstmt | whenstmt
| breakstmt | breakstmt
| pipestmt
| labeldef | labeldef
| pipestmt
; ;
@@ -160,7 +160,8 @@ assign_target:
postincrdecr : assign_target operator = ('++' | '--') ; postincrdecr : assign_target operator = ('++' | '--') ;
expression : expression :
functioncall '(' expression ')'
| functioncall
| <assoc=right> prefix = ('+'|'-'|'~') expression | <assoc=right> prefix = ('+'|'-'|'~') expression
| left = expression EOL? bop = '**' EOL? right = expression | left = expression EOL? bop = '**' EOL? right = expression
| left = expression EOL? bop = ('*' | '/' | '%' ) EOL? right = expression | left = expression EOL? bop = ('*' | '/' | '%' ) EOL? right = expression
@@ -183,13 +184,12 @@ expression :
| directmemory | directmemory
| addressof | addressof
| expression typecast | expression typecast
| '(' expression ')' | pipesource = expression EOL? pipe=PIPE EOL? pipetarget = expression
; ;
typecast : 'as' datatype; typecast : 'as' datatype;
arrayindexed : scoped_identifier arrayindex ; arrayindexed : scoped_identifier arrayindex ;
directmemory : '@' '(' expression ')'; directmemory : '@' '(' expression ')';
@@ -209,6 +209,8 @@ returnstmt : 'return' expression? ;
breakstmt : 'break'; breakstmt : 'break';
pipestmt: source=expression pipe=PIPE EOL? target=expression ;
identifier : NAME ; identifier : NAME ;
scoped_identifier : NAME ('.' NAME)* ; scoped_identifier : NAME ('.' NAME)* ;
@@ -300,5 +302,3 @@ repeatloop: 'repeat' expression? EOL? (statement | statement_block) ;
whenstmt: 'when' expression '{' EOL (when_choice | EOL) * '}' EOL? ; whenstmt: 'when' expression '{' EOL (when_choice | EOL) * '}' EOL? ;
when_choice: (expression_list | 'else' ) '->' (statement | statement_block ) ; when_choice: (expression_list | 'else' ) '->' (statement | statement_block ) ;
pipestmt: expression EOL? PIPE expression ( EOL? PIPE expression)* ;