Skip to content

Commit

Permalink
feat: Support loading PactSource from annotations on the test class (…
Browse files Browse the repository at this point in the history
…JUnit 4) #1237
  • Loading branch information
Ronald Holshausen committed Nov 3, 2020
1 parent f05d904 commit 94275ae
Show file tree
Hide file tree
Showing 16 changed files with 242 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import au.com.dius.pact.core.model.Interaction
import au.com.dius.pact.core.model.Pact
import au.com.dius.pact.core.support.expressions.SystemPropertyResolver
import au.com.dius.pact.core.support.json.JsonException
import au.com.dius.pact.provider.ProviderUtils
import au.com.dius.pact.provider.ProviderUtils.findAnnotation
import au.com.dius.pact.provider.junit.target.HttpTarget
import au.com.dius.pact.provider.junitsupport.AllowOverridePactUrl
import au.com.dius.pact.provider.junitsupport.Consumer
Expand All @@ -26,7 +28,7 @@ import org.junit.runners.model.InitializationError
import org.junit.runners.model.TestClass
import java.io.IOException
import kotlin.reflect.full.createInstance
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.memberProperties

/**
* JUnit Runner runs pacts against provider
Expand Down Expand Up @@ -67,23 +69,30 @@ open class PactRunner<I>(private val clazz: Class<*>) : ParentRunner<Interaction
if (clazz.getAnnotation(Ignore::class.java) != null) {
logger.info("Ignore annotation detected, exiting")
} else {
val providerInfo = clazz.getAnnotation(Provider::class.java) ?: throw InitializationError(
val providerInfo = findAnnotation(clazz, Provider::class.java) ?: throw InitializationError(
"Provider name should be specified by using ${Provider::class.java.simpleName} annotation")
logger.debug { "Found annotation $providerInfo" }
val serviceName = providerInfo.value

val consumerInfo = clazz.getAnnotation(Consumer::class.java)
val consumerInfo = findAnnotation(clazz, Consumer::class.java)
if (consumerInfo != null) {
logger.debug { "Found annotation $consumerInfo" }
}
val consumerName = consumerInfo?.value

val testClass = TestClass(clazz)
val ignoreNoPactsToVerify = clazz.getAnnotation(IgnoreNoPactsToVerify::class.java)
val ignoreNoPactsToVerify = findAnnotation(clazz, IgnoreNoPactsToVerify::class.java)
if (ignoreNoPactsToVerify != null) {
logger.debug { "Found annotation $ignoreNoPactsToVerify" }
}
val ignoreIoErrors = try {
valueResolver.resolveValue(ignoreNoPactsToVerify?.ignoreIoErrors)
} catch (e: RuntimeException) {
logger.debug(e) { "Failed to resolve property value" }
ignoreNoPactsToVerify?.ignoreIoErrors
} ?: "false"

val pactLoader = getPactSource(testClass)
val pactLoader = getPactSource(testClass, consumerInfo)
val pacts = try {
filterPacts(pactLoader.load(serviceName)
.filter { p -> consumerName == null || p.consumer.name == consumerName } as List<Pact<I>>)
Expand Down Expand Up @@ -150,40 +159,57 @@ open class PactRunner<I>(private val clazz: Class<*>) : ParentRunner<Interaction
interaction.run(notifier)
}

protected open fun getPactSource(clazz: TestClass): PactLoader {
val pactSource = clazz.getAnnotation(PactSource::class.java)
val pactLoaders = clazz.annotations
.filter { annotation -> annotation.annotationClass.findAnnotation<PactSource>() != null }
if ((if (pactSource == null) 0 else 1) + pactLoaders.size != 1) {
throw InitializationError("Exactly one pact source should be set")
protected open fun getPactSource(clazz: TestClass, consumerInfo: Consumer?): PactLoader {
val pactSources = ProviderUtils.findAllPactSources(clazz.javaClass.kotlin)
if (pactSources.size > 1) {
throw InitializationError(
"Exactly one pact source should be set, found ${pactSources.size}: " +
pactSources.map { it.first }.joinToString(", "))
} else if (pactSources.isEmpty()) {
throw InitializationError("Did not find any PactSource annotations. Exactly one pact source should be set")
}

try {
val loader = if (pactSource != null) {
val pactLoaderClass = pactSource.value
try {
// Checks if there is a constructor with one argument of type Class.
val constructorWithClass = pactLoaderClass.java.getDeclaredConstructor(Class::class.java)
if (constructorWithClass != null) {
constructorWithClass.isAccessible = true
constructorWithClass.newInstance(clazz.javaClass)
} else {
pactLoaderClass.createInstance()
val (pactSource, annotation) = pactSources.first()
return try {
val pactLoaderClass = pactSource.value
val loader = try {
// Checks if there is a constructor with one argument of type Class.
val constructorWithClass = pactLoaderClass.java.getDeclaredConstructor(Class::class.java)
constructorWithClass.isAccessible = true
constructorWithClass.newInstance(clazz.javaClass)
} catch (e: NoSuchMethodException) {
logger.debug { "Pact source does not have a constructor with one argument of type Class" }
if (annotation != null) {
try {
// Check for a constructor with one argument with the type from the annotation with the PactSource
val constructor = pactLoaderClass.java.getDeclaredConstructor(annotation.annotationClass.java)
constructor.isAccessible = true
constructor.newInstance(annotation)
} catch (e: NoSuchMethodException) {
logger.debug {
"Pact loader does not have a constructor with one argument of type $pactSource"
}
try {
// Check for a constructor with one argument with the type from the PactSource annotation value
val annotationValueProp = annotation.annotationClass.memberProperties.find { it.name == "value" }
val annotationValue = annotationValueProp!!.getter.call(annotation)!!
pactLoaderClass.java.getDeclaredConstructor(annotationValue.javaClass).newInstance(annotationValue)
} catch (e: NoSuchMethodException) {
logger.debug {
"Pact loader does not have a constructor with one argument of type ${pactSource.value}"
}
pactLoaderClass.createInstance()
}
}
} catch (e: NoSuchMethodException) {
logger.error(e) { e.message }
} else {
pactLoaderClass.createInstance()
}
} else {
val annotation = pactLoaders.first()
annotation.annotationClass.findAnnotation<PactSource>()!!.value.java
.getConstructor(annotation.annotationClass.java).newInstance(annotation)
}

checkForOverriddenPactUrl(loader, clazz.getAnnotation(AllowOverridePactUrl::class.java),
clazz.getAnnotation(Consumer::class.java))
checkForOverriddenPactUrl(loader, findAnnotation(clazz.javaClass, AllowOverridePactUrl::class.java),
consumerInfo)

return loader
loader
} catch (e: ReflectiveOperationException) {
logger.error(e) { "Error while creating pact source" }
throw InitializationError(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import au.com.dius.pact.provider.IProviderInfo
import au.com.dius.pact.provider.IProviderVerifier
import au.com.dius.pact.provider.PactVerification
import au.com.dius.pact.provider.ProviderInfo
import au.com.dius.pact.provider.ProviderUtils.findAnnotation
import au.com.dius.pact.provider.ProviderVerifier
import au.com.dius.pact.provider.VerificationResult
import au.com.dius.pact.provider.junitsupport.Provider
Expand Down Expand Up @@ -91,7 +92,7 @@ open class AmqpTarget @JvmOverloads constructor(
}

override fun getProviderInfo(source: PactSource): ProviderInfo {
val provider = testClass.getAnnotation(Provider::class.java)
val provider = findAnnotation(testClass.javaClass, Provider::class.java)!!
val providerInfo = ProviderInfo(provider.value)
providerInfo.verificationType = PactVerification.ANNOTATED_METHOD
providerInfo.packagesToScan = packagesToScan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.File
import java.util.function.BiConsumer
import java.util.function.Supplier
import au.com.dius.pact.core.support.BuiltToolConfig.detectedBuildToolPactDirectory
import au.com.dius.pact.provider.ProviderUtils
import org.apache.commons.io.FilenameUtils

/**
Expand Down Expand Up @@ -54,8 +55,8 @@ abstract class BaseTarget : TestClassAwareTarget {
var reportDirectory = FilenameUtils.concat(detectedBuildToolPactDirectory(), "reports")
var reportingEnabled = false

val verificationReports = testClass.getAnnotation(VerificationReports::class.java)
val reports: List<String> = when {
val verificationReports = ProviderUtils.findAnnotation(testClass.javaClass, VerificationReports::class.java)
val reports: List<String> = when {
verificationReports != null -> {
reportingEnabled = true
if (verificationReports.reportDir.isNotEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import au.com.dius.pact.provider.IProviderInfo
import au.com.dius.pact.provider.IProviderVerifier
import au.com.dius.pact.provider.ProviderClient
import au.com.dius.pact.provider.ProviderInfo
import au.com.dius.pact.provider.ProviderUtils
import au.com.dius.pact.provider.ProviderVerifier
import au.com.dius.pact.provider.VerificationResult
import au.com.dius.pact.provider.junitsupport.Provider
Expand Down Expand Up @@ -104,7 +105,7 @@ open class HttpTarget
}

override fun getProviderInfo(source: PactSource): ProviderInfo {
val provider = testClass.getAnnotation(Provider::class.java)
val provider = ProviderUtils.findAnnotation(testClass.javaClass, Provider::class.java)!!
val providerInfo = ProviderInfo(provider.value)
providerInfo.port = port
providerInfo.host = host
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import au.com.dius.pact.provider.IProviderInfo
import au.com.dius.pact.provider.IProviderVerifier
import au.com.dius.pact.provider.PactVerification
import au.com.dius.pact.provider.ProviderInfo
import au.com.dius.pact.provider.ProviderUtils
import au.com.dius.pact.provider.ProviderVerifier
import au.com.dius.pact.provider.VerificationResult
import au.com.dius.pact.provider.junitsupport.Provider
Expand Down Expand Up @@ -90,7 +91,7 @@ open class MessageTarget @JvmOverloads constructor(
}

override fun getProviderInfo(source: PactSource): ProviderInfo {
val provider = testClass.getAnnotation(Provider::class.java)
val provider = ProviderUtils.findAnnotation(testClass.javaClass, Provider::class.java)!!
val providerInfo = ProviderInfo(provider.value)
providerInfo.verificationType = PactVerification.ANNOTATED_METHOD
providerInfo.packagesToScan = packagesToScan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class PactRunnerSpec extends Specification {

then:
InitializationError e = thrown()
e.causes*.message == ['Exactly one pact source should be set']
e.causes*.message == ['Did not find any PactSource annotations. Exactly one pact source should be set']
}

def 'PactRunner throws an exception if the pact source throws an IO exception'() {
Expand Down Expand Up @@ -189,7 +189,7 @@ class PactRunnerSpec extends Specification {

then:
InitializationError e = thrown()
e.causes*.message == ['Exactly one pact source should be set']
e.causes[0].message.startsWith('Exactly one pact source should be set, found 2: ')
}

def 'PactRunner handles a pact source with a pact loader that takes a class parameter'() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package au.com.dius.pact.provider.junit;

import au.com.dius.pact.provider.junit.target.HttpTarget;
import au.com.dius.pact.provider.junitsupport.Provider;
import au.com.dius.pact.provider.junitsupport.State;
import au.com.dius.pact.provider.junitsupport.TargetRequestFilter;
import au.com.dius.pact.provider.junitsupport.VerificationReports;
import au.com.dius.pact.provider.junitsupport.loader.PactFolder;
import au.com.dius.pact.provider.junitsupport.target.Target;
import au.com.dius.pact.provider.junitsupport.target.TestTarget;
import com.github.restdriver.clientdriver.ClientDriverRule;
import org.apache.http.HttpRequest;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

import static com.github.restdriver.clientdriver.RestClientDriver.giveEmptyResponse;
import static com.github.restdriver.clientdriver.RestClientDriver.onRequestTo;

@RunWith(PactRunner.class)
@IsContractTest
public class CustomAnnotationContractTest {
@ClassRule
public static final ClientDriverRule embeddedService = new ClientDriverRule(8339);
private static final Logger LOGGER = LoggerFactory.getLogger(CustomAnnotationContractTest.class);

@TestTarget
public final Target target = new HttpTarget(8339);

@Before
public void before() {
embeddedService.addExpectation(
onRequestTo("/data").withAnyParams(), giveEmptyResponse()
);
}

@State("default")
public void toDefaultState() {
}

@State("state 2")
public void toSecondState(Map params) {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package au.com.dius.pact.provider.junit;

import au.com.dius.pact.provider.junitsupport.Provider;
import au.com.dius.pact.provider.junitsupport.loader.PactFolder;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(value = ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Provider("myAwesomeService")
@PactFolder("pacts")
public @interface IsContractTest {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.messaging.Message
import au.com.dius.pact.provider.junit.InteractionRunner
import au.com.dius.pact.provider.junit.MessagePactRunner
import au.com.dius.pact.provider.junitsupport.Consumer
import au.com.dius.pact.provider.junitsupport.loader.PactLoader
import org.junit.runners.model.Statement
import org.junit.runners.model.TestClass
Expand Down Expand Up @@ -45,10 +46,10 @@ open class SpringMessagePactRunner(clazz: Class<*>) : MessagePactRunner<Message>
return SpringInteractionRunner(testClass, pact, pactSource, initTestContextManager(testClass.javaClass))
}

override fun getPactSource(clazz: TestClass): PactLoader {
override fun getPactSource(clazz: TestClass, consumerInfo: Consumer?): PactLoader {
initTestContextManager(clazz.javaClass)
val environment = testContextManager!!.testContext.applicationContext.environment
val pactSource = super.getPactSource(clazz)
val pactSource = super.getPactSource(clazz, consumerInfo)
pactSource.setValueResolver(SpringEnvironmentResolver(environment))
return pactSource
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import au.com.dius.pact.core.model.PactSource
import au.com.dius.pact.core.model.RequestResponseInteraction
import au.com.dius.pact.provider.junit.InteractionRunner
import au.com.dius.pact.provider.junit.RestPactRunner
import au.com.dius.pact.provider.junitsupport.Consumer
import au.com.dius.pact.provider.junitsupport.loader.PactLoader
import org.junit.runners.model.Statement
import org.junit.runners.model.TestClass
Expand All @@ -20,6 +21,7 @@ import org.springframework.test.context.web.ServletTestExecutionListener
/**
* Pact runner for REST providers that boots up the spring context
*/
@ExperimentalStdlibApi
open class SpringRestPactRunner(clazz: Class<*>) : RestPactRunner<RequestResponseInteraction>(clazz) {

private var testContextManager: TestContextManager? = null
Expand Down Expand Up @@ -52,10 +54,10 @@ open class SpringRestPactRunner(clazz: Class<*>) : RestPactRunner<RequestRespons
return SpringInteractionRunner(testClass, pact, pactSource, initTestContextManager(testClass.javaClass))
}

