diff --git a/runtime/lua/vim/diagnostic.lua b/runtime/lua/vim/diagnostic.lua index a0c37eea80..002d9a63cb 100644 --- a/runtime/lua/vim/diagnostic.lua +++ b/runtime/lua/vim/diagnostic.lua @@ -1013,13 +1013,55 @@ local function set_list(loclist, opts) end end +--- @param a vim.Diagnostic +--- @param b vim.Diagnostic +--- @param primary_key string Primary sort key ('severity', 'col', etc) +--- @param reverse boolean Whether to reverse primary comparison +--- @param col_fn (fun(diagnostic: vim.Diagnostic): integer)? Optional function to get column value +--- @return boolean +local function diagnostic_cmp(a, b, primary_key, reverse, col_fn) + local a_val, b_val --- @type integer, integer + if col_fn then + a_val, b_val = col_fn(a), col_fn(b) + else + a_val = a[primary_key] --[[@as integer]] + b_val = b[primary_key] --[[@as integer]] + end + + local cmp = function(x, y) + if reverse then + return x > y + else + return x < y + end + end + + if a_val ~= b_val then + return cmp(a_val, b_val) + end + if a.lnum ~= b.lnum then + return cmp(a.lnum, b.lnum) + end + if a.col ~= b.col then + return cmp(a.col, b.col) + end + if a.end_lnum ~= b.end_lnum then + return cmp(a.end_lnum, b.end_lnum) + end + if a.end_col ~= b.end_col then + return cmp(a.end_col, b.end_col) + end + + return cmp(a._extmark_id or 0, b._extmark_id or 0) +end + --- Jump to the diagnostic with the highest severity. First sort the --- diagnostics by severity. The first diagnostic then contains the highest severity, and we can --- discard all diagnostics with a lower severity. --- @param diagnostics vim.Diagnostic[] local function filter_highest(diagnostics) table.sort(diagnostics, function(a, b) - return a.severity < b.severity + return diagnostic_cmp(a, b, 'severity', false) end) -- Find the first diagnostic where the severity does not match the highest severity, and remove @@ -1077,7 +1119,7 @@ local function next_diagnostic(search_forward, opts, use_logical_pos) --- @param diagnostic vim.Diagnostic --- @return integer - local function col(diagnostic) + local function col_fn(diagnostic) return use_logical_pos and select(2, get_logical_pos(diagnostic)) or diagnostic.col end @@ -1097,17 +1139,17 @@ local function next_diagnostic(search_forward, opts, use_logical_pos) local sort_diagnostics, is_next if search_forward then sort_diagnostics = function(a, b) - return col(a) < col(b) + return diagnostic_cmp(a, b, 'col', false, col_fn) end is_next = function(d) - return math.min(col(d), math.max(line_length - 1, 0)) > position[2] + return math.min(col_fn(d), math.max(line_length - 1, 0)) > position[2] end else sort_diagnostics = function(a, b) - return col(a) > col(b) + return diagnostic_cmp(a, b, 'col', true, col_fn) end is_next = function(d) - return math.min(col(d), math.max(line_length - 1, 0)) < position[2] + return math.min(col_fn(d), math.max(line_length - 1, 0)) < position[2] end end table.sort(line_diagnostics[lnum], sort_diagnostics) @@ -1909,11 +1951,7 @@ end --- @param diagnostics vim.Diagnostic[] local function render_virtual_lines(namespace, bufnr, diagnostics) table.sort(diagnostics, function(d1, d2) - if d1.lnum == d2.lnum then - return d1.col < d2.col - else - return d1.lnum < d2.lnum - end + return diagnostic_cmp(d1, d2, 'lnum', false) end) api.nvim_buf_clear_namespace(bufnr, namespace, 0, -1) @@ -2309,11 +2347,11 @@ function M.show(namespace, bufnr, diagnostics, opts) if opts_res.severity_sort then if type(opts_res.severity_sort) == 'table' and opts_res.severity_sort.reverse then table.sort(diagnostics, function(a, b) - return a.severity < b.severity + return diagnostic_cmp(a, b, 'severity', false) end) else table.sort(diagnostics, function(a, b) - return a.severity > b.severity + return diagnostic_cmp(a, b, 'severity', true) end) end end @@ -2407,11 +2445,11 @@ function M.open_float(opts, ...) if severity_sort then if type(severity_sort) == 'table' and severity_sort.reverse then table.sort(diagnostics, function(a, b) - return a.severity > b.severity + return diagnostic_cmp(a, b, 'severity', true) end) else table.sort(diagnostics, function(a, b) - return a.severity < b.severity + return diagnostic_cmp(a, b, 'severity', false) end) end end diff --git a/test/functional/lua/diagnostic_spec.lua b/test/functional/lua/diagnostic_spec.lua index aec11b6f70..e37471273f 100644 --- a/test/functional/lua/diagnostic_spec.lua +++ b/test/functional/lua/diagnostic_spec.lua @@ -2538,6 +2538,31 @@ describe('vim.diagnostic', function() end) eq('Error here!', result[1][3][1]) end) + + it('sorts by severity with stable tiebreaker #37137', function() + local result = exec_lua(function() + vim.diagnostic.config({ severity_sort = true, virtual_lines = { current_line = true } }) + local m = 100 + local diagnostics = { + { end_col = m, lnum = 0, message = 'a', severity = 2 }, + { end_col = m, lnum = 0, message = 'b', severity = 2 }, + { end_col = m, lnum = 0, message = 'c', severity = 2 }, + { end_col = m, lnum = 2, message = 'd', severity = 2 }, + { end_col = m, lnum = 2, message = 'e', severity = 2 }, + { end_col = m, lnum = 2, message = 'f', severity = 2 }, + } + vim.diagnostic.set(_G.diagnostic_ns, _G.diagnostic_bufnr, diagnostics, {}) + vim.diagnostic.show(_G.diagnostic_ns, _G.diagnostic_bufnr) + vim.api.nvim_win_set_cursor(0, { 1, 0 }) + local extmarks = _G.get_virt_lines_extmarks(_G.diagnostic_ns) + local result = {} + for _, d in ipairs(extmarks[1][4].virt_lines) do + table.insert(result, d[3][1]) + end + return result + end) + eq({ 'c', 'b', 'a' }, result) + end) end) describe('set()', function()