improved sourceref column calculation when dealing with tabs, added more error checks

This commit is contained in:
Irmen de Jong 2018-01-14 18:02:39 +01:00
parent 07387f501a
commit 10d0dbe80b
11 changed files with 266 additions and 176 deletions

View File

@ -13,7 +13,7 @@ from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check
import attr import attr
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \ from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\ SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
SymbolName, Dereference, AddressOf SymbolName, Dereference, AddressOf, IncrDecr, TargetRegisters
from .plylex import SourceRef, print_bold from .plylex import SourceRef, print_bold
from .optimize import optimize from .optimize import optimize
from .datatypes import DataType, VarType from .datatypes import DataType, VarType
@ -42,6 +42,7 @@ class PlyParser:
# these shall only be done on the main module after all imports have been done: # these shall only be done on the main module after all imports have been done:
self.apply_directive_options(module) self.apply_directive_options(module)
self.determine_subroutine_usage(module) self.determine_subroutine_usage(module)
self.semantic_check(module)
self.allocate_zeropage_vars(module) self.allocate_zeropage_vars(module)
except ParseError as x: except ParseError as x:
self.handle_parse_error(x) self.handle_parse_error(x)
@ -54,6 +55,18 @@ class PlyParser:
self.parse_errors += 1 self.parse_errors += 1
print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args))) print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
def semantic_check(self, module: Module) -> None:
# perform semantic analysis / checks on the syntactic parse tree we have so far
for block, parent in module.all_scopes():
assert isinstance(block, (Module, Block, Subroutine))
assert parent is None or isinstance(parent, (Module, Block, Subroutine))
for stmt in block.nodes:
if isinstance(stmt, IncrDecr):
if isinstance(stmt.target, SymbolName):
symdef = block.scope[stmt.target.name]
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
raise ParseError("cannot modify a constant", stmt.sourceref)
def check_and_merge_zeropages(self, module: Module) -> None: def check_and_merge_zeropages(self, module: Module) -> None:
# merge all ZP blocks into one # merge all ZP blocks into one
zeropage = None zeropage = None
@ -126,7 +139,6 @@ class PlyParser:
if isinstance(node.right, Assignment): if isinstance(node.right, Assignment):
multi = reduce_right(node) multi = reduce_right(node)
assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment) assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment)
node.simplify_targetregisters()
def apply_directive_options(self, module: Module) -> None: def apply_directive_options(self, module: Module) -> None:
def set_save_registers(scope: Scope, save_dir: Directive) -> None: def set_save_registers(scope: Scope, save_dir: Directive) -> None:
@ -392,7 +404,7 @@ class PlyParser:
print("Error:", str(exc), file=sys.stderr) print("Error:", str(exc), file=sys.stderr)
sourcetext = linecache.getline(exc.sourceref.file, exc.sourceref.line).rstrip() sourcetext = linecache.getline(exc.sourceref.file, exc.sourceref.line).rstrip()
if sourcetext: if sourcetext:
print(" " + sourcetext.expandtabs(1), file=sys.stderr) print(" " + sourcetext.expandtabs(8), file=sys.stderr)
if exc.sourceref.column: if exc.sourceref.column:
print(' ' * (1+exc.sourceref.column) + '^', file=sys.stderr) print(' ' * (1+exc.sourceref.column) + '^', file=sys.stderr)
if sys.stderr.isatty(): if sys.stderr.isatty():

View File

@ -6,7 +6,23 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
from typing import Callable from typing import Callable
from ..plyparse import LiteralValue, Assignment, AugAssignment from ..plyparse import AstNode, Scope, VarDef, Dereference, Register, TargetRegisters,\
LiteralValue, Assignment, AugAssignment
from ..datatypes import DataType
from ..plyparse import SymbolName
def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType:
if isinstance(assignmenttarget, (VarDef, Dereference, Register)):
return assignmenttarget.datatype
elif isinstance(assignmenttarget, SymbolName):
symdef = scope[assignmenttarget.name]
if isinstance(symdef, VarDef):
return symdef.datatype
elif isinstance(assignmenttarget, TargetRegisters):
if len(assignmenttarget.registers) == 1:
return datatype_of(assignmenttarget.registers[0], scope)
raise TypeError("cannot determine datatype", assignmenttarget)
def generate_assignment(out: Callable, stmt: Assignment) -> None: def generate_assignment(out: Callable, stmt: Assignment) -> None:

