diff --git a/src/generator.rs b/src/generator.rs index f70a1de..77016c3 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -67,7 +67,7 @@ fn map_type_to_rust_builder(field_type: i32) -> (String, String) { match field_type { 9 => ("&str".to_string(), "write_string".to_string()), 5 | 17 => ("i32".to_string(), "write_int32".to_string()), - 3 | 4 | 8 | 13 | 14 | 18 => ("i64".to_string(), "write_varint".to_string()), + 3 | 4 | 8 | 13 | 14 | 18 => ("u64".to_string(), "write_varint".to_string()), 7 | 15 => ("u32".to_string(), "write_fixed32".to_string()), 6 | 16 => ("u64".to_string(), "write_fixed64".to_string()), 11 | 12 => ("&[u8]".to_string(), "write_bytes".to_string()), @@ -97,20 +97,26 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { let mut values = enum_proto.enum_value(); let mut variant_count = 0; + let mut zero_variant_name = None; while let Some(val_res) = values.next() { - let (val_data, _) = val_res.expect("Failed to read enum value"); + let (val_data, _) = val_res.expect("Failed to iterate enum"); let accessor = ProtoAccessor::new(val_data).expect("Failed to parse EnumValueDescriptorProto"); let (name_bytes, _) = accessor.get_value(1).expect("Enum value name missing"); let name = str::from_utf8(name_bytes).expect("Enum value name invalid utf8"); let (num_bytes, _) = accessor.get_value(2).expect("Enum value number missing"); let (num, _) = crate::read_varint(num_bytes).expect("Enum value number invalid varint"); - output.push_str(&format!(" {}, = {},\n", to_pascal_case(name), num)); + let pascal_name = to_pascal_case(name); + if num == 0 { + zero_variant_name = Some(pascal_name.clone()); + } + output.push_str(&format!(" {} = {},\n", pascal_name, num)); variant_count += 1; } - if variant_count == 0 { + if zero_variant_name.is_none() { output.push_str(" Unknown = 0,\n"); + zero_variant_name = Some("Unknown".to_string()); } output.push_str("}\n\n"); @@ -132,8 +138,8 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { output.push_str(&format!(" {} => {}::{},\n", num, enum_name, to_pascal_case(name))); } - output.push_str(&format!(" _ => {}::Unknown,\n", enum_name)); - output.push_str(" }\n }}\n}\n\n"); + output.push_str(&format!(" _ => {}::{},\n", enum_name, zero_variant_name.as_ref().unwrap())); + output.push_str(" }\n }\n}\n\n"); } // Messages @@ -155,6 +161,7 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); let field_name = field_proto.name().unwrap(); + let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.to_string() }; let tag = field_proto.number().unwrap(); let f_type = field_proto.field_type().unwrap() as i32; let f_label = field_proto.label().unwrap() as i32; @@ -164,12 +171,12 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { if f_label == 3 { output.push_str(&format!( " pub fn {}(&self) -> {} {{\n {}\n }}\n\n", - field_name, rust_type, logic.replace("%d", &tag.to_string()) + safe_name, rust_type, logic.replace("%d", &tag.to_string()) )); } else { output.push_str(&format!( " pub fn {}(&self) -> Result<{}> {{\n let (bytes, _) = self.0.get_value({})?;\n {}\n }}\n\n", - field_name, rust_type, tag, logic + safe_name, rust_type, tag, logic )); } } @@ -189,6 +196,7 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { let (field_data, _) = field_res.expect("Failed to iterate field"); let field_proto = FieldDescriptorProto::new(field_data).expect("Failed to parse FieldDescriptorProto"); let field_name = field_proto.name().unwrap(); + let safe_name = if field_name == "type" { format!("r#{}", field_name) } else { field_name.to_string() }; let tag = field_proto.number().unwrap(); let f_type = field_proto.field_type().unwrap() as i32; @@ -196,7 +204,7 @@ pub fn generate_rust_code(set: &FileDescriptorSet) -> String { output.push_str(&format!( " pub fn {}(mut self, value: {}) -> Result {{\n self.builder.{}({}, value)?;\n Ok(self)\n }}\n\n", - field_name, rust_type, method, tag + safe_name, rust_type, method, tag )); } diff --git a/src/lib.rs b/src/lib.rs index a3e1d41..4d2d139 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,9 @@ pub mod proto_gen; pub mod generator; +// Uncomment this to check if the code compiles +#[path = "../proto/google/protobuf/descriptor.rs"] +pub mod descriptor; + use std::fmt; #[derive(Debug, PartialEq, Eq)]