fix problem with typechecking of const arrays

This commit is contained in:
Irmen de Jong 2019-08-17 21:43:48 +02:00
parent d4a17dfad1
commit cbb7083307
4 changed files with 62 additions and 48 deletions

View File

@ -1 +1,2 @@
1.52
1.53-dev

View File

@ -7,7 +7,6 @@ import prog8.ast.Program
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.statements.*
import prog8.compiler.HeapValues
import prog8.compiler.target.c64.AssemblyProgram
import prog8.functions.BuiltinFunctions
@ -255,7 +254,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
if(array is ArrayLiteralValue) {
val vardecl = array.parent as? VarDecl
return if (vardecl!=null) {
fixupArrayDatatype(array, vardecl, program.heap)
fixupArrayDatatype(array, vardecl, program)
} else if(array.heapId!=null) {
// fix the datatype of the array (also on the heap) to the 'biggest' datatype in the array
// (we don't know the desired datatype here exactly so we guess)
@ -408,14 +407,35 @@ internal fun fixupArrayDatatype(array: ArrayLiteralValue, program: Program): Arr
return array2
}
internal fun fixupArrayDatatype(array: ArrayLiteralValue, vardecl: VarDecl, heap: HeapValues): ArrayLiteralValue {
internal fun fixupArrayDatatype(array: ArrayLiteralValue, vardecl: VarDecl, program: Program): ArrayLiteralValue {
if(array.heapId!=null) {
val arrayDt = array.type
if(arrayDt!=vardecl.datatype) {
// fix the datatype of the array (also on the heap) to match the vardecl
val litval2 =
try {
array.cast(vardecl.datatype)!!
val result = array.cast(vardecl.datatype)
if(result==null) {
val constElements = array.value.mapNotNull { it.constValue(program) }
val elementDts = constElements.map { it.type }
if(DataType.FLOAT in elementDts) {
array.cast(DataType.ARRAY_F) ?: ArrayLiteralValue(DataType.ARRAY_F, array.value, array.heapId, array.position)
} else {
val numbers = constElements.map { it.number.toInt() }
val minValue = numbers.min()!!
val maxValue = numbers.max()!!
if (minValue >= 0) {
// only positive values, so uword or ubyte
val dt = if(maxValue<256) DataType.ARRAY_UB else DataType.ARRAY_UW
array.cast(dt) ?: ArrayLiteralValue(dt, array.value, array.heapId, array.position)
} else {
// negative value present, so word or byte
val dt = if(minValue >= -128 && maxValue<=127) DataType.ARRAY_B else DataType.ARRAY_W
array.cast(dt) ?: ArrayLiteralValue(dt, array.value, array.heapId, array.position)
}
}
}
else result
} catch(x: ExpressionError) {
// couldn't cast permanently.
// instead, simply adjust the array type and trust the AstChecker to report the exact error
@ -423,11 +443,11 @@ internal fun fixupArrayDatatype(array: ArrayLiteralValue, vardecl: VarDecl, heap
}
vardecl.value = litval2
litval2.linkParents(vardecl)
litval2.addToHeap(heap)
litval2.addToHeap(program.heap)
return litval2
}
} else {
array.addToHeap(heap)
array.addToHeap(program.heap)
}
return array
}

View File

@ -603,7 +603,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
if(array is ArrayLiteralValue) {
val vardecl = array.parent as? VarDecl
return if (vardecl!=null) {
fixupArrayDatatype(array, vardecl, program.heap)
fixupArrayDatatype(array, vardecl, program)
} else {
// it's not an array associated with a vardecl, attempt to guess the data type from the array values
fixupArrayDatatype(array, program)

View File

@ -6,50 +6,43 @@
main {
sub start() {
ubyte ubarr = max([10,0,2,8,5,4,3,9])
uword uwarr = max([1000,0,200,8000,50,40000,3,900])
byte barr = max([-10,0,-2,8,5,4,-3,9])
word warr = max([-1000,0,-200,8000,50,2000,3,-900])
float flarr = max([-2.2, 1.1, 3.3, 0.0])
ubarr = min([10,0,2,8,5,4,3,9])
uwarr = min([1000,0,200,8000,50,40000,3,900])
barr = min([-10,0,-2,8,5,4,-3,9])
warr = min([-1000,0,-200,8000,50,2000,3,-900])
flarr = min([-2.2, 1.1, 3.3, 0.0])
uwarr = sum([10,0,2,8,5,4,3,9])
uwarr = sum([1000,0,200,8000,50,40000,3,900])
warr = sum([-10,0,-2,8,5,4,-3,9])
warr = sum([-1000,0,-200,8000,50,2000,3,-900])
flarr = sum([-2.2, 1.1, 3.3, 0.0])
ubarr = any([10,0,2,8,5,4,3,9])
ubarr = any([1000,0,200,8000,50,40000,3,900])
ubarr = any([-10,0,-2,8,5,4,-3,9])
ubarr = any([-1000,0,-200,8000,50,2000,3,-900])
ubarr = any([-2.2, 1.1, 3.3, 0.0])
ubarr = all([10,0,2,8,5,4,3,9])
ubarr = all([1000,0,200,8000,50,40000,3,900])
ubarr = all([-10,0,-2,8,5,4,-3,9])
ubarr = all([-1000,0,-200,8000,50,2000,3,-900])
ubarr = all([-2.2, 1.1, 3.3, 0.0])
ubarr = len([10,0,2,8,5,4,3,9])
A = len([1000,0,200,8000,50,40000,3,900])
ubarr = len([-10,0,-2,8,5,4,-3,9])
A = len([-1000,0,-200,8000,50,2000,3,-900])
ubarr = len([-2.2, 1.1, 3.3, 0.0])
; ubyte[] uba = sort([10,0,2,8,5,4,3,9])
; uword[] uwa = sort([1000,0,200,8000,50,40000,3,900])
; byte[] ba = sort([-10,0,-2,8,5,4,-3,9])
; word[] wa = sort([-1000,0,-200,8000,50,40000,3,-900])
; float[] fla = sort([-2.2, 1.1, 3.3, 0.0])
ubyte[] uba = [10,0,2,8,5,4,3,9]
uword[] uwa = [1000,0,200,8000,50,40000,3,900]
byte[] ba = [-10,0,-2,8,5,4,-3,9]
word[] wa = [-1000,0,-200,8000,50,31111,3,-900]
float[] fla = [-2.2, 1.1, 3.3, 0.0]
for ubyte ub in uba {
c64scr.print_ub(ub)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for uword uw in uwa {
c64scr.print_uw(uw)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for byte bb in ba {
c64scr.print_b(bb)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for word ww in wa {
c64scr.print_w(ww)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for ubyte i in 0 to len(fla)-1 {
c64flt.print_f(fla[i])
c64.CHROUT(',')
}
c64.CHROUT('\n')
}
}