1515 */
1616package org .openrewrite .staticanalysis ;
1717
18- import org .openrewrite .ExecutionContext ;
19- import org .openrewrite .Incubating ;
20- import org .openrewrite .Recipe ;
21- import org .openrewrite .TreeVisitor ;
18+ import org .openrewrite .*;
19+ import org .openrewrite .java .AnnotationMatcher ;
20+ import org .openrewrite .java .JavaIsoVisitor ;
21+ import org .openrewrite .java .JavaTemplate ;
22+ import org .openrewrite .java .MethodMatcher ;
23+ import org .openrewrite .java .service .AnnotationService ;
24+ import org .openrewrite .java .tree .J ;
25+ import org .openrewrite .java .tree .JavaType ;
2226
2327import java .time .Duration ;
2428import java .util .Collections ;
29+ import java .util .Comparator ;
2530import java .util .Set ;
31+ import java .util .stream .Stream ;
2632
2733@ Incubating (since = "7.0.0" )
2834public class CovariantEquals extends Recipe {
@@ -35,7 +41,7 @@ public String getDisplayName() {
3541 @ Override
3642 public String getDescription () {
3743 return "Checks that classes and records which define a covariant `equals()` method also override method `equals(Object)`. " +
38- "Covariant `equals()` means a method that is similar to `equals(Object)`, but with a covariant parameter type (any subtype of `Object`)." ;
44+ "Covariant `equals()` means a method that is similar to `equals(Object)`, but with a covariant parameter type (any subtype of `Object`)." ;
3945 }
4046
4147 @ Override
@@ -50,6 +56,98 @@ public Duration getEstimatedEffortPerOccurrence() {
5056
5157 @ Override
5258 public TreeVisitor <?, ExecutionContext > getVisitor () {
53- return new CovariantEqualsVisitor <>();
59+ MethodMatcher objectEquals = new MethodMatcher ("* equals(java.lang.Object)" );
60+ return new JavaIsoVisitor <ExecutionContext >() {
61+
62+ @ Override
63+ public J .ClassDeclaration visitClassDeclaration (J .ClassDeclaration classDecl , ExecutionContext ctx ) {
64+ J .ClassDeclaration cd = super .visitClassDeclaration (classDecl , ctx );
65+ Stream <J .MethodDeclaration > mds = cd .getBody ().getStatements ().stream ()
66+ .filter (J .MethodDeclaration .class ::isInstance )
67+ .map (J .MethodDeclaration .class ::cast );
68+ if (cd .getKind () != J .ClassDeclaration .Kind .Type .Interface && mds .noneMatch (m -> objectEquals .matches (m , classDecl ))) {
69+ cd = (J .ClassDeclaration ) new ChangeCovariantEqualsMethodVisitor (cd ).visit (cd , ctx , getCursor ().getParentOrThrow ());
70+ assert cd != null ;
71+ }
72+ return cd ;
73+ }
74+
75+ class ChangeCovariantEqualsMethodVisitor extends JavaIsoVisitor <ExecutionContext > {
76+ private final AnnotationMatcher OVERRIDE_ANNOTATION = new AnnotationMatcher ("@java.lang.Override" );
77+
78+ private final J .ClassDeclaration enclosingClass ;
79+
80+ public ChangeCovariantEqualsMethodVisitor (J .ClassDeclaration enclosingClass ) {
81+ this .enclosingClass = enclosingClass ;
82+ }
83+
84+ @ Override
85+ public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
86+ J .MethodDeclaration m = super .visitMethodDeclaration (method , ctx );
87+ updateCursor (m );
88+
89+ /*
90+ * Looking for "public boolean equals(EnclosingClassType)" as the method signature match.
91+ * We'll replace it with "public boolean equals(Object)"
92+ */
93+ JavaType .FullyQualified type = enclosingClass .getType ();
94+ if (type == null || type instanceof JavaType .Unknown ) {
95+ return m ;
96+ }
97+
98+ String ecfqn = type .getFullyQualifiedName ();
99+ if (m .hasModifier (J .Modifier .Type .Public ) &&
100+ m .getReturnTypeExpression () != null &&
101+ JavaType .Primitive .Boolean .equals (m .getReturnTypeExpression ().getType ()) &&
102+ new MethodMatcher (ecfqn + " equals(" + ecfqn + ")" ).matches (m , enclosingClass )) {
103+
104+ if (!service (AnnotationService .class ).matches (getCursor (), OVERRIDE_ANNOTATION )) {
105+ m = JavaTemplate .builder ("@Override" ).build ()
106+ .apply (updateCursor (m ),
107+ m .getCoordinates ().addAnnotation (Comparator .comparing (J .Annotation ::getSimpleName )));
108+ }
109+
110+ /*
111+ * Change parameter type to Object, and maybe change input parameter name representing the other object.
112+ * This is because we prepend these type-checking replacement statements to the existing "equals(..)" body.
113+ * Therefore we don't want to collide with any existing variable names.
114+ */
115+ J .VariableDeclarations .NamedVariable oldParamName = ((J .VariableDeclarations ) m .getParameters ().get (0 )).getVariables ().get (0 );
116+ String paramName = "obj" .equals (oldParamName .getSimpleName ()) ? "other" : "obj" ;
117+ m = JavaTemplate .builder ("Object #{}" ).build ()
118+ .apply (updateCursor (m ),
119+ m .getCoordinates ().replaceParameters (),
120+ paramName );
121+
122+ /*
123+ * We'll prepend this type-check and type-cast to the beginning of the existing
124+ * equals(..) method body statements, and let the existing equals(..) method definition continue
125+ * with the logic doing what it was doing.
126+ */
127+ String equalsBodyPrefixTemplate = "if (#{} == this) return true;\n " +
128+ "if (#{} == null || getClass() != #{}.getClass()) return false;\n " +
129+ "#{} #{} = (#{}) #{};\n " ;
130+ JavaTemplate equalsBodySnippet = JavaTemplate .builder (equalsBodyPrefixTemplate ).contextSensitive ().build ();
131+
132+ assert m .getBody () != null ;
133+ Object [] params = new Object []{
134+ paramName ,
135+ paramName ,
136+ paramName ,
137+ enclosingClass .getSimpleName (),
138+ oldParamName .getSimpleName (),
139+ enclosingClass .getSimpleName (),
140+ paramName
141+ };
142+
143+ m = equalsBodySnippet .apply (new Cursor (getCursor ().getParent (), m ),
144+ m .getBody ().getStatements ().get (0 ).getCoordinates ().before (),
145+ params );
146+ }
147+
148+ return m ;
149+ }
150+ }
151+ };
54152 }
55153}
0 commit comments