View File

@ -15,4 +15,3 @@ def generate_goto(out: Callable, stmt: Goto) -> None:
def generate_subcall(out: Callable, stmt: SubCall) -> None: def generate_subcall(out: Callable, stmt: SubCall) -> None:
pass # @todo pass # @todo

View File

@ -9,7 +9,7 @@ import os
import datetime import datetime
from typing import TextIO, Callable from typing import TextIO, Callable
from ..plylex import print_bold from ..plylex import print_bold
from ..plyparse import Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, AstNode, ZpOptions, \ from ..plyparse import Module, Scope, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, AstNode, ZpOptions, \
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr
from . import CodeError, to_hex from . import CodeError, to_hex
from .variables import generate_block_init, generate_block_vars from .variables import generate_block_init, generate_block_vars
@ -160,7 +160,7 @@ class AssemblyGenerator:
for stmt in block.scope.nodes: for stmt in block.scope.nodes:
if isinstance(stmt, (VarDef, Subroutine)): if isinstance(stmt, (VarDef, Subroutine)):
continue # should have been handled already or will be later continue # should have been handled already or will be later
self.generate_statement(out, stmt) self.generate_statement(out, stmt, block.scope)
if block.name == "main" and isinstance(stmt, Label) and stmt.name == "start": if block.name == "main" and isinstance(stmt, Label) and stmt.name == "start":
# make sure the main.start routine clears the decimal and carry flags as first steps # make sure the main.start routine clears the decimal and carry flags as first steps
out("\vcld\n\vclc\n\vclv") out("\vcld\n\vclc\n\vclv")
@ -177,15 +177,14 @@ class AssemblyGenerator:
out("\v; params: {}\n\v; returns: {} clobbers: {}".format(params or "-", returns or "-", clobbers or "-")) out("\v; params: {}\n\v; returns: {} clobbers: {}".format(params or "-", returns or "-", clobbers or "-"))
cur_block = self.cur_block cur_block = self.cur_block
self.cur_block = subdef.scope self.cur_block = subdef.scope
print(subdef.scope.nodes)
for stmt in subdef.scope.nodes: for stmt in subdef.scope.nodes:
self.generate_statement(out, stmt) self.generate_statement(out, stmt, subdef.scope)
self.cur_block = cur_block self.cur_block = cur_block
out("") out("")
out("; -- end block subroutines") out("; -- end block subroutines")
out("\n\v.pend\n") out("\n\v.pend\n")
def generate_statement(self, out: Callable, stmt: AstNode) -> None: def generate_statement(self, out: Callable, stmt: AstNode, scope: Scope) -> None:
if isinstance(stmt, Label): if isinstance(stmt, Label):
out("\n{:s}\v\t\t; {:s}".format(stmt.name, stmt.lineref)) out("\n{:s}\v\t\t; {:s}".format(stmt.name, stmt.lineref))
elif isinstance(stmt, Return): elif isinstance(stmt, Return):
@ -207,7 +206,7 @@ class AssemblyGenerator:
out(stmt.assembly) out(stmt.assembly)
out("\v; end inline asm, " + stmt.lineref + "\n") out("\v; end inline asm, " + stmt.lineref + "\n")
elif isinstance(stmt, IncrDecr): elif isinstance(stmt, IncrDecr):
generate_incrdecr(out, stmt) generate_incrdecr(out, stmt, scope)
elif isinstance(stmt, Goto): elif isinstance(stmt, Goto):
generate_goto(out, stmt) generate_goto(out, stmt)
elif isinstance(stmt, SubCall): elif isinstance(stmt, SubCall):

View File