override fun getPactSource(clazz: TestClass): PactLoader {
override fun getPactSource(clazz: TestClass, consumerInfo: Consumer?): PactLoader {
initTestContextManager(clazz.javaClass)
val environment = testContextManager!!.testContext.applicationContext.environment
val pactSource = super.getPactSource(clazz)
val pactSource = super.getPactSource(clazz, consumerInfo)
pactSource.setValueResolver(SpringEnvironmentResolver(environment))
return pactSource
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import au.com.dius.pact.provider.IProviderVerifier
import au.com.dius.pact.provider.ProviderInfo
import au.com.dius.pact.provider.ProviderVerifier
import au.com.dius.pact.provider.PactVerification
import au.com.dius.pact.provider.ProviderUtils
import au.com.dius.pact.provider.VerificationResult
import au.com.dius.pact.provider.junit.target.BaseTarget
import au.com.dius.pact.provider.junitsupport.Provider
Expand All @@ -21,7 +22,7 @@ abstract class MockTestingTarget(
) : BaseTarget() {

override fun getProviderInfo(source: PactSource): ProviderInfo {
val provider = testClass.getAnnotation(Provider::class.java)
val provider = ProviderUtils.findAnnotation(testClass.javaClass, Provider::class.java)!!
val providerInfo = ProviderInfo(provider.value)

val methods = testClass.getAnnotatedMethods(TargetRequestFilter::class.java)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package au.com.dius.pact.provider.junitsupport.loader;

import au.com.dius.pact.provider.junitsupport.loader.PactFolderLoader;

import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
Expand Down
Loading

0 comments on commit 94275ae

Please sign in to comment.