Add raw field iterator and with builder method
- Implement RawFieldIterator and ProtoAccessor::raw_fields that yield (field_number, raw_bytes) pairs for each field - Extend Builder with per-field _written flags and add a with() method to copy unseen fields from a source message - Add ProtoBuilder::write_raw to copy pre-encoded field bytes - Add tests for raw-field iteration, verbatim copying, and with()
This commit is contained in:
+59
-14
@@ -257,36 +257,81 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
||||
output.push_str(" }\n\n");
|
||||
}
|
||||
}
|
||||
// raw_fields() convenience on the message struct (before closing the impl)
|
||||
output.push_str(" pub fn raw_fields(&self) -> crate::RawFieldIterator<'a> {\n");
|
||||
output.push_str(" self.accessor.raw_fields()\n");
|
||||
output.push_str(" }\n\n");
|
||||
output.push_str("}\n\n");
|
||||
|
||||
// Builder
|
||||
output.push_str(&format!(
|
||||
"pub struct {}Builder<'b> {{\n builder: crate::ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n",
|
||||
msg_name, msg_name
|
||||
));
|
||||
output.push_str(&format!(
|
||||
" pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: crate::ProtoBuilder::new(buf),\n }}\n }}\n\n",
|
||||
msg_name, msg_name
|
||||
));
|
||||
|
||||
// Collect builder field info so we can use it multiple times below.
|
||||
// Tuple: (field_name, safe_name, tag, rust_type, write_method)
|
||||
let mut builder_fields: Vec<(String, String, u32, String, String)> = Vec::new();
|
||||
for field_res in msg_proto.field() {
|
||||
let (field_data, _) = field_res.expect("Failed to iterate field");
|
||||
let field_proto =
|
||||
FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto");
|
||||
let field_name = field_proto.name().unwrap();
|
||||
let field_name = field_proto.name().unwrap().to_string();
|
||||
let safe_name = if field_name == "type" {
|
||||
format!("r#{}", field_name)
|
||||
} else {
|
||||
field_name.to_string()
|
||||
field_name.clone()
|
||||
};
|
||||
let tag = field_proto.number().unwrap();
|
||||
let f_type = field_proto.r#type().unwrap() as i32;
|
||||
let (rust_type, method) = map_type_to_rust_builder(f_type);
|
||||
builder_fields.push((field_name, safe_name, tag, rust_type, method));
|
||||
}
|
||||
|
||||
// Builder struct — one `_written: bool` flag per field
|
||||
output.push_str(&format!("pub struct {}Builder<'b> {{\n", msg_name));
|
||||
output.push_str(" builder: crate::ProtoBuilder<'b>,\n");
|
||||
for (field_name, _, _, _, _) in &builder_fields {
|
||||
output.push_str(&format!(" {}_written: bool,\n", field_name));
|
||||
}
|
||||
output.push_str(&format!("}}\.\n\nimpl<'b> {}Builder<'b> {{\n", msg_name));
|
||||
|
||||
// Constructor — initialise every flag to false
|
||||
output.push_str(&format!(
|
||||
" pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n",
|
||||
msg_name, msg_name
|
||||
));
|
||||
output.push_str(" builder: crate::ProtoBuilder::new(buf),\n");
|
||||
for (field_name, _, _, _, _) in &builder_fields {
|
||||
output.push_str(&format!(" {}_written: false,\n", field_name));
|
||||
}
|
||||
output.push_str(" }\n }\n\n");
|
||||
|
||||
// Per-field setters — mark field as written
|
||||
for (field_name, safe_name, tag, rust_type, method) in &builder_fields {
|
||||
output.push_str(&format!(
|
||||
" pub fn {}(mut self, value: {}) -> crate::Result<Self> {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n",
|
||||
safe_name, rust_type, method, tag
|
||||
" pub fn {}(mut self, value: {}) -> crate::Result<Self> {{\n self.builder.{}({}, value)?;\n self.{}_written = true;\n Ok(self)\n }}\n\n",
|
||||
safe_name, rust_type, method, tag, field_name
|
||||
));
|
||||
}
|
||||
|
||||
// with() — copies unseen fields from an existing message
|
||||
output.push_str(&format!(
|
||||
" pub fn with(mut self, msg: &{}<'_>) -> crate::Result<Self> {{\n",
|
||||
msg_name
|
||||
));
|
||||
output.push_str(" for item in msg.raw_fields() {\n");
|
||||
output.push_str(" let (field_number, raw_bytes) = item?;\n");
|
||||
output.push_str(" let is_written = match field_number {\n");
|
||||
for (field_name, _, tag, _, _) in &builder_fields {
|
||||
output.push_str(&format!(
|
||||
" {} => self.{}_written,\n",
|
||||
tag, field_name
|
||||
));
|
||||
}
|
||||
output.push_str(" _ => false,\n");
|
||||
output.push_str(" };\n");
|
||||
output.push_str(" if !is_written {\n");
|
||||
output.push_str(" self.builder.write_raw(raw_bytes)?;\n");
|
||||
output.push_str(" }\n");
|
||||
output.push_str(" }\n");
|
||||
output.push_str(" Ok(self)\n");
|
||||
output.push_str(" }\n\n");
|
||||
|
||||
output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
|
||||
|
||||
let mut nested_enums = Vec::new();
|
||||
|
||||
+158
@@ -236,6 +236,17 @@ impl<'a> ProtoAccessor<'a> {
|
||||
) -> RepeatedFieldIterator<'a> {
|
||||
RepeatedFieldIterator::new_range(self.data, field_number, start, end)
|
||||
}
|
||||
|
||||
/// Returns an iterator that yields `(field_number, raw_bytes)` for every
|
||||
/// field in the message. `raw_bytes` is the complete on-wire encoding
|
||||
/// (tag + value, including any length prefix), suitable for passing
|
||||
/// directly to `ProtoBuilder::write_raw`.
|
||||
pub fn raw_fields(&self) -> RawFieldIterator<'a> {
|
||||
RawFieldIterator {
|
||||
data: self.data,
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FieldIterator<'a> {
|
||||
@@ -346,6 +357,48 @@ impl<'a> Iterator for RepeatedFieldIterator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator that yields `(field_number, raw_bytes)` for every field in a
|
||||
/// protobuf message, where `raw_bytes` is the complete on-wire encoding of the
|
||||
/// field: tag varint + value bytes (including the length prefix for
|
||||
/// length-delimited fields). This is the slice needed by
|
||||
/// `ProtoBuilder::write_raw` to copy a field verbatim.
|
||||
pub struct RawFieldIterator<'a> {
|
||||
data: &'a [u8],
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for RawFieldIterator<'a> {
|
||||
type Item = Result<(u32, &'a [u8])>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.cursor >= self.data.len() {
|
||||
return None;
|
||||
}
|
||||
let field_start = self.cursor;
|
||||
let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
let cursor_after_tag = self.cursor + tag_len;
|
||||
if cursor_after_tag > self.data.len() {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(RotoError::UnexpectedEndOfBuffer));
|
||||
}
|
||||
let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
self.cursor = cursor_after_tag + value_len;
|
||||
Some(Ok((tag.field_number, &self.data[field_start..self.cursor])))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -455,6 +508,104 @@ mod tests {
|
||||
assert_eq!(result, Err(RotoError::BufferOverflow));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_raw_field_iterator_yields_correct_bytes() {
|
||||
// Build: field 1 = string "hi", field 2 = int32 42
|
||||
let mut buf = [0u8; 64];
|
||||
let mut builder = ProtoBuilder::new(&mut buf);
|
||||
builder.write_string(1, "hi").unwrap();
|
||||
builder.write_int32(2, 42).unwrap();
|
||||
let data = builder.finish().unwrap().to_vec();
|
||||
|
||||
let acc = ProtoAccessor::new(&data).unwrap();
|
||||
let raw: Vec<_> = acc.raw_fields().collect();
|
||||
assert_eq!(raw.len(), 2);
|
||||
|
||||
// Field 1: tag = (1 << 3) | 2 = 0x0A, len varint = 0x02, "hi" = [0x68, 0x69]
|
||||
let (fn1, bytes1) = raw[0].as_ref().unwrap();
|
||||
assert_eq!(*fn1, 1);
|
||||
assert_eq!(*bytes1, [0x0A, 0x02, b'h', b'i']);
|
||||
|
||||
// Field 2: tag = (2 << 3) | 0 = 0x10, varint 42 = 0x2A
|
||||
let (fn2, bytes2) = raw[1].as_ref().unwrap();
|
||||
assert_eq!(*fn2, 2);
|
||||
assert_eq!(*bytes2, [0x10, 0x2A]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_raw_copies_field_verbatim() {
|
||||
// Build source: field 1 = string "hello", field 2 = int32 99
|
||||
let mut src_buf = [0u8; 64];
|
||||
let mut src_builder = ProtoBuilder::new(&mut src_buf);
|
||||
src_builder.write_string(1, "hello").unwrap();
|
||||
src_builder.write_int32(2, 99).unwrap();
|
||||
let src_data = src_builder.finish().unwrap().to_vec();
|
||||
|
||||
// Copy every raw field verbatim into a new buffer
|
||||
let src_acc = ProtoAccessor::new(&src_data).unwrap();
|
||||
let mut dst_buf = [0u8; 64];
|
||||
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
|
||||
for item in src_acc.raw_fields() {
|
||||
let (_, raw_bytes) = item.unwrap();
|
||||
dst_builder.write_raw(raw_bytes).unwrap();
|
||||
}
|
||||
let dst_data = dst_builder.finish().unwrap();
|
||||
|
||||
// The copy must be byte-identical to the source
|
||||
assert_eq!(dst_data, src_data.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_pattern_copies_unseen_fields() {
|
||||
// Build an existing source message with 3 fields
|
||||
let mut src_buf = [0u8; 128];
|
||||
let mut src_builder = ProtoBuilder::new(&mut src_buf);
|
||||
src_builder.write_string(1, "original").unwrap();
|
||||
src_builder.write_int32(2, 99).unwrap();
|
||||
src_builder.write_varint(3, 1u64).unwrap(); // bool
|
||||
let src_data = src_builder.finish().unwrap().to_vec();
|
||||
let src_acc = ProtoAccessor::new(&src_data).unwrap();
|
||||
|
||||
// Simulate what a generated `with` method does:
|
||||
// field 1 was explicitly written; fields 2 and 3 come from source.
|
||||
let field1_written = true;
|
||||
let field2_written = false;
|
||||
let field3_written = false;
|
||||
|
||||
let mut dst_buf = [0u8; 128];
|
||||
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
|
||||
dst_builder.write_string(1, "updated").unwrap();
|
||||
|
||||
for item in src_acc.raw_fields() {
|
||||
let (field_number, raw_bytes) = item.unwrap();
|
||||
let is_written = match field_number {
|
||||
1 => field1_written,
|
||||
2 => field2_written,
|
||||
3 => field3_written,
|
||||
_ => false,
|
||||
};
|
||||
if !is_written {
|
||||
dst_builder.write_raw(raw_bytes).unwrap();
|
||||
}
|
||||
}
|
||||
let dst_data = dst_builder.finish().unwrap();
|
||||
let dst_acc = ProtoAccessor::new(dst_data).unwrap();
|
||||
|
||||
// Field 1: overridden value
|
||||
let (val1, _) = dst_acc.get_value(1).unwrap();
|
||||
assert_eq!(val1, b"updated");
|
||||
|
||||
// Field 2: copied from source
|
||||
let (val2, _) = dst_acc.get_value(2).unwrap();
|
||||
let (v2, _) = read_varint(val2).unwrap();
|
||||
assert_eq!(v2 as i32, 99);
|
||||
|
||||
// Field 3: copied from source
|
||||
let (val3, _) = dst_acc.get_value(3).unwrap();
|
||||
let (v3, _) = read_varint(val3).unwrap();
|
||||
assert_eq!(v3, 1u64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protoc_binary_compatibility() {
|
||||
let data = include_bytes!("../data/test_data.pb");
|
||||
@@ -618,6 +769,13 @@ impl<'a> ProtoBuilder<'a> {
|
||||
self.append_bytes(value)
|
||||
}
|
||||
|
||||
/// Appends a pre-encoded field (tag + value bytes) verbatim into the
|
||||
/// buffer. Use this together with `ProtoAccessor::raw_fields` to copy
|
||||
/// fields from an existing message into a builder without re-encoding them.
|
||||
pub fn write_raw(&mut self, raw_bytes: &[u8]) -> Result<()> {
|
||||
self.append_bytes(raw_bytes)
|
||||
}
|
||||
|
||||
pub fn finish(self) -> Result<&'a mut [u8]> {
|
||||
Ok(&mut self.buf[..self.pos])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
use roto::generator::generate_rust_code;
|
||||
use roto::google::protobuf::compiler::plugin::CodeGeneratorRequest;
|
||||
use roto::google::protobuf::descriptor::FileDescriptorSet;
|
||||
use std::fs;
|
||||
|
||||
fn load_generated_code() -> String {
|
||||
let data = fs::read("data/request.bin").expect("Failed to read data/request.bin");
|
||||
let request = CodeGeneratorRequest::new(&data).expect("Failed to parse CodeGeneratorRequest");
|
||||
|
||||
let mut set_buf = Vec::new();
|
||||
for file_res in request.proto_file() {
|
||||
let (file_data, _) = file_res.expect("Failed to iterate proto_file");
|
||||
set_buf.push(10u8);
|
||||
let len = file_data.len() as u64;
|
||||
let mut len_buf = [0u8; 10];
|
||||
let len_size = roto::write_varint(len, &mut len_buf).unwrap();
|
||||
set_buf.extend_from_slice(&len_buf[..len_size]);
|
||||
set_buf.extend_from_slice(file_data);
|
||||
}
|
||||
let set = FileDescriptorSet::new(&set_buf).expect("Failed to create FileDescriptorSet");
|
||||
|
||||
generate_rust_code(&set, None, false)
|
||||
.into_iter()
|
||||
.map(|(_, content)| content)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_structs_have_written_flags() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("_written: bool"),
|
||||
"Builder structs should contain `_written: bool` fields for each proto field"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_constructor_initialises_written_flags_to_false() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("_written: false"),
|
||||
"Builder constructors should initialise every `_written` flag to false"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_setters_mark_field_as_written() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("_written = true"),
|
||||
"Each builder setter should set its `_written` flag to true"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_has_with_method() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("pub fn with("),
|
||||
"Each builder impl should expose a `with` method"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_structs_have_raw_fields_method() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("pub fn raw_fields("),
|
||||
"Each message struct impl should expose a `raw_fields` method"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_method_uses_write_raw() {
|
||||
let code = load_generated_code();
|
||||
assert!(
|
||||
code.contains("write_raw(raw_bytes)"),
|
||||
"The `with` method should call `write_raw` to copy field bytes"
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user