Skip to content

Commit

Permalink
refactor: Update DynamicMessage to store the Protobuf fields as map k…
Browse files Browse the repository at this point in the history
…eyed by field number #73
  • Loading branch information
rholshausen committed Oct 31, 2024
1 parent 8452e25 commit cfb54b4
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 66 deletions.
146 changes: 81 additions & 65 deletions src/dynamic_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ impl Codec for PactCodec {
}
}

/// Dynamic message support based on a vector of ProtobufField field values
#[derive(Debug, Clone)]
/// Dynamic message support based on a vector of ProtobufField fields
pub struct DynamicMessage {
fields: Vec<ProtobufField>,
descriptors: FileDescriptorSet
fields: HashMap<u32, Vec<ProtobufField>>,
descriptors: FileDescriptorSet,
message_descriptor: DescriptorProto
}

impl DynamicMessage {
Expand All @@ -81,75 +82,79 @@ impl DynamicMessage {
descriptors: &FileDescriptorSet
) -> DynamicMessage {
DynamicMessage {
fields: field_data.to_vec(),
fields: field_data.iter().map(|f| (f.field_num, f.clone())).into_group_map(),
message_descriptor: message_descriptor.clone(),
descriptors: descriptors.clone()
}
}

/// Return a slice of the fields
pub fn proto_fields(&self) -> &[ProtobufField] {
self.fields.as_slice()
/// Return a vector of the fields
pub fn proto_fields(&self) -> Vec<ProtobufField> {
self.fields.values().flatten().cloned().collect()
}

/// Encode this message to the provided buffer
pub fn write_to<B>(&self, buffer: &mut B) -> anyhow::Result<()> where B: BufMut {
for field in self.fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)) {
trace!(field = field.to_string().as_str(), "Writing");
encode_key(field.field_num, field.wire_type, buffer);
match field.wire_type {
WireType::Varint => match &field.data {
ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer),
ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer),
ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
for (field_num, values) in self.fields.iter()
.sorted_by(|(a, _), (b, _)| Ord::cmp(a, b)) {
for field in values {
trace!(%field_num, field = field.to_string().as_str(), "Writing");
encode_key(field.field_num, field.wire_type, buffer);
match field.wire_type {
WireType::Varint => match &field.data {
ProtobufFieldData::Boolean(b) => encode_varint(*b as u64, buffer),
ProtobufFieldData::UInteger32(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Integer32(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::UInteger64(n) => encode_varint(*n, buffer),
ProtobufFieldData::Integer64(n) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Enum(n, _) => encode_varint(*n as u64, buffer),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
},
_ => return Err(anyhow!("Expected a varint, but field is {}", field.data))
},
_ => return Err(anyhow!("Expected a varint, but field is {}", field.data))
},
WireType::SixtyFourBit => match &field.data {
ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n),
ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n),
ProtobufFieldData::Double(n) => buffer.put_f64_le(*n),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
}
_ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data))
}
WireType::LengthDelimited => match &field.data {
ProtobufFieldData::String(s) => {
encode_varint(s.len() as u64, buffer);
buffer.put_slice(s.as_bytes());
WireType::SixtyFourBit => match &field.data {
ProtobufFieldData::UInteger64(n) => buffer.put_u64_le(*n),
ProtobufFieldData::Integer64(n) => buffer.put_i64_le(*n),
ProtobufFieldData::Double(n) => buffer.put_f64_le(*n),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
}
_ => return Err(anyhow!("Expected a 64 bit value, but field is {}", field.data))
}
ProtobufFieldData::Bytes(b) => {
encode_varint(b.len() as u64, buffer);
buffer.put_slice(b.as_slice());
WireType::LengthDelimited => match &field.data {
ProtobufFieldData::String(s) => {
encode_varint(s.len() as u64, buffer);
buffer.put_slice(s.as_bytes());
}
ProtobufFieldData::Bytes(b) => {
encode_varint(b.len() as u64, buffer);
buffer.put_slice(b.as_slice());
}
ProtobufFieldData::Message(m, _) => {
encode_varint(m.len() as u64, buffer);
buffer.put_slice(m.as_slice());
}
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
},
_ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data))
}
ProtobufFieldData::Message(m, _) => {
encode_varint(m.len() as u64, buffer);
buffer.put_slice(m.as_slice());
WireType::ThirtyTwoBit => match &field.data {
ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n),
ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n),
ProtobufFieldData::Float(n) => buffer.put_f32_le(*n),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
},
_ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data))
}
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
},
_ => return Err(anyhow!("Expected a length delimited value, but field is {}", field.data))
}
WireType::ThirtyTwoBit => match &field.data {
ProtobufFieldData::UInteger32(n) => buffer.put_u32_le(*n),
ProtobufFieldData::Integer32(n) => buffer.put_i32_le(*n),
ProtobufFieldData::Float(n) => buffer.put_f32_le(*n),
ProtobufFieldData::Unknown(b) => {
debug!("Writing unknown field {}", field.data);
buffer.put_slice(b.as_slice());
},
_ => return Err(anyhow!("Expected a 32 bit value, but field is {}", field.data))
_ => return Err(anyhow!("Groups are not supported"))
}
_ => return Err(anyhow!("Groups are not supported"))
}
}
Ok(())
Expand Down Expand Up @@ -224,9 +229,9 @@ impl DynamicMessage {
}
}
}
PathToken::Index(_) => todo!(),
PathToken::Star => todo!(),
PathToken::StarIndex => todo!(),
PathToken::Index(_) => todo!("Support for index paths is not supported yet"),
PathToken::Star => todo!("Support for * in paths is not supported yet"),
PathToken::StarIndex => todo!("Support for [*] in paths is not supported yet"),
_ => ()
}
}
Expand Down Expand Up @@ -260,9 +265,20 @@ impl DynamicMessage {
}
}

fn find_field<'a>(fields: &'a mut [ProtobufField], field_name: &str) -> Option<&'a mut ProtobufField> {
// TODO: This only supports the first value, needs to deal with repeated fields
fn find_field<'a>(
fields: &'a mut HashMap<u32, Vec<ProtobufField>>,
field_name: &str
) -> Option<&'a mut ProtobufField> {
fields.iter_mut()
.find(|field| field.field_name == field_name)
.find(|(_, fields)| fields.iter().any(|field| field.field_name == field_name))
.map(|(_, fields)| {
if fields.len() > 1 {
warn!("There is more than one field value");
}
fields.get_mut(0)
})
.flatten()
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -508,7 +524,7 @@ mod tests {
};

expect!(message.apply_generators(Some(&generators), &GeneratorTestMode::Provider, &hashmap!{})).to(be_ok());
expect!(message.fields[0].data.as_i64().unwrap()).to_not(be_equal_to(100));
expect!(message.proto_fields()[0].data.as_i64().unwrap()).to_not(be_equal_to(100));
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/mock_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl MockService {
let context = CoreMatchingContext::new(DiffConfig::NoUnexpectedKeys,
&self.message.request.matching_rules.rules_for_category("body").unwrap_or_default(),
&plugin_config);
let mismatches = compare(&message_descriptor, &expected_message, request.proto_fields(), &context,
let mismatches = compare(&message_descriptor, &expected_message, request.proto_fields().as_slice(), &context,
&expected_message_bytes, &self.file_descriptor_set);

// 2. Compare any metadata from the incoming message
Expand Down

0 comments on commit cfb54b4

Please sign in to comment.