Add raw field iterator and with builder method

- Implement RawFieldIterator and ProtoAccessor::raw_fields that yield
  (field_number, raw_bytes) pairs for each field
- Extend Builder with per-field _written flags and add a with()
  method to copy unseen fields from a source message
- Add ProtoBuilder::write_raw to copy pre-encoded field bytes
- Add tests for raw-field iteration, verbatim copying, and with()
This commit is contained in:
2026-05-04 19:03:56 -07:00
parent 6821bd1cca
commit 05e4c275bb
3 changed files with 297 additions and 14 deletions
+59 -14
View File
@@ -257,36 +257,81 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(" }\n\n"); output.push_str(" }\n\n");
} }
} }
// raw_fields() convenience on the message struct (before closing the impl)
output.push_str(" pub fn raw_fields(&self) -> crate::RawFieldIterator<'a> {\n");
output.push_str(" self.accessor.raw_fields()\n");
output.push_str(" }\n\n");
output.push_str("}\n\n"); output.push_str("}\n\n");
// Builder // Collect builder field info so we can use it multiple times below.
output.push_str(&format!( // Tuple: (field_name, safe_name, tag, rust_type, write_method)
"pub struct {}Builder<'b> {{\n builder: crate::ProtoBuilder<'b>,\n}}\n\nimpl<'b> {}Builder<'b> {{\n", let mut builder_fields: Vec<(String, String, u32, String, String)> = Vec::new();
msg_name, msg_name
));
output.push_str(&format!(
" pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n builder: crate::ProtoBuilder::new(buf),\n }}\n }}\n\n",
msg_name, msg_name
));
for field_res in msg_proto.field() { for field_res in msg_proto.field() {
let (field_data, _) = field_res.expect("Failed to iterate field"); let (field_data, _) = field_res.expect("Failed to iterate field");
let field_proto = let field_proto =
FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto");
let field_name = field_proto.name().unwrap(); let field_name = field_proto.name().unwrap().to_string();
let safe_name = if field_name == "type" { let safe_name = if field_name == "type" {
format!("r#{}", field_name) format!("r#{}", field_name)
} else { } else {
field_name.to_string() field_name.clone()
}; };
let tag = field_proto.number().unwrap(); let tag = field_proto.number().unwrap();
let f_type = field_proto.r#type().unwrap() as i32; let f_type = field_proto.r#type().unwrap() as i32;
let (rust_type, method) = map_type_to_rust_builder(f_type); let (rust_type, method) = map_type_to_rust_builder(f_type);
builder_fields.push((field_name, safe_name, tag, rust_type, method));
}
// Builder struct — one `_written: bool` flag per field
output.push_str(&format!("pub struct {}Builder<'b> {{\n", msg_name));
output.push_str(" builder: crate::ProtoBuilder<'b>,\n");
for (field_name, _, _, _, _) in &builder_fields {
output.push_str(&format!(" {}_written: bool,\n", field_name));
}
output.push_str(&format!("}}\.\n\nimpl<'b> {}Builder<'b> {{\n", msg_name));
// Constructor — initialise every flag to false
output.push_str(&format!( output.push_str(&format!(
" pub fn {}(mut self, value: {}) -> crate::Result<Self> {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", " pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n",
safe_name, rust_type, method, tag msg_name, msg_name
));
output.push_str(" builder: crate::ProtoBuilder::new(buf),\n");
for (field_name, _, _, _, _) in &builder_fields {
output.push_str(&format!(" {}_written: false,\n", field_name));
}
output.push_str(" }\n }\n\n");
// Per-field setters — mark field as written
for (field_name, safe_name, tag, rust_type, method) in &builder_fields {
output.push_str(&format!(
" pub fn {}(mut self, value: {}) -> crate::Result<Self> {{\n self.builder.{}({}, value)?;\n self.{}_written = true;\n Ok(self)\n }}\n\n",
safe_name, rust_type, method, tag, field_name
)); ));
} }
// with() — copies unseen fields from an existing message
output.push_str(&format!(
" pub fn with(mut self, msg: &{}<'_>) -> crate::Result<Self> {{\n",
msg_name
));
output.push_str(" for item in msg.raw_fields() {\n");
output.push_str(" let (field_number, raw_bytes) = item?;\n");
output.push_str(" let is_written = match field_number {\n");
for (field_name, _, tag, _, _) in &builder_fields {
output.push_str(&format!(
" {} => self.{}_written,\n",
tag, field_name
));
}
output.push_str(" _ => false,\n");
output.push_str(" };\n");
output.push_str(" if !is_written {\n");
output.push_str(" self.builder.write_raw(raw_bytes)?;\n");
output.push_str(" }\n");
output.push_str(" }\n");
output.push_str(" Ok(self)\n");
output.push_str(" }\n\n");
output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
let mut nested_enums = Vec::new(); let mut nested_enums = Vec::new();
+158
View File
@@ -236,6 +236,17 @@ impl<'a> ProtoAccessor<'a> {
) -> RepeatedFieldIterator<'a> { ) -> RepeatedFieldIterator<'a> {
RepeatedFieldIterator::new_range(self.data, field_number, start, end) RepeatedFieldIterator::new_range(self.data, field_number, start, end)
} }
/// Returns an iterator that yields `(field_number, raw_bytes)` for every
/// field in the message. `raw_bytes` is the complete on-wire encoding
/// (tag + value, including any length prefix), suitable for passing
/// directly to `ProtoBuilder::write_raw`.
pub fn raw_fields(&self) -> RawFieldIterator<'a> {
RawFieldIterator {
data: self.data,
cursor: 0,
}
}
} }
pub struct FieldIterator<'a> { pub struct FieldIterator<'a> {
@@ -346,6 +357,48 @@ impl<'a> Iterator for RepeatedFieldIterator<'a> {
} }
} }
/// An iterator that yields `(field_number, raw_bytes)` for every field in a
/// protobuf message, where `raw_bytes` is the complete on-wire encoding of the
/// field: tag varint + value bytes (including the length prefix for
/// length-delimited fields). This is the slice needed by
/// `ProtoBuilder::write_raw` to copy a field verbatim.
pub struct RawFieldIterator<'a> {
data: &'a [u8],
cursor: usize,
}
impl<'a> Iterator for RawFieldIterator<'a> {
type Item = Result<(u32, &'a [u8])>;
fn next(&mut self) -> Option<Self::Item> {
if self.cursor >= self.data.len() {
return None;
}
let field_start = self.cursor;
let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) {
Ok(t) => t,
Err(e) => {
self.cursor = self.data.len();
return Some(Err(e));
}
};
let cursor_after_tag = self.cursor + tag_len;
if cursor_after_tag > self.data.len() {
self.cursor = self.data.len();
return Some(Err(RotoError::UnexpectedEndOfBuffer));
}
let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) {
Ok(l) => l,
Err(e) => {
self.cursor = self.data.len();
return Some(Err(e));
}
};
self.cursor = cursor_after_tag + value_len;
Some(Ok((tag.field_number, &self.data[field_start..self.cursor])))
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -455,6 +508,104 @@ mod tests {
assert_eq!(result, Err(RotoError::BufferOverflow)); assert_eq!(result, Err(RotoError::BufferOverflow));
} }
#[test]
fn test_raw_field_iterator_yields_correct_bytes() {
// Build: field 1 = string "hi", field 2 = int32 42
let mut buf = [0u8; 64];
let mut builder = ProtoBuilder::new(&mut buf);
builder.write_string(1, "hi").unwrap();
builder.write_int32(2, 42).unwrap();
let data = builder.finish().unwrap().to_vec();
let acc = ProtoAccessor::new(&data).unwrap();
let raw: Vec<_> = acc.raw_fields().collect();
assert_eq!(raw.len(), 2);
// Field 1: tag = (1 << 3) | 2 = 0x0A, len varint = 0x02, "hi" = [0x68, 0x69]
let (fn1, bytes1) = raw[0].as_ref().unwrap();
assert_eq!(*fn1, 1);
assert_eq!(*bytes1, [0x0A, 0x02, b'h', b'i']);
// Field 2: tag = (2 << 3) | 0 = 0x10, varint 42 = 0x2A
let (fn2, bytes2) = raw[1].as_ref().unwrap();
assert_eq!(*fn2, 2);
assert_eq!(*bytes2, [0x10, 0x2A]);
}
#[test]
fn test_write_raw_copies_field_verbatim() {
// Build source: field 1 = string "hello", field 2 = int32 99
let mut src_buf = [0u8; 64];
let mut src_builder = ProtoBuilder::new(&mut src_buf);
src_builder.write_string(1, "hello").unwrap();
src_builder.write_int32(2, 99).unwrap();
let src_data = src_builder.finish().unwrap().to_vec();
// Copy every raw field verbatim into a new buffer
let src_acc = ProtoAccessor::new(&src_data).unwrap();
let mut dst_buf = [0u8; 64];
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
for item in src_acc.raw_fields() {
let (_, raw_bytes) = item.unwrap();
dst_builder.write_raw(raw_bytes).unwrap();
}
let dst_data = dst_builder.finish().unwrap();
// The copy must be byte-identical to the source
assert_eq!(dst_data, src_data.as_slice());
}
#[test]
fn test_with_pattern_copies_unseen_fields() {
// Build an existing source message with 3 fields
let mut src_buf = [0u8; 128];
let mut src_builder = ProtoBuilder::new(&mut src_buf);
src_builder.write_string(1, "original").unwrap();
src_builder.write_int32(2, 99).unwrap();
src_builder.write_varint(3, 1u64).unwrap(); // bool
let src_data = src_builder.finish().unwrap().to_vec();
let src_acc = ProtoAccessor::new(&src_data).unwrap();
// Simulate what a generated `with` method does:
// field 1 was explicitly written; fields 2 and 3 come from source.
let field1_written = true;
let field2_written = false;
let field3_written = false;
let mut dst_buf = [0u8; 128];
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
dst_builder.write_string(1, "updated").unwrap();
for item in src_acc.raw_fields() {
let (field_number, raw_bytes) = item.unwrap();
let is_written = match field_number {
1 => field1_written,
2 => field2_written,
3 => field3_written,
_ => false,
};
if !is_written {
dst_builder.write_raw(raw_bytes).unwrap();
}
}
let dst_data = dst_builder.finish().unwrap();
let dst_acc = ProtoAccessor::new(dst_data).unwrap();
// Field 1: overridden value
let (val1, _) = dst_acc.get_value(1).unwrap();
assert_eq!(val1, b"updated");
// Field 2: copied from source
let (val2, _) = dst_acc.get_value(2).unwrap();
let (v2, _) = read_varint(val2).unwrap();
assert_eq!(v2 as i32, 99);
// Field 3: copied from source
let (val3, _) = dst_acc.get_value(3).unwrap();
let (v3, _) = read_varint(val3).unwrap();
assert_eq!(v3, 1u64);
}
#[test] #[test]
fn test_protoc_binary_compatibility() { fn test_protoc_binary_compatibility() {
let data = include_bytes!("../data/test_data.pb"); let data = include_bytes!("../data/test_data.pb");
@@ -618,6 +769,13 @@ impl<'a> ProtoBuilder<'a> {
self.append_bytes(value) self.append_bytes(value)
} }
/// Appends a pre-encoded field (tag + value bytes) verbatim into the
/// buffer. Use this together with `ProtoAccessor::raw_fields` to copy
/// fields from an existing message into a builder without re-encoding them.
pub fn write_raw(&mut self, raw_bytes: &[u8]) -> Result<()> {
self.append_bytes(raw_bytes)
}
pub fn finish(self) -> Result<&'a mut [u8]> { pub fn finish(self) -> Result<&'a mut [u8]> {
Ok(&mut self.buf[..self.pos]) Ok(&mut self.buf[..self.pos])
} }
+80
View File
@@ -0,0 +1,80 @@
use roto::generator::generate_rust_code;
use roto::google::protobuf::compiler::plugin::CodeGeneratorRequest;
use roto::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
fn load_generated_code() -> String {
let data = fs::read("data/request.bin").expect("Failed to read data/request.bin");
let request = CodeGeneratorRequest::new(&data).expect("Failed to parse CodeGeneratorRequest");
let mut set_buf = Vec::new();
for file_res in request.proto_file() {
let (file_data, _) = file_res.expect("Failed to iterate proto_file");
set_buf.push(10u8);
let len = file_data.len() as u64;
let mut len_buf = [0u8; 10];
let len_size = roto::write_varint(len, &mut len_buf).unwrap();
set_buf.extend_from_slice(&len_buf[..len_size]);
set_buf.extend_from_slice(file_data);
}
let set = FileDescriptorSet::new(&set_buf).expect("Failed to create FileDescriptorSet");
generate_rust_code(&set, None, false)
.into_iter()
.map(|(_, content)| content)
.collect()
}
#[test]
fn test_builder_structs_have_written_flags() {
let code = load_generated_code();
assert!(
code.contains("_written: bool"),
"Builder structs should contain `_written: bool` fields for each proto field"
);
}
#[test]
fn test_builder_constructor_initialises_written_flags_to_false() {
let code = load_generated_code();
assert!(
code.contains("_written: false"),
"Builder constructors should initialise every `_written` flag to false"
);
}
#[test]
fn test_builder_setters_mark_field_as_written() {
let code = load_generated_code();
assert!(
code.contains("_written = true"),
"Each builder setter should set its `_written` flag to true"
);
}
#[test]
fn test_builder_has_with_method() {
let code = load_generated_code();
assert!(
code.contains("pub fn with("),
"Each builder impl should expose a `with` method"
);
}
#[test]
fn test_message_structs_have_raw_fields_method() {
let code = load_generated_code();
assert!(
code.contains("pub fn raw_fields("),
"Each message struct impl should expose a `raw_fields` method"
);
}
#[test]
fn test_with_method_uses_write_raw() {
let code = load_generated_code();
assert!(
code.contains("write_raw(raw_bytes)"),
"The `with` method should call `write_raw` to copy field bytes"
);
}