diff --git a/AGENTS.md b/AGENTS.md index 29577bf..2693cc5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,6 +7,8 @@ you should be able to work without user assistance. If you are writing code, write tests first. The tests must pass for your work to be complete. +Before considering a task complete, make sure that all target build, and all tests suceed. + ## Special instructions ### Fork diff --git a/examples/hello_world/src/bin/server.rs b/examples/hello_world/src/bin/server.rs index d3d0760..b49080b 100644 --- a/examples/hello_world/src/bin/server.rs +++ b/examples/hello_world/src/bin/server.rs @@ -1,12 +1,12 @@ use std::pin::Pin; use std::future::Future; use std::task::{Context, Poll}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tonic::{transport::Server, Request, Response, Status}; use roto_tonic::RotoCodec; use hello::{HelloWorldService, OwnedHelloRequest, OwnedHelloResponse}; use tower::Service; -use bytes::{Bytes, Buf, BufMut}; +use bytes::{Bytes, BytesMut, Buf, BufMut}; use tonic::body::BoxBody; use futures_util::StreamExt; use roto_runtime::{RotoOwned, RotoMessage}; @@ -17,8 +17,41 @@ pub mod hello { include!("../../proto/hello.rs"); } -#[derive(Default, Clone)] -pub struct MyHelloWorld {} +struct BufferPool { + pool: Mutex>, + default_capacity: usize, +} + +impl BufferPool { + fn new(default_capacity: usize) -> Self { + Self { + pool: Mutex::new(Vec::new()), + default_capacity, + } + } + + fn get(&self) -> BytesMut { + self.pool.lock().unwrap().pop().unwrap_or_else(|| BytesMut::with_capacity(self.default_capacity)) + } + + fn put(&self, mut buf: BytesMut) { + buf.clear(); + if buf.capacity() >= self.default_capacity { + self.pool.lock().unwrap().push(buf); + } + } +} + +#[derive(Clone)] +pub struct MyHelloWorld { + pool: Arc, +} + +impl MyHelloWorld { + pub fn new(pool: Arc) -> Self { + Self { pool } + } +} #[tonic::async_trait] impl HelloWorldService for MyHelloWorld { @@ -30,13 +63,18 @@ impl HelloWorldService for MyHelloWorld { let reader = req.reader(); let name = reader.name().unwrap_or("Unknown"); - let mut buf = vec![0u8; 1024]; - let slice = hello::HelloResponseBuilder::builder(&mut buf) + let mut buf = self.pool.get(); + buf.resize(1024, 0); + let slice = hello::HelloResponseBuilder::builder(&mut buf[..]) .message(&format!("Hello {}!", name)).unwrap() .finish().unwrap(); + let res_len = slice.len(); + let response_bytes = buf.split_to(res_len).freeze(); + self.pool.put(buf); + let reply = OwnedHelloResponse { - data: bytes::Bytes::copy_from_slice(slice), + data: response_bytes, }; Ok(Response::new(reply)) @@ -48,11 +86,12 @@ impl HelloWorldService for MyHelloWorld { #[derive(Clone)] pub struct HelloWorldServer { inner: Arc, + pool: Arc, } impl HelloWorldServer { - pub fn new(inner: MyHelloWorld) -> Self { - Self { inner: Arc::new(inner) } + pub fn new(inner: MyHelloWorld, pool: Arc) -> Self { + Self { inner: Arc::new(inner), pool } } } @@ -89,32 +128,43 @@ impl Service> for HelloWorldServer { fn call(&mut self, req: http::Request) -> Self::Future { let inner = self.inner.clone(); + let pool = self.pool.clone(); println!("Server received request: {} {}", req.method(), req.uri()); Box::pin(async move { let body = req.into_body(); - let bytes_vec = body.collect().await.map_err(|e| { - println!("Body collect error: {}", e); - panic!("Body collect error: {}", e); - })?.to_bytes(); + let mut buf = pool.get(); + let mut stream = body; + while let Some(frame_result) = stream.frame().await { + let frame = frame_result.map_err(|e| { + println!("Body frame error: {}", e); + panic!("Body frame error: {}", e); + })?; + if let Some(data) = frame.data_ref() { + buf.put(data.clone()); + } + } + + let total_len = buf.len(); + let bytes_vec = buf.split_to(total_len).freeze(); + pool.put(buf); println!("Collected body bytes: {} bytes", bytes_vec.len()); if bytes_vec.len() < 5 { println!("Body too short: {} bytes", bytes_vec.len()); - let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder() .status(200) .body(res_body) .unwrap()); } - let data = &bytes_vec[5..]; - println!("Decoding request from {} bytes", data.len()); - let request_msg = match OwnedHelloRequest::decode(Bytes::copy_from_slice(data)) { + println!("Decoding request from {} bytes", bytes_vec.len() - 5); + let request_msg = match OwnedHelloRequest::decode(bytes_vec.slice(5..)) { Ok(msg) => msg, Err(e) => { println!("Decode error: {}", e); - let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder().status(200).body(res_body).unwrap()); } }; @@ -124,7 +174,7 @@ impl Service> for HelloWorldServer { Ok(res) => res, Err(e) => { println!("Service error: {}", e); - let res_body = BoxBody::new(StatusBody(Some(Bytes::from(vec![0, 0, 0, 0, 0])))); + let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0])))); return Ok(http::Response::builder().status(200).body(res_body).unwrap()); } }; @@ -133,13 +183,17 @@ impl Service> for HelloWorldServer { let response_bytes = response_msg.bytes(); println!("Service responded with {} bytes", response_bytes.len()); - let mut res_buf = vec![0u8; 5 + response_bytes.len()]; - res_buf[0] = 0; + let mut res_buf = pool.get(); + res_buf.put_u8(0); let len = response_bytes.len() as u32; - res_buf[1..5].copy_from_slice(&len.to_be_bytes()); - res_buf[5..].copy_from_slice(&response_bytes); + res_buf.put_slice(&len.to_be_bytes()); + res_buf.put_slice(&response_bytes); - let res_body = BoxBody::new(StatusBody(Some(Bytes::from(res_buf)))); + let frame_len = res_buf.len(); + let frame = res_buf.split_to(frame_len).freeze(); + pool.put(res_buf); + + let res_body = BoxBody::new(StatusBody(Some(frame))); Ok(http::Response::builder() .status(200) .header("content-type", "application/grpc") @@ -152,12 +206,13 @@ impl Service> for HelloWorldServer { #[tokio::main] async fn main() -> Result<(), Box> { let addr: std::net::SocketAddr = "[::1]:50051".parse()?; - let hello = MyHelloWorld::default(); + let pool = Arc::new(BufferPool::new(1024)); + let hello = MyHelloWorld::new(pool.clone()); println!("Server listening on {}", addr); Server::builder() - .add_service(HelloWorldServer::new(hello)) + .add_service(HelloWorldServer::new(hello, pool)) .serve(addr) .await?;