diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 018374528..661d5b551 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -406,8 +406,8 @@ impl<'b> CodeGenerator<'_, 'b> { } fn append_field(&mut self, fq_message_name: &str, field: &Field) { - let type_ = field.descriptor.r#type(); - let repeated = field.descriptor.label() == Label::Repeated; + let type_ = field.descriptor.type_or_default(); + let repeated = field.descriptor.label_or_default() == Label::Repeated; let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self @@ -442,7 +442,7 @@ impl<'b> CodeGenerator<'_, 'b> { .push_str(&format!(" = {:?}", bytes_type.annotation())); } - match field.descriptor.label() { + match field.descriptor.label_or_default() { Label::Optional => { if optional { self.buf.push_str(", optional"); @@ -946,7 +946,7 @@ impl<'b> CodeGenerator<'_, 'b> { } fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { - match field.r#type() { + match field.type_or_default() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), Type::Uint32 | Type::Fixed32 => String::from("u32"), @@ -1003,7 +1003,7 @@ impl<'b> CodeGenerator<'_, 'b> { } fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { + match field.type_or_default() { Type::Float => Cow::Borrowed("float"), Type::Double => Cow::Borrowed("double"), Type::Int32 => Cow::Borrowed("int32"), @@ -1029,7 +1029,7 @@ impl<'b> CodeGenerator<'_, 'b> { } fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { + match field.type_or_default() { Type::Enum => Cow::Owned(format!( "enumeration({})", self.resolve_ident(field.type_name()) @@ -1043,11 +1043,11 @@ impl<'b> CodeGenerator<'_, 'b> { return true; } - if field.label() != Label::Optional { + if field.label_or_default() != Label::Optional { return false; } - match field.r#type() { + match field.type_or_default() { Type::Message => true, _ => self.syntax == Syntax::Proto2, } @@ -1074,7 +1074,7 @@ impl<'b> CodeGenerator<'_, 'b> { /// Returns `true` if the repeated field type can be packed. fn can_pack(field: &FieldDescriptorProto) -> bool { matches!( - field.r#type(), + field.type_or_default(), Type::Float | Type::Double | Type::Int32 diff --git a/prost-build/src/context.rs b/prost-build/src/context.rs index f4dee041b..2374c3fb1 100644 --- a/prost-build/src/context.rs +++ b/prost-build/src/context.rs @@ -141,11 +141,11 @@ impl<'a> Context<'a> { oneof: Option<&str>, field: &FieldDescriptorProto, ) -> bool { - if field.label() == Label::Repeated { + if field.label_or_default() == Label::Repeated { // Repeated field are stored in Vec, therefore it is already heap allocated return false; } - let fd_type = field.r#type(); + let fd_type = field.type_or_default(); if (fd_type == Type::Message || fd_type == Type::Group) && self .message_graph @@ -188,9 +188,9 @@ impl<'a> Context<'a> { assert_eq!(".", &fq_message_name[..1]); // repeated field cannot derive Copy - if field.label() == Label::Repeated { + if field.label_or_default() == Label::Repeated { false - } else if field.r#type() == Type::Message { + } else if field.type_or_default() == Type::Message { // nested and boxed messages cannot derive Copy if self .message_graph @@ -210,7 +210,7 @@ impl<'a> Context<'a> { } } else { matches!( - field.r#type(), + field.type_or_default(), Type::Float | Type::Double | Type::Int32 @@ -243,8 +243,8 @@ impl<'a> Context<'a> { pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { assert_eq!(".", &fq_message_name[..1]); - if field.r#type() == Type::Message { - if field.label() == Label::Repeated + if field.type_or_default() == Type::Message { + if field.label_or_default() == Label::Repeated || self .message_graph .is_nested(field.type_name(), fq_message_name) @@ -255,7 +255,7 @@ impl<'a> Context<'a> { } } else { matches!( - field.r#type(), + field.type_or_default(), Type::Int32 | Type::Int64 | Type::Uint32 diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index 92486c742..80a2646d4 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -58,7 +58,9 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == Type::Message && field.label() != Label::Repeated { + if field.type_or_default() == Type::Message + && field.label_or_default() != Label::Repeated + { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 45caeaa17..cf6cb74f6 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -2,7 +2,7 @@ use std::fmt; use anyhow::{anyhow, bail, Error}; use proc_macro2::{Span, TokenStream}; -use quote::{quote, ToTokens, TokenStreamExt}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path}; use crate::field::{bool_attr, set_option, tag_attr, Label}; @@ -283,6 +283,16 @@ impl Field { } Err(_) => quote!(#ident), }; + let get_or_default = match syn::parse_str::(&ident_str) { + Ok(index) => { + let get = Ident::new( + &format!("get_{}_or_default", index.index), + Span::call_site(), + ); + quote!(#get) + } + Err(_) => format_ident!("{ident}_or_default").to_token_stream(), + }; if let Ty::Enumeration(ref ty) = self.ty { let set = Ident::new(&format!("set_{ident_str}"), Span::call_site()); @@ -307,16 +317,25 @@ impl Field { } Kind::Optional(ref default) => { let get_doc = format!( + "Returns the enum value of `{ident_str}`, \ + or `None` if the field is unset or set to an invalid enum value." + ); + let get_or_default_doc = format!( "Returns the enum value of `{ident_str}`, \ or the default if the field is unset or set to an invalid enum value." ); quote! { #[doc=#get_doc] - pub fn #get(&self) -> #ty { + pub fn #get(&self) -> ::core::option::Option<#ty> { self.#ident.and_then(|x| { let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); result.ok() - }).unwrap_or(#default) + }) + } + + #[doc=#get_or_default_doc] + pub fn #get_or_default(&self) -> #ty { + self.#get().unwrap_or(#default) } #[doc=#set_doc] diff --git a/tests/src/default_enum_value.rs b/tests/src/default_enum_value.rs index dc831022f..dd65fa45d 100644 --- a/tests/src/default_enum_value.rs +++ b/tests/src/default_enum_value.rs @@ -9,16 +9,19 @@ include!(concat!(env!("OUT_DIR"), "/default_enum_value.rs")); #[test] fn test_default_enum() { let msg = Test::default(); - assert_eq!(msg.privacy_level_1(), PrivacyLevel::One); - assert_eq!(msg.privacy_level_3(), PrivacyLevel::PrivacyLevelThree); + assert_eq!(msg.privacy_level_1_or_default(), PrivacyLevel::One); assert_eq!( - msg.privacy_level_4(), + msg.privacy_level_3_or_default(), + PrivacyLevel::PrivacyLevelThree + ); + assert_eq!( + msg.privacy_level_4_or_default(), PrivacyLevel::PrivacyLevelprivacyLevelFour ); let msg = CMsgRemoteClientBroadcastHeader::default(); assert_eq!( - msg.msg_type(), + msg.msg_type_or_default(), ERemoteClientBroadcastMsg::KERemoteClientBroadcastMsgDiscovery ); }