fixed bit shifts, added sgn() function

This commit is contained in:
Irmen de Jong 2019-08-17 15:28:06 +02:00
parent 3ae2597261
commit 27f987f0ae
9 changed files with 250 additions and 58 deletions

View File

@ -954,6 +954,16 @@ func_sum_f .proc
bne -
+ jmp push_fac1_as_result
.pend
sign_f .proc
jsr pop_float_fac1
jsr SIGN
sta c64.ESTACK_LO,x
dex
rts
.pend
}}
} ; ------ end of block c64flt

View File

@ -643,3 +643,40 @@ mul_word_40 .proc
sta c64.ESTACK_HI+1,x
rts
.pend
sign_b .proc
lda c64.ESTACK_LO+1,x
beq _sign_zero
bmi _sign_neg
_sign_pos lda #1
sta c64.ESTACK_LO+1,x
rts
_sign_neg lda #-1
_sign_zero sta c64.ESTACK_LO+1,x
rts
.pend
sign_ub .proc
lda c64.ESTACK_LO+1,x
beq sign_b._sign_zero
bne sign_b._sign_pos
.pend
sign_w .proc
lda c64.ESTACK_HI+1,x
bmi sign_b._sign_neg
beq sign_ub
bne sign_b._sign_pos
.pend
sign_uw .proc
lda c64.ESTACK_HI+1,x
beq _sign_possibly_zero
_sign_pos lda #1
sta c64.ESTACK_LO+1,x
rts
_sign_possibly_zero lda c64.ESTACK_LO+1,x
bne _sign_pos
sta c64.ESTACK_LO+1,x
rts
.pend

View File

@ -101,6 +101,18 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, private val
else -> throw AssemblyError("weird type $dt")
}
}
"sgn" -> {
translateFunctionArguments(fcall.arglist, func)
val dt = fcall.arglist.single().inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> asmgen.out(" jsr math.sign_ub")
DataType.BYTE -> asmgen.out(" jsr math.sign_b")
DataType.UWORD -> asmgen.out(" jsr math.sign_uw")
DataType.WORD -> asmgen.out(" jsr math.sign_w")
DataType.FLOAT -> asmgen.out(" jsr c64flt.sign_f")
else -> throw AssemblyError("weird type $dt")
}
}
"sin", "cos", "tan", "atan",
"ln", "log2", "sqrt", "rad",
"deg", "round", "floor", "ceil",
@ -200,7 +212,10 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, private val
is ArrayIndexedExpression -> TODO("lsr sbyte $what")
is DirectMemoryRead -> TODO("lsr sbyte $what")
is RegisterExpr -> TODO("lsr sbyte $what")
is IdentifierReference -> TODO("lsr sbyte $what")
is IdentifierReference -> {
val variable = asmgen.asmIdentifierName(what)
asmgen.out(" lda $variable | asl a | ror $variable")
}
else -> throw AssemblyError("weird type")
}
}
@ -217,7 +232,10 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, private val
DataType.WORD -> {
when (what) {
is ArrayIndexedExpression -> TODO("lsr sword $what")
is IdentifierReference -> TODO("lsr sword $what")
is IdentifierReference -> {
val variable = asmgen.asmIdentifierName(what)
asmgen.out(" lda $variable+1 | asl a | ror $variable+1 | ror $variable")
}
else -> throw AssemblyError("weird type")
}
}

View File