@ -6,217 +6,206 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
from typing import Callable from typing import Callable
from ..plyparse import Scope, AstNode, Register, IncrDecr, TargetRegisters, SymbolName, Dereference from ..plyparse import Scope, VarType, VarDef, Register, TargetRegisters, IncrDecr, SymbolName, Dereference
from ..datatypes import DataType, REGISTER_BYTES from ..datatypes import DataType, REGISTER_BYTES
from . import CodeError, to_hex, preserving_registers from . import CodeError, to_hex, preserving_registers
from .assignment import datatype_of
def datatype_of(node: AstNode, scope: Scope) -> DataType: def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None:
if isinstance(node, (Dereference, Register)):
return node.datatype
if isinstance(node, SymbolName):
symdef = scope[node.name]
raise TypeError("cannot determine datatype", node)
def generate_incrdecr(out: Callable, stmt: IncrDecr) -> None:
assert isinstance(stmt.howmuch, (int, float)) and stmt.howmuch >= 0 assert isinstance(stmt.howmuch, (int, float)) and stmt.howmuch >= 0
assert stmt.operator in ("++", "--") assert stmt.operator in ("++", "--")
target = stmt.target target = stmt.target # one of Register/SymbolName/Dereference
if isinstance(target, TargetRegisters): if isinstance(target, SymbolName):
if len(target.registers) != 1: symdef = scope[target.name]
raise CodeError("incr/decr can operate on one register at a time only") if isinstance(symdef, VarDef):
target = target[0] target = symdef
# target = Register/SymbolName/Dereference else:
raise CodeError("cannot incr/decr this", symdef)
if stmt.howmuch > 255: if stmt.howmuch > 255:
if isinstance(stmt.target, TargetRegisters) if datatype_of(target, scope) != DataType.FLOAT:
if stmt.what.datatype != DataType.FLOAT and not stmt.value.name and stmt.value.value > 0xff: raise CodeError("only supports integer incr/decr by up to 255 for now")
raise CodeError("only supports integer incr/decr by up to 255 for now") # XXX howmuch_str = str(stmt.howmuch)
howmuch = stmt.value.value
value_str = stmt.value.name or str(howmuch) if isinstance(target, Register):
if isinstance(stmt.what, RegisterValue): reg = target.name
reg = stmt.what.register
# note: these operations below are all checked to be ok # note: these operations below are all checked to be ok
if stmt.operator == "++": if stmt.operator == "++":
if reg == 'A': if reg == 'A':
# a += 1..255 # a += 1..255
out("\t\tclc") out("\vclc")
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
elif reg in REGISTER_BYTES: elif reg in REGISTER_BYTES:
if howmuch == 1: if stmt.howmuch == 1:
# x/y += 1 # x/y += 1
out("\t\tin{:s}".format(reg.lower())) out("\vin{:s}".format(reg.lower()))
else: else:
# x/y += 2..255 # x/y += 2..255
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\tt{:s}a".format(reg.lower())) out("\vt{:s}a".format(reg.lower()))
out("\t\tclc") out("\vclc")
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
out("\t\tta{:s}".format(reg.lower())) out("\vta{:s}".format(reg.lower()))
elif reg == "AX": elif reg == "AX":
# AX += 1..255 # AX += 1..255
out("\t\tclc") out("\vclc")
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
out("\t\tbcc +") out("\vbcc +")
out("\t\tinx") out("\vinx")
out("+") out("+")
elif reg == "AY": elif reg == "AY":
# AY += 1..255 # AY += 1..255
out("\t\tclc") out("\vclc")
out("\t\tadc # " + value_str) out("\vadc # " + howmuch_str)
out("\t\tbcc +") out("\vbcc +")
out("\t\tiny") out("\viny")
out("+") out("+")
elif reg == "XY": elif reg == "XY":
if howmuch == 1: if stmt.howmuch == 1:
# XY += 1 # XY += 1
out("\t\tinx") out("\vinx")
out("\t\tbne +") out("\vbne +")
out("\t\tiny") out("\viny")
out("+") out("+")
else: else:
# XY += 2..255 # XY += 2..255
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\ttxa") out("\vtxa")
out("\t\tclc") out("\vclc")
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
out("\t\ttax") out("\vtax")
out("\t\tbcc +") out("\vbcc +")
out("\t\tiny") out("\viny")
out("+") out("+")
else: else:
raise CodeError("invalid incr register: " + reg) raise CodeError("invalid incr register: " + reg)
else: else:
if reg == 'A': if reg == 'A':
# a -= 1..255 # a -= 1..255
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
elif reg in REGISTER_BYTES: elif reg in REGISTER_BYTES:
if howmuch == 1: if stmt.howmuch == 1:
# x/y -= 1 # x/y -= 1
out("\t\tde{:s}".format(reg.lower())) out("\vde{:s}".format(reg.lower()))
else: else:
# x/y -= 2..255 # x/y -= 2..255
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\tt{:s}a".format(reg.lower())) out("\vt{:s}a".format(reg.lower()))
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\tta{:s}".format(reg.lower())) out("\vta{:s}".format(reg.lower()))
elif reg == "AX": elif reg == "AX":
# AX -= 1..255 # AX -= 1..255
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\tbcs +") out("\vbcs +")
out("\t\tdex") out("\vdex")
out("+") out("+")
elif reg == "AY": elif reg == "AY":
# AY -= 1..255 # AY -= 1..255
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\tbcs +") out("\vbcs +")
out("\t\tdey") out("\vdey")
out("+") out("+")
elif reg == "XY": elif reg == "XY":
if howmuch == 1: if stmt.howmuch == 1:
# XY -= 1 # XY -= 1
out("\t\tcpx #0") out("\vcpx #0")
out("\t\tbne +") out("\vbne +")
out("\t\tdey") out("\vdey")
out("+\t\tdex") out("+\t\tdex")
else: else:
# XY -= 2..255 # XY -= 2..255
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\ttxa") out("\vtxa")
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\ttax") out("\vtax")
out("\t\tbcs +") out("\vbcs +")
out("\t\tdey") out("\vdey")
out("+") out("+")
else: else:
raise CodeError("invalid decr register: " + reg) raise CodeError("invalid decr register: " + reg)
elif isinstance(stmt.what, (MemMappedValue, IndirectValue)):
what = stmt.what elif isinstance(target, VarDef):
if isinstance(what, IndirectValue): if target.vartype == VarType.CONST:
if isinstance(what.value, IntegerValue): raise CodeError("cannot modify a constant", target)
what_str = what.value.name or to_hex(what.value.value) what_str = target.name
if target.datatype == DataType.BYTE:
if stmt.howmuch == 1:
out("\v{:s} {:s}".format("inc" if stmt.operator == "++" else "dec", what_str))
else: else:
raise CodeError("invalid incr indirect type", what.value) with preserving_registers({'A'}, scope, out):
else: out("\vlda " + what_str)
what_str = what.name or to_hex(what.address)
if what.datatype == DataType.BYTE:
if howmuch == 1:
out("\t\t{:s} {:s}".format("inc" if stmt.operator == "++" else "dec", what_str))
else:
with preserving_registers({'A'}):
out("\t\tlda " + what_str)
if stmt.operator == "++": if stmt.operator == "++":
out("\t\tclc") out("\vclc")
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
else: else:
out("\t\tsec") out("\vsec")
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\tsta " + what_str) out("\vsta " + what_str)
elif what.datatype == DataType.WORD: elif target.datatype == DataType.WORD:
if howmuch == 1: if stmt.howmuch == 1:
# mem.word +=/-= 1 # mem.word +=/-= 1
if stmt.operator == "++": if stmt.operator == "++":
out("\t\tinc " + what_str) out("\vinc " + what_str)
out("\t\tbne +") out("\vbne +")
out("\t\tinc {:s}+1".format(what_str)) out("\vinc {:s}+1".format(what_str))
out("+") out("+")
else: else:
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\tlda " + what_str) out("\vlda " + what_str)
out("\t\tbne +") out("\vbne +")
out("\t\tdec {:s}+1".format(what_str)) out("\vdec {:s}+1".format(what_str))
out("+\t\tdec " + what_str) out("+\t\tdec " + what_str)
else: else:
# mem.word +=/-= 2..255 # mem.word +=/-= 2..255
if stmt.operator == "++": if stmt.operator == "++":
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\tclc") out("\vclc")
out("\t\tlda " + what_str) out("\vlda " + what_str)
out("\t\tadc #" + value_str) out("\vadc #" + howmuch_str)
out("\t\tsta " + what_str) out("\vsta " + what_str)
out("\t\tbcc +") out("\vbcc +")
out("\t\tinc {:s}+1".format(what_str)) out("\vinc {:s}+1".format(what_str))
out("+") out("+")
else: else:
with preserving_registers({'A'}): with preserving_registers({'A'}, scope, out):
out("\t\tsec") out("\vsec")
out("\t\tlda " + what_str) out("\vlda " + what_str)
out("\t\tsbc #" + value_str) out("\vsbc #" + howmuch_str)
out("\t\tsta " + what_str) out("\vsta " + what_str)
out("\t\tbcs +") out("\vbcs +")
out("\t\tdec {:s}+1".format(what_str)) out("\vdec {:s}+1".format(what_str))
out("+") out("+")
elif what.datatype == DataType.FLOAT: elif target.datatype == DataType.FLOAT:
if howmuch == 1.0: if stmt.howmuch == 1.0:
# special case for +/-1 # special case for +/-1
with preserving_registers({'A', 'X', 'Y'}, loads_a_within=True): with preserving_registers({'A', 'X', 'Y'}, scope, out, loads_a_within=True):
out("\t\tldx #<" + what_str) out("\vldx #<" + what_str)
out("\t\tldy #>" + what_str) out("\vldy #>" + what_str)
if stmt.operator == "++": if stmt.operator == "++":
out("\t\tjsr c64flt.float_add_one") out("\vjsr c64flt.float_add_one")
else: else:
out("\t\tjsr c64flt.float_sub_one") out("\vjsr c64flt.float_sub_one")
elif stmt.value.name: elif stmt.value.name: # XXX
with preserving_registers({'A', 'X', 'Y'}, loads_a_within=True): with preserving_registers({'A', 'X', 'Y'}, scope, out, loads_a_within=True):
out("\t\tlda #<" + stmt.value.name) out("\vlda #<" + stmt.value.name)
out("\t\tsta c64.SCRATCH_ZPWORD1") out("\vsta c64.SCRATCH_ZPWORD1")
out("\t\tlda #>" + stmt.value.name) out("\vlda #>" + stmt.value.name)
out("\t\tsta c64.SCRATCH_ZPWORD1+1") out("\vsta c64.SCRATCH_ZPWORD1+1")
out("\t\tldx #<" + what_str) out("\vldx #<" + what_str)
out("\t\tldy #>" + what_str) out("\vldy #>" + what_str)
if stmt.operator == "++": if stmt.operator == "++":
out("\t\tjsr c64flt.float_add_SW1_to_XY") out("\vjsr c64flt.float_add_SW1_to_XY")
else: else:
out("\t\tjsr c64flt.float_sub_SW1_from_XY") out("\vjsr c64flt.float_sub_SW1_from_XY")
else: else:
raise CodeError("incr/decr missing float constant definition") raise CodeError("incr/decr missing float constant definition")
else: else:
raise CodeError("cannot in/decrement memory of type " + str(what.datatype), howmuch) raise CodeError("cannot in/decrement memory of type " + str(target.datatype), stmt.howmuch)
else: else:
raise CodeError("cannot in/decrement " + str(stmt.what)) raise CodeError("cannot in/decrement", target)

