Skip to content

Commit

Permalink
more refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Mar 3, 2025
1 parent 3757199 commit f4d6920
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 102 deletions.
146 changes: 77 additions & 69 deletions rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ private module Input1 implements InputSig1<Location> {
class TypeParameter = Types::TypeParameter;

private newtype TTypeArgumentPosition =
// method type parameters are matched by position instead of by type
// parameter entity, to avoid extra recursion through method call resolution
TMethodTypeArgumentPosition(int pos) {
exists(any(MethodCallExpr mce).getGenericArgList().getTypeArgument(pos))
} or
Expand Down Expand Up @@ -535,6 +537,7 @@ private module Input2 implements InputSig2 {
private import Input2
import Make2<Input2>

/** Gets the type annotation that applies to `n`, if any. */
private TypeMention getTypeAnnotation(AstNode n) {
exists(LetStmt let |
n = let.getPat() and
Expand Down Expand Up @@ -650,7 +653,7 @@ private Type resolveImplicitSelfType(SelfParam self, TypePath path) {
* A matching configuration for resolving types of record expressions
* like `Foo { bar = baz }`.
*/
private module RecordFieldMatchingInput implements MatchingInputSig {
private module RecordExprMatchingInput implements MatchingInputSig {
abstract class Declaration extends AstNode {
abstract TypeParam getATypeParam();

Expand All @@ -660,12 +663,31 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
}

abstract RecordField getField(string name);

Type getDeclaredType(DeclarationPosition pos, TypePath path) {
exists(TypeReprMention tp |
tp = this.getField(pos.asFieldPos()).getTypeRepr() and
result = tp.resolveTypeAt(path)
)
or
pos = TDeclPos() and
result = this.getTypeParameter(_) and
path = typePath(result)
}
}

private class StructDecl extends Declaration, Struct {
override TypeParam getATypeParam() { result = this.getGenericParamList().getATypeParam() }

override RecordField getField(string name) { result = this.getRecordField(name) }

override Type getDeclaredType(DeclarationPosition pos, TypePath path) {
result = super.getDeclaredType(pos, path)
or
pos = TDeclPos() and
path.isEmpty() and
result = TStruct(this)
}
}

private class VariantDecl extends Declaration, Variant {
Expand All @@ -676,6 +698,14 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
}

override RecordField getField(string name) { result = this.getRecordField(name) }

override Type getDeclaredType(DeclarationPosition pos, TypePath path) {
result = super.getDeclaredType(pos, path)
or
pos = TDeclPos() and
path.isEmpty() and
result = TEnum(this.getEnum())
}
}

class Access extends RecordExpr {
Expand Down Expand Up @@ -717,32 +747,14 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition ppos) {
apos = ppos
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
exists(TypeReprMention tp | tp = decl.getField(pos.asFieldPos()).getTypeRepr() |
t = tp.resolveTypeAt(path)
)
or
pos = TDeclPos() and
(
t = TStruct(decl) and
path.isEmpty()
or
t = TEnum(decl.(VariantDecl).getEnum()) and
path.isEmpty()
or
t = decl.getTypeParameter(_) and
path = typePath(t)
)
}
}

private module RecordFieldMatching = Matching<RecordFieldMatchingInput>;
private module RecordExprMatching = Matching<RecordExprMatchingInput>;

private Type resolveRecordExprType(AstNode n, TypePath path) {
exists(RecordFieldMatchingInput::Access a, RecordFieldMatchingInput::AccessPosition apos |
exists(RecordExprMatchingInput::Access a, RecordExprMatchingInput::AccessPosition apos |
n = a.getNode(apos) and
result = RecordFieldMatching::resolveAccessType(a, apos, _, path)
result = RecordExprMatching::resolveAccessType(a, apos, _, path)
)
}

Expand All @@ -761,7 +773,7 @@ private Type resolvePathExprType(PathExpr pe, TypePath path) {
* A matching configuration for resolving types of call expressions
* like `foo::bar(baz)` and `foo.bar(baz)`.
*/
private module FunctionMatchingInput implements MatchingInputSig {
private module CallExprBaseMatchingInput implements MatchingInputSig {
private import codeql.util.Boolean

private predicate positionalDeclarationPosition(ParamList pl, Param p, int pos, boolean inMethod) {
Expand Down Expand Up @@ -795,17 +807,17 @@ private module FunctionMatchingInput implements MatchingInputSig {
}
}

private predicate argPos(CallExprBase call, Expr e, int pos, boolean inMethod) {
private predicate argPos(CallExprBase call, Expr e, int pos, boolean isMethodCall) {
exists(ArgList al |
e = al.getArg(pos) and
call.getArgList() = al and
if call instanceof MethodCallExpr then inMethod = true else inMethod = false
if call instanceof MethodCallExpr then isMethodCall = true else isMethodCall = false
)
}

private newtype TAccessPosition =
TSelfAccessPosition() or
TPositionalAccessPosition(int pos, Boolean inMethod) { argPos(_, _, pos, inMethod) } or
TPositionalAccessPosition(int pos, Boolean isMethodCall) { argPos(_, _, pos, isMethodCall) } or
TReturnAccessPosition()

class AccessPosition extends TAccessPosition {
Expand All @@ -815,8 +827,8 @@ private module FunctionMatchingInput implements MatchingInputSig {
this = TSelfAccessPosition() and
result = "self"
or
exists(int pos, Boolean inMethod |
this = TPositionalAccessPosition(pos, inMethod) and
exists(int pos |
this = TPositionalAccessPosition(pos, _) and
result = pos.toString()
)
or
Expand Down Expand Up @@ -858,6 +870,13 @@ private module FunctionMatchingInput implements MatchingInputSig {
abstract Type getParameterType(DeclarationPosition pos, TypePath path);

abstract Type getReturnType(TypePath path);

Type getDeclaredType(DeclarationPosition pos, TypePath path) {
result = this.getParameterType(pos, path)
or
pos = TReturnDeclarationPosition() and
result = this.getReturnType(path)
}
}

private class StructDecl extends Declaration, Struct {
Expand Down Expand Up @@ -951,9 +970,9 @@ private module FunctionMatchingInput implements MatchingInputSig {
}

AstNode getNode(AccessPosition apos) {
exists(int p, boolean inMethod |
argPos(this, result, p, inMethod) and
apos = TPositionalAccessPosition(p, inMethod)
exists(int p, boolean isMethodCall |
argPos(this, result, p, isMethodCall) and
apos = TPositionalAccessPosition(p, isMethodCall)
)
or
result = this.(MethodCallExpr).getReceiver() and
Expand All @@ -978,13 +997,11 @@ private module FunctionMatchingInput implements MatchingInputSig {
}
}

bindingset[a, apos, target, path, t]
bindingset[apos, target, path, t]
pragma[inline_late]
predicate adjustArgType(
Access a, AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj,
Type tAdj
predicate adjustAccessType(
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
) {
exists(a) and
if apos.isSelf() and target.getParameterType(TSelfDeclarationPosition(), "") = TRefType()
then
// implicit borrow
Expand Down Expand Up @@ -1023,20 +1040,13 @@ private module FunctionMatchingInput implements MatchingInputSig {
tAdj = t
)
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
t = decl.getParameterType(pos, path)
or
pos = TReturnDeclarationPosition() and
t = decl.getReturnType(path)
}
}

private module FunctionMatching = Matching<FunctionMatchingInput>;
private module CallExprBaseMatching = Matching<CallExprBaseMatchingInput>;

pragma[nomagic]
private Type resolveReceiverType(AstNode n) {
exists(FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos |
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
result = resolveType(n) and
n = a.getNode(apos) and
apos.isSelf()
Expand All @@ -1046,11 +1056,11 @@ private Type resolveReceiverType(AstNode n) {
pragma[nomagic]
private Type resolveCallExprBaseType(AstNode n, TypePath path) {
exists(
FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos,
FunctionMatchingInput::DeclarationPosition ppos, TypePath path0, Type res
CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos,
CallExprBaseMatchingInput::DeclarationPosition ppos, TypePath path0, Type res

Check warning

Code scanning / CodeQL

Omittable 'exists' variable Warning

This exists variable can be omitted by using a don't-care expression
in this argument
.
|
n = a.getNode(apos) and
res = FunctionMatching::resolveAccessType(a, apos, ppos, path0)
res = CallExprBaseMatching::resolveAccessType(a, apos, ppos, path0)
|
result = res and
path = path0 and
Expand Down Expand Up @@ -1094,6 +1104,22 @@ private module FieldExprMatchingInput implements MatchingInputSig {
TypeParameter getTypeParameter(TypeParameterPosition ppos) { none() }

Check warning

Code scanning / CodeQL

Use of 'if' with a 'none()' branch. Warning

Use a conjunction instead.

abstract TypeRepr getTypeRepr();

Type getDeclaredType(DeclarationPosition pos, TypePath path) {
pos = TSelfDeclarationPosition() and
exists(Struct s | s.getRecordField(_) = this or s.getTupleField(_) = this |
result = TStruct(s) and
path.isEmpty()
or
exists(int i |

Check warning

Code scanning / CodeQL

Omittable 'exists' variable Warning

This exists variable can be omitted by using a don't-care expression
in this argument
.
result = TTypeParamTypeParameter(s.getGenericParamList().getTypeParam(i)) and
path = typePath(result)
)
)
or
pos = TReturnPos() and
result = this.getTypeRepr().(TypeReprMention).resolveTypeAt(path)
}
}

private class RecordFieldDecl extends Declaration instanceof RecordField {
Expand Down Expand Up @@ -1125,13 +1151,11 @@ private module FieldExprMatchingInput implements MatchingInputSig {
}
}

bindingset[a, apos, target, path, t]
bindingset[apos, target, path, t]
pragma[inline_late]
predicate adjustArgType(
Access a, AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj,
Type tAdj
predicate adjustAccessType(
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
) {
exists(a) and
exists(target) and
if apos.isSelf()
then
Expand Down Expand Up @@ -1175,22 +1199,6 @@ private module FieldExprMatchingInput implements MatchingInputSig {
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition ppos) {
apos = ppos
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
pos = TSelfDeclarationPosition() and
exists(Struct s | s.getRecordField(_) = decl or s.getTupleField(_) = decl |
t = TStruct(s) and
path.isEmpty()
or
exists(int i |
t = TTypeParamTypeParameter(s.getGenericParamList().getTypeParam(i)) and
path = typePath(t)
)
)
or
pos = TReturnPos() and
t = decl.getTypeRepr().(TypeReprMention).resolveTypeAt(path)
}
}

Check warning

Code scanning / CodeQL

Omittable 'exists' variable Warning

This exists variable can be omitted by using a don't-care expression
in this argument
.

private module FieldExprMatching = Matching<FieldExprMatchingInput>;
Expand Down
Loading

0 comments on commit f4d6920

Please # to comment.