Files
roto/src/generator.rs
T

489 lines
19 KiB
Rust
Raw Normal View History

2026-05-04 13:46:05 -07:00
use crate::ProtoAccessor;
2026-05-04 10:45:08 -07:00
use crate::google::protobuf::descriptor::{
2026-05-04 13:46:05 -07:00
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
FileDescriptorSet,
2026-05-04 10:45:08 -07:00
};
2026-05-03 20:44:07 -07:00
use std::collections::{HashMap, HashSet};
2026-05-04 13:46:05 -07:00
use std::str;
2026-05-02 22:48:03 -07:00
pub fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
pub fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_ascii_lowercase());
} else {
result.push(c);
}
}
result
}
2026-05-02 22:48:03 -07:00
fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) {
if label == 3 {
// LABEL_REPEATED
return (
"crate::RepeatedFieldIterator<'a>".to_string(),
2026-05-03 13:32:39 -07:00
"".to_string(), // Not used for repeated fields in the same way
2026-05-02 22:48:03 -07:00
);
}
match field_type {
9 => (
"&'a str".to_string(),
"str::from_utf8(bytes).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // TYPE_STRING
1 => (
"f64".to_string(),
"Ok(f64::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(),
2026-05-02 22:48:03 -07:00
), // TYPE_DOUBLE
2 => (
"f32".to_string(),
2026-05-04 14:40:11 -07:00
"Ok(f32::from_le_bytes(bytes.try_into().map_err(|_| crate::RotoError::WireFormatViolation)?))".to_string(),
2026-05-02 22:48:03 -07:00
), // TYPE_FLOAT
3 | 5 | 15 | 17 => (
"i32".to_string(),
"crate::read_varint(bytes).map(|(v, _)| v as i32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // INT/SINT/SFIXED 32
4 | 6 | 13 => (
"u32".to_string(),
"crate::read_varint(bytes).map(|(v, _)| v as u32).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // UINT/FIXED 32
16 | 18 => (
"i64".to_string(),
2026-05-04 09:23:01 -07:00
"crate::read_varint(bytes).map(|(v, _)| v as i64).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // SINT/SFIXED 64
7 | 14 => (
"u64".to_string(),
2026-05-04 09:23:01 -07:00
"crate::read_varint(bytes).map(|(v, _)| v as u64).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // UINT/FIXED 64
8 => (
"bool".to_string(),
2026-05-04 09:23:01 -07:00
"crate::read_varint(bytes).map(|(v, _)| v != 0).map_err(|_| crate::RotoError::WireFormatViolation)".to_string(),
2026-05-02 22:48:03 -07:00
), // TYPE_BOOL
11 | 12 => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()), // MESSAGE/BYTES
_ => ("&'a [u8]".to_string(), "Ok(bytes)".to_string()),
}
}
fn write_enum(enum_proto: &EnumDescriptorProto, output: &mut String) {
let enum_name = to_pascal_case(enum_proto.name().unwrap());
output.push_str(&format!(
"#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(i32)]\npub enum {} {{\n",
enum_name
));
2026-05-04 10:45:08 -07:00
let mut values = enum_proto.value();
let mut zero_variant_name = None;
while let Some(val_res) = values.next() {
let (val_data, _) = val_res.expect("Failed to iterate enum");
2026-05-04 13:46:05 -07:00
let accessor =
ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto");
let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing");
let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8");
let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing");
let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint");
let pascal_name = to_pascal_case(name);
if num == 0 {
zero_variant_name = Some(pascal_name.clone());
}
output.push_str(&format!(" {} = {},\n", pascal_name, num));
}
if zero_variant_name.is_none() {
output.push_str(" Unknown = 0,\n");
zero_variant_name = Some("Unknown".to_string());
}
output.push_str("}\n\n");
output.push_str(&format!(
"impl {} {{\n pub fn from_i32(value: i32) -> Self {{\n match value {{\n",
enum_name
));
2026-05-04 10:45:08 -07:00
let mut values = enum_proto.value();
while let Some(val_res) = values.next() {
let (val_data, _) = val_res.expect("Failed to read enum value");
2026-05-04 13:46:05 -07:00
let accessor =
ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto");
let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing");
let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8");
let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing");
let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint");
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" {} => {}::{},\n",
num,
enum_name,
to_pascal_case(name)
));
}
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" _ => {}::{},\n",
enum_name,
zero_variant_name.as_ref().unwrap()
));
output.push_str(" }\n }\n}\n\n");
}
fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
let msg_name = to_pascal_case(msg_proto.name().unwrap());
let mut fields_info = Vec::new();
for field_res in msg_proto.field() {
let (field_data, _) = field_res.expect("Failed to iterate field");
2026-05-04 13:46:05 -07:00
let field_proto =
FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto");
let field_name = field_proto.name().unwrap();
let tag = field_proto.number().unwrap();
2026-05-04 10:45:08 -07:00
let f_type = field_proto.r#type().unwrap() as i32;
let f_label = field_proto.label().unwrap() as i32;
fields_info.push((field_name.to_string(), tag, f_type, f_label));
2026-05-04 11:14:57 -07:00
}
2026-05-04 13:46:05 -07:00
output.push_str(&format!("pub struct {}<'a> {{\n", msg_name));
output.push_str(" accessor: crate::ProtoAccessor<'a>,\n");
2026-05-04 11:14:57 -07:00
for (field_name, _tag, _f_type, f_label) in &fields_info {
if *f_label == 3 {
output.push_str(&format!(" {}_start: Option<usize>,\n", field_name));
output.push_str(&format!(" {}_end: Option<usize>,\n", field_name));
} else {
output.push_str(&format!(" {}_offset: Option<usize>,\n", field_name));
}
}
output.push_str("}\n\n");
output.push_str(&format!("impl<'a> {}<'a> {{\n", msg_name));
output.push_str(" pub fn new(data: &'a [u8]) -> crate::Result<Self> {\n");
2026-05-04 13:46:05 -07:00
output.push_str(" let accessor = crate::ProtoAccessor::new(data)?;\n");
2026-05-04 11:14:57 -07:00
if !fields_info.is_empty() {
for (name, _, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(" let mut {}_start = None;\n", name));
output.push_str(&format!(" let mut {}_end = None;\n", name));
} else {
output.push_str(&format!(" let mut {}_offset = None;\n", name));
}
}
2026-05-04 11:14:57 -07:00
output.push_str(" for item in accessor.fields() {\n");
output.push_str(" let (offset, tag, _) = item?;\n");
for (name, tag, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!(" if tag.field_number == {} {{\n", tag));
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" if {}_start.is_none() {{ {}_start = Some(offset); }}\n",
name, name
));
2026-05-04 11:14:57 -07:00
output.push_str(&format!(" {}_end = Some(offset);\n", name));
output.push_str(" }\n");
} else {
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" if tag.field_number == {} {{ {}_offset = Some(offset); }}\n",
tag, name
));
2026-05-04 11:14:57 -07:00
}
}
2026-05-04 11:14:57 -07:00
output.push_str(" }\n\n");
}
output.push_str(" Ok(Self {\n");
2026-05-04 13:46:05 -07:00
output.push_str(" accessor,\n");
for (name, _, _, label) in &fields_info {
if *label == 3 {
output.push_str(&format!("{}_start, {}_end,\n", name, name));
} else {
output.push_str(&format!("{}_offset,\n", name));
}
}
output.push_str(" })\n }\n\n");
for (field_name, tag, f_type, f_label) in fields_info {
let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label);
2026-05-04 13:46:05 -07:00
let safe_name = if field_name == "type" {
format!("r#{}", field_name)
} else {
field_name.clone()
};
if f_label == 3 {
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" pub fn {}(&self) -> {} {{\n",
safe_name, rust_type
));
output.push_str(&format!(
" match (self.{}_start, self.{}_end) {{\n",
field_name, field_name
));
output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag));
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" _ => self.accessor.iter_repeated({}),\n",
tag
));
output.push_str(" }\n }\n\n");
} else {
2026-05-04 13:46:05 -07:00
output.push_str(&format!(
" pub fn {}(&self) -> crate::Result<{}> {{\n",
safe_name, rust_type
));
output.push_str(&format!(
" let offset = self.{}_offset.ok_or(crate::RotoError::FieldNotFound)?;\n",
field_name
));
output.push_str(" let (bytes, _) = self.accessor.get_value_at(offset)?;\n");
output.push_str(&format!(" {}\n", logic));
output.push_str(" }\n\n");
}
}
// raw_fields() convenience on the message struct (before closing the impl)
output.push_str(" pub fn raw_fields(&self) -> roto::RawFieldIterator<'a> {\n");
output.push_str(" self.accessor.raw_fields()\n");
output.push_str(" }\n\n");
output.push_str("}\n\n");
// Collect builder field info so we can use it multiple times below.
// Tuple: (field_name, safe_name, tag, rust_type, write_method)
let mut builder_fields: Vec<(String, String, u32, String, String)> = Vec::new();
for field_res in msg_proto.field() {
let (field_data, _) = field_res.expect("Failed to iterate field");
2026-05-04 13:46:05 -07:00
let field_proto =
FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto");
let field_name = field_proto.name().unwrap().to_string();
2026-05-04 13:46:05 -07:00
let safe_name = if field_name == "type" {
format!("r#{}", field_name)
} else {
field_name.clone()
2026-05-04 13:46:05 -07:00
};
let tag = field_proto.number().unwrap();
2026-05-04 10:45:08 -07:00
let f_type = field_proto.r#type().unwrap() as i32;
let (rust_type, method) = map_type_to_rust_builder(f_type);
2026-05-04 20:11:54 -07:00
builder_fields.push((field_name, safe_name, tag as u32, rust_type, method));
}
// Builder struct — one `_written: bool` flag per field
output.push_str(&format!("pub struct {}Builder<'b> {{\n", msg_name));
output.push_str(" builder: crate::ProtoBuilder<'b>,\n");
for (field_name, _, _, _, _) in &builder_fields {
output.push_str(&format!(" {}_written: bool,\n", field_name));
}
2026-05-04 20:11:54 -07:00
output.push_str(&format!("}}\n\nimpl<'b> {}Builder<'b> {{\n", msg_name));
// Constructor — initialise every flag to false
output.push_str(&format!(
" pub fn builder(buf: &mut [u8]) -> {}Builder<'_> {{\n {}Builder {{\n",
msg_name, msg_name
));
output.push_str(" builder: crate::ProtoBuilder::new(buf),\n");
for (field_name, _, _, _, _) in &builder_fields {
output.push_str(&format!(" {}_written: false,\n", field_name));
}
output.push_str(" }\n }\n\n");
// Per-field setters — mark field as written
for (field_name, safe_name, tag, rust_type, method) in &builder_fields {
output.push_str(&format!(
" pub fn {}(mut self, value: {}) -> crate::Result<Self> {{\n self.builder.{}({}, value)?;\n self.{}_written = true;\n Ok(self)\n }}\n\n",
safe_name, rust_type, method, tag, field_name
));
}
// with() — copies unseen fields from an existing message
output.push_str(&format!(
" pub fn with(mut self, msg: &{}<'_>) -> crate::Result<Self> {{\n",
msg_name
));
output.push_str(" for item in msg.raw_fields() {\n");
output.push_str(" let (field_number, raw_bytes) = item?;\n");
output.push_str(" let is_written = match field_number {\n");
for (field_name, _, tag, _, _) in &builder_fields {
output.push_str(&format!(
" {} => self.{}_written,\n",
tag, field_name
));
}
output.push_str(" _ => false,\n");
output.push_str(" };\n");
output.push_str(" if !is_written {\n");
output.push_str(" self.builder.write_raw(raw_bytes)?;\n");
output.push_str(" }\n");
output.push_str(" }\n");
output.push_str(" Ok(self)\n");
output.push_str(" }\n\n");
output.push_str(&format!(" pub fn finish(self) -> crate::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
let mut nested_enums = Vec::new();
for e_res in msg_proto.enum_type() {
2026-05-04 13:46:05 -07:00
if let Ok((e, _)) = e_res {
nested_enums.push(e);
}
}
let mut nested_msgs = Vec::new();
for m_res in msg_proto.nested_type() {
2026-05-04 13:46:05 -07:00
if let Ok((m, _)) = m_res {
nested_msgs.push(m);
}
}
if !nested_enums.is_empty() || !nested_msgs.is_empty() {
let mod_name = to_snake_case(msg_proto.name().unwrap());
output.push_str(&format!("pub mod {} {{\n", mod_name));
for e_data in nested_enums {
2026-05-04 13:46:05 -07:00
write_enum(
&EnumDescriptorProto::new(e_data)
.expect("Failed to parse nested EnumDescriptorProto"),
output,
);
}
for m_data in nested_msgs {
2026-05-04 13:46:05 -07:00
write_message(
&DescriptorProto::new(m_data).expect("Failed to parse nested DescriptorProto"),
output,
);
}
output.push_str("}\n\n");
}
}
2026-05-02 22:48:03 -07:00
fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
match field_type {
9 => ("&str".to_string(), "write_string".to_string()),
5 | 17 => ("i32".to_string(), "write_int32".to_string()),
2026-05-03 12:57:14 -07:00
3 | 4 | 8 | 13 | 14 | 18 => ("u64".to_string(), "write_varint".to_string()),
2026-05-02 22:48:03 -07:00
7 | 15 => ("u32".to_string(), "write_fixed32".to_string()),
6 | 16 => ("u64".to_string(), "write_fixed64".to_string()),
11 | 12 => ("&[u8]".to_string(), "write_bytes".to_string()),
_ => ("&[u8]".to_string(), "write_bytes".to_string()),
}
}
2026-05-03 20:44:07 -07:00
pub fn generate_rust_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
let mut generated_files = Vec::new();
2026-05-02 22:48:03 -07:00
for file_res in set.file() {
let (file_data, _) = file_res.expect("Failed to iterate file");
2026-05-04 13:46:05 -07:00
let file_proto =
FileDescriptorProto::new(file_data).expect("Failed to parse FileDescriptorProto");
2026-05-03 20:44:07 -07:00
let proto_name = file_proto.name().expect("File proto name missing");
if let Some(filter) = files_to_generate {
if !filter.contains(&proto_name.to_string()) {
continue;
}
}
let rust_file_name = format!("{}.rs", proto_name.replace(".proto", ""));
let mut output = String::new();
2026-05-04 20:19:40 -07:00
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
output.push_str("#![allow(unused_imports)]\n\n");
2026-05-03 20:44:07 -07:00
output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n");
output.push_str("use std::str;\n\n");
for dep_res in file_proto.dependency() {
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
let dep_name = str::from_utf8(dep_data).expect("Dependency name invalid utf8");
let dep_mod_path = dep_name.replace(".proto", "").replace('/', "::");
output.push_str(&format!("use crate::{};\n", dep_mod_path));
}
output.push_str("\n");
2026-05-02 22:48:03 -07:00
// Enums
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
2026-05-04 13:46:05 -07:00
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
&mut output,
);
2026-05-02 22:48:03 -07:00
}
// Messages
for msg_res in file_proto.message_type() {
let (msg_data, _) = msg_res.expect("Failed to iterate message");
2026-05-04 13:46:05 -07:00
write_message(
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
&mut output,
);
2026-05-02 22:48:03 -07:00
}
2026-05-03 20:44:07 -07:00
generated_files.push((rust_file_name, output));
}
if !generate_mod_files {
return generated_files;
}
let mut all_paths: Vec<String> = generated_files.iter().map(|(p, _)| p.clone()).collect();
all_paths.sort();
let mut mod_files: HashMap<String, HashSet<String>> = HashMap::new();
for path in &all_paths {
let parts: Vec<&str> = path.split('/').collect();
let mut current_dir = String::new();
for i in 0..parts.len() - 1 {
if !current_dir.is_empty() {
current_dir.push('/');
}
current_dir.push_str(parts[i]);
let mod_path = format!("{}/mod.rs", current_dir);
let sub_mod = parts[i + 1].replace(".rs", "");
mod_files.entry(mod_path).or_default().insert(sub_mod);
}
}
let mut root_mods = HashSet::new();
for path in &all_paths {
let parts: Vec<&str> = path.split('/').collect();
root_mods.insert(parts[0].replace(".rs", ""));
}
let mut root_mod_content = String::new();
2026-05-04 20:19:40 -07:00
root_mod_content.push_str("// @generated by protoc-gen-roto — do not edit\n");
root_mod_content.push_str("#![allow(unused_imports)]\n\n");
2026-05-03 20:44:07 -07:00
let mut sorted_root_mods: Vec<_> = root_mods.into_iter().collect();
sorted_root_mods.sort();
for m in sorted_root_mods {
root_mod_content.push_str(&format!("pub mod {};\n", m));
}
generated_files.push(("mod.rs".to_string(), root_mod_content));
for (mod_path, sub_mods) in mod_files {
let mut content = String::new();
2026-05-04 20:19:40 -07:00
content.push_str("// @generated by protoc-gen-roto — do not edit\n");
content.push_str("#![allow(unused_imports)]\n\n");
2026-05-03 20:44:07 -07:00
let mut sorted_subs: Vec<_> = sub_mods.into_iter().collect();
sorted_subs.sort();
for sub in sorted_subs {
content.push_str(&format!("pub mod {};\n", sub));
}
generated_files.push((mod_path, content));
2026-05-02 22:48:03 -07:00
}
2026-05-03 20:44:07 -07:00
generated_files
2026-05-02 22:48:03 -07:00
}