View File

@ -5,7 +5,7 @@ This is the optimizer that applies various optimizations to the parse tree.
Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
from .plyparse import Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression from .plyparse import Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr
from .plylex import print_warning, print_bold from .plylex import print_warning, print_bold
@ -16,16 +16,17 @@ class Optimizer:
def optimize(self) -> None: def optimize(self) -> None:
self.num_warnings = 0 self.num_warnings = 0
self.remove_useless_assigns() self.optimize_assignments()
self.combine_assignments_into_multi() self.combine_assignments_into_multi()
self.optimize_multiassigns() self.optimize_multiassigns()
self.remove_unused_subroutines() self.remove_unused_subroutines()
self.optimize_compare_with_zero() self.optimize_compare_with_zero()
self.remove_empty_blocks() self.remove_empty_blocks()
def remove_useless_assigns(self): def optimize_assignments(self):
# remove assignment statements that do nothing (A=A) # remove assignment statements that do nothing (A=A)
# and augmented assignments that have no effect (A+=0) # and augmented assignments that have no effect (A+=0)
# convert augmented assignments to simple incr/decr if possible (A+=10 => A++ by 10)
# @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?) # @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?)
for block, parent in self.module.all_scopes(): for block, parent in self.module.all_scopes():
for assignment in list(block.nodes): for assignment in list(block.nodes):
@ -45,6 +46,10 @@ class Optimizer:
print("{}: shifting result is always zero".format(assignment.sourceref)) print("{}: shifting result is always zero".format(assignment.sourceref))
new_stmt = Assignment(left=[assignment.left], right=0, sourceref=assignment.sourceref) new_stmt = Assignment(left=[assignment.left], right=0, sourceref=assignment.sourceref)
block.scope.replace_node(assignment, new_stmt) block.scope.replace_node(assignment, new_stmt)
if assignment.operator in ("+=", "-=") and 0 < assignment.right < 256:
new_stmt = IncrDecr(target=assignment.left, operator="++" if assignment.operator == "+=" else "--",
howmuch=assignment.right, sourceref=assignment.sourceref)
block.scope.replace_node(assignment, new_stmt)
def combine_assignments_into_multi(self): def combine_assignments_into_multi(self):
# fold multiple consecutive assignments with the same rvalue into one multi-assignment # fold multiple consecutive assignments with the same rvalue into one multi-assignment

