Introduce alloc feature for optional allocation

Wrap heap-allocated types and service generation in the `alloc` feature
flag to support environments without a memory allocator.
This commit is contained in:
2026-05-19 21:55:18 -07:00
parent 6910f11d69
commit 117cbf812b
7 changed files with 201 additions and 52 deletions
+117 -33
View File
@@ -1,13 +1,26 @@
use crate::google::protobuf::descriptor::{
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
FileDescriptorSet, MessageOptions, MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto,
FileDescriptorSet, MessageOptions, MethodDescriptorProto, OneofDescriptorProto,
ServiceDescriptorProto,
};
use roto_runtime::ProtoAccessor;
use std::collections::{HashMap, HashSet};
use std::str;
const DATA_IMPORTS: &str = "use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\nuse core::str;\n#[cfg(feature = \"alloc\")]\nuse bytes::{Bytes, BytesMut, Buf, BufMut};\n";
const SERVICE_IMPORTS: &str = "use tonic::{Request, Response, Status};\nuse tokio_stream::Stream;\nuse std::pin::Pin;\nuse std::sync::Arc;\nuse std::task::{Context, Poll};\nuse std::future::Future;\nuse tonic::body::BoxBody;\nuse tower::Service;\nuse futures_util::StreamExt;\nuse http_body_util::BodyExt;\nuse http_body::Body;\nuse crate::{BufferPool, StatusBody};\n";
const SERVICE_IMPORTS: &str =
"#[cfg(feature = \"alloc\")]\nuse tonic::{Request, Response, Status};\n\
#[cfg(feature = \"alloc\")]\nuse tokio_stream::Stream;\n\
#[cfg(feature = \"alloc\")]\nuse std::pin::Pin;\n\
#[cfg(feature = \"alloc\")]\nuse std::sync::Arc;\n\
#[cfg(feature = \"alloc\")]\nuse std::task::{Context, Poll};\n\
#[cfg(feature = \"alloc\")]\nuse std::future::Future;\n\
#[cfg(feature = \"alloc\")]\nuse tonic::body::BoxBody;\n\
#[cfg(feature = \"alloc\")]\nuse tower::Service;\n\
#[cfg(feature = \"alloc\")]\nuse futures_util::StreamExt;\n\
#[cfg(feature = \"alloc\")]\nuse http_body_util::BodyExt;\n\
#[cfg(feature = \"alloc\")]\nuse http_body::Body;\n\
#[cfg(feature = \"alloc\")]\nuse crate::{BufferPool, StatusBody};\n";
pub fn to_pascal_case(s: &str) -> String {
s.split('_')
@@ -36,7 +49,11 @@ pub fn to_snake_case(s: &str) -> String {
result
}
fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String, String) {
fn map_type_to_rust_accessor(
field_type: i32,
label: i32,
is_map: bool,
) -> (String, String, String) {
if label == 3 {
// LABEL_REPEATED
let iterator_type = if is_map {
@@ -314,7 +331,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(&format!(
" pub fn {}_or_default(&self) -> roto_runtime::Result<{}> {{\n",
safe_name, rust_type
safe_name, rust_type
));
output.push_str(&format!(
" self.{}().or(Ok({}))\n",
@@ -440,18 +457,30 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
output.push_str(&format!(" pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {{\n self.builder.finish()\n }}\n}}\n\n"));
output.push_str(&format!("#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n", msg_name));
output.push_str(&format!(
"#[cfg(feature = \"alloc\")]\npub struct Owned{} {{\n",
msg_name
));
output.push_str(" pub data: bytes::Bytes,\n");
output.push_str("}\n\n");
output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n", msg_name));
output.push_str(&format!(
"#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoOwned for Owned{} {{\n",
msg_name
));
output.push_str(&format!(" type Reader<'a> = {}<'a>;\n", msg_name));
output.push_str(&format!(" fn reader(&self) -> {}<'_> {{\n", msg_name));
output.push_str(&format!(" {}::new(&self.data).expect(\"failed to create reader\")\n", msg_name));
output.push_str(&format!(
" {}::new(&self.data).expect(\"failed to create reader\")\n",
msg_name
));
output.push_str(" }\n");
output.push_str("}\n\n");
output.push_str(&format!("#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n", msg_name));
output.push_str(&format!(
"#[cfg(feature = \"alloc\")]\nimpl roto_runtime::RotoMessage for Owned{} {{\n",
msg_name
));
output.push_str(" fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {\n");
output.push_str(&format!(" Ok(Owned{} {{ data: buf }})\n", msg_name));
output.push_str(" }\n\n");
@@ -551,7 +580,14 @@ where
}
}
let rust_file_name = format!("{}.rs", std::path::Path::new(proto_name).file_stem().unwrap().to_str().unwrap());
let rust_file_name = format!(
"{}.rs",
std::path::Path::new(proto_name)
.file_stem()
.unwrap()
.to_str()
.unwrap()
);
let mut output = String::new();
output.push_str("// @generated by protoc-gen-roto — do not edit\n");
@@ -639,7 +675,8 @@ pub fn generate_protobuf_code(
for enum_res in file_proto.enum_type() {
let (enum_data, _) = enum_res.expect("Failed to iterate enum");
write_enum(
&EnumDescriptorProto::new(enum_data).expect("Failed to parse EnumDescriptorProto"),
&EnumDescriptorProto::new(enum_data)
.expect("Failed to parse EnumDescriptorProto"),
output,
);
}
@@ -672,7 +709,8 @@ pub fn generate_service_code(
for svc_res in file_proto.service() {
let (svc_data, _) = svc_res.expect("Failed to iterate service");
write_service(
&ServiceDescriptorProto::new(svc_data).expect("Failed to parse ServiceDescriptorProto"),
&ServiceDescriptorProto::new(svc_data)
.expect("Failed to parse ServiceDescriptorProto"),
&package,
output,
);
@@ -709,13 +747,7 @@ pub fn generate_rust_code(
result.sort_by(|a, b| a.0.cmp(&b.0));
if generate_mod_files {
let mods = generate_files_common(
set,
files_to_generate,
true,
"",
|_, _| {},
);
let mods = generate_files_common(set, files_to_generate, true, "", |_, _| {});
for (filename, content) in mods {
if filename == "mod.rs" || filename.contains("/mod.rs") {
result.push((filename, content));
@@ -732,12 +764,18 @@ fn strip_boilerplate(content: &str) -> String {
fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut String) {
output.push_str(SERVICE_IMPORTS);
output.push_str("\n");
let svc_name = to_pascal_case(svc_proto.name().unwrap());
output.push_str(&format!("#[async_trait::async_trait]\npub trait {}: Send + Sync + 'static {{\n", svc_name));
output.push_str("#[cfg(feature = \"alloc\")]\n");
output.push_str(&format!(
"#[async_trait::async_trait]\npub trait {}: Send + Sync + 'static {{\n",
svc_name
));
for method_res in svc_proto.method() {
let (method_data, _) = method_res.expect("Failed to iterate method");
let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
let method_proto =
MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
let method_name = to_snake_case(method_proto.name().unwrap());
let input_full_name = method_proto.input_type().unwrap();
@@ -759,7 +797,10 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
};
let resp_type = if server_streaming {
format!("Response<Pin<Box<dyn Stream<Item = std::result::Result<{}, Status>> + Send>>>", output_owned)
format!(
"Response<Pin<Box<dyn Stream<Item = std::result::Result<{}, Status>> + Send>>>",
output_owned
)
} else {
format!("Response<{}>", output_owned)
};
@@ -772,27 +813,46 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
output.push_str("}\n\n");
let server_name = format!("{}Server", svc_name);
output.push_str(&format!("#[derive(Clone)]\npub struct {} {{\n", server_name));
output.push_str("#[cfg(feature = \"alloc\")]\n");
output.push_str(&format!(
"#[derive(Clone)]\npub struct {} {{\n",
server_name
));
output.push_str(&format!(" inner: Arc<dyn {}>,\n", svc_name));
output.push_str(" pool: Arc<BufferPool>,\n");
output.push_str("}\n\n");
output.push_str("#[cfg(feature = \"alloc\")]\n");
output.push_str(&format!("impl {} {{\n", server_name));
output.push_str(&format!(" pub fn new(inner: Arc<dyn {}>, pool: Arc<BufferPool>) -> Self {{\n", svc_name));
output.push_str(&format!(
" pub fn new(inner: Arc<dyn {}>, pool: Arc<BufferPool>) -> Self {{\n",
svc_name
));
output.push_str(" Self { inner, pool }\n");
output.push_str(" }\n");
output.push_str("}\n\n");
output.push_str(&format!("impl tonic::server::NamedService for {} {{\n", server_name));
output.push_str("#[cfg(feature = \"alloc\")]\n");
output.push_str(&format!(
"impl tonic::server::NamedService for {} {{\n",
server_name
));
let full_svc_name = if package.is_empty() {
svc_proto.name().unwrap().to_string()
} else {
format!("{}.{}", package, svc_proto.name().unwrap())
};
output.push_str(&format!(" const NAME: &'static str = \"{}\";\n", full_svc_name));
output.push_str(&format!(
" const NAME: &'static str = \"{}\";\n",
full_svc_name
));
output.push_str("}\n\n");
output.push_str(&format!("impl Service<http::Request<BoxBody>> for {} {{\n", server_name));
output.push_str("#[cfg(feature = \"alloc\")]\n");
output.push_str(&format!(
"impl Service<http::Request<BoxBody>> for {} {{\n",
server_name
));
output.push_str(" type Response = http::Response<BoxBody>;\n");
output.push_str(" type Error = std::convert::Infallible;\n");
output.push_str(" type Future = Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;\n\n");
@@ -829,14 +889,20 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
let mut methods = Vec::new();
for method_res in svc_proto.method() {
let (method_data, _) = method_res.expect("Failed to iterate method");
let method_proto = MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
let method_proto =
MethodDescriptorProto::new(method_data).expect("Failed to parse MethodDescriptorProto");
let original_method_name = method_proto.name().unwrap().to_string();
let method_name = to_snake_case(&original_method_name);
let input_full_name = method_proto.input_type().unwrap();
let input_type = input_full_name.split('.').last().unwrap();
let input_owned = format!("Owned{}", input_type);
let server_streaming = method_proto.server_streaming().unwrap_or(false);
methods.push((original_method_name, method_name, input_owned, server_streaming));
methods.push((
original_method_name,
method_name,
input_owned,
server_streaming,
));
}
for (original_method_name, method_name, input_owned, server_streaming) in methods {
@@ -846,7 +912,12 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
let full_path = if package.is_empty() {
format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
} else {
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
format!(
"/{}.{}/{}",
package,
svc_proto.name().unwrap(),
original_method_name
)
};
output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
@@ -857,17 +928,28 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
let full_path = if package.is_empty() {
format!("/{}/{}", svc_proto.name().unwrap(), original_method_name)
} else {
format!("/{}.{}/{}", package, svc_proto.name().unwrap(), original_method_name)
format!(
"/{}.{}/{}",
package,
svc_proto.name().unwrap(),
original_method_name
)
};
output.push_str(&format!(" if path == \"{}\" {{\n", full_path));
output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned));
output.push_str(&format!(
" let request_msg = match {}::decode(payload) {{\n",
input_owned
));
output.push_str(" Ok(msg) => msg,\n");
output.push_str(" Err(e) => {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n");
output.push_str(" };\n\n");
output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name));
output.push_str(&format!(
" let response = match inner.{}(Request::new(request_msg)).await {{\n",
method_name
));
output.push_str(" Ok(res) => res,\n");
output.push_str(" Err(e) => {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n");
@@ -884,7 +966,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, package: &str, output: &mut
output.push_str(" let frame_len = res_buf.len();\n");
output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n");
output.push_str(" pool.put(res_buf);\n");
output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n");
output.push_str(
" let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n",
);
output.push_str(" routed = true;\n");
output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n");
output.push_str(" }\n");