Skip to content
Open
37 changes: 2 additions & 35 deletions src/python/pants/backend/java/compile/javac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import logging
from itertools import chain

from pants.backend.java.dependency_inference.rules import (
JavaInferredDependenciesAndExportsRequest,
infer_java_dependencies_and_exports_via_source_analysis,
)
from pants.backend.java.dependency_inference.rules import rules as java_dep_inference_rules
from pants.backend.java.subsystems.javac import JavacSubsystem
from pants.backend.java.target_types import JavaFieldSet, JavaGeneratorFieldSet, JavaSourceField
Expand Down Expand Up @@ -83,26 +79,6 @@ async def compile_java_source(
exit_code=1,
)

# Capture just the `ClasspathEntry` objects that are listed as `export` types by source analysis
deps_to_classpath_entries = dict(
zip(request.component.dependencies, direct_dependency_classpath_entries or ())
)
# Re-request inferred dependencies to get a list of export dependency addresses
inferred_dependencies = await concurrently(
infer_java_dependencies_and_exports_via_source_analysis(
JavaInferredDependenciesAndExportsRequest(tgt[JavaSourceField]), **implicitly()
)
for tgt in request.component.members
if JavaFieldSet.is_applicable(tgt)
)
flat_exports = {export for i in inferred_dependencies for export in i.exports}

export_classpath_entries = [
classpath_entry
for coarsened_target, classpath_entry in deps_to_classpath_entries.items()
if any(m.address in flat_exports for m in coarsened_target.members)
]

