Compare commits

...

4 Commits

Author SHA1 Message Date
charles fa4d8cca83 Support no_std in roto-runtime
Add #![no_std] to the runtime crate and introduce optional std and
alloc features. Update the code generator to be compatible and add a
no_std_test example. Remove the generator binary.
2026-05-17 19:55:44 -07:00
charles 956993d1d0 Split service and proto gen 2026-05-17 18:53:00 -07:00
charles b2c5639338 add: grpc_bench 2026-05-17 16:45:44 -07:00
charles 33f3e58f74 Use original method names for gRPC paths
Stop converting method names to snake_case when generating the gRPC
service paths to maintain compatibility with protobuf definitions.
2026-05-17 10:44:07 -07:00
12 changed files with 245 additions and 61 deletions
+3
View File
@@ -0,0 +1,3 @@
[submodule "grpc_bench"]
path = grpc_bench
url = https://github.com/Lesnyrumcajs/grpc_bench.git
Generated
+7
View File
@@ -832,6 +832,13 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "no_std_test"
version = "0.1.0"
dependencies = [
"roto-runtime",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.19" version = "0.2.19"
+7
View File
@@ -6,8 +6,15 @@ members = [
"benches", "benches",
"roto-tonic", "roto-tonic",
"examples/hello_world", "examples/hello_world",
"examples/no_std_test",
] ]
exclude = [ exclude = [
"test_gen_project" "test_gen_project"
] ]
[profile.dev]
panic = "abort"
[profile.release]
panic = "abort"
@@ -1,5 +1,5 @@
use clap::Parser; use clap::Parser;
use roto_codegen::generator::generate_rust_code; use roto_codegen::generator::generate_protobuf_code;
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet; use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
@@ -29,7 +29,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let data = fs::read(&args.input)?; let data = fs::read(&args.input)?;
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet"); let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
let files = generate_rust_code(&set, args.files.as_deref(), true); let files = generate_protobuf_code(&set, args.files.as_deref(), true);
for (filename, content) in files { for (filename, content) in files {
let path = args.output.join(filename); let path = args.output.join(filename);
+42
View File
@@ -0,0 +1,42 @@
use clap::Parser;
use roto_codegen::generator::generate_service_code;
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
use std::fs;
use std::path::PathBuf;
#[derive(Parser)]
#[command(
author,
version,
about = "Generates Rust gRPC service code from a protobuf descriptor set"
)]
struct Args {
/// Path to the descriptor set file (.desc)
#[arg(short, long)]
input: PathBuf,
/// Path to the output directory
#[arg(short, long)]
output: PathBuf,
/// Files to generate. If omitted, all files are generated.
#[arg(short, long, value_delimiter = ',')]
files: Option<Vec<String>>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let data = fs::read(&args.input)?;
let set = FileDescriptorSet::new(&data).expect("Failed to parse FileDescriptorSet");
let files = generate_service_code(&set, args.files.as_deref(), true);
for (filename, content) in files {
let path = args.output.join(filename);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, content)?;
}
Ok(())
}
+138 -53
View File
@@ -6,6 +6,9 @@ use roto_runtime::ProtoAccessor;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::str; use std::str;
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
pub fn to_pascal_case(s: &str) -> String { pub fn to_pascal_case(s: &str) -> String {
s.split('_') s.split('_')
.map(|word| { .map(|word| {
@@ -437,18 +440,18 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n")); output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
output.push_str(&format!("pub struct Owned{} {{\n", msg_name)); output.push_str(&format!("#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", msg_name));
output.push_str(" pub data: bytes::Bytes,\n"); output.push_str(" pub data: bytes::Bytes,\n");
output.push_str("}\n\n"); output.push_str("}\n\n");
output.push_str(&format!("impl roto_runtime::RotoOwned for Owned{} {{\n", msg_name)); output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", msg_name));
output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name)); output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name));
output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name)); output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name));
output.push_str(&format!(" {}::new(&self.data).expect(\"failed to create reader\")\n", msg_name)); output.push_str(&format!(" {}::new(&self.data).expect(\"failed to create reader\")\n", msg_name));
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str("}\n\n"); output.push_str("}\n\n");
output.push_str(&format!("impl roto_runtime::RotoMessage for Owned{} {{\n", msg_name)); output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", msg_name));
output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {\n"); output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {\n");
output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name)); output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name));
output.push_str(" }\n\n"); output.push_str(" }\n\n");
@@ -524,11 +527,16 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) {
} }
} }
pub fn generate_rust_code( fn generate_files_common<F>(
set: &FileDescriptorSet, set: &FileDescriptorSet,
files_to_generate: Option<&[String]>, files_to_generate: Option<&[String]>,
generate_mod_files: bool, generate_mod_files: bool,
) -> Vec<(String, String)> { imports: &str,
mut content_gen: F,
) -> Vec<(String, String)>
where
F: FnMut(&FileDescriptorProto, &mut String),
{
let mut generated_files = Vec::new(); let mut generated_files = Vec::new();
for file_res in set.file() { for file_res in set.file() {
@@ -548,21 +556,7 @@ pub fn generate_rust_code(
let mut output = String::new(); let mut output = String::new();
output.push_str("// @generated by protoc-gen-roto — do not edit\n"); output.push_str("// @generated by protoc-gen-roto — do not edit\n");
output.push_str("#[allow(unused_imports)]\n\n"); output.push_str("#[allow(unused_imports)]\n\n");
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n"); output.push_str(imports);
output.push_str("use std::str;\n");
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
output.push_str("use tonic::{Request, Response, Status};\n");
output.push_str("use tokio_stream::Stream;\n");
output.push_str("use std::pin::Pin;\n");
output.push_str("use std::sync::Arc;\n");
output.push_str("use std::task::{Context, Poll};\n");
output.push_str("use std::future::Future;\n");
output.push_str("use tonic::body::BoxBody;\n");
output.push_str("use tower::Service;\n");
output.push_str("use futures_util::StreamExt;\n");
output.push_str("use http_body_util::BodyExt;\n");
output.push_str("use http_body::Body;\n");
output.push_str("use crate::{BufferPool, StatusBody};\n\n");
for dep_res in file_proto.dependency() { for dep_res in file_proto.dependency() {
let (dep_data, _) = dep_res.expect("Failed to iterate dependency"); let (dep_data, _) = dep_res.expect("Failed to iterate dependency");
@@ -572,33 +566,8 @@ pub fn generate_rust_code(
} }
output.push_str("\n"); output.push_str("\n");
// Enums content_gen(&file_proto, &mut output);
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
&mut output,
);
}
// Messages
for msg_res in file_proto.message_type() {
let (msg_data, _) = msg_res.expect("Failed to iterate message");
write_message(
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
&mut output,
);
}
// Services
for svc_res in file_proto.service() {
let (svc_data, _) = svc_res.expect("Failed to iterate service");
write_service(
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
file_proto.package().unwrap_or(""),
&mut output,
);
}
generated_files.push((rust_file_name, output)); generated_files.push((rust_file_name, output));
} }
@@ -655,6 +624,121 @@ pub fn generate_rust_code(
generated_files generated_files
} }
pub fn generate_protobuf_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
generate_files_common(
set,
files_to_generate,
generate_mod_files,
DATA_IMPORTS,
|file_proto, output| {
// Enums
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
output,
);
}
// Messages
for msg_res in file_proto.message_type() {
let (msg_data, _) = msg_res.expect("Failed to iterate message");
write_message(
&DescriptorProto::new(msg_data).expect("Failed to parse DescriptorProto"),
output,
);
}
},
)
}
pub fn generate_service_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
generate_files_common(
set,
files_to_generate,
generate_mod_files,
SERVICE_IMPORTS,
|file_proto, output| {
let package = file_proto.package().unwrap_or("").to_string();
// Services
for svc_res in file_proto.service() {
let (svc_data, _) = svc_res.expect("Failed to iterate service");
write_service(
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
&package,
output,
);
}
},
)
}
pub fn generate_rust_code(
set: &FileDescriptorSet,
files_to_generate: Option<&[String]>,
generate_mod_files: bool,
) -> Vec<(String, String)> {
let protobuf_files = generate_protobuf_code(set, files_to_generate, false);
let service_files = generate_service_code(set, files_to_generate, false);
let mut combined_files: HashMap<String, String> = HashMap::new();
for (filename, content) in protobuf_files {
combined_files.insert(filename, content);
}
for (filename, content) in service_files {
if let Some(existing_content) = combined_files.get_mut(&filename) {
let stripped = strip_boilerplate(&content);
existing_content.push_str("\n");
existing_content.push_str(&stripped);
} else {
combined_files.insert(filename, content);
}
}
let mut result = combined_files.into_iter().collect::<Vec<_>>();
result.sort_by(|a, b| a.0.cmp(&b.0));
if generate_mod_files {
let mods = generate_files_common(
set,
files_to_generate,
true,
"",
|_, _| {},
);
for (filename, content) in mods {
if filename == "mod.rs" || filename.contains("/mod.rs") {
result.push((filename, content));
}
}
}
result
}
fn strip_boilerplate(content: &str) -> String {
// Find the first occurrence of a service definition or a trait
// In our case, the services start after the dependency imports and a newline.
if let Some(idx) = content.find("pub trait ") {
return content[idx..].to_string();
}
if let Some(idx) = content.find("pub struct ") {
// This might be a message, but generate_service_code only generates services (and their server structs)
return content[idx..].to_string();
}
content.to_string()
}
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) { fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
let svc_name = to_pascal_case(svc_proto.name().unwrap()); let svc_name = to_pascal_case(svc_proto.name().unwrap());
output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name)); output.push_str(&format!("#[tonic::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
@@ -754,22 +838,23 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
for method_res in svc_proto.method() { for method_res in svc_proto.method() {
let (method_data, _) = method_res.expect("Failed to iterate method"); let (method_data, _) = method_res.expect("Failed to iterate method");
let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto"); let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
let method_name = to_snake_case(method_proto.name().unwrap()); let original_method_name = method_proto.name().unwrap().to_string();
let method_name = to_snake_case(&original_method_name);
let input_full_name = method_proto.input_type().unwrap(); let input_full_name = method_proto.input_type().unwrap();
let input_type = input_full_name.split('.').last().unwrap(); let input_type = input_full_name.split('.').last().unwrap();
let input_owned = format!("Owned{}", input_type); let input_owned = format!("Owned{}", input_type);
let server_streaming = method_proto.server_streaming().unwrap_or(false); let server_streaming = method_proto.server_streaming().unwrap_or(false);
methods.push((method_name, input_owned, server_streaming)); methods.push((original_method_name, method_name, input_owned, server_streaming));
} }
for (method_name, input_owned, server_streaming) in methods { for (original_method_name, method_name, input_owned, server_streaming) in methods {
if server_streaming { if server_streaming {
// For streaming RPCs, we don't implement the server logic yet. // For streaming RPCs, we don't implement the server logic yet.
// We just make it compile by returning a "not implemented" response. // We just make it compile by returning a "not implemented" response.
let full_path = if package.is_empty() { let full_path = if package.is_empty() {
format!("/{}/{}", svc_proto.name().unwrap(), method_name) format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
} else { } else {
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), method_name) format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
}; };
output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
@@ -778,9 +863,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
continue; continue;
} }
let full_path = if package.is_empty() { let full_path = if package.is_empty() {
format!("/{}/{}", svc_proto.name().unwrap(), method_name) format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
} else { } else {
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), method_name) format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
}; };
output.push_str(&format!(" if path == \"{}\" {{\n", full_path)); output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned)); output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned));
+13
View File
@@ -0,0 +1,13 @@
[package]
name = "no_std_test"
version = "0.1.0"
edition = "2024"
[dependencies]
roto-runtime = { path = "../../runtime", default-features = false }
[profile.dev]
panic = "abort"
[profile.release]
panic = "abort"
+16
View File
@@ -0,0 +1,16 @@
#![no_std]
#![no_main]
use roto_runtime::ProtoAccessor;
#[panic_handler]
fn panic(_info: &core::panic::PanicInfo) -> ! {
loop {}
}
#[unsafe(no_mangle)]
pub extern "C" fn _start() -> ! {
let _data = [0u8; 0];
let _ = ProtoAccessor::new(&_data);
loop {}
}
Submodule
+1
Submodule grpc_bench added at c645de5855
+1 -2
View File
@@ -463,7 +463,6 @@ impl Service<http::Request<BoxBody>> for InteropServiceServer {
let payload = bytes_vec.slice(5..); let payload = bytes_vec.slice(5..);
let mut routed = false; let mut routed = false;
if path == "/interop.InteropService/UnaryCall" { if path == "/interop.InteropService/UnaryCall" {
let request_msg = match OwnedUnaryRequest::decode(payload) { let request_msg = match OwnedUnaryRequest::decode(payload) {
Ok(msg) => msg, Ok(msg) => msg,
@@ -495,7 +494,7 @@ impl Service<http::Request<BoxBody>> for InteropServiceServer {
routed = true; routed = true;
return Ok(http::Response::builder().status(200).header("content-type", "application/grpc").body(res_body).unwrap()); return Ok(http::Response::builder().status(200).header("content-type", "application/grpc").body(res_body).unwrap());
} }
if path == "/interop.InteropService/streaming_call" { if path == "/interop.InteropService/StreamingCall" {
let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0)); let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));
return Ok(http::Response::builder().status(200).body(res_body).unwrap()); return Ok(http::Response::builder().status(200).body(res_body).unwrap());
} }
+6 -1
View File
@@ -4,4 +4,9 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
bytes = "1.7" bytes = { version = "1.7", default-features = false }
[features]
default = ["std", "alloc"]
std = []
alloc = []
+9 -3
View File
@@ -1,4 +1,9 @@
use std::fmt; #![no_std]
#[cfg(feature = "std")]
extern crate std;
use core::fmt;
use bytes::BufMut; use bytes::BufMut;
pub struct MapFieldIterator<'a> { pub struct MapFieldIterator<'a> {
@@ -51,9 +56,10 @@ impl fmt::Display for RotoError {
} }
} }
#[cfg(feature = "std")]
impl std::error::Error for RotoError {} impl std::error::Error for RotoError {}
pub type Result<T> = std::result::Result<T, RotoError>; pub type Result<T> = core::result::Result<T, RotoError>;
pub trait RotoOwned { pub trait RotoOwned {
type Reader<'a> where Self: 'a; type Reader<'a> where Self: 'a;
@@ -686,7 +692,7 @@ mod tests {
.iter_repeated(18) .iter_repeated(18)
.map(|r| { .map(|r| {
let (val, _) = r.expect("Failed to decode repeated string"); let (val, _) = r.expect("Failed to decode repeated string");
std::str::from_utf8(val).expect("Invalid utf8") core::str::from_utf8(val).expect("Invalid utf8")
}) })
.collect(); .collect();
assert_eq!(repeated_strings, vec!["one", "two", "three"]); assert_eq!(repeated_strings, vec!["one", "two", "three"]);