From cfb54b418b8efd970b126448718ff58dfb3ebc03 Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Thu, 31 Oct 2024 16:56:56 +1100 Subject: [PATCH] refactor: Update DynamicMessage to store the Protobuf fields as map keyed by field number #73 --- src/dynamic_message.rs | 146 +++++++++++++++++++++++------------------ src/mock_service.rs | 2 +- 2 files changed, 82 insertions(+), 66 deletions(-) diff --git a/src/dynamic_message.rs b/src/dynamic_message.rs index 3be5069..82d9867 100644 --- a/src/dynamic_message.rs +++ b/src/dynamic_message.rs @@ -66,11 +66,12 @@ impl Codec for PactCodec { } } +/// Dynamic message support based on a vector of ProtobufField field values #[derive(Debug, Clone)] -/// Dynamic message support based on a vector of ProtobufField fields pub struct DynamicMessage { - fields: Vec, - descriptors: FileDescriptorSet + fields: HashMap>, + descriptors: FileDescriptorSet, + message_descriptor: DescriptorProto } impl DynamicMessage { @@ -81,75 +82,79 @@ impl DynamicMessage { descriptors: &FileDescriptorSet ) -> DynamicMessage { DynamicMessage { - fields: field_data.to_vec(), + fields: field_data.iter().map(|f| (f.field_num, f.clone())).into_group_map(), + message_descriptor: message_descriptor.clone(), descriptors: descriptors.clone() } } - /// Return a slice of the fields - pub fn proto_fields(&self) -> &[ProtobufField] { - self.fields.as_slice() + /// Return a vector of the fields + pub fn proto_fields(&self) -> Vec { + self.fields.values().flatten().cloned().collect() } /// Encode this message to the provided buffer pub fn write_to(&self, buffer: &mut B) -> anyhow::Result<()> where B: BufMut { - for field in self.fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)) { - trace!(field = field.to_string().as_str(), "Writing"); - encode_key(field.field_num, field.wire_type, buffer); - match field.wire_type { - WireType::Varint => match &field.data { - ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer), - ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer), - ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); + for (field_num, values) in self.fields.iter() + .sorted_by(|(a, _), (b, _)| Ord::cmp(a, b)) { + for field in values { + trace!(%field_num, field = field.to_string().as_str(), "Writing"); + encode_key(field.field_num, field.wire_type, buffer); + match field.wire_type { + WireType::Varint => match &field.data { + ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer), + ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer), + ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a varint, but field is {}", field.data)) }, - _ => return Err(anyhow!("Expected a varint, but field is {}", field.data)) - }, - WireType::SixtyFourBit => match &field.data { - ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n), - ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n), - ProtobufFieldData::Double(n) => buffer.put_f64_le(*n), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - } - _ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data)) - } - WireType::LengthDelimited => match &field.data { - ProtobufFieldData::String(s) => { - encode_varint(s.len() as u64, buffer); - buffer.put_slice(s.as_bytes()); + WireType::SixtyFourBit => match &field.data { + ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n), + ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n), + ProtobufFieldData::Double(n) => buffer.put_f64_le(*n), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + } + _ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data)) } - ProtobufFieldData::Bytes(b) => { - encode_varint(b.len() as u64, buffer); - buffer.put_slice(b.as_slice()); + WireType::LengthDelimited => match &field.data { + ProtobufFieldData::String(s) => { + encode_varint(s.len() as u64, buffer); + buffer.put_slice(s.as_bytes()); + } + ProtobufFieldData::Bytes(b) => { + encode_varint(b.len() as u64, buffer); + buffer.put_slice(b.as_slice()); + } + ProtobufFieldData::Message(m, _) => { + encode_varint(m.len() as u64, buffer); + buffer.put_slice(m.as_slice()); + } + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data)) } - ProtobufFieldData::Message(m, _) => { - encode_varint(m.len() as u64, buffer); - buffer.put_slice(m.as_slice()); + WireType::ThirtyTwoBit => match &field.data { + ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n), + ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n), + ProtobufFieldData::Float(n) => buffer.put_f32_le(*n), + ProtobufFieldData::Unknown(b) => { + debug!("Writing unknown field {}", field.data); + buffer.put_slice(b.as_slice()); + }, + _ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data)) } - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - }, - _ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data)) - } - WireType::ThirtyTwoBit => match &field.data { - ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n), - ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n), - ProtobufFieldData::Float(n) => buffer.put_f32_le(*n), - ProtobufFieldData::Unknown(b) => { - debug!("Writing unknown field {}", field.data); - buffer.put_slice(b.as_slice()); - }, - _ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data)) + _ => return Err(anyhow!("Groups are not supported")) } - _ => return Err(anyhow!("Groups are not supported")) } } Ok(()) @@ -224,9 +229,9 @@ impl DynamicMessage { } } } - PathToken::Index(_) => todo!(), - PathToken::Star => todo!(), - PathToken::StarIndex => todo!(), + PathToken::Index(_) => todo!("Support for index paths is not supported yet"), + PathToken::Star => todo!("Support for * in paths is not supported yet"), + PathToken::StarIndex => todo!("Support for [*] in paths is not supported yet"), _ => () } } @@ -260,9 +265,20 @@ impl DynamicMessage { } } -fn find_field<'a>(fields: &'a mut [ProtobufField], field_name: &str) -> Option<&'a mut ProtobufField> { +// TODO: This only supports the first value, needs to deal with repeated fields +fn find_field<'a>( + fields: &'a mut HashMap>, + field_name: &str +) -> Option<&'a mut ProtobufField> { fields.iter_mut() - .find(|field| field.field_name == field_name) + .find(|(_, fields)| fields.iter().any(|field| field.field_name == field_name)) + .map(|(_, fields)| { + if fields.len() > 1 { + warn!("There is more than one field value"); + } + fields.get_mut(0) + }) + .flatten() } #[derive(Debug, Clone)] @@ -508,7 +524,7 @@ mod tests { }; expect!(message.apply_generators(Some(&generators), &GeneratorTestMode::Provider, &hashmap!{})).to(be_ok()); - expect!(message.fields[0].data.as_i64().unwrap()).to_not(be_equal_to(100)); + expect!(message.proto_fields()[0].data.as_i64().unwrap()).to_not(be_equal_to(100)); } #[test] diff --git a/src/mock_service.rs b/src/mock_service.rs index 878f57e..93e8d6c 100644 --- a/src/mock_service.rs +++ b/src/mock_service.rs @@ -77,7 +77,7 @@ impl MockService { let context = CoreMatchingContext::new(DiffConfig::NoUnexpectedKeys, &self.message.request.matching_rules.rules_for_category("body").unwrap_or_default(), &plugin_config); - let mismatches = compare(&message_descriptor, &expected_message, request.proto_fields(), &context, + let mismatches = compare(&message_descriptor, &expected_message, request.proto_fields().as_slice(), &context, &expected_message_bytes, &self.file_descriptor_set); // 2. Compare any metadata from the incoming message