Refactor crate into multiple subcrates
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
[package]
|
||||
name = "roto-runtime"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
@@ -0,0 +1 @@
|
||||
bench/
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
||||
d_val: 3.1415926535
|
||||
f_val: 2.71828
|
||||
i32_val: 42
|
||||
i64_val: 123456789012345
|
||||
u32_val: 1000
|
||||
u64_val: 18446744073709551615
|
||||
si32_val: -42
|
||||
si64_val: -123456789012345
|
||||
fx32_val: 123456
|
||||
fx64_val: 1234567890123456789
|
||||
sfx32_val: -123456
|
||||
sfx64_val: -1234567890123456789
|
||||
b_val: true
|
||||
s_val: "Hello Roto!"
|
||||
bytes_val: "SGVsbG8gUm90byE="
|
||||
status: ACTIVE
|
||||
repeated_i32: 1
|
||||
repeated_i32: 2
|
||||
repeated_i32: 3
|
||||
repeated_i32: 4
|
||||
repeated_i32: 5
|
||||
repeated_string: "one"
|
||||
repeated_string: "two"
|
||||
repeated_string: "three"
|
||||
repeated_nested {
|
||||
id: 101
|
||||
name: "Nested 1"
|
||||
active: true
|
||||
}
|
||||
repeated_nested {
|
||||
id: 102
|
||||
name: "Nested 2"
|
||||
active: false
|
||||
}
|
||||
single_nested {
|
||||
id: 200
|
||||
name: "Single Nested"
|
||||
active: true
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,53 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package roto.test;
|
||||
|
||||
// A comprehensive message containing all primitive types and complex structures
|
||||
// to test the proto-to-rust codegen and runtime accessors.
|
||||
message ComplexMessage {
|
||||
// --- Floating Point ---
|
||||
double d_val = 1;
|
||||
float f_val = 2;
|
||||
|
||||
// --- Integers (Variable Length) ---
|
||||
int32 i32_val = 3;
|
||||
int64 i64_val = 4;
|
||||
uint32 u32_val = 5;
|
||||
uint64 u64_val = 6;
|
||||
sint32 si32_val = 7;
|
||||
sint64 si64_val = 8;
|
||||
|
||||
// --- Integers (Fixed Length) ---
|
||||
fixed32 fx32_val = 9;
|
||||
fixed64 fx64_val = 10;
|
||||
sfixed32 sfx32_val = 11;
|
||||
sfixed64 sfx64_val = 12;
|
||||
|
||||
// --- Other Primitives ---
|
||||
bool b_val = 13;
|
||||
string s_val = 14;
|
||||
bytes bytes_val = 15;
|
||||
|
||||
// --- Enumerations ---
|
||||
enum Status {
|
||||
UNKNOWN = 0;
|
||||
ACTIVE = 1;
|
||||
INACTIVE = 2;
|
||||
DELETED = 3;
|
||||
}
|
||||
Status status = 16;
|
||||
|
||||
// --- Repeated Fields ---
|
||||
// Testing packed primitives and non-packed types
|
||||
repeated int32 repeated_i32 = 17;
|
||||
repeated string repeated_string = 18;
|
||||
repeated NestedMessage repeated_nested = 19;
|
||||
|
||||
// --- Nested Messages ---
|
||||
message NestedMessage {
|
||||
int32 id = 1;
|
||||
string name = 2;
|
||||
bool active = 3;
|
||||
}
|
||||
NestedMessage single_nested = 20;
|
||||
}
|
||||
@@ -0,0 +1,775 @@
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum RotoError {
|
||||
UnexpectedEndOfBuffer,
|
||||
InvalidVarint,
|
||||
InvalidWireType(u8),
|
||||
BufferOverflow,
|
||||
FieldNotFound,
|
||||
WireFormatViolation,
|
||||
}
|
||||
|
||||
impl fmt::Display for RotoError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
RotoError::UnexpectedEndOfBuffer => write!(f, "Unexpected end of buffer"),
|
||||
RotoError::InvalidVarint => write!(f, "Invalid varint encoding"),
|
||||
RotoError::InvalidWireType(t) => write!(f, "Invalid wire type: {t}"),
|
||||
RotoError::BufferOverflow => write!(f, "Buffer overflow during write"),
|
||||
RotoError::FieldNotFound => write!(f, "Requested field not found in message"),
|
||||
RotoError::WireFormatViolation => write!(f, "Wire format violation"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RotoError {}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RotoError>;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum WireType {
|
||||
Varint = 0,
|
||||
Fixed64 = 1,
|
||||
LengthDelimited = 2,
|
||||
StartGroup = 3, // Deprecated
|
||||
EndGroup = 4, // Deprecated
|
||||
Fixed32 = 5,
|
||||
}
|
||||
|
||||
impl WireType {
|
||||
pub fn from_u8(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(WireType::Varint),
|
||||
1 => Ok(WireType::Fixed64),
|
||||
2 => Ok(WireType::LengthDelimited),
|
||||
3 => Ok(WireType::StartGroup),
|
||||
4 => Ok(WireType::EndGroup),
|
||||
5 => Ok(WireType::Fixed32),
|
||||
_ => Err(RotoError::InvalidWireType(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Tag {
|
||||
pub field_number: u32,
|
||||
pub wire_type: WireType,
|
||||
}
|
||||
|
||||
impl Tag {
|
||||
/// Decodes a tag from the buffer, returning the tag and the number of bytes read.
|
||||
pub fn decode(data: &[u8]) -> Result<(Self, usize)> {
|
||||
let (val, len) = read_varint(data)?;
|
||||
let wire_type_raw = (val & 0x7) as u8;
|
||||
let field_number = (val >> 3) as u32;
|
||||
|
||||
Ok((
|
||||
Tag {
|
||||
field_number,
|
||||
wire_type: WireType::from_u8(wire_type_raw)?,
|
||||
},
|
||||
len,
|
||||
))
|
||||
}
|
||||
|
||||
/// Encodes a tag into the provided buffer.
|
||||
pub fn encode(field_number: u32, wire_type: WireType, buf: &mut [u8]) -> Result<usize> {
|
||||
let val = ((field_number as u64) << 3) | (wire_type as u64);
|
||||
write_varint(val, buf)
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a varint from the start of the buffer.
|
||||
pub fn read_varint(data: &[u8]) -> Result<(u64, usize)> {
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0;
|
||||
let mut bytes_read = 0;
|
||||
|
||||
for &byte in data {
|
||||
bytes_read += 1;
|
||||
if bytes_read > 10 {
|
||||
return Err(RotoError::InvalidVarint);
|
||||
}
|
||||
|
||||
let value = (byte & 0x7F) as u64;
|
||||
if shift >= 64 {
|
||||
return Err(RotoError::InvalidVarint);
|
||||
}
|
||||
result |= value << shift;
|
||||
shift += 7;
|
||||
|
||||
if (byte & 0x80) == 0 {
|
||||
return Ok((result, bytes_read));
|
||||
}
|
||||
}
|
||||
|
||||
Err(RotoError::UnexpectedEndOfBuffer)
|
||||
}
|
||||
|
||||
/// Writes a varint into the buffer.
|
||||
pub fn write_varint(mut value: u64, buf: &mut [u8]) -> Result<usize> {
|
||||
let mut bytes_written = 0;
|
||||
while value >= 0x80 {
|
||||
if bytes_written >= buf.len() {
|
||||
return Err(RotoError::BufferOverflow);
|
||||
}
|
||||
buf[bytes_written] = (value as u8 & 0x7F) | 0x80;
|
||||
value >>= 7;
|
||||
bytes_written += 1;
|
||||
}
|
||||
|
||||
if bytes_written >= buf.len() {
|
||||
return Err(RotoError::BufferOverflow);
|
||||
}
|
||||
buf[bytes_written] = value as u8;
|
||||
bytes_written += 1;
|
||||
Ok(bytes_written)
|
||||
}
|
||||
|
||||
/// Returns the number of bytes that should be skipped for a given wire type and the current data slice.
|
||||
pub fn skip_value(wire_type: WireType, data: &[u8]) -> Result<usize> {
|
||||
match wire_type {
|
||||
WireType::Varint => {
|
||||
let (_, len) = read_varint(data)?;
|
||||
Ok(len)
|
||||
}
|
||||
WireType::Fixed64 => {
|
||||
if data.len() < 8 {
|
||||
return Err(RotoError::UnexpectedEndOfBuffer);
|
||||
}
|
||||
Ok(8)
|
||||
}
|
||||
WireType::LengthDelimited => {
|
||||
let (len, varint_len) = read_varint(data)?;
|
||||
let total_len = varint_len + len as usize;
|
||||
if data.len() < total_len {
|
||||
return Err(RotoError::UnexpectedEndOfBuffer);
|
||||
}
|
||||
Ok(total_len)
|
||||
}
|
||||
WireType::Fixed32 => {
|
||||
if data.len() < 4 {
|
||||
return Err(RotoError::UnexpectedEndOfBuffer);
|
||||
}
|
||||
Ok(4)
|
||||
}
|
||||
WireType::StartGroup | WireType::EndGroup => {
|
||||
// These are deprecated and not fully supported in this runtime.
|
||||
Err(RotoError::WireFormatViolation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProtoAccessor<'a> {
|
||||
data: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> ProtoAccessor<'a> {
|
||||
pub fn new(data: &'a [u8]) -> Result<Self> {
|
||||
Ok(Self { data })
|
||||
}
|
||||
|
||||
/// Returns an iterator over all fields in the message.
|
||||
pub fn fields(&self) -> FieldIterator<'a> {
|
||||
FieldIterator {
|
||||
data: self.data,
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value and wire type of the last occurrence of the specified field.
|
||||
pub fn get_value(&self, field_number: u32) -> Result<(&'a [u8], WireType)> {
|
||||
let mut last_value = None;
|
||||
for item in self.fields() {
|
||||
let (_offset, tag, value) = item?;
|
||||
if tag.field_number == field_number {
|
||||
last_value = Some((value, tag.wire_type));
|
||||
}
|
||||
}
|
||||
last_value.ok_or(RotoError::FieldNotFound)
|
||||
}
|
||||
|
||||
/// Returns an iterator that scans the entire buffer for all occurrences of the specified field.
|
||||
pub fn iter_repeated(&self, field_number: u32) -> RepeatedFieldIterator<'a> {
|
||||
RepeatedFieldIterator::new(self.data, field_number)
|
||||
}
|
||||
|
||||
/// Returns the value and wire type of a field at a specific offset.
|
||||
pub fn get_value_at(&self, offset: usize) -> Result<(&'a [u8], WireType)> {
|
||||
if offset >= self.data.len() {
|
||||
return Err(RotoError::UnexpectedEndOfBuffer);
|
||||
}
|
||||
let (tag, tag_len) = Tag::decode(&self.data[offset..])?;
|
||||
let cursor_after_tag = offset + tag_len;
|
||||
if cursor_after_tag > self.data.len() {
|
||||
return Err(RotoError::UnexpectedEndOfBuffer);
|
||||
}
|
||||
let value_len = skip_value(tag.wire_type, &self.data[cursor_after_tag..])?;
|
||||
let (value_offset, actual_value_len) = match tag.wire_type {
|
||||
WireType::LengthDelimited => {
|
||||
let (_, varint_len) = read_varint(&self.data[cursor_after_tag..])?;
|
||||
(cursor_after_tag + varint_len, value_len - varint_len)
|
||||
}
|
||||
_ => (cursor_after_tag, value_len),
|
||||
};
|
||||
Ok((
|
||||
&self.data[value_offset..value_offset + actual_value_len],
|
||||
tag.wire_type,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns an iterator that scans a specific range of the buffer for all occurrences of the specified field.
|
||||
pub fn iter_repeated_range(
|
||||
&self,
|
||||
field_number: u32,
|
||||
start: usize,
|
||||
end: usize,
|
||||
) -> RepeatedFieldIterator<'a> {
|
||||
RepeatedFieldIterator::new_range(self.data, field_number, start, end)
|
||||
}
|
||||
|
||||
/// Returns an iterator that yields `(field_number, raw_bytes)` for every
|
||||
/// field in the message. `raw_bytes` is the complete on-wire encoding
|
||||
/// (tag + value, including any length prefix), suitable for passing
|
||||
/// directly to `ProtoBuilder::write_raw`.
|
||||
pub fn raw_fields(&self) -> RawFieldIterator<'a> {
|
||||
RawFieldIterator {
|
||||
data: self.data,
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FieldIterator<'a> {
|
||||
data: &'a [u8],
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for FieldIterator<'a> {
|
||||
type Item = Result<(usize, Tag, &'a [u8])>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.cursor >= self.data.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
|
||||
let cursor_after_tag = self.cursor + tag_len;
|
||||
if cursor_after_tag > self.data.len() {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(RotoError::UnexpectedEndOfBuffer));
|
||||
}
|
||||
|
||||
let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
|
||||
let (value_offset, actual_value_len) = match tag.wire_type {
|
||||
WireType::LengthDelimited => {
|
||||
let (_, varint_len) = match read_varint(&self.data[cursor_after_tag..]) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
(cursor_after_tag + varint_len, value_len - varint_len)
|
||||
}
|
||||
_ => (cursor_after_tag, value_len),
|
||||
};
|
||||
|
||||
self.cursor = cursor_after_tag + value_len;
|
||||
|
||||
Some(Ok((
|
||||
self.cursor - tag_len - value_len,
|
||||
tag,
|
||||
&self.data[value_offset..value_offset + actual_value_len],
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RepeatedFieldIterator<'a> {
|
||||
iterator: FieldIterator<'a>,
|
||||
field_number: u32,
|
||||
end_offset: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a> RepeatedFieldIterator<'a> {
|
||||
pub fn new(data: &'a [u8], field_number: u32) -> Self {
|
||||
Self {
|
||||
iterator: FieldIterator { data, cursor: 0 },
|
||||
field_number,
|
||||
end_offset: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_range(data: &'a [u8], field_number: u32, start: usize, end: usize) -> Self {
|
||||
Self {
|
||||
iterator: FieldIterator {
|
||||
data,
|
||||
cursor: start,
|
||||
},
|
||||
field_number,
|
||||
end_offset: Some(end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for RepeatedFieldIterator<'a> {
|
||||
type Item = Result<(&'a [u8], WireType)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(item) = self.iterator.next() {
|
||||
match item {
|
||||
Ok((offset, tag, value)) if tag.field_number == self.field_number => {
|
||||
if let Some(end) = self.end_offset {
|
||||
if offset > end {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
return Some(Ok((value, tag.wire_type)));
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Some(Err(e)),
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator that yields `(field_number, raw_bytes)` for every field in a
|
||||
/// protobuf message, where `raw_bytes` is the complete on-wire encoding of the
|
||||
/// field: tag varint + value bytes (including the length prefix for
|
||||
/// length-delimited fields). This is the slice needed by
|
||||
/// `ProtoBuilder::write_raw` to copy a field verbatim.
|
||||
pub struct RawFieldIterator<'a> {
|
||||
data: &'a [u8],
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for RawFieldIterator<'a> {
|
||||
type Item = Result<(u32, &'a [u8])>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.cursor >= self.data.len() {
|
||||
return None;
|
||||
}
|
||||
let field_start = self.cursor;
|
||||
let (tag, tag_len) = match Tag::decode(&self.data[self.cursor..]) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
let cursor_after_tag = self.cursor + tag_len;
|
||||
if cursor_after_tag > self.data.len() {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(RotoError::UnexpectedEndOfBuffer));
|
||||
}
|
||||
let value_len = match skip_value(tag.wire_type, &self.data[cursor_after_tag..]) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
self.cursor = self.data.len();
|
||||
return Some(Err(e));
|
||||
}
|
||||
};
|
||||
self.cursor = cursor_after_tag + value_len;
|
||||
Some(Ok((tag.field_number, &self.data[field_start..self.cursor])))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_varint_read_write() {
|
||||
let mut buf = [0u8; 10];
|
||||
let val = 300u64;
|
||||
let len = write_varint(val, &mut buf).unwrap();
|
||||
assert_eq!(len, 2);
|
||||
assert_eq!(&buf[..2], &[0xAC, 0x02]);
|
||||
|
||||
let (read_val, read_len) = read_varint(&buf[..2]).unwrap();
|
||||
assert_eq!(read_val, val);
|
||||
assert_eq!(read_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tag_decode() {
|
||||
// Field 1, WireType Varint: (1 << 3) | 0 = 8
|
||||
let data = [8u8];
|
||||
let (tag, len) = Tag::decode(&data).unwrap();
|
||||
assert_eq!(tag.field_number, 1);
|
||||
assert_eq!(tag.wire_type, WireType::Varint);
|
||||
assert_eq!(len, 1);
|
||||
|
||||
// Field 15, WireType LengthDelimited: (15 << 3) | 2 = 120 | 2 = 122
|
||||
let data2 = [122u8];
|
||||
let (tag2, len2) = Tag::decode(&data2).unwrap();
|
||||
assert_eq!(tag2.field_number, 15);
|
||||
assert_eq!(tag2.wire_type, WireType::LengthDelimited);
|
||||
assert_eq!(len2, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_value() {
|
||||
// Varint: 300 (2 bytes)
|
||||
let data_varint = [0xAC, 0x02];
|
||||
assert_eq!(skip_value(WireType::Varint, &data_varint).unwrap(), 2);
|
||||
|
||||
// Fixed32: 4 bytes
|
||||
let data_fixed32 = [0u8; 4];
|
||||
assert_eq!(skip_value(WireType::Fixed32, &data_fixed32).unwrap(), 4);
|
||||
|
||||
// Length delimited: len=3, data=[1,2,3] (1 byte varint for length + 3 bytes)
|
||||
let data_len = [3, 1, 2, 3];
|
||||
assert_eq!(skip_value(WireType::LengthDelimited, &data_len).unwrap(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accessor_basic() {
|
||||
// Field 1 (Varint): 150
|
||||
// Tag: (1 << 3) | 0 = 8. Value: 150 = [150, 1]
|
||||
// Field 2 (LengthDelimited): "hi"
|
||||
// Tag: (2 << 3) | 2 = 18. Length: 2. Value: [104, 105]
|
||||
let data = [8, 150, 1, 18, 2, 104, 105];
|
||||
let acc = ProtoAccessor::new(&data).unwrap();
|
||||
|
||||
let (val1, type1) = acc.get_value(1).unwrap();
|
||||
assert_eq!(type1, WireType::Varint);
|
||||
assert_eq!(val1, &[150, 1]);
|
||||
|
||||
let (val2, type2) = acc.get_value(2).unwrap();
|
||||
assert_eq!(type2, WireType::LengthDelimited);
|
||||
assert_eq!(val2, &[104, 105]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accessor_repeated() {
|
||||
// Field 1: 10, Field 1: 20, Field 1: 30
|
||||
// Tags: 8, 8, 8. Values: 10, 20, 30
|
||||
let data = [8, 10, 8, 20, 8, 30];
|
||||
let acc = ProtoAccessor::new(&data).unwrap();
|
||||
|
||||
// Last value should be 30
|
||||
let (val, _) = acc.get_value(1).unwrap();
|
||||
assert_eq!(val, &[30]);
|
||||
|
||||
// Iteration should find all three
|
||||
let results: Vec<_> = acc.iter_repeated(1).collect();
|
||||
assert_eq!(results.len(), 3);
|
||||
assert_eq!(results[0].as_ref().unwrap().0, &[10]);
|
||||
assert_eq!(results[1].as_ref().unwrap().0, &[20]);
|
||||
assert_eq!(results[2].as_ref().unwrap().0, &[30]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_basic() {
|
||||
let mut buf = [0u8; 1024];
|
||||
let mut builder = ProtoBuilder::new(&mut buf);
|
||||
builder.write_string(1, "hello").unwrap();
|
||||
builder.write_int32(2, 42).unwrap();
|
||||
let data = builder.finish().unwrap();
|
||||
|
||||
let acc = ProtoAccessor::new(data).unwrap();
|
||||
let (val1, _) = acc.get_value(1).unwrap();
|
||||
assert_eq!(val1, "hello".as_bytes());
|
||||
let (val2, _) = acc.get_value(2).unwrap();
|
||||
assert_eq!(val2, &[42]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_overflow() {
|
||||
let mut buf = [0u8; 2];
|
||||
let mut builder = ProtoBuilder::new(&mut buf);
|
||||
let result = builder.write_string(1, "too long");
|
||||
assert_eq!(result, Err(RotoError::BufferOverflow));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_raw_field_iterator_yields_correct_bytes() {
|
||||
// Build: field 1 = string "hi", field 2 = int32 42
|
||||
let mut buf = [0u8; 64];
|
||||
let mut builder = ProtoBuilder::new(&mut buf);
|
||||
builder.write_string(1, "hi").unwrap();
|
||||
builder.write_int32(2, 42).unwrap();
|
||||
let data = builder.finish().unwrap().to_vec();
|
||||
|
||||
let acc = ProtoAccessor::new(&data).unwrap();
|
||||
let raw: Vec<_> = acc.raw_fields().collect();
|
||||
assert_eq!(raw.len(), 2);
|
||||
|
||||
// Field 1: tag = (1 << 3) | 2 = 0x0A, len varint = 0x02, "hi" = [0x68, 0x69]
|
||||
let (fn1, bytes1) = raw[0].as_ref().unwrap();
|
||||
assert_eq!(*fn1, 1);
|
||||
assert_eq!(*bytes1, [0x0A, 0x02, b'h', b'i']);
|
||||
|
||||
// Field 2: tag = (2 << 3) | 0 = 0x10, varint 42 = 0x2A
|
||||
let (fn2, bytes2) = raw[1].as_ref().unwrap();
|
||||
assert_eq!(*fn2, 2);
|
||||
assert_eq!(*bytes2, [0x10, 0x2A]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_raw_copies_field_verbatim() {
|
||||
// Build source: field 1 = string "hello", field 2 = int32 99
|
||||
let mut src_buf = [0u8; 64];
|
||||
let mut src_builder = ProtoBuilder::new(&mut src_buf);
|
||||
src_builder.write_string(1, "hello").unwrap();
|
||||
src_builder.write_int32(2, 99).unwrap();
|
||||
let src_data = src_builder.finish().unwrap().to_vec();
|
||||
|
||||
// Copy every raw field verbatim into a new buffer
|
||||
let src_acc = ProtoAccessor::new(&src_data).unwrap();
|
||||
let mut dst_buf = [0u8; 64];
|
||||
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
|
||||
for item in src_acc.raw_fields() {
|
||||
let (_, raw_bytes) = item.unwrap();
|
||||
dst_builder.write_raw(raw_bytes).unwrap();
|
||||
}
|
||||
let dst_data = dst_builder.finish().unwrap();
|
||||
|
||||
// The copy must be byte-identical to the source
|
||||
assert_eq!(dst_data, src_data.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_pattern_copies_unseen_fields() {
|
||||
// Build an existing source message with 3 fields
|
||||
let mut src_buf = [0u8; 128];
|
||||
let mut src_builder = ProtoBuilder::new(&mut src_buf);
|
||||
src_builder.write_string(1, "original").unwrap();
|
||||
src_builder.write_int32(2, 99).unwrap();
|
||||
src_builder.write_varint(3, 1u64).unwrap(); // bool
|
||||
let src_data = src_builder.finish().unwrap().to_vec();
|
||||
let src_acc = ProtoAccessor::new(&src_data).unwrap();
|
||||
|
||||
// Simulate what a generated `with` method does:
|
||||
// field 1 was explicitly written; fields 2 and 3 come from source.
|
||||
let field1_written = true;
|
||||
let field2_written = false;
|
||||
let field3_written = false;
|
||||
|
||||
let mut dst_buf = [0u8; 128];
|
||||
let mut dst_builder = ProtoBuilder::new(&mut dst_buf);
|
||||
dst_builder.write_string(1, "updated").unwrap();
|
||||
|
||||
for item in src_acc.raw_fields() {
|
||||
let (field_number, raw_bytes) = item.unwrap();
|
||||
let is_written = match field_number {
|
||||
1 => field1_written,
|
||||
2 => field2_written,
|
||||
3 => field3_written,
|
||||
_ => false,
|
||||
};
|
||||
if !is_written {
|
||||
dst_builder.write_raw(raw_bytes).unwrap();
|
||||
}
|
||||
}
|
||||
let dst_data = dst_builder.finish().unwrap();
|
||||
let dst_acc = ProtoAccessor::new(dst_data).unwrap();
|
||||
|
||||
// Field 1: overridden value
|
||||
let (val1, _) = dst_acc.get_value(1).unwrap();
|
||||
assert_eq!(val1, b"updated");
|
||||
|
||||
// Field 2: copied from source
|
||||
let (val2, _) = dst_acc.get_value(2).unwrap();
|
||||
let (v2, _) = read_varint(val2).unwrap();
|
||||
assert_eq!(v2 as i32, 99);
|
||||
|
||||
// Field 3: copied from source
|
||||
let (val3, _) = dst_acc.get_value(3).unwrap();
|
||||
let (v3, _) = read_varint(val3).unwrap();
|
||||
assert_eq!(v3, 1u64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protoc_binary_compatibility() {
|
||||
let data = include_bytes!("../data/test_data.pb");
|
||||
let acc = ProtoAccessor::new(data).unwrap();
|
||||
|
||||
// 1. Varints (Integers, Booleans, Enums)
|
||||
let (val_i32, type_i32) = acc.get_value(3).expect("i32_val not found");
|
||||
assert_eq!(type_i32, WireType::Varint);
|
||||
let (v, _) = read_varint(val_i32).unwrap();
|
||||
assert_eq!(v, 42);
|
||||
|
||||
let (val_b, type_b) = acc.get_value(13).expect("b_val not found");
|
||||
assert_eq!(type_b, WireType::Varint);
|
||||
let (v_b, _) = read_varint(val_b).unwrap();
|
||||
assert_eq!(v_b, 1); // true
|
||||
|
||||
let (val_status, type_status) = acc.get_value(16).expect("status not found");
|
||||
assert_eq!(type_status, WireType::Varint);
|
||||
let (v_s, _) = read_varint(val_status).unwrap();
|
||||
assert_eq!(v_s, 1); // ACTIVE
|
||||
|
||||
// 2. Length Delimited (Strings, Bytes)
|
||||
let (val_s, type_s) = acc.get_value(14).expect("s_val not found");
|
||||
assert_eq!(type_s, WireType::LengthDelimited);
|
||||
assert_eq!(val_s, "Hello Roto!".as_bytes());
|
||||
|
||||
// 3. Fixed Width (Floats)
|
||||
let (val_f, type_f) = acc.get_value(2).expect("f_val not found");
|
||||
assert_eq!(type_f, WireType::Fixed32);
|
||||
let f_val = f32::from_le_bytes(val_f.try_into().expect("Expected 4 bytes for f32"));
|
||||
assert!((f_val - 2.71828).abs() < 1e-5);
|
||||
|
||||
// 4. Repeated Fields
|
||||
// Note: primitive repeated fields are packed in proto3, so we iterate over the blob
|
||||
let mut i32_vals = Vec::new();
|
||||
for item in acc.iter_repeated(17) {
|
||||
let (blob, _) = item.expect("Failed to decode repeated i32");
|
||||
let mut cursor = 0;
|
||||
while cursor < blob.len() {
|
||||
let (v, len) = read_varint(&blob[cursor..]).unwrap();
|
||||
i32_vals.push(v);
|
||||
cursor += len;
|
||||
}
|
||||
}
|
||||
assert_eq!(i32_vals, vec![1, 2, 3, 4, 5]);
|
||||
|
||||
let repeated_strings: Vec<_> = acc
|
||||
.iter_repeated(18)
|
||||
.map(|r| {
|
||||
let (val, _) = r.expect("Failed to decode repeated string");
|
||||
std::str::from_utf8(val).expect("Invalid utf8")
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(repeated_strings, vec!["one", "two", "three"]);
|
||||
|
||||
let repeated_nested: Vec<_> = acc
|
||||
.iter_repeated(19)
|
||||
.map(|r| {
|
||||
let (val, _) = r.expect("Failed to decode repeated nested");
|
||||
let nested_acc = ProtoAccessor::new(val).unwrap();
|
||||
let (id_val, _) = nested_acc.get_value(1).expect("Nested id not found");
|
||||
let (id, _) = read_varint(id_val).unwrap();
|
||||
id
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(repeated_nested, vec![101, 102]);
|
||||
|
||||
// 5. Single Nested Message
|
||||
let (val_nested, type_nested) = acc.get_value(20).expect("single_nested not found");
|
||||
assert_eq!(type_nested, WireType::LengthDelimited);
|
||||
let nested_acc = ProtoAccessor::new(val_nested).unwrap();
|
||||
let (val_id, _) = nested_acc.get_value(1).expect("Nested id not found");
|
||||
let (id, _) = read_varint(val_id).unwrap();
|
||||
assert_eq!(id, 200);
|
||||
|
||||
// Validate that fields appear in the expected relative order
|
||||
let field_numbers: Vec<u32> = acc
|
||||
.fields()
|
||||
.map(|r| r.expect("Failed to decode field").1.field_number)
|
||||
.collect();
|
||||
|
||||
let essential_fields = [1, 2, 3, 14, 16, 20];
|
||||
let mut last_field = 0;
|
||||
let mut found_count = 0;
|
||||
for &f in &field_numbers {
|
||||
if essential_fields.contains(&f) {
|
||||
assert!(
|
||||
f >= last_field,
|
||||
"Fields appeared out of order: {} came after {}",
|
||||
f,
|
||||
last_field
|
||||
);
|
||||
last_field = f;
|
||||
found_count += 1;
|
||||
}
|
||||
}
|
||||
assert_eq!(found_count, essential_fields.len());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProtoBuilder<'a> {
|
||||
buf: &'a mut [u8],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> ProtoBuilder<'a> {
|
||||
pub fn new(buf: &'a mut [u8]) -> Self {
|
||||
Self { buf, pos: 0 }
|
||||
}
|
||||
|
||||
fn write_tag(&mut self, field_number: u32, wire_type: WireType) -> Result<()> {
|
||||
let mut temp = [0u8; 10];
|
||||
let len = Tag::encode(field_number, wire_type, &mut temp)?;
|
||||
self.append_bytes(&temp[..len])
|
||||
}
|
||||
|
||||
fn append_bytes(&mut self, bytes: &[u8]) -> Result<()> {
|
||||
if self.pos + bytes.len() > self.buf.len() {
|
||||
return Err(RotoError::BufferOverflow);
|
||||
}
|
||||
self.buf[self.pos..self.pos + bytes.len()].copy_from_slice(bytes);
|
||||
self.pos += bytes.len();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_varint(&mut self, field_number: u32, value: u64) -> Result<()> {
|
||||
self.write_tag(field_number, WireType::Varint)?;
|
||||
let mut temp = [0u8; 10];
|
||||
let len = write_varint(value, &mut temp)?;
|
||||
self.append_bytes(&temp[..len])
|
||||
}
|
||||
|
||||
pub fn write_int32(&mut self, field_number: u32, value: i32) -> Result<()> {
|
||||
self.write_varint(field_number, value as u64)
|
||||
}
|
||||
|
||||
pub fn write_string(&mut self, field_number: u32, value: &str) -> Result<()> {
|
||||
self.write_tag(field_number, WireType::LengthDelimited)?;
|
||||
let bytes = value.as_bytes();
|
||||
let mut len_buf = [0u8; 10];
|
||||
let len_len = write_varint(bytes.len() as u64, &mut len_buf)?;
|
||||
self.append_bytes(&len_buf[..len_len])?;
|
||||
self.append_bytes(bytes)
|
||||
}
|
||||
|
||||
pub fn write_fixed32(&mut self, field_number: u32, value: u32) -> Result<()> {
|
||||
self.write_tag(field_number, WireType::Fixed32)?;
|
||||
self.append_bytes(&value.to_le_bytes())
|
||||
}
|
||||
|
||||
pub fn write_fixed64(&mut self, field_number: u32, value: u64) -> Result<()> {
|
||||
self.write_tag(field_number, WireType::Fixed64)?;
|
||||
self.append_bytes(&value.to_le_bytes())
|
||||
}
|
||||
|
||||
pub fn write_bytes(&mut self, field_number: u32, value: &[u8]) -> Result<()> {
|
||||
self.write_tag(field_number, WireType::LengthDelimited)?;
|
||||
let mut len_buf = [0u8; 10];
|
||||
let len_len = write_varint(value.len() as u64, &mut len_buf)?;
|
||||
self.append_bytes(&len_buf[..len_len])?;
|
||||
self.append_bytes(value)
|
||||
}
|
||||
|
||||
/// Appends a pre-encoded field (tag + value bytes) verbatim into the
|
||||
/// buffer. Use this together with `ProtoAccessor::raw_fields` to copy
|
||||
/// fields from an existing message into a builder without re-encoding them.
|
||||
pub fn write_raw(&mut self, raw_bytes: &[u8]) -> Result<()> {
|
||||
self.append_bytes(raw_bytes)
|
||||
}
|
||||
|
||||
pub fn finish(self) -> Result<&'a mut [u8]> {
|
||||
Ok(&mut self.buf[..self.pos])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user