@ -34,6 +34,7 @@ val BuiltinFunctions = mapOf(
"abs" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", NumericDatatypes)), null, ::builtinAbs), // type depends on argument
"len" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", IterableDatatypes)), null, ::builtinLen), // type is UBYTE or UWORD depending on actual length
// normal functions follow:
"sgn" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", NumericDatatypes)), DataType.BYTE, ::builtinSgn ),
"sin" to FunctionSignature(true, listOf(BuiltinFunctionParam("rads", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sin) },
"sin8" to FunctionSignature(true, listOf(BuiltinFunctionParam("angle8", setOf(DataType.UBYTE))), DataType.BYTE, ::builtinSin8 ),
"sin8u" to FunctionSignature(true, listOf(BuiltinFunctionParam("angle8", setOf(DataType.UBYTE))), DataType.UBYTE, ::builtinSin8u ),
@ -184,7 +185,7 @@ fun builtinFunctionReturnType(function: String, args: List<Expression>, program:
}
class NotConstArgumentException: AstException("not a const argument to a built-in function")
class NotConstArgumentException: AstException("not a const argument to a built-in function") // TODO: ugly, remove throwing exceptions for control flow
private fun oneDoubleArg(args: List<Expression>, position: Position, program: Program, function: (arg: Double)->Number): NumericLiteralValue {
@ -359,6 +360,13 @@ private fun builtinCos16u(args: List<Expression>, position: Position, program: P
return NumericLiteralValue(DataType.UWORD, (32768.0 + 32767.5 * cos(rad)).toInt(), position)
}
private fun builtinSgn(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1)
throw SyntaxError("sgn requires one argument", position)
val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
return NumericLiteralValue(DataType.BYTE, constval.number.toDouble().sign.toShort(), position)
}
private fun numericLiteral(value: Number, position: Position): NumericLiteralValue {
val floatNum=value.toDouble()
val tweakedValue: Number =

View File

@ -8,6 +8,7 @@ import prog8.ast.statements.Assignment
import prog8.ast.statements.Statement
import kotlin.math.abs
import kotlin.math.log2
import kotlin.math.pow
/*
todo advanced expression optimization: common (sub) expression elimination (turn common expressions into single subroutine call + introduce variable to hold it)
@ -216,12 +217,6 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
}
}
// WORD >> 8 --> msb(WORD)
if(expr.operator == ">>" && leftDt in WordDatatypes && rightVal?.number == 8) {
optimizationsDone++
return FunctionCall(IdentifierReference(listOf("msb"), expr.position), mutableListOf(expr.left), expr.position)
}
if (expr.operator == "+" || expr.operator == "-"
&& leftVal == null && rightVal == null
&& leftDt in NumericDatatypes && rightDt in NumericDatatypes) {
@ -346,6 +341,8 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
"-" -> return optimizeSub(expr, leftVal, rightVal)
"**" -> return optimizePower(expr, leftVal, rightVal)
"%" -> return optimizeRemainder(expr, leftVal, rightVal)
">>" -> return optimizeShiftRight(expr, rightVal)
"<<" -> return optimizeShiftLeft(expr, rightVal)
}
return expr
}
@ -611,12 +608,16 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
}
private val powersOfTwo = (1 .. 16).map { (2.0).pow(it) }
private val negativePowersOfTwo = powersOfTwo.map { -it }
private fun optimizeDivision(expr: BinaryExpression, leftVal: NumericLiteralValue?, rightVal: NumericLiteralValue?): Expression {
if(leftVal==null && rightVal==null)
return expr
// cannot shuffle assiciativity with division!
// TODO fix bug in this routine!
// cannot shuffle assiciativity with division!
if(rightVal!=null) {
// right value is a constant, see if we can optimize
val rightConst: NumericLiteralValue = rightVal
@ -640,7 +641,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr.left
}
}
2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0 -> {
in powersOfTwo -> {
if(leftDt in IntegerDatatypes) {
// divided by a power of two => shift right
optimizationsDone++
@ -648,7 +649,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return BinaryExpression(expr.left, ">>", NumericLiteralValue.optimalInteger(numshifts, expr.position), expr.position)
}
}
-2.0, -4.0, -8.0, -16.0, -32.0, -64.0, -128.0, -256.0, -512.0, -1024.0, -2048.0, -4096.0, -8192.0, -16384.0, -32768.0, -65536.0 -> {
in negativePowersOfTwo -> {
if(leftDt in IntegerDatatypes) {
// divided by a negative power of two => negate, then shift right
optimizationsDone++
@ -733,4 +734,97 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr
}
private fun optimizeShiftLeft(expr: BinaryExpression, amountLv: NumericLiteralValue?): Expression {
if(amountLv==null)
return expr
val amount=amountLv.number.toInt()
if(amount==0) {
optimizationsDone++
return expr.left
}
val targetDt = expr.left.inferType(program).typeOrElse(DataType.STRUCT)
when(targetDt) {
DataType.UBYTE, DataType.BYTE -> {
if(amount>=8) {
optimizationsDone++
return NumericLiteralValue.optimalInteger(0, expr.position)
}
}
DataType.UWORD, DataType.WORD -> {
if(amount>=16) {
optimizationsDone++
return NumericLiteralValue.optimalInteger(0, expr.position)
}
else if(amount>=8) {
optimizationsDone++
val lsb=TypecastExpression(expr.left, DataType.UBYTE, true, expr.position)
if(amount==8) {
return FunctionCall(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(NumericLiteralValue.optimalInteger(0, expr.position), lsb), expr.position)
}
val shifted = BinaryExpression(lsb, "<<", NumericLiteralValue.optimalInteger(amount-8, expr.position), expr.position)
return FunctionCall(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(NumericLiteralValue.optimalInteger(0, expr.position), shifted), expr.position)
}
}
else -> {}
}
return expr
}
private fun optimizeShiftRight(expr: BinaryExpression, amountLv: NumericLiteralValue?): Expression {
if(amountLv==null)
return expr
val amount=amountLv.number.toInt()
if(amount==0) {
optimizationsDone++
return expr.left
}
val targetDt = expr.left.inferType(program).typeOrElse(DataType.STRUCT)
when(targetDt) {
DataType.UBYTE -> {
if(amount>=8) {
optimizationsDone++
return NumericLiteralValue.optimalInteger(0, expr.position)
}
}
DataType.BYTE -> {
if(amount>8) {
expr.right = NumericLiteralValue.optimalInteger(8, expr.right.position)
return expr
}
}
DataType.UWORD -> {
if(amount>=16) {
optimizationsDone++
return NumericLiteralValue.optimalInteger(0, expr.position)
}
else if(amount>=8) {
optimizationsDone++
val msb=FunctionCall(IdentifierReference(listOf("msb"), expr.position), mutableListOf(expr.left), expr.position)
if(amount==8)
return msb
return BinaryExpression(msb, ">>", NumericLiteralValue.optimalInteger(amount-8, expr.position), expr.position)
}
}
DataType.WORD -> {
if(amount>16) {
expr.right = NumericLiteralValue.optimalInteger(16, expr.right.position)
return expr
} else if(amount>=8) {
optimizationsDone++
val msbAsByte = TypecastExpression(
FunctionCall(IdentifierReference(listOf("msb"), expr.position), mutableListOf(expr.left), expr.position),
DataType.BYTE,
true, expr.position)
if(amount==8)
return msbAsByte
return BinaryExpression(msbAsByte, ">>", NumericLiteralValue.optimalInteger(amount-8, expr.position), expr.position)
}
}
else -> {}
}
return expr
}
}

View File

@ -516,8 +516,7 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV
optimizationsDone++
return NopStatement.insteadOf(assignment)
}
if (((targetDt == DataType.UWORD || targetDt == DataType.WORD) && cv > 15.0) ||
((targetDt == DataType.UBYTE || targetDt == DataType.BYTE) && cv > 7.0)) {
if ((targetDt == DataType.UWORD && cv > 15.0) || (targetDt == DataType.UBYTE && cv > 7.0)) {
assignment.value = NumericLiteralValue.optimalInteger(0, assignment.value.position)
assignment.value.linkParents(assignment)
optimizationsDone++

View File

@ -727,6 +727,9 @@ lsb(x)
msb(x)
Get the most significant byte of the word x.
sgn(x)
Get the sign of the value. Result is -1, 0 or 1 (negative, zero, positive).
mkword(lsb, msb)
Efficiently create a word value from two bytes (the lsb and the msb). Avoids multiplication and shifting.

View File

@ -6,55 +6,78 @@
main {
sub start() {
byte ub = 100
byte ub2
word uw = 22222
word uw2
uword start=1027
uword stop=2020
uword i
ubyte ib
ub = -100
c64scr.print_b(ub >> 1)
c64.CHROUT('\n')
c64scr.print_b(ub >> 2)
c64.CHROUT('\n')
c64scr.print_b(ub >> 7)
c64.CHROUT('\n')
c64scr.print_b(ub >> 8)
c64.CHROUT('\n')
c64scr.print_b(ub >> 9)
c64.CHROUT('\n')
c64scr.print_b(ub >> 16)
c64.CHROUT('\n')
c64scr.print_b(ub >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
c64scr.print("\n\n\n\n\n\n\n\n")
memset($0400, 40*25, 30)
ubyte ibstart = 1
for ib in ibstart to 255-ibstart {
@(ib+1024) = 44
}
for ib in 253 to 2 step -1 {
@(ib+1024) = 3
}
ibstart = 3
for ib in 255-ibstart to ibstart step -1 {
@(ib+1024) = 45
}
ub = 100
c64scr.print_b(ub >> 1)
c64.CHROUT('\n')
c64scr.print_b(ub >> 2)
c64.CHROUT('\n')
c64scr.print_b(ub >> 7)
c64.CHROUT('\n')
c64scr.print_b(ub >> 8)
c64.CHROUT('\n')
c64scr.print_b(ub >> 9)
c64.CHROUT('\n')
c64scr.print_b(ub >> 16)
c64.CHROUT('\n')
c64scr.print_b(ub >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
for i in 1025 to 2022 {
@(i) = 1
}
uw = -22222
c64scr.print_w(uw >> 1)
c64.CHROUT('\n')
c64scr.print_w(uw >> 7)
c64.CHROUT('\n')
c64scr.print_w(uw >> 8)
c64.CHROUT('\n')
c64scr.print_w(uw >> 9)
c64.CHROUT('\n')
c64scr.print_w(uw >> 15)
c64.CHROUT('\n')
c64scr.print_w(uw >> 16)
c64.CHROUT('\n')
c64scr.print_w(uw >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
for i in 2021 to 1026 step -1 {
@(i) = 92
}
for i in start to stop {
@(i) = 0
}
for i in stop-1 to start+1 step -1 {
@(i) = 91
}
ubyte xx=X
c64scr.print_ub(xx)
; for i in stop to start {
; c64scr.print_uw(i)
; c64.CHROUT(',')
; }
uw = 22222
c64scr.print_w(uw >> 1)
c64.CHROUT('\n')
c64scr.print_w(uw >> 7)
c64.CHROUT('\n')
c64scr.print_w(uw >> 8)
c64.CHROUT('\n')
c64scr.print_w(uw >> 9)
c64.CHROUT('\n')
c64scr.print_w(uw >> 15)
c64.CHROUT('\n')
c64scr.print_w(uw >> 16)
c64.CHROUT('\n')
c64scr.print_w(uw >> 26)
c64.CHROUT('\n')
}
}