pub mod proto_gen; pub mod generator; use std::fmt; #[derive(Debug, PartialEq, Eq)] pub enum RotoError { UnexpectedEndOfBuffer, InvalidVarint, InvalidWireType(u8), BufferOverflow, FieldNotFound, WireFormatViolation, } impl fmt::Display for RotoError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { RotoError::UnexpectedEndOfBuffer => write!(f, "Unexpected end of buffer"), RotoError::InvalidVarint => write!(f, "Invalid varint encoding"), RotoError::InvalidWireType(t) => write!(f, "Invalid wire type: {t}"), RotoError::BufferOverflow => write!(f, "Buffer overflow during write"), RotoError::FieldNotFound => write!(f, "Requested field not found in message"), RotoError::WireFormatViolation => write!(f, "Wire format violation"), } } } impl std::error::Error for RotoError {} pub type Result = std::result::Result; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum WireType { Varint = 0, Fixed64 = 1, LengthDelimited = 2, StartGroup = 3, // Deprecated EndGroup = 4, // Deprecated Fixed32 = 5, } impl WireType { pub fn from_u8(value: u8) -> Result { match value { 0 => Ok(WireType::Varint), 1 => Ok(WireType::Fixed64), 2 => Ok(WireType::LengthDelimited), 3 => Ok(WireType::StartGroup), 4 => Ok(WireType::EndGroup), 5 => Ok(WireType::Fixed32), _ => Err(RotoError::InvalidWireType(value)), } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Tag { pub field_number: u32, pub wire_type: WireType, } impl Tag { /// Decodes a tag from the buffer, returning the tag and the number of bytes read. pub fn decode(data: &[u8]) -> Result<(Self, usize)> { let (val, len) = read_varint(data)?; let wire_type_raw = (val & 0x7) as u8; let field_number = (val >> 3) as u32; Ok(( Tag { field_number, wire_type: WireType::from_u8(wire_type_raw)?, }, len, )) } /// Encodes a tag into the provided buffer. pub fn encode(field_number: u32, wire_type: WireType, buf: &mut [u8]) -> Result { let val = ((field_number as u64) << 3) | (wire_type as u64); write_varint(val, buf) } } /// Reads a varint from the start of the buffer. pub fn read_varint(data: &[u8]) -> Result<(u64, usize)> { let mut result = 0u64; let mut shift = 0; let mut bytes_read = 0; for &byte in data { bytes_read += 1; if bytes_read > 10 { return Err(RotoError::InvalidVarint); } let value = (byte & 0x7F) as u64; if shift >= 64 { return Err(RotoError::InvalidVarint); } result |= value << shift; shift += 7; if (byte & 0x80) == 0 { return Ok((result, bytes_read)); } } Err(RotoError::UnexpectedEndOfBuffer) } /// Writes a varint into the buffer. pub fn write_varint(mut value: u64, buf: &mut [u8]) -> Result { let mut bytes_written = 0; while value >= 0x80 { if bytes_written >= buf.len() { return Err(RotoError::BufferOverflow); } buf[bytes_written] = (value as u8 & 0x7F) | 0x80; value >>= 7; bytes_written += 1; } if bytes_written >= buf.len() { return Err(RotoError::BufferOverflow); } buf[bytes_written] = value as u8; bytes_written += 1; Ok(bytes_written) } /// Returns the number of bytes that should be skipped for a given wire type and the current data slice. pub fn skip_value(wire_type: WireType, data: &[u8]) -> Result { match wire_type { WireType::Varint => { let (_, len) = read_varint(data)?; Ok(len) } WireType::Fixed64 => { if data.len() < 8 { return Err(RotoError::UnexpectedEndOfBuffer); } Ok(8) } WireType::LengthDelimited => { let (len, varint_len) = read_varint(data)?; let total_len = varint_len + len as usize; if data.len() < total_len { return Err(RotoError::UnexpectedEndOfBuffer); } Ok(total_len) } WireType::Fixed32 => { if data.len() < 4 { return Err(RotoError::UnexpectedEndOfBuffer); } Ok(4) } WireType::StartGroup | WireType::EndGroup => { // These are deprecated and not fully supported in this runtime. Err(RotoError::WireFormatViolation) } } } pub struct ProtoAccessor<'a> { data: &'a [u8], } impl<'a> ProtoAccessor<'a> { pub fn new(data: &'a [u8]) -> Result { Ok(Self { data }) } /// Returns an iterator over all fields in the message. pub fn fields(&self) -> FieldIterator<'a> { FieldIterator { data: self.data, cursor: 0, } } /// Returns the value and wire type of the last occurrence of the specified field. pub fn get_value(&self, field_number: u32) -> Result<(&'a [u8], WireType)> { let mut last_value = None; for item in self.fields() { let (tag, value) = item?; if tag.field_number == field_number { last_value = Some((value, tag.wire_type)); } } last_value.ok_or(RotoError::FieldNotFound) } /// Returns an iterator that scans the entire buffer for all occurrences of the specified field. pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> { RepeatedFieldIterator::new(self.data, field_number) } } pub struct FieldIterator<'a> { data: &'a [u8], cursor: usize, } impl<'a> Iterator for FieldIterator<'a> { type Item = Result<(Tag, &'a [u8])>; fn next(&mut self) -> Option { if self.cursor >= self.data.len() { return None; } let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) { Ok(t) => t, Err(e) => { self.cursor = self.data.len(); return Some(Err(e)); } }; let cursor_after_tag = self.cursor + tag_len; if cursor_after_tag > self.data.len() { self.cursor = self.data.len(); return Some(Err(RotoError::UnexpectedEndOfBuffer)); } let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) { Ok(l) => l, Err(e) => { self.cursor = self.data.len(); return Some(Err(e)); } }; let (value_offset, actual_value_len) = match tag.wire_type { WireType::LengthDelimited => { let (_, varint_len) = match read_varint(&self.data[cursor_after_tag..]) { Ok(v) => v, Err(e) => { self.cursor = self.data.len(); return Some(Err(e)); } }; (cursor_after_tag + varint_len, value_len - varint_len) } _ => (cursor_after_tag, value_len), }; self.cursor = cursor_after_tag + value_len; Some(Ok((tag, &self.data[value_offset..value_offset + actual_value_len]))) } } pub struct RepeatedFieldIterator<'a> { iterator: FieldIterator<'a>, field_number: u32, } impl<'a> RepeatedFieldIterator<'a> { fn new(data: &'a [u8], field_number: u32) -> Self { Self { iterator: FieldIterator { data, cursor: 0, }, field_number, } } } impl<'a> Iterator for RepeatedFieldIterator<'a> { type Item = Result<(&'a [u8], WireType)>; fn next(&mut self) -> Option { while let Some(item) = self.iterator.next() { match item { Ok((tag, value)) if tag.field_number == self.field_number => { return Some(Ok((value, tag.wire_type))); } Ok(_) => continue, Err(e) => return Some(Err(e)), } } None } } #[cfg(test)] mod tests { use super::*; #[test] fn test_varint_read_write() { let mut buf = [0u8; 10]; let val = 300u64; let len = write_varint(val, &mut buf).unwrap(); assert_eq!(len, 2); assert_eq!(&buf[..2], &[0xAC, 0x02]); let (read_val, read_len) = read_varint(&buf[..2]).unwrap(); assert_eq!(read_val, val); assert_eq!(read_len, 2); } #[test] fn test_tag_decode() { // Field 1, WireType Varint: (1 << 3) | 0 = 8 let data = [8u8]; let (tag, len) = Tag::decode(&data).unwrap(); assert_eq!(tag.field_number, 1); assert_eq!(tag.wire_type, WireType::Varint); assert_eq!(len, 1); // Field 15, WireType LengthDelimited: (15 << 3) | 2 = 120 | 2 = 122 let data2 = [122u8]; let (tag2, len2) = Tag::decode(&data2).unwrap(); assert_eq!(tag2.field_number, 15); assert_eq!(tag2.wire_type, WireType::LengthDelimited); assert_eq!(len2, 1); } #[test] fn test_skip_value() { // Varint: 300 (2 bytes) let data_varint = [0xAC, 0x02]; assert_eq!(skip_value(WireType::Varint, &data_varint).unwrap(), 2); // Fixed32: 4 bytes let data_fixed32 = [0u8; 4]; assert_eq!(skip_value(WireType::Fixed32, &data_fixed32).unwrap(), 4); // Length delimited: len=3, data=[1,2,3] (1 byte varint for length + 3 bytes) let data_len = [3, 1, 2, 3]; assert_eq!(skip_value(WireType::LengthDelimited, &data_len).unwrap(), 4); } #[test] fn test_accessor_basic() { // Field 1 (Varint): 150 // Tag: (1 << 3) | 0 = 8. Value: 150 = [150, 1] // Field 2 (LengthDelimited): "hi" // Tag: (2 << 3) | 2 = 18. Length: 2. Value: [104, 105] let data = [8, 150, 1, 18, 2, 104, 105]; let acc = ProtoAccessor::new(&data).unwrap(); let (val1, type1) = acc.get_value(1).unwrap(); assert_eq!(type1, WireType::Varint); assert_eq!(val1, &[150, 1]); let (val2, type2) = acc.get_value(2).unwrap(); assert_eq!(type2, WireType::LengthDelimited); assert_eq!(val2, &[104, 105]); } #[test] fn test_accessor_repeated() { // Field 1: 10, Field 1: 20, Field 1: 30 // Tags: 8, 8, 8. Values: 10, 20, 30 let data = [8, 10, 8, 20, 8, 30]; let acc = ProtoAccessor::new(&data).unwrap(); // Last value should be 30 let (val, _) = acc.get_value(1).unwrap(); assert_eq!(val, &[30]); // Iteration should find all three let results: Vec<_> = acc.iter_repeated(1).collect(); assert_eq!(results.len(), 3); assert_eq!(results[0].as_ref().unwrap().0, &[10]); assert_eq!(results[1].as_ref().unwrap().0, &[20]); assert_eq!(results[2].as_ref().unwrap().0, &[30]); } #[test] fn test_builder_basic() { let mut buf = [0u8; 1024]; let mut builder = ProtoBuilder::new(&mut buf); builder.write_string(1, "hello").unwrap(); builder.write_int32(2, 42).unwrap(); let data = builder.finish().unwrap(); let acc = ProtoAccessor::new(data).unwrap(); let (val1, _) = acc.get_value(1).unwrap(); assert_eq!(val1, "hello".as_bytes()); let (val2, _) = acc.get_value(2).unwrap(); assert_eq!(val2, &[42]); } #[test] fn test_builder_overflow() { let mut buf = [0u8; 2]; let mut builder = ProtoBuilder::new(&mut buf); let result = builder.write_string(1, "too long"); assert_eq!(result, Err(RotoError::BufferOverflow)); } #[test] fn test_protoc_binary_compatibility() { let data = include_bytes!("../data/test_data.pb"); let acc = ProtoAccessor::new(data).unwrap(); // 1. Varints (Integers, Booleans, Enums) let (val_i32, type_i32) = acc.get_value(3).expect("i32_val not found"); assert_eq!(type_i32, WireType::Varint); let (v, _) = read_varint(val_i32).unwrap(); assert_eq!(v, 42); let (val_b, type_b) = acc.get_value(13).expect("b_val not found"); assert_eq!(type_b, WireType::Varint); let (v_b, _) = read_varint(val_b).unwrap(); assert_eq!(v_b, 1); // true let (val_status, type_status) = acc.get_value(16).expect("status not found"); assert_eq!(type_status, WireType::Varint); let (v_s, _) = read_varint(val_status).unwrap(); assert_eq!(v_s, 1); // ACTIVE // 2. Length Delimited (Strings, Bytes) let (val_s, type_s) = acc.get_value(14).expect("s_val not found"); assert_eq!(type_s, WireType::LengthDelimited); assert_eq!(val_s, "Hello Roto!".as_bytes()); // 3. Fixed Width (Floats) let (val_f, type_f) = acc.get_value(2).expect("f_val not found"); assert_eq!(type_f, WireType::Fixed32); let f_val = f32::from_le_bytes(val_f.try_into().expect("Expected 4 bytes for f32")); assert!((f_val - 2.71828).abs() < 1e-5); // 4. Repeated Fields // Note: primitive repeated fields are packed in proto3, so we iterate over the blob let mut i32_vals = Vec::new(); for item in acc.iter_repeated(17) { let (blob, _) = item.expect("Failed to decode repeated i32"); let mut cursor = 0; while cursor < blob.len() { let (v, len) = read_varint(&blob[cursor..]).unwrap(); i32_vals.push(v); cursor += len; } } assert_eq!(i32_vals, vec![1, 2, 3, 4, 5]); let repeated_strings: Vec<_> = acc.iter_repeated(18) .map(|r| { let (val, _) = r.expect("Failed to decode repeated string"); std::str::from_utf8(val).expect("Invalid utf8") }) .collect(); assert_eq!(repeated_strings, vec!["one", "two", "three"]); let repeated_nested: Vec<_> = acc.iter_repeated(19) .map(|r| { let (val, _) = r.expect("Failed to decode repeated nested"); let nested_acc = ProtoAccessor::new(val).unwrap(); let (id_val, _) = nested_acc.get_value(1).expect("Nested id not found"); let (id, _) = read_varint(id_val).unwrap(); id }) .collect(); assert_eq!(repeated_nested, vec![101, 102]); // 5. Single Nested Message let (val_nested, type_nested) = acc.get_value(20).expect("single_nested not found"); assert_eq!(type_nested, WireType::LengthDelimited); let nested_acc = ProtoAccessor::new(val_nested).unwrap(); let (val_id, _) = nested_acc.get_value(1).expect("Nested id not found"); let (id, _) = read_varint(val_id).unwrap(); assert_eq!(id, 200); // Validate that fields appear in the expected relative order let field_numbers: Vec = acc.fields() .map(|r| r.expect("Failed to decode field").0.field_number) .collect(); let essential_fields = [1, 2, 3, 14, 16, 20]; let mut last_field = 0; let mut found_count = 0; for &f in &field_numbers { if essential_fields.contains(&f) { assert!(f >= last_field, "Fields appeared out of order: {} came after {}", f, last_field); last_field = f; found_count += 1; } } assert_eq!(found_count, essential_fields.len()); } } pub struct ProtoBuilder<'a> { buf: &'a mut [u8], pos: usize, } impl<'a> ProtoBuilder<'a> { pub fn new(buf: &'a mut [u8]) -> Self { Self { buf, pos: 0 } } fn write_tag(&mut self, field_number: u32, wire_type: WireType) -> Result<()> { let mut temp = [0u8; 10]; let len = Tag::encode(field_number, wire_type, &mut temp)?; self.append_bytes(&temp[..len]) } fn append_bytes(&mut self, bytes: &[u8]) -> Result<()> { if self.pos + bytes.len() > self.buf.len() { return Err(RotoError::BufferOverflow); } self.buf[self.pos..self.pos + bytes.len()].copy_from_slice(bytes); self.pos += bytes.len(); Ok(()) } pub fn write_varint(&mut self, field_number: u32, value: u64) -> Result<()> { self.write_tag(field_number, WireType::Varint)?; let mut temp = [0u8; 10]; let len = write_varint(value, &mut temp)?; self.append_bytes(&temp[..len]) } pub fn write_int32(&mut self, field_number: u32, value: i32) -> Result<()> { self.write_varint(field_number, value as u64) } pub fn write_string(&mut self, field_number: u32, value: &str) -> Result<()> { self.write_tag(field_number, WireType::LengthDelimited)?; let bytes = value.as_bytes(); let mut len_buf = [0u8; 10]; let len_len = write_varint(bytes.len() as u64, &mut len_buf)?; self.append_bytes(&len_buf[..len_len])?; self.append_bytes(bytes) } pub fn write_fixed32(&mut self, field_number: u32, value: u32) -> Result<()> { self.write_tag(field_number, WireType::Fixed32)?; self.append_bytes(&value.to_le_bytes()) } pub fn write_fixed64(&mut self, field_number: u32, value: u64) -> Result<()> { self.write_tag(field_number, WireType::Fixed64)?; self.append_bytes(&value.to_le_bytes()) } pub fn write_bytes(&mut self, field_number: u32, value: &[u8]) -> Result<()> { self.write_tag(field_number, WireType::LengthDelimited)?; let mut len_buf = [0u8; 10]; let len_len = write_varint(value.len() as u64, &mut len_buf)?; self.append_bytes(&len_buf[..len_len])?; self.append_bytes(value) } pub fn finish(self) -> Result<&'a mut [u8]> { Ok(&mut self.buf[..self.pos]) } }