diff --git a/crates/osstr_traits_derive/src/attr_args.rs b/crates/osstr_traits_derive/src/attr_args.rs index 3da72ce..4c82a8c 100644 --- a/crates/osstr_traits_derive/src/attr_args.rs +++ b/crates/osstr_traits_derive/src/attr_args.rs @@ -2,9 +2,11 @@ use syn::parse::{Parse, ParseStream}; use syn::{Expr, Ident, LitStr, Token, custom_keyword}; custom_keyword!(transparent); +custom_keyword!(from_display); pub enum OsDisplayAttribute { Transparent, + FromDisplay, Format(FormatArgs), } @@ -25,6 +27,13 @@ impl Parse for OsDisplayAttribute { } else { Ok(OsDisplayAttribute::Transparent) } + } else if lookahead.peek(from_display) { + input.parse::()?; + if !input.is_empty() { + Err(input.error("Unexpected tokens after `from_display` attribute.")) + } else { + Ok(OsDisplayAttribute::FromDisplay) + } } else if lookahead.peek(LitStr) { let format_string = input.parse()?; diff --git a/crates/osstr_traits_derive/src/lib.rs b/crates/osstr_traits_derive/src/lib.rs index 7eea813..dcc5062 100644 --- a/crates/osstr_traits_derive/src/lib.rs +++ b/crates/osstr_traits_derive/src/lib.rs @@ -61,6 +61,32 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } } + OsDisplayAttribute::FromDisplay => { + match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + quote_spanned! {variant.span() => + match self { + #name::#variant_name(value) => f.write_str(&value.to_string()), + _ => unreachable!(), // Should not happen with match arm + } + } + } + Fields::Named(fields) if fields.named.len() == 1 => { + let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); + quote_spanned! {variant.span() => + match self { + #name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()), + _ => unreachable!(), // Should not happen with match arm + } + } + } + _ => { + return quote_spanned! {variant.span() => + compile_error!("#[os_display(from_display)] on an enum variant requires a single field which implements std::fmt::Display."); + }; + } + } + } OsDisplayAttribute::Format(format_args) => { let format_str_value = format_args.format_string.value(); @@ -191,6 +217,15 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } } + OsDisplayAttribute::FromDisplay => { + quote! { + impl #impl_generics my_os_traits::OsDisplay for #name #ty_generics #where_clause { + fn fmt_os(&self, f: &mut my_os_traits::OsStringFormatter) -> std::fmt::Result { + f.write_str(&self.to_string()) + } + } + } + } OsDisplayAttribute::Format(format_args) => { let format_str_value = format_args.format_string.value(); let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect();