Fix generated code and update example integration

Resolve `Result` ambiguity and lifetime issues in generated services.
Use `file_stem` for proto filenames. Make `StatusBody` public in
`roto-tonic` and update the `hello_world` build process.
This commit is contained in:
2026-05-15 10:29:54 -07:00
parent 00b3dcd9a6
commit da7ba47505
6 changed files with 323 additions and 11 deletions
Generated
+1
View File
@@ -552,6 +552,7 @@ dependencies = [
"http-body", "http-body",
"http-body-util", "http-body-util",
"prost", "prost",
"roto-codegen",
"roto-runtime", "roto-runtime",
"roto-tonic", "roto-tonic",
"tokio", "tokio",
+7 -7
View File
@@ -540,12 +540,12 @@ pub fn generate_rust_code(
} }
} }
let rust_file_name = format!("{}.rs", proto_name.replace(".proto", "")); let rust_file_name = format!("{}.rs", std::path::Path::new(proto_name).file_stem().unwrap().to_str().unwrap());
let mut output = String::new(); let mut output = String::new();
output.push_str("// @generated by protoc-gen-roto — do not edit\n"); output.push_str("// @generated by protoc-gen-roto — do not edit\n");
output.push_str("#[allow(unused_imports)]\n\n"); output.push_str("#[allow(unused_imports)]\n\n");
output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};\n"); output.push_str("use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator, RotoMessage};\n");
output.push_str("use std::str;\n"); output.push_str("use std::str;\n");
output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n"); output.push_str("use bytes::{Bytes, BytesMut, Buf, BufMut};\n");
output.push_str("use tonic::{Request, Response, Status};\n"); output.push_str("use tonic::{Request, Response, Status};\n");
@@ -710,9 +710,9 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(&format!("impl Service<http::Request<BoxBody>> for {} {{\n", server_name)); output.push_str(&format!("impl Service<http::Request<BoxBody>> for {} {{\n", server_name));
output.push_str(" type Response = http::Response<BoxBody>;\n"); output.push_str(" type Response = http::Response<BoxBody>;\n");
output.push_str(" type Error = std::convert::Infallible;\n"); output.push_str(" type Error = std::convert::Infallible;\n");
output.push_str(" type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;\n\n"); output.push_str(" type Future = Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;\n\n");
output.push_str(" fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {\n"); output.push_str(" fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {\n");
output.push_str(" Poll::Ready(Ok(()))\n"); output.push_str(" Poll::Ready(Ok(()))\n");
output.push_str(" }\n\n"); output.push_str(" }\n\n");
@@ -720,6 +720,7 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(" let inner = self.inner.clone();\n"); output.push_str(" let inner = self.inner.clone();\n");
output.push_str(" let pool = self.pool.clone();\n"); output.push_str(" let pool = self.pool.clone();\n");
output.push_str(" Box::pin(async move {\n"); output.push_str(" Box::pin(async move {\n");
output.push_str(" let path = req.uri().path().to_string();\n");
output.push_str(" let body = req.into_body();\n"); output.push_str(" let body = req.into_body();\n");
output.push_str(" let mut buf = pool.get();\n"); output.push_str(" let mut buf = pool.get();\n");
output.push_str(" let mut stream = body;\n"); output.push_str(" let mut stream = body;\n");
@@ -740,7 +741,6 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n\n"); output.push_str(" }\n\n");
output.push_str(" let payload = bytes_vec.slice(5..);\n"); output.push_str(" let payload = bytes_vec.slice(5..);\n");
output.push_str(" let path = req.uri().path();\n");
output.push_str(" let mut routed = false;\n\n"); output.push_str(" let mut routed = false;\n\n");
let mut methods = Vec::new(); let mut methods = Vec::new();
@@ -762,14 +762,14 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) {
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str(" }};\n\n"); output.push_str(" };\n\n");
output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name)); output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name));
output.push_str(" Ok(res) => res,\n"); output.push_str(" Ok(res) => res,\n");
output.push_str(" Err(e) => {\n"); output.push_str(" Err(e) => {\n");
output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n");
output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n");
output.push_str(" }\n"); output.push_str(" }\n");
output.push_str(" }};\n\n"); output.push_str(" };\n\n");
output.push_str(" let response_msg = response.into_inner();\n"); output.push_str(" let response_msg = response.into_inner();\n");
output.push_str(" let response_bytes = response_msg.bytes();\n"); output.push_str(" let response_bytes = response_msg.bytes();\n");
output.push_str(" let mut res_buf = pool.get();\n"); output.push_str(" let mut res_buf = pool.get();\n");
+1
View File
@@ -27,3 +27,4 @@ http-body = "1.0"
[build-dependencies] [build-dependencies]
tonic-build = "0.12" tonic-build = "0.12"
roto-codegen = { path = "../../codegen" }
+3 -3
View File
@@ -4,9 +4,9 @@ fn main() {
let dest_path = std::path::Path::new(&out_dir).join("hello.rs"); let dest_path = std::path::Path::new(&out_dir).join("hello.rs");
// Find the protoc-gen-roto binary // Find the protoc-gen-roto binary
// In a real scenario, this should be passed as an environment variable or found in PATH // Since we added roto-codegen to build-dependencies, it will be built.
// For this example, we'll try to find it in the target directory let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
let target_dir = std::env::current_dir().unwrap().join("../../target/debug"); let target_dir = std::path::Path::new(&manifest_dir).join("../../target/debug");
let plugin_path = target_dir.join("protoc-gen-roto"); let plugin_path = target_dir.join("protoc-gen-roto");
if !plugin_path.exists() { if !plugin_path.exists() {
+310
View File
@@ -0,0 +1,310 @@
// @generated by protoc-gen-roto — do not edit
#[allow(unused_imports)]
use roto_runtime::{ProtoAccessor, ProtoBuilder, Result, RotoError, read_varint, RepeatedFieldIterator};
use std::str;
use bytes::{Bytes, BytesMut, Buf, BufMut};
use tonic::{Request, Response, Status};
use tokio_stream::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::future::Future;
use tonic::body::BoxBody;
use tower::Service;
use futures_util::StreamExt;
use http_body_util::BodyExt;
use http_body::Body;
use roto_tonic::{BufferPool, StatusBody};
pub struct HelloRequest<'a> {
accessor: roto_runtime::ProtoAccessor<'a>,
name_offset: Option<usize>,
}
impl<'a> HelloRequest<'a> {
pub fn new(data: &'a [u8]) -> roto_runtime::Result<Self> {
let accessor = roto_runtime::ProtoAccessor::new(data)?;
let mut name_offset = None;
for item in accessor.fields() {
let (offset, tag, _) = item?;
if tag.field_number == 1 { name_offset = Some(offset); }
}
Ok(Self {
accessor,
name_offset,
})
}
pub fn name(&self) -> roto_runtime::Result<&'a str> {
let offset = self.name_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?;
let (bytes, _) = self.accessor.get_value_at(offset)?;
str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation)
}
pub fn name_or_default(&self) -> roto_runtime::Result<&'a str> {
self.name().or(Ok(""))
}
pub fn has_name(&self) -> bool { self.name_offset.is_some() }
pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> {
self.accessor.raw_fields()
}
}
pub struct HelloRequestBuilder<'b> {
builder: roto_runtime::ProtoBuilder<'b>,
name_written: bool,
}
impl<'b> HelloRequestBuilder<'b> {
pub fn builder(buf: &mut [u8]) -> HelloRequestBuilder<'_> {
HelloRequestBuilder {
builder: roto_runtime::ProtoBuilder::new(buf),
name_written: false,
}
}
pub fn name(mut self, value: &str) -> roto_runtime::Result<Self> {
self.builder.write_string(1, value)?;
self.name_written = true;
Ok(self)
}
pub fn with(mut self, msg: &HelloRequest<'_>) -> roto_runtime::Result<Self> {
for item in msg.raw_fields() {
let (field_number, raw_bytes) = item?;
let is_written = match field_number {
1 => self.name_written,
_ => false,
};
if !is_written {
self.builder.write_raw(raw_bytes)?;
}
}
Ok(self)
}
pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {
self.builder.finish()
}
}
pub struct OwnedHelloRequest {
pub data: bytes::Bytes,
}
impl roto_runtime::RotoOwned for OwnedHelloRequest {
type Reader<'a> = HelloRequest<'a>;
fn reader(&self) -> HelloRequest<'_> {
HelloRequest::new(&self.data).expect("failed to create reader")
}
}
impl roto_runtime::RotoMessage for OwnedHelloRequest {
fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {
Ok(OwnedHelloRequest { data: buf })
}
fn bytes(&self) -> bytes::Bytes {
self.data.clone()
}
}
pub struct HelloResponse<'a> {
accessor: roto_runtime::ProtoAccessor<'a>,
message_offset: Option<usize>,
}
impl<'a> HelloResponse<'a> {
pub fn new(data: &'a [u8]) -> roto_runtime::Result<Self> {
let accessor = roto_runtime::ProtoAccessor::new(data)?;
let mut message_offset = None;
for item in accessor.fields() {
let (offset, tag, _) = item?;
if tag.field_number == 1 { message_offset = Some(offset); }
}
Ok(Self {
accessor,
message_offset,
})
}
pub fn message(&self) -> roto_runtime::Result<&'a str> {
let offset = self.message_offset.ok_or(roto_runtime::RotoError::FieldNotFound)?;
let (bytes, _) = self.accessor.get_value_at(offset)?;
str::from_utf8(bytes).map_err(|_| roto_runtime::RotoError::WireFormatViolation)
}
pub fn message_or_default(&self) -> roto_runtime::Result<&'a str> {
self.message().or(Ok(""))
}
pub fn has_message(&self) -> bool { self.message_offset.is_some() }
pub fn raw_fields(&self) -> roto_runtime::RawFieldIterator<'a> {
self.accessor.raw_fields()
}
}
pub struct HelloResponseBuilder<'b> {
builder: roto_runtime::ProtoBuilder<'b>,
message_written: bool,
}
impl<'b> HelloResponseBuilder<'b> {
pub fn builder(buf: &mut [u8]) -> HelloResponseBuilder<'_> {
HelloResponseBuilder {
builder: roto_runtime::ProtoBuilder::new(buf),
message_written: false,
}
}
pub fn message(mut self, value: &str) -> roto_runtime::Result<Self> {
self.builder.write_string(1, value)?;
self.message_written = true;
Ok(self)
}
pub fn with(mut self, msg: &HelloResponse<'_>) -> roto_runtime::Result<Self> {
for item in msg.raw_fields() {
let (field_number, raw_bytes) = item?;
let is_written = match field_number {
1 => self.message_written,
_ => false,
};
if !is_written {
self.builder.write_raw(raw_bytes)?;
}
}
Ok(self)
}
pub fn finish(self) -> roto_runtime::Result<&'b mut [u8]> {
self.builder.finish()
}
}
pub struct OwnedHelloResponse {
pub data: bytes::Bytes,
}
impl roto_runtime::RotoOwned for OwnedHelloResponse {
type Reader<'a> = HelloResponse<'a>;
fn reader(&self) -> HelloResponse<'_> {
HelloResponse::new(&self.data).expect("failed to create reader")
}
}
impl roto_runtime::RotoMessage for OwnedHelloResponse {
fn decode(buf: bytes::Bytes) -> roto_runtime::Result<Self> {
Ok(OwnedHelloResponse { data: buf })
}
fn bytes(&self) -> bytes::Bytes {
self.data.clone()
}
}
#[tonic::async_trait]
pub trait HelloWorldService: Send + Sync + 'static {
async fn hello_world(&self, request: Request<OwnedHelloRequest>) -> std::result::Result<Response<OwnedHelloResponse>, Status>;
}
pub struct HelloWorldServiceServer {
inner: Arc<dyn HelloWorldService>,
pool: Arc<BufferPool>,
}
impl HelloWorldServiceServer {
pub fn new(inner: Arc<dyn HelloWorldService>, pool: Arc<BufferPool>) -> Self {
Self { inner, pool }
}
}
impl tonic::server::NamedService for HelloWorldServiceServer {
const NAME: &'static str = "HelloWorldService";
}
impl Service<http::Request<BoxBody>> for HelloWorldServiceServer {
type Response = http::Response<BoxBody>;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<BoxBody>) -> Self::Future {
let inner = self.inner.clone();
let pool = self.pool.clone();
Box::pin(async move {
let body = req.into_body();
let mut buf = pool.get();
let mut stream = body;
while let Some(frame_result) = stream.frame().await {
let frame = frame_result.map_err(|e| {
panic!("Body frame error: {}", e);
})?;
if let Some(data) = frame.data_ref() {
buf.put(data.clone());
}
}
let total_len = buf.len();
let bytes_vec = buf.split_to(total_len).freeze();
pool.put(buf);
if bytes_vec.len() < 5 {
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
}
let payload = bytes_vec.slice(5..);
let path = req.uri().path();
let mut routed = false;
if path == "/HelloWorldService/hello_world" {
let request_msg = match OwnedHelloRequest::decode(payload) {
Ok(msg) => msg,
Err(e) => {
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
}
}};
let response = match inner.hello_world(Request::new(request_msg)).await {
Ok(res) => res,
Err(e) => {
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
}
}};
let response_msg = response.into_inner();
let response_bytes = response_msg.bytes();
let mut res_buf = pool.get();
res_buf.put_u8(0);
let len = response_bytes.len() as u32;
res_buf.put_slice(&len.to_be_bytes());
res_buf.put_slice(&response_bytes);
let frame_len = res_buf.len();
let frame = res_buf.split_to(frame_len).freeze();
pool.put(res_buf);
let res_body = BoxBody::new(StatusBody(Some(frame)));
routed = true;
return Ok(http::Response::builder().status(200).header("content-type", "application/grpc").body(res_body).unwrap());
}
if !routed {
let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));
return Ok(http::Response::builder().status(200).body(res_body).unwrap());
}
Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody(None))).unwrap())
})
}
}
+1 -1
View File
@@ -101,7 +101,7 @@ impl BufferPool {
} }
} }
pub struct StatusBody(pub(crate) Option<Bytes>); pub struct StatusBody(pub Option<Bytes>);
impl Body for StatusBody { impl Body for StatusBody {
type Data = Bytes; type Data = Bytes;