Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fa4d8cca83 | |||
| 956993d1d0 | |||
| b2c5639338 | |||
| 33f3e58f74 |
@@ -0,0 +1,3 @@
|
||||
[submodule "grpc_bench"]
|
||||
path = grpc_bench
|
||||
url = https://github.com/Lesnyrumcajs/grpc_bench.git
|
||||
Generated
+7
@@ -832,6 +832,13 @@ version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
|
||||
|
||||
[[package]]
|
||||
name = "no_std_test"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"roto-runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
|
||||
@@ -6,8 +6,15 @@ members = [
|
||||
"benches",
|
||||
"roto-tonic",
|
||||
"examples/hello_world",
|
||||
"examples/no_std_test",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"test_gen_project"
|
||||
]
|
||||
|
||||
[profile.dev]
|
||||
panic = "abort"
|
||||
|
||||
[profile.release]
|
||||
panic = "abort"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use clap::Parser;
|
||||
use roto_codegen::generator::generate_rust_code;
|
||||
use roto_codegen::generator::generate_protobuf_code;
|
||||
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
@@ -29,7 +29,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = fs::read(&args.input)?;
|
||||
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||
|
||||
let files = generate_rust_code(&set, args.files.as_deref(), true);
|
||||
let files = generate_protobuf_code(&set, args.files.as_deref(), true);
|
||||
|
||||
for (filename, content) in files {
|
||||
let path = args.output.join(filename);
|
||||
@@ -0,0 +1,42 @@
|
||||
use clap::Parser;
|
||||
use roto_codegen::generator::generate_service_code;
|
||||
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(
|
||||
author,
|
||||
version,
|
||||
about = "Generates Rust gRPC service code from a protobuf descriptor set"
|
||||
)]
|
||||
struct Args {
|
||||
/// Path to the descriptor set file (.desc)
|
||||
#[arg(short, long)]
|
||||
input: PathBuf,
|
||||
|
||||
/// Path to the output directory
|
||||
#[arg(short, long)]
|
||||
output: PathBuf,
|
||||
|
||||
/// Files to generate. If omitted, all files are generated.
|
||||
#[arg(short, long, value_delimiter = ',')]
|
||||
files: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
let data = fs::read(&args.input)?;
|
||||
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
|
||||
|
||||
let files = generate_service_code(&set, args.files.as_deref(), true);
|
||||
|
||||
for (filename, content) in files {
|
||||
let path = args.output.join(filename);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(path, content)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
+138
-53
@@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str;
|
||||
|
||||
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
|
||||
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
|
||||
|
||||
pub fn to_pascal_case(s: &str) -> String {
|
||||
s.split('_')
|
||||
.map(|word| {
|
||||
@@ -437,18 +440,18 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
||||
|
||||
output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
|
||||
|
||||
output.push_str(&format!("pub struct Owned{} {{\n", msg_name));
|
||||
output.push_str(&format!("#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", msg_name));
|
||||
output.push_str(" pub data: bytes::Bytes,\n");
|
||||
output.push_str("}\n\n");
|
||||
|
||||
output.push_str(&format!("impl roto_runtime::RotoOwned for Owned{} {{\n", msg_name));
|
||||
output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", msg_name));
|
||||
output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name));
|
||||
output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name));
|
||||
output.push_str(&format!(" {}::new(&self.data).expect(\"failed to create reader\")\n", msg_name));
|
||||
output.push_str(" }\n");
|
||||
output.push_str("}\n\n");
|
||||
|
||||
output.push_str(&format!("impl roto_runtime::RotoMessage for Owned{} {{\n", msg_name));
|
||||
output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", msg_name));
|
||||
output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {\n");
|
||||
output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name));
|
||||
output.push_str(" }\n\n");
|
||||
@@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_rust_code(
|
||||
fn generate_files_common<F>(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
imports: &str,
|
||||
mut content_gen: F,
|
||||
) -> Vec<(String, String)>
|
||||
where
|
||||
F: FnMut(&FileDescriptorProto, &mut String),
|
||||
{
|
||||
let mut generated_files = Vec::new();
|
||||
|
||||
for file_res in set.file() {
|
||||
@@ -548,21 +556,7 @@ pub fn generate_rust_code(
|
||||
let mut output = String::new();
|
||||
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
|
||||
output.push_str("#[allow(unused_imports)]\n\n");
|
||||
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n");
|
||||
output.push_str("use std::str;\n");
|
||||
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
|
||||
output.push_str("use tonic::{Request, Response, Status};\n");
|
||||
output.push_str("use tokio_stream::Stream;\n");
|
||||
output.push_str("use std::pin::Pin;\n");
|
||||
output.push_str("use std::sync::Arc;\n");
|
||||
output.push_str("use std::task::{Context, Poll};\n");
|
||||
output.push_str("use std::future::Future;\n");
|
||||
output.push_str("use tonic::body::BoxBody;\n");
|
||||
output.push_str("use tower::Service;\n");
|
||||
output.push_str("use futures_util::StreamExt;\n");
|
||||
output.push_str("use http_body_util::BodyExt;\n");
|
||||
output.push_str("use http_body::Body;\n");
|
||||
output.push_str("use crate::{BufferPool, StatusBody};\n\n");
|
||||
output.push_str(imports);
|
||||
|
||||
for dep_res in file_proto.dependency() {
|
||||
let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
|
||||
@@ -572,33 +566,8 @@ pub fn generate_rust_code(
|
||||
}
|
||||
output.push_str("\n");
|
||||
|
||||
// Enums
|
||||
for enum_res in file_proto.enum_type() {
|
||||
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
||||
write_enum(
|
||||
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
content_gen(&file_proto, &mut output);
|
||||
|
||||
// Messages
|
||||
for msg_res in file_proto.message_type() {
|
||||
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
||||
write_message(
|
||||
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
|
||||
// Services
|
||||
for svc_res in file_proto.service() {
|
||||
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
||||
write_service(
|
||||
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
||||
file_proto.package().unwrap_or(""),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
generated_files.push((rust_file_name, output));
|
||||
}
|
||||
|
||||
@@ -655,6 +624,121 @@ pub fn generate_rust_code(
|
||||
generated_files
|
||||
}
|
||||
|
||||
pub fn generate_protobuf_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
generate_mod_files,
|
||||
DATA_IMPORTS,
|
||||
|file_proto, output| {
|
||||
// Enums
|
||||
for enum_res in file_proto.enum_type() {
|
||||
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
|
||||
write_enum(
|
||||
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
|
||||
output,
|
||||
);
|
||||
}
|
||||
|
||||
// Messages
|
||||
for msg_res in file_proto.message_type() {
|
||||
let (msg_data, _) = msg_res.expect("Failed to iterate message");
|
||||
write_message(
|
||||
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
|
||||
output,
|
||||
);
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_service_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
generate_mod_files,
|
||||
SERVICE_IMPORTS,
|
||||
|file_proto, output| {
|
||||
let package = file_proto.package().unwrap_or("").to_string();
|
||||
// Services
|
||||
for svc_res in file_proto.service() {
|
||||
let (svc_data, _) = svc_res.expect("Failed to iterate service");
|
||||
write_service(
|
||||
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
|
||||
&package,
|
||||
output,
|
||||
);
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_rust_code(
|
||||
set: &FileDescriptorSet,
|
||||
files_to_generate: Option<&[String]>,
|
||||
generate_mod_files: bool,
|
||||
) -> Vec<(String, String)> {
|
||||
let protobuf_files = generate_protobuf_code(set, files_to_generate, false);
|
||||
let service_files = generate_service_code(set, files_to_generate, false);
|
||||
|
||||
let mut combined_files: HashMap<String, String> = HashMap::new();
|
||||
|
||||
for (filename, content) in protobuf_files {
|
||||
combined_files.insert(filename, content);
|
||||
}
|
||||
|
||||
for (filename, content) in service_files {
|
||||
if let Some(existing_content) = combined_files.get_mut(&filename) {
|
||||
let stripped = strip_boilerplate(&content);
|
||||
existing_content.push_str("\n");
|
||||
existing_content.push_str(&stripped);
|
||||
} else {
|
||||
combined_files.insert(filename, content);
|
||||
}
|
||||
}
|
||||
|
||||
let mut result = combined_files.into_iter().collect::<Vec<_>>();
|
||||
result.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
if generate_mod_files {
|
||||
let mods = generate_files_common(
|
||||
set,
|
||||
files_to_generate,
|
||||
true,
|
||||
"",
|
||||
|_, _| {},
|
||||
);
|
||||
for (filename, content) in mods {
|
||||
if filename == "mod.rs" || filename.contains("/mod.rs") {
|
||||
result.push((filename, content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn strip_boilerplate(content: &str) -> String {
|
||||
// Find the first occurrence of a service definition or a trait
|
||||
// In our case, the services start after the dependency imports and a newline.
|
||||
if let Some(idx) = content.find("pub trait ") {
|
||||
return content[idx..].to_string();
|
||||
}
|
||||
if let Some(idx) = content.find("pub struct ") {
|
||||
// This might be a message, but generate_service_code only generates services (and their server structs)
|
||||
return content[idx..].to_string();
|
||||
}
|
||||
content.to_string()
|
||||
}
|
||||
|
||||
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
|
||||
let svc_name = to_pascal_case(svc_proto.name().unwrap());
|
||||
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
|
||||
@@ -754,22 +838,23 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
|
||||
for method_res in svc_proto.method() {
|
||||
let (method_data, _) = method_res.expect("Failed to iterate method");
|
||||
let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
|
||||
let method_name = to_snake_case(method_proto.name().unwrap());
|
||||
let original_method_name = method_proto.name().unwrap().to_string();
|
||||
let method_name = to_snake_case(&original_method_name);
|
||||
let input_full_name = method_proto.input_type().unwrap();
|
||||
let input_type = input_full_name.split('.').last().unwrap();
|
||||
let input_owned = format!("Owned{}", input_type);
|
||||
let server_streaming = method_proto.server_streaming().unwrap_or(false);
|
||||
methods.push((method_name, input_owned, server_streaming));
|
||||
methods.push((original_method_name, method_name, input_owned, server_streaming));
|
||||
}
|
||||
|
||||
for (method_name, input_owned, server_streaming) in methods {
|
||||
for (original_method_name, method_name, input_owned, server_streaming) in methods {
|
||||
if server_streaming {
|
||||
// For streaming RPCs, we don't implement the server logic yet.
|
||||
// We just make it compile by returning a "not implemented" response.
|
||||
let full_path = if package.is_empty() {
|
||||
format!("/{}/{}", svc_proto.name().unwrap(), method_name)
|
||||
format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
|
||||
} else {
|
||||
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), method_name)
|
||||
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
|
||||
};
|
||||
output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
|
||||
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
|
||||
@@ -778,9 +863,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
|
||||
continue;
|
||||
}
|
||||
let full_path = if package.is_empty() {
|
||||
format!("/{}/{}", svc_proto.name().unwrap(), method_name)
|
||||
format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
|
||||
} else {
|
||||
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), method_name)
|
||||
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
|
||||
};
|
||||
output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
|
||||
output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned));
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "no_std_test"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
roto-runtime = { path = "../../runtime", default-features = false }
|
||||
|
||||
[profile.dev]
|
||||
panic = "abort"
|
||||
|
||||
[profile.release]
|
||||
panic = "abort"
|
||||
@@ -0,0 +1,16 @@
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
use roto_runtime::ProtoAccessor;
|
||||
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
loop {}
|
||||
}
|
||||
|
||||
#[unsafe(no_mangle)]
|
||||
pub extern "C" fn _start() -> ! {
|
||||
let _data = [0u8; 0];
|
||||
let _ = ProtoAccessor::new(&_data);
|
||||
loop {}
|
||||
}
|
||||
Submodule
+1
Submodule grpc_bench added at c645de5855
@@ -463,7 +463,6 @@ impl Service<http::Request<BoxBody>> for InteropServiceServer {
|
||||
let payload = bytes_vec.slice(5..);
|
||||
let mut routed = false;
|
||||
|
||||
|
||||
if path == "/interop.InteropService/UnaryCall" {
|
||||
let request_msg = match OwnedUnaryRequest::decode(payload) {
|
||||
Ok(msg) => msg,
|
||||
@@ -495,7 +494,7 @@ impl Service<http::Request<BoxBody>> for InteropServiceServer {
|
||||
routed = true;
|
||||
return Ok(http::Response::builder().status(200).header("content-type", "application/grpc").body(res_body).unwrap());
|
||||
}
|
||||
if path == "/interop.InteropService/streaming_call" {
|
||||
if path == "/interop.InteropService/StreamingCall" {
|
||||
let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));
|
||||
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
|
||||
}
|
||||
|
||||
+6
-1
@@ -4,4 +4,9 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.7"
|
||||
bytes = { version = "1.7", default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["std", "alloc"]
|
||||
std = []
|
||||
alloc = []
|
||||
|
||||
+9
-3
@@ -1,4 +1,9 @@
|
||||
use std::fmt;
|
||||
#![no_std]
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
extern crate std;
|
||||
|
||||
use core::fmt;
|
||||
use bytes::BufMut;
|
||||
|
||||
pub struct MapFieldIterator<'a> {
|
||||
@@ -51,9 +56,10 @@ impl fmt::Display for RotoError {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for RotoError {}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RotoError>;
|
||||
pub type Result<T> = core::result::Result<T, RotoError>;
|
||||
|
||||
pub trait RotoOwned {
|
||||
type Reader<'a> where Self: 'a;
|
||||
@@ -686,7 +692,7 @@ mod tests {
|
||||
.iter_repeated(18)
|
||||
.map(|r| {
|
||||
let (val, _) = r.expect("Failed to decode repeated string");
|
||||
std::str::from_utf8(val).expect("Invalid utf8")
|
||||
core::str::from_utf8(val).expect("Invalid utf8")
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(repeated_strings, vec!["one", "two", "three"]);
|
||||
|
||||
Reference in New Issue
Block a user