Skip to content

Commit

Permalink
feat: Support setting the gRPC status for the response
Browse files Browse the repository at this point in the history
  • Loading branch information
rholshausen committed Jun 1, 2023
1 parent b25b4aa commit 564cefa
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 41 deletions.
13 changes: 6 additions & 7 deletions src/mock_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use tonic::body::{BoxBody, empty_body};
use tonic::metadata::MetadataMap;
use tower::make::Shared;
use tower::ServiceBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::ServiceBuilderExt;
use tower_service::Service;
use tracing::{debug, error, Instrument, instrument, trace, trace_span};
Expand Down Expand Up @@ -166,8 +165,6 @@ impl GrpcMockServer
let service = ServiceBuilder::new()
// High level logging of requests and responses
.trace_for_grpc()
// Compress responses
.layer(CompressionLayer::new())
// Wrap a `Service` in our middleware stack
.service(self);

Expand Down Expand Up @@ -209,8 +206,8 @@ impl GrpcMockServer
}

impl Service<Request<hyper::Body>> for GrpcMockServer {
type Response = hyper::Response<BoxBody>;
type Error = anyhow::Error;
type Response = Response<BoxBody>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand Down Expand Up @@ -243,7 +240,7 @@ impl Service<Request<hyper::Body>> for GrpcMockServer {
let method = req.method();
if method == Method::POST {
let request_path = req.uri().path();
debug!("gRPC request received {}", request_path);
debug!(?request_path, "gRPC request received");
if let Some((service, method)) = request_path[1..].split_once('/') {
let service_name = last_name(service);
let lookup = format!("{service_name}/{method}");
Expand All @@ -263,7 +260,9 @@ impl Service<Request<hyper::Body>> for GrpcMockServer {
pact
);
let mut grpc = tonic::server::Grpc::new(codec);
Ok(grpc.unary(mock_service, req).await)
let response = grpc.unary(mock_service, req).await;
trace!(?response, ">> sending response");
Ok(response)
} else {
error!("Did not find the descriptor for the output message {}", output_message_name);
Ok(failed_precondition())
Expand Down
155 changes: 129 additions & 26 deletions src/mock_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ use pact_models::v4::message_parts::MessageContents;
use pact_models::v4::sync_message::SynchronousMessage;
use pact_plugin_driver::plugin_models::PluginInteractionConfig;
use prost_types::{DescriptorProto, FileDescriptorSet, MethodDescriptorProto};
use tonic::{Request, Response, Status};
use tonic::{Code, Request, Response, Status};
use tonic::metadata::{Entry, MetadataMap};
use tower_service::Service;
use tracing::{debug, error, instrument, trace, warn};
use tracing::{debug, error, info, instrument, trace, warn};

use crate::dynamic_message::DynamicMessage;
use crate::matching::compare;
Expand All @@ -45,7 +45,7 @@ impl MockService {
request: DynamicMessage,
message_descriptor: DescriptorProto,
response_descriptor: DescriptorProto,
request_metadata: &MetadataMap
request_metadata: MetadataMap
) -> Result<Response<DynamicMessage>, Status> {
// 1. Compare the incoming message to the request message from the interaction
let mut expected_message_bytes = self.message.request.contents.value().unwrap_or_default();
Expand All @@ -69,7 +69,7 @@ impl MockService {
let md_context = CoreMatchingContext::new(DiffConfig::NoUnexpectedKeys,
&self.message.request.matching_rules.rules_for_category("metadata").unwrap_or_default(),
&plugin_config);
let md_mismatches = compare_metadata(&self.message.request.metadata, request_metadata,
let md_mismatches = compare_metadata(&self.message.request.metadata, &request_metadata,
&md_context);

trace!("Comparison result = {:?}", mismatches);
Expand All @@ -90,27 +90,34 @@ impl MockService {
}

if result.all_matched() && md_result.all_matched() {
debug!("Request matched OK, returning expected response");
debug!("Request matched OK");
let response_contents = self.message.response.first().cloned().unwrap_or_default();
let mut response_bytes = response_contents.contents.value()
.unwrap_or_default();
trace!("Response message has {} bytes", response_bytes.len());
let response_message = decode_message(&mut response_bytes, &response_descriptor, &self.file_descriptor_set)
.map_err(|err| {
error!("Failed to encode response message - {}", err);
// check for a gRPC status on the response metadata
if let Some(status) = grpc_status(&response_contents) {
info!("a gRPC status {} is set for the response, returning that", status);
Err(status)
} else {
debug!("Returning response");
let mut response_bytes = response_contents.contents.value()
.unwrap_or_default();
trace!("Response message has {} bytes", response_bytes.len());
let response_message = decode_message(&mut response_bytes, &response_descriptor, &self.file_descriptor_set)
.map_err(|err| {
error!("Failed to encode response message - {}", err);
Status::invalid_argument(err.to_string())
})?;
let mut message = DynamicMessage::new(&response_message, &self.file_descriptor_set);
self.apply_generators(&mut message, &response_contents).map_err(|err| {
error!("Failed to generate response message - {}", err);
Status::invalid_argument(err.to_string())
})?;
let mut message = DynamicMessage::new(&response_message, &self.file_descriptor_set);
self.apply_generators(&mut message, &response_contents).map_err(|err| {
error!("Failed to generate response message - {}", err);
Status::invalid_argument(err.to_string())
})?;
trace!("Sending message {message:?}");
let mut response = Response::new(message);
if !response_contents.metadata.is_empty() {
Self::set_response_metadata(response_contents, &mut response);
trace!("Sending message {message:?}");
let mut response = Response::new(message);
if !response_contents.metadata.is_empty() {
Self::set_response_metadata(response_contents, &mut response);
}
Ok(response)
}
Ok(response)
} else {
error!("Failed to match the request message - {result:?}");
Err(Status::failed_precondition(format!("Failed to match the request message - {result:?}")))
Expand All @@ -132,7 +139,8 @@ impl MockService {
for (key, value) in &response_contents.metadata {
let key = key.to_lowercase();
// exclude the content type, because that is a special value added by the Pact framework
if key != "content-type" && key != "contenttype" {
// also exclude the gRPC status, because that is handled separately
if key != "content-type" && key != "contenttype" && key != "grpc-status" {
match json_to_string(value).parse() {
Ok(parsed_val) => {
match md.entry(key.as_str()) {
Expand All @@ -159,6 +167,45 @@ impl MockService {
}
}

fn grpc_status(response_contents: &MessageContents) -> Option<Status> {
if let Some(value) = response_contents.metadata.get("grpc-status") {
let status = json_to_string(value);
let message = response_contents.metadata.get("grpc-message")
.map(json_to_string)
.unwrap_or("No message set".to_string());
match status.as_str() {
// Taken from https://grpc.github.io/grpc/core/md_doc_statuscodes.html
"OK" => None,
"CANCELLED" => Some(Status::cancelled(message)),
"UNKNOWN" => Some(Status::unknown(message)),
"INVALID_ARGUMENT" => Some(Status::invalid_argument(message)),
"DEADLINE_EXCEEDED" => Some(Status::deadline_exceeded(message)),
"NOT_FOUND" => Some(Status::not_found(message)),
"ALREADY_EXISTS" => Some(Status::already_exists(message)),
"PERMISSION_DENIED" => Some(Status::permission_denied(message)),
"RESOURCE_EXHAUSTED" => Some(Status::resource_exhausted(message)),
"FAILED_PRECONDITION" => Some(Status::failed_precondition(message)),
"ABORTED" => Some(Status::aborted(message)),
"OUT_OF_RANGE" => Some(Status::out_of_range(message)),
"UNIMPLEMENTED" => Some(Status::unimplemented(message)),
"INTERNAL" => Some(Status::internal(message)),
"UNAVAILABLE" => Some(Status::unavailable(message)),
"DATA_LOSS" => Some(Status::data_loss(message)),
"UNAUTHENTICATED" => Some(Status::unauthenticated(message)),
_ => {
let code = Code::from_bytes(status.as_bytes());
if code == Code::Ok {
None
} else {
Some(Status::new(code, message))
}
}
}
} else {
None
}
}

impl MockService {
pub(crate) fn new(
file_descriptor_set: &FileDescriptorSet,
Expand Down Expand Up @@ -217,7 +264,7 @@ impl Service<Request<DynamicMessage>> for MockService {
let response_descriptor = self.output_message.clone();
let service = self.clone();
Box::pin(async move {
service.handle_message(request, message_descriptor, response_descriptor, &request_metadata).await
service.handle_message(request, message_descriptor, response_descriptor, request_metadata).await
})
}
}
Expand All @@ -228,15 +275,18 @@ mod tests {
use base64::engine::general_purpose::STANDARD as BASE64;
use bytes::{Bytes, BytesMut};
use expectest::prelude::*;
use maplit::hashmap;
use pact_models::v4::message_parts::MessageContents;
use pact_models::v4::pact::V4Pact;
use prost::Message;
use prost_types::FileDescriptorSet;
use serde_json::json;
use tonic::Code;
use tonic::metadata::MetadataMap;

use crate::dynamic_message::DynamicMessage;
use crate::message_decoder::decode_message;

use crate::mock_service::MockService;
use crate::mock_service::{grpc_status, MockService};
use crate::protobuf::tests::DESCRIPTOR_BYTES;

#[test_log::test(tokio::test)]
Expand Down Expand Up @@ -360,11 +410,64 @@ mod tests {
};
let response = mock_service.handle_message(request,
input_message.clone(), output_message.clone(),
&MetadataMap::default()
MetadataMap::default()
).await.unwrap();
let response_message = response.into_inner();
let response_fields = response_message.proto_fields();
let area = &response_fields[0];
expect!(area.data.to_string()).to_not(be_equal_to("12"));
}

#[test]
fn grpc_status_test_no_status_set() {
let message = MessageContents {
contents: Default::default(),
metadata: hashmap!{},
matching_rules: Default::default(),
generators: Default::default(),
};
expect!(grpc_status(&message)).to(be_none());
}

fn setup_message(status: &str, message: Option<&str>) -> MessageContents {
if let Some(message) = message {
MessageContents {
metadata: hashmap!{
"grpc-status".to_string() => json!(status),
"grpc-message".to_string() => json!(message)
},
.. MessageContents::default()
}
} else {
MessageContents {
metadata: hashmap!{ "grpc-status".to_string() => json!(status) },
.. MessageContents::default()
}
}
}

#[test]
fn grpc_status_test_status_set_by_value() {
let message = setup_message("OK", None);
expect!(grpc_status(&message)).to(be_none());

let message = setup_message("CANCELLED", None);
expect!(grpc_status(&message).unwrap().code()).to(be_equal_to(Code::Cancelled));
let message = setup_message("UNKNOWN", Some("it went bang, Mate!"));
let status = grpc_status(&message).unwrap();
expect!(status.code()).to(be_equal_to(Code::Unknown));
expect!(status.message()).to(be_equal_to("it went bang, Mate!"));

let message = setup_message("10", None);
expect!(grpc_status(&message).unwrap().code()).to(be_equal_to(Code::Aborted));
}

#[test]
fn grpc_status_test_inavlid_status() {
let message = setup_message("GGGH", None);
expect!(grpc_status(&message).unwrap().code()).to(be_equal_to(Code::Unknown));

let message = setup_message("33", None);
expect!(grpc_status(&message).unwrap().code()).to(be_equal_to(Code::Unknown));
}
}
37 changes: 30 additions & 7 deletions src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,12 @@ fn response_part<'a>(
config: &'a BTreeMap<String, prost_types::Value>,
service_part: &str
) -> anyhow::Result<Vec<(BTreeMap<String, prost_types::Value>, Option<&'a prost_types::Value>)>> {
trace!(?config, ?service_part, "response_part");
if service_part == "response" {
Ok(vec![(config.clone(), None)])
} else {
Ok(config.get("response").and_then(|response_config| {
response_config.kind.as_ref().map(|kind| {
} else if let Some(response_config) = config.get("response") {
Ok(response_config.kind.as_ref()
.map(|kind| {
match kind {
Kind::StructValue(s) => {
let metadata = config.get("responseMetadata");
Expand All @@ -240,9 +241,11 @@ fn response_part<'a>(
Kind::StringValue(_) => vec![(btreemap! { "value".to_string() => response_config.clone() }, None)],
_ => vec![]
}
})
})
.unwrap_or_default())
}).unwrap_or_default())
} else if let Some(response_md_config) = config.get("responseMetadata") {
Ok(vec![(btreemap!{}, Some(response_md_config))])
} else {
Ok(vec![])
}
}

Expand Down Expand Up @@ -2142,7 +2145,7 @@ pub(crate) mod tests {
}

#[test]
fn configuring_response_part_returns_empty_map_if_there_is_no_response_element() {
fn configuring_response_part_returns_empty_map_if_there_is_no_response_elements() {
let config = btreemap!{};
let result = response_part(&config, "").unwrap();
expect!(result).to(be_equal_to(vec![]));
Expand Down Expand Up @@ -2241,4 +2244,24 @@ pub(crate) mod tests {
};
expect!(result).to(be_equal_to(vec![(response_config, Some(&expected_metadata))]));
}

#[test]
fn configuring_response_part_deals_with_the_case_where_there_is_only_metadata() {
let response_metadata_config = btreemap!{
"C".to_string() => prost_types::Value { kind: Some(StringValue("D".to_string())) }
};
let config = btreemap!{
"responseMetadata".to_string() => prost_types::Value { kind: Some(StructValue(Struct {
fields: response_metadata_config.clone()
}))
}
};
let result = response_part(&config, "").unwrap();
let expected_metadata = prost_types::Value {
kind: Some(StructValue(Struct {
fields: response_metadata_config.clone()
}))
};
expect!(result).to(be_equal_to(vec![(btreemap!{}, Some(&expected_metadata))]));
}
}
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ impl PactPlugin for ProtobufPactPlugin {
}
Err(err) => {
error!("Failed to process protobuf: {}", err);
Ok(tonic::Response::new(proto::ConfigureInteractionResponse {
Ok(Response::new(proto::ConfigureInteractionResponse {
error: format!("Failed to process protobuf: {}", err),
.. proto::ConfigureInteractionResponse::default()
}))
Expand Down
1 change: 1 addition & 0 deletions tests/pact_verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ async fn init_plugin_request(request_message: &MessageRequest) -> (Status, (Cont
#[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
async fn verify_plugin() {
// Test Setup
#[allow(deprecated)]
let provider_info = ProviderInfo {
name: "plugin".to_string(),
port: Some(8000),
Expand Down

0 comments on commit 564cefa

Please sign in to comment.