# Then collect the component's sources.
component_members_with_sources = tuple(
t for t in request.component.members if t.has_field(SourcesField)
Expand Down Expand Up @@ -157,8 +133,8 @@ async def compile_java_source(

usercp = "__cp"
user_classpath = Classpath(direct_dependency_classpath_entries, request.resolve)
classpath_arg = ":".join(user_classpath.root_immutable_inputs_args(prefix=usercp))
immutable_input_digests = dict(user_classpath.root_immutable_inputs(prefix=usercp))
classpath_arg = ":".join(user_classpath.immutable_inputs_args(prefix=usercp))
immutable_input_digests = dict(user_classpath.immutable_inputs(prefix=usercp))

# Compile.
compile_result = await execute_process(
Expand Down Expand Up @@ -235,15 +211,6 @@ async def compile_java_source(
jar_output_digest, output_files, direct_dependency_classpath_entries
)

if export_classpath_entries:
merged_export_digest = await merge_digests(
MergeDigests((output_classpath.digest, *(i.digest for i in export_classpath_entries)))
)
merged_classpath = ClasspathEntry.merge(
merged_export_digest, (output_classpath, *export_classpath_entries)
)
output_classpath = merged_classpath

return FallibleClasspathEntry.from_fallible_process_result(
str(request.component),
compile_result,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,17 @@ class CompilationUnitAnalysis {
Optional<String> declaredPackage,
ArrayList<Import> imports,
ArrayList<String> topLevelTypes,
ArrayList<String> consumedTypes,
ArrayList<String> exportTypes) {
ArrayList<String> consumedTypes) {
this.declaredPackage = declaredPackage;
this.imports = imports;
this.topLevelTypes = topLevelTypes;
this.consumedTypes = consumedTypes;
this.exportTypes = exportTypes;
}

public final Optional<String> declaredPackage;
public final ArrayList<Import> imports;
public final ArrayList<String> topLevelTypes;
public final ArrayList<String> consumedTypes;
public final ArrayList<String> exportTypes;
}

/**
Expand Down Expand Up @@ -151,20 +148,13 @@ public static void main(String[] args) throws Exception {
.collect(Collectors.toList()));

HashSet<Type> candidateConsumedTypes = new HashSet<>();
HashSet<Type> candidateExportTypes = new HashSet<>();

Consumer<Type> consumed =
(type) -> {
candidateConsumedTypes.add(type);
};
Consumer<Type> export =
(type) -> {
candidateConsumedTypes.add(type);
candidateExportTypes.add(type);
};

HashSet<String> consumedIdentifiers = new HashSet<>();
HashSet<String> exportIdentifiers = new HashSet<>();

cu.walk(
new Consumer<Node>() {
Expand All @@ -180,16 +170,16 @@ public void accept(Node node) {
}
if (node instanceof MethodDeclaration) {
MethodDeclaration methodDecl = (MethodDeclaration) node;
export.accept(methodDecl.getType());
consumed.accept(methodDecl.getType());
Comment on lines -183 to +173
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The export consumer was adding to both lists. Since exports aren't needed anymore lines like this just call consumed.accept instead of being removed entirely

for (Parameter param : methodDecl.getParameters()) {
export.accept(param.getType());
consumed.accept(param.getType());
}
methodDecl.getThrownExceptions().stream().forEach(consumed);
}
if (node instanceof ClassOrInterfaceDeclaration) {
ClassOrInterfaceDeclaration classOrIntfDecl = (ClassOrInterfaceDeclaration) node;
classOrIntfDecl.getExtendedTypes().stream().forEach(export);
classOrIntfDecl.getImplementedTypes().stream().forEach(export);
classOrIntfDecl.getExtendedTypes().stream().forEach(consumed);
classOrIntfDecl.getImplementedTypes().stream().forEach(consumed);
}
if (node instanceof AnnotationExpr) {
AnnotationExpr annoExpr = (AnnotationExpr) node;
Expand Down Expand Up @@ -220,16 +210,12 @@ public void accept(Node node) {
for (Type type : candidateConsumedTypes) {
List<String> identifiersForType = unwrapIdentifiersForType(type);
consumedIdentifiers.addAll(identifiersForType);
if (candidateExportTypes.contains(type)) {
exportIdentifiers.addAll(identifiersForType);
}
}

ArrayList<String> consumedTypes = new ArrayList<>(consumedIdentifiers);
ArrayList<String> exportTypes = new ArrayList<>(exportIdentifiers);
CompilationUnitAnalysis analysis =
new CompilationUnitAnalysis(
declaredPackage, imports, topLevelTypes, consumedTypes, exportTypes);
declaredPackage, imports, topLevelTypes, consumedTypes);
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new Jdk8Module());
mapper.writeValue(new File(analysisOutputPath), analysis);
Expand Down
146 changes: 106 additions & 40 deletions src/python/pants/backend/java/dependency_inference/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass

from pants.backend.java.dependency_inference import symbol_mapper
Expand Down Expand Up @@ -35,6 +34,13 @@
from pants.jvm.target_types import JvmResolveField
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet

# Java standard library package prefixes - types starting with these are always fully qualified
JAVA_STDLIB_PREFIXES = frozenset([
"java.", "javax.", "jakarta.", "jdk.",
"com.sun.", "sun.", # Oracle internal
"org.w3c.", "org.xml.", "org.ietf.", "org.omg.", # Standards
])


@dataclass(frozen=True)
class JavaSourceDependenciesInferenceFieldSet(FieldSet):
Expand All @@ -50,7 +56,6 @@ class InferJavaSourceDependencies(InferDependenciesRequest):
@dataclass(frozen=True)
class JavaInferredDependencies:
dependencies: FrozenOrderedSet[Address]
exports: FrozenOrderedSet[Address]


@dataclass(frozen=True)
Expand All @@ -66,7 +71,7 @@ async def infer_java_dependencies_and_exports_via_source_analysis(
symbol_mapping: SymbolMapping,
) -> JavaInferredDependencies:
if not java_infer_subsystem.imports and not java_infer_subsystem.consumed_types:
return JavaInferredDependencies(FrozenOrderedSet([]), FrozenOrderedSet([]))
return JavaInferredDependencies(FrozenOrderedSet([]))

address = request.source.address

Expand All @@ -91,38 +96,19 @@ async def infer_java_dependencies_and_exports_via_source_analysis(
if java_infer_subsystem.consumed_types:
package = analysis.declared_package

# 13545: `analysis.consumed_types` may be unqualified (package-local or imported) or qualified
# (prefixed by package name). Heuristic for now is that if there's a `.` in the type name, it's
# probably fully qualified. This is probably fine for now.
maybe_qualify_types = (
f"{package}.{consumed_type}" if package and "." not in consumed_type else consumed_type
for consumed_type in analysis.consumed_types
)

types.update(maybe_qualify_types)
# Qualify each consumed type, potentially generating multiple candidates
type_candidates: OrderedSet[str] = OrderedSet()
for consumed_type in analysis.consumed_types:
candidates = qualify_consumed_type(consumed_type, package, analysis.imports)
type_candidates.update(candidates)

# Resolve the export types into (probable) types:
# First produce a map of known consumed unqualified types to possible qualified names
consumed_type_mapping: dict[str, set[str]] = defaultdict(set)
for typ in types:
unqualified = typ.rpartition(".")[2] # `"org.foo.Java"` -> `("org.foo", ".", "Java")`
consumed_type_mapping[unqualified].add(typ)

# Now take the list of unqualified export types and convert them to possible
# qualified names based on the guesses we made for consumed types
export_types = {
i for typ in analysis.export_types for i in consumed_type_mapping.get(typ, set())
}
# Finally, if there's a `.` in the name, it's probably fully qualified,
# so just add it unaltered
export_types.update(typ for typ in analysis.export_types if "." in typ)
types.update(type_candidates)

resolve = tgt[JvmResolveField].normalized_value(jvm)

dependencies: OrderedSet[Address] = OrderedSet()
exports: OrderedSet[Address] = OrderedSet()
for typ in types:
for matches in symbol_mapping.addresses_for_symbol(typ, resolve).values():
for matches in lookup_type_with_fallback(typ, symbol_mapping, resolve).values():
explicitly_provided_deps.maybe_warn_of_ambiguous_dependency_inference(
matches,
address,
Expand All @@ -133,18 +119,8 @@ async def infer_java_dependencies_and_exports_via_source_analysis(

if maybe_disambiguated:
dependencies.add(maybe_disambiguated)
if typ in export_types:
exports.add(maybe_disambiguated)
else:
# Exports from explicitly provided dependencies:
explicitly_provided_exports = set(matches) & set(explicitly_provided_deps.includes)
exports.update(explicitly_provided_exports)

# Files do not export themselves. Don't be silly.
if address in exports:
exports.remove(address)

return JavaInferredDependencies(FrozenOrderedSet(dependencies), FrozenOrderedSet(exports))
return JavaInferredDependencies(FrozenOrderedSet(dependencies))


@rule(desc="Inferring Java dependencies by source analysis")
Expand All @@ -157,6 +133,96 @@ async def infer_java_dependencies_via_source_analysis(
return InferredDependencies(jids.dependencies)


def qualify_consumed_type(
type_name: str,
source_package: str | None,
imports: tuple[JavaImport, ...],
) -> tuple[str, ...]:
"""
Qualify a consumed type name, returning possible qualified names to try.

Returns a tuple of candidates in priority order. The symbol map should be checked
for each candidate until a match is found.

Args:
type_name: The type name as it appears in the source (may be qualified or unqualified)
source_package: The package of the source file, or None if unnamed package
imports: The imports declared in the source file

Returns:
Tuple of possible fully-qualified type names, in priority order
"""
# Case 1: No dots → definitely unqualified, needs package prefix
if "." not in type_name:
if source_package:
return (f"{source_package}.{type_name}",)
else:
return (type_name,) # Unnamed package

# Case 2: Known JDK/stdlib type → already fully qualified
if any(type_name.startswith(prefix) for prefix in JAVA_STDLIB_PREFIXES):
return (type_name,)

# Case 3: Type fully qualified name appears in imports → already resolved
import_names = {imp.name for imp in imports}
if type_name in import_names:
return (type_name,)

# Case 4: Outer class is imported → resolve inner class through import
# E.g., "B.InnerB" where "com.other.B" is imported → "com.other.B.InnerB"
first_part = type_name.split(".")[0]
for imp in imports:
if imp.name.endswith(f".{first_part}"):
# Found import for outer class, construct fully qualified inner class name
qualified = imp.name + type_name[len(first_part):]
return (qualified,)

# Case 5: Ambiguous - has dots but not stdlib, not imported
# Most likely: same-package inner class like "B.InnerB" → "com.example.B.InnerB"
# Less likely: third-party FQTN without import
if source_package:
# Try same-package first (most common), then as-is (fallback for third-party)
return (f"{source_package}.{type_name}", type_name)
else:
return (type_name,)


def lookup_type_with_fallback(
typ: str,
symbol_mapping: SymbolMapping,
resolve: str
) -> dict[str, FrozenOrderedSet[Address]]:
"""
Look up a type in the symbol map, with fallback to parent types for inner classes.

Args:
typ: Fully qualified type name (e.g., "com.example.B.InnerB")
symbol_mapping: The symbol map to search
resolve: The JVM resolve to search within

Returns:
Dict mapping namespaces to addresses, empty if no match found
"""
# Try exact match first
matches = symbol_mapping.addresses_for_symbol(typ, resolve)
if matches:
return matches

# If not found and typ looks like it might be an inner class (has dots after package)
# Try stripping inner class parts one by one
# E.g., "com.example.B.InnerB" → try "com.example.B"
# "com.example.Outer.Middle.Inner" → try "com.example.Outer.Middle", then "com.example.Outer"
parts = typ.split(".")
if len(parts) > 2: # At least package + outer + inner
for i in range(len(parts) - 1, 1, -1): # Don't try single-part names
parent_type = ".".join(parts[:i])
matches = symbol_mapping.addresses_for_symbol(parent_type, resolve)
if matches:
return matches

return {}


def dependency_name(imp: JavaImport):
if imp.is_static and not imp.is_asterisk:
return imp.name.rsplit(".", maxsplit=1)[0]
Expand Down
Loading