use std::pin::Pin; use std::future::Future; use std::task::{Context, Poll}; use std::sync::{Arc, Mutex}; use tonic::{transport::Server, Request, Response, Status}; use roto_tonic::RotoCodec; use hello::{HelloWorldService, OwnedHelloRequest, OwnedHelloResponse}; use tower::Service; use bytes::{Bytes, BytesMut, Buf, BufMut}; use tonic::body::BoxBody; use futures_util::StreamExt; use roto_runtime::{RotoOwned, RotoMessage}; use http_body_util::BodyExt; use http_body::Body; pub mod hello { include!(concat!(env!("OUT_DIR"), "/hello.rs")); } struct BufferPool { pool: Mutex>, default_capacity: usize, } impl BufferPool { fn new(default_capacity: usize) -> Self { Self { pool: Mutex::new(Vec::new()), default_capacity, } } fn get(&self) -> BytesMut { self.pool.lock().unwrap().pop().unwrap_or_else(|| BytesMut::with_capacity(self.default_capacity)) } fn put(&self, mut buf: BytesMut) { buf.clear(); if buf.capacity() >= self.default_capacity { self.pool.lock().unwrap().push(buf); } } } #[derive(Clone)] pub struct MyHelloWorld { pool: Arc, } impl MyHelloWorld { pub fn new(pool: Arc) -> Self { Self { pool } } } #[tonic::async_trait] impl HelloWorldService for MyHelloWorld { async fn hello_world( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let reader = req.reader(); let name = reader.name().unwrap_or("Unknown"); let mut buf = self.pool.get(); buf.resize(1024, 0); let slice = hello::HelloResponseBuilder::builder(&mut buf[..]) .message(&format!("Hello {}!", name)).unwrap() .finish().unwrap(); let res_len = slice.len(); let response_bytes = buf.split_to(res_len).freeze(); self.pool.put(buf); let reply = OwnedHelloResponse { data: response_bytes, }; Ok(Response::new(reply)) } } // --- Tonic Glue --- #[derive(Clone)] pub struct HelloWorldServer { inner: Arc, pool: Arc, } impl HelloWorldServer { pub fn new(inner: MyHelloWorld, pool: Arc) -> Self { Self { inner: Arc::new(inner), pool } } } impl tonic::server::NamedService for HelloWorldServer { const NAME: &'static str = "hello.HelloWorldService"; } struct StatusBody(Option); impl Body for StatusBody { type Data = Bytes; type Error = Status; fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { if let Some(data) = self.0.take() { Poll::Ready(Some(Ok(http_body::Frame::data(data)))) } else { Poll::Ready(None) } } } impl Service> for HelloWorldServer { type Response = http::Response; type Error = std::convert::Infallible; type Future = Pin> + Send>>; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { let inner = self.inner.clone(); let pool = self.pool.clone(); println!("Server received request: {} {}", req.method(), req.uri()); Box::pin(async move { let body = req.into_body(); let mut buf = pool.get(); let mut stream = body; while let Some(frame_result) = stream.frame().await { let frame = frame_result.map_err(|e| { println!("Body frame error: {}", e); panic!("Body frame error: {}", e); })?; if let Some(data) = frame.data_ref() { buf.put(data.clone()); } } let total_len = buf.len(); let bytes_vec = buf.split_to(total_len).freeze(); pool.put(buf); println!("Collected body bytes: {} bytes", bytes_vec.len()); if bytes_vec.len() < 5 { println!("Body too short: {} bytes", bytes_vec.len()); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder() .status(200) .body(res_body) .unwrap()); } println!("Decoding request from {} bytes", bytes_vec.len() - 5); let request_msg = match OwnedHelloRequest::decode(bytes_vec.slice(5..)) { Ok(msg) => msg, Err(e) => { println!("Decode error: {}", e); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder().status(200).body(res_body).unwrap()); } }; println!("Request decoded successfully"); let response = match inner.hello_world(Request::new(request_msg)).await { Ok(res) => res, Err(e) => { println!("Service error: {}", e); let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder().status(200).body(res_body).unwrap()); } }; let response_msg = response.into_inner(); let response_bytes = response_msg.bytes(); println!("Service responded with {} bytes", response_bytes.len()); let mut res_buf = pool.get(); res_buf.put_u8(0); let len = response_bytes.len() as u32; res_buf.put_slice(&len.to_be_bytes()); res_buf.put_slice(&response_bytes); let frame_len = res_buf.len(); let frame = res_buf.split_to(frame_len).freeze(); pool.put(res_buf); let res_body = BoxBody::new(StatusBody(Some(frame))); Ok(http::Response::builder() .status(200) .header("content-type", "application/grpc") .body(res_body) .unwrap()) }) } } #[tokio::main] async fn main() -> Result<(), Box> { let addr: std::net::SocketAddr = "[::1]:50051".parse()?; let pool = Arc::new(BufferPool::new(1024)); let hello = MyHelloWorld::new(pool.clone()); println!("Server listening on {}", addr); Server::builder() .add_service(HelloWorldServer::new(hello, pool)) .serve(addr) .await?; Ok(()) }