feat(treesitter)!: consolidate query util functions

- And address more type errors.
- Removed the `concat` option from `get_node_text` since it was applied
  inconsistently and made typing awkward.
This commit is contained in:
Lewis Russell
2023-03-10 16:10:05 +00:00
parent 845efb8e12
commit 9d70fe062c
5 changed files with 93 additions and 106 deletions

View File

@@ -1,6 +1,8 @@
local a = vim.api
local language = require('vim.treesitter.language')
local Range = require('vim.treesitter._range')
---@class Query
---@field captures string[] List of captures used in query
---@field info TSQueryInfo Contains used queries, predicates, directives
@@ -56,35 +58,13 @@ local function add_included_lang(base_langs, lang, ilang)
end
---@private
---@param buf (integer)
---@param range (table)
---@param concat (boolean)
---@returns (string[]|string|nil)
local function buf_range_get_text(buf, range, concat)
local lines
local start_row, start_col, end_row, end_col = unpack(range)
local eof_row = a.nvim_buf_line_count(buf)
if start_row >= eof_row then
return nil
end
if end_col == 0 then
lines = a.nvim_buf_get_lines(buf, start_row, end_row, true)
end_col = -1
else
lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true)
end
if #lines > 0 then
if #lines == 1 then
lines[1] = string.sub(lines[1], start_col + 1, end_col)
else
lines[1] = string.sub(lines[1], start_col + 1)
lines[#lines] = string.sub(lines[#lines], 1, end_col)
end
end
return concat and table.concat(lines, '\n') or lines
---@param buf integer
---@param range Range6
---@returns string
local function buf_range_get_text(buf, range)
local start_row, start_col, end_row, end_col = Range.unpack4(range)
local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
return table.concat(lines, '\n')
end
--- Gets the list of files used to make up a query
@@ -256,14 +236,28 @@ function M.parse_query(lang, query)
local cached = query_cache[lang][query]
if cached then
return cached
else
local self = setmetatable({}, Query)
self.query = vim._ts_parse_query(lang, query)
self.info = self.query:inspect()
self.captures = self.info.captures
query_cache[lang][query] = self
return self
end
local self = setmetatable({}, Query)
self.query = vim._ts_parse_query(lang, query)
self.info = self.query:inspect()
self.captures = self.info.captures
query_cache[lang][query] = self
return self
end
---Get the range of a |TSNode|. Can also supply {source} and {metadata}
---to get the range with directives applied.
---@param node TSNode
---@param source integer|string|nil Buffer or string from which the {node} is extracted
---@param metadata TSMetadata|nil
---@return Range6
function M.get_range(node, source, metadata)
if metadata and metadata.range then
assert(source)
return Range.add_bytes(source, metadata.range)
end
return { node:range(true) }
end
--- Gets the text corresponding to a given node
@@ -271,24 +265,22 @@ end
---@param node TSNode
---@param source (integer|string) Buffer or string from which the {node} is extracted
---@param opts (table|nil) Optional parameters.
--- - concat: (boolean) Concatenate result in a string (default true)
--- - metadata (table) Metadata of a specific capture. This would be
--- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|.
---@return (string[]|string|nil)
---@return string
function M.get_node_text(node, source, opts)
opts = opts or {}
-- TODO(lewis6991): concat only works when source is number.
local concat = vim.F.if_nil(opts.concat, true)
local metadata = opts.metadata or {}
if metadata.text then
return metadata.text
elseif type(source) == 'number' then
return metadata.range and buf_range_get_text(source, metadata.range, concat)
or buf_range_get_text(source, { node:range() }, concat)
elseif type(source) == 'string' then
return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
local range = M.get_range(node, source, metadata)
return buf_range_get_text(source, range)
end
---@cast source string
return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end
---@alias TSMatch table<integer,TSNode>
@@ -312,7 +304,7 @@ local predicate_handlers = {
str = predicate[3]
else
-- (#eq? @aa @bb)
str = M.get_node_text(match[predicate[3]], source) --[[@as string]]
str = M.get_node_text(match[predicate[3]], source)
end
if node_text ~= str or str == nil then
@@ -328,7 +320,7 @@ local predicate_handlers = {
return true
end
local regex = predicate[3]
return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil
return string.find(M.get_node_text(node, source), regex) ~= nil
end,
['match?'] = (function()
@@ -366,7 +358,7 @@ local predicate_handlers = {
if not node then
return true
end
local node_text = M.get_node_text(node, source) --[[@as string]]
local node_text = M.get_node_text(node, source)
for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
@@ -404,9 +396,9 @@ local predicate_handlers = {
predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata
---@field range Range4|Range6
---@field [integer] TSMetadata
---@field [string] integer|string
---@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
@@ -465,13 +457,20 @@ local directive_handlers = {
assert(#pred == 4)
local id = pred[2]
assert(type(id) == 'number')
local node = match[id]
local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then
metadata[id] = {}
end
metadata[id].text = text:gsub(pred[3], pred[4])
local pattern, replacement = pred[3], pred[3]
assert(type(pattern) == 'string')
assert(type(replacement) == 'string')
metadata[id].text = text:gsub(pattern, replacement)
end,
}