diff --git a/codegen/src/bin/protobuf-generator.rs b/codegen/src/bin/protobuf-generator.rs new file mode 100644 index 0000000..9f8b8f0 --- /dev/null +++ b/codegen/src/bin/protobuf-generator.rs @@ -0,0 +1,42 @@ +use clap::Parser; +use roto_codegen::generator::generate_protobuf_code; +use roto_codegen::google::protobuf::descriptor::FileDescriptorSet; +use std::fs; +use std::path::PathBuf; + +#[derive(Parser)] +#[command( + author, + version, + about = "Generates Rust accessor and builder code from a protobuf descriptor set" +)] +struct Args { + /// Path to the descriptor set file (.desc) + #[arg(short, long)] + input: PathBuf, + + /// Path to the output directory + #[arg(short, long)] + output: PathBuf, + + /// Files to generate. If omitted, all files are generated. + #[arg(short, long, value_delimiter = ',')] + files: Option>, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let data = fs::read(&args.input)?; + let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet"); + + let files = generate_protobuf_code(&set, args.files.as_deref(), true); + + for (filename, content) in files { + let path = args.output.join(filename); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, content)?; + } + Ok(()) +} diff --git a/codegen/src/bin/service-generator.rs b/codegen/src/bin/service-generator.rs new file mode 100644 index 0000000..828d989 --- /dev/null +++ b/codegen/src/bin/service-generator.rs @@ -0,0 +1,42 @@ +use clap::Parser; +use roto_codegen::generator::generate_service_code; +use roto_codegen::google::protobuf::descriptor::FileDescriptorSet; +use std::fs; +use std::path::PathBuf; + +#[derive(Parser)] +#[command( + author, + version, + about = "Generates Rust gRPC service code from a protobuf descriptor set" +)] +struct Args { + /// Path to the descriptor set file (.desc) + #[arg(short, long)] + input: PathBuf, + + /// Path to the output directory + #[arg(short, long)] + output: PathBuf, + + /// Files to generate. If omitted, all files are generated. + #[arg(short, long, value_delimiter = ',')] + files: Option>, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let data = fs::read(&args.input)?; + let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet"); + + let files = generate_service_code(&set, args.files.as_deref(), true); + + for (filename, content) in files { + let path = args.output.join(filename); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, content)?; + } + Ok(()) +} diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index 764bc12..c9e2cf5 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor; use std::collections::{HashMap, HashSet}; use std::str; +const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse std::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n"; +const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n"; + pub fn to_pascal_case(s: &str) -> String { s.split('_') .map(|word| { @@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) { } } -pub fn generate_rust_code( +fn generate_files_common( set: &FileDescriptorSet, files_to_generate: Option<&[String]>, generate_mod_files: bool, -) -> Vec<(String, String)> { + imports: &str, + mut content_gen: F, +) -> Vec<(String, String)> +where + F: FnMut(&FileDescriptorProto, &mut String), +{ let mut generated_files = Vec::new(); for file_res in set.file() { @@ -548,21 +556,7 @@ pub fn generate_rust_code( let mut output = String::new(); output.push_str("// @generated by protoc-gen-roto — do not edit\n"); output.push_str("#[allow(unused_imports)]\n\n"); - output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n"); - output.push_str("use std::str;\n"); - output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n"); - output.push_str("use tonic::{Request, Response, Status};\n"); - output.push_str("use tokio_stream::Stream;\n"); - output.push_str("use std::pin::Pin;\n"); - output.push_str("use std::sync::Arc;\n"); - output.push_str("use std::task::{Context, Poll};\n"); - output.push_str("use std::future::Future;\n"); - output.push_str("use tonic::body::BoxBody;\n"); - output.push_str("use tower::Service;\n"); - output.push_str("use futures_util::StreamExt;\n"); - output.push_str("use http_body_util::BodyExt;\n"); - output.push_str("use http_body::Body;\n"); - output.push_str("use crate::{BufferPool, StatusBody};\n\n"); + output.push_str(imports); for dep_res in file_proto.dependency() { let (dep_data, _) = dep_res.expect("Failed to iterate dependency"); @@ -572,33 +566,8 @@ pub fn generate_rust_code( } output.push_str("\n"); - // Enums - for enum_res in file_proto.enum_type() { - let (enum_data, _) = enum_res.expect("Failed to iterate enum"); - write_enum( - &EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"), - &mut output, - ); - } + content_gen(&file_proto, &mut output); - // Messages - for msg_res in file_proto.message_type() { - let (msg_data, _) = msg_res.expect("Failed to iterate message"); - write_message( - &DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"), - &mut output, - ); - } - - // Services - for svc_res in file_proto.service() { - let (svc_data, _) = svc_res.expect("Failed to iterate service"); - write_service( - &ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"), - file_proto.package().unwrap_or(""), - &mut output, - ); - } generated_files.push((rust_file_name, output)); } @@ -655,6 +624,121 @@ pub fn generate_rust_code( generated_files } +pub fn generate_protobuf_code( + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + generate_files_common( + set, + files_to_generate, + generate_mod_files, + DATA_IMPORTS, + |file_proto, output| { + // Enums + for enum_res in file_proto.enum_type() { + let (enum_data, _) = enum_res.expect("Failed to iterate enum"); + write_enum( + &EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"), + output, + ); + } + + // Messages + for msg_res in file_proto.message_type() { + let (msg_data, _) = msg_res.expect("Failed to iterate message"); + write_message( + &DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"), + output, + ); + } + }, + ) +} + +pub fn generate_service_code( + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + generate_files_common( + set, + files_to_generate, + generate_mod_files, + SERVICE_IMPORTS, + |file_proto, output| { + let package = file_proto.package().unwrap_or("").to_string(); + // Services + for svc_res in file_proto.service() { + let (svc_data, _) = svc_res.expect("Failed to iterate service"); + write_service( + &ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"), + &package, + output, + ); + } + }, + ) +} + +pub fn generate_rust_code( + set: &FileDescriptorSet, + files_to_generate: Option<&[String]>, + generate_mod_files: bool, +) -> Vec<(String, String)> { + let protobuf_files = generate_protobuf_code(set, files_to_generate, false); + let service_files = generate_service_code(set, files_to_generate, false); + + let mut combined_files: HashMap = HashMap::new(); + + for (filename, content) in protobuf_files { + combined_files.insert(filename, content); + } + + for (filename, content) in service_files { + if let Some(existing_content) = combined_files.get_mut(&filename) { + let stripped = strip_boilerplate(&content); + existing_content.push_str("\n"); + existing_content.push_str(&stripped); + } else { + combined_files.insert(filename, content); + } + } + + let mut result = combined_files.into_iter().collect::>(); + result.sort_by(|a, b| a.0.cmp(&b.0)); + + if generate_mod_files { + let mods = generate_files_common( + set, + files_to_generate, + true, + "", + |_, _| {}, + ); + for (filename, content) in mods { + if filename == "mod.rs" || filename.contains("/mod.rs") { + result.push((filename, content)); + } + } + } + + result +} + +fn strip_boilerplate(content: &str) -> String { + // Find the first occurrence of a service definition or a trait + // In our case, the services start after the dependency imports and a newline. + if let Some(idx) = content.find("pub trait ") { + return content[idx..].to_string(); + } + if let Some(idx) = content.find("pub struct ") { + // This might be a message, but generate_service_code only generates services (and their server structs) + return content[idx..].to_string(); + } + content.to_string() +} + fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) { let svc_name = to_pascal_case(svc_proto.name().unwrap()); output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));