View File

@ -230,6 +230,10 @@ def t_BOOLEAN(t):
def t_DOTTEDNAME(t): def t_DOTTEDNAME(t):
r"[a-zA-Z_]\w*(\.[a-zA-Z_]\w*)+" r"[a-zA-Z_]\w*(\.[a-zA-Z_]\w*)+"
first, second = t.value.split(".")
if first in reserved or second in reserved:
custom_error(t, "reserved word as part of dotted name")
return None
return t return t
@ -321,10 +325,22 @@ def t_error(t):
t.lexer.skip(1) t.lexer.skip(1)
def custom_error(t, message):
line, col = t.lineno, find_tok_column(t)
filename = getattr(t.lexer, "source_filename", "<unknown-file>")
sref = SourceRef(filename, line, col)
if hasattr(t.lexer, "error_function"):
t.lexer.error_function(sref, message)
else:
print(sref, message, file=sys.stderr)
t.lexer.skip(1)
def find_tok_column(token): def find_tok_column(token):
""" Find the column of the token in its line.""" """ Find the column of the token in its line."""
last_cr = lexer.lexdata.rfind('\n', 0, token.lexpos) last_cr = lexer.lexdata.rfind('\n', 0, token.lexpos)
return token.lexpos - last_cr chunk = lexer.lexdata[last_cr:token.lexpos]
return len(chunk.expandtabs())
def print_warning(text: str, sourceref: SourceRef = None) -> None: def print_warning(text: str, sourceref: SourceRef = None) -> None:

