From d81b2b8041e4b7ac01a480eaaa0c283066839686 Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Wed, 16 Jun 2021 15:59:32 +1000 Subject: [PATCH] feat: added support for WebFlux with Spring JUnit 5 tests #1373 --- provider/build.gradle | 2 +- .../dius/pact/provider/junit5/TestTarget.kt | 1 - provider/junit5spring/build.gradle | 1 + .../spring/junit5/MockMvcTestTarget.kt | 2 - .../junit5/PactVerificationSpringExtension.kt | 3 + .../provider/spring/junit5/WebFluxTarget.kt | 126 ++++++++++++++++++ .../spring/junit5/WebFluxTargetSpec.groovy | 82 ++++++++++++ .../dius/pact/provider/ProviderVerifier.kt | 62 +++++---- 8 files changed, 250 insertions(+), 29 deletions(-) create mode 100644 provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/WebFluxTarget.kt create mode 100644 provider/junit5spring/src/test/groovy/au/com/dius/pact/provider/spring/junit5/WebFluxTargetSpec.groovy diff --git a/provider/build.gradle b/provider/build.gradle index ba59f63242..c99f255d1f 100644 --- a/provider/build.gradle +++ b/provider/build.gradle @@ -11,7 +11,7 @@ dependencies { "org.apache.httpcomponents:httpclient:${project.httpClientVersion}" implementation "org.slf4j:slf4j-api:${project.slf4jVersion}" implementation "org.scala-lang:scala-library:${project.scalaVersion}" - implementation 'io.github.classgraph:classgraph:4.8.105' + api 'io.github.classgraph:classgraph:4.8.105' implementation "org.codehaus.groovy:groovy:${project.groovyVersion}" api 'com.michael-bull.kotlin-result:kotlin-result:1.1.6' implementation 'com.github.ajalt:mordant:1.2.1' diff --git a/provider/junit5/src/main/kotlin/au/com/dius/pact/provider/junit5/TestTarget.kt b/provider/junit5/src/main/kotlin/au/com/dius/pact/provider/junit5/TestTarget.kt index c6a7ce3a80..8c33c87888 100644 --- a/provider/junit5/src/main/kotlin/au/com/dius/pact/provider/junit5/TestTarget.kt +++ b/provider/junit5/src/main/kotlin/au/com/dius/pact/provider/junit5/TestTarget.kt @@ -4,7 +4,6 @@ import au.com.dius.pact.core.model.DirectorySource import au.com.dius.pact.core.model.Interaction import au.com.dius.pact.core.model.PactBrokerSource import au.com.dius.pact.core.model.PactSource -import au.com.dius.pact.core.model.RequestResponseInteraction import au.com.dius.pact.core.model.SynchronousRequestResponse import au.com.dius.pact.core.model.generators.GeneratorTestMode import au.com.dius.pact.core.model.messaging.Message diff --git a/provider/junit5spring/build.gradle b/provider/junit5spring/build.gradle index 949b12a2d2..085ced5c38 100644 --- a/provider/junit5spring/build.gradle +++ b/provider/junit5spring/build.gradle @@ -3,6 +3,7 @@ dependencies { api 'org.springframework:spring-context:5.2.3.RELEASE' api 'org.springframework:spring-test:5.2.3.RELEASE' api 'org.springframework:spring-web:5.2.3.RELEASE' + api 'org.springframework:spring-webflux:5.2.3.RELEASE' api 'javax.servlet:javax.servlet-api:3.1.0' api 'org.hamcrest:hamcrest:2.1' diff --git a/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/MockMvcTestTarget.kt b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/MockMvcTestTarget.kt index 749ee81919..f9a3675d5b 100644 --- a/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/MockMvcTestTarget.kt +++ b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/MockMvcTestTarget.kt @@ -4,8 +4,6 @@ import au.com.dius.pact.core.model.ContentType import au.com.dius.pact.core.model.IRequest import au.com.dius.pact.core.model.Interaction import au.com.dius.pact.core.model.PactSource -import au.com.dius.pact.core.model.Request -import au.com.dius.pact.core.model.RequestResponseInteraction import au.com.dius.pact.core.model.SynchronousRequestResponse import au.com.dius.pact.core.model.generators.GeneratorTestMode import au.com.dius.pact.provider.IProviderVerifier diff --git a/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/PactVerificationSpringExtension.kt b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/PactVerificationSpringExtension.kt index cadf997daf..2d60c55dcc 100644 --- a/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/PactVerificationSpringExtension.kt +++ b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/PactVerificationSpringExtension.kt @@ -7,6 +7,7 @@ import au.com.dius.pact.provider.junit5.PactVerificationContext import au.com.dius.pact.provider.junit5.PactVerificationExtension import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.extension.ParameterContext +import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder open class PactVerificationSpringExtension( @@ -24,6 +25,7 @@ open class PactVerificationSpringExtension( val testContext = store.get("interactionContext") as PactVerificationContext return when (parameterContext.parameter.type) { MockHttpServletRequestBuilder::class.java -> testContext.target is MockMvcTestTarget + WebTestClient.RequestHeadersSpec::class.java -> testContext.target is WebFluxTarget else -> super.supportsParameter(parameterContext, extensionContext) } } @@ -32,6 +34,7 @@ open class PactVerificationSpringExtension( val store = extensionContext.getStore(ExtensionContext.Namespace.create("pact-jvm")) return when (parameterContext.parameter.type) { MockHttpServletRequestBuilder::class.java -> store.get("request") + WebTestClient.RequestHeadersSpec::class.java -> store.get("request") else -> super.resolveParameter(parameterContext, extensionContext) } } diff --git a/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/WebFluxTarget.kt b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/WebFluxTarget.kt new file mode 100644 index 0000000000..4fd6b89020 --- /dev/null +++ b/provider/junit5spring/src/main/kotlin/au/com/dius/pact/provider/spring/junit5/WebFluxTarget.kt @@ -0,0 +1,126 @@ +package au.com.dius.pact.provider.spring.junit5 + +import au.com.dius.pact.core.model.ContentType +import au.com.dius.pact.core.model.IRequest +import au.com.dius.pact.core.model.Interaction +import au.com.dius.pact.core.model.PactSource +import au.com.dius.pact.core.model.SynchronousRequestResponse +import au.com.dius.pact.core.model.generators.GeneratorTestMode +import au.com.dius.pact.provider.IProviderVerifier +import au.com.dius.pact.provider.ProviderInfo +import au.com.dius.pact.provider.ProviderResponse +import au.com.dius.pact.provider.junit5.TestTarget +import org.apache.commons.lang3.StringUtils +import org.springframework.http.HttpHeaders +import org.springframework.http.HttpMethod +import org.springframework.http.MediaType +import org.springframework.http.client.MultipartBodyBuilder +import org.springframework.test.web.reactive.server.WebTestClient +import org.springframework.web.reactive.function.BodyInserters +import org.springframework.web.reactive.function.server.RouterFunction +import org.springframework.web.util.UriComponentsBuilder +import javax.mail.internet.ContentDisposition +import javax.mail.internet.MimeMultipart +import javax.mail.util.ByteArrayDataSource + +class WebFluxTarget(private val routerFunction: RouterFunction<*>) : TestTarget { + override fun getProviderInfo(serviceName: String, pactSource: PactSource?) = ProviderInfo(serviceName) + + override fun prepareRequest(interaction: Interaction, context: MutableMap): Pair, WebTestClient> { + if (interaction is SynchronousRequestResponse) { + val request = interaction.request.generatedRequest(context, GeneratorTestMode.Provider) + val webClient = WebTestClient.bindToRouterFunction(routerFunction).build() + return toWebFluxRequestBuilder(webClient, request) to webClient + } + throw UnsupportedOperationException("Only request/response interactions can be used with an MockMvc test target") + } + + private fun toWebFluxRequestBuilder(webClient: WebTestClient, request: IRequest): WebTestClient.RequestHeadersSpec<*> { + return if (request.body.isPresent()) { + if (request.isMultipartFileUpload()) { + val multipart = MimeMultipart(ByteArrayDataSource(request.body.unwrap(), request.contentTypeHeader())) + + val bodyBuilder = MultipartBodyBuilder() + var i = 0 + while (i < multipart.count) { + val bodyPart = multipart.getBodyPart(i) + val contentDisposition = ContentDisposition(bodyPart.getHeader("Content-Disposition").first()) + val name = StringUtils.defaultString(contentDisposition.getParameter("name"), "file") + val filename = contentDisposition.getParameter("filename").orEmpty() + + bodyBuilder + .part(name, bodyPart.content) + .filename(filename) + .contentType(MediaType.valueOf(bodyPart.contentType)) + .header("Content-Disposition", "form-data; name=$name; filename=$filename") + + i++ + } + + webClient + .method(HttpMethod.POST) + .uri(requestUriString(request)) + .body(BodyInserters.fromMultipartData(bodyBuilder.build())) + .headers { request.headers.forEach { (k, v) -> it.addAll(k, v) } } + } else { + webClient + .method(HttpMethod.valueOf(request.method)) + .uri(requestUriString(request)) + .bodyValue(request.body.value!!) + .headers { + request.headers.forEach { (k, v) -> it.addAll(k, v) } + if (!request.headers.containsKey(HttpHeaders.CONTENT_TYPE)) { + it.set(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + } + } + } + } else { + webClient + .method(HttpMethod.valueOf(request.method)) + .uri(requestUriString(request)) + .headers { + request.headers.forEach { (k, v) -> it.addAll(k, v) } + } + } + } + + private fun requestUriString(request: IRequest): String { + val uriBuilder = UriComponentsBuilder.fromPath(request.path) + + request.query.forEach { (key, value) -> + uriBuilder.queryParam(key, value) + } + + return uriBuilder.toUriString() + } + + override fun isHttpTarget() = true + + override fun executeInteraction(client: Any?, request: Any?): ProviderResponse { + val requestBuilder = request as WebTestClient.RequestHeadersSpec<*> + val exchangeResult = requestBuilder.exchange().expectBody().returnResult() + + val headers = mutableMapOf>() + exchangeResult.responseHeaders.forEach { header -> + headers[header.key] = header.value + } + + val contentTypeHeader = exchangeResult.responseHeaders.contentType + val contentType = if (contentTypeHeader == null) { + ContentType.JSON + } else { + ContentType.fromString(contentTypeHeader.toString()) + } + + return ProviderResponse( + exchangeResult.status.value(), + headers, + contentType, + exchangeResult.responseBody?.let { String(it) } + ) + } + + override fun prepareVerifier(verifier: IProviderVerifier, testInstance: Any) { + /* NO-OP */ + } +} diff --git a/provider/junit5spring/src/test/groovy/au/com/dius/pact/provider/spring/junit5/WebFluxTargetSpec.groovy b/provider/junit5spring/src/test/groovy/au/com/dius/pact/provider/spring/junit5/WebFluxTargetSpec.groovy new file mode 100644 index 0000000000..0d6dab5826 --- /dev/null +++ b/provider/junit5spring/src/test/groovy/au/com/dius/pact/provider/spring/junit5/WebFluxTargetSpec.groovy @@ -0,0 +1,82 @@ +package au.com.dius.pact.provider.spring.junit5 + +import au.com.dius.pact.core.model.OptionalBody +import au.com.dius.pact.core.model.Request +import au.com.dius.pact.core.model.RequestResponseInteraction +import org.springframework.http.MediaType +import org.springframework.test.web.reactive.server.WebTestClient +import org.springframework.web.reactive.function.BodyInserters +import org.springframework.web.reactive.function.server.RequestPredicates +import org.springframework.web.reactive.function.server.RouterFunction +import org.springframework.web.reactive.function.server.RouterFunctions +import org.springframework.web.reactive.function.server.ServerResponse +import spock.lang.Specification + +import java.nio.charset.StandardCharsets + +@SuppressWarnings('ClosureAsLastMethodParameter') +class WebFluxTargetSpec extends Specification { + RouterFunction routerFunction = RouterFunctions.route(RequestPredicates.GET('/data'), { req -> + ServerResponse.ok().contentType(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromValue('{"id":1234}')) + }) + + def 'should prepare get request'() { + given: + WebFluxTarget webFluxTarget = new WebFluxTarget(routerFunction) + def request = new Request('GET', '/data', [id: ['1234']]) + def interaction = new RequestResponseInteraction('some description', [], request) + + when: + def requestAndClient = webFluxTarget.prepareRequest(interaction, [:]) + def requestBuilder = requestAndClient.first + def builtRequest = requestBuilder.exchange().expectBody().returnResult() + + then: + requestBuilder instanceof WebTestClient.RequestHeadersSpec + builtRequest.url.path == '/data' + builtRequest.method.toString() == 'GET' + new String(builtRequest.responseBody) == '{"id":1234}' + } + + def 'should prepare post request'() { + given: + RouterFunction postRouterFunction = RouterFunctions.route(RequestPredicates.POST('/data'), { req -> + assert req.queryParams() == [id: ['1234']] + def reqBody = req.bodyToMono(String).doOnNext({ s -> assert s == '{"foo":"bar"}' }) + ServerResponse.ok().build(reqBody) + }) + WebFluxTarget webFluxTarget = new WebFluxTarget(postRouterFunction) + def request = new Request('POST', '/data', [id: ['1234']], [:], + OptionalBody.body('{"foo":"bar"}'.getBytes(StandardCharsets.UTF_8))) + def interaction = new RequestResponseInteraction('some description', [], request) + + when: + def requestAndClient = webFluxTarget.prepareRequest(interaction, [:]) + def requestBuilder = requestAndClient.first + + then: + requestBuilder instanceof WebTestClient.RequestHeadersSpec + def builtRequest = requestBuilder.exchange().expectBody().returnResult() + builtRequest.url.path == '/data' + builtRequest.method.toString() == 'POST' + builtRequest.rawStatusCode == 200 + } + + def 'should execute interaction'() { + given: + def request = new Request('GET', '/data', [id: ['1234']]) + def interaction = new RequestResponseInteraction('some description', [], request) + WebFluxTarget webFluxTarget = new WebFluxTarget(routerFunction) + def requestAndClient = webFluxTarget.prepareRequest(interaction, [:]) + def requestBuilder = requestAndClient.first + + when: + def response = webFluxTarget.executeInteraction(requestAndClient.second, requestBuilder) + + then: + response.statusCode == 200 + response.contentType.toString() == 'application/json' + response.body == '{"id":1234}' + } +} diff --git a/provider/src/main/kotlin/au/com/dius/pact/provider/ProviderVerifier.kt b/provider/src/main/kotlin/au/com/dius/pact/provider/ProviderVerifier.kt index 177d376034..4705d5f7f1 100644 --- a/provider/src/main/kotlin/au/com/dius/pact/provider/ProviderVerifier.kt +++ b/provider/src/main/kotlin/au/com/dius/pact/provider/ProviderVerifier.kt @@ -16,13 +16,12 @@ import au.com.dius.pact.core.model.Pact import au.com.dius.pact.core.model.PactReader import au.com.dius.pact.core.model.PactSource import au.com.dius.pact.core.model.ProviderState -import au.com.dius.pact.core.model.RequestResponseInteraction import au.com.dius.pact.core.model.SynchronousRequestResponse import au.com.dius.pact.core.model.UrlPactSource import au.com.dius.pact.core.model.UrlSource import au.com.dius.pact.core.model.generators.GeneratorTestMode -import au.com.dius.pact.core.model.messaging.MessageInteraction import au.com.dius.pact.core.model.messaging.Message +import au.com.dius.pact.core.model.messaging.MessageInteraction import au.com.dius.pact.core.pactbroker.IPactBrokerClient import au.com.dius.pact.core.support.expressions.SystemPropertyResolver import au.com.dius.pact.core.support.hasProperty @@ -237,12 +236,16 @@ interface IProviderVerifier { fun generateErrorStringFromVerificationResult(result: List): String fun reportStateChangeFailed(providerState: ProviderState, error: Exception, isSetup: Boolean) + + fun initialiseReporters(provider: IProviderInfo) + + fun reportVerificationForConsumer(consumer: IConsumerInfo, provider: IProviderInfo, pactSource: PactSource?) } /** * Verifies the providers against the defined consumers in the context of a build plugin */ -@Suppress("TooManyFunctions") +@Suppress("TooManyFunctions", "LongParameterList") open class ProviderVerifier @JvmOverloads constructor ( override var pactLoadFailureMessage: Any? = null, override var checkBuildSpecificTask: Function = Function { false }, @@ -294,26 +297,7 @@ open class ProviderVerifier @JvmOverloads constructor ( ): VerificationResult { val interactionId = interaction.interactionId try { - val classGraph = ClassGraph().enableAllInfo() - if (System.getProperty("pact.verifier.classpathscan.verbose") != null) { - classGraph.verbose() - } - - val classLoader = projectClassLoader?.get() - if (classLoader == null) { - val urls = projectClasspath.get() - logger.debug { "projectClasspath = $urls" } - if (urls.isNotEmpty()) { - classGraph.overrideClassLoaders(URLClassLoader(urls.toTypedArray())) - } - } else { - classGraph.overrideClassLoaders(classLoader) - } - - val scan = ProviderUtils.packagesToScan(providerInfo, consumer) - if (scan.isNotEmpty()) { - classGraph.whitelistPackages(*scan.toTypedArray()) - } + val classGraph = setupClassGraph(providerInfo, consumer) val methodsAnnotatedWith = classGraph.scan().use { scanResult -> scanResult.getClassesWithMethodAnnotation(PactVerifyProvider::class.qualifiedName) @@ -365,6 +349,30 @@ open class ProviderVerifier @JvmOverloads constructor ( } } + private fun setupClassGraph(providerInfo: IProviderInfo, consumer: IConsumerInfo): ClassGraph { + val classGraph = ClassGraph().enableAllInfo() + if (System.getProperty("pact.verifier.classpathscan.verbose") != null) { + classGraph.verbose() + } + + val classLoader = projectClassLoader?.get() + if (classLoader == null) { + val urls = projectClasspath.get() + logger.debug { "projectClasspath = $urls" } + if (urls.isNotEmpty()) { + classGraph.overrideClassLoaders(URLClassLoader(urls.toTypedArray())) + } + } else { + classGraph.overrideClassLoaders(classLoader) + } + + val scan = ProviderUtils.packagesToScan(providerInfo, consumer) + if (scan.isNotEmpty()) { + classGraph.whitelistPackages(*scan.toTypedArray()) + } + return classGraph + } + private fun emitEvent(event: Event) { reporters.forEach { it.receive(event) } } @@ -678,7 +686,7 @@ open class ProviderVerifier @JvmOverloads constructor ( } } - fun initialiseReporters(provider: IProviderInfo) { + override fun initialiseReporters(provider: IProviderInfo) { reporters.forEach { if (it.hasProperty("displayFullDiff")) { (it.property("displayFullDiff") as KMutableProperty1) @@ -732,7 +740,11 @@ open class ProviderVerifier @JvmOverloads constructor ( } } - fun reportVerificationForConsumer(consumer: IConsumerInfo, provider: IProviderInfo, pactSource: PactSource?) { + override fun reportVerificationForConsumer( + consumer: IConsumerInfo, + provider: IProviderInfo, + pactSource: PactSource? + ) { when (pactSource) { is BrokerUrlSource -> reporters.forEach { reporter -> reporter.reportVerificationForConsumer(consumer, provider, pactSource.tag)