Skip to content

Commit bb1244f

Browse files
committed
resolve #871 infer called function by params num
1 parent f76cd50 commit bb1244f

File tree

5 files changed

+282
-21
lines changed

5 files changed

+282
-21
lines changed

changelog.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
local function f() end
77
local x = f() -- `x` is `nil` instead of `unknown`
88
```
9+
* `CHG` infer called function by params num
10+
```lua
11+
---@overload fun(number, number):string
12+
---@overload fun(number):number
13+
---@return boolean
14+
local function f() end
15+
16+
local n1 = f() -- `n1` is `boolean`
17+
local n2 = f(0) -- `n2` is `number`
18+
local n3 = f(0, 0) -- `n3` is `string`
19+
```
920

1021
## 3.3.1
1122
`2022-6-17`

script/vm/compiler.lua

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -554,32 +554,29 @@ local function getReturn(func, index, args)
554554
end
555555
return vm.compileNode(ast)
556556
end
557-
local node = vm.compileNode(func)
557+
local funcs = vm.getMatchedFunctions(func, args)
558558
---@type vm.node?
559559
local result
560-
for cnode in node:eachObject() do
561-
if cnode.type == 'function'
562-
or cnode.type == 'doc.type.function' then
563-
local returnObject = vm.getReturnOfFunction(cnode, index)
564-
if returnObject then
565-
local returnNode = vm.compileNode(returnObject)
560+
for _, mfunc in ipairs(funcs) do
561+
local returnObject = vm.getReturnOfFunction(mfunc, index)
562+
if returnObject then
563+
local returnNode = vm.compileNode(returnObject)
564+
for rnode in returnNode:eachObject() do
565+
if rnode.type == 'generic' then
566+
returnNode = rnode:resolve(guide.getUri(func), args)
567+
break
568+
end
569+
end
570+
if returnNode then
566571
for rnode in returnNode:eachObject() do
567-
if rnode.type == 'generic' then
568-
returnNode = rnode:resolve(guide.getUri(func), args)
569-
break
572+
-- TODO: narrow type
573+
if rnode.type ~= 'doc.generic.name' then
574+
result = result or vm.createNode()
575+
result:merge(rnode)
570576
end
571577
end
572-
if returnNode then
573-
for rnode in returnNode:eachObject() do
574-
-- TODO: narrow type
575-
if rnode.type ~= 'doc.generic.name' then
576-
result = result or vm.createNode()
577-
result:merge(rnode)
578-
end
579-
end
580-
if result and returnNode:isOptional() then
581-
result:addOptional()
582-
end
578+
if result and returnNode:isOptional() then
579+
result:addOptional()
583580
end
584581
end
585582
end

script/vm/function.lua

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
---@class vm
2+
local vm = require 'vm.vm'
3+
4+
---@param func parser.object
5+
---@return integer min
6+
---@return integer max
7+
function vm.countParamsOfFunction(func)
8+
local min = 0
9+
local max = 0
10+
if func.type == 'function'
11+
or func.type == 'doc.type.function' then
12+
if func.args then
13+
max = #func.args
14+
min = max
15+
for i = #func.args, 1, -1 do
16+
local arg = func.args[i]
17+
if arg.type == '...'
18+
or (arg.name and arg.name[1] =='...') then
19+
max = math.huge
20+
elseif not vm.compileNode(arg):isNullable() then
21+
min = i
22+
break
23+
end
24+
end
25+
end
26+
end
27+
return min, max
28+
end
29+
30+
---@param func parser.object
31+
---@return integer min
32+
---@return integer max
33+
function vm.countReturnsOfFunction(func)
34+
if func.type == 'function' then
35+
if not func.returns then
36+
return 0, 0
37+
end
38+
local min, max
39+
for _, ret in ipairs(func.returns) do
40+
local rmin, rmax = vm.countList(ret)
41+
if not min or rmin < min then
42+
min = rmin
43+
end
44+
if not max or rmax > max then
45+
max = rmax
46+
end
47+
end
48+
return min, max
49+
end
50+
if func.type == 'doc.type.function' then
51+
return vm.countList(func.returns)
52+
end
53+
return 0, 0
54+
end
55+
56+
---@param func parser.object
57+
---@return integer min
58+
---@return integer max
59+
function vm.countReturnsOfCall(func, args)
60+
local funcs = vm.getMatchedFunctions(func, args)
61+
local min
62+
local max
63+
for _, f in ipairs(funcs) do
64+
local rmin, rmax = vm.countReturnsOfFunction(f)
65+
if not min or rmin < min then
66+
min = rmin
67+
end
68+
if not max or rmax > max then
69+
max = rmax
70+
end
71+
end
72+
return min or 0, max or 0
73+
end
74+
75+
---@param list parser.object[]?
76+
---@return integer min
77+
---@return integer max
78+
function vm.countList(list)
79+
if not list then
80+
return 0, 0
81+
end
82+
local lastArg = list[#list]
83+
if not lastArg then
84+
return 0, 0
85+
end
86+
if lastArg.type == '...' then
87+
return #list - 1, math.huge
88+
end
89+
if lastArg.type == 'call' then
90+
local rmin, rmax = vm.countReturnsOfCall(lastArg.node, lastArg.args)
91+
return #list - 1 + rmin, #list - 1 + rmax
92+
end
93+
return #list, #list
94+
end
95+
96+
---@param func parser.object
97+
---@param args parser.object[]?
98+
---@return parser.object[]
99+
function vm.getMatchedFunctions(func, args)
100+
local funcs = {}
101+
local node = vm.compileNode(func)
102+
for n in node:eachObject() do
103+
if n.type == 'function'
104+
or n.type == 'doc.type.function' then
105+
funcs[#funcs+1] = n
106+
end
107+
end
108+
if #funcs <= 1 then
109+
return funcs
110+
end
111+
112+
local amin, amax = vm.countList(args)
113+
114+
local matched = {}
115+
for _, n in ipairs(funcs) do
116+
local min, max = vm.countParamsOfFunction(n)
117+
if amin >= min and amax <= max then
118+
matched[#matched+1] = n
119+
end
120+
end
121+
122+
if #matched == 0 then
123+
return funcs
124+
else
125+
return matched
126+
end
127+
end

script/vm/init.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ require 'vm.generic'
1717
require 'vm.sign'
1818
require 'vm.local-id'
1919
require 'vm.global'
20+
require 'vm.function'
2021
return vm

test/type_inference/init.lua

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,3 +2835,128 @@ local b
28352835
28362836
local <?x?> = echo(b)
28372837
]]
2838+
2839+
TEST 'boolean' [[
2840+
---@return boolean
2841+
function f()
2842+
end
2843+
2844+
---@param x integer
2845+
---@return number
2846+
function f(x)
2847+
end
2848+
2849+
local <?x?> = f()
2850+
]]
2851+
2852+
TEST 'number' [[
2853+
---@return boolean
2854+
function f()
2855+
end
2856+
2857+
---@param x integer
2858+
---@return number
2859+
function f(x)
2860+
end
2861+
2862+
local <?x?> = f(1)
2863+
]]
2864+
2865+
TEST 'boolean' [[
2866+
---@return boolean
2867+
function f()
2868+
end
2869+
2870+
---@param x integer
2871+
---@return number
2872+
function f(x)
2873+
end
2874+
2875+
function r0()
2876+
return
2877+
end
2878+
2879+
local <?x?> = f(r0())
2880+
]]
2881+
2882+
TEST 'number' [[
2883+
---@return boolean
2884+
function f()
2885+
end
2886+
2887+
---@param x integer
2888+
---@return number
2889+
function f(x)
2890+
end
2891+
2892+
function r1()
2893+
return 1
2894+
end
2895+
2896+
local <?x?> = f(r1())
2897+
]]
2898+
2899+
TEST 'boolean' [[
2900+
---@return boolean
2901+
function f()
2902+
end
2903+
2904+
---@param x integer
2905+
---@return number
2906+
function f(x)
2907+
end
2908+
2909+
---@type fun()
2910+
local r0
2911+
2912+
local <?x?> = f(r0())
2913+
]]
2914+
2915+
TEST 'number' [[
2916+
---@return boolean
2917+
function f()
2918+
end
2919+
2920+
---@param x integer
2921+
---@return number
2922+
function f(x)
2923+
end
2924+
2925+
---@type fun():integer
2926+
local r1
2927+
2928+
local <?x?> = f(r1())
2929+
]]
2930+
2931+
TEST 'boolean' [[
2932+
---@overload fun(number, number):string
2933+
---@overload fun(number):number
2934+
---@return boolean
2935+
local function f() end
2936+
2937+
local <?n1?> = f()
2938+
local n2 = f(0)
2939+
local n3 = f(0, 0)
2940+
]]
2941+
2942+
TEST 'number' [[
2943+
---@overload fun(number, number):string
2944+
---@overload fun(number):number
2945+
---@return boolean
2946+
local function f() end
2947+
2948+
local n1 = f()
2949+
local <?n2?> = f(0)
2950+
local n3 = f(0, 0)
2951+
]]
2952+
2953+
TEST 'string' [[
2954+
---@overload fun(number, number):string
2955+
---@overload fun(number):number
2956+
---@return boolean
2957+
local function f() end
2958+
2959+
local n1 = f()
2960+
local n2 = f(0)
2961+
local <?n3?> = f(0, 0)
2962+
]]

0 commit comments

Comments
 (0)