diff --git a/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java b/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java index 17fe393b..ca8537ca 100644 --- a/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java +++ b/src/main/java/org/openrewrite/java/template/internal/ImportDetector.java @@ -27,6 +27,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Predicate; public class ImportDetector { /** @@ -38,7 +39,11 @@ public class ImportDetector { * @return The list of imports to add. */ public static List imports(JCTree.JCMethodDecl methodDecl) { - ImportScanner importScanner = new ImportScanner(); + return imports(methodDecl, t -> true); + } + + public static List imports(JCTree.JCMethodDecl methodDecl, Predicate scopePredicate) { + ImportScanner importScanner = new ImportScanner(scopePredicate); importScanner.scan(methodDecl.getBody()); importScanner.scan(methodDecl.getReturnType()); for (JCTree.JCVariableDecl param : methodDecl.getParameters()) { @@ -57,16 +62,25 @@ public static List imports(JCTree.JCMethodDecl methodDecl) { * @return The list of imports to add. */ public static List imports(JCTree tree) { - ImportScanner importScanner = new ImportScanner(); + ImportScanner importScanner = new ImportScanner(t -> true); importScanner.scan(tree); return new ArrayList<>(importScanner.imports); } private static class ImportScanner extends TreeScanner { final Set imports = new LinkedHashSet<>(); + private final Predicate scopePredicate; + + public ImportScanner(Predicate scopePredicate) { + this.scopePredicate = scopePredicate; + } @Override public void scan(JCTree tree) { + if (!scopePredicate.test(tree)) { + return; + } + JCTree maybeFieldAccess = tree; if (maybeFieldAccess instanceof JCFieldAccess && ((JCFieldAccess) maybeFieldAccess).sym instanceof Symbol.ClassSymbol && @@ -113,6 +127,11 @@ public void scan(JCTree tree) { && ((JCIdent) ((JCFieldAccess) tree).selected).sym instanceof Symbol.ClassSymbol && !(((JCIdent) ((JCFieldAccess) tree).selected).sym.type instanceof Type.ErrorType)) { imports.add(((JCIdent) ((JCFieldAccess) tree).selected).sym); + } else if (tree instanceof JCTree.JCFieldAccess && + ((JCTree.JCFieldAccess) tree).sym instanceof Symbol.ClassSymbol) { + if (tree.toString().equals(((JCTree.JCFieldAccess) tree).sym.toString())) { + imports.add(((JCTree.JCFieldAccess) tree).sym); + } } super.scan(tree); diff --git a/src/main/java/org/openrewrite/java/template/internal/TemplateCode.java b/src/main/java/org/openrewrite/java/template/internal/TemplateCode.java index 5a9a46ac..f0a84668 100644 --- a/src/main/java/org/openrewrite/java/template/internal/TemplateCode.java +++ b/src/main/java/org/openrewrite/java/template/internal/TemplateCode.java @@ -16,6 +16,7 @@ package org.openrewrite.java.template.internal; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; import com.sun.tools.javac.tree.JCTree; import com.sun.tools.javac.tree.JCTree.JCIdent; import com.sun.tools.javac.tree.Pretty; @@ -94,12 +95,19 @@ public void visitIdent(JCIdent jcIdent) { if (param.isPresent()) { print("#{" + sym.name); if (seenParameters.add(param.get())) { - String type = param.get().sym.type.toString(); - if (param.get().getModifiers().getAnnotations().stream() - .anyMatch(a -> a.attribute.type.tsym.getQualifiedName().toString().equals(PRIMITIVE_ANNOTATION))) { - type = getUnboxedPrimitive(type); + Type type = param.get().sym.type; + String typeString; + if (type instanceof Type.ArrayType) { + print(":anyArray(" + ((Type.ArrayType) type).elemtype.toString() + ")"); + } else { + if (param.get().getModifiers().getAnnotations().stream() + .anyMatch(a -> a.attribute.type.tsym.getQualifiedName().toString().equals(PRIMITIVE_ANNOTATION))) { + typeString = getUnboxedPrimitive(type.toString()); + } else { + typeString = type.toString(); + } + print(":any(" + typeString + ")"); } - print(":any(" + type + ")"); } print("}"); } else if (sym != null) { diff --git a/src/main/java/org/openrewrite/java/template/internal/UsedMethodDetector.java b/src/main/java/org/openrewrite/java/template/internal/UsedMethodDetector.java index 4378f113..903e6b92 100644 --- a/src/main/java/org/openrewrite/java/template/internal/UsedMethodDetector.java +++ b/src/main/java/org/openrewrite/java/template/internal/UsedMethodDetector.java @@ -25,16 +25,21 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Predicate; public class UsedMethodDetector { public static List usedMethods(JCTree input) { + return usedMethods(input, t -> true); + } + + public static List usedMethods(JCTree input, Predicate scopePredicate) { Set imports = new LinkedHashSet<>(); new TreeScanner() { @Override public void scan(JCTree tree) { - if (tree instanceof JCTree.JCAnnotation) { + if (tree instanceof JCTree.JCAnnotation || !scopePredicate.test(tree)) { // completely skip annotations for now return; } diff --git a/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java b/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java index 29a3c1db..734b0de4 100644 --- a/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java +++ b/src/main/java/org/openrewrite/java/template/processor/RefasterTemplateProcessor.java @@ -23,10 +23,8 @@ import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.TypeTag; import com.sun.tools.javac.parser.Tokens; -import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.*; import com.sun.tools.javac.tree.JCTree.JCCompilationUnit; -import com.sun.tools.javac.tree.TreeMaker; -import com.sun.tools.javac.tree.TreeScanner; import com.sun.tools.javac.util.Context; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.java.template.internal.ImportDetector; @@ -45,11 +43,13 @@ import java.io.IOException; import java.io.Writer; import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.*; import static org.openrewrite.java.template.processor.RefasterTemplateProcessor.AFTER_TEMPLATE; import static org.openrewrite.java.template.processor.RefasterTemplateProcessor.BEFORE_TEMPLATE; @@ -59,6 +59,7 @@ */ @SupportedAnnotationTypes({BEFORE_TEMPLATE, AFTER_TEMPLATE}) public class RefasterTemplateProcessor extends TypeAwareProcessor { + static final String BEFORE_TEMPLATE = "com.google.errorprone.refaster.annotation.BeforeTemplate"; static final String AFTER_TEMPLATE = "com.google.errorprone.refaster.annotation.AfterTemplate"; static Set UNSUPPORTED_ANNOTATIONS = Stream.of( @@ -71,7 +72,7 @@ public class RefasterTemplateProcessor extends TypeAwareProcessor { "com.google.errorprone.refaster.annotation.OfKind", "com.google.errorprone.refaster.annotation.Placeholder", "com.google.errorprone.refaster.annotation.Repeated" - ).collect(Collectors.toSet()); + ).collect(toSet()); static ClassValue> LST_TYPE_MAP = new ClassValue>() { @Override @@ -120,27 +121,27 @@ void maybeGenerateTemplateSources(JCCompilationUnit cu) { Context context = javacProcessingEnv.getContext(); new TreeScanner() { - final Map> imports = new HashMap<>(); - final Map> staticImports = new HashMap<>(); + final Map> imports = new HashMap<>(); + final Map> staticImports = new HashMap<>(); final Map recipes = new LinkedHashMap<>(); @Override public void visitClassDef(JCTree.JCClassDecl classDecl) { super.visitClassDef(classDecl); - TemplateDescriptor descriptor = getTemplateDescriptor(classDecl, context, cu); + RuleDescriptor descriptor = getRuleDescriptor(classDecl, context, cu); if (descriptor != null) { TreeMaker treeMaker = TreeMaker.instance(context).forToplevel(cu); List membersWithoutConstructor = classDecl.getMembers().stream() .filter(m -> !(m instanceof JCTree.JCMethodDecl) || !((JCTree.JCMethodDecl) m).name.contentEquals("")) - .collect(Collectors.toList()); + .collect(toList()); JCTree.JCClassDecl copy = treeMaker.ClassDef(classDecl.mods, classDecl.name, classDecl.typarams, classDecl.extending, classDecl.implementing, com.sun.tools.javac.util.List.from(membersWithoutConstructor)); String templateFqn = classDecl.sym.fullname.toString() + "Recipe"; String templateCode = copy.toString().trim(); - for (JCTree.JCMethodDecl template : descriptor.beforeTemplates) { - for (Symbol anImport : ImportDetector.imports(template)) { + for (TemplateDescriptor template : descriptor.beforeTemplates) { + for (Symbol anImport : ImportDetector.imports(template.method)) { if (anImport instanceof Symbol.ClassSymbol) { imports.computeIfAbsent(template, k -> new TreeSet<>()) .add(anImport.getQualifiedName().toString().replace('$', '.')); @@ -152,7 +153,7 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } } } - for (Symbol anImport : ImportDetector.imports(descriptor.afterTemplate)) { + for (Symbol anImport : ImportDetector.imports(descriptor.afterTemplate.method)) { if (anImport instanceof Symbol.ClassSymbol) { imports.computeIfAbsent(descriptor.afterTemplate, k -> new TreeSet<>()) .add(anImport.getQualifiedName().toString().replace('$', '.')); @@ -167,15 +168,16 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { for (Set imports : imports.values()) { imports.removeIf(i -> { int endIndex = i.lastIndexOf('.'); - return endIndex < 0 || "java.lang".equals(i.substring(0, endIndex)); + return endIndex < 0 || "java.lang".equals(i.substring(0, endIndex)) || "com.google.errorprone.refaster".equals(i.substring(0, endIndex)); }); - imports.remove(BEFORE_TEMPLATE); - imports.remove(AFTER_TEMPLATE); + } + for (Set imports : staticImports.values()) { + imports.removeIf(i -> i.startsWith("java.lang.") || i.startsWith("com.google.errorprone.refaster.")); } - Map beforeTemplates = new LinkedHashMap<>(); - for (JCTree.JCMethodDecl templ : descriptor.beforeTemplates) { - String name = templ.getName().toString(); + Map beforeTemplates = new LinkedHashMap<>(); + for (TemplateDescriptor templ : descriptor.beforeTemplates) { + String name = templ.method.name.toString(); if (beforeTemplates.containsKey(name)) { String base = name; for (int i = 0; ; i++) { @@ -187,7 +189,7 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } beforeTemplates.put(name, templ); } - String after = descriptor.afterTemplate.getName().toString(); + String after = descriptor.afterTemplate.method.name.toString(); StringBuilder recipe = new StringBuilder(); Symbol.PackageSymbol pkg = classDecl.sym.packge(); @@ -210,21 +212,24 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { recipe.append(" @Override\n"); recipe.append(" public TreeVisitor getVisitor() {\n"); recipe.append(" JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() {\n"); - for (Map.Entry entry : beforeTemplates.entrySet()) { - recipe.append(" final JavaTemplate ") - .append(entry.getKey()) - .append(" = ") - .append(toJavaTemplateBuilder(entry.getValue(), descriptor.resolvedParameters)) - .append("\n .build();\n"); + for (Map.Entry entry : beforeTemplates.entrySet()) { + int arity = entry.getValue().getArity(); + for (int i = 0; i < arity; i++) { + recipe.append(" final JavaTemplate ") + .append(entry.getKey()).append(arity > 1 ? "$" + i : "") + .append(" = ") + .append(entry.getValue().toJavaTemplateBuilder(i)) + .append("\n .build();\n"); + } } recipe.append(" final JavaTemplate ") .append(after) .append(" = ") - .append(toJavaTemplateBuilder(descriptor.afterTemplate, descriptor.resolvedParameters)) + .append(descriptor.afterTemplate.toJavaTemplateBuilder()) .append("\n .build();\n"); recipe.append("\n"); - List lstTypes = LST_TYPE_MAP.get(getType(descriptor.beforeTemplates.get(0))); + List lstTypes = LST_TYPE_MAP.get(getType(descriptor.beforeTemplates.get(0).method)); String parameters = parameters(descriptor); for (String lstType : lstTypes) { String methodSuffix = lstType.startsWith("J.") ? lstType.substring(2) : lstType; @@ -238,54 +243,57 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } recipe.append(" JavaTemplate.Matcher matcher;\n"); - for (Map.Entry entry : beforeTemplates.entrySet()) { - recipe.append(" if (" + "(matcher = ").append(entry.getKey()).append(".matcher(getCursor())).find()").append(") {\n"); - com.sun.tools.javac.util.List jcVariableDecls = entry.getValue().getParameters(); - for (int i = 0; i < jcVariableDecls.size(); i++) { - JCTree.JCVariableDecl param = jcVariableDecls.get(i); - com.sun.tools.javac.util.List annotations = param.getModifiers().getAnnotations(); - for (JCTree.JCAnnotation jcAnnotation : annotations) { - String annotationType = jcAnnotation.attribute.type.tsym.getQualifiedName().toString(); - if (annotationType.equals("org.openrewrite.java.template.NotMatches")) { - String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); - recipe.append(" if (new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(i).append("))) {\n"); - recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); - recipe.append(" }\n"); - } else if (annotationType.equals("org.openrewrite.java.template.Matches")) { - String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); - recipe.append(" if (!new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(i).append("))) {\n"); - recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); - recipe.append(" }\n"); + for (Map.Entry entry : beforeTemplates.entrySet()) { + int arity = entry.getValue().getArity(); + for (int i = 0; i < arity; i++) { + recipe.append(" if (" + "(matcher = ").append(entry.getKey()).append(arity > 1 ? "$" + i : "").append(".matcher(getCursor())).find()").append(") {\n"); + com.sun.tools.javac.util.List jcVariableDecls = entry.getValue().method.getParameters(); + for (int j = 0; j < jcVariableDecls.size(); j++) { + JCTree.JCVariableDecl param = jcVariableDecls.get(j); + com.sun.tools.javac.util.List annotations = param.getModifiers().getAnnotations(); + for (JCTree.JCAnnotation jcAnnotation : annotations) { + String annotationType = jcAnnotation.attribute.type.tsym.getQualifiedName().toString(); + if (annotationType.equals("org.openrewrite.java.template.NotMatches")) { + String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); + recipe.append(" if (new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(j).append("))) {\n"); + recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + recipe.append(" }\n"); + } else if (annotationType.equals("org.openrewrite.java.template.Matches")) { + String matcher = ((Type.ClassType) jcAnnotation.attribute.getValue().values.get(0).snd.getValue()).tsym.getQualifiedName().toString(); + recipe.append(" if (!new ").append(matcher).append("().matches((Expression) matcher.parameter(").append(j).append("))) {\n"); + recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); + recipe.append(" }\n"); + } } } - } - maybeRemoveImports(imports, recipe, entry.getValue(), descriptor.afterTemplate); - maybeRemoveImports(staticImports, recipe, entry.getValue(), descriptor.afterTemplate); + maybeRemoveImports(imports, recipe, entry.getValue(), i, descriptor.afterTemplate); + maybeRemoveStaticImports(staticImports, recipe, entry.getValue(), i, descriptor.afterTemplate); - List embedOptions = new ArrayList<>(); - JCTree.JCExpression afterReturn = getReturnExpression(descriptor.afterTemplate); - if (afterReturn instanceof JCTree.JCParens || - afterReturn instanceof JCTree.JCUnary && ((JCTree.JCUnary) afterReturn).getExpression() instanceof JCTree.JCParens) { - embedOptions.add("REMOVE_PARENS"); - } - // TODO check if after template contains type or member references - embedOptions.add("SHORTEN_NAMES"); - if (simplifyBooleans(descriptor.afterTemplate)) { - embedOptions.add("SIMPLIFY_BOOLEANS"); - } + List embedOptions = new ArrayList<>(); + JCTree.JCExpression afterReturn = getReturnExpression(descriptor.afterTemplate.method); + if (afterReturn instanceof JCTree.JCParens || + afterReturn instanceof JCTree.JCUnary && ((JCTree.JCUnary) afterReturn).getExpression() instanceof JCTree.JCParens) { + embedOptions.add("REMOVE_PARENS"); + } + // TODO check if after template contains type or member references + embedOptions.add("SHORTEN_NAMES"); + if (simplifyBooleans(descriptor.afterTemplate.method)) { + embedOptions.add("SIMPLIFY_BOOLEANS"); + } - recipe.append(" return embed(\n"); - recipe.append(" ").append(after).append(".apply(getCursor(), elem.getCoordinates().replace()"); - if (!parameters.isEmpty()) { - recipe.append(", ").append(parameters); + recipe.append(" return embed(\n"); + recipe.append(" ").append(after).append(".apply(getCursor(), elem.getCoordinates().replace()"); + if (!parameters.isEmpty()) { + recipe.append(", ").append(parameters); + } + recipe.append("),\n"); + recipe.append(" getCursor(),\n"); + recipe.append(" ctx,\n"); + recipe.append(" ").append(String.join(", ", embedOptions)).append("\n"); + recipe.append(" );\n"); + recipe.append(" }\n"); } - recipe.append("),\n"); - recipe.append(" getCursor(),\n"); - recipe.append(" ctx,\n"); - recipe.append(" ").append(String.join(", ", embedOptions)).append("\n"); - recipe.append(" );\n"); - recipe.append(" }\n"); } recipe.append(" return super.visit").append(methodSuffix).append("(elem, ctx);\n"); recipe.append(" }\n"); @@ -293,7 +301,7 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } recipe.append(" };\n"); - String preconditions = generatePreconditions(descriptor.beforeTemplates, imports, 16); + String preconditions = generatePreconditions(descriptor.beforeTemplates, 16); if (preconditions == null) { recipe.append(" return javaVisitor;\n"); } else { @@ -339,19 +347,6 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { out.write("\n"); - if (!imports.isEmpty()) { - for (String anImport : imports.values().stream().flatMap(Set::stream).collect(Collectors.toSet())) { - out.write("import " + anImport + ";\n"); - } - out.write("\n"); - } - if (!staticImports.isEmpty()) { - for (String anImport : staticImports.values().stream().flatMap(Set::stream).collect(Collectors.toSet())) { - out.write("import static " + anImport + ";\n"); - } - out.write("\n"); - } - if (outerClassRequired) { out.write("/**\n * OpenRewrite recipes created for Refaster template {@code " + inputOuterFQN + "}.\n */\n"); String outerClassName = className.substring(className.lastIndexOf('.') + 1); @@ -393,22 +388,6 @@ public void visitClassDef(JCTree.JCClassDecl classDecl) { } } - private String toJavaTemplateBuilder(JCTree.JCMethodDecl methodDecl, - Map resolvedParameters) { - JCTree tree = methodDecl.getBody().getStatements().get(0); - if (tree instanceof JCTree.JCReturn) { - tree = ((JCTree.JCReturn) tree).getExpression(); - } - - List mappedParameters = methodDecl.getParameters().stream() - .map(resolvedParameters::get) - .map(JCTree.JCVariableDecl.class::cast) - .collect(Collectors.toList()); - - String javaTemplateBuilder = TemplateCode.process(tree, mappedParameters, true); - return TemplateCode.indent(javaTemplateBuilder, 16); - } - private boolean simplifyBooleans(JCTree.JCMethodDecl template) { if (template.getReturnType().type.getTag() == TypeTag.BOOLEAN) { return true; @@ -519,64 +498,55 @@ private String recipeDescriptor(JCTree.JCClassDecl classDecl, String defaultDisp return recipeDescriptor; } - private void maybeRemoveImports(Map> importsByTemplate, StringBuilder recipe, JCTree.JCMethodDecl beforeTemplate, JCTree.JCMethodDecl afterTemplate) { - Set beforeImports = getBeforeImportsAsStrings(importsByTemplate, beforeTemplate); + private void maybeRemoveImports(Map> importsByTemplate, StringBuilder recipe, TemplateDescriptor beforeTemplate, int pos, TemplateDescriptor afterTemplate) { + Set beforeImports = beforeTemplate.usedTypes(pos).stream().map(sym -> sym.fullname.toString()).collect(toCollection(LinkedHashSet::new)); beforeImports.removeAll(getImportsAsStrings(importsByTemplate, afterTemplate)); - beforeImports.removeIf(i -> i.startsWith("java.lang.")); + beforeImports.removeIf(i -> i.startsWith("java.lang.") || i.startsWith("com.google.errorprone.refaster.")); beforeImports.forEach(anImport -> recipe.append(" maybeRemoveImport(\"").append(anImport).append("\");\n")); } - private Set getBeforeImportsAsStrings(Map> importsByTemplate, JCTree.JCMethodDecl templateMethod) { - Set beforeImports = getImportsAsStrings(importsByTemplate, templateMethod); - for (JCTree.JCMethodDecl beforeTemplate : importsByTemplate.keySet()) { - // add fully qualified imports inside the template to the "before imports" set, - // since in the code that is being matched the type may not be fully qualified - new TreeScanner() { - @Override - public void scan(JCTree tree) { - if (tree instanceof JCTree.JCFieldAccess && - ((JCTree.JCFieldAccess) tree).sym instanceof Symbol.ClassSymbol) { - if (tree.toString().equals(((JCTree.JCFieldAccess) tree).sym.toString())) { - beforeImports.add(((JCTree.JCFieldAccess) tree).sym.toString()); - } - } - super.scan(tree); - } - }.scan(beforeTemplate.getBody()); - } - return beforeImports; + private void maybeRemoveStaticImports(Map> importsByTemplate, StringBuilder recipe, TemplateDescriptor beforeTemplate, int pos, TemplateDescriptor afterTemplate) { + Set beforeImports = beforeTemplate.usedMembers(pos).stream().map(symbol -> symbol.owner.getQualifiedName() + "." + symbol.name).collect(toCollection(LinkedHashSet::new)); + beforeImports.removeAll(getImportsAsStrings(importsByTemplate, afterTemplate)); + beforeImports.removeIf(i -> i.startsWith("java.lang.") || i.startsWith("com.google.errorprone.refaster.")); + beforeImports.forEach(anImport -> recipe.append(" maybeRemoveImport(\"").append(anImport).append("\");\n")); } - private Set getImportsAsStrings(Map> importsByTemplate, JCTree.JCMethodDecl templateMethod) { + private Set getImportsAsStrings(Map> importsByTemplate, TemplateDescriptor templateMethod) { return importsByTemplate.entrySet().stream() .filter(e -> templateMethod == e.getKey()) .map(Map.Entry::getValue) .flatMap(Set::stream) - .collect(Collectors.toSet()); + .collect(toSet()); } /* Generate the minimal precondition that would allow to match each before template individually. */ @SuppressWarnings("SameParameterValue") @Nullable - private String generatePreconditions(List beforeTemplates, - Map> imports, - int indent) { - Map> preconditions = new LinkedHashMap<>(); - for (JCTree.JCMethodDecl beforeTemplate : beforeTemplates) { - Set usesVisitors = new LinkedHashSet<>(); - - Set localImports = imports.getOrDefault(beforeTemplate, Collections.emptySet()); - for (String anImport : localImports) { - usesVisitors.add("new UsesType<>(\"" + anImport + "\", true)"); - } - List usedMethods = UsedMethodDetector.usedMethods(beforeTemplate); - for (Symbol.MethodSymbol method : usedMethods) { - String methodName = method.name.toString(); - methodName = methodName.equals("") ? "" : methodName; - usesVisitors.add("new UsesMethod<>(\"" + method.owner.getQualifiedName().toString() + ' ' + methodName + "(..)\")"); - } + private String generatePreconditions(List beforeTemplates, int indent) { + Map> preconditions = new LinkedHashMap<>(); + for (TemplateDescriptor beforeTemplate : beforeTemplates) { + int arity = beforeTemplate.getArity(); + for (int i = 0; i < arity; i++) { + Set usesVisitors = new LinkedHashSet<>(); + + for (Symbol.ClassSymbol usedType : beforeTemplate.usedTypes(i)) { + String name = usedType.getQualifiedName().toString().replace('$', '.'); + if (!name.startsWith("java.lang.") && !name.startsWith("com.google.errorprone.refaster.")) { + usesVisitors.add("new UsesType<>(\"" + name + "\", true)"); + } + } + for (Symbol.MethodSymbol method : beforeTemplate.usedMethods(i)) { + if (method.owner.getQualifiedName().toString().startsWith("com.google.errorprone.refaster.")) { + continue; + } + String methodName = method.name.toString(); + methodName = methodName.equals("") ? "" : methodName; + usesVisitors.add("new UsesMethod<>(\"" + method.owner.getQualifiedName().toString() + ' ' + methodName + "(..)\")"); + } - preconditions.put(beforeTemplate, usesVisitors); + preconditions.put(beforeTemplate.method.name.toString() + (arity == 1 ? "" : "$" + i), usesVisitors); + } } if (preconditions.size() == 1) { @@ -592,10 +562,10 @@ private String generatePreconditions(List beforeTemplates, preconditions.values().removeIf(Collection::isEmpty); if (common.isEmpty()) { - return joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 4)).collect(Collectors.toList()), "or", indent + 4); + return joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 4)).collect(toList()), "or", indent + 4); } else { if (!preconditions.isEmpty()) { - String uniqueConditions = joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 12)).collect(Collectors.toList()), "or", indent + 8); + String uniqueConditions = joinPreconditions(preconditions.values().stream().map(v -> joinPreconditions(v, "and", indent + 12)).collect(toList()), "or", indent + 8); common.add(uniqueConditions); } return joinPreconditions(common, "and", indent + 4); @@ -622,7 +592,7 @@ private String escape(String string) { return string.replace("\\", "\\\\").replace("\"", "\\\"").replaceAll("\\R", "\\\\n"); } - private String parameters(TemplateDescriptor descriptor) { + private String parameters(RuleDescriptor descriptor) { List afterParams = new ArrayList<>(); Set seenParams = new HashSet<>(); new TreeScanner() { @@ -639,7 +609,7 @@ public void scan(JCTree jcTree) { } super.scan(jcTree); } - }.scan(descriptor.afterTemplate.body); + }.scan(descriptor.afterTemplate.method.body); StringJoiner joiner = new StringJoiner(", "); for (Integer param : afterParams) { @@ -665,8 +635,8 @@ private JCTree.JCExpression getReturnExpression(JCTree.JCMethodDecl method) { } @Nullable - private TemplateDescriptor getTemplateDescriptor(JCTree.JCClassDecl tree, Context context, JCCompilationUnit cu) { - TemplateDescriptor result = new TemplateDescriptor(tree); + private RuleDescriptor getRuleDescriptor(JCTree.JCClassDecl tree, Context context, JCCompilationUnit cu) { + RuleDescriptor result = new RuleDescriptor(tree, cu, context); for (JCTree member : tree.getMembers()) { if (member instanceof JCTree.JCMethodDecl) { JCTree.JCMethodDecl method = (JCTree.JCMethodDecl) member; @@ -680,21 +650,24 @@ private TemplateDescriptor getTemplateDescriptor(JCTree.JCClassDecl tree, Contex } } } - return result.validate(context, cu); + return result.validate(); } - class TemplateDescriptor { + class RuleDescriptor { final JCTree.JCClassDecl classDecl; - final List beforeTemplates = new ArrayList<>(); - JCTree.JCMethodDecl afterTemplate; - Map resolvedParameters = new IdentityHashMap<>(); + private final JCCompilationUnit cu; + private final Context context; + final List beforeTemplates = new ArrayList<>(); + TemplateDescriptor afterTemplate; - public TemplateDescriptor(JCTree.JCClassDecl classDecl) { + public RuleDescriptor(JCTree.JCClassDecl classDecl, JCCompilationUnit cu, Context context) { this.classDecl = classDecl; + this.cu = cu; + this.context = context; } @Nullable - private TemplateDescriptor validate(Context context, JCCompilationUnit cu) { + private RefasterTemplateProcessor.RuleDescriptor validate() { if (beforeTemplates.isEmpty() || afterTemplate == null) { return null; } @@ -705,7 +678,7 @@ private TemplateDescriptor validate(Context context, JCCompilationUnit cu) { } for (JCTree member : classDecl.getMembers()) { - if (member instanceof JCTree.JCMethodDecl && !beforeTemplates.contains(member) && member != afterTemplate) { + if (member instanceof JCTree.JCMethodDecl && beforeTemplates.stream().noneMatch(t -> t.method == member) && member != afterTemplate.method) { for (JCTree.JCAnnotation annotation : getTemplateAnnotations(((JCTree.JCMethodDecl) member), UNSUPPORTED_ANNOTATIONS::contains)) { printNoteOnce("@" + annotation.annotationType + " is currently not supported", classDecl.sym); return null; @@ -714,26 +687,138 @@ private TemplateDescriptor validate(Context context, JCCompilationUnit cu) { } // resolve so that we can inspect the template body - boolean valid = resolve(context, cu); + boolean valid = resolve(); if (valid) { - for (JCTree.JCMethodDecl template : beforeTemplates) { - valid = valid && validateTemplateMethod(template); + for (TemplateDescriptor template : beforeTemplates) { + valid = valid && template.validate(); } - valid = valid && validateTemplateMethod(afterTemplate); + valid = valid && afterTemplate.validate(); } return valid ? this : null; } - private boolean validateTemplateMethod(JCTree.JCMethodDecl template) { - if (template.typarams != null && !template.typarams.isEmpty()) { + public void beforeTemplate(JCTree.JCMethodDecl method) { + beforeTemplates.add(new TemplateDescriptor(method, classDecl, cu, context)); + } + + public void afterTemplate(JCTree.JCMethodDecl method) { + afterTemplate = new TemplateDescriptor(method, classDecl, cu, context); + } + + private boolean resolve() { + boolean valid = true; + try { + for (TemplateDescriptor beforeTemplate : beforeTemplates) { + valid &= beforeTemplate.resolve(); + } + valid &= afterTemplate.resolve(); + } catch (Throwable t) { + processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template."); + valid = false; + } + return valid; + } + } + + class TemplateDescriptor { + JCTree.JCMethodDecl method; + private final JCTree.JCClassDecl classDecl; + private final JCCompilationUnit cu; + private final Context context; + + public TemplateDescriptor(JCTree.JCMethodDecl method, JCTree.JCClassDecl classDecl, JCCompilationUnit cu, Context context) { + this.classDecl = classDecl; + this.method = method; + this.cu = cu; + this.context = context; + } + + public int getArity() { + AtomicReference anyOfCall = new AtomicReference<>(); + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + anyOfCall.set(jcMethodInvocation); + return; + } + super.visitApply(jcMethodInvocation); + } + }.scan(method); + return Optional.ofNullable(anyOfCall.get()).map(call -> call.args.size()).orElse(1); + } + + private boolean isAnyOfCall(JCTree.JCMethodInvocation call) { + JCTree.JCExpression meth = call.meth; + if (meth instanceof JCTree.JCFieldAccess) { + JCTree.JCFieldAccess fieldAccess = (JCTree.JCFieldAccess) meth; + if (fieldAccess.name.toString().equals("anyOf") && ((JCTree.JCIdent) fieldAccess.selected).name.toString().equals("Refaster")) { + return true; + } + } + return false; + } + + private String toJavaTemplateBuilder() { + JCTree tree = method.getBody().getStatements().get(0); + if (tree instanceof JCTree.JCReturn) { + tree = ((JCTree.JCReturn) tree).getExpression(); + } + + String javaTemplateBuilder = TemplateCode.process(tree, method.getParameters(), true); + return TemplateCode.indent(javaTemplateBuilder, 16); + } + + private String toJavaTemplateBuilder(int pos) { + if (getArity() == 1) { + assert pos == 0; + return toJavaTemplateBuilder(); + } + + JCTree tree = method.getBody().getStatements().get(0); + if (tree instanceof JCTree.JCReturn) { + tree = ((JCTree.JCReturn) tree).getExpression(); + } + + AtomicReference original = new AtomicReference<>(); + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + original.set(jcMethodInvocation.args.get(pos)); + return; + } + super.visitApply(jcMethodInvocation); + } + }.scan(tree); + + TreeCopier copier = new TreeCopier<>(TreeMaker.instance(context).forToplevel(cu)); + JCTree copied = copier.copy(tree); + JCTree translated = new TreeTranslator() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + result = original.get(); + return; + } + super.visitApply(jcMethodInvocation); + } + }.translate(copied); + + String javaTemplateBuilder = TemplateCode.process(translated, method.getParameters(), true); + return TemplateCode.indent(javaTemplateBuilder, 16); + } + + boolean validate() { + if (method.typarams != null && !method.typarams.isEmpty()) { printNoteOnce("Generic type parameters are currently not supported", classDecl.sym); return false; } - for (JCTree.JCAnnotation annotation : getTemplateAnnotations(template, UNSUPPORTED_ANNOTATIONS::contains)) { + for (JCTree.JCAnnotation annotation : getTemplateAnnotations(method, UNSUPPORTED_ANNOTATIONS::contains)) { printNoteOnce("@" + annotation.annotationType + " is currently not supported", classDecl.sym); return false; } - for (JCTree.JCVariableDecl parameter : template.getParameters()) { + for (JCTree.JCVariableDecl parameter : method.getParameters()) { for (JCTree.JCAnnotation annotation : getTemplateAnnotations(parameter, UNSUPPORTED_ANNOTATIONS::contains)) { printNoteOnce("@" + annotation.annotationType + " is currently not supported", classDecl.sym); return false; @@ -743,16 +828,16 @@ private boolean validateTemplateMethod(JCTree.JCMethodDecl template) { return false; } } - if (template.restype.type instanceof Type.TypeVar) { + if (method.restype.type instanceof Type.TypeVar) { printNoteOnce("Generic type parameters are currently not supported", classDecl.sym); return false; } - if (template.body.stats.get(0) instanceof JCTree.JCIf) { + if (method.body.stats.get(0) instanceof JCTree.JCIf) { printNoteOnce("If statements are currently not supported", classDecl.sym); return false; } - if (template.body.stats.get(0) instanceof JCTree.JCReturn) { - JCTree.JCExpression expr = ((JCTree.JCReturn) template.body.stats.get(0)).expr; + if (method.body.stats.get(0) instanceof JCTree.JCReturn) { + JCTree.JCExpression expr = ((JCTree.JCReturn) method.body.stats.get(0)).expr; if (expr instanceof JCTree.JCLambda) { printNoteOnce("Lambdas are currently not supported", classDecl.sym); return false; @@ -763,12 +848,27 @@ private boolean validateTemplateMethod(JCTree.JCMethodDecl template) { } return new TreeScanner() { boolean valid = true; + int anyOfCount = 0; boolean validate(JCTree tree) { scan(tree); return valid; } + @Override + public void visitSelect(JCTree.JCFieldAccess jcFieldAccess) { + if (jcFieldAccess.selected.type.tsym.toString().equals("com.google.errorprone.refaster.Refaster") && + jcFieldAccess.name.toString().equals("anyOf")) { + // exception for `Refaster.anyOf()` + if (++anyOfCount > 1) { + printNoteOnce("Refaster.anyOf() can only be used once per template", classDecl.sym); + valid = false; + } + return; + } + super.visitSelect(jcFieldAccess); + } + @Override public void visitIdent(JCTree.JCIdent jcIdent) { if (valid @@ -778,47 +878,101 @@ public void visitIdent(JCTree.JCIdent jcIdent) { valid = false; } } - }.validate(template.getBody()); + }.validate(method.getBody()); } - public void beforeTemplate(JCTree.JCMethodDecl method) { - beforeTemplates.add(method); - } - - public void afterTemplate(JCTree.JCMethodDecl method) { - afterTemplate = method; + private boolean resolve() { + method = resolve(method); + return method != null; } - private boolean resolve(Context context, JCCompilationUnit cu) { + @Nullable + private JCTree.JCMethodDecl resolve(JCTree.JCMethodDecl method) { JavacResolution res = new JavacResolution(context); try { - // Resolve parameters - for (JCTree.JCMethodDecl beforeTemplate : beforeTemplates) { - if (!beforeTemplate.getParameters().isEmpty()) { - for (Map.Entry e : res.resolveAll(context, cu, beforeTemplate.getParameters()).entrySet()) { - if (e.getKey() instanceof JCTree.JCVariableDecl && e.getValue() instanceof JCTree.JCVariableDecl) { - resolvedParameters.put((JCTree.JCVariableDecl) e.getValue(), (JCTree.JCVariableDecl) e.getKey()); + classDecl.defs = classDecl.defs.prepend(method); + JCTree.JCMethodDecl resolvedMethod = (JCTree.JCMethodDecl) res.resolveAll(context, cu, singletonList(method)).get(method); + classDecl.defs = classDecl.defs.tail; + resolvedMethod.params = method.params; + method = resolvedMethod; + return method; + } catch (Throwable t) { + processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template method: " + method.name); + } + return null; + } + + public List usedTypes(int i) { + List imports; + if (getArity() == 1) { + imports = ImportDetector.imports(method); + } else { + Set skip = new HashSet<>(); + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + for (int j = 0; j < jcMethodInvocation.args.size(); j++) { + if (j != i) { + skip.add(jcMethodInvocation.args.get(j)); + } } + return; } + super.visitApply(jcMethodInvocation); } - } - if (!afterTemplate.getParameters().isEmpty()) { - for (Map.Entry e : res.resolveAll(context, cu, afterTemplate.getParameters()).entrySet()) { - if (e.getKey() instanceof JCTree.JCVariableDecl && e.getValue() instanceof JCTree.JCVariableDecl) { - resolvedParameters.put((JCTree.JCVariableDecl) e.getValue(), (JCTree.JCVariableDecl) e.getKey()); + }.scan(method); + imports = ImportDetector.imports(method, t -> !skip.contains(t)); + } + return imports.stream().filter(Symbol.ClassSymbol.class::isInstance).map(Symbol.ClassSymbol.class::cast).collect(toList()); + } + + public List usedMembers(int i) { + List imports; + if (getArity() == 1) { + imports = ImportDetector.imports(method); + } else { + Set skip = new HashSet<>(); + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + for (int j = 0; j < jcMethodInvocation.args.size(); j++) { + if (j != i) { + skip.add(jcMethodInvocation.args.get(j)); + } + } + return; } + super.visitApply(jcMethodInvocation); } - } + }.scan(method); + imports = ImportDetector.imports(method, t -> !skip.contains(t)); + } + return imports.stream().filter(sym -> sym instanceof Symbol.VarSymbol || sym instanceof Symbol.MethodSymbol).collect(toList()); + } - // Resolve templates - Map resolvedBeforeTemplates = res.resolveAll(context, cu, beforeTemplates); - beforeTemplates.replaceAll(key -> (JCTree.JCMethodDecl) resolvedBeforeTemplates.get(key)); - afterTemplate = (JCTree.JCMethodDecl) res.resolveAll(context, cu, singletonList(afterTemplate)).get(afterTemplate); - } catch (Throwable t) { - processingEnv.getMessager().printMessage(Kind.WARNING, "Had trouble type attributing the template."); - return false; + public List usedMethods(int i) { + if (getArity() == 1) { + return UsedMethodDetector.usedMethods(method); + } else { + Set skip = new HashSet<>(); + new TreeScanner() { + @Override + public void visitApply(JCTree.JCMethodInvocation jcMethodInvocation) { + if (isAnyOfCall(jcMethodInvocation)) { + for (int j = 0; j < jcMethodInvocation.args.size(); j++) { + if (j != i) { + skip.add(jcMethodInvocation.args.get(j)); + } + } + return; + } + super.visitApply(jcMethodInvocation); + } + }.scan(method); + return UsedMethodDetector.usedMethods(method, t -> !skip.contains(t)); } - return true; } } diff --git a/src/test/java/org/openrewrite/java/template/RefasterTemplateProcessorTest.java b/src/test/java/org/openrewrite/java/template/RefasterTemplateProcessorTest.java index 9cfe0f5d..18f584ab 100644 --- a/src/test/java/org/openrewrite/java/template/RefasterTemplateProcessorTest.java +++ b/src/test/java/org/openrewrite/java/template/RefasterTemplateProcessorTest.java @@ -68,7 +68,6 @@ void generateRecipeInDefaultPackage() { @ParameterizedTest @ValueSource(strings = { "OrElseGetGet", - "RefasterAnyOf", }) void skipRecipeGeneration(String recipeName) { Compilation compilation = compileResource("refaster/" + recipeName + ".java"); @@ -85,6 +84,7 @@ void skipRecipeGeneration(String recipeName) { "ShouldAddImports", "ShouldSupportNestedClasses", "SimplifyTernary", + "RefasterAnyOf", }) void nestedRecipes(String recipeName) { Compilation compilation = compileResource("refaster/" + recipeName + ".java"); diff --git a/src/test/resources/refaster/ArraysRecipe.java b/src/test/resources/refaster/ArraysRecipe.java index f4283a73..4a3434ec 100644 --- a/src/test/resources/refaster/ArraysRecipe.java +++ b/src/test/resources/refaster/ArraysRecipe.java @@ -61,10 +61,10 @@ public String getDescription() { public TreeVisitor getVisitor() { JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() { final JavaTemplate before = JavaTemplate - .builder("String.join(\", \", #{strings:any(java.lang.String[])})") + .builder("String.join(\", \", #{strings:anyArray(java.lang.String)})") .build(); final JavaTemplate after = JavaTemplate - .builder("String.join(\":\", #{strings:any(java.lang.String[])})") + .builder("String.join(\":\", #{strings:anyArray(java.lang.String)})") .build(); @Override diff --git a/src/test/resources/refaster/EscapesRecipes.java b/src/test/resources/refaster/EscapesRecipes.java index f66fc224..599ec413 100644 --- a/src/test/resources/refaster/EscapesRecipes.java +++ b/src/test/resources/refaster/EscapesRecipes.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,9 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import com.sun.tools.javac.util.Convert; -import com.sun.tools.javac.util.Constants; - /** * OpenRewrite recipes created for Refaster template {@code foo.Escapes}. */ diff --git a/src/test/resources/refaster/GenericsRecipes.java b/src/test/resources/refaster/GenericsRecipes.java index bd5153fb..29560fab 100644 --- a/src/test/resources/refaster/GenericsRecipes.java +++ b/src/test/resources/refaster/GenericsRecipes.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,8 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import java.util.List; - /** * OpenRewrite recipes created for Refaster template {@code foo.Generics}. */ diff --git a/src/test/resources/refaster/MethodThrowsRecipe.java b/src/test/resources/refaster/MethodThrowsRecipe.java index c9c1171a..ede66c67 100644 --- a/src/test/resources/refaster/MethodThrowsRecipe.java +++ b/src/test/resources/refaster/MethodThrowsRecipe.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,10 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.charset.StandardCharsets; - /** * OpenRewrite recipe created for Refaster template {@code MethodThrows}. */ @@ -88,8 +83,8 @@ public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { }; return Preconditions.check( Preconditions.and( - new UsesType<>("java.nio.charset.StandardCharsets", true), new UsesType<>("java.nio.file.Files", true), + new UsesType<>("java.nio.charset.StandardCharsets", true), new UsesType<>("java.nio.file.Path", true), new UsesMethod<>("java.nio.file.Files readAllLines(..)") ), diff --git a/src/test/resources/refaster/MultipleDereferencesRecipes.java b/src/test/resources/refaster/MultipleDereferencesRecipes.java index 3eb865f1..5e71f8d3 100644 --- a/src/test/resources/refaster/MultipleDereferencesRecipes.java +++ b/src/test/resources/refaster/MultipleDereferencesRecipes.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,9 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import java.nio.file.Files; -import java.nio.file.Path; - /** * OpenRewrite recipes created for Refaster template {@code foo.MultipleDereferences}. */ diff --git a/src/test/resources/refaster/NestedPreconditionsRecipe.java b/src/test/resources/refaster/NestedPreconditionsRecipe.java index 8f9835f2..0a8385d2 100644 --- a/src/test/resources/refaster/NestedPreconditionsRecipe.java +++ b/src/test/resources/refaster/NestedPreconditionsRecipe.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,11 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.HashMap; -import java.util.Hashtable; - /** * OpenRewrite recipe created for Refaster template {@code NestedPreconditions}. */ diff --git a/src/test/resources/refaster/RefasterAnyOf.java b/src/test/resources/refaster/RefasterAnyOf.java index dca96c95..26e978b3 100644 --- a/src/test/resources/refaster/RefasterAnyOf.java +++ b/src/test/resources/refaster/RefasterAnyOf.java @@ -18,21 +18,44 @@ import com.google.errorprone.refaster.Refaster; import com.google.errorprone.refaster.annotation.AfterTemplate; import com.google.errorprone.refaster.annotation.BeforeTemplate; -import org.openrewrite.java.template.RecipeDescriptor; -@RecipeDescriptor( - name = "Use `String.isEmpty()`", - description = "Use `String#isEmpty()` instead of String length comparison.", - tags = {"sast", "strings"} -) +import java.util.LinkedList; +import java.util.List; + public class RefasterAnyOf { - @BeforeTemplate - boolean before(String s) { - return Refaster.anyOf(s.length() < 1, s.length() == 0); + public static class StringIsEmpty { + @BeforeTemplate + boolean before(String s) { + return Refaster.anyOf(s.length() < 1, s.length() == 0); + } + + @AfterTemplate + boolean after(String s) { + return s.isEmpty(); + } + } + + public static class EmptyList { + @BeforeTemplate + List before() { + return Refaster.anyOf(new LinkedList(), java.util.Collections.emptyList()); + } + + @AfterTemplate + List after() { + return new java.util.ArrayList(); + } } - @AfterTemplate - boolean after(String s) { - return s.isEmpty(); + public static class NewStringFromCharArraySubSequence { + @BeforeTemplate + String before(char[] data, int offset, int count) { + return Refaster.anyOf(String.valueOf(data, offset, count), String.copyValueOf(data, offset, count)); + } + + @AfterTemplate + String after(char[] data, int offset, int count) { + return new String(data, offset, count); + } } } diff --git a/src/test/resources/refaster/RefasterAnyOfRecipes.java b/src/test/resources/refaster/RefasterAnyOfRecipes.java new file mode 100644 index 00000000..0890cbd3 --- /dev/null +++ b/src/test/resources/refaster/RefasterAnyOfRecipes.java @@ -0,0 +1,277 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package foo; + +import org.openrewrite.ExecutionContext; +import org.openrewrite.Preconditions; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.internal.lang.NonNullApi; +import org.openrewrite.java.JavaParser; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.search.*; +import org.openrewrite.java.template.Primitive; +import org.openrewrite.java.template.function.*; +import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; +import org.openrewrite.java.tree.*; + +import java.util.*; + +import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; + +/** + * OpenRewrite recipes created for Refaster template {@code foo.RefasterAnyOf}. + */ +@SuppressWarnings("all") +public class RefasterAnyOfRecipes extends Recipe { + /** + * Instantiates a new instance. + */ + public RefasterAnyOfRecipes() {} + + @Override + public String getDisplayName() { + return "`RefasterAnyOf` Refaster recipes"; + } + + @Override + public String getDescription() { + return "Refaster template recipes for `foo.RefasterAnyOf`."; + } + + @Override + public List getRecipeList() { + return Arrays.asList( + new StringIsEmptyRecipe(), + new EmptyListRecipe(), + new NewStringFromCharArraySubSequenceRecipe() + ); + } + + /** + * OpenRewrite recipe created for Refaster template {@code RefasterAnyOf.StringIsEmpty}. + */ + @SuppressWarnings("all") + @NonNullApi + public static class StringIsEmptyRecipe extends Recipe { + + /** + * Instantiates a new instance. + */ + public StringIsEmptyRecipe() {} + + @Override + public String getDisplayName() { + return "Refaster template `RefasterAnyOf.StringIsEmpty`"; + } + + @Override + public String getDescription() { + return "Recipe created for the following Refaster template:\n```java\npublic static class StringIsEmpty {\n \n @BeforeTemplate()\n boolean before(String s) {\n return Refaster.anyOf(s.length() < 1, s.length() == 0);\n }\n \n @AfterTemplate()\n boolean after(String s) {\n return s.isEmpty();\n }\n}\n```\n."; + } + + @Override + public TreeVisitor getVisitor() { + JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() { + final JavaTemplate before$0 = JavaTemplate + .builder("#{s:any(java.lang.String)}.length() < 1") + .build(); + final JavaTemplate before$1 = JavaTemplate + .builder("#{s:any(java.lang.String)}.length() == 0") + .build(); + final JavaTemplate after = JavaTemplate + .builder("#{s:any(java.lang.String)}.isEmpty()") + .build(); + + @Override + public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { + JavaTemplate.Matcher matcher; + if ((matcher = before$0.matcher(getCursor())).find()) { + return embed( + after.apply(getCursor(), elem.getCoordinates().replace(), matcher.parameter(0)), + getCursor(), + ctx, + SHORTEN_NAMES, SIMPLIFY_BOOLEANS + ); + } + if ((matcher = before$1.matcher(getCursor())).find()) { + return embed( + after.apply(getCursor(), elem.getCoordinates().replace(), matcher.parameter(0)), + getCursor(), + ctx, + SHORTEN_NAMES, SIMPLIFY_BOOLEANS + ); + } + return super.visitMethodInvocation(elem, ctx); + } + + }; + return Preconditions.check( + new UsesMethod<>("java.lang.String length(..)"), + javaVisitor + ); + } + } + + /** + * OpenRewrite recipe created for Refaster template {@code RefasterAnyOf.EmptyList}. + */ + @SuppressWarnings("all") + @NonNullApi + public static class EmptyListRecipe extends Recipe { + + /** + * Instantiates a new instance. + */ + public EmptyListRecipe() {} + + @Override + public String getDisplayName() { + return "Refaster template `RefasterAnyOf.EmptyList`"; + } + + @Override + public String getDescription() { + return "Recipe created for the following Refaster template:\n```java\npublic static class EmptyList {\n \n @BeforeTemplate()\n List before() {\n return Refaster.anyOf(new LinkedList(), java.util.Collections.emptyList());\n }\n \n @AfterTemplate()\n List after() {\n return new java.util.ArrayList();\n }\n}\n```\n."; + } + + @Override + public TreeVisitor getVisitor() { + JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() { + final JavaTemplate before$0 = JavaTemplate + .builder("new java.util.LinkedList()") + .build(); + final JavaTemplate before$1 = JavaTemplate + .builder("java.util.Collections.emptyList()") + .build(); + final JavaTemplate after = JavaTemplate + .builder("new java.util.ArrayList()") + .build(); + + @Override + public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { + JavaTemplate.Matcher matcher; + if ((matcher = before$0.matcher(getCursor())).find()) { + maybeRemoveImport("java.util.LinkedList"); + return embed( + after.apply(getCursor(), elem.getCoordinates().replace()), + getCursor(), + ctx, + SHORTEN_NAMES + ); + } + if ((matcher = before$1.matcher(getCursor())).find()) { + maybeRemoveImport("java.util.Collections"); + return embed( + after.apply(getCursor(), elem.getCoordinates().replace()), + getCursor(), + ctx, + SHORTEN_NAMES + ); + } + return super.visitMethodInvocation(elem, ctx); + } + + }; + return Preconditions.check( + Preconditions.and( + new UsesType<>("java.util.List", true), + Preconditions.or( + Preconditions.and( + new UsesType<>("java.util.LinkedList", true), + new UsesMethod<>("java.util.LinkedList (..)") + ), + Preconditions.and( + new UsesType<>("java.util.Collections", true), + new UsesMethod<>("java.util.Collections emptyList(..)") + ) + ) + ), + javaVisitor + ); + } + } + + /** + * OpenRewrite recipe created for Refaster template {@code RefasterAnyOf.NewStringFromCharArraySubSequence}. + */ + @SuppressWarnings("all") + @NonNullApi + public static class NewStringFromCharArraySubSequenceRecipe extends Recipe { + + /** + * Instantiates a new instance. + */ + public NewStringFromCharArraySubSequenceRecipe() {} + + @Override + public String getDisplayName() { + return "Refaster template `RefasterAnyOf.NewStringFromCharArraySubSequence`"; + } + + @Override + public String getDescription() { + return "Recipe created for the following Refaster template:\n```java\npublic static class NewStringFromCharArraySubSequence {\n \n @BeforeTemplate()\n String before(char[] data, int offset, int count) {\n return Refaster.anyOf(String.valueOf(data, offset, count), String.copyValueOf(data, offset, count));\n }\n \n @AfterTemplate()\n String after(char[] data, int offset, int count) {\n return new String(data, offset, count);\n }\n}\n```\n."; + } + + @Override + public TreeVisitor getVisitor() { + JavaVisitor javaVisitor = new AbstractRefasterJavaVisitor() { + final JavaTemplate before$0 = JavaTemplate + .builder("String.valueOf(#{data:anyArray(char)}, #{offset:any(int)}, #{count:any(int)})") + .build(); + final JavaTemplate before$1 = JavaTemplate + .builder("String.copyValueOf(#{data:anyArray(char)}, #{offset:any(int)}, #{count:any(int)})") + .build(); + final JavaTemplate after = JavaTemplate + .builder("new String(#{data:anyArray(char)}, #{offset:any(int)}, #{count:any(int)})") + .build(); + + @Override + public J visitMethodInvocation(J.MethodInvocation elem, ExecutionContext ctx) { + JavaTemplate.Matcher matcher; + if ((matcher = before$0.matcher(getCursor())).find()) { + return embed( + after.apply(getCursor(), elem.getCoordinates().replace(), matcher.parameter(0), matcher.parameter(1), matcher.parameter(2)), + getCursor(), + ctx, + SHORTEN_NAMES + ); + } + if ((matcher = before$1.matcher(getCursor())).find()) { + return embed( + after.apply(getCursor(), elem.getCoordinates().replace(), matcher.parameter(0), matcher.parameter(1), matcher.parameter(2)), + getCursor(), + ctx, + SHORTEN_NAMES + ); + } + return super.visitMethodInvocation(elem, ctx); + } + + }; + return Preconditions.check( + Preconditions.or( + new UsesMethod<>("java.lang.String valueOf(..)"), + new UsesMethod<>("java.lang.String copyValueOf(..)") + ), + javaVisitor + ); + } + } + +} diff --git a/src/test/resources/refaster/ShouldAddImportsRecipes.java b/src/test/resources/refaster/ShouldAddImportsRecipes.java index 904843f3..c6a068b4 100644 --- a/src/test/resources/refaster/ShouldAddImportsRecipes.java +++ b/src/test/resources/refaster/ShouldAddImportsRecipes.java @@ -25,7 +25,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.*; import org.openrewrite.java.template.Primitive; - import org.openrewrite.java.template.function.*; import org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor; import org.openrewrite.java.tree.*; @@ -34,12 +33,6 @@ import static org.openrewrite.java.template.internal.AbstractRefasterJavaVisitor.EmbeddingOption.*; -import java.util.Objects; -import java.nio.file.Path; - -import static java.nio.file.Files.exists; -import static java.util.Objects.hash; - /** * OpenRewrite recipes created for Refaster template {@code foo.ShouldAddImports}. */