From 20e4fb909b053727d8eac896c9867399e9df77c3 Mon Sep 17 00:00:00 2001 From: charles Date: Sun, 3 May 2026 20:44:07 -0700 Subject: [PATCH] Generate modules to make easy importing --- src/bin/generator.rs | 16 ++++++-- src/bin/protoc-gen-roto.rs | 41 ++++++++++--------- src/generator.rs | 81 +++++++++++++++++++++++++++++++++++--- 3 files changed, 109 insertions(+), 29 deletions(-) diff --git a/src/bin/generator.rs b/src/bin/generator.rs index 0241754..24a3780 100644 --- a/src/bin/generator.rs +++ b/src/bin/generator.rs @@ -11,9 +11,13 @@ struct Args { #[arg(short, long)] input: PathBuf, - /// Path to the output Rust file (.rs) + /// Path to the output directory #[arg(short, long)] output: PathBuf, + + /// Files to generate. If omitted, all files are generated. + #[arg(short, long, value_delimiter = ',')] + files: Option>, } fn main() -> Result<(), Box> { @@ -21,8 +25,14 @@ fn main() -> Result<(), Box> { let data = fs::read(&args.input)?; let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet"); - let output = generate_rust_code(&set); + let files = generate_rust_code(&set, args.files.as_deref(), true); - fs::write(&args.output, output)?; + for (filename, content) in files { + let path = args.output.join(filename); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, content)?; + } Ok(()) } diff --git a/src/bin/protoc-gen-roto.rs b/src/bin/protoc-gen-roto.rs index 1f7a10a..f0f111c 100644 --- a/src/bin/protoc-gen-roto.rs +++ b/src/bin/protoc-gen-roto.rs @@ -67,36 +67,35 @@ fn handle_request(request: &CodeGeneratorRequest) -> std::result::Result let set = FileDescriptorSet::new(&set_buf)?; + let files_to_generate: Vec = request.file_to_generate() + .filter_map(|res| { + let (bytes, _) = res.ok()?; + std::str::from_utf8(bytes).ok().map(|s| s.to_string()) + }) + .collect(); + // Generate the Rust code info!("Generating Rust code from descriptor set..."); - let generated_code = generate_rust_code(&set); - - // Determine the output filename - let mut output_filename = "roto_generated.rs".to_string(); - if let Some(first_file) = request.file_to_generate().next() { - if let Ok((name_bytes, _)) = first_file { - if let Ok(name) = std::str::from_utf8(name_bytes) { - output_filename = format!("{}.rs", name.replace(".proto", "")); - } - } - } + let generated_files = generate_rust_code(&set, Some(&files_to_generate), false); // Construct the response let mut response_buf = vec![0u8; 1024 * 1024 * 2]; // Allocate 2MB for response let mut resp_builder = CodeGeneratorResponse::builder(&mut response_buf); - let mut file_buf = vec![0u8; 1024 * 1024 * 2]; - let final_file = ResponseFile::builder(&mut file_buf) - .name(&output_filename)? - .content(&generated_code)? - .finish() - .map_err(|e| { - error!("Failed to build ResponseFile: {:?}", e); - e - })?; + for (filename, content) in generated_files { + let mut file_buf = vec![0u8; 1024 * 1024 * 2]; + let final_file = ResponseFile::builder(&mut file_buf) + .name(&filename)? + .content(&content)? + .finish() + .map_err(|e| { + error!("Failed to build ResponseFile {}: {:?}", filename, e); + e + })?; + resp_builder = resp_builder.add_file(final_file)?; + } let final_response_slice = resp_builder - .add_file(final_file)? .finish() .map_err(|e| { error!("Failed to finish CodeGeneratorResponse: {:?}", e); diff --git a/src/generator.rs b/src/generator.rs index 12944dd..e7734ae 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -3,6 +3,7 @@ use crate::proto_gen::google::protobuf::descriptor::{ }; use crate::{ProtoAccessor, Result, RotoError}; use std::str; +use std::collections::{HashMap, HashSet}; pub fn to_pascal_case(s: &str) -> String { s.split('_') @@ -75,14 +76,37 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) { } } -pub fn generate_rust_code(set: &FileDescriptorSet) -> String { - let mut output = String::new(); - output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n"); - output.push_str("use std::str;\n\n"); +pub fn generate_rust_code( + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + let mut generated_files = Vec::new(); for file_res in set.file() { let (file_data, _) = file_res.expect("Failed to iterate file"); let file_proto = FileDescriptorProto::new(file_data).expect("Failed to parse FileDescriptorProto"); + let proto_name = file_proto.name().expect("File proto name missing"); + + if let Some(filter) = files_to_generate { + if !filter.contains(&proto_name.to_string()) { + continue; + } + } + + let rust_file_name = format!("{}.rs", proto_name.replace(".proto", "")); + + let mut output = String::new(); + output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n"); + output.push_str("use std::str;\n\n"); + + for dep_res in file_proto.dependency() { + let (dep_data, _) = dep_res.expect("Failed to iterate dependency"); + let dep_name = str::from_utf8(dep_data).expect("Dependency name invalid utf8"); + let dep_mod_path = dep_name.replace(".proto", "").replace('/', "::"); + output.push_str(&format!("use crate::{};\n", dep_mod_path)); + } + output.push_str("\n"); // Enums for enum_res in file_proto.enum_type() { @@ -293,7 +317,54 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { " pub fn finish(self) -> Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n" )); } + generated_files.push((rust_file_name, output)); } - output + if !generate_mod_files { + return generated_files; + } + + let mut all_paths: Vec = generated_files.iter().map(|(p, _)| p.clone()).collect(); + all_paths.sort(); + + let mut mod_files: HashMap> = HashMap::new(); + for path in &all_paths { + let parts: Vec<&str> = path.split('/').collect(); + let mut current_dir = String::new(); + for i in 0..parts.len() - 1 { + if !current_dir.is_empty() { + current_dir.push('/'); + } + current_dir.push_str(parts[i]); + let mod_path = format!("{}/mod.rs", current_dir); + let sub_mod = parts[i + 1].replace(".rs", ""); + mod_files.entry(mod_path).or_default().insert(sub_mod); + } + } + + let mut root_mods = HashSet::new(); + for path in &all_paths { + let parts: Vec<&str> = path.split('/').collect(); + root_mods.insert(parts[0].replace(".rs", "")); + } + + let mut root_mod_content = String::new(); + let mut sorted_root_mods: Vec<_> = root_mods.into_iter().collect(); + sorted_root_mods.sort(); + for m in sorted_root_mods { + root_mod_content.push_str(&format!("pub mod {};\n", m)); + } + generated_files.push(("mod.rs".to_string(), root_mod_content)); + + for (mod_path, sub_mods) in mod_files { + let mut content = String::new(); + let mut sorted_subs: Vec<_> = sub_mods.into_iter().collect(); + sorted_subs.sort(); + for sub in sorted_subs { + content.push_str(&format!("pub mod {};\n", sub)); + } + generated_files.push((mod_path, content)); + } + + generated_files }