View File

@ -101,7 +101,7 @@ class Scope(AstNode):
symbols = attr.ib(init=False) symbols = attr.ib(init=False)
name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc. name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc.
parent_scope = attr.ib(init=False, default=None) # will be wired up later parent_scope = attr.ib(init=False, default=None) # will be wired up later
save_registers = attr.ib(type=bool, default=None, init=False) # None = look in parent scope's setting # @todo property that does that save_registers = attr.ib(type=bool, default=None, init=False) # None = look in parent scope's setting @todo property that does that
def __attrs_post_init__(self): def __attrs_post_init__(self):
# populate the symbol table for this scope for fast lookups via scope["name"] or scope["dotted.name"] # populate the symbol table for this scope for fast lookups via scope["name"] or scope["dotted.name"]
@ -319,19 +319,6 @@ class Assignment(AstNode):
left = attr.ib(type=list) # type: List[Union[str, TargetRegisters, Dereference]] left = attr.ib(type=list) # type: List[Union[str, TargetRegisters, Dereference]]
right = attr.ib() right = attr.ib()
def __attrs_post_init__(self):
self.simplify_targetregisters()
def simplify_targetregisters(self) -> None:
# optimize TargetRegisters down to single Register if it's just one register
new_targets = []
assert isinstance(self.left, (list, tuple)), "assignment lvalue must be sequence"
for t in self.left:
if isinstance(t, TargetRegisters) and len(t.registers) == 1:
t = t.registers[0]
new_targets.append(t)
self.left = new_targets
def process_expressions(self, scope: Scope) -> None: def process_expressions(self, scope: Scope) -> None:
self.right = process_expression(self.right, scope, self.right.sourceref) self.right = process_expression(self.right, scope, self.right.sourceref)
@ -393,6 +380,10 @@ class Return(AstNode):
@attr.s(cmp=False, repr=False) @attr.s(cmp=False, repr=False)
class TargetRegisters(AstNode): class TargetRegisters(AstNode):
# This is a tuple of 1 or more registers.
# In it's multiple-register form it is only used to be able to parse
# the result of a subroutine call such as A,X = sub().
# It will be replaced by a regular Register node if it contains just one register.
registers = attr.ib(type=list) registers = attr.ib(type=list)
def add(self, register: str) -> None: def add(self, register: str) -> None:
@ -527,6 +518,8 @@ class Dereference(AstNode):
elif isinstance(self.datatype, DatatypeNode): elif isinstance(self.datatype, DatatypeNode):
assert self.size is None assert self.size is None
self.size = self.datatype.dimensions self.size = self.datatype.dimensions
if not self.datatype.to_enum().isnumeric():
raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref)
self.datatype = self.datatype.to_enum() self.datatype = self.datatype.to_enum()
@ -545,6 +538,7 @@ class AddressOf(AstNode):
@attr.s(cmp=False, repr=False) @attr.s(cmp=False, repr=False)
class IncrDecr(AstNode): class IncrDecr(AstNode):
# increment or decrement something by a constant value (1 or more)
target = attr.ib() target = attr.ib()
operator = attr.ib(type=str, validator=attr.validators.in_(["++", "--"])) operator = attr.ib(type=str, validator=attr.validators.in_(["++", "--"]))
howmuch = attr.ib(default=1) howmuch = attr.ib(default=1)
@ -554,6 +548,11 @@ class IncrDecr(AstNode):
if self.howmuch < 0: if self.howmuch < 0:
self.howmuch = -self.howmuch self.howmuch = -self.howmuch
self.operator = "++" if self.operator == "--" else "--" self.operator = "++" if self.operator == "--" else "--"
if isinstance(self.target, Register):
if self.target.name not in REGISTER_BYTES | REGISTER_WORDS:
raise ParseError("cannot incr/decr that register", self.sourceref)
if isinstance(self.target, TargetRegisters):
raise ParseError("cannot incr/decr multiple registers at once", self.sourceref)
@attr.s(cmp=False, repr=False) @attr.s(cmp=False, repr=False)
@ -1087,7 +1086,7 @@ def p_incrdecr(p):
incrdecr : assignment_target INCR incrdecr : assignment_target INCR
| assignment_target DECR | assignment_target DECR
""" """
p[0] = IncrDecr(target=p[1], operator=p[2], sourceref=_token_sref(p, 1)) p[0] = IncrDecr(target=p[1], operator=p[2], sourceref=_token_sref(p, 2))
def p_call_subroutine(p): def p_call_subroutine(p):
@ -1316,6 +1315,11 @@ def p_assignment_target(p):
| symbolname | symbolname
| dereference | dereference
""" """
if isinstance(p[1], TargetRegisters):
# if the target registers is just a single register, use that instead
if len(p[1].registers) == 1:
assert isinstance(p[1].registers[0], Register)
p[1] = p[1].registers[0]
p[0] = p[1] p[0] = p[1]
@ -1356,7 +1360,8 @@ def _token_sref(p, token_idx):
last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx)) last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx))
if last_cr < 0: if last_cr < 0:
last_cr = -1 last_cr = -1
column = (p.lexpos(token_idx) - last_cr) chunk = p.lexer.lexdata[last_cr:p.lexpos(token_idx)]
column = len(chunk.expandtabs())
return SourceRef(p.lexer.source_filename, p.lineno(token_idx), column) return SourceRef(p.lexer.source_filename, p.lineno(token_idx), column)

