diff --git a/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java b/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java index 7ec17a38a..d48b2cc2b 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java @@ -100,8 +100,13 @@ private boolean isGroupableAssertion(J.MethodInvocation assertion) { // Only match method invocations where the select is an assertThat, containing a non-method call argument if (ASSERT_THAT.matches(assertion.getSelect())) { J.MethodInvocation assertThat = (J.MethodInvocation) assertion.getSelect(); - if (assertThat != null && !(assertThat.getArguments().get(0) instanceof MethodCall)) { - return TypeUtils.isOfType(assertThat.getType(), assertion.getType()); + if (assertThat != null) { + Expression assertThatArgument = assertThat.getArguments().get(0); + if (!(assertThatArgument instanceof MethodCall)) { + JavaType assertThatType = assertThat.getType(); + JavaType assertionType = assertion.getType(); + return TypeUtils.isOfType(assertThatType, assertionType); + } } } return false; diff --git a/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java b/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java index a5faa2337..d740f2fb8 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java @@ -15,6 +15,7 @@ */ package org.openrewrite.java.testing.assertj; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; @@ -192,6 +193,37 @@ private int[] notification() { ); } + @Disabled("Not yet implemented") + @Test + void collapseAssertThatsOnInteger() { + //language=java + rewriteRun( + java( + """ + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void test(Integer i) { + assertThat(i).isNotNull(); + assertThat(i).isEqualTo(2); + } + } + """, + """ + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void test(Integer i) { + assertThat(i) + .isNotNull() + .isEqualTo(2); + } + } + """ + ) + ); + } + @Test void ignoreIfAssertThatOnDifferentVariables() { //language=java