From 0977f70f4dd3d14175697d5e6568d4019019506f Mon Sep 17 00:00:00 2001 From: Riley Bruins Date: Sun, 13 Apr 2025 14:22:17 -0700 Subject: [PATCH] fix(treesitter): injected lang ranges may cross capture boundaries #32549 Problem: treesitter injected language ranges sometimes cross over the capture boundaries when `@combined`. Solution: Clip child regions to not spill out of parent regions within languagetree.lua, and only apply highlights within those regions in highlighter.lua. Co-authored-by: Cormac Relf --- runtime/lua/vim/treesitter/_range.lua | 13 ++ runtime/lua/vim/treesitter/highlighter.lua | 83 +++++++------ runtime/lua/vim/treesitter/languagetree.lua | 45 ++++++- test/functional/treesitter/highlight_spec.lua | 114 ++++++++++++++++++ 4 files changed, 213 insertions(+), 42 deletions(-) diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua index 82ab8517aa..b1595cf8f6 100644 --- a/runtime/lua/vim/treesitter/_range.lua +++ b/runtime/lua/vim/treesitter/_range.lua @@ -114,6 +114,19 @@ function M.intercepts(r1, r2) return true end +---@private +---@param r1 Range6 +---@param r2 Range6 +---@return Range6? +function M.intersection(r1, r2) + if not M.intercepts(r1, r2) then + return nil + end + local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1 + local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1 + return { rs[1], rs[2], rs[3], re[4], re[5], re[6] } +end + ---@private ---@param r Range ---@return integer, integer, integer, integer diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 9969cb9dce..60d4a89278 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -322,6 +322,8 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal) return end + local tree_region = state.tstree:included_ranges(true) + if state.iter == nil or state.next_row < line then -- Mainly used to skip over folds @@ -336,56 +338,63 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal) while line >= state.next_row do local capture, node, metadata, match = state.iter(line) - local range = { root_end_row + 1, 0, root_end_row + 1, 0 } + local outer_range = { root_end_row + 1, 0, root_end_row + 1, 0 } if node then - range = vim.treesitter.get_range(node, buf, metadata and metadata[capture]) + outer_range = vim.treesitter.get_range(node, buf, metadata and metadata[capture]) end - local start_row, start_col, end_row, end_col = Range.unpack4(range) + local outer_range_start_row = outer_range[1] - if capture then - local hl = state.highlighter_query:get_hl_from_capture(capture) + for _, range in ipairs(tree_region) do + local intersection = Range.intersection(range, outer_range) + if intersection then + local start_row, start_col, end_row, end_col = Range.unpack4(intersection) - local capture_name = captures[capture] + if capture then + local hl = state.highlighter_query:get_hl_from_capture(capture) - local spell, spell_pri_offset = get_spell(capture_name) + local capture_name = captures[capture] - -- The "priority" attribute can be set at the pattern level or on a particular capture - local priority = ( - tonumber(metadata.priority or metadata[capture] and metadata[capture].priority) - or vim.hl.priorities.treesitter - ) + spell_pri_offset + local spell, spell_pri_offset = get_spell(capture_name) - -- The "conceal" attribute can be set at the pattern level or on a particular capture - local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal + -- The "priority" attribute can be set at the pattern level or on a particular capture + local priority = ( + tonumber(metadata.priority or metadata[capture] and metadata[capture].priority) + or vim.hl.priorities.treesitter + ) + spell_pri_offset - local url = get_url(match, buf, capture, metadata) + -- The "conceal" attribute can be set at the pattern level or on a particular capture + local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal - if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then - api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_line = end_row, - end_col = end_col, - hl_group = hl, - ephemeral = true, - priority = priority, - conceal = conceal, - spell = spell, - url = url, - }) - end + local url = get_url(match, buf, capture, metadata) - if - (metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines) - and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0 - then - api.nvim_buf_set_extmark(buf, ns, start_row, 0, { - end_line = end_row, - conceal_lines = '', - }) + if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then + api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { + end_line = end_row, + end_col = end_col, + hl_group = hl, + ephemeral = true, + priority = priority, + conceal = conceal, + spell = spell, + url = url, + }) + end + + if + (metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines) + and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0 + then + api.nvim_buf_set_extmark(buf, ns, start_row, 0, { + end_line = end_row, + conceal_lines = '', + }) + end + end end end - if start_row > line then - state.next_row = start_row + if outer_range_start_row > line then + state.next_row = outer_range_start_row end end end) diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 6f0e377d2f..31cf64b54a 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -874,6 +874,39 @@ local function get_node_ranges(node, source, metadata, include_children) return ranges end +---Finds the intersection between two regions, assuming they are sorted in ascending order by +---starting point. +---@param region1 Range6[] +---@param region2 Range6[]? +---@return Range6[] +local function clip_regions(region1, region2) + if not region2 then + return region1 + end + + local result = {} + local i, j = 1, 1 + + while i <= #region1 and j <= #region2 do + local r1 = region1[i] + local r2 = region2[j] + + local intersection = Range.intersection(r1, r2) + if intersection then + table.insert(result, intersection) + end + + -- Advance the range that ends earlier + if Range.cmp_pos.le(r1[3], r1[4], r2[3], r2[4]) then + i = i + 1 + else + j = j + 1 + end + end + + return result +end + ---@nodoc ---@class vim.treesitter.languagetree.InjectionElem ---@field combined boolean @@ -886,8 +919,9 @@ end ---@param lang string ---@param combined boolean ---@param ranges Range6[] +---@param parent_ranges Range6[]? ---@param result table -local function add_injection(t, pattern, lang, combined, ranges, result) +local function add_injection(t, pattern, lang, combined, ranges, parent_ranges, result) if #ranges == 0 then -- Make sure not to add an empty range set as this is interpreted to mean the whole buffer. return @@ -898,7 +932,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result) end if not combined then - table.insert(result[lang], ranges) + table.insert(result[lang], clip_regions(ranges, parent_ranges)) return end @@ -914,7 +948,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result) table.insert(result[lang], regions) end - for _, range in ipairs(ranges) do + for _, range in ipairs(clip_regions(ranges, parent_ranges)) do table.insert(t[lang][pattern], range) end end @@ -1007,10 +1041,11 @@ function LanguageTree:_get_injections(range, thread_state) local full_scan = range == true or self._injection_query.has_combined_injections - for _, tree in pairs(self._trees) do + for tree_index, tree in pairs(self._trees) do ---@type vim.treesitter.languagetree.Injection local injections = {} local root_node = tree:root() + local parent_ranges = self._regions and self._regions[tree_index] or nil local start_line, end_line ---@type integer, integer if full_scan then start_line, _, end_line = root_node:range() @@ -1023,7 +1058,7 @@ function LanguageTree:_get_injections(range, thread_state) do local lang, combined, ranges = self:_get_injection(match, metadata) if lang then - add_injection(injections, pattern, lang, combined, ranges, result) + add_injection(injections, pattern, lang, combined, ranges, parent_ranges, result) else self:_log('match from injection query failed for pattern', pattern) end diff --git a/test/functional/treesitter/highlight_spec.lua b/test/functional/treesitter/highlight_spec.lua index 5b9a060fb5..9bf32bf8a7 100644 --- a/test/functional/treesitter/highlight_spec.lua +++ b/test/functional/treesitter/highlight_spec.lua @@ -513,6 +513,120 @@ describe('treesitter highlighting (C)', function() screen:expect { grid = injection_grid_expected_c } end) + it('supports combined injections #31777', function() + insert([=[ + -- print([[ + -- some + -- random + -- text + -- here]]) + ]=]) + + exec_lua(function() + local parser = vim.treesitter.get_parser(0, 'lua', { + injections = { + lua = [[ + ; query + ((comment_content) @injection.content + (#set! injection.self) + (#set! injection.combined)) + ]], + }, + }) + local highlighter = vim.treesitter.highlighter + highlighter.new(parser, { + queries = { + lua = [[ + ; query + (string) @string + (comment) @comment + (function_call (identifier) @function.call) + [ "(" ")" ] @punctuation.bracket + ]], + }, + }) + end) + + screen:expect([=[ + {18:-- }{25:print}{16:(}{26:[[} | + {18:--}{26: some} | + {18:-- random} | + {18:-- text} | + {18:-- here]])} | + ^ | + {1:~ }|*11 + | + ]=]) + -- NOTE: Once #31777 is fixed, this test case should be updated to the following: + -- screen:expect([=[ + -- {18:-- }{25:print}{16:(}{26:[[} | + -- {18:--}{26: some} | + -- {18:--}{26: random} | + -- {18:--}{26: text} | + -- {18:--}{26: here]]}{16:)} | + -- ^ | + -- {1:~ }|*11 + -- | + -- ]=]) + end) + + it('supports complicated combined injections', function() + insert([[ + -- # Markdown here + -- + -- ```c + -- int main() { + -- printf("Hello, world!"); + -- } + -- ``` + ]]) + + exec_lua(function() + local parser = vim.treesitter.get_parser(0, 'lua', { + injections = { + lua = [[ + ; query + ((comment) @injection.content + (#offset! @injection.content 0 3 0 1) + (#lua-match? @injection.content "[-][-] ") + (#set! injection.combined) + (#set! injection.include-children) + (#set! injection.language "markdown")) + ]], + }, + }) + local highlighter = vim.treesitter.highlighter + highlighter.new(parser, { + queries = { + lua = [[ + ; query + (string) @string + (comment) @comment + (function_call (identifier) @function.call) + [ "(" ")" ] @punctuation.bracket + ]], + }, + }) + end) + + screen:add_extra_attr_ids({ + [131] = { foreground = Screen.colors.Fuchsia, bold = true }, + }) + + screen:expect([[ + {18:-- }{131:# Markdown here} | + {18:--} | + {18:-- ```}{15:c} | + {18:-- }{16:int}{18: }{25:main}{16:()}{18: }{16:{} | + {18:-- }{25:printf}{16:(}{26:"Hello, world!"}{16:);} | + {18:-- }{16:}} | + {18:-- ```} | + ^ | + {1:~ }|*9 + | + ]]) + end) + it("supports injecting by ft name in metadata['injection.language']", function() insert(injection_text_c)