Skip to content

Commit

Permalink
Parameter reuse supported in Refaster templates
Browse files Browse the repository at this point in the history
  • Loading branch information
jkschneider committed Sep 30, 2023
1 parent 691e526 commit 3eadd4e
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,16 @@ private String escape(String string) {

private String parameters(TemplateDescriptor descriptor) {
List<Integer> afterParams = new ArrayList<>();
Set<Symbol> seenParams = new HashSet<>();
new TreeScanner() {
@Override
public void scan(JCTree jcTree) {
if (jcTree instanceof JCTree.JCIdent) {
JCTree.JCIdent jcIdent = (JCTree.JCIdent) jcTree;
if (jcIdent.sym instanceof Symbol.VarSymbol
&& jcIdent.sym.owner instanceof Symbol.MethodSymbol
&& ((Symbol.MethodSymbol) jcIdent.sym.owner).params.contains(jcIdent.sym)) {
&& ((Symbol.MethodSymbol) jcIdent.sym.owner).params.contains(jcIdent.sym)
&& seenParams.add(jcIdent.sym)) {
afterParams.add(((Symbol.MethodSymbol) jcIdent.sym.owner).params.indexOf(jcIdent.sym));
}
}
Expand Down Expand Up @@ -544,38 +546,9 @@ private String toLambda(JCTree.JCMethodDecl method) {
StringJoiner joiner = new StringJoiner(", ", "(", ")");
for (JCTree.JCVariableDecl parameter : method.getParameters()) {
String paramType = parameter.getType().type.tsym.getQualifiedName().toString();

switch (paramType) {
case "boolean":
paramType = "@Primitive Boolean";
break;
case "byte":
paramType = "@Primitive Byte";
break;
case "char":
paramType = "@Primitive Character";
break;
case "double":
paramType = "@Primitive Double";
break;
case "float":
paramType = "@Primitive Float";
break;
case "int":
paramType = "@Primitive Integer";
break;
case "long":
paramType = "@Primitive Long";
break;
case "short":
paramType = "@Primitive Short";
break;
case "void":
paramType = "@Primitive Void";
break;
}

if (paramType.startsWith("java.lang.")) {
if (!getBoxedPrimitive(paramType).equals(paramType)) {
paramType = "@Primitive " + getBoxedPrimitive(paramType);
} else if (paramType.startsWith("java.lang.")) {
paramType = paramType.substring("java.lang.".length());
}
joiner.add(paramType + " " + parameter.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,6 @@
@SupportedAnnotationTypes("*")
public class TemplateProcessor extends TypeAwareProcessor {
private static final String PRIMITIVE_ANNOTATION = "org.openrewrite.java.template.Primitive";
private static final Map<String, String> PRIMITIVE_TYPE_MAP = new HashMap<>();

static {
PRIMITIVE_TYPE_MAP.put(Boolean.class.getName(), boolean.class.getName());
PRIMITIVE_TYPE_MAP.put(Byte.class.getName(), byte.class.getName());
PRIMITIVE_TYPE_MAP.put(Character.class.getName(), char.class.getName());
PRIMITIVE_TYPE_MAP.put(Short.class.getName(), short.class.getName());
PRIMITIVE_TYPE_MAP.put(Integer.class.getName(), int.class.getName());
PRIMITIVE_TYPE_MAP.put(Long.class.getName(), long.class.getName());
PRIMITIVE_TYPE_MAP.put(Float.class.getName(), float.class.getName());
PRIMITIVE_TYPE_MAP.put(Double.class.getName(), double.class.getName());
PRIMITIVE_TYPE_MAP.put(Void.class.getName(), void.class.getName());
}

private final String javaFileContent;

Expand Down Expand Up @@ -153,16 +140,26 @@ public void visitIdent(JCTree.JCIdent ident) {

for (Map.Entry<Integer, JCTree.JCVariableDecl> paramPos : parameterPositions.descendingMap().entrySet()) {
JCTree.JCVariableDecl param = paramPos.getValue();
String type = param.type.toString();
for (JCTree.JCAnnotation annotation : param.getModifiers().getAnnotations()) {
if (annotation.type.tsym.getQualifiedName().contentEquals(PRIMITIVE_ANNOTATION)) {
type = PRIMITIVE_TYPE_MAP.get(param.type.toString());
// don't generate the annotation into the source code
param.mods.annotations = com.sun.tools.javac.util.List.filter(param.mods.annotations, annotation);

String typeDef = "";

// identify whether this is the leftmost occurrence of this parameter name
if (Objects.equals(parameterPositions.entrySet().stream().filter(p -> p.getValue() == param)
.map(Map.Entry::getKey)
.findFirst().orElse(null), paramPos.getKey())) {
String type = param.type.toString();
for (JCTree.JCAnnotation annotation : param.getModifiers().getAnnotations()) {
if (annotation.type.tsym.getQualifiedName().contentEquals(PRIMITIVE_ANNOTATION)) {
type = getBoxedPrimitive(param.type.toString());
// don't generate the annotation into the source code
param.mods.annotations = com.sun.tools.javac.util.List.filter(param.mods.annotations, annotation);
}
}
typeDef = ":any(" + type + ")";
}

templateSource = templateSource.substring(0, paramPos.getKey() - template.getBody().getStartPosition()) +
"#{any(" + type + ")}" +
"#{" + param.getName().toString() + typeDef + "}" +
templateSource.substring((paramPos.getKey() - template.getBody().getStartPosition()) +
param.name.length());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,28 @@ protected Object tryGetProxyDelegateToField(Object instance) {
return null;
}
}

protected String getBoxedPrimitive(String paramType) {
switch (paramType) {
case "boolean":
return "Boolean";
case "byte":
return "Byte";
case "char":
return "Character";
case "double":
return "Double";
case "float":
return "Float";
case "int":
return "Integer";
case "long":
return "Long";
case "short":
return "Short";
case "void":
return "Void";
}
return paramType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class RefasterTemplateProcessorTest {
@ParameterizedTest
@ValueSource(strings = {
"UseStringIsEmpty",
"NestedPreconditions"
"NestedPreconditions",
"ParameterReuse",
})
void generateRecipe(String recipeName) {
// As per https://github.com/google/compile-testing/blob/v0.21.0/src/main/java/com/google/testing/compile/package-info.java#L53-L55
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.google.testing.compile.Compilation;
import com.google.testing.compile.JavaFileObjects;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.openrewrite.java.template.processor.RefasterTemplateProcessor;
Expand All @@ -33,7 +34,7 @@ class TemplateProcessorTest {
"Unqualified",
"FullyQualified",
})
void generateRecipeTemplates(String qualifier) {
void qualification(String qualifier) {
// As per https://github.com/google/compile-testing/blob/v0.21.0/src/main/java/com/google/testing/compile/package-info.java#L53-L55
Compilation compilation = javac()
.withProcessors(new RefasterTemplateProcessor(), new TemplateProcessor())
Expand All @@ -49,4 +50,16 @@ void generateRecipeTemplates(String qualifier) {
.hasSourceEquivalentTo(JavaFileObjects.forResource("template/ShouldAddClasspathRecipe$" + qualifier + "Recipe$1_after.java"));
}

@Test
void parameterReuse() {
Compilation compilation = javac()
.withProcessors(new RefasterTemplateProcessor(), new TemplateProcessor())
.withClasspath(classpath())
.compile(JavaFileObjects.forResource("template/ParameterReuse.java"));
assertThat(compilation).succeeded();
compilation.generatedSourceFiles().forEach(System.out::println);
assertThat(compilation)
.generatedSourceFile("foo/ParameterReuseRecipe$1_before")
.hasSourceEquivalentTo(JavaFileObjects.forResource("template/ParameterReuseRecipe$1_before.java"));
}
}
19 changes: 19 additions & 0 deletions src/test/resources/refaster/ParameterReuse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package foo;

import com.google.errorprone.refaster.annotation.AfterTemplate;
import com.google.errorprone.refaster.annotation.BeforeTemplate;
import org.openrewrite.java.template.Matches;
import org.openrewrite.java.template.MethodInvocationMatcher;
import org.openrewrite.java.template.NotMatches;

public class ParameterReuse {
@BeforeTemplate
boolean before(String s) {
return s == s;
}

@AfterTemplate
boolean after(String s) {
return s.equals(s);
}
}
55 changes: 55 additions & 0 deletions src/test/resources/refaster/ParameterReuseRecipe.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package foo;

import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.NonNullApi;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.search.*;
import org.openrewrite.java.template.Primitive;
import org.openrewrite.java.template.Semantics;
import org.openrewrite.java.template.function.*;
import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor;
import org.openrewrite.java.tree.*;

import java.util.function.Supplier;


@NonNullApi
public class ParameterReuseRecipe extends Recipe {

@Override
public String getDisplayName() {
return "Refaster template `ParameterReuse`";
}

@Override
public String getDescription() {
return "Recipe created for the following Refaster template:\n```java\npublic class ParameterReuse {\n \n @BeforeTemplate()\n boolean before(String s) {\n return s == s;\n }\n \n @AfterTemplate()\n boolean after(String s) {\n return s.equals(s);\n }\n}\n```\n.";
}

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
JavaVisitor<ExecutionContext> javaVisitor = new AbstractRefasterJavaVisitor() {
final Supplier<JavaTemplate> before = memoize(() -> Semantics.expression(this, "before", (String s) -> s == s).build());
final Supplier<JavaTemplate> after = memoize(() -> Semantics.expression(this, "after", (String s) -> s.equals(s)).build());

@Override
public J visitBinary(J.Binary elem, ExecutionContext ctx) {
JavaTemplate.Matcher matcher;
if ((matcher = matcher(before, getCursor())).find()) {
return embed(
apply(after, getCursor(), elem.getCoordinates().replace(), matcher.parameter(0)),
getCursor(),
ctx
);
}
return super.visitBinary(elem, ctx);
}

};
return javaVisitor;
}
}
19 changes: 19 additions & 0 deletions src/test/resources/template/ParameterReuse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package foo;

import com.google.errorprone.refaster.annotation.AfterTemplate;
import com.google.errorprone.refaster.annotation.BeforeTemplate;
import org.openrewrite.java.template.Matches;
import org.openrewrite.java.template.MethodInvocationMatcher;
import org.openrewrite.java.template.NotMatches;

public class ParameterReuse {
@BeforeTemplate
boolean before(String s) {
return s.equals(s);
}

@AfterTemplate
boolean after() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo;
import org.openrewrite.java.*;

public class ParameterReuseRecipe$1_before {
public static JavaTemplate.Builder getTemplate() {
return JavaTemplate
.builder("#{s:any(java.lang.String)}.equals(#{s})");
}
}

0 comments on commit 3eadd4e

Please sign in to comment.