LuaComp/src/ast2.lua

474 lines
11 KiB
Lua

-- AST Generator v2: Belkan Boogaloo
-- Hopefully faster than v1
-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at https://mozilla.org/MPL/2.0/.
local ast = {}
do
local ws = "\t "
function ast.str_to_stream(str, file)
local s = {
str = str,
pos = 1,
file = file or "(unknown)"
}
function s:next(c)
c = c or 1
--dprint(c)
local d = self.str:sub(self.pos, self.pos+c-1)
self.pos = self.pos + c
return d
end
function s:peek(c)
c = c or 1
if (c < 0) then
return self.str:sub(self.pos+c, self.pos-1)
end
return self.str:sub(self.pos, self.pos+c-1)
end
function s:rewind(c)
c = c or 1
self.pos = self.pos - c
return self.pos
end
function s:skip(c)
c = c or 1
self.pos = self.pos + c
return self.pos
end
function s:set(c)
--dprint(c)
self.pos = c or self.pos
return self.pos
end
function s:tell()
return self.pos
end
function s:size()
return #self.str
end
function s:next_instance(pat, raw)
local st, en = self.str:find(pat, self.pos, raw)
if not st then return nil, "not found" end
self.pos = en+1
return self.str:sub(st, en)
end
function s:get_yx() -- it *is* yx
local pos = 0
local line = 1
while pos < self.pos do
local newpos = self.str:find("\n", pos+1)
if not newpos then return line+1, 0 end
if newpos > self.pos then
return line, self.pos-pos
end
line = line + 1
pos = newpos
end
return line, 1
end
return s
end
local esct = {
["t"] = "\t",
["n"] = "\n",
["r"] = "\r",
--["\\"] = "\\\\"
}
function ast.parser_error(str, err)
local y, x = str:get_yx()
--print(y, x)
lc_error("@[{_GENERATOR.fname}]", string.format("%s(%d:%d): %s\n", str.file, y or 0, x or 0, err))
end
function ast.unescape(escaped_string)
local i = 1
local out_string = ""
while i <= #escaped_string do
local c = escaped_string:sub(i,i)
if (c == "\\") then
i = i + 1
local nc = escaped_string:sub(i,i)
if esct[nc] then
out_string = out_string .. esct[nc]
else
out_string = out_string .. nc
end
else
out_string = out_string .. c
end
i = i + 1
end
return out_string
end
function ast.remove_escapes(escaped_string)
local i = 1
local out_string = ""
while i <= #escaped_string do
local c = escaped_string:sub(i,i)
--lc_warning(c, tostring(i).." "..#escaped_string)
if (c == "\\") then
i = i + 1
else
out_string = out_string .. c
end
i = i + 1
end
--lc_warning("debug", out_string)
return out_string
end
function ast.back_escape_count(str, start)
local i=2
while str:peek(-i):sub(1,1) == "\\" do
i = i + 1
if (str:tell()-i < start) then
ast.error(str, "internal error")
end
end
--lc_warning(tostring(i), #str:peek(1-i).." "..str:peek(1-i))
return str:peek(1-i)
end
function ast.parse_quote(str)
local spos = str:tell()
while true do
if not str:next_instance("\'") then
ast.parser_error(str, "unclosed string")
end
local rpos = str:tell()
str:set(spos)
if str:next_instance("\n") then
if rpos > str:tell() then
ast.parser_error(str, "unclosed string")
end
end
str:set(rpos)
if str:peek(-1) == "\\" then
local parsed = ast.remove_escapes(ast.back_escape_count(str, spos))
if parsed:sub(#parsed) == "\'" then
goto found_end
end
else
goto found_end
end
end
::found_end::
local epos = str:tell()
local amt = epos-spos-1
str:set(spos)
local esc = str:next(amt)
str:skip(1)
return ast.unescape(esc)
end
function ast.parse_dblquote(str)
local spos = str:tell()
while true do
if not str:next_instance("\"") then
ast.parser_error(str, "unclosed string")
end
local rpos = str:tell()
str:set(spos)
if str:next_instance("\n") then
if rpos > str:tell() then
ast.parser_error(str, "unclosed string")
end
end
str:set(rpos)
--lc_warning(str:peek(-2), "test")
if str:peek(-2):sub(1,1) == "\\" then
local parsed = ast.remove_escapes(ast.back_escape_count(str, spos))
if parsed:sub(#parsed) == "\"" then
goto found_end
end
else
goto found_end
end
--str:set(rpos)
end
::found_end::
local epos = str:tell()
local amt = epos-spos-1
str:set(spos)
--dprint(spos, amt)
local esc = str:next(amt)
--print(esc)
str:skip(1)
return ast.unescape(esc)
end
function ast.parse_hex(str)
local hex = str:next_instance("%x+")
if not hex then
ast.parser_error(str, "internal error")
end
return tonumber(hex, 16)
end
function ast.parse_number(str)
local num = str:next_instance("%d+")
if not num then
ast.parser_error(str, "internal error")
end
return tonumber(num, 10)
end
function ast.parse_envvar(str)
local name = str:next_instance("[^)]+")
if not name then
ast.parser_error(str, "unclosed shell var")
end
str:skip(1)
return name
end
-- [{...}]
function ast.parse_span(str)
local spos = str:tell()
if not str:next_instance("}]", true) then
ast.parser_error(str, "unclosed block")
else
local rpos = str:tell()
str:set(spos)
if str:next_instance("\n") then
if str:tell() < rpos then
str:set(spos)
ast.parser_error(str, "unclosed span")
end
end
str:set(rpos)
end
local epos = str:tell()
str:set(spos)
local data = str:next(epos-spos-2)
str:skip(2)
return data
end
-- [[...]]
function ast.parse_block(str)
local spos = str:tell()
if not str:next_instance("]]") then
ast.parser_error(str, "unclosed block")
end
local epos = str:tell()
str:set(spos)
local data = str:next(epos-spos-2)
str:skip(2)
return data
end
function ast.parse_directive(str) -- And now we start getting more complex.
local name = str:next_instance("[^ ]+")
local args = {}
while true do
local spos = str:tell()
if not str:next_instance(" +") then
break
else
local rpos = str:tell()
if str:next_instance("\n") then
if str:tell() < rpos then
str:set(spos)
break
end
str:set(rpos)
end
end
local apos = str:tell()
if str:peek(2) == "0x" then
str:skip(2)
local n = ast.parse_hex(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed hex")
end
table.insert(args, n)
elseif str:peek():find("%d") then
local n = ast.parse_number(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed number")
end
table.insert(args, n)
elseif str:peek() == "\"" then
str:skip(1)
local sval = ast.parse_dblquote(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed string")
end
table.insert(args, sval)
elseif str:peek() == "\'" then
str:skip(1)
local sval = ast.parse_quote(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed string")
end
table.insert(args, sval)
elseif str:peek(2) == "$".."(" then -- i have to avoid the funny
str:skip(2)
local sval = ast.parse_envvar(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed argument")
end
table.insert(args, {type="evar", val=sval})
elseif str:peek(3) == "@".."[{" then
str:skip(3)
local sval = ast.parse_span(str)
local c = str:peek()
if c ~= " " and c ~= "\n" and c ~= "" then
str:set(apos)
ast.parser_error(str, "malformed code block")
end
table.insert(args, {type="lua_span", val=sval})
elseif str:peek() == "\n" then
break
else
ast.parser_error(str, "unknown arg type")
end
if str:peek() == "\n" then
break
end
end
return {
type="directive",
name = name,
args = args
}
end
function ast.find_first(str, onfind, ...)
local t = table.pack(...)
local spos = str:tell()
local epos = math.huge
local ematch
for i=1, t.n do
str:set(spos)
local m = str:next_instance(t[i], true)
if m then
if str:tell() < epos then
if onfind then
if not onfind(str, m) then goto continue end
end
epos = str:tell()
ematch = m
end
end
::continue::
end
if ematch then
str:set(epos)
else
str:set(spos)
end
return ematch
end
function ast.add_debugging_info(list, str, sx, sy)
if DEBUGGING then
local node = list[#list]
node.sx = sx
node.sy = sy
node.ey, node.ex = str:get_yx()
node.file = str.file
if not str.file then
luacomp.error("Node has no file!\n"..debug.traceback())
end
end
end
-- And now we parse
function ast.parse(str)
local cast = {}
while true do
local spos = str:tell()
--dprint("searching")
local match = ast.find_first(str, function(str, submatch)
if (submatch == "--#") then
--dprint("directive?")
local i=4
while true do
if str:peek(-i):sub(1,1) == "\n" or str:peek(-i):sub(1,1) == "" or str:tell() == 4 then
--dprint("found newline, we're cool")
return true
elseif not ws:find(str:peek(-i):sub(1,1)) then
--dprint("found non-whitespace character "..string.byte(str:peek(-i):sub(1,1))..str:peek(-i):sub(1,1))
return false
end
i = i + 1
end
end
return true
end, "--".."#", "$".."[[", "@".."[[", "$".."[{", "@".."[{", "$".."(", "//".."##") -- trust me, this was needed
--dprint("searched")
local sy, sx = str:get_yx()
if not match then
--dprint("not found")
table.insert(cast, {type="content", val=str:next(str:size())})
ast.add_debugging_info(cast, str, sx, sy)
break
end
local epos = str:tell()
local size = (epos-#match)-spos
if size > 0 then
str:set(spos)
local chunk = str:next(size)
if not chunk:match("^%s+$") then
table.insert(cast, {type="content", val=chunk})
ast.add_debugging_info(cast, str, sx, sy)
end
str:skip(#match)
end
--dprint("match: "..match)
if match == "--".."#" or match == "//".."##" then
--str:skip(3)
table.insert(cast, ast.parse_directive(str))
elseif match == "$".."[[" then
local blk = ast.parse_block(str)
table.insert(cast, {type="shell_block", val=blk})
elseif match == "@".."[[" then
local blk = ast.parse_block(str)
table.insert(cast, {type="lua_block", val=blk})
elseif match == "$".."[{" then
local span = ast.parse_span(str)
table.insert(cast, {type="shell_span", val=span})
elseif match == "@".."[{" then
local span = ast.parse_span(str)
--print(span)
table.insert(cast, {type="lua_span", val=span})
elseif match == "$".."(" then
local var = ast.parse_envvar(str)
table.insert(cast, {type="evar", val=var})
else
ast.parser_error(str, "internal compiler error")
end
--dprint("Parsed")
ast.add_debugging_info(cast, str, sx, sy)
end
return cast
end
end