Skip to content

Commit

Permalink
feat: added support for WebFlux with Spring JUnit 5 tests #1373
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Holshausen committed Jun 16, 2021
1 parent 4abd4ff commit d81b2b8
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 29 deletions.
2 changes: 1 addition & 1 deletion provider/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies {
"org.apache.httpcomponents:httpclient:${project.httpClientVersion}"
implementation "org.slf4j:slf4j-api:${project.slf4jVersion}"
implementation "org.scala-lang:scala-library:${project.scalaVersion}"
implementation 'io.github.classgraph:classgraph:4.8.105'
api 'io.github.classgraph:classgraph:4.8.105'
implementation "org.codehaus.groovy:groovy:${project.groovyVersion}"
api 'com.michael-bull.kotlin-result:kotlin-result:1.1.6'
implementation 'com.github.ajalt:mordant:1.2.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import au.com.dius.pact.core.model.DirectorySource
import au.com.dius.pact.core.model.Interaction
import au.com.dius.pact.core.model.PactBrokerSource
import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.RequestResponseInteraction
import au.com.dius.pact.core.model.SynchronousRequestResponse
import au.com.dius.pact.core.model.generators.GeneratorTestMode
import au.com.dius.pact.core.model.messaging.Message
Expand Down
1 change: 1 addition & 0 deletions provider/junit5spring/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies {
api 'org.springframework:spring-context:5.2.3.RELEASE'
api 'org.springframework:spring-test:5.2.3.RELEASE'
api 'org.springframework:spring-web:5.2.3.RELEASE'
api 'org.springframework:spring-webflux:5.2.3.RELEASE'
api 'javax.servlet:javax.servlet-api:3.1.0'
api 'org.hamcrest:hamcrest:2.1'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import au.com.dius.pact.core.model.ContentType
import au.com.dius.pact.core.model.IRequest
import au.com.dius.pact.core.model.Interaction
import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.Request
import au.com.dius.pact.core.model.RequestResponseInteraction
import au.com.dius.pact.core.model.SynchronousRequestResponse
import au.com.dius.pact.core.model.generators.GeneratorTestMode
import au.com.dius.pact.provider.IProviderVerifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import au.com.dius.pact.provider.junit5.PactVerificationContext
import au.com.dius.pact.provider.junit5.PactVerificationExtension
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.api.extension.ParameterContext
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder

open class PactVerificationSpringExtension(
Expand All @@ -24,6 +25,7 @@ open class PactVerificationSpringExtension(
val testContext = store.get("interactionContext") as PactVerificationContext
return when (parameterContext.parameter.type) {
MockHttpServletRequestBuilder::class.java -> testContext.target is MockMvcTestTarget
WebTestClient.RequestHeadersSpec::class.java -> testContext.target is WebFluxTarget
else -> super.supportsParameter(parameterContext, extensionContext)
}
}
Expand All @@ -32,6 +34,7 @@ open class PactVerificationSpringExtension(
val store = extensionContext.getStore(ExtensionContext.Namespace.create("pact-jvm"))
return when (parameterContext.parameter.type) {
MockHttpServletRequestBuilder::class.java -> store.get("request")
WebTestClient.RequestHeadersSpec::class.java -> store.get("request")
else -> super.resolveParameter(parameterContext, extensionContext)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package au.com.dius.pact.provider.spring.junit5

import au.com.dius.pact.core.model.ContentType
import au.com.dius.pact.core.model.IRequest
import au.com.dius.pact.core.model.Interaction
import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.SynchronousRequestResponse
import au.com.dius.pact.core.model.generators.GeneratorTestMode
import au.com.dius.pact.provider.IProviderVerifier
import au.com.dius.pact.provider.ProviderInfo
import au.com.dius.pact.provider.ProviderResponse
import au.com.dius.pact.provider.junit5.TestTarget
import org.apache.commons.lang3.StringUtils
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.http.MediaType
import org.springframework.http.client.MultipartBodyBuilder
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.reactive.function.server.RouterFunction
import org.springframework.web.util.UriComponentsBuilder
import javax.mail.internet.ContentDisposition
import javax.mail.internet.MimeMultipart
import javax.mail.util.ByteArrayDataSource

class WebFluxTarget(private val routerFunction: RouterFunction<*>) : TestTarget {
override fun getProviderInfo(serviceName: String, pactSource: PactSource?) = ProviderInfo(serviceName)

override fun prepareRequest(interaction: Interaction, context: MutableMap<String, Any>): Pair<WebTestClient.RequestHeadersSpec<*>, WebTestClient> {
if (interaction is SynchronousRequestResponse) {
val request = interaction.request.generatedRequest(context, GeneratorTestMode.Provider)
val webClient = WebTestClient.bindToRouterFunction(routerFunction).build()
return toWebFluxRequestBuilder(webClient, request) to webClient
}
throw UnsupportedOperationException("Only request/response interactions can be used with an MockMvc test target")
}

private fun toWebFluxRequestBuilder(webClient: WebTestClient, request: IRequest): WebTestClient.RequestHeadersSpec<*> {
return if (request.body.isPresent()) {
if (request.isMultipartFileUpload()) {
val multipart = MimeMultipart(ByteArrayDataSource(request.body.unwrap(), request.contentTypeHeader()))

val bodyBuilder = MultipartBodyBuilder()
var i = 0
while (i < multipart.count) {
val bodyPart = multipart.getBodyPart(i)
val contentDisposition = ContentDisposition(bodyPart.getHeader("Content-Disposition").first())
val name = StringUtils.defaultString(contentDisposition.getParameter("name"), "file")
val filename = contentDisposition.getParameter("filename").orEmpty()

bodyBuilder
.part(name, bodyPart.content)
.filename(filename)
.contentType(MediaType.valueOf(bodyPart.contentType))
.header("Content-Disposition", "form-data; name=$name; filename=$filename")

i++
}

webClient
.method(HttpMethod.POST)
.uri(requestUriString(request))
.body(BodyInserters.fromMultipartData(bodyBuilder.build()))
.headers { request.headers.forEach { (k, v) -> it.addAll(k, v) } }
} else {
webClient
.method(HttpMethod.valueOf(request.method))
.uri(requestUriString(request))
.bodyValue(request.body.value!!)
.headers {
request.headers.forEach { (k, v) -> it.addAll(k, v) }
if (!request.headers.containsKey(HttpHeaders.CONTENT_TYPE)) {
it.set(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
}
}
}
} else {
webClient
.method(HttpMethod.valueOf(request.method))
.uri(requestUriString(request))
.headers {
request.headers.forEach { (k, v) -> it.addAll(k, v) }
}
}
}

private fun requestUriString(request: IRequest): String {
val uriBuilder = UriComponentsBuilder.fromPath(request.path)

request.query.forEach { (key, value) ->
uriBuilder.queryParam(key, value)
}

return uriBuilder.toUriString()
}

override fun isHttpTarget() = true

override fun executeInteraction(client: Any?, request: Any?): ProviderResponse {
val requestBuilder = request as WebTestClient.RequestHeadersSpec<*>
val exchangeResult = requestBuilder.exchange().expectBody().returnResult()

val headers = mutableMapOf<String, List<String>>()
exchangeResult.responseHeaders.forEach { header ->
headers[header.key] = header.value
}

val contentTypeHeader = exchangeResult.responseHeaders.contentType
val contentType = if (contentTypeHeader == null) {
ContentType.JSON
} else {
ContentType.fromString(contentTypeHeader.toString())
}

return ProviderResponse(
exchangeResult.status.value(),
headers,
contentType,
exchangeResult.responseBody?.let { String(it) }
)
}

override fun prepareVerifier(verifier: IProviderVerifier, testInstance: Any) {
/* NO-OP */
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package au.com.dius.pact.provider.spring.junit5

import au.com.dius.pact.core.model.OptionalBody
import au.com.dius.pact.core.model.Request
import au.com.dius.pact.core.model.RequestResponseInteraction
import org.springframework.http.MediaType
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.reactive.function.server.RequestPredicates
import org.springframework.web.reactive.function.server.RouterFunction
import org.springframework.web.reactive.function.server.RouterFunctions
import org.springframework.web.reactive.function.server.ServerResponse
import spock.lang.Specification

import java.nio.charset.StandardCharsets

@SuppressWarnings('ClosureAsLastMethodParameter')
class WebFluxTargetSpec extends Specification {
RouterFunction routerFunction = RouterFunctions.route(RequestPredicates.GET('/data'), { req ->
ServerResponse.ok().contentType(MediaType.APPLICATION_JSON)
.body(BodyInserters.fromValue('{"id":1234}'))
})

def 'should prepare get request'() {
given:
WebFluxTarget webFluxTarget = new WebFluxTarget(routerFunction)
def request = new Request('GET', '/data', [id: ['1234']])
def interaction = new RequestResponseInteraction('some description', [], request)

when:
def requestAndClient = webFluxTarget.prepareRequest(interaction, [:])
def requestBuilder = requestAndClient.first
def builtRequest = requestBuilder.exchange().expectBody().returnResult()

then:
requestBuilder instanceof WebTestClient.RequestHeadersSpec
builtRequest.url.path == '/data'
builtRequest.method.toString() == 'GET'
new String(builtRequest.responseBody) == '{"id":1234}'
}

def 'should prepare post request'() {
given:
RouterFunction postRouterFunction = RouterFunctions.route(RequestPredicates.POST('/data'), { req ->
assert req.queryParams() == [id: ['1234']]
def reqBody = req.bodyToMono(String).doOnNext({ s -> assert s == '{"foo":"bar"}' })
ServerResponse.ok().build(reqBody)
})
WebFluxTarget webFluxTarget = new WebFluxTarget(postRouterFunction)
def request = new Request('POST', '/data', [id: ['1234']], [:],
OptionalBody.body('{"foo":"bar"}'.getBytes(StandardCharsets.UTF_8)))
def interaction = new RequestResponseInteraction('some description', [], request)

when:
def requestAndClient = webFluxTarget.prepareRequest(interaction, [:])
def requestBuilder = requestAndClient.first

then:
requestBuilder instanceof WebTestClient.RequestHeadersSpec
def builtRequest = requestBuilder.exchange().expectBody().returnResult()
builtRequest.url.path == '/data'
builtRequest.method.toString() == 'POST'
builtRequest.rawStatusCode == 200
}

def 'should execute interaction'() {
given:
def request = new Request('GET', '/data', [id: ['1234']])
def interaction = new RequestResponseInteraction('some description', [], request)
WebFluxTarget webFluxTarget = new WebFluxTarget(routerFunction)
def requestAndClient = webFluxTarget.prepareRequest(interaction, [:])
def requestBuilder = requestAndClient.first

when:
def response = webFluxTarget.executeInteraction(requestAndClient.second, requestBuilder)

then:
response.statusCode == 200
response.contentType.toString() == 'application/json'
response.body == '{"id":1234}'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ import au.com.dius.pact.core.model.Pact
import au.com.dius.pact.core.model.PactReader
import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.ProviderState
import au.com.dius.pact.core.model.RequestResponseInteraction
import au.com.dius.pact.core.model.SynchronousRequestResponse
import au.com.dius.pact.core.model.UrlPactSource
import au.com.dius.pact.core.model.UrlSource
import au.com.dius.pact.core.model.generators.GeneratorTestMode
import au.com.dius.pact.core.model.messaging.MessageInteraction
import au.com.dius.pact.core.model.messaging.Message
import au.com.dius.pact.core.model.messaging.MessageInteraction
import au.com.dius.pact.core.pactbroker.IPactBrokerClient
import au.com.dius.pact.core.support.expressions.SystemPropertyResolver
import au.com.dius.pact.core.support.hasProperty
Expand Down Expand Up @@ -237,12 +236,16 @@ interface IProviderVerifier {
fun generateErrorStringFromVerificationResult(result: List<VerificationResult.Failed>): String

fun reportStateChangeFailed(providerState: ProviderState, error: Exception, isSetup: Boolean)

fun initialiseReporters(provider: IProviderInfo)

fun reportVerificationForConsumer(consumer: IConsumerInfo, provider: IProviderInfo, pactSource: PactSource?)
}

/**
* Verifies the providers against the defined consumers in the context of a build plugin
*/
@Suppress("TooManyFunctions")
@Suppress("TooManyFunctions", "LongParameterList")
open class ProviderVerifier @JvmOverloads constructor (
override var pactLoadFailureMessage: Any? = null,
override var checkBuildSpecificTask: Function<Any, Boolean> = Function { false },
Expand Down Expand Up @@ -294,26 +297,7 @@ open class ProviderVerifier @JvmOverloads constructor (
): VerificationResult {
val interactionId = interaction.interactionId
try {
val classGraph = ClassGraph().enableAllInfo()
if (System.getProperty("pact.verifier.classpathscan.verbose") != null) {
classGraph.verbose()
}

val classLoader = projectClassLoader?.get()
if (classLoader == null) {
val urls = projectClasspath.get()
logger.debug { "projectClasspath = $urls" }
if (urls.isNotEmpty()) {
classGraph.overrideClassLoaders(URLClassLoader(urls.toTypedArray()))
}
} else {
classGraph.overrideClassLoaders(classLoader)
}

val scan = ProviderUtils.packagesToScan(providerInfo, consumer)
if (scan.isNotEmpty()) {
classGraph.whitelistPackages(*scan.toTypedArray())
}
val classGraph = setupClassGraph(providerInfo, consumer)

val methodsAnnotatedWith = classGraph.scan().use { scanResult ->
scanResult.getClassesWithMethodAnnotation(PactVerifyProvider::class.qualifiedName)
Expand Down Expand Up @@ -365,6 +349,30 @@ open class ProviderVerifier @JvmOverloads constructor (
}
}

private fun setupClassGraph(providerInfo: IProviderInfo, consumer: IConsumerInfo): ClassGraph {
val classGraph = ClassGraph().enableAllInfo()
if (System.getProperty("pact.verifier.classpathscan.verbose") != null) {
classGraph.verbose()
}

val classLoader = projectClassLoader?.get()
if (classLoader == null) {
val urls = projectClasspath.get()
logger.debug { "projectClasspath = $urls" }
if (urls.isNotEmpty()) {
classGraph.overrideClassLoaders(URLClassLoader(urls.toTypedArray()))
}
} else {
classGraph.overrideClassLoaders(classLoader)
}

val scan = ProviderUtils.packagesToScan(providerInfo, consumer)
if (scan.isNotEmpty()) {
classGraph.whitelistPackages(*scan.toTypedArray())
}
return classGraph
}

private fun emitEvent(event: Event) {
reporters.forEach { it.receive(event) }
}
Expand Down Expand Up @@ -678,7 +686,7 @@ open class ProviderVerifier @JvmOverloads constructor (
}
}

fun initialiseReporters(provider: IProviderInfo) {
override fun initialiseReporters(provider: IProviderInfo) {
reporters.forEach {
if (it.hasProperty("displayFullDiff")) {
(it.property("displayFullDiff") as KMutableProperty1<VerifierReporter, Boolean>)
Expand Down Expand Up @@ -732,7 +740,11 @@ open class ProviderVerifier @JvmOverloads constructor (
}
}

fun reportVerificationForConsumer(consumer: IConsumerInfo, provider: IProviderInfo, pactSource: PactSource?) {
override fun reportVerificationForConsumer(
consumer: IConsumerInfo,
provider: IProviderInfo,
pactSource: PactSource?
) {
when (pactSource) {
is BrokerUrlSource -> reporters.forEach { reporter ->
reporter.reportVerificationForConsumer(consumer, provider, pactSource.tag)
Expand Down

0 comments on commit d81b2b8

Please sign in to comment.