diff --git a/Cargo.lock b/Cargo.lock index bba7556..f1f6479 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1196,6 +1196,7 @@ version = "0.1.0" dependencies = [ "bytes", "futures-util", + "http", "http-body", "http-body-util", "prost", @@ -1352,23 +1353,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" -[[package]] -name = "temp_test_project" -version = "0.1.0" -dependencies = [ - "bytes", - "futures-util", - "http", - "http-body", - "http-body-util", - "roto-codegen", - "roto-runtime", - "roto-tonic", - "tokio-stream", - "tonic", - "tower 0.4.13", -] - [[package]] name = "tempfile" version = "3.27.0" diff --git a/Cargo.toml b/Cargo.toml index be7fc0d..422753a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ members = [ "protos", "benches", "roto-tonic", - "temp_test_project", "examples/hello_world", ] diff --git a/codegen/src/generator.rs b/codegen/src/generator.rs index 1697149..04ca1d8 100644 --- a/codegen/src/generator.rs +++ b/codegen/src/generator.rs @@ -738,7 +738,7 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(" let bytes_vec = buf.split_to(total_len).freeze();\n"); output.push_str(" pool.put(buf);\n"); output.push_str(" if bytes_vec.len() < 5 {\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n\n"); output.push_str(" let payload = bytes_vec.slice(5..);\n"); @@ -760,14 +760,14 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(&format!(" let request_msg = match {}::decode(payload) {{\n", input_owned)); output.push_str(" Ok(msg) => msg,\n"); output.push_str(" Err(e) => {\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); output.push_str(" };\n\n"); output.push_str(&format!(" let response = match inner.{}(Request::new(request_msg)).await {{\n", method_name)); output.push_str(" Ok(res) => res,\n"); output.push_str(" Err(e) => {\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); output.push_str(" };\n\n"); @@ -781,17 +781,17 @@ fn write_service(svc_proto: &ServiceDescriptorProto, output: &mut String) { output.push_str(" let frame_len = res_buf.len();\n"); output.push_str(" let frame = res_buf.split_to(frame_len).freeze();\n"); output.push_str(" pool.put(res_buf);\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody(Some(frame)));\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(frame), 0));\n"); output.push_str(" routed = true;\n"); output.push_str(" return Ok(http::Response::builder().status(200).header(\"content-type\", \"application/grpc\").body(res_body).unwrap());\n"); output.push_str(" }\n"); } output.push_str(" if !routed {\n"); - output.push_str(" let res_body = BoxBody::new(StatusBody(Some(Bytes::from_static(&[0, 0, 0, 0, 0]))));\n"); + output.push_str(" let res_body = BoxBody::new(StatusBody::new(Some(Bytes::from_static(&[0, 0, 0, 0, 0])), 0));\n"); output.push_str(" return Ok(http::Response::builder().status(200).body(res_body).unwrap());\n"); output.push_str(" }\n"); - output.push_str(" Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody(None))).unwrap())\n"); + output.push_str(" Ok(http::Response::builder().status(200).body(BoxBody::new(StatusBody::new(None, 0))).unwrap())\n"); output.push_str(" })\n"); output.push_str(" }\n"); output.push_str("}\n"); diff --git a/roto-tonic/Cargo.toml b/roto-tonic/Cargo.toml index e962ee1..0397e49 100644 --- a/roto-tonic/Cargo.toml +++ b/roto-tonic/Cargo.toml @@ -12,3 +12,4 @@ http-body = "1.0" http-body-util = "0.1" tower = "0.4" futures-util = "0.3" +http = "1.1" diff --git a/roto-tonic/src/lib.rs b/roto-tonic/src/lib.rs index d795267..3d3755b 100644 --- a/roto-tonic/src/lib.rs +++ b/roto-tonic/src/lib.rs @@ -101,7 +101,21 @@ impl BufferPool { } } -pub struct StatusBody(pub Option); +pub struct StatusBody { + pub data: Option, + pub trailers: Option, +} + +impl StatusBody { + pub fn new(data: Option, status: u8) -> Self { + let mut trailers = http::HeaderMap::new(); + trailers.insert("grpc-status", status.to_string().parse().unwrap()); + Self { + data, + trailers: Some(trailers), + } + } +} impl Body for StatusBody { type Data = Bytes; @@ -111,8 +125,10 @@ impl Body for StatusBody { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - if let Some(data) = self.0.take() { + if let Some(data) = self.data.take() { Poll::Ready(Some(Ok(http_body::Frame::data(data)))) + } else if let Some(trailers) = self.trailers.take() { + Poll::Ready(Some(Ok(http_body::Frame::trailers(trailers)))) } else { Poll::Ready(None) }