Cache field offsets in generated accessors

Update the generator to store field offsets and ranges within the
generated accessor structs, avoiding repeated buffer scans.
This commit is contained in:
2026-05-03 13:32:39 -07:00
parent 642a78e4e7
commit 38367227ed
2 changed files with 139 additions and 16 deletions
+90 -11
View File
@@ -21,7 +21,7 @@ fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) {
// LABEL_REPEATED // LABEL_REPEATED
return ( return (
"crate::RepeatedFieldIterator<'a>".to_string(), "crate::RepeatedFieldIterator<'a>".to_string(),
"self.0.iter_repeated(%d)".to_string(), "".to_string(), // Not used for repeated fields in the same way
); );
} }
@@ -142,15 +142,13 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String {
let msg_proto = DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"); let msg_proto = DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto");
let msg_name = to_pascal_case(msg_proto.name().unwrap()); let msg_name = to_pascal_case(msg_proto.name().unwrap());
// Accessor // Accessor Struct Definition
output.push_str(&format!( output.push_str(&format!(
"pub struct {}<'a>(ProtoAccessor<'a>);\n\nimpl<'a> {}<'a> {{\n", "pub struct {}<'a> {{\n accessor: ProtoAccessor<'a>,\n",
msg_name, msg_name msg_name
));
output.push_str(&format!(
" pub fn new(data: &'a [u8]) -> Result<Self> {{\n Ok(Self(ProtoAccessor::new(data)?))\n }}\n\n"
)); ));
let mut fields_info = Vec::new();
for field_res in msg_proto.field() { for field_res in msg_proto.field() {
let (field_data, _) = field_res.expect("Failed to iterate field"); let (field_data, _) = field_res.expect("Failed to iterate field");
let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto");
@@ -159,18 +157,99 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String {
let f_type = field_proto.field_type().unwrap() as i32; let f_type = field_proto.field_type().unwrap() as i32;
let f_label = field_proto.label().unwrap() as i32; let f_label = field_proto.label().unwrap() as i32;
fields_info.push((field_name.to_string(), tag, f_type, f_label));
if f_label == 3 {
output.push_str(&format!(" {}_start: Option<usize>,\n", field_name));
output.push_str(&format!(" {}_end: Option<usize>,\n", field_name));
} else {
output.push_str(&format!(" {}_offset: Option<usize>,\n", field_name));
}
}
output.push_str("}\n\n");
// Accessor Implementation
output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name));
// new() method
output.push_str(" pub fn new(data: &'a [u8]) -> Result<Self> {\n");
output.push_str(" let accessor = ProtoAccessor::new(data)?;\n");
for (name, _, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(" let mut {}_start = None;\n", name));
output.push_str(&format!(" let mut {}_end = None;\n", name));
} else {
output.push_str(&format!(" let mut {}_offset = None;\n", name));
}
}
output.push_str(" for item in accessor.fields() {\n");
output.push_str(" let (offset, tag, _) = item?;\n");
for (name, tag, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(
" if tag.field_number == {} {{\n", tag
));
output.push_str(&format!(" if {}_start.is_none() {{ {}_start = Some(offset); }}\n", name, name));
output.push_str(&format!(" {}_end = Some(offset);\n", name));
output.push_str(" }\n");
} else {
output.push_str(&format!(
" if tag.field_number == {} {{ {}_offset = Some(offset); }}\n", tag, name
));
}
}
output.push_str(" }\n\n");
output.push_str(" Ok(Self {\n");
output.push_str(" accessor,\n");
for (name, _, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!("{}_start, {}_end,\n", name, name));
} else {
output.push_str(&format!("{}_offset,\n", name));
}
}
output.push_str(" })\n }\n\n");
// Field Accessors
for (field_name, tag, f_type, f_label) in fields_info {
let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label); let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label);
if f_label == 3 { if f_label == 3 {
output.push_str(&format!( output.push_str(&format!(
" pub fn {}(&self) -> {} {{\n {}\n }}\n\n", " pub fn {}(&self) -> {} {{\n",
field_name, rust_type, logic.replace("%d", &tag.to_string()) field_name, rust_type
)); ));
output.push_str(&format!(
" match (self.{}_start, self.{}_end) {{\n",
field_name, field_name
));
output.push_str(&format!(
" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n",
tag
));
output.push_str(&format!(
" _ => self.accessor.iter_repeated({}),\n",
tag
));
output.push_str(" }\n }\n\n");
} else { } else {
output.push_str(&format!( output.push_str(&format!(
" pub fn {}(&self) -> Result<{}> {{\n let (bytes, _) = self.0.get_value({})?;\n {}\n }}\n\n", " pub fn {}(&self) -> Result<{}> {{\n",
field_name, rust_type, tag, logic field_name, rust_type
)); ));
output.push_str(&format!(
" let offset = self.{}_offset.ok_or(RotoError::FieldNotFound)?;\n",
field_name
));
output.push_str(&format!(
" let (bytes, _) = self.accessor.get_value_at(offset)?;\n",
));
output.push_str(&format!(" {}\n", logic));
output.push_str(" }\n\n");
} }
} }
output.push_str("}\n\n"); output.push_str("}\n\n");
+49 -5
View File
@@ -185,7 +185,7 @@ impl<'a> ProtoAccessor<'a> {
pub fn get_value(&self, field_number: u32) -> Result<(&'a [u8], WireType)> { pub fn get_value(&self, field_number: u32) -> Result<(&'a [u8], WireType)> {
let mut last_value = None; let mut last_value = None;
for item in self.fields() { for item in self.fields() {
let (tag, value) = item?; let (_offset, tag, value) = item?;
if tag.field_number == field_number { if tag.field_number == field_number {
last_value = Some((value, tag.wire_type)); last_value = Some((value, tag.wire_type));
} }
@@ -197,6 +197,32 @@ impl<'a> ProtoAccessor<'a> {
pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> { pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> {
RepeatedFieldIterator::new(self.data, field_number) RepeatedFieldIterator::new(self.data, field_number)
} }
/// Returns the value and wire type of a field at a specific offset.
pub fn get_value_at(&self, offset: usize) -> Result<(&'a [u8], WireType)> {
if offset >= self.data.len() {
return Err(RotoError::UnexpectedEndOfBuffer);
}
let (tag, tag_len) = Tag::decode(&self.data[offset..])?;
let cursor_after_tag = offset + tag_len;
if cursor_after_tag > self.data.len() {
return Err(RotoError::UnexpectedEndOfBuffer);
}
let value_len = skip_value(tag.wire_type, &self.data[cursor_after_tag..])?;
let (value_offset, actual_value_len) = match tag.wire_type {
WireType::LengthDelimited => {
let (_, varint_len) = read_varint(&self.data[cursor_after_tag..])?;
(cursor_after_tag + varint_len, value_len - varint_len)
}
_ => (cursor_after_tag, value_len),
};
Ok((&self.data[value_offset..value_offset + actual_value_len], tag.wire_type))
}
/// Returns an iterator that scans a specific range of the buffer for all occurrences of the specified field.
pub fn iter_repeated_range(&self, field_number: u32, start: usize, end: usize) -> RepeatedFieldIterator<'a> {
RepeatedFieldIterator::new_range(self.data, field_number, start, end)
}
} }
pub struct FieldIterator<'a> { pub struct FieldIterator<'a> {
@@ -205,7 +231,7 @@ pub struct FieldIterator<'a> {
} }
impl<'a> Iterator for FieldIterator<'a> { impl<'a> Iterator for FieldIterator<'a> {
type Item = Result<(Tag, &'a [u8])>; type Item = Result<(usize, Tag, &'a [u8])>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.cursor >= self.data.len() { if self.cursor >= self.data.len() {
@@ -250,23 +276,36 @@ impl<'a> Iterator for FieldIterator<'a> {
self.cursor = cursor_after_tag + value_len; self.cursor = cursor_after_tag + value_len;
Some(Ok((tag, &self.data[value_offset..value_offset + actual_value_len]))) Some(Ok((self.cursor - tag_len - value_len, tag, &self.data[value_offset..value_offset + actual_value_len])))
} }
} }
pub struct RepeatedFieldIterator<'a> { pub struct RepeatedFieldIterator<'a> {
iterator: FieldIterator<'a>, iterator: FieldIterator<'a>,
field_number: u32, field_number: u32,
end_offset: Option<usize>,
} }
impl<'a> RepeatedFieldIterator<'a> { impl<'a> RepeatedFieldIterator<'a> {
fn new(data: &'a [u8], field_number: u32) -> Self { pub fn new(data: &'a [u8], field_number: u32) -> Self {
Self { Self {
iterator: FieldIterator { iterator: FieldIterator {
data, data,
cursor: 0, cursor: 0,
}, },
field_number, field_number,
end_offset: None,
}
}
pub fn new_range(data: &'a [u8], field_number: u32, start: usize, end: usize) -> Self {
Self {
iterator: FieldIterator {
data,
cursor: start,
},
field_number,
end_offset: Some(end),
} }
} }
} }
@@ -277,7 +316,12 @@ impl<'a> Iterator for RepeatedFieldIterator<'a> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
while let Some(item) = self.iterator.next() { while let Some(item) = self.iterator.next() {
match item { match item {
Ok((tag, value)) if tag.field_number == self.field_number => { Ok((offset, tag, value)) if tag.field_number == self.field_number => {
if let Some(end) = self.end_offset {
if offset > end {
return None;
}
}
return Some(Ok((value, tag.wire_type))); return Some(Ok((value, tag.wire_type)));
} }
Ok(_) => continue, Ok(_) => continue,