Add support for protobuf map fields
Update the generator to detect map fields and use MapFieldIterator. Implement MapFieldIterator in the runtime to handle key-value pair extraction and add write_map_entry to ProtoBuilder. Add tests to verify that map-bearing messages generate and compile correctly.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
/target
|
/target
|
||||||
test_gen_project
|
test_gen_project
|
||||||
test_types_gen_project
|
test_types_gen_project
|
||||||
|
test_map_gen_project
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
|
||||||
|
«
|
||||||
|
codegen/data/test_map.proto roto.test"y
|
||||||
|
MapTest4
|
||||||
|
my_map (2.roto.test.MapTest.MyMapEntryRmyMap8
|
||||||
|
|
||||||
|
MyMapEntry
|
||||||
|
key ( Rkey
|
||||||
|
value (Rvalue:8bproto3
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package roto.test;
|
||||||
|
|
||||||
|
message MapTest {
|
||||||
|
map<string, int32> my_map = 1;
|
||||||
|
}
|
||||||
+32
-10
@@ -1,6 +1,6 @@
|
|||||||
use crate::google::protobuf::descriptor::{
|
use crate::google::protobuf::descriptor::{
|
||||||
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
|
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
|
||||||
FileDescriptorSet,
|
FileDescriptorSet, MessageOptions,
|
||||||
};
|
};
|
||||||
use roto_runtime::ProtoAccessor;
|
use roto_runtime::ProtoAccessor;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
@@ -33,11 +33,16 @@ pub fn to_snake_case(s: &str) -> String {
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn map_type_to_rust_accessor(field_type: i32, label: i32) -> (String, String) {
|
fn map_type_to_rust_accessor(field_type: i32, label: i32, is_map: bool) -> (String, String) {
|
||||||
if label == 3 {
|
if label == 3 {
|
||||||
// LABEL_REPEATED
|
// LABEL_REPEATED
|
||||||
|
let iterator_type = if is_map {
|
||||||
|
"roto_runtime::MapFieldIterator<'a>"
|
||||||
|
} else {
|
||||||
|
"roto_runtime::RepeatedFieldIterator<'a>"
|
||||||
|
};
|
||||||
return (
|
return (
|
||||||
"roto_runtime::RepeatedFieldIterator<'a>".to_string(),
|
iterator_type.to_string(),
|
||||||
"".to_string(), // Not used for repeated fields in the same way
|
"".to_string(), // Not used for repeated fields in the same way
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -159,14 +164,23 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
let tag = field_proto.number().unwrap();
|
let tag = field_proto.number().unwrap();
|
||||||
let f_type = field_proto.r#type().unwrap() as i32;
|
let f_type = field_proto.r#type().unwrap() as i32;
|
||||||
let f_label = field_proto.label().unwrap() as i32;
|
let f_label = field_proto.label().unwrap() as i32;
|
||||||
|
let is_map = field_proto
|
||||||
|
.options()
|
||||||
|
.map(|opt| {
|
||||||
|
MessageOptions::new(opt)
|
||||||
|
.unwrap()
|
||||||
|
.map_entry()
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
fields_info.push((field_name.to_string(), tag, f_type, f_label));
|
fields_info.push((field_name.to_string(), tag, f_type, f_label, is_map));
|
||||||
}
|
}
|
||||||
|
|
||||||
output.push_str(&format!("pub struct {}<'a> {{\n", msg_name));
|
output.push_str(&format!("pub struct {}<'a> {{\n", msg_name));
|
||||||
output.push_str(" accessor: roto_runtime::ProtoAccessor<'a>,\n");
|
output.push_str(" accessor: roto_runtime::ProtoAccessor<'a>,\n");
|
||||||
|
|
||||||
for (field_name, _tag, _f_type, f_label) in &fields_info {
|
for (field_name, _tag, _f_type, f_label, _is_map) in &fields_info {
|
||||||
if *f_label == 3 {
|
if *f_label == 3 {
|
||||||
output.push_str(&format!(" {}_start: Option<usize>,\n", field_name));
|
output.push_str(&format!(" {}_start: Option<usize>,\n", field_name));
|
||||||
output.push_str(&format!(" {}_end: Option<usize>,\n", field_name));
|
output.push_str(&format!(" {}_end: Option<usize>,\n", field_name));
|
||||||
@@ -180,7 +194,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
output.push_str(" pub fn new(data: &'a [u8]) -> roto_runtime::Result<Self> {\n");
|
output.push_str(" pub fn new(data: &'a [u8]) -> roto_runtime::Result<Self> {\n");
|
||||||
output.push_str(" let accessor = roto_runtime::ProtoAccessor::new(data)?;\n");
|
output.push_str(" let accessor = roto_runtime::ProtoAccessor::new(data)?;\n");
|
||||||
if !fields_info.is_empty() {
|
if !fields_info.is_empty() {
|
||||||
for (name, _, _, label) in &fields_info {
|
for (name, _, _, label, _is_map) in &fields_info {
|
||||||
if *label == 3 {
|
if *label == 3 {
|
||||||
output.push_str(&format!(" let mut {}_start = None;\n", name));
|
output.push_str(&format!(" let mut {}_start = None;\n", name));
|
||||||
output.push_str(&format!(" let mut {}_end = None;\n", name));
|
output.push_str(&format!(" let mut {}_end = None;\n", name));
|
||||||
@@ -192,7 +206,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
output.push_str(" for item in accessor.fields() {\n");
|
output.push_str(" for item in accessor.fields() {\n");
|
||||||
output.push_str(" let (offset, tag, _) = item?;\n");
|
output.push_str(" let (offset, tag, _) = item?;\n");
|
||||||
|
|
||||||
for (name, tag, _, label) in &fields_info {
|
for (name, tag, _, label, _is_map) in &fields_info {
|
||||||
if *label == 3 {
|
if *label == 3 {
|
||||||
output.push_str(&format!(" if tag.field_number == {} {{\n", tag));
|
output.push_str(&format!(" if tag.field_number == {} {{\n", tag));
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
@@ -213,7 +227,7 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
|
|
||||||
output.push_str(" Ok(Self {\n");
|
output.push_str(" Ok(Self {\n");
|
||||||
output.push_str(" accessor,\n");
|
output.push_str(" accessor,\n");
|
||||||
for (name, _, _, label) in &fields_info {
|
for (name, _, _, label, _is_map) in &fields_info {
|
||||||
if *label == 3 {
|
if *label == 3 {
|
||||||
output.push_str(&format!("{}_start, {}_end,\n", name, name));
|
output.push_str(&format!("{}_start, {}_end,\n", name, name));
|
||||||
} else {
|
} else {
|
||||||
@@ -222,8 +236,8 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
}
|
}
|
||||||
output.push_str(" })\n }\n\n");
|
output.push_str(" })\n }\n\n");
|
||||||
|
|
||||||
for (field_name, tag, f_type, f_label) in fields_info {
|
for (field_name, tag, f_type, f_label, is_map) in fields_info {
|
||||||
let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label);
|
let (rust_type, logic) = map_type_to_rust_accessor(f_type, f_label, is_map);
|
||||||
let safe_name = if field_name == "type" {
|
let safe_name = if field_name == "type" {
|
||||||
format!("r#{}", field_name)
|
format!("r#{}", field_name)
|
||||||
} else {
|
} else {
|
||||||
@@ -239,11 +253,19 @@ fn write_message(msg_proto: &DescriptorProto, output: &mut String) {
|
|||||||
" match (self.{}_start, self.{}_end) {{\n",
|
" match (self.{}_start, self.{}_end) {{\n",
|
||||||
field_name, field_name
|
field_name, field_name
|
||||||
));
|
));
|
||||||
|
if is_map {
|
||||||
|
output.push_str(&format!(" (Some(start), Some(end)) => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated_range({}, start, end)),\n", tag));
|
||||||
|
output.push_str(&format!(
|
||||||
|
" _ => roto_runtime::MapFieldIterator::new(self.accessor.iter_repeated({})),\n",
|
||||||
|
tag
|
||||||
|
));
|
||||||
|
} else {
|
||||||
output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag));
|
output.push_str(&format!(" (Some(start), Some(end)) => self.accessor.iter_repeated_range({}, start, end),\n", tag));
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
" _ => self.accessor.iter_repeated({}),\n",
|
" _ => self.accessor.iter_repeated({}),\n",
|
||||||
tag
|
tag
|
||||||
));
|
));
|
||||||
|
}
|
||||||
output.push_str(" }\n }\n\n");
|
output.push_str(" }\n }\n\n");
|
||||||
} else {
|
} else {
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
use roto_codegen::google::protobuf::descriptor::FileDescriptorSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_map_generated_code_builds() {
|
||||||
|
// 1. Load FileDescriptorSet from data/test_map.desc
|
||||||
|
let desc_path = "data/test_map.desc";
|
||||||
|
let data = fs::read(desc_path).expect("Failed to read test_map.desc");
|
||||||
|
let set = FileDescriptorSet::new(&data)
|
||||||
|
.expect("Failed to create FileDescriptorSet from test_map.desc");
|
||||||
|
|
||||||
|
let generated_files = roto_codegen::generator::generate_rust_code(&set, None, false);
|
||||||
|
assert!(
|
||||||
|
!generated_files.is_empty(),
|
||||||
|
"Generated code should not be empty"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Setup a temporary Cargo project to verify the code builds
|
||||||
|
let root = std::env::current_dir().expect("Failed to get current directory");
|
||||||
|
let temp_project_dir = root.join("test_map_gen_project");
|
||||||
|
|
||||||
|
// Clean up previous runs
|
||||||
|
if temp_project_dir.exists() {
|
||||||
|
fs::remove_dir_all(&temp_project_dir).expect("Failed to clean up temp project directory");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new library project
|
||||||
|
let status = Command::new("cargo")
|
||||||
|
.args(["new", "--lib", "test_map_gen_project"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("Failed to run cargo new");
|
||||||
|
assert!(status.success(), "cargo new failed");
|
||||||
|
|
||||||
|
// 3. Configure the project to depend on the current roto crate
|
||||||
|
let cargo_toml_path = temp_project_dir.join("Cargo.toml");
|
||||||
|
let cargo_toml_content =
|
||||||
|
fs::read_to_string(&cargo_toml_path).expect("Failed to read Cargo.toml");
|
||||||
|
let updated_cargo_toml = format!(
|
||||||
|
"{}\n\nroto-codegen = {{ path = \"..\" }}\nroto-runtime = {{ path = \"../../runtime\" }}\n\n[workspace]\n",
|
||||||
|
cargo_toml_content
|
||||||
|
);
|
||||||
|
fs::write(cargo_toml_path, updated_cargo_toml).expect("Failed to write Cargo.toml");
|
||||||
|
|
||||||
|
// 4. Write the generated code to src/lib.rs
|
||||||
|
let mut all_code = String::new();
|
||||||
|
for (_, content) in generated_files {
|
||||||
|
all_code.push_str(&content);
|
||||||
|
all_code.push_str("\n");
|
||||||
|
}
|
||||||
|
let final_code = all_code.replace("use crate::", "use roto::");
|
||||||
|
let lib_path = temp_project_dir.join("src/lib.rs");
|
||||||
|
fs::write(lib_path, final_code).expect("Failed to write generated code to src/lib.rs");
|
||||||
|
|
||||||
|
// 5. Attempt to build the project
|
||||||
|
let build_status = Command::new("cargo")
|
||||||
|
.args(["--offline", "build"])
|
||||||
|
.current_dir(&temp_project_dir)
|
||||||
|
.status()
|
||||||
|
.expect("Failed to run cargo build");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
build_status.success(),
|
||||||
|
"The generated Rust code for test_map.proto failed to build in a standalone project!"
|
||||||
|
);
|
||||||
|
}
|
||||||
+46
-1
@@ -1,6 +1,33 @@
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
pub struct MapFieldIterator<'a> {
|
||||||
|
inner: RepeatedFieldIterator<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> MapFieldIterator<'a> {
|
||||||
|
pub fn new(inner: RepeatedFieldIterator<'a>) -> Self {
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for MapFieldIterator<'a> {
|
||||||
|
type Item = Result<(&'a [u8], &'a [u8])>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
match self.inner.next() {
|
||||||
|
Some(Ok((value, _wire_type))) => {
|
||||||
|
let accessor = ProtoAccessor::new(value).ok()?;
|
||||||
|
let (key_bytes, _) = accessor.get_value(1).ok()?;
|
||||||
|
let (val_bytes, _) = accessor.get_value(2).ok()?;
|
||||||
|
Some(Ok((key_bytes, val_bytes)))
|
||||||
|
}
|
||||||
|
Some(Err(e)) => Some(Err(e)),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum RotoError {
|
pub enum RotoError {
|
||||||
UnexpectedEndOfBuffer,
|
UnexpectedEndOfBuffer,
|
||||||
InvalidVarint,
|
InvalidVarint,
|
||||||
@@ -769,6 +796,24 @@ impl<'a> ProtoBuilder<'a> {
|
|||||||
self.append_bytes(raw_bytes)
|
self.append_bytes(raw_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn write_map_entry(
|
||||||
|
&mut self,
|
||||||
|
field_number: u32,
|
||||||
|
key_encoded: &[u8],
|
||||||
|
value_encoded: &[u8],
|
||||||
|
) -> Result<()> {
|
||||||
|
let entry_len = key_encoded.len() + value_encoded.len();
|
||||||
|
self.write_tag(field_number, WireType::LengthDelimited)?;
|
||||||
|
|
||||||
|
let mut len_buf = [0u8; 10];
|
||||||
|
let len_len = write_varint(entry_len as u64, &mut len_buf)?;
|
||||||
|
self.append_bytes(&len_buf[..len_len])?;
|
||||||
|
|
||||||
|
self.append_bytes(key_encoded)?;
|
||||||
|
self.append_bytes(value_encoded)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn finish(self) -> Result<&'a mut [u8]> {
|
pub fn finish(self) -> Result<&'a mut [u8]> {
|
||||||
Ok(&mut self.buf[..self.pos])
|
Ok(&mut self.buf[..self.pos])
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user