Generate modules to make easy importing
This commit is contained in:
+13
-3
@@ -11,9 +11,13 @@ struct Args {
|
|||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
input: PathBuf,
|
input: PathBuf,
|
||||||
|
|
||||||
/// Path to the output Rust file (.rs)
|
/// Path to the output directory
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
output: PathBuf,
|
output: PathBuf,
|
||||||
|
|
||||||
|
/// Files to generate. If omitted, all files are generated.
|
||||||
|
#[arg(short, long, value_delimiter = ',')]
|
||||||
|
files: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@@ -21,8 +25,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let data = fs::read(&args.input)?;
|
let data = fs::read(&args.input)?;
|
||||||
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||||
|
|
||||||
let output = generate_rust_code(&set);
|
let files = generate_rust_code(&set, args.files.as_deref(), true);
|
||||||
|
|
||||||
fs::write(&args.output, output)?;
|
for (filename, content) in files {
|
||||||
|
let path = args.output.join(filename);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(path, content)?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
+14
-15
@@ -67,36 +67,35 @@ fn handle_request(request: &CodeGeneratorRequest) -> std::result::Result<Vec<u8>
|
|||||||
|
|
||||||
let set = FileDescriptorSet::new(&set_buf)?;
|
let set = FileDescriptorSet::new(&set_buf)?;
|
||||||
|
|
||||||
|
let files_to_generate: Vec<String> = request.file_to_generate()
|
||||||
|
.filter_map(|res| {
|
||||||
|
let (bytes, _) = res.ok()?;
|
||||||
|
std::str::from_utf8(bytes).ok().map(|s| s.to_string())
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Generate the Rust code
|
// Generate the Rust code
|
||||||
info!("Generating Rust code from descriptor set...");
|
info!("Generating Rust code from descriptor set...");
|
||||||
let generated_code = generate_rust_code(&set);
|
let generated_files = generate_rust_code(&set, Some(&files_to_generate), false);
|
||||||
|
|
||||||
// Determine the output filename
|
|
||||||
let mut output_filename = "roto_generated.rs".to_string();
|
|
||||||
if let Some(first_file) = request.file_to_generate().next() {
|
|
||||||
if let Ok((name_bytes, _)) = first_file {
|
|
||||||
if let Ok(name) = std::str::from_utf8(name_bytes) {
|
|
||||||
output_filename = format!("{}.rs", name.replace(".proto", ""));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct the response
|
// Construct the response
|
||||||
let mut response_buf = vec![0u8; 1024 * 1024 * 2]; // Allocate 2MB for response
|
let mut response_buf = vec![0u8; 1024 * 1024 * 2]; // Allocate 2MB for response
|
||||||
let mut resp_builder = CodeGeneratorResponse::builder(&mut response_buf);
|
let mut resp_builder = CodeGeneratorResponse::builder(&mut response_buf);
|
||||||
|
|
||||||
|
for (filename, content) in generated_files {
|
||||||
let mut file_buf = vec![0u8; 1024 * 1024 * 2];
|
let mut file_buf = vec![0u8; 1024 * 1024 * 2];
|
||||||
let final_file = ResponseFile::builder(&mut file_buf)
|
let final_file = ResponseFile::builder(&mut file_buf)
|
||||||
.name(&output_filename)?
|
.name(&filename)?
|
||||||
.content(&generated_code)?
|
.content(&content)?
|
||||||
.finish()
|
.finish()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Failed to build ResponseFile: {:?}", e);
|
error!("Failed to build ResponseFile {}: {:?}", filename, e);
|
||||||
e
|
e
|
||||||
})?;
|
})?;
|
||||||
|
resp_builder = resp_builder.add_file(final_file)?;
|
||||||
|
}
|
||||||
|
|
||||||
let final_response_slice = resp_builder
|
let final_response_slice = resp_builder
|
||||||
.add_file(final_file)?
|
|
||||||
.finish()
|
.finish()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Failed to finish CodeGeneratorResponse: {:?}", e);
|
error!("Failed to finish CodeGeneratorResponse: {:?}", e);
|
||||||
|
|||||||
+76
-5
@@ -3,6 +3,7 @@ use crate::proto_gen::google::protobuf::descriptor::{
|
|||||||
};
|
};
|
||||||
use crate::{ProtoAccessor, Result, RotoError};
|
use crate::{ProtoAccessor, Result, RotoError};
|
||||||
use std::str;
|
use std::str;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
pub fn to_pascal_case(s: &str) -> String {
|
pub fn to_pascal_case(s: &str) -> String {
|
||||||
s.split('_')
|
s.split('_')
|
||||||
@@ -75,14 +76,37 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_rust_code(set: &FileDescriptorSet) -> String {
|
pub fn generate_rust_code(
|
||||||
let mut output = String::new();
|
set: &FileDescriptorSet,
|
||||||
output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n");
|
files_to_generate: Option<&[String]>,
|
||||||
output.push_str("use std::str;\n\n");
|
generate_mod_files: bool,
|
||||||
|
) -> Vec<(String, String)> {
|
||||||
|
let mut generated_files = Vec::new();
|
||||||
|
|
||||||
for file_res in set.file() {
|
for file_res in set.file() {
|
||||||
let (file_data, _) = file_res.expect("Failed to iterate file");
|
let (file_data, _) = file_res.expect("Failed to iterate file");
|
||||||
let file_proto = FileDescriptorProto::new(file_data).expect("Failed to parse FileDescriptorProto");
|
let file_proto = FileDescriptorProto::new(file_data).expect("Failed to parse FileDescriptorProto");
|
||||||
|
let proto_name = file_proto.name().expect("File proto name missing");
|
||||||
|
|
||||||
|
if let Some(filter) = files_to_generate {
|
||||||
|
if !filter.contains(&proto_name.to_string()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let rust_file_name = format!("{}.rs", proto_name.replace(".proto", ""));
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
output.push_str("use crate::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n");
|
||||||
|
output.push_str("use std::str;\n\n");
|
||||||
|
|
||||||
|
for dep_res in file_proto.dependency() {
|
||||||
|
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
|
||||||
|
let dep_name = str::from_utf8(dep_data).expect("Dependency name invalid utf8");
|
||||||
|
let dep_mod_path = dep_name.replace(".proto", "").replace('/', "::");
|
||||||
|
output.push_str(&format!("use crate::{};\n", dep_mod_path));
|
||||||
|
}
|
||||||
|
output.push_str("\n");
|
||||||
|
|
||||||
// Enums
|
// Enums
|
||||||
for enum_res in file_proto.enum_type() {
|
for enum_res in file_proto.enum_type() {
|
||||||
@@ -293,7 +317,54 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String {
|
|||||||
" pub fn finish(self) -> Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"
|
" pub fn finish(self) -> Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
generated_files.push((rust_file_name, output));
|
||||||
}
|
}
|
||||||
|
|
||||||
output
|
if !generate_mod_files {
|
||||||
|
return generated_files;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_paths: Vec<String> = generated_files.iter().map(|(p, _)| p.clone()).collect();
|
||||||
|
all_paths.sort();
|
||||||
|
|
||||||
|
let mut mod_files: HashMap<String, HashSet<String>> = HashMap::new();
|
||||||
|
for path in &all_paths {
|
||||||
|
let parts: Vec<&str> = path.split('/').collect();
|
||||||
|
let mut current_dir = String::new();
|
||||||
|
for i in 0..parts.len() - 1 {
|
||||||
|
if !current_dir.is_empty() {
|
||||||
|
current_dir.push('/');
|
||||||
|
}
|
||||||
|
current_dir.push_str(parts[i]);
|
||||||
|
let mod_path = format!("{}/mod.rs", current_dir);
|
||||||
|
let sub_mod = parts[i + 1].replace(".rs", "");
|
||||||
|
mod_files.entry(mod_path).or_default().insert(sub_mod);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut root_mods = HashSet::new();
|
||||||
|
for path in &all_paths {
|
||||||
|
let parts: Vec<&str> = path.split('/').collect();
|
||||||
|
root_mods.insert(parts[0].replace(".rs", ""));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut root_mod_content = String::new();
|
||||||
|
let mut sorted_root_mods: Vec<_> = root_mods.into_iter().collect();
|
||||||
|
sorted_root_mods.sort();
|
||||||
|
for m in sorted_root_mods {
|
||||||
|
root_mod_content.push_str(&format!("pub mod {};\n", m));
|
||||||
|
}
|
||||||
|
generated_files.push(("mod.rs".to_string(), root_mod_content));
|
||||||
|
|
||||||
|
for (mod_path, sub_mods) in mod_files {
|
||||||
|
let mut content = String::new();
|
||||||
|
let mut sorted_subs: Vec<_> = sub_mods.into_iter().collect();
|
||||||
|
sorted_subs.sort();
|
||||||
|
for sub in sorted_subs {
|
||||||
|
content.push_str(&format!("pub mod {};\n", sub));
|
||||||
|
}
|
||||||
|
generated_files.push((mod_path, content));
|
||||||
|
}
|
||||||
|
|
||||||
|
generated_files
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user