Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Mar 3, 2025
1 parent 432d2c8 commit 3757199
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 145 deletions.
101 changes: 56 additions & 45 deletions rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ private module Types {
import Types

private module Input1 implements InputSig1<Location> {
private import rust as Rust
private import codeql.rust.elements.internal.generated.Raw
private import codeql.rust.elements.internal.generated.Synth

Expand Down Expand Up @@ -361,8 +360,6 @@ private module Input1 implements InputSig1<Location> {
tp0 order by kind, id
)
}

class AstNode = Rust::AstNode;
}

private import Input1
Expand Down Expand Up @@ -516,15 +513,11 @@ private module TypeMentions {

private import TypeMentions

private predicate resolveTypeAlias = resolveType/2;

private module Input2 implements InputSig2 {
class TypeMention = TypeMentions::TypeMention;

TypeMention getABaseTypeMention(Type t) { result = t.getABaseTypeMention() }

predicate resolveType = resolveTypeAlias/2;

additional TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) {
exists(int i |
result = path.getPart().getGenericArgList().getTypeArgument(i) and
Expand Down Expand Up @@ -690,6 +683,17 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
result = getExplicitTypeArgMention(this.getPath(), apos.asTypeParam()).resolveTypeAt(path)
}

AstNode getNode(AccessPosition apos) {
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
or
result = this and
apos = TDeclPos()
}

Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNode(apos), path)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
}

Expand All @@ -714,13 +718,6 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
apos = ppos
}

AstNode getArg(Access a, AccessPosition pos) {
result = a.getFieldExpr(pos.asFieldPos()).getExpr()
or
result = a and
pos = TDeclPos()
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
exists(TypeReprMention tp | tp = decl.getField(pos.asFieldPos()).getTypeRepr() |
t = tp.resolveTypeAt(path)
Expand All @@ -743,7 +740,10 @@ private module RecordFieldMatchingInput implements MatchingInputSig {
private module RecordFieldMatching = Matching<RecordFieldMatchingInput>;

private Type resolveRecordExprType(AstNode n, TypePath path) {
result = RecordFieldMatching::resolveArgType(n, _, _, path)
exists(RecordFieldMatchingInput::Access a, RecordFieldMatchingInput::AccessPosition apos |
n = a.getNode(apos) and
result = RecordFieldMatching::resolveAccessType(a, apos, _, path)
)
}

pragma[nomagic]
Expand Down Expand Up @@ -950,6 +950,23 @@ private module FunctionMatchingInput implements MatchingInputSig {
)
}

AstNode getNode(AccessPosition apos) {
exists(int p, boolean inMethod |
argPos(this, result, p, inMethod) and
apos = TPositionalAccessPosition(p, inMethod)
)
or
result = this.(MethodCallExpr).getReceiver() and
apos = TSelfAccessPosition()
or
result = this and
apos = TReturnAccessPosition()
}

Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNode(apos), path)
}

Declaration getTarget() {
result =
[
Expand Down Expand Up @@ -1007,19 +1024,6 @@ private module FunctionMatchingInput implements MatchingInputSig {
)
}

AstNode getArg(Access a, AccessPosition pos) {
exists(int p, boolean inMethod |
argPos(a, result, p, inMethod) and
pos = TPositionalAccessPosition(p, inMethod)
)
or
result = a.(MethodCallExpr).getReceiver() and
pos = TSelfAccessPosition()
or
result = a and
pos = TReturnAccessPosition()
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
t = decl.getParameterType(pos, path)
or
Expand All @@ -1032,26 +1036,28 @@ private module FunctionMatching = Matching<FunctionMatchingInput>;

pragma[nomagic]
private Type resolveReceiverType(AstNode n) {
exists(FunctionMatchingInput::AccessPosition apos |
exists(FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos |
result = resolveType(n) and
n = FunctionMatchingInput::getArg(_, apos) and
n = a.getNode(apos) and
apos.isSelf()
)
}

pragma[nomagic]
private Type resolveCallExprBaseType(AstNode n, TypePath path) {
exists(
FunctionMatchingInput::AccessPosition apos, FunctionMatchingInput::DeclarationPosition ppos
FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos,
FunctionMatchingInput::DeclarationPosition ppos, TypePath path0, Type res
|
n = a.getNode(apos) and
res = FunctionMatching::resolveAccessType(a, apos, ppos, path0)
|
result = FunctionMatching::resolveArgType(n, apos, ppos, path) and
result = res and
path = path0 and
not apos.isSelf()
or
apos.isSelf() and
exists(Type receiverType, Type res, TypePath path0 |
receiverType = resolveReceiverType(n) and
res = FunctionMatching::resolveArgType(n, apos, ppos, path0)
|
exists(Type receiverType | receiverType = resolveReceiverType(n) |
if receiverType = TRefType()
then
result = res and
Expand Down Expand Up @@ -1101,6 +1107,17 @@ private module FieldExprMatchingInput implements MatchingInputSig {
class Access extends FieldExpr {
Type getExplicitTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }

AstNode getNode(AccessPosition apos) {
result = this.getExpr() and apos = TSelfDeclarationPosition()
or
result = this and
apos = TReturnPos()
}

Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNode(apos), path)
}

Declaration getTarget() {
result = resolveRecordFieldExpr(this)
or
Expand Down Expand Up @@ -1159,13 +1176,6 @@ private module FieldExprMatchingInput implements MatchingInputSig {
apos = ppos
}

AstNode getArg(Access a, AccessPosition pos) {
result = a.getExpr() and pos = TSelfDeclarationPosition()
or
result = a and
pos = TReturnPos()
}

predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) {
pos = TSelfDeclarationPosition() and
exists(Struct s | s.getRecordField(_) = decl or s.getTupleField(_) = decl |
Expand All @@ -1186,8 +1196,9 @@ private module FieldExprMatchingInput implements MatchingInputSig {
private module FieldExprMatching = Matching<FieldExprMatchingInput>;

private Type resolveFieldExprType(AstNode n, TypePath path) {
exists(FieldExprMatchingInput::AccessPosition apos |
result = FieldExprMatching::resolveArgType(n, apos, _, path) and
exists(FieldExprMatchingInput::Access a, FieldExprMatchingInput::AccessPosition apos |
n = a.getNode(apos) and
result = FieldExprMatching::resolveAccessType(a, apos, _, path) and
apos.isReturn()
)
}
Expand Down
Loading

0 comments on commit 3757199

Please # to comment.