diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..30421de --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "roto" +version = "0.1.0" diff --git a/README.md b/README.md index a08dc6e..55da5ba 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # roto -Rust protos without the P. +Rust protos without the pointers. The codegen is different; we don't create data structures. We mark what where each field is, and only read it when asked. @@ -28,7 +28,7 @@ message Hello { fn parse_proto(data: &[u8]) -> Result { // Scans the data, marks where each flag is as an offset // into the proto. - let accessor = HelloProto::new(accessor)?; + let accessor = HelloProto::new(data)?; // Load the hello world string; returns bytes, not // a Rust string. let hello_world = accessor.hello_world()?; @@ -40,6 +40,16 @@ fn parse_proto(data: &[u8]) -> Result { } ``` +### Sample builder usage + +```rust +let mut buf = [0u8; 1024]; +let mut builder = ProtoBuilder::new(&mut buf); +builder.write_string(1, "hello world")?; +builder.write_int32(2, 42)?; +let data = builder.finish()?; // returns the used slice of the buffer +``` + ### High level design The runtime library offers an iterator over the fields in a message, using the protobuf wire format provide diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d7333e1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,465 @@ +use std::fmt; + +#[derive(Debug, PartialEq, Eq)] +pub enum RotoError { + UnexpectedEndOfBuffer, + InvalidVarint, + InvalidWireType(u8), + BufferOverflow, + FieldNotFound, + WireFormatViolation, +} + +impl fmt::Display for RotoError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RotoError::UnexpectedEndOfBuffer => write!(f, "Unexpected end of buffer"), + RotoError::InvalidVarint => write!(f, "Invalid varint encoding"), + RotoError::InvalidWireType(t) => write!(f, "Invalid wire type: {t}"), + RotoError::BufferOverflow => write!(f, "Buffer overflow during write"), + RotoError::FieldNotFound => write!(f, "Requested field not found in message"), + RotoError::WireFormatViolation => write!(f, "Wire format violation"), + } + } +} + +impl std::error::Error for RotoError {} + +pub type Result = std::result::Result; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum WireType { + Varint = 0, + Fixed64 = 1, + LengthDelimited = 2, + StartGroup = 3, // Deprecated + EndGroup = 4, // Deprecated + Fixed32 = 5, +} + +impl WireType { + pub fn from_u8(value: u8) -> Result { + match value { + 0 => Ok(WireType::Varint), + 1 => Ok(WireType::Fixed64), + 2 => Ok(WireType::LengthDelimited), + 3 => Ok(WireType::StartGroup), + 4 => Ok(WireType::EndGroup), + 5 => Ok(WireType::Fixed32), + _ => Err(RotoError::InvalidWireType(value)), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Tag { + pub field_number: u32, + pub wire_type: WireType, +} + +impl Tag { + /// Decodes a tag from the buffer, returning the tag and the number of bytes read. + pub fn decode(data: &[u8]) -> Result<(Self, usize)> { + let (val, len) = read_varint(data)?; + let wire_type_raw = (val & 0x7) as u8; + let field_number = (val >> 3) as u32; + + Ok(( + Tag { + field_number, + wire_type: WireType::from_u8(wire_type_raw)?, + }, + len, + )) + } + + /// Encodes a tag into the provided buffer. + pub fn encode(field_number: u32, wire_type: WireType, buf: &mut [u8]) -> Result { + let val = ((field_number as u64) << 3) | (wire_type as u64); + write_varint(val, buf) + } +} + +/// Reads a varint from the start of the buffer. +pub fn read_varint(data: &[u8]) -> Result<(u64, usize)> { + let mut result = 0u64; + let mut shift = 0; + let mut bytes_read = 0; + + for &byte in data { + bytes_read += 1; + if bytes_read > 10 { + return Err(RotoError::InvalidVarint); + } + + let value = (byte & 0x7F) as u64; + if shift >= 64 { + return Err(RotoError::InvalidVarint); + } + result |= value << shift; + shift += 7; + + if (byte & 0x80) == 0 { + return Ok((result, bytes_read)); + } + } + + Err(RotoError::UnexpectedEndOfBuffer) +} + +/// Writes a varint into the buffer. +pub fn write_varint(mut value: u64, buf: &mut [u8]) -> Result { + let mut bytes_written = 0; + while value >= 0x80 { + if bytes_written >= buf.len() { + return Err(RotoError::BufferOverflow); + } + buf[bytes_written] = (value as u8 & 0x7F) | 0x80; + value >>= 7; + bytes_written += 1; + } + + if bytes_written >= buf.len() { + return Err(RotoError::BufferOverflow); + } + buf[bytes_written] = value as u8; + bytes_written += 1; + Ok(bytes_written) +} + +/// Returns the number of bytes that should be skipped for a given wire type and the current data slice. +pub fn skip_value(wire_type: WireType, data: &[u8]) -> Result { + match wire_type { + WireType::Varint => { + let (_, len) = read_varint(data)?; + Ok(len) + } + WireType::Fixed64 => { + if data.len() < 8 { + return Err(RotoError::UnexpectedEndOfBuffer); + } + Ok(8) + } + WireType::LengthDelimited => { + let (len, varint_len) = read_varint(data)?; + let total_len = varint_len + len as usize; + if data.len() < total_len { + return Err(RotoError::UnexpectedEndOfBuffer); + } + Ok(total_len) + } + WireType::Fixed32 => { + if data.len() < 4 { + return Err(RotoError::UnexpectedEndOfBuffer); + } + Ok(4) + } + WireType::StartGroup | WireType::EndGroup => { + // These are deprecated and not fully supported in this runtime. + Err(RotoError::WireFormatViolation) + } + } +} + +pub struct ProtoAccessor<'a> { + data: &'a [u8], +} + +impl<'a> ProtoAccessor<'a> { + pub fn new(data: &'a [u8]) -> Result { + Ok(Self { data }) + } + + /// Returns an iterator over all fields in the message. + pub fn fields(&self) -> FieldIterator<'a> { + FieldIterator { + data: self.data, + cursor: 0, + } + } + + /// Returns the value and wire type of the last occurrence of the specified field. + pub fn get_value(&self, field_number: u32) -> Result<(&[u8], WireType)> { + let mut last_value = None; + for item in self.fields() { + let (tag, value) = item?; + if tag.field_number == field_number { + last_value = Some((value, tag.wire_type)); + } + } + last_value.ok_or(RotoError::FieldNotFound) + } + + /// Returns an iterator that scans the entire buffer for all occurrences of the specified field. + pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> { + RepeatedFieldIterator::new(self.data, field_number) + } +} + +pub struct FieldIterator<'a> { + data: &'a [u8], + cursor: usize, +} + +impl<'a> Iterator for FieldIterator<'a> { + type Item = Result<(Tag, &'a [u8])>; + + fn next(&mut self) -> Option { + if self.cursor >= self.data.len() { + return None; + } + + let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) { + Ok(t) => t, + Err(e) => { + self.cursor = self.data.len(); + return Some(Err(e)); + } + }; + + let cursor_after_tag = self.cursor + tag_len; + if cursor_after_tag > self.data.len() { + self.cursor = self.data.len(); + return Some(Err(RotoError::UnexpectedEndOfBuffer)); + } + + let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) { + Ok(l) => l, + Err(e) => { + self.cursor = self.data.len(); + return Some(Err(e)); + } + }; + + let (value_offset, actual_value_len) = match tag.wire_type { + WireType::LengthDelimited => { + let (_, varint_len) = match read_varint(&self.data[cursor_after_tag..]) { + Ok(v) => v, + Err(e) => { + self.cursor = self.data.len(); + return Some(Err(e)); + } + }; + (cursor_after_tag + varint_len, value_len - varint_len) + } + _ => (cursor_after_tag, value_len), + }; + + self.cursor = cursor_after_tag + value_len; + + Some(Ok((tag, &self.data[value_offset..value_offset + actual_value_len]))) + } +} + +pub struct RepeatedFieldIterator<'a> { + iterator: FieldIterator<'a>, + field_number: u32, +} + +impl<'a> RepeatedFieldIterator<'a> { + fn new(data: &'a [u8], field_number: u32) -> Self { + Self { + iterator: FieldIterator { + data, + cursor: 0, + }, + field_number, + } + } +} + +impl<'a> Iterator for RepeatedFieldIterator<'a> { + type Item = Result<(&'a [u8], WireType)>; + + fn next(&mut self) -> Option { + while let Some(item) = self.iterator.next() { + match item { + Ok((tag, value)) if tag.field_number == self.field_number => { + return Some(Ok((value, tag.wire_type))); + } + Ok(_) => continue, + Err(e) => return Some(Err(e)), + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_varint_read_write() { + let mut buf = [0u8; 10]; + let val = 300u64; + let len = write_varint(val, &mut buf).unwrap(); + assert_eq!(len, 2); + assert_eq!(&buf[..2], &[0xAC, 0x02]); + + let (read_val, read_len) = read_varint(&buf[..2]).unwrap(); + assert_eq!(read_val, val); + assert_eq!(read_len, 2); + } + + #[test] + fn test_tag_decode() { + // Field 1, WireType Varint: (1 << 3) | 0 = 8 + let data = [8u8]; + let (tag, len) = Tag::decode(&data).unwrap(); + assert_eq!(tag.field_number, 1); + assert_eq!(tag.wire_type, WireType::Varint); + assert_eq!(len, 1); + + // Field 15, WireType LengthDelimited: (15 << 3) | 2 = 120 | 2 = 122 + let data2 = [122u8]; + let (tag2, len2) = Tag::decode(&data2).unwrap(); + assert_eq!(tag2.field_number, 15); + assert_eq!(tag2.wire_type, WireType::LengthDelimited); + assert_eq!(len2, 1); + } + + #[test] + fn test_skip_value() { + // Varint: 300 (2 bytes) + let data_varint = [0xAC, 0x02]; + assert_eq!(skip_value(WireType::Varint, &data_varint).unwrap(), 2); + + // Fixed32: 4 bytes + let data_fixed32 = [0u8; 4]; + assert_eq!(skip_value(WireType::Fixed32, &data_fixed32).unwrap(), 4); + + // Length delimited: len=3, data=[1,2,3] (1 byte varint for length + 3 bytes) + let data_len = [3, 1, 2, 3]; + assert_eq!(skip_value(WireType::LengthDelimited, &data_len).unwrap(), 4); + } + + #[test] + fn test_accessor_basic() { + // Field 1 (Varint): 150 + // Tag: (1 << 3) | 0 = 8. Value: 150 = [150, 1] + // Field 2 (LengthDelimited): "hi" + // Tag: (2 << 3) | 2 = 18. Length: 2. Value: [104, 105] + let data = [8, 150, 1, 18, 2, 104, 105]; + let acc = ProtoAccessor::new(&data).unwrap(); + + let (val1, type1) = acc.get_value(1).unwrap(); + assert_eq!(type1, WireType::Varint); + assert_eq!(val1, &[150, 1]); + + let (val2, type2) = acc.get_value(2).unwrap(); + assert_eq!(type2, WireType::LengthDelimited); + assert_eq!(val2, &[104, 105]); + } + + #[test] + fn test_accessor_repeated() { + // Field 1: 10, Field 1: 20, Field 1: 30 + // Tags: 8, 8, 8. Values: 10, 20, 30 + let data = [8, 10, 8, 20, 8, 30]; + let acc = ProtoAccessor::new(&data).unwrap(); + + // Last value should be 30 + let (val, _) = acc.get_value(1).unwrap(); + assert_eq!(val, &[30]); + + // Iteration should find all three + let results: Vec<_> = acc.iter_repeated(1).collect(); + assert_eq!(results.len(), 3); + assert_eq!(results[0].as_ref().unwrap().0, &[10]); + assert_eq!(results[1].as_ref().unwrap().0, &[20]); + assert_eq!(results[2].as_ref().unwrap().0, &[30]); + } + + #[test] + fn test_builder_basic() { + let mut buf = [0u8; 1024]; + let mut builder = ProtoBuilder::new(&mut buf); + builder.write_string(1, "hello").unwrap(); + builder.write_int32(2, 42).unwrap(); + let data = builder.finish().unwrap(); + + let acc = ProtoAccessor::new(data).unwrap(); + let (val1, _) = acc.get_value(1).unwrap(); + assert_eq!(val1, "hello".as_bytes()); + let (val2, _) = acc.get_value(2).unwrap(); + assert_eq!(val2, &[42]); + } + + #[test] + fn test_builder_overflow() { + let mut buf = [0u8; 2]; + let mut builder = ProtoBuilder::new(&mut buf); + let result = builder.write_string(1, "too long"); + assert_eq!(result, Err(RotoError::BufferOverflow)); + } +} + +pub struct ProtoBuilder<'a> { + buf: &'a mut [u8], + pos: usize, +} + +impl<'a> ProtoBuilder<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { buf, pos: 0 } + } + + fn write_tag(&mut self, field_number: u32, wire_type: WireType) -> Result<()> { + let mut temp = [0u8; 10]; + let len = Tag::encode(field_number, wire_type, &mut temp)?; + self.append_bytes(&temp[..len]) + } + + fn append_bytes(&mut self, bytes: &[u8]) -> Result<()> { + if self.pos + bytes.len() > self.buf.len() { + return Err(RotoError::BufferOverflow); + } + self.buf[self.pos..self.pos + bytes.len()].copy_from_slice(bytes); + self.pos += bytes.len(); + Ok(()) + } + + pub fn write_varint(&mut self, field_number: u32, value: u64) -> Result<()> { + self.write_tag(field_number, WireType::Varint)?; + let mut temp = [0u8; 10]; + let len = write_varint(value, &mut temp)?; + self.append_bytes(&temp[..len]) + } + + pub fn write_int32(&mut self, field_number: u32, value: i32) -> Result<()> { + self.write_varint(field_number, value as u64) + } + + pub fn write_string(&mut self, field_number: u32, value: &str) -> Result<()> { + self.write_tag(field_number, WireType::LengthDelimited)?; + let bytes = value.as_bytes(); + let mut len_buf = [0u8; 10]; + let len_len = write_varint(bytes.len() as u64, &mut len_buf)?; + self.append_bytes(&len_buf[..len_len])?; + self.append_bytes(bytes) + } + + pub fn write_fixed32(&mut self, field_number: u32, value: u32) -> Result<()> { + self.write_tag(field_number, WireType::Fixed32)?; + self.append_bytes(&value.to_le_bytes()) + } + + pub fn write_fixed64(&mut self, field_number: u32, value: u64) -> Result<()> { + self.write_tag(field_number, WireType::Fixed64)?; + self.append_bytes(&value.to_le_bytes()) + } + + pub fn write_bytes(&mut self, field_number: u32, value: &[u8]) -> Result<()> { + self.write_tag(field_number, WireType::LengthDelimited)?; + let mut len_buf = [0u8; 10]; + let len_len = write_varint(value.len() as u64, &mut len_buf)?; + self.append_bytes(&len_buf[..len_len])?; + self.append_bytes(value) + } + + pub fn finish(self) -> Result<&'a mut [u8]> { + Ok(&mut self.buf[..self.pos]) + } +}