Skip to content

Commit c636fdd

Browse files
authored
Merge pull request #2822 from tomlau10/fix/type_narrow
fix: improve function type narrow by checking params' literal identical
2 parents 08dd0ca + 30deedc commit c636fdd

File tree

4 files changed

+88
-13
lines changed

4 files changed

+88
-13
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
77
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
88
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
9+
* `FIX` Improve type narrow by checking exact match on literal type params
910

1011
## 3.10.5
1112
`2024-8-19`

script/vm/function.lua

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,35 @@ local function isAllParamMatched(uri, args, params)
353353
return true
354354
end
355355

356+
---@param uri uri
357+
---@param args parser.object[]
358+
---@param func parser.object
359+
---@return number
360+
local function calcFunctionMatchScore(uri, args, func)
361+
if vm.isVarargFunctionWithOverloads(func)
362+
or not isAllParamMatched(uri, args, func.args)
363+
then
364+
return -1
365+
end
366+
local matchScore = 0
367+
for i = 1, math.min(#args, #func.args) do
368+
local arg, param = args[i], func.args[i]
369+
local defLiterals, literalsCount = vm.getLiterals(param)
370+
if defLiterals then
371+
for n in vm.compileNode(arg):eachObject() do
372+
-- if param's literals map contains arg's literal, this is narrower than a subtype match
373+
if defLiterals[guide.getLiteral(n)] then
374+
-- the more the literals defined in the param, the less bonus score will be added
375+
-- this favors matching overload param with exact literal value, over alias/enum that has many literal values
376+
matchScore = matchScore + 1/literalsCount
377+
break
378+
end
379+
end
380+
end
381+
end
382+
return matchScore
383+
end
384+
356385
---@param func parser.object
357386
---@param args? parser.object[]
358387
---@return parser.object[]?
@@ -365,21 +394,29 @@ function vm.getExactMatchedFunctions(func, args)
365394
return funcs
366395
end
367396
local uri = guide.getUri(func)
368-
local needRemove
397+
local matchScores = {}
369398
for i, n in ipairs(funcs) do
370-
if vm.isVarargFunctionWithOverloads(n)
371-
or not isAllParamMatched(uri, args, n.args) then
372-
if not needRemove then
373-
needRemove = {}
374-
end
375-
needRemove[#needRemove+1] = i
376-
end
399+
matchScores[i] = calcFunctionMatchScore(uri, args, n)
400+
end
401+
402+
local maxMatchScore = math.max(table.unpack(matchScores))
403+
if maxMatchScore == -1 then
404+
-- all should be removed
405+
return nil
377406
end
378-
if not needRemove then
407+
408+
local minMatchScore = math.min(table.unpack(matchScores))
409+
if minMatchScore == maxMatchScore then
410+
-- all should be kept
379411
return funcs
380412
end
381-
if #needRemove == #funcs then
382-
return nil
413+
414+
-- remove functions that have matchScore < maxMatchScore
415+
local needRemove = {}
416+
for i, matchScore in ipairs(matchScores) do
417+
if matchScore < maxMatchScore then
418+
needRemove[#needRemove + 1] = i
419+
end
383420
end
384421
util.tableMultiRemove(funcs, needRemove)
385422
return funcs

script/vm/value.lua

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,13 @@ end
213213

214214
---@param v vm.object
215215
---@return table<any, boolean>?
216+
---@return integer
216217
function vm.getLiterals(v)
217218
if not v then
218-
return nil
219+
return nil, 0
219220
end
220221
local map
222+
local count = 0
221223
local node = vm.compileNode(v)
222224
for n in node:eachObject() do
223225
local literal
@@ -237,7 +239,8 @@ function vm.getLiterals(v)
237239
map = {}
238240
end
239241
map[literal] = true
242+
count = count + 1
240243
end
241244
end
242-
return map
245+
return map, count
243246
end

test/type_inference/param_match.lua

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,40 @@ local function f(...) end
138138
local <?r?> = f(10)
139139
]]
140140

141+
TEST '1' [[
142+
---@overload fun(a: string): 1
143+
---@overload fun(a: 'y'): 2
144+
local function f(...) end
145+
146+
local <?r?> = f('x')
147+
]]
148+
149+
TEST '2' [[
150+
---@overload fun(a: string): 1
151+
---@overload fun(a: 'y'): 2
152+
local function f(...) end
153+
154+
local <?r?> = f('y')
155+
]]
156+
157+
TEST '1' [[
158+
---@overload fun(a: string): 1
159+
---@overload fun(a: 'y'): 2
160+
local function f(...) end
161+
162+
local v = 'x'
163+
local <?r?> = f(v)
164+
]]
165+
166+
TEST '2' [[
167+
---@overload fun(a: string): 1
168+
---@overload fun(a: 'y'): 2
169+
local function f(...) end
170+
171+
local v = 'y'
172+
local <?r?> = f(v)
173+
]]
174+
141175
TEST 'number' [[
142176
---@overload fun(a: 1, c: fun(x: number))
143177
---@overload fun(a: 2, c: fun(x: string))

0 commit comments

Comments
 (0)