Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ module Impl {
}
}

/** Holds if the call expression dispatches to a trait method. */
/** Holds if the call expression dispatches to a method. */
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
exists(Path path, Function f |
path = call.getFunction().(PathExpr).getPath() and
Expand Down
24 changes: 17 additions & 7 deletions rust/ql/lib/codeql/rust/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ abstract class ItemNode extends Locatable {
exists(ItemNode node |
this = node.(ImplItemNode).resolveSelfTy() and
result = node.getASuccessorRec(name) and
result instanceof AssocItemNode
result instanceof AssocItemNode and
not result instanceof TypeAlias
)
or
// trait items with default implementations made available in an implementation
Expand All @@ -181,6 +182,10 @@ abstract class ItemNode extends Locatable {
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(TypeAliasItemNode).resolveAlias().getASuccessorRec(name) and
// type parameters defined in the RHS are not available in the LHS
not result instanceof TypeParam
}

/**
Expand Down Expand Up @@ -289,6 +294,8 @@ abstract class ItemNode extends Locatable {
Location getLocation() { result = super.getLocation() }
}

abstract class TypeItemNode extends ItemNode { }

/** A module or a source file. */
abstract private class ModuleLikeNode extends ItemNode {
/** Gets an item that may refer directly to items defined in this module. */
Expand Down Expand Up @@ -438,7 +445,7 @@ private class ConstItemNode extends AssocItemNode instanceof Const {
override TypeParam getTypeParam(int i) { none() }
}

private class EnumItemNode extends ItemNode instanceof Enum {
private class EnumItemNode extends TypeItemNode instanceof Enum {
override string getName() { result = Enum.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -746,7 +753,7 @@ private class ModuleItemNode extends ModuleLikeNode instanceof Module {
}
}

private class StructItemNode extends ItemNode instanceof Struct {
private class StructItemNode extends TypeItemNode instanceof Struct {
override string getName() { result = Struct.super.getName().getText() }

override Namespace getNamespace() {
Expand Down Expand Up @@ -781,7 +788,7 @@ private class StructItemNode extends ItemNode instanceof Struct {
}
}

class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
class TraitItemNode extends ImplOrTraitItemNode, TypeItemNode instanceof Trait {
pragma[nomagic]
Path getABoundPath() {
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
Expand Down Expand Up @@ -838,7 +845,10 @@ class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
}
}

class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
class TypeAliasItemNode extends TypeItemNode, AssocItemNode instanceof TypeAlias {
pragma[nomagic]
ItemNode resolveAlias() { result = resolvePathFull(super.getTypeRepr().(PathTypeRepr).getPath()) }

override string getName() { result = TypeAlias.super.getName().getText() }

override predicate hasImplementation() { super.hasTypeRepr() }
Expand All @@ -854,7 +864,7 @@ class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
override string getCanonicalPath(Crate c) { none() }
}

private class UnionItemNode extends ItemNode instanceof Union {
private class UnionItemNode extends TypeItemNode instanceof Union {
override string getName() { result = Union.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -912,7 +922,7 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
override string getCanonicalPath(Crate c) { none() }
}

class TypeParamItemNode extends ItemNode instanceof TypeParam {
class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
private WherePred getAWherePred() {
exists(ItemNode declaringItem |
this = resolveTypeParamPathTypeRepr(result.getTypeRepr()) and
Expand Down
19 changes: 0 additions & 19 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ class TraitType extends Type, TTrait {

override TypeParameter getTypeParameter(int i) {
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
or
result =
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
}

override TypeMention getTypeParameterDefault(int i) {
Expand Down Expand Up @@ -299,20 +296,6 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
override Location getLocation() { result = typeParam.getLocation() }
}

/**
* Gets the type alias that is the `i`th type parameter of `trait`. Type aliases
* are numbered consecutively but in arbitrary order, starting from the index
* following the last ordinary type parameter.
*/
predicate traitAliasIndex(Trait trait, int i, TypeAlias typeAlias) {
typeAlias =
rank[i + 1 - trait.getNumberOfGenericParams()](TypeAlias alias |
trait.(TraitItemNode).getADescendant() = alias
|
alias order by idOfTypeParameterAstNode(alias)
)
}

/**
* A type parameter corresponding to an associated type in a trait.
*
Expand Down Expand Up @@ -341,8 +324,6 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
/** Gets the trait that contains this associated type declaration. */
TraitItemNode getTrait() { result.getAnAssocItem() = typeAlias }

int getIndex() { traitAliasIndex(_, result, typeAlias) }

override string toString() { result = typeAlias.getName().getText() }

override Location getLocation() { result = typeAlias.getLocation() }
Expand Down
138 changes: 105 additions & 33 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
private import codeql.rust.frameworks.stdlib.Stdlib
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
private import codeql.rust.elements.Call
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl

class Type = T::Type;

Expand Down Expand Up @@ -353,19 +354,6 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
)
}

/**
* Gets any of the types mentioned in `path` that corresponds to the type
* parameter `tp`.
*/
private TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) {
exists(int i |
result = path.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
tp = resolvePath(path).getTypeParam(pragma[only_bind_into](i))
)
or
result = getExplicitTypeArgMention(path.getQualifier(), tp)
}

/**
* A matching configuration for resolving types of struct expressions
* like `Foo { bar = baz }`.
Expand Down Expand Up @@ -452,9 +440,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
class AccessPosition = DeclarationPosition;

class Access extends StructExpr {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
result = getExplicitTypeArgMention(this.getPath(), apos.asTypeParam()).resolveTypeAt(path)
}
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }

AstNode getNodeAt(AccessPosition apos) {
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
Expand All @@ -465,6 +451,16 @@ private module StructExprMatchingInput implements MatchingInputSig {

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The struct type is supplied explicitly as a type qualifier, e.g.
// `Foo<Bar>::Variant { ... }`.
apos.isStructPos() and
exists(Path p, TypeMention tm |
p = this.getPath() and
if resolvePath(p) instanceof Variant then tm = p.getQualifier() else tm = p
|
result = tm.resolveTypeAt(path)
)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
Expand Down Expand Up @@ -537,15 +533,24 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {

abstract Type getReturnType(TypePath path);

final Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = this.getParameterType(dpos, path)
or
dpos.isReturn() and
result = this.getReturnType(path)
}
}

private class TupleStructDecl extends Declaration, Struct {
abstract private class TupleDeclaration extends Declaration {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
dpos.isSelf() and
result = this.getReturnType(path)
}
}

private class TupleStructDecl extends TupleDeclaration, Struct {
TupleStructDecl() { this.isTuple() }

override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
Expand All @@ -568,7 +573,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
}

private class TupleVariantDecl extends Declaration, Variant {
private class TupleVariantDecl extends TupleDeclaration, Variant {
TupleVariantDecl() { this.isTuple() }

override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
Expand Down Expand Up @@ -597,13 +602,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
or
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
exists(ImplOrTraitItemNode i | this = i.getAnAssocItem() |
typeParamMatchPosition(i.getTypeParam(_), result, ppos)
or
ppos.isImplicit() and result = TSelfTypeParameter(trait)
ppos.isImplicit() and result = TSelfTypeParameter(i)
or
ppos.isImplicit() and
result.(AssociatedTypeTypeParameter).getTrait() = trait
result.(AssociatedTypeTypeParameter).getTrait() = i
)
or
ppos.isImplicit() and
Expand All @@ -625,6 +630,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
or
result = inferImplicitSelfType(self, path) // `self` parameter without type annotation
)
or
// For associated functions, we may also need to match type arguments against
// the `Self` type. For example, in
//
// ```rust
// struct Foo<T>(T);
//
// impl<T : Default> Foo<T> {
// fn default() -> Self {
// Foo(Default::default())
// }
// }
//
// Foo::<i32>::default();
// ```
//
// we need to match `i32` against the type parameter `T` of the `impl` block.
exists(ImplOrTraitItemNode i |
this = i.getAnAssocItem() and
dpos.isSelf() and
not this.getParamList().hasSelfParam()
|
result = TSelfTypeParameter(i) and
path.isEmpty()
or
result = resolveImplSelfType(i, path)
)
}

private Type resolveRetType(TypePath path) {
Expand Down Expand Up @@ -670,9 +702,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl

final class Access extends Call {
pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
exists(Path p, int i |
p = CallExprImpl::getFunctionPath(this) and
arg = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i))
)
or
arg =
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
Expand All @@ -696,6 +733,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
apos = TArgumentAccessPosition(CallImpl::TSelfArgumentPosition(), false, false) and
exists(PathExpr pe, TypeMention tm |
pe = this.(CallExpr).getFunction() and
tm = pe.getPath().getQualifier() and
result = tm.resolveTypeAt(path)
)
}

Declaration getTarget() {
Expand Down Expand Up @@ -1110,12 +1155,7 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
}

final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
// We want the method calls that don't have a path to a concrete method in
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
(this instanceof CallExpr implies exists(this.getTrait()))
}
MethodCall() { exists(this.getReceiver()) }

/** Gets the type of the receiver of the method call at `path`. */
Type getTypeAt(TypePath path) {
Expand Down Expand Up @@ -1582,19 +1622,51 @@ private module Debug {
result = resolveMethodCallTarget(mce)
}

predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
self = getRelevantLocatable() and
t = inferImplicitSelfType(self, path)
}

predicate debugInferCallExprBaseType(AstNode n, TypePath path, Type t) {
n = getRelevantLocatable() and
t = inferCallExprBaseType(n, path)
}

predicate debugTypeMention(TypeMention tm, TypePath path, Type type) {
tm = getRelevantLocatable() and
tm.resolveTypeAt(path) = type
}

pragma[nomagic]
private int countTypes(AstNode n, TypePath path, Type t) {
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
t = inferType(n, path) and
result = strictcount(Type t0 | t0 = inferType(n, path))
}

predicate maxTypes(AstNode n, TypePath path, Type t, int c) {
c = countTypes(n, path, t) and
c = max(countTypes(_, _, _))
c = countTypesAtPath(n, path, t) and
c = max(countTypesAtPath(_, _, _))
}

pragma[nomagic]
private predicate typePathLength(AstNode n, TypePath path, Type t, int len) {
t = inferType(n, path) and
len = path.length()
}

predicate maxTypePath(AstNode n, TypePath path, Type t, int len) {
typePathLength(n, path, t, len) and
len = max(int i | typePathLength(_, _, _, i))
}

pragma[nomagic]
private int countTypePaths(AstNode n, TypePath path, Type t) {
t = inferType(n, path) and
result = strictcount(TypePath path0, Type t0 | t0 = inferType(n, path0))
}

predicate maxTypePaths(AstNode n, TypePath path, Type t, int c) {
c = countTypePaths(n, path, t) and
c = max(countTypePaths(_, _, _))
}
}
Loading