From 37571999374bd00e63a65e797a033a7fc7900501 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Mon, 3 Mar 2025 14:06:26 +0100 Subject: [PATCH] refactor --- .../rust/elements/internal/TypeInference.qll | 101 +++++----- .../elements/internal/TypeInferenceShared.qll | 175 ++++++++---------- 2 files changed, 131 insertions(+), 145 deletions(-) diff --git a/rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll index 3083bb46da02..d2b744828151 100644 --- a/rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/elements/internal/TypeInference.qll @@ -306,7 +306,6 @@ private module Types { import Types private module Input1 implements InputSig1 { - private import rust as Rust private import codeql.rust.elements.internal.generated.Raw private import codeql.rust.elements.internal.generated.Synth @@ -361,8 +360,6 @@ private module Input1 implements InputSig1 { tp0 order by kind, id ) } - - class AstNode = Rust::AstNode; } private import Input1 @@ -516,15 +513,11 @@ private module TypeMentions { private import TypeMentions -private predicate resolveTypeAlias = resolveType/2; - private module Input2 implements InputSig2 { class TypeMention = TypeMentions::TypeMention; TypeMention getABaseTypeMention(Type t) { result = t.getABaseTypeMention() } - predicate resolveType = resolveTypeAlias/2; - additional TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) { exists(int i | result = path.getPart().getGenericArgList().getTypeArgument(i) and @@ -690,6 +683,17 @@ private module RecordFieldMatchingInput implements MatchingInputSig { result = getExplicitTypeArgMention(this.getPath(), apos.asTypeParam()).resolveTypeAt(path) } + AstNode getNode(AccessPosition apos) { + result = this.getFieldExpr(apos.asFieldPos()).getExpr() + or + result = this and + apos = TDeclPos() + } + + Type getResolvedType(AccessPosition apos, TypePath path) { + result = resolveType(this.getNode(apos), path) + } + Declaration getTarget() { result = resolvePath(this.getPath()) } } @@ -714,13 +718,6 @@ private module RecordFieldMatchingInput implements MatchingInputSig { apos = ppos } - AstNode getArg(Access a, AccessPosition pos) { - result = a.getFieldExpr(pos.asFieldPos()).getExpr() - or - result = a and - pos = TDeclPos() - } - predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) { exists(TypeReprMention tp | tp = decl.getField(pos.asFieldPos()).getTypeRepr() | t = tp.resolveTypeAt(path) @@ -743,7 +740,10 @@ private module RecordFieldMatchingInput implements MatchingInputSig { private module RecordFieldMatching = Matching; private Type resolveRecordExprType(AstNode n, TypePath path) { - result = RecordFieldMatching::resolveArgType(n, _, _, path) + exists(RecordFieldMatchingInput::Access a, RecordFieldMatchingInput::AccessPosition apos | + n = a.getNode(apos) and + result = RecordFieldMatching::resolveAccessType(a, apos, _, path) + ) } pragma[nomagic] @@ -950,6 +950,23 @@ private module FunctionMatchingInput implements MatchingInputSig { ) } + AstNode getNode(AccessPosition apos) { + exists(int p, boolean inMethod | + argPos(this, result, p, inMethod) and + apos = TPositionalAccessPosition(p, inMethod) + ) + or + result = this.(MethodCallExpr).getReceiver() and + apos = TSelfAccessPosition() + or + result = this and + apos = TReturnAccessPosition() + } + + Type getResolvedType(AccessPosition apos, TypePath path) { + result = resolveType(this.getNode(apos), path) + } + Declaration getTarget() { result = [ @@ -1007,19 +1024,6 @@ private module FunctionMatchingInput implements MatchingInputSig { ) } - AstNode getArg(Access a, AccessPosition pos) { - exists(int p, boolean inMethod | - argPos(a, result, p, inMethod) and - pos = TPositionalAccessPosition(p, inMethod) - ) - or - result = a.(MethodCallExpr).getReceiver() and - pos = TSelfAccessPosition() - or - result = a and - pos = TReturnAccessPosition() - } - predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) { t = decl.getParameterType(pos, path) or @@ -1032,9 +1036,9 @@ private module FunctionMatching = Matching; pragma[nomagic] private Type resolveReceiverType(AstNode n) { - exists(FunctionMatchingInput::AccessPosition apos | + exists(FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos | result = resolveType(n) and - n = FunctionMatchingInput::getArg(_, apos) and + n = a.getNode(apos) and apos.isSelf() ) } @@ -1042,16 +1046,18 @@ private Type resolveReceiverType(AstNode n) { pragma[nomagic] private Type resolveCallExprBaseType(AstNode n, TypePath path) { exists( - FunctionMatchingInput::AccessPosition apos, FunctionMatchingInput::DeclarationPosition ppos + FunctionMatchingInput::Access a, FunctionMatchingInput::AccessPosition apos, + FunctionMatchingInput::DeclarationPosition ppos, TypePath path0, Type res + | + n = a.getNode(apos) and + res = FunctionMatching::resolveAccessType(a, apos, ppos, path0) | - result = FunctionMatching::resolveArgType(n, apos, ppos, path) and + result = res and + path = path0 and not apos.isSelf() or apos.isSelf() and - exists(Type receiverType, Type res, TypePath path0 | - receiverType = resolveReceiverType(n) and - res = FunctionMatching::resolveArgType(n, apos, ppos, path0) - | + exists(Type receiverType | receiverType = resolveReceiverType(n) | if receiverType = TRefType() then result = res and @@ -1101,6 +1107,17 @@ private module FieldExprMatchingInput implements MatchingInputSig { class Access extends FieldExpr { Type getExplicitTypeArgument(TypeArgumentPosition apos, TypePath path) { none() } + AstNode getNode(AccessPosition apos) { + result = this.getExpr() and apos = TSelfDeclarationPosition() + or + result = this and + apos = TReturnPos() + } + + Type getResolvedType(AccessPosition apos, TypePath path) { + result = resolveType(this.getNode(apos), path) + } + Declaration getTarget() { result = resolveRecordFieldExpr(this) or @@ -1159,13 +1176,6 @@ private module FieldExprMatchingInput implements MatchingInputSig { apos = ppos } - AstNode getArg(Access a, AccessPosition pos) { - result = a.getExpr() and pos = TSelfDeclarationPosition() - or - result = a and - pos = TReturnPos() - } - predicate parameterType(Declaration decl, DeclarationPosition pos, TypePath path, Type t) { pos = TSelfDeclarationPosition() and exists(Struct s | s.getRecordField(_) = decl or s.getTupleField(_) = decl | @@ -1186,8 +1196,9 @@ private module FieldExprMatchingInput implements MatchingInputSig { private module FieldExprMatching = Matching; private Type resolveFieldExprType(AstNode n, TypePath path) { - exists(FieldExprMatchingInput::AccessPosition apos | - result = FieldExprMatching::resolveArgType(n, apos, _, path) and + exists(FieldExprMatchingInput::Access a, FieldExprMatchingInput::AccessPosition apos | + n = a.getNode(apos) and + result = FieldExprMatching::resolveAccessType(a, apos, _, path) and apos.isReturn() ) } diff --git a/rust/ql/lib/codeql/rust/elements/internal/TypeInferenceShared.qll b/rust/ql/lib/codeql/rust/elements/internal/TypeInferenceShared.qll index f93a03f6dc49..e498c4c15362 100644 --- a/rust/ql/lib/codeql/rust/elements/internal/TypeInferenceShared.qll +++ b/rust/ql/lib/codeql/rust/elements/internal/TypeInferenceShared.qll @@ -64,23 +64,12 @@ signature module InputSig1 { bindingset[apos] bindingset[ppos] predicate typeArgumentParameterPositionMatch(TypeArgumentPosition apos, TypeParameterPosition ppos); - - /** An AST node. */ - class AstNode { - /** Gets a textual representation of this node. */ - string toString(); - - /** Gets the location of this node. */ - Location getLocation(); - } } module Make1 Input1> { private import Input1 private import codeql.util.DenseRank - final private class AstNodeFinal = AstNode; - private module DenseRankInput implements DenseRankInputSig { class Ranked = TypeParameter; @@ -227,9 +216,6 @@ module Make1 Input1> { * the class ``C`1`` has the base type mention of `Base`. */ TypeMention getABaseTypeMention(Type t); - - /** Gets the type that `n` resolves to at `path`. */ - Type resolveType(AstNode n, TypePath path); } module Make2 { @@ -238,10 +224,7 @@ module Make1 Input1> { pragma[nomagic] private Type resolveTypeMentionRoot(TypeMention tm) { result = tm.resolveTypeAt("") } - /** - * Provides the parameterized module `NodeBaseType` for computing base types - * of AST nodes. - */ + /** Provides logic for computing base types. */ private module BaseTypes { /** * Holds if `baseMention` is a (transitive) base type mention of `sub`, @@ -265,7 +248,7 @@ module Make1 Input1> { * - `T4` is mentioned at `0.0.0` for transitive base type `Base` of `Sub`. */ pragma[nomagic] - private predicate baseTypeMentionHasTypeParameterAt( + predicate baseTypeMentionHasTypeParameterAt( Type sub, TypeMention baseMention, TypePath path, TypeParameter tp ) { exists(TypeMention immediateBaseMention, TypePath pathToTypeParam | @@ -318,7 +301,7 @@ module Make1 Input1> { * - ``C`1`` is mentioned at `0` and `0.0` for transitive base type `Base` of `Sub`. */ pragma[nomagic] - private predicate baseTypeMentionHasNonTypeParameterAt( + predicate baseTypeMentionHasNonTypeParameterAt( Type sub, TypeMention baseMention, TypePath path, Type t ) { not t instanceof TypeParameter and @@ -354,63 +337,6 @@ module Make1 Input1> { ) ) } - - /** Holds if `n` is relevant for computing base type information. */ - signature predicate relevantNodeSig(AstNode n); - - module NodeBaseType { - private class ReleventNode extends AstNodeFinal { - ReleventNode() { relevantNode(this) } - } - - pragma[nomagic] - private Type resolveRootType(ReleventNode n) { result = resolveType(n, "") } - - pragma[nomagic] - private Type resolveTypeAt(ReleventNode n, TypeParameter tp, TypePath suffix) { - exists(TypePath path0 | - result = resolveType(n, path0) and - path0.startsWith(tp, suffix) - ) - } - - /** - * Holds if `baseMention` is a (transitive) base type mention of the type of - * `n`, and `t` is mentioned (implicitly) at `path` inside `base`. For example, - * in - * - * ```csharp - * class C { } - * - * class Base { } - * - * class Mid : Base> { } - * - * class Sub : Mid> { } - * - * new Sub(); - * ``` - * - * for the node `new Sub()`: - * - * - ``C`1`` is mentioned at `0` for immediate base type ``Mid`1``, - * - `int` is mentioned at `0.1` for immediate base type ``Mid`1``, - * - ``C`1`` is mentioned at `0` and `0.0` for transitive base type ``Base`1``, and - * - `int` is mentioned at `0.0.1` for transitive base type ``Base`1``. - */ - pragma[nomagic] - predicate hasBaseType(ReleventNode n, TypeMention baseMention, TypePath path, Type t) { - exists(Type sub | sub = resolveRootType(n) | - baseTypeMentionHasNonTypeParameterAt(sub, baseMention, path, t) - or - exists(TypePath prefix, TypePath suffix, TypeParameter i | - baseTypeMentionHasTypeParameterAt(sub, baseMention, prefix, i) and - t = resolveTypeAt(n, i, suffix) and - path = prefix.append(suffix) - ) - ) - } - } } private import BaseTypes @@ -444,6 +370,14 @@ module Make1 Input1> { */ Type getExplicitTypeArgument(TypeArgumentPosition apos, TypePath path); + /** + * Gets the type at `path` at argument position `apos` of this access. + * + * For example, if this access is the method call `M(42)`, then the type + * at position `0` is `int`. + */ + Type getResolvedType(AccessPosition apos, TypePath path); + /** Gets the declaration that this access targets. */ Declaration getTarget(); } @@ -484,8 +418,6 @@ module Make1 Input1> { tAdj = t } - AstNode getArg(Access a, AccessPosition apos); - predicate parameterType(Declaration decl, DeclarationPosition ppos, TypePath path, Type t); } @@ -497,14 +429,11 @@ module Make1 Input1> { private import Input pragma[nomagic] - private predicate argumentType0( + private predicate argumentTypeUnadjusted( Access a, AccessPosition apos, Declaration target, TypePath path, Type t ) { target = a.getTarget() and - exists(AstNode arg | - arg = getArg(a, apos) and - t = resolveType(arg, path) - ) + t = a.getResolvedType(apos, path) } pragma[nomagic] @@ -512,7 +441,7 @@ module Make1 Input1> { Access a, AccessPosition apos, Declaration target, TypePath path, Type t ) { exists(TypePath path0, Type t0 | - argumentType0(a, apos, target, path0, t0) and + argumentTypeUnadjusted(a, apos, target, path0, t0) and adjustArgType(a, apos, target, path0, t0, path, t) ) } @@ -546,21 +475,77 @@ module Make1 Input1> { ) } - private predicate relevantNode(AstNode n) { - exists(Access a, AccessPosition apos, Declaration target | - n = getArg(a, apos) and + private predicate relevantNode(Access a, AccessPosition apos) { + exists(Declaration target | + // n = getArg(a, apos) and argumentType(a, apos, target, _, _) and parameterType(target, _, _, any(TypeParameter tp)) ) } + private module Foo { + pragma[nomagic] + private Type resolveRootType(Access a, AccessPosition apos) { + relevantNode(a, apos) and + result = a.getResolvedType(apos, "") + } + + pragma[nomagic] + private Type resolveTypeAt(Access a, AccessPosition apos, TypeParameter tp, TypePath suffix) { + relevantNode(a, apos) and + exists(TypePath path0 | + result = a.getResolvedType(apos, path0) and + path0.startsWith(tp, suffix) + ) + } + + /** + * Holds if `baseMention` is a (transitive) base type mention of the type of + * `n`, and `t` is mentioned (implicitly) at `path` inside `base`. For example, + * in + * + * ```csharp + * class C { } + * + * class Base { } + * + * class Mid : Base> { } + * + * class Sub : Mid> { } + * + * new Sub(); + * ``` + * + * for the node `new Sub()`: + * + * - ``C`1`` is mentioned at `0` for immediate base type ``Mid`1``, + * - `int` is mentioned at `0.1` for immediate base type ``Mid`1``, + * - ``C`1`` is mentioned at `0` and `0.0` for transitive base type ``Base`1``, and + * - `int` is mentioned at `0.0.1` for transitive base type ``Base`1``. + */ + pragma[nomagic] + predicate hasBaseType( + Access a, AccessPosition apos, TypeMention baseMention, TypePath path, Type t + ) { + exists(Type sub | sub = resolveRootType(a, apos) | + baseTypeMentionHasNonTypeParameterAt(sub, baseMention, path, t) + or + exists(TypePath prefix, TypePath suffix, TypeParameter i | + baseTypeMentionHasTypeParameterAt(sub, baseMention, prefix, i) and + t = resolveTypeAt(a, apos, i, suffix) and + path = prefix.append(suffix) + ) + ) + } + } + pragma[nomagic] private predicate argumentBaseTypeAt( Access a, AccessPosition apos, Declaration target, Type base, TypePath path, Type t ) { exists(TypeMention tm | target = a.getTarget() and - NodeBaseType::hasBaseType(getArg(a, apos), tm, path, t) and + Foo::hasBaseType(a, apos, tm, path, t) and base = resolveTypeMentionRoot(tm) ) } @@ -620,9 +605,7 @@ module Make1 Input1> { } pragma[nomagic] - private Type resolveAccess( - Access a, AccessPosition apos, DeclarationPosition ppos, TypePath path - ) { + Type resolveAccessType(Access a, AccessPosition apos, DeclarationPosition ppos, TypePath path) { accessDeclarationPositionMatch(apos, ppos) and ( exists(Declaration target, TypePath prefix, TypeParameter tp, TypePath suffix | @@ -638,14 +621,6 @@ module Make1 Input1> { ) ) } - - pragma[nomagic] - Type resolveArgType(AstNode arg, AccessPosition apos, DeclarationPosition ppos, TypePath path) { - exists(Access a | - arg = getArg(a, apos) and - result = resolveAccess(a, apos, ppos, path) - ) - } } /** Provides consitency checks. */