Split service and proto gen
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
use clap::Parser;
|
||||
use roto_codegen::generator::generate_protobuf_code;
|
||||
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(
|
||||
author,
|
||||
version,
|
||||
about = "Generates Rust accessor and builder code from a protobuf descriptor set"
|
||||
)]
|
||||
struct Args {
|
||||
/// Path to the descriptor set file (.desc)
|
||||
#[arg(short, long)]
|
||||
input: PathBuf,
|
||||
|
||||
/// Path to the output directory
|
||||
#[arg(short, long)]
|
||||
output: PathBuf,
|
||||
|
||||
/// Files to generate. If omitted, all files are generated.
|
||||
#[arg(short, long, value_delimiter = ',')]
|
||||
files: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
let data = fs::read(&args.input)?;
|
||||
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||
|
||||
let files = generate_protobuf_code(&set, args.files.as_deref(), true);
|
||||
|
||||
for (filename, content) in files {
|
||||
let path = args.output.join(filename);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(path, content)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
use clap::Parser;
|
||||
use roto_codegen::generator::generate_service_code;
|
||||
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(
|
||||
author,
|
||||
version,
|
||||
about = "Generates Rust gRPC service code from a protobuf descriptor set"
|
||||
)]
|
||||
struct Args {
|
||||
/// Path to the descriptor set file (.desc)
|
||||
#[arg(short, long)]
|
||||
input: PathBuf,
|
||||
|
||||
/// Path to the output directory
|
||||
#[arg(short, long)]
|
||||
output: PathBuf,
|
||||
|
||||
/// Files to generate. If omitted, all files are generated.
|
||||
#[arg(short, long, value_delimiter = ',')]
|
||||
files: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
let data = fs::read(&args.input)?;
|
||||
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||
|
||||
let files = generate_service_code(&set, args.files.as_deref(), true);
|
||||
|
||||
for (filename, content) in files {
|
||||
let path = args.output.join(filename);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(path, content)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
+127
-43
@@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str;
|
||||
|
||||
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse std::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
|
||||
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
|
||||
|
||||
pub fn to_pascal_case(s: &str) -> String {
|
||||
s.split('_')
|
||||
.map(|word| {
|
||||
@@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_rust_code(
|
||||
fn generate_files_common<F>(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
imports: &str,
|
||||
mut content_gen: F,
|
||||
) -> Vec<(String, String)>
|
||||
where
|
||||
F: FnMut(&FileDescriptorProto, &mut String),
|
||||
{
|
||||
let mut generated_files = Vec::new();
|
||||
|
||||
for file_res in set.file() {
|
||||
@@ -548,21 +556,7 @@ pub fn generate_rust_code(
|
||||
let mut output = String::new();
|
||||
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
|
||||
output.push_str("#[allow(unused_imports)]\n\n");
|
||||
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n");
|
||||
output.push_str("use std::str;\n");
|
||||
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
|
||||
output.push_str("use tonic::{Request, Response, Status};\n");
|
||||
output.push_str("use tokio_stream::Stream;\n");
|
||||
output.push_str("use std::pin::Pin;\n");
|
||||
output.push_str("use std::sync::Arc;\n");
|
||||
output.push_str("use std::task::{Context, Poll};\n");
|
||||
output.push_str("use std::future::Future;\n");
|
||||
output.push_str("use tonic::body::BoxBody;\n");
|
||||
output.push_str("use tower::Service;\n");
|
||||
output.push_str("use futures_util::StreamExt;\n");
|
||||
output.push_str("use http_body_util::BodyExt;\n");
|
||||
output.push_str("use http_body::Body;\n");
|
||||
output.push_str("use crate::{BufferPool, StatusBody};\n\n");
|
||||
output.push_str(imports);
|
||||
|
||||
for dep_res in file_proto.dependency() {
|
||||
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
|
||||
@@ -572,33 +566,8 @@ pub fn generate_rust_code(
|
||||
}
|
||||
output.push_str("\n");
|
||||
|
||||
// Enums
|
||||
for enum_res in file_proto.enum_type() {
|
||||
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
||||
write_enum(
|
||||
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
content_gen(&file_proto, &mut output);
|
||||
|
||||
// Messages
|
||||
for msg_res in file_proto.message_type() {
|
||||
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
||||
write_message(
|
||||
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
|
||||
// Services
|
||||
for svc_res in file_proto.service() {
|
||||
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
||||
write_service(
|
||||
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
||||
file_proto.package().unwrap_or(""),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
generated_files.push((rust_file_name, output));
|
||||
}
|
||||
|
||||
@@ -655,6 +624,121 @@ pub fn generate_rust_code(
|
||||
generated_files
|
||||
}
|
||||
|
||||
pub fn generate_protobuf_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
generate_mod_files,
|
||||
DATA_IMPORTS,
|
||||
|file_proto, output| {
|
||||
// Enums
|
||||
for enum_res in file_proto.enum_type() {
|
||||
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
||||
write_enum(
|
||||
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
||||
output,
|
||||
);
|
||||
}
|
||||
|
||||
// Messages
|
||||
for msg_res in file_proto.message_type() {
|
||||
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
||||
write_message(
|
||||
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
||||
output,
|
||||
);
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_service_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
generate_mod_files,
|
||||
SERVICE_IMPORTS,
|
||||
|file_proto, output| {
|
||||
let package = file_proto.package().unwrap_or("").to_string();
|
||||
// Services
|
||||
for svc_res in file_proto.service() {
|
||||
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
||||
write_service(
|
||||
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
||||
&package,
|
||||
output,
|
||||
);
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_rust_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
let protobuf_files = generate_protobuf_code(set, files_to_generate, false);
|
||||
let service_files = generate_service_code(set, files_to_generate, false);
|
||||
|
||||
let mut combined_files: HashMap<String, String> = HashMap::new();
|
||||
|
||||
for (filename, content) in protobuf_files {
|
||||
combined_files.insert(filename, content);
|
||||
}
|
||||
|
||||
for (filename, content) in service_files {
|
||||
if let Some(existing_content) = combined_files.get_mut(&filename) {
|
||||
let stripped = strip_boilerplate(&content);
|
||||
existing_content.push_str("\n");
|
||||
existing_content.push_str(&stripped);
|
||||
} else {
|
||||
combined_files.insert(filename, content);
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = combined_files.into_iter().collect::<Vec<_>>();
|
||||
result.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
if generate_mod_files {
|
||||
let mods = generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
true,
|
||||
"",
|
||||
|_, _| {},
|
||||
);
|
||||
for (filename, content) in mods {
|
||||
if filename == "mod.rs" || filename.contains("/mod.rs") {
|
||||
result.push((filename, content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn strip_boilerplate(content: &str) -> String {
|
||||
// Find the first occurrence of a service definition or a trait
|
||||
// In our case, the services start after the dependency imports and a newline.
|
||||
if let Some(idx) = content.find("pub trait ") {
|
||||
return content[idx..].to_string();
|
||||
}
|
||||
if let Some(idx) = content.find("pub struct ") {
|
||||
// This might be a message, but generate_service_code only generates services (and their server structs)
|
||||
return content[idx..].to_string();
|
||||
}
|
||||
content.to_string()
|
||||
}
|
||||
|
||||
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
|
||||
let svc_name = to_pascal_case(svc_proto.name().unwrap());
|
||||
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
|
||||
|
||||
Reference in New Issue
Block a user