View File

@ -165,10 +165,11 @@ For most other types this prefix is not supported.
**Indirect addressing:** The ``[address]`` syntax means: the contents of the memory at address, or "indirect addressing". **Indirect addressing:** The ``[address]`` syntax means: the contents of the memory at address, or "indirect addressing".
By default, if not otherwise known, a single byte is assumed. You can add the ``.byte`` or ``.word`` or ``.float`` By default, if not otherwise known, a single byte is assumed. You can add the ``.byte`` or ``.word`` or ``.float``
type identifier suffix to make it clear what data type the address points to. type identifier, inside the bracket, to make it clear what data type the address points to.
This addressing mode is only supported for constant (integer) addresses and not for variable types, For instance: ``[address .word]`` (notice the space, to distinguish this from a dotted symbol name).
unless it is part of a subroutine call statement. For an indirect goto call, the 6502 CPU has a special instruction For an indirect goto call, the 6502 CPU has a special instruction
(``jmp`` indirect) and an indirect subroutine call (``jsr`` indirect) is synthesized using a couple of instructions. (``jmp`` indirect) and an indirect subroutine call (``jsr`` indirect) is emitted
using a couple of instructions.
Program Structure Program Structure

View File

@ -180,3 +180,29 @@ def test_parser_2():
assert isinstance(call.target.target, SymbolName) assert isinstance(call.target.target, SymbolName)
assert call.target.target.name == "zz" assert call.target.target.name == "zz"
assert call.target.address_of is True assert call.target.address_of is True
test_source_3 = """
~ {
goto.XY = 5
AX.text = 5
[$c000.word] = 5
[AX.word] = 5
}
"""
def test_typespec():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_3, tokenfunc=filter.token)
nodes = result.nodes[0].nodes
assignment1, assignment2, assignment3, assignment4 = nodes
assert assignment1.right.value == 5
assert assignment2.right.value == 5
assert assignment3.right.value == 5
assert assignment4.right.value == 5
print("A1", assignment1.left)
print("A2", assignment2.left)
print("A3", assignment3.left)
print("A4", assignment4.left)

View File

@ -5,6 +5,7 @@
var zp1_1 = 200 var zp1_1 = 200
var zp1_2 = 200 var zp1_2 = 200
var .float zpf1
var .text zp_s1 = "hello\n" var .text zp_s1 = "hello\n"
var .ptext zp_s2 = "goodbye\n" var .ptext zp_s2 = "goodbye\n"
var .stext zp_s3 = "welcome\n" var .stext zp_s3 = "welcome\n"
@ -13,10 +14,31 @@
var .array(20) arr1 = $ea var .array(20) arr1 = $ea
var .wordarray(20) arr2 = $ea var .wordarray(20) arr2 = $ea
memory border = $d020
const .word cword = 2
start: start:
%breakpoint abc,def %breakpoint abc,def
A++
X--
A+=1
X-=2
border++
zp1_1++
zpf1++
[AX]++
[AX .byte]++
[AX .word]++
[AX .float]++
[$ccc0]++
[$ccc0 .byte]++
[$ccc0 .word]++
[$ccc0 .float]++
A+=2
XY+=666
foobar() foobar()
return 44 return 44