1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-05-31 18:41:30 +00:00
This commit is contained in:
Adam Gastineau 2021-02-26 15:06:18 -05:00 committed by GitHub
commit 2871f20a00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 2169 additions and 88 deletions

4
.gitignore vendored
View File

@ -1,6 +1,9 @@
# various directories
target/
.bloop/
.bsp/
.idea/
.metals/
project/target
project/project/target/
stuff
@ -14,6 +17,7 @@ include-*/
# hidden files
*.~
.DS_Store
#tools
*.bat

1
.scalafmt.conf Normal file
View File

@ -0,0 +1 @@
version = "2.6.4"

18
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,18 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "scala",
"name": "Debug",
"request": "launch",
"mainClass": "millfork.Main",
// optional jvm properties to use
"jvmOptions": [],
"args": []
},
]
}

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"files.watcherExclude": {
"**/target": true
}
}

17
.vscode/tasks.json vendored Normal file
View File

@ -0,0 +1,17 @@
{
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
"version": "2.0.0",
"tasks": [
{
"label": "Compile Millfork",
"type": "shell",
"command": "sbt -DskipTests=true compile && sbt -DskipTests=true assembly",
"problemMatcher": [],
"group": {
"kind": "build",
"isDefault": true
}
}
]
}

View File

@ -10,6 +10,10 @@ libraryDependencies += "com.lihaoyi" %% "fastparse" % "1.0.0"
libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.2"
libraryDependencies += "org.eclipse.lsp4j" % "org.eclipse.lsp4j" % "0.9.0"
libraryDependencies += "net.liftweb" %% "lift-json" % "3.4.2"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test"
val testDependencies = Seq(
@ -33,7 +37,7 @@ val testDependencies = Seq(
val includesTests = System.getProperty("skipTests") == null
libraryDependencies ++=(
libraryDependencies ++= (
if (includesTests) {
println("Including test dependencies")
testDependencies
@ -43,26 +47,27 @@ libraryDependencies ++=(
)
(if (!includesTests) {
// Disable assembling tests
sbt.internals.DslEntry.fromSettingsDef(test in assembly := {})
} else {
sbt.internals.DslEntry.fromSettingsDef(Seq[sbt.Def.Setting[_]]())
})
// Disable assembling tests
sbt.internal.DslEntry.fromSettingsDef(test in assembly := {})
} else {
sbt.internal.DslEntry.fromSettingsDef(Seq[sbt.Def.Setting[_]]())
})
mainClass in Compile := Some("millfork.Main")
assemblyJarName := "millfork.jar"
lazy val root = (project in file(".")).
enablePlugins(BuildInfoPlugin).
settings(
lazy val root = (project in file("."))
.enablePlugins(BuildInfoPlugin)
.settings(
buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion),
buildInfoPackage := "millfork.buildinfo"
)
import sbtassembly.AssemblyKeys
val releaseDist = TaskKey[File]("release-dist", "Creates a distributable zip file.")
val releaseDist =
TaskKey[File]("release-dist", "Creates a distributable zip file.")
releaseDist := {
val jar = AssemblyKeys.assembly.value
@ -79,7 +84,10 @@ releaseDist := {
IO.createDirectory(distDir)
IO.copyFile(jar, distDir / jar.name)
IO.copyFile(base / "LICENSE", distDir / "LICENSE")
IO.copyFile(base / "src/3rd-party-licenses.txt", distDir / "3rd-party-licenses.txt")
IO.copyFile(
base / "src/3rd-party-licenses.txt",
distDir / "3rd-party-licenses.txt"
)
IO.copyFile(base / "CHANGELOG.md", distDir / "CHANGELOG.md")
IO.copyFile(base / "README.md", distDir / "README.md")
IO.copyFile(base / "COMPILING.md", distDir / "COMPILING.md")
@ -89,8 +97,14 @@ releaseDist := {
}
copyDir("include")
copyDir("docs")
def entries(f: File): List[File] = f :: (if (f.isDirectory) IO.listFiles(f).toList.flatMap(entries) else Nil)
IO.zip(entries(distDir).map(d => (d, d.getAbsolutePath.substring(distDir.getParent.length + 1))), zipFile)
def entries(f: File): List[File] =
f :: (if (f.isDirectory) IO.listFiles(f).toList.flatMap(entries) else Nil)
IO.zip(
entries(distDir).map(d =>
(d, d.getAbsolutePath.substring(distDir.getParent.length + 1))
),
zipFile
)
IO.delete(distDir)
zipFile
}

View File

@ -1 +1 @@
sbt.version = 0.13.18
sbt.version = 1.4.0

4
project/metals.sbt Normal file
View File

@ -0,0 +1,4 @@
// DO NOT EDIT! This file is auto-generated.
// This file enables sbt-bloop to create bloop config files.
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80")

View File

@ -0,0 +1,4 @@
// DO NOT EDIT! This file is auto-generated.
// This file enables sbt-bloop to create bloop config files.
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80")

View File

@ -0,0 +1,4 @@
// DO NOT EDIT! This file is auto-generated.
// This file enables sbt-bloop to create bloop config files.
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80")

View File

@ -9,6 +9,7 @@ import millfork.error.Logger
case class Context(errorReporting: Logger,
inputFileNames: List[String],
outputFileName: Option[String] = None,
configFilePath: Option[String] = None,
runFileName: Option[String] = None,
runParams: Seq[String] = Vector(),
optimizationLevel: Option[Int] = None,
@ -21,7 +22,8 @@ case class Context(errorReporting: Logger,
extraIncludePath: Seq[String] = IndexedSeq(),
flags: Map[CompilationFlag.Value, Boolean] = Map(),
features: Map[String, Long] = Map(),
verbosity: Option[Int] = None) {
verbosity: Option[Int] = None,
languageServer: Boolean = false) {
def changeFlag(f: CompilationFlag.Value, b: Boolean): Context = {
if (flags.contains(f)) {
if (flags(f) != b) {

View File

@ -19,11 +19,15 @@ import millfork.node.StandardCallGraph
import millfork.output._
import millfork.parser.{MSourceLoadingQueue, MosSourceLoadingQueue, TextCodecRepository, ZSourceLoadingQueue}
import millfork.language.{MfLanguageServer,MfLanguageClient,LanguageServerLogger}
import org.eclipse.lsp4j.services.LanguageServer
import org.eclipse.lsp4j.jsonrpc.Launcher
import java.util.concurrent.Executors
import java.io.PrintWriter
import millfork.cli.JsonConfigParser
object Main {
def main(args: Array[String]): Unit = {
val errorReporting = new ConsoleLogger
implicit val __implicitLogger: Logger = errorReporting
@ -34,6 +38,7 @@ object Main {
val startTime = System.nanoTime()
val (status, c0) = parser(errorReporting).parse(Context(errorReporting, Nil), args.toList)
val c1 = JsonConfigParser.parseConfig(c0, errorReporting)
status match {
case CliStatus.Quit => return
case CliStatus.Failed =>
@ -41,8 +46,8 @@ object Main {
case CliStatus.Ok => ()
}
errorReporting.assertNoErrors("Invalid command line")
errorReporting.verbosity = c0.verbosity.getOrElse(0)
if (c0.inputFileNames.isEmpty) {
errorReporting.verbosity = c1.verbosity.getOrElse(0)
if (c1.inputFileNames.isEmpty && !c1.languageServer) {
errorReporting.fatalQuit("No input files")
}
@ -51,14 +56,14 @@ object Main {
errorReporting.trace("This program comes with ABSOLUTELY NO WARRANTY.")
errorReporting.trace("This is free software, and you are welcome to redistribute it under certain conditions")
errorReporting.trace("You should have received a copy of the GNU General Public License along with this program. If not, see https://www.gnu.org/licenses/")
val c = fixMissingIncludePath(c0).filloutFlags()
val c = fixMissingIncludePath(c1).filloutFlags()
if (c.includePath.isEmpty) {
errorReporting.warn("Failed to detect the default include directory, consider using the -I option")
}
val textCodecRepository = new TextCodecRepository("." :: c.includePath)
val platform = Platform.lookupPlatformFile("." :: c.includePath, c.platform.getOrElse {
errorReporting.info("No platform selected, defaulting to `c64`")
if (!c1.languageServer) errorReporting.info("No platform selected, defaulting to `c64`")
"c64"
}, textCodecRepository)
val options = CompilationOptions(platform, c.flags, c.outputFileName, c.zpRegisterSize.getOrElse(platform.zpRegisterSize), c.features, textCodecRepository, JobContext(errorReporting, new LabelGenerator))
@ -67,6 +72,25 @@ object Main {
case (f, b) => errorReporting.debug(f" $f%-30s : $b%s")
}
if (c1.languageServer) {
// We cannot log anything to stdout when starting the language server (otherwise it's a protocol violation)
errorReporting.setOutput(true)
val server = new MfLanguageServer(c, options)
val exec = Executors.newCachedThreadPool()
val launcher = new Launcher.Builder[MfLanguageClient]()
.setExecutorService(exec)
.setInput(System.in)
.setOutput(System.out)
.setRemoteInterface(classOf[MfLanguageClient])
.setLocalService(server)
.create()
val clientProxy = launcher.getRemoteProxy
server.client = Some(clientProxy)
launcher.startListening().get()
}
val output = c.outputFileName match {
case Some(ofn) => ofn
case None => c.inputFileNames match {
@ -252,7 +276,7 @@ object Main {
val unoptimized = new MosSourceLoadingQueue(
initialFilenames = c.inputFileNames,
includePath = c.includePath,
options = options).run()
options = options).run().compilationOrderProgram
val program = if (optLevel > 0) {
OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options))
@ -306,7 +330,7 @@ object Main {
val unoptimized = new ZSourceLoadingQueue(
initialFilenames = c.inputFileNames,
includePath = c.includePath,
options = options).run()
options = options).run().compilationOrderProgram
val program = if (optLevel > 0) {
OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options))
@ -346,7 +370,7 @@ object Main {
val unoptimized = new MSourceLoadingQueue(
initialFilenames = c.inputFileNames,
includePath = c.includePath,
options = options).run()
options = options).run().compilationOrderProgram
val program = if (optLevel > 0) {
OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options))
@ -376,7 +400,7 @@ object Main {
val unoptimized = new ZSourceLoadingQueue(
initialFilenames = c.inputFileNames,
includePath = c.includePath,
options = options).run()
options = options).run().compilationOrderProgram
val program = if (optLevel > 0) {
OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options))
@ -429,6 +453,15 @@ object Main {
c.copy(outputLabels = true, outputLabelsFormatOverride = Some(f))
}.description("Generate also the label file in the given format. Available options: vice, nesasm, sym.")
flag("-lsp").action { c =>
c.copy(languageServer = true)
}.description("Start the Millfork language server. Does not start compilation.")
parameter("-c", "--config").placeholder("<file>").action { (p, c) =>
assertNone(c.outputFileName, "Config file already defined")
c.copy(configFilePath = Some(p))
}.description("The Millfork config file. Suppliments the provided CLI options.")
boolean("-fbreakpoints", "-fno-breakpoints").action((c,v) =>
c.changeFlag(CompilationFlag.EnableBreakpoints, v)
).description("Include breakpoints in the label file. Requires either -g or -G.")

View File

@ -0,0 +1,65 @@
package millfork.cli
import net.liftweb.json._
import java.nio.file.Files
import java.nio.file.Paths
import java.nio.charset.StandardCharsets
import scala.collection.mutable
import scala.collection.convert.ImplicitConversionsToScala._
import java.io.InputStreamReader
import millfork.Context
import millfork.error.ConsoleLogger
case class JsonConfig(
include: Option[List[String]],
platform: Option[String],
inputFiles: Option[List[String]]
)
object JsonConfigParser {
implicit val formats = DefaultFormats
def parseConfig(context: Context, logger: ConsoleLogger): Context = {
var newContext = context
var defaultConfig = false
val filePath = context.configFilePath.getOrElse({
defaultConfig = true
".millforkrc.json"
})
val path = Paths.get(filePath)
try {
val jsonString =
Files
.readAllLines(path, StandardCharsets.UTF_8)
.toIndexedSeq
.mkString("")
val result = parse(jsonString).extract[JsonConfig]
if (context.inputFileNames.length < 1 && result.inputFiles.isDefined) {
newContext = newContext.copy(inputFileNames = result.inputFiles.get)
}
if (context.includePath.length < 1 && result.include.isDefined) {
newContext =
newContext.copy(extraIncludePath = result.include.get.toSeq)
}
if (context.platform.isEmpty && result.platform.isDefined) {
newContext = newContext.copy(platform = Some(result.platform.get))
}
} catch {
case default: Throwable => {
if (!defaultConfig) {
// Only throw error if not default config
logger.fatalQuit("Invalid config file")
}
}
}
newContext
}
}

View File

@ -4,10 +4,12 @@ import millfork.assembly.SourceLine
import millfork.node.Position
import scala.collection.mutable
import java.io.PrintStream
class ConsoleLogger extends Logger {
FatalErrorReporting.considerAsGlobal(this)
private var defaultUseStderr = false
var verbosity = 0
var fatalWarnings = false
@ -15,6 +17,10 @@ class ConsoleLogger extends Logger {
this.fatalWarnings = fatalWarnings
}
def setOutput(useStderr: Boolean): Unit = {
this.defaultUseStderr = useStderr
}
var hasErrors = false
private val sourceLines: mutable.Map[String, IndexedSeq[String]] = mutable.Map()
@ -27,11 +33,11 @@ class ConsoleLogger extends Logger {
val line = lines.apply(lineIx)
val column = pos.get.column - 1
val margin = " "
print(margin)
println(line)
print(margin)
print(" " * column)
println("^")
this.print(margin)
this.println(line)
this.print(margin)
this.print(" " * column)
this.println("^")
}
}
}
@ -42,14 +48,14 @@ class ConsoleLogger extends Logger {
override def info(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 0) return
println("INFO: " + f(position) + msg)
this.println("INFO: " + f(position) + msg)
printErrorContext(position)
flushOutput()
}
override def debug(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 1) return
println("DEBUG: " + f(position) + msg)
this.println("DEBUG: " + f(position) + msg)
flushOutput()
}
@ -59,7 +65,7 @@ class ConsoleLogger extends Logger {
override def trace(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 2) return
println("TRACE: " + f(position) + msg)
this.println("TRACE: " + f(position) + msg)
flushOutput()
}
@ -71,7 +77,7 @@ class ConsoleLogger extends Logger {
override def warn(msg: String, position: Option[Position] = None): Unit = {
if (verbosity < 0) return
println("WARN: " + f(position) + msg)
this.println("WARN: " + f(position) + msg)
printErrorContext(position)
flushOutput()
if (fatalWarnings) {
@ -81,14 +87,14 @@ class ConsoleLogger extends Logger {
override def error(msg: String, position: Option[Position] = None): Unit = {
hasErrors = true
println("ERROR: " + f(position) + msg)
this.println("ERROR: " + f(position) + msg)
printErrorContext(position)
flushOutput()
}
override def fatal(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
this.println("FATAL: " + f(position) + msg)
printErrorContext(position)
flushOutput()
throw new AssertionError(msg)
@ -96,7 +102,7 @@ class ConsoleLogger extends Logger {
override def fatalQuit(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
this.println("FATAL: " + f(position) + msg)
printErrorContext(position)
flushOutput()
System.exit(1)
@ -128,4 +134,10 @@ class ConsoleLogger extends Logger {
file <- sourceLines.get(line.moduleName)
line <- file.lift(line.line - 1)
} yield line
private def getOutputStream: PrintStream = if (this.defaultUseStderr) System.err else System.out
private def print(x: String): Unit = getOutputStream.print(x)
private def println(x: String): Unit = getOutputStream.println(x)
}

View File

@ -0,0 +1,35 @@
package millfork.language
import millfork.error.Logger
import millfork.node.Position
import millfork.assembly.SourceLine
class LanguageServerLogger extends Logger {
// TODO: Unused. Complete stub to send diagnostics to client
override def setFatalWarnings(fatalWarnings: Boolean): Unit = {}
override def info(msg: String, position: Option[Position]): Unit = {}
override def debug(msg: String, position: Option[Position]): Unit = {}
override def trace(msg: String, position: Option[Position]): Unit = {}
override def traceEnabled: Boolean = false
override def debugEnabled: Boolean = false
override def warn(msg: String, position: Option[Position]): Unit = {}
override def error(msg: String, position: Option[Position]): Unit = {}
override def fatal(msg: String, position: Option[Position]): Nothing = ???
override def fatalQuit(msg: String, position: Option[Position]): Nothing = ???
override def assertNoErrors(msg: String): Unit = {}
override def addSource(filename: String, lines: IndexedSeq[String]): Unit = {}
override def getLine(line: SourceLine): Option[String] = None
}

View File

@ -0,0 +1,69 @@
package millfork.language
import org.eclipse.lsp4j.services.LanguageClient
import org.eclipse.lsp4j.jsonrpc.services.JsonNotification
import org.eclipse.lsp4j.MessageType
import org.eclipse.lsp4j.MessageParams
trait MfLanguageClient extends LanguageClient {
/**
* Display message in the editor "status bar", which should be displayed somewhere alongside the buffer.
*
* The status bar should always be visible to the user.
*
* - VS Code: https://code.visualstudio.com/docs/extensionAPI/vscode-api#StatusBarItem
*/
// @JsonNotification("metals/status")
// def metalsStatus(params: MetalsStatusParams): Unit
/**
* Starts a long running task with no estimate for how long it will take to complete.
*
* - request cancellation from the server indicates that the task has completed
* - response with cancel=true indicates the client wishes to cancel the slow task
*/
// @JsonRequest("metals/slowTask")
// def metalsSlowTask(
// params: MetalsSlowTaskParams
// ): CompletableFuture[MetalsSlowTaskResult]
// @JsonNotification("metals/executeClientCommand")
// def metalsExecuteClientCommand(params: ExecuteCommandParams): Unit
final def refreshModel(): Unit = {
// val command = ClientCommands.RefreshModel.id
// val params = new ExecuteCommandParams(command, Nil.asJava)
// metalsExecuteClientCommand(params)
}
/**
* Opens an input box to ask the user for input.
*
* @return the user provided input. The future can be cancelled, meaning
* the input box should be dismissed in the editor.
*/
// @JsonRequest("metals/inputBox")
// def metalsInputBox(
// params: MetalsInputBoxParams
// ): CompletableFuture[MetalsInputBoxResult]
/**
* Opens an menu to ask the user to pick one of the suggested options.
*
* @return the user provided pick. The future can be cancelled, meaning
* the input box should be dismissed in the editor.
*/
// @JsonRequest("metals/quickPick")
// def metalsQuickPick(
// params: MetalsQuickPickParams
// ): CompletableFuture[MetalsQuickPickResult]
final def showMessage(messageType: MessageType, message: String): Unit = {
val params = new MessageParams(messageType, message)
showMessage(params)
}
def shutdown(): Unit = {}
}

View File

@ -0,0 +1,483 @@
package millfork.language
import millfork.CompilationOptions
import millfork.parser.{
MosSourceLoadingQueue,
ZSourceLoadingQueue,
MSourceLoadingQueue,
ParsedProgram
}
import millfork.Context
import millfork.node.{
FunctionDeclarationStatement,
ParameterDeclaration,
Position,
Node
}
import org.eclipse.lsp4j.services.{
LanguageServer,
TextDocumentService,
WorkspaceService
}
import org.eclipse.lsp4j.{
InitializeParams,
InitializeResult,
ServerCapabilities,
Range
}
import org.eclipse.lsp4j.jsonrpc.services.JsonRequest
import java.util.concurrent.CompletableFuture
import org.eclipse.lsp4j.TextDocumentPositionParams
import org.eclipse.lsp4j.Hover
import org.eclipse.lsp4j.jsonrpc.messages.Either
import java.{util => ju}
import scala.collection.mutable
import org.eclipse.lsp4j.MarkedString
import org.eclipse.lsp4j.jsonrpc.services.JsonNotification
import org.eclipse.lsp4j.InitializedParams
import org.eclipse.lsp4j.MarkupContent
import org.eclipse.lsp4j.DefinitionParams
import org.eclipse.lsp4j.Location
import net.liftweb.json._
import net.liftweb.json.Serialization.{read, write}
import org.eclipse.lsp4j.MessageParams
import org.eclipse.lsp4j.MessageType
import org.eclipse.lsp4j.DidOpenTextDocumentParams
import org.eclipse.lsp4j.TextDocumentSyncKind
import org.eclipse.lsp4j.DidChangeTextDocumentParams
import java.nio.file.Path
import java.nio.file.Paths
import org.eclipse.lsp4j.VersionedTextDocumentIdentifier
import millfork.node.Program
import millfork.node.ImportStatement
import org.eclipse.lsp4j.ReferenceParams
import millfork.node.DeclarationStatement
import scala.collection.JavaConverters._
import millfork.CpuFamily
import millfork.parser.ZSourceLoadingQueue
import millfork.parser.MSourceLoadingQueue
class MfLanguageServer(context: Context, options: CompilationOptions) {
var client: Option[MfLanguageClient] = None
val cachedModules: mutable.Map[String, Program] = mutable.Map()
private var cachedProgram: Option[ParsedProgram] = None
private val moduleNames: mutable.Map[String, String] = mutable.Map()
private val modulePaths: mutable.Map[String, Path] = mutable.Map()
@JsonRequest("initialize")
def initialize(
params: InitializeParams
): CompletableFuture[
InitializeResult
] =
CompletableFuture.completedFuture {
val capabilities = new ServerCapabilities()
capabilities.setHoverProvider(true)
capabilities.setDefinitionProvider(true)
capabilities.setTextDocumentSync(TextDocumentSyncKind.Full)
capabilities.setReferencesProvider(true)
new InitializeResult(capabilities)
}
@JsonNotification("initialized")
def initialized(params: InitializedParams): CompletableFuture[Unit] =
CompletableFuture.completedFuture {
populateProgramForPath()
}
// @JsonRequest("getTextDocumentService")
// def getTextDocumentService(): CompletableFuture[TextDocumentService] = {
// val completableFuture = new CompletableFuture[InitializeResult]()
// completableFuture.complete(new TextDocumentService())
// completableFuture
// }
// @JsonRequest("getWorkspaceService")
// def getWorkspaceService(): CompletableFuture[WorkspaceService] = ???
@JsonRequest("exit")
def exit(): CompletableFuture[Unit] = ???
@JsonRequest("shutdown")
def shutdown(): CompletableFuture[Object] = ???
@JsonRequest("textDocument/didOpen")
def textDocumentDidOpen(
params: DidOpenTextDocumentParams
): CompletableFuture[Unit] =
CompletableFuture.completedFuture {
val textDocument = params.getTextDocument()
val pathString = trimDocumentUri(textDocument.getUri())
val documentText = textDocument.getText().split("\n").toSeq
rebuildASTForFile(pathString, documentText)
}
@JsonRequest("textDocument/didChange")
def textDocumentDidChange(
params: DidChangeTextDocumentParams
): CompletableFuture[Unit] =
CompletableFuture.completedFuture {
val pathString = trimDocumentUri(params.getTextDocument().getUri())
val documentText =
params.getContentChanges().get(0).getText().split("\n").toSeq
rebuildASTForFile(pathString, documentText)
}
def rebuildASTForFile(pathString: String, text: Seq[String]) = {
logEvent(TelemetryEvent("Rebuilding AST for module at path", pathString))
val path = Paths.get(pathString)
val moduleName = queue.extractName(pathString)
logEvent(
TelemetryEvent(
"Path",
Map("path" -> path.toString(), "module" -> moduleName)
)
)
val newProgram = queue.parseModuleWithLines(
moduleName,
path,
text,
context.includePath,
Left(None),
Nil
)
if (newProgram.isDefined) {
cachedModules.put(moduleName, newProgram.get)
moduleNames.put(pathString, moduleName)
modulePaths.put(moduleName, Paths.get(pathString))
logEvent(
TelemetryEvent(
"Finished rebuilding AST for module at path",
pathString
)
)
} else {
logEvent(
TelemetryEvent("Failed to rebuild AST for module at path", pathString)
)
}
}
@JsonRequest("textDocument/definition")
def textDocumentDefinition(
params: DefinitionParams
): CompletableFuture[Location] =
CompletableFuture.completedFuture {
val activePosition = params.getPosition()
val statement = findExpressionAtPosition(
trimDocumentUri(params.getTextDocument().getUri()),
Position(
"",
activePosition.getLine() + 1,
activePosition.getCharacter() + 2,
0
)
)
if (statement.isDefined) {
val (module, declaration) = statement.get
locationForExpression(declaration, module)
} else null
}
@JsonRequest("textDocument/references")
def textDocumentReferences(
params: ReferenceParams
): CompletableFuture[ju.List[Location]] =
CompletableFuture.completedFuture {
val activePosition = params.getPosition()
val statement = findExpressionAtPosition(
trimDocumentUri(params.getTextDocument().getUri()),
Position(
"",
activePosition.getLine() + 1,
activePosition.getCharacter() + 2,
0
)
)
if (statement.isDefined) {
val (declarationModule, declarationContent) = statement.get
logEvent(
TelemetryEvent("Attempting to find references")
)
if (
declarationContent
.isInstanceOf[DeclarationStatement] || declarationContent
.isInstanceOf[ParameterDeclaration]
) {
val matchingExpressions =
// Only include declaration if params specify it
(if (params.getContext().isIncludeDeclaration())
List((declarationModule, declarationContent))
else List()) ++ NodeFinder
.matchingExpressionsForDeclaration(
cachedModules.toStream,
declarationContent
)
logEvent(
TelemetryEvent("Prepping references", matchingExpressions)
)
matchingExpressions
.sortBy {
case (_, expression) =>
expression.position match {
case Some(value) => value.line
case None => 0
}
}
.map {
case (module, expression) => {
try {
locationForExpression(expression, module)
} catch {
case _: Throwable => null
}
}
}
.filter(e => e != null)
.asJava
} else {
null
}
} else null
}
@JsonRequest("textDocument/hover")
def textDocumentHover(
params: TextDocumentPositionParams
): CompletableFuture[Hover] =
CompletableFuture.completedFuture {
val hoverPosition = params.getPosition()
val statement = findExpressionAtPosition(
trimDocumentUri(params.getTextDocument().getUri()),
Position(
"",
// Millfork positions start at 1,2, rather than 0,0, so add to each coord
hoverPosition.getLine() + 1,
hoverPosition.getCharacter() + 2,
0
)
)
if (statement.isDefined) {
val (_, declarationContent) = statement.get
val formatting = NodeFormatter.symbol(declarationContent)
val docstring = NodeFormatter.docstring(declarationContent)
if (formatting.isDefined)
new Hover(
new MarkupContent(
"markdown",
NodeFormatter.hover(
formatting.get,
docstring.getOrElse("")
)
)
)
else null
} else null
}
/**
* Builds the AST for the entire program, based on the configured "inputFileNames"
*/
private def populateProgramForPath() = {
logEvent(
TelemetryEvent("Building program AST")
)
var program = queue.run()
logEvent(
TelemetryEvent("Finished building AST")
)
cachedProgram = Some(program)
program.parsedModules.foreach {
case (moduleName, program) =>
cachedModules.put(moduleName, program)
}
program.modulePaths.foreach {
case (moduleName, path) => modulePaths.put(moduleName, path)
}
}
private def moduleNameForPath(documentPath: String) =
moduleNames.get(documentPath).getOrElse {
throw new Exception("Cannot find module at " + documentPath)
}
private def findExpressionAtPosition(
documentPath: String,
position: Position
): Option[(String, Node)] = {
val moduleName = moduleNameForPath(documentPath)
val currentModuleDeclarations = cachedModules.get(moduleName)
if (currentModuleDeclarations.isEmpty) {
return None
}
val (node, enclosingDeclarations) = NodeFinder.findNodeAtPosition(
currentModuleDeclarations.get,
position
)
if (node.isDefined) {
logEvent(TelemetryEvent("Found node at position", node))
// Build ordered scopes to search through
// First, our current enclosing scope, then the current module (which contains the current scope), then all other modules
val orderedScopes = List(
(moduleName, enclosingDeclarations.get),
(moduleName, currentModuleDeclarations.get.declarations)
) ++ cachedModules.toList
.filter {
case (cachedModuleName, program) => cachedModuleName != moduleName
}
.map {
case (cachedModuleName, program) =>
(cachedModuleName, program.declarations)
}
val usage =
NodeFinder.findDeclarationForUsage(orderedScopes, node.get)
if (usage.isDefined) {
logEvent(TelemetryEvent("Found original declaration", usage))
usage
} else Some((moduleName, node.get))
} else {
logEvent(TelemetryEvent("Cannot find node for position", position))
None
}
}
/**
* Builds highlighted `Location` of a declaration or usage
*/
private def locationForExpression(
expression: Node,
module: String
): Location = {
val name = NodeFinder.extractNodeName(expression)
val position = expression.position.get
val modulePath = modulePaths.getOrElse(
module, {
logEvent(
TelemetryEvent(
"Could not find path for module",
module
)
)
null
}
)
if (expression.isInstanceOf[ImportStatement]) {
// ImportStatement declaration is the entire "file". Set position to 1,1
val importPosition = Position(module, 1, 1, 0)
return new Location(
modulePath.toUri().toString(),
new Range(
mfPositionToLSP4j(importPosition),
mfPositionToLSP4j(importPosition)
)
)
}
val endPosition = if (name.isDefined) {
Position(
module,
position.line,
position.column + name.get.length,
0
)
} else position
new Location(
modulePath.toUri().toString(),
new Range(
mfPositionToLSP4j(position),
mfPositionToLSP4j(endPosition)
)
)
}
private def queue() =
CpuFamily.forType(options.platform.cpu) match {
case CpuFamily.M6502 =>
new MosSourceLoadingQueue(
initialFilenames = context.inputFileNames,
includePath = context.includePath,
options = options
)
case CpuFamily.I80 | CpuFamily.I86 =>
new ZSourceLoadingQueue(
initialFilenames = context.inputFileNames,
includePath = context.includePath,
options = options
)
case CpuFamily.M6809 =>
new MSourceLoadingQueue(
initialFilenames = context.inputFileNames,
includePath = context.includePath,
options = options
)
}
private def mfPositionToLSP4j(
position: Position
): org.eclipse.lsp4j.Position =
new org.eclipse.lsp4j.Position(
position.line - 1,
// If subtracting 1 would be < 0, set to 0
if (position.column < 1) 0 else position.column - 1
)
private def logEvent(event: TelemetryEvent): Unit = {
val languageClient = client.getOrElse {
// Language client not registered
return
}
implicit val formats = Serialization.formats(NoTypeHints)
val serializedEvent = write(event)
languageClient.logMessage(
new MessageParams(MessageType.Log, serializedEvent)
)
}
private def trimDocumentUri(uri: String): String =
uri
.replaceFirst("file:(//)?", "")
// Trim Windows path oddities provided by VSCode (may not be for all LSP clients)
.replaceFirst("%3A", ":")
.replaceFirst("/([A-Za-z]):", "$1:")
}
case class TelemetryEvent(message: String, data: Any = None)

View File

@ -0,0 +1,406 @@
package millfork.language
import millfork.node.{
DeclarationStatement,
Expression,
FunctionCallExpression,
FunctionDeclarationStatement,
Node,
Program,
Position,
VariableDeclarationStatement,
VariableExpression
}
import millfork.parser.ParsedProgram
import scala.collection.mutable
import millfork.node.ExpressionStatement
import millfork.node.ImportStatement
import millfork.node.ParameterDeclaration
import millfork.env.ByConstant
import millfork.env.ByReference
import millfork.env.ByVariable
import millfork.node.ArrayDeclarationStatement
import millfork.node.AliasDefinitionStatement
import millfork.node.IndexedExpression
import millfork.node.Statement
import millfork.node.SumExpression
import millfork.env.ByLazilyEvaluableExpressionVariable
import millfork.node.EnumDefinitionStatement
import millfork.node.LabelStatement
import millfork.node.StructDefinitionStatement
import millfork.node.TypeDefinitionStatement
import millfork.node.UnionDefinitionStatement
object NodeFinder {
/**
* Finds the declaration matching the provided node
*
* @param orderedScopes A list, ordered by decreasing scope (function, local module, global module),
* of tuples containing the module name and all declarations contained wherein
* @param node The node to find the source declaration for
* @return A tuple containing the module name and "declaration" (could be `ParameterDeclaration`, hence
* the `Node` type); `None` otherwise
*/
def findDeclarationForUsage(
orderedScopes: List[(String, List[DeclarationStatement])],
node: Node
): Option[(String, Node)] = {
node match {
case importStatement: ImportStatement =>
Some((importStatement.filename, importStatement))
case expression: Expression => {
for ((moduleName, scopedDeclarations) <- orderedScopes) {
val declaration =
matchingDeclarationForExpression(
expression,
scopedDeclarations
)
if (declaration.isDefined) {
return Some((moduleName, declaration.get))
}
}
return None
}
case default => None
}
}
/**
* Searches for the declaration matching the type and name of the provided expression
*
* @param expression The expression to find the root declaration for
* @param declarations The declarations to check
* @return The matching declaration if found; `None` otherwise
*/
private def matchingDeclarationForExpression(
expression: Expression,
declarations: List[DeclarationStatement]
): Option[Node] =
expression match {
case FunctionCallExpression(name, expressions) =>
declarations
.filter(d => d.isInstanceOf[FunctionDeclarationStatement])
.find(d => d.name == name)
case VariableExpression(name) =>
matchVariableExpressionName(name, declarations)
case IndexedExpression(name, index) =>
matchVariableExpressionName(name, declarations)
case default => None
}
/**
* Searches for the declaration matching a variable name
*
* @param name The name of the variable
* @param declarations The declarations to check
* @return The matching declaration if found; `None` otherwise
*/
private def matchVariableExpressionName(
name: String,
declarations: List[DeclarationStatement]
) =
declarations
.flatMap(d =>
d match {
// Extract nested declarations (and `ParameterDeclaration`s, which do not extend `DeclarationStatement`)
// from functions
case functionDeclaration: FunctionDeclarationStatement =>
recursivelyFlatten(functionDeclaration)
.filter(e =>
e.isInstanceOf[DeclarationStatement] || e
.isInstanceOf[ParameterDeclaration]
)
case default => List(default)
}
)
.find(d =>
d match {
case variableDeclaration: VariableDeclarationStatement =>
variableDeclaration.name == name
case ParameterDeclaration(typ, assemblyParamPassingConvention) =>
assemblyParamPassingConvention match {
case ByConstant(pName) => pName == name
case ByReference(pName) => pName == name
case ByVariable(pName) => pName == name
case default => false
}
case arrayDeclaration: ArrayDeclarationStatement =>
arrayDeclaration.name == name
case AliasDefinitionStatement(aName, target, important) =>
aName == name
case default => false
}
)
/**
* Finds all expressions referencing a declaration
*
* @param parsedModules All program modules
* @param declaration The declaration to find all references for
* @return A list of tuples, containing the module name and the corresponding expression
*/
def matchingExpressionsForDeclaration(
parsedModules: Stream[(String, Program)],
declaration: Node
): List[(String, Node)] = {
parsedModules.toStream.flatMap {
case (module, program) => {
val allDeclarations =
program.declarations
.flatMap(d => d.getAllExpressions)
.flatMap(flattenNestedExpressions)
declaration match {
case f: FunctionDeclarationStatement =>
allDeclarations
.filter(d => d.isInstanceOf[FunctionCallExpression])
.map(d => d.asInstanceOf[FunctionCallExpression])
.filter(d => d.functionName == f.name)
.map(d => (module, d))
case v: VariableDeclarationStatement =>
allDeclarations
.filter(d => d.isInstanceOf[VariableExpression])
.map(d => d.asInstanceOf[VariableExpression])
.filter(d => d.name == v.name)
.map(d => (module, d))
case a: ArrayDeclarationStatement =>
allDeclarations
.filter(d => d.isInstanceOf[IndexedExpression])
.map(d => d.asInstanceOf[IndexedExpression])
.filter(d => d.name == a.name)
.map(d => (module, d))
case p: ParameterDeclaration => {
val pName = p.assemblyParamPassingConvention match {
case ByConstant(name) => Some(name)
case ByReference(name) => Some(name)
case ByVariable(name) => Some(name)
case ByLazilyEvaluableExpressionVariable(name) =>
Some(name)
case _ => None
}
if (pName.isDefined) {
allDeclarations
.filter(d => extractNodeName(d) == pName)
.map(d => (module, d))
} else List()
}
case default => List()
}
}
}.toList
}
/**
* Finds the node and enclosing declaration scope for a given position
*
* @param program The program containing the position
* @param position The position of the node to find
* @return A tuple containing the found node, and a list of enclosing declaration scopes
*/
def findNodeAtPosition(
program: Program,
position: Position
): (Option[Node], Option[List[DeclarationStatement]]) = {
val line = position.line
val column = position.column
val declarations =
findEnclosingDeclarationsAtLine(program.declarations, line)
if (declarations.isEmpty) {
return (None, None)
}
if (lineOrNegOne(declarations.get.head.position) != line) {
// Declaration is a function or similar wrapper
// Find inner expressions
if (declarations.get.length > 1) {
throw new Exception("Unexpected number of declarations")
}
return (
findNodeAtColumn(
recursivelyFlatten(declarations.get.head),
line,
column
),
declarations
)
}
// All declarations are current line, find matching node for column
(
findNodeAtColumn(
declarations.get
.flatMap(recursivelyFlatten),
line,
column
),
declarations
)
}
/**
* Finds the narrowest top level declaration scope for the given line (typically enclosing function)
*
* @param declarations All program declarations in the file
* @param line The line to search for
*/
private def findEnclosingDeclarationsAtLine(
declarations: List[DeclarationStatement],
line: Int
): Option[List[DeclarationStatement]] = {
var lastDeclarations: Option[List[DeclarationStatement]] =
if (declarations.length > 0) Some(declarations.take(1)) else None
for ((nextDeclaration, i) <- declarations.view.zipWithIndex) {
val nextLine = lineOrNegOne(nextDeclaration.position)
if (nextLine == line) {
// Declaration is on this line
// Check for additional declarations on this line
val newDeclarations = mutable.MutableList(nextDeclaration)
for (declarationIndex <- i to declarations.length - 1) {
val checkDeclaration = declarations(declarationIndex)
if (
checkDeclaration.position.isDefined && checkDeclaration.position.get.line == line
) {
newDeclarations += checkDeclaration
} else {
// Line doesn't match, done with this line
return Some(newDeclarations.toList)
}
}
return Some(newDeclarations.toList)
} else if (nextLine < line) {
// Closer to desired line
lastDeclarations = Some(List(nextDeclaration))
}
}
lastDeclarations
}
/**
* Searches for closest column index less than the selected column on a given line
*/
private def findNodeAtColumn(
nodes: List[Node],
line: Int,
column: Int
): Option[Node] = {
var lastNode: Option[Node] = None
var lastPosition: Option[Position] = None
// Only consider nodes on this line (if we're opening a declaration, it could span multiple lines)
var flattenedNodes = nodes.flatMap(flattenNestedExpressions)
for (nextNode <- flattenedNodes)
if (lineOrNegOne(nextNode.position) == line) {
if (nextNode.position.isEmpty) {
throw new Error("Missing position for node " + nextNode.toString())
}
if (
colOrNegOne(nextNode.position) < column && colOrNegOne(
nextNode.position
// Allow equality, because later nodes are of higher specificity
) >= colOrNegOne(lastPosition)
) {
lastNode = Some(nextNode)
lastPosition = nextNode.position
}
}
lastNode
}
private def lineOrNegOne(position: Option[Position]): Int =
position match {
case Some(pos) => pos.line
case None => -1
}
private def colOrNegOne(position: Option[Position]): Int =
position match {
case Some(pos) => pos.column
case None => -1
}
/**
* Recursively flattens a node tree into a tree containing the node itself and all of its "owned" nodes.
* In particular, function declarations don't call `getAllExpressions`, instead opting to manually pull out
* `params` and `statements` to allow for proper access
*
* @param node The root of the tree to process
* @return A flatten list of all nodes
*/
private def recursivelyFlatten(node: Node): List[Node] =
node match {
case functionDeclaration: FunctionDeclarationStatement =>
List(
functionDeclaration
) ++ functionDeclaration.params ++ functionDeclaration.statements
.getOrElse(List())
.flatMap(recursivelyFlatten)
case statement: Statement =>
List(statement) ++ statement.getAllExpressions.flatMap(
recursivelyFlatten
)
case expression: Expression => List(expression)
case default => List(default)
}
/**
* Returns all of the expressions contained within an expression, including itself
*/
private def flattenNestedExpressions(node: Node): List[Node] =
node match {
case statement: ExpressionStatement => {
val innerExpressions = flattenNestedExpressions(
statement.expression
)
List(statement.expression) ++ innerExpressions
}
case functionExpression: FunctionCallExpression =>
List(functionExpression) ++
functionExpression.expressions
.flatMap(flattenNestedExpressions)
case indexExpression: IndexedExpression =>
List(indexExpression) ++
flattenNestedExpressions(indexExpression.index)
case sumExpression: SumExpression =>
List(sumExpression) ++ sumExpression.expressions.flatMap(e =>
flattenNestedExpressions(e._2)
)
case default => List(default)
}
/**
* Returns the name of the node, if it exists
*/
def extractNodeName(node: Node): Option[String] =
node match {
case a: AliasDefinitionStatement => Some(a.name)
case a: ArrayDeclarationStatement => Some(a.name)
case e: EnumDefinitionStatement => Some(e.name)
case f: FunctionCallExpression => Some(f.functionName)
case f: FunctionDeclarationStatement => Some(f.name)
case l: LabelStatement => Some(l.name)
case s: StructDefinitionStatement => Some(s.name)
case t: TypeDefinitionStatement => Some(t.name)
case u: UnionDefinitionStatement => Some(u.name)
case v: VariableDeclarationStatement => Some(v.name)
case v: VariableExpression => Some(v.name)
case _ => None
}
}

View File

@ -0,0 +1,293 @@
package millfork.language
import millfork.node.Node
import millfork.node.DeclarationStatement
import millfork.node.FunctionDeclarationStatement
import millfork.node.ParameterDeclaration
import millfork.node.Expression
import millfork.node.VariableExpression
import millfork.node.VariableDeclarationStatement
import millfork.node.LiteralExpression
import millfork.env.ParamPassingConvention
import millfork.env.ByConstant
import millfork.env.ByVariable
import millfork.env.ByReference
import millfork.node.ImportStatement
import millfork.env.ByLazilyEvaluableExpressionVariable
import millfork.env.ByMosRegister
import millfork.node.MosRegister
import millfork.env.ByZRegister
import millfork.node.ZRegister
import millfork.env.ByM6809Register
import millfork.node.M6809Register
import millfork.node.ArrayDeclarationStatement
import millfork.output.MemoryAlignment
import millfork.output.NoAlignment
import millfork.output.DivisibleAlignment
import millfork.output.WithinPageAlignment
import millfork.node.AliasDefinitionStatement
import java.util.regex.Pattern
import scala.collection.mutable.ListBuffer
object NodeFormatter {
val docstringAsteriskPattern =
Pattern.compile("^\\s*\\*? *", Pattern.MULTILINE)
val docstringParamPattern =
Pattern.compile("@param (\\w+) +(.*)$", Pattern.MULTILINE)
val docstringReturnsPattern =
Pattern.compile("@returns +(.*)$", Pattern.MULTILINE)
// TODO: Remove Option
def symbol(node: Node): Option[String] =
node match {
case statement: DeclarationStatement =>
statement match {
case functionStatement: FunctionDeclarationStatement => {
val builder = new StringBuilder()
if (functionStatement.constPure) {
builder.append("const ")
}
if (functionStatement.interrupt) {
builder.append("interrupt ")
}
if (functionStatement.kernalInterrupt) {
builder.append("kernal_interrupt ")
}
if (functionStatement.assembly) {
builder.append("asm ")
}
// Cannot have both "macro" and "inline"
if (functionStatement.isMacro) {
builder.append("macro ")
} else if (
functionStatement.inlinable.isDefined && functionStatement.inlinable.get
) {
builder.append("inline ")
}
builder.append(
s"""${functionStatement.resultType} ${functionStatement.name}(${functionStatement.params
.map(symbol)
.filter(n => n.isDefined)
.map(n => n.get)
.mkString(", ")})"""
)
Some(builder.toString())
}
case variableStatement: VariableDeclarationStatement => {
val builder = new StringBuilder()
if (variableStatement.constant) {
builder.append("const ")
}
if (variableStatement.volatile) {
builder.append("volatile ")
}
builder.append(
s"""${variableStatement.typ} ${variableStatement.name}"""
)
if (variableStatement.initialValue.isDefined) {
val formattedInitialValue = symbol(
variableStatement.initialValue.get
)
if (formattedInitialValue.isDefined) {
builder.append(s""" = ${formattedInitialValue.get}""")
}
}
Some(builder.toString())
}
case importStatement: ImportStatement =>
Some(s"""import ${importStatement.filename}""")
case arrayStatement: ArrayDeclarationStatement => {
val builder = new StringBuilder()
if (arrayStatement.const) {
builder.append("const ")
}
builder.append(
s"""array(${arrayStatement.elementType}) ${arrayStatement.name}"""
)
if (arrayStatement.length.isDefined) {
val formattedLength = symbol(arrayStatement.length.get)
if (formattedLength.isDefined) {
builder.append(s""" [${formattedLength.get}]""")
}
}
if (arrayStatement.alignment.isDefined) {
val formattedAlignment = symbol(
arrayStatement.alignment.get
)
if (formattedAlignment.isDefined) {
builder.append(s""" align(${formattedAlignment.get})""")
}
}
if (arrayStatement.address.isDefined) {
val formattedAddress = symbol(arrayStatement.address.get)
if (formattedAddress.isDefined) {
builder.append(s""" @ ${formattedAddress.get}""")
}
}
if (arrayStatement.elements.isDefined) {
val formattedInitialValue = arrayStatement.elements.get
.getAllExpressions(false)
.map(e => symbol(e))
.filter(e => e.isDefined)
.map(e => e.get)
.mkString(", ")
builder.append(s""" = [${formattedInitialValue}]""")
}
Some(builder.toString())
}
case AliasDefinitionStatement(name, target, important) => {
val builder = new StringBuilder()
builder.append(s"""alias ${name} = ${target}""")
if (important) {
builder.append("!")
}
Some(builder.toString())
}
// TODO: Finish
case default => None
}
case ParameterDeclaration(typ, assemblyParamPassingConvention) =>
Some(s"""${typ} ${symbol(assemblyParamPassingConvention)}""")
case expression: Expression =>
expression match {
case LiteralExpression(value, _) => Some(s"""${value}""")
case VariableExpression(name) => Some(s"""${name}""")
case default => None
}
case default => None
}
def symbol(paramConvention: ParamPassingConvention): String =
paramConvention match {
case ByConstant(name) => name
case ByVariable(name) => name
case ByReference(name) => name
case ByLazilyEvaluableExpressionVariable(name) => name
case ByMosRegister(register) =>
MosRegister.toString(register).getOrElse("")
case ByZRegister(register) => ZRegister.toString(register).getOrElse("")
case ByM6809Register(register) =>
M6809Register.toString(register).getOrElse("")
}
def symbol(alignment: MemoryAlignment): Option[String] =
alignment match {
case NoAlignment => None
// TOOD: Improve
case DivisibleAlignment(divisor) => Some(s"""${divisor}""")
case WithinPageAlignment => Some("Within page")
}
def docstring(node: Node): Option[String] = {
val docComment = node match {
case f: FunctionDeclarationStatement => f.docComment
case v: VariableDeclarationStatement => v.docComment
case a: ArrayDeclarationStatement => a.docComment
case _ => None
}
if (docComment.isEmpty) {
return None
}
val baseString = docComment.get.text
var strippedString = docstringAsteriskPattern
.matcher(baseString.stripSuffix("*/"))
.replaceAll("")
val matchGroups = new ListBuffer[(String, String, Range)]()
val paramMatcher = docstringParamPattern.matcher(strippedString)
while (paramMatcher.find()) {
matchGroups += (
(
paramMatcher.group(1),
paramMatcher
.group(2),
Range(paramMatcher.start(), paramMatcher.end())
)
)
}
val builder = new StringBuilder(strippedString)
for (param <- matchGroups.reverse) {
val (paramName, description, range) = param
builder.replace(
range.start,
range.end,
s"\n_@param_ `${paramName}` \u2014 ${description.trim()}\n"
)
}
val returnMatch = docstringReturnsPattern.matcher(builder.toString())
if (returnMatch.find()) {
builder.replace(
returnMatch.start(),
returnMatch.end(),
s"\n_@returns_ \u2014 ${returnMatch.group(1).trim()}"
)
}
return Some(builder.toString())
}
/**
* Render the textDocument/hover result into markdown.
*
* @param symbolSignature The signature of the symbol over the cursor, for example
* "def map[B](fn: A => B): Option[B]"
* @param docstring The Markdown documentation string for the symbol.
*/
def hover(
symbolSignature: String,
docstring: String
): String = {
val markdown = new StringBuilder()
if (symbolSignature.nonEmpty) {
markdown
.append("```mfk\n")
.append(symbolSignature)
.append("\n```")
}
if (docstring.nonEmpty)
markdown
.append("\n---\n")
.append(docstring)
.append("\n")
markdown.toString()
}
}

View File

@ -253,12 +253,46 @@ object M6809NiceFunctionProperty {
object MosRegister extends Enumeration {
val A, X, Y, AX, AY, YA, XA, XY, YX, AW = Value
private val registerStringToValue = Map[String, MosRegister.Value](
"xy" -> MosRegister.XY,
"yx" -> MosRegister.YX,
"ax" -> MosRegister.AX,
"ay" -> MosRegister.AY,
"xa" -> MosRegister.XA,
"ya" -> MosRegister.YA,
"a" -> MosRegister.A,
"x" -> MosRegister.X,
"y" -> MosRegister.Y,
)
private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap
def fromString(name: String): Option[MosRegister.Value] = registerStringToValue.get(name)
def toString(value: MosRegister.Value): Option[String] = registerValueToString.get(value)
}
object ZRegister extends Enumeration {
val A, B, C, D, E, H, L, AF, BC, HL, DE, SP, IXH, IXL, IYH, IYL, IX, IY, R, I, MEM_HL, MEM_BC, MEM_DE, MEM_IX_D, MEM_IY_D, MEM_ABS_8, MEM_ABS_16, IMM_8, IMM_16 = Value
val registerStringToValue = Map[String, ZRegister.Value](
"hl" -> ZRegister.HL,
"bc" -> ZRegister.BC,
"de" -> ZRegister.DE,
"a" -> ZRegister.A,
"b" -> ZRegister.B,
"c" -> ZRegister.C,
"d" -> ZRegister.D,
"e" -> ZRegister.E,
"h" -> ZRegister.H,
"l" -> ZRegister.L,
)
private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap
def fromString(name: String): Option[ZRegister.Value] = registerStringToValue.get(name)
def toString(value: ZRegister.Value): Option[String] = registerValueToString.get(value)
def registerSize(reg: Value): Int = reg match {
case AF | BC | DE | HL | IX | IY | IMM_16 => 2
case A | B | C | D | E | H | L | IXH | IXL | IYH | IYL | R | I | IMM_8 => 1
@ -280,6 +314,24 @@ object ZRegister extends Enumeration {
object M6809Register extends Enumeration {
val A, B, D, DP, X, Y, U, S, PC, CC = Value
val registerStringToValue = Map[String, M6809Register.Value](
"x" -> M6809Register.X,
"y" -> M6809Register.Y,
"s" -> M6809Register.S,
"u" -> M6809Register.U,
"a" -> M6809Register.A,
"b" -> M6809Register.B,
"d" -> M6809Register.D,
"dp" -> M6809Register.DP,
"pc" -> M6809Register.PC,
"cc" -> M6809Register.CC,
)
private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap
def fromString(name: String): Option[M6809Register.Value] = registerStringToValue.get(name)
def toString(value: M6809Register.Value): Option[String] = registerValueToString.get(value)
def registerSize(reg: Value): Int = reg match {
case D | X | Y | U | S | PC => 2
case A | B | DP | CC => 1
@ -445,6 +497,18 @@ sealed trait Statement extends Node {
sealed trait DeclarationStatement extends Statement {
def name: String
var docComment: Option[DocComment] = None
}
object DeclarationStatement {
implicit class DeclarationStatementOps[D<:DeclarationStatement](val declaration: D) extends AnyVal {
def docComment(comment: Option[DocComment]): D = {
if (comment.isDefined) {
declaration.docComment = comment
}
declaration
}
}
}
sealed trait BankedDeclarationStatement extends DeclarationStatement {
@ -833,4 +897,6 @@ object MosAssemblyStatement {
def implied(opcode: Opcode.Value, elidability: Elidability.Value) = MosAssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidability)
def nonexistent(opcode: Opcode.Value) = MosAssemblyStatement(opcode, AddrMode.DoesNotExist, LiteralExpression(0, 1), elidability = Elidability.Elidable)
}
}
case class DocComment(text: String) extends Node {}

View File

@ -1,7 +1,7 @@
package millfork.parser
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import java.nio.file.{Files, Path, Paths}
import fastparse.core.Parsed.{Failure, Success}
import millfork.{CompilationFlag, CompilationOptions, Tarjan}
@ -10,11 +10,14 @@ import millfork.node.{AliasDefinitionStatement, DeclarationStatement, ImportStat
import scala.collection.mutable
import scala.collection.convert.ImplicitConversionsToScala._
case class ParsedProgram(compilationOrderProgram: Program, parsedModules: Map[String, Program], modulePaths: Map[String, Path])
abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String],
val includePath: List[String],
val options: CompilationOptions) {
protected val parsedModules: mutable.Map[String, Program] = mutable.Map[String, Program]()
protected val modulePaths: mutable.Map[String, Path] = mutable.Map[String, Path]()
protected val moduleDependecies: mutable.Set[(String, String)] = mutable.Set[(String, String)]()
protected val moduleQueue: mutable.Queue[() => Unit] = mutable.Queue[() => Unit]()
val extension: String = ".mfk"
@ -41,7 +44,12 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String],
encodingConversionAliases
}
def run(): Program = {
/**
* Tokenizes and parses the configured source file and modules
*
* @return A ParsedProgram containing an ordered set of statements in order of compilation dependencies, and each individual parsed module
*/
def run(): ParsedProgram = {
for {
initialFilename <- initialFilenames
startingModule <- options.platform.startingModules
@ -76,7 +84,8 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String],
options.log.assertNoErrors("Parse failed")
val compilationOrder = Tarjan.sort(parsedModules.keys, moduleDependecies)
options.log.debug("Compilation order: " + compilationOrder.mkString(", "))
compilationOrder.filter(parsedModules.contains).map(parsedModules).reduce(_ + _).applyImportantAliases
ParsedProgram(compilationOrder.filter(parsedModules.contains).map(parsedModules).reduce(_ + _).applyImportantAliases, parsedModules.toMap, modulePaths.toMap)
}
def lookupModuleFile(includePath: List[String], moduleName: String, position: Option[Position]): String = {
@ -98,13 +107,24 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String],
if (templateParams.isEmpty) moduleNameBase else moduleNameBase + templateParams.mkString("<", ",", ">")
}
/**
* Finds module path and builds module AST, adding to `parsedModules`
*/
def parseModule(moduleName: String, includePath: List[String], why: Either[Option[Position], String], templateParams: List[String]): Unit = {
val filename: String = why.fold(p => lookupModuleFile(includePath, moduleName, p), s => s)
options.log.debug(s"Parsing $filename")
val path = Paths.get(filename)
modulePaths.put(moduleName, path)
parseModuleWithLines(moduleName, path, Files.readAllLines(path, StandardCharsets.UTF_8).toIndexedSeq, includePath, why, templateParams)
}
def parseModuleWithLines(moduleName: String, path: Path, lines: Seq[String], includePath: List[String], why: Either[Option[Position], String], templateParams: List[String]): Option[Program] = {
val parentDir = path.toFile.getAbsoluteFile.getParent
val shortFileName = path.getFileName.toString
val PreprocessingResult(src, featureConstants, pragmas) = Preprocessor(options, shortFileName, Files.readAllLines(path, StandardCharsets.UTF_8).toIndexedSeq, templateParams)
val PreprocessingResult(src, featureConstants, pragmas) = Preprocessor(options, shortFileName, lines, templateParams)
for (pragma <- pragmas) {
if (!supportedPragmas(pragma._1) && options.flag(CompilationFlag.BuggyCodeWarning)) {
options.log.warn(s"Unsupported pragma: #pragma ${pragma._1}", Some(Position(moduleName, pragma._2, 1, 0)))
@ -126,16 +146,19 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String],
case _ => ()
}
}
Some(prog)
case f@Failure(a, b, d) =>
options.log.error(s"Failed to parse the module `$moduleName` in $filename", Some(parser.indexToPosition(f.index, parser.lastLabel)))
options.log.error(s"Failed to parse the module `$moduleName` in ${path.toString()}", Some(parser.indexToPosition(f.index, parser.lastLabel)))
if (parser.lastLabel != "") {
options.log.error(s"Syntax error: ${parser.lastLabel} expected", Some(parser.lastPosition))
} else {
options.log.error("Syntax error", Some(parser.lastPosition))
}
None
}
}
// TODO: Separate from Queue
def extractName(i: String): String = {
val noExt = i.stripSuffix(extension)
val lastSlash = noExt.lastIndexOf('/') max noExt.lastIndexOf('\\')

View File

@ -45,21 +45,10 @@ case class M6809Parser(filename: String,
val asmOpcode: P[(MOpcode.Value, Option[MAddrMode])] =
(position() ~ (letter.rep ~ ("2" | "3").?).! ).map { case (p, o) => MOpcode.lookup(o, Some(p), log) }
private def mapRegister(p: (Position, String)): M6809Register.Value = p._2.toLowerCase(Locale.ROOT) match {
case "x" => M6809Register.X
case "y" => M6809Register.Y
case "s" => M6809Register.S
case "u" => M6809Register.U
case "a" => M6809Register.A
case "b" => M6809Register.B
case "d" => M6809Register.D
case "dp" => M6809Register.DP
case "pc" => M6809Register.PC
case "cc" => M6809Register.CC
case _ =>
private def mapRegister(p: (Position, String)): M6809Register.Value = M6809Register.fromString(p._2.toLowerCase(Locale.ROOT)).getOrElse({
log.error("Invalid register " + p._2, Some(p._1))
M6809Register.D
}
})
// only used for TFR, EXG, PSHS, PULS, PSHU, PULU, so it is allowed to accept any register name in order to let parsing continue:
val anyRegister: P[M6809Register.Value] = P(position() ~ identifier.!).map(mapRegister)

View File

@ -83,11 +83,20 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
val comment: P[Unit] = P("//" ~ CharsWhile(c => c != '\n' && c != '\r', min = 0) ~ ("\r\n" | "\r" | "\n"))
val recursiveMultilineCommentContent: P[Unit] = P(CharsWhile(c => c != '*', min = 0) ~ ("*/" | ("*" ~ recursiveMultilineCommentContent)))
val docComment: P[DocComment] = for {
p <- position()
text <- ("/**" ~ recursiveMultilineCommentContent.!)
} yield DocComment(text).pos(p)
val multilineComment: P[Unit] = P("/*" ~ !"*" ~ recursiveMultilineCommentContent)
val semicolon: P[Unit] = P(";" ~ CharsWhileIn("; \t", min = 0) ~ position("line break after a semicolon").map(_ => ()) ~ (comment | "\r\n" | "\r" | "\n").opaque("<line break>"))
val semicolonComment: P[Unit] = P(";" ~ CharsWhile(c => c != '\n' && c != '\r' && c != '{' && c != '}', min = 0) ~ position("line break instead of braces").map(_ => ()) ~ ("\r\n" | "\r" | "\n").opaque("<line break>"))
val AWS: P[Unit] = P((CharIn(" \t\n\r") | semicolon | comment).rep(min = 0)).opaque("<any whitespace>")
val AWS: P[Unit] = P((CharIn(" \t\n\r") | semicolon | comment | multilineComment).rep(min = 0)).opaque("<any whitespace>")
val AWS_asm: P[Unit] = P((CharIn(" \t\n\r") | semicolonComment | comment).rep(min = 0)).opaque("<any whitespace>")
@ -205,9 +214,9 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
case x => x.toString
} | textLiteralAtom.!
val importStatement: P[Seq[ImportStatement]] = ("import" ~ !letterOrDigit ~/ SWS ~/
val importStatement: P[Seq[ImportStatement]] = (position() ~ "import" ~ !letterOrDigit ~/ SWS ~/
identifier.rep(min = 1, sep = "/") ~ HWS ~ ("<" ~/ HWS ~/ quotedAtom.rep(min = 1, sep = HWS ~ "," ~/ HWS) ~/ HWS ~/ ">" ~/ Pass).?).
map{case (name, params) => Seq(ImportStatement(name.mkString("/"), params.getOrElse(Nil).toList))}
map{case (p, name, params) => Seq(ImportStatement(name.mkString("/"), params.getOrElse(Nil).toList).pos(p))}
val optimizationHintsDeclaration: P[Set[String]] =
if (options.flag(CompilationFlag.EnableInternalTestSyntax)) {
@ -235,6 +244,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
}
def variableDefinition(implicitlyGlobal: Boolean): P[Seq[BankedDeclarationStatement]] = for {
docComment <- (docComment ~ EOL).?
p <- position()
bank <- bankDeclaration
flags <- variableFlags ~ HWS
@ -250,7 +260,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
constant = flags("const"),
volatile = flags("volatile"),
register = flags("register"),
initialValue, addr, optimizationHints, alignment).pos(p)
initialValue, addr, optimizationHints, alignment).pos(p).docComment(docComment)
}
}
@ -428,6 +438,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
}
val arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for {
docComment <- (docComment ~ EOL).?
p <- position()
bank <- bankDeclaration
const <- ("const".! ~ HWS).?
@ -443,7 +454,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
} yield {
if (alignment1.isDefined && alignment2.isDefined) log.error(s"Cannot define the alignment multiple times", Some(p))
val alignment = alignment1.orElse(alignment2)
Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, const.isDefined, contents, optimizationHints, alignment, options.isBigEndian).pos(p))
Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, const.isDefined, contents, optimizationHints, alignment, options.isBigEndian).pos(p).docComment(docComment))
}
def tightMfExpression(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = {
@ -695,6 +706,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
} yield Seq(DoWhileStatement(body.toList, Nil, condition))
val functionDefinition: P[Seq[BankedDeclarationStatement]] = for {
docComment <- (docComment ~ AWS).?
p <- position()
bank <- bankDeclaration
flags <- functionFlags ~ HWS
@ -754,7 +766,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
flags("interrupt"),
flags("kernal_interrupt"),
flags("const") && !flags("asm"),
flags("reentrant")).pos(p))
flags("reentrant")).pos(p).docComment(docComment))
}
def validateAsmFunctionBody(p: Position, flags: Set[String], name: String, statements: Option[List[Statement]])

View File

@ -115,19 +115,8 @@ case class MosParser(filename: String, input: String, currentDirectory: String,
val asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmMacro | arrayContentsForAsm | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros
override val appcRegister: P[ParamPassingConvention] = P(("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "a" | "x" | "y") ~ !letterOrDigit).!.map {
case "xy" => ByMosRegister(MosRegister.XY)
case "yx" => ByMosRegister(MosRegister.YX)
case "ax" => ByMosRegister(MosRegister.AX)
case "ay" => ByMosRegister(MosRegister.AY)
case "xa" => ByMosRegister(MosRegister.XA)
case "ya" => ByMosRegister(MosRegister.YA)
case "a" => ByMosRegister(MosRegister.A)
case "x" => ByMosRegister(MosRegister.X)
case "y" => ByMosRegister(MosRegister.Y)
case x => log.fatal(s"Unknown assembly parameter passing convention: `$x`")
}
override val appcRegister: P[ParamPassingConvention] = P(("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "a" | "x" | "y") ~ !letterOrDigit).!
.map(name => ByMosRegister(MosRegister.fromString(name).getOrElse(log.fatal(s"Unknown assembly parameter passing convention: `$name`"))))
def validateAsmFunctionBody(p: Position, flags: Set[String], name: String, statements: Option[List[Statement]]): Unit = {
if (!options.flag(CompilationFlag.BuggyCodeWarning)) return

View File

@ -31,19 +31,8 @@ case class Z80Parser(filename: String,
private val zero = LiteralExpression(0, 1)
override val appcRegister: P[ParamPassingConvention] = (P("hl" | "bc" | "de" | "a" | "b" | "c" | "d" | "e" | "h" | "l").! ~ !letterOrDigit).map {
case "a" => ByZRegister(ZRegister.A)
case "b" => ByZRegister(ZRegister.B)
case "c" => ByZRegister(ZRegister.C)
case "d" => ByZRegister(ZRegister.D)
case "e" => ByZRegister(ZRegister.E)
case "h" => ByZRegister(ZRegister.H)
case "l" => ByZRegister(ZRegister.L)
case "hl" => ByZRegister(ZRegister.HL)
case "bc" => ByZRegister(ZRegister.BC)
case "de" => ByZRegister(ZRegister.DE)
case x => log.fatal(s"Unknown assembly parameter passing convention: `$x`")
}
override val appcRegister: P[ParamPassingConvention] = (P("hl" | "bc" | "de" | "a" | "b" | "c" | "d" | "e" | "h" | "l").! ~ !letterOrDigit)
.map(name => ByZRegister(ZRegister.fromString(name).getOrElse(log.fatal(s"Unknown assembly parameter passing convention: `$name`"))))
override val asmParamDefinition: P[ParameterDeclaration] = for {
p <- position()

View File

@ -0,0 +1,143 @@
package millfork.test.language
import org.scalatest.{AppendedClues, FunSpec, Matchers}
import millfork.test.language.util._
import org.eclipse.lsp4j.DidOpenTextDocumentParams
import org.eclipse.lsp4j.TextDocumentItem
import org.eclipse.lsp4j.HoverParams
import org.eclipse.lsp4j.TextDocumentIdentifier
import org.eclipse.lsp4j.Position
import java.util.regex.Pattern
import scala.collection.mutable
class MfLanguageServerSuite extends FunSpec with Matchers with AppendedClues {
describe("hover") {
it("should find node under cursor, and its root declaration") {
val server = LanguageHelper.createServer
LanguageHelper.openDocument(
server,
"file.mfk",
"""
| byte test
| array(byte) foo[4]
| void main() {
| test = test + 1
| foo[1] = test
| }
"""
)
{
// Select `test` variable usage
val hoverParams = new HoverParams(
new TextDocumentIdentifier("file.mfk"),
new Position(4, 3)
)
val response = server.textDocumentHover(hoverParams)
val hover = response.get
val contents = hover.getContents().getRight()
contents should not equal (null)
contents.getValue() should equal(
LanguageHelper.formatHover("byte test")
)
}
{
// Select `main` function
val hoverParams = new HoverParams(
new TextDocumentIdentifier("file.mfk"),
new Position(3, 3)
)
val response = server.textDocumentHover(hoverParams)
val hover = response.get
val contents = hover.getContents().getRight()
contents should not equal (null)
contents.getValue() should equal(
LanguageHelper.formatHover("void main()")
)
}
{
// Select `foo` array usage
val hoverParams = new HoverParams(
new TextDocumentIdentifier("file.mfk"),
new Position(5, 6)
)
val response = server.textDocumentHover(hoverParams)
val hover = response.get
val contents = hover.getContents().getRight()
contents should not equal (null)
contents.getValue() should equal(
LanguageHelper.formatHover("array(byte) foo [4]")
)
}
}
describe("should always produce value") {
val server = LanguageHelper.createServer
val text = """
| byte test
| array(byte) foo[4]
| void main() {
| test += test
| foo[1] = test
| func()
| }
| byte func() {
| byte i
| byte innerValue
| innerValue = 2
| innerValue += innerValue
| return innerValue
| }
""".stripMargin
LanguageHelper.openDocument(
server,
"file.mfk",
text
)
val lines = text.split("\n")
val pattern = Pattern.compile("(return|byte)")
for ((line, i) <- lines.zipWithIndex) {
val matcher = pattern.matcher(line)
val ignoreRanges = mutable.MutableList[Range]()
while (matcher.find()) {
ignoreRanges += Range(matcher.start(), matcher.end())
}
for ((character, column) <- line.toCharArray().zipWithIndex) {
if (
Character.isLetter(character) &&
// Ignore sections of string matching pattern
ignoreRanges.filter(r => r.contains(column)).length == 0
) {
it(s"""should work on ${i}, ${column} contents "${line}" """) {
val hoverParams = new HoverParams(
new TextDocumentIdentifier("file.mfk"),
new Position(i, column + 2)
)
val response = server.textDocumentHover(hoverParams)
val hover = response.get
hover should not equal (null)
val contents = hover.getContents().getRight()
info(contents.toString())
contents should not equal (null)
}
}
}
}
}
}
}

View File

@ -0,0 +1,356 @@
package millfork.test.language
import org.scalatest.{AppendedClues, FunSpec, Matchers}
import millfork.test.language.util._
import org.eclipse.lsp4j.DidOpenTextDocumentParams
import org.eclipse.lsp4j.TextDocumentItem
import org.eclipse.lsp4j.HoverParams
import org.eclipse.lsp4j.TextDocumentIdentifier
import millfork.language.NodeFinder
import millfork.node.Position
import millfork.node.FunctionDeclarationStatement
import millfork.node.ExpressionStatement
import millfork.node.FunctionCallExpression
import millfork.node.Assignment
import java.util.regex.Pattern
import millfork.node.Program
import millfork.node.IndexedExpression
import millfork.node.SumExpression
class NodeFinderSuite extends FunSpec with Matchers with AppendedClues {
def createProgram(text: String): Program = {
val server = LanguageHelper.createServer
LanguageHelper
.openDocument(
server,
"file.mfk",
text
)
server.cachedModules.get("file").get
}
describe("nodeAtPosition") {
val text = """
|
| byte test
| array(byte) foo[4]
| void main() {
| test += test
| foo[1] = test
| func()
| }
| byte func(byte arg) {
| byte i
| byte innerValue
| innerValue = 2
| innerValue += innerValue
| innerValue += arg
| return innerValue
| }
""".stripMargin
val program = createProgram(text)
def findRangeOfString(
text: String,
textMatch: String,
afterLine: Int = 0
): (Int, Range) = {
val pattern = Pattern.compile(s"(${Pattern.quote(textMatch)})")
val lines = text.split("\n")
for ((line, i) <- lines.zipWithIndex) {
if (i >= afterLine) {
val matcher = pattern.matcher(line)
if (matcher.find()) {
return (i + 1, Range(matcher.start() + 2, matcher.end() + 2))
}
}
}
throw new Error(s"Cound not find pattern ${textMatch}")
}
it("should find root variable declarations") {
val (line, range) = findRangeOfString(text, "test")
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._2
.get(0) should equal(
program.declarations(0)
)
}
}
it("should find root array declarations") {
val (line, range) = findRangeOfString(text, "foo[4]")
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._2
.get(0) should equal(
program.declarations(1)
)
}
}
it("should find function declarations") {
val (line, range) = findRangeOfString(text, "main()")
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._2
.get(0) should equal(
program.declarations(2)
)
}
}
it("should find variable expression within function") {
val (line, range) = findRangeOfString(text, "test", 4)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(0)
.asInstanceOf[ExpressionStatement]
.expression
.asInstanceOf[FunctionCallExpression]
.expressions(0)
)
}
}
it("should find array expression within function") {
val (line, range) = findRangeOfString(text, "foo", 4)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(1)
.asInstanceOf[Assignment]
.destination
)
}
}
it("should find right hand side of assignment") {
val (line, range) = findRangeOfString(text, "test", 5)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(1)
.asInstanceOf[Assignment]
.source
)
}
}
it("should find function call") {
val (line, range) = findRangeOfString(text, "func()")
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(2)
.asInstanceOf[ExpressionStatement]
.expression
)
}
}
it("should find function argument") {
val (line, range) = findRangeOfString(text, "arg")
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(3)
.asInstanceOf[FunctionDeclarationStatement]
.params(0)
)
}
}
it("should find function nested variable declarations") {
val (line, range) = findRangeOfString(text, "i", 7)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(3)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(0)
)
}
}
it("should find variable used to index array") {
val innerText = """
| byte root
| array(byte) anArray[10]
| void main() {
| byte index
| index = 4
| root = anArray[index]
| index = anArray[root+1]
| }
""".stripMargin
val program = createProgram(innerText)
{
// Standard indexing
val (line, range) = findRangeOfString(innerText, "index", 6)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(2)
.asInstanceOf[Assignment]
.source
.asInstanceOf[IndexedExpression]
.index
)
}
}
{
// Indexing within sum expression
val (line, range) = findRangeOfString(innerText, "root", 7)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
program
.declarations(2)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(3)
.asInstanceOf[Assignment]
.source
.asInstanceOf[IndexedExpression]
.index
.asInstanceOf[SumExpression]
.expressions(0)
._2
)
}
}
}
it("should find variables within a sum") {
val innerText = """
| byte valA
| byte valB
| byte valC
| byte output
| void main() {
| output = valB + valA - 102 + valC
| }
""".stripMargin
val program = createProgram(innerText)
val sumExpression = program
.declarations(4)
.asInstanceOf[FunctionDeclarationStatement]
.statements
.get(0)
.asInstanceOf[Assignment]
.source
.asInstanceOf[SumExpression]
{
val (line, range) = findRangeOfString(innerText, "valA", 5)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
sumExpression.expressions(1)._2
)
}
}
{
val (line, range) = findRangeOfString(innerText, "valB", 5)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
sumExpression.expressions(0)._2
)
}
}
{
val (line, range) = findRangeOfString(innerText, "valC", 5)
for (column <- range) {
NodeFinder
.findNodeAtPosition(program, Position("", line, column, 0))
._1
.get should equal(
sumExpression.expressions(3)._2
)
}
}
}
// TODO: Additional tests:
// Fields on array indexing: spawn_info[index].hi
// Struct type: Player player
// Struct fields: player1.pos
// Messed up hover positions (in `nes_reset_joy.mfk`, each variable assignment to 0)
// Alias references
// Pointers: obj_ptr->xvel
}
}

View File

@ -0,0 +1,45 @@
package millfork.test.language.util
import millfork.error.Logger
import millfork.{NullLogger, Platform}
import millfork.language.MfLanguageServer
import millfork.Context
import millfork.CompilationOptions
import millfork.test.emu.{EmuPlatform, TestErrorReporting}
import millfork.Cpu
import millfork.JobContext
import millfork.compiler.LabelGenerator
import org.eclipse.lsp4j.DidOpenTextDocumentParams
import org.eclipse.lsp4j.TextDocumentItem
import org.eclipse.lsp4j.HoverParams
import org.eclipse.lsp4j.TextDocumentIdentifier
import org.eclipse.lsp4j.Position
object LanguageHelper {
def createServer(): MfLanguageServer = {
implicit val logger: Logger = new NullLogger()
val platform = EmuPlatform.get(Cpu.Mos)
val jobContext = JobContext(TestErrorReporting.log, new LabelGenerator)
new MfLanguageServer(
new Context(logger, List()),
new CompilationOptions(
platform,
Map(),
None,
0,
Map(),
EmuPlatform.textCodecRepository,
jobContext
)
)
}
def openDocument(server: MfLanguageServer, name: String, text: String) = {
val textDocument =
new TextDocumentItem(name, "millfork", 1, text.stripMargin)
val openParams = new DidOpenTextDocumentParams(textDocument)
server.textDocumentDidOpen(openParams)
}
def formatHover(text: String): String = s"""```mfk\n${text}\n```"""
}