diff --git a/crates/osstr_traits/src/impls.rs b/crates/osstr_traits/src/impls.rs index f8ab5d3..ba3d4b5 100644 --- a/crates/osstr_traits/src/impls.rs +++ b/crates/osstr_traits/src/impls.rs @@ -66,7 +66,7 @@ impl_borrowed!(std::ffi::OsStr); impl OsDisplay for std::path::Path { fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result { - f.write_os_str(&self.as_os_str()) + f.write_os_str(self.as_os_str()) } } @@ -74,7 +74,7 @@ impl_borrowed!(std::path::Path); impl OsDisplay for std::path::PathBuf { fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result { - f.write_os_str(&self.as_os_str()) + f.write_os_str(self.as_os_str()) } } diff --git a/crates/osstr_traits_derive/src/lib.rs b/crates/osstr_traits_derive/src/lib.rs index 9e1a77d..e5d97df 100644 --- a/crates/osstr_traits_derive/src/lib.rs +++ b/crates/osstr_traits_derive/src/lib.rs @@ -5,10 +5,10 @@ extern crate proc_macro; use crate::attr_args::OsDisplayAttribute; use proc_macro::TokenStream; use quote::{quote, quote_spanned}; +use std::collections::HashMap; use syn::parse::Parse; use syn::spanned::Spanned; use syn::{Data, DeriveInput, Expr, Fields, Ident, LitStr, parse_macro_input}; -use std::collections::HashMap; #[proc_macro_derive(OsDisplay, attributes(os_display))] pub fn os_display_derive(input: TokenStream) -> TokenStream { @@ -19,30 +19,54 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { let os_display_impl = match &input.data { Data::Enum(data_enum) => { + let enum_os_display_attr = input + .attrs + .iter() + .find(|attr| attr.path().is_ident("os_display")); + + let top_level_from_display = if let Some(attr) = enum_os_display_attr { + let parsed_attr = attr + .parse_args_with(OsDisplayAttribute::parse) + .expect("Failed to parse top-level #[os_display] attribute arguments"); + + match parsed_attr { + OsDisplayAttribute::FromDisplay => true, + _ => { + return quote_spanned! {name.span() => + compile_error!("Only #[os_display(from_display)] is currently supported at the enum level for defaulting variant behavior."); + }.into(); + } + } + } else { + false + }; + let variant_arms: Vec<_> = data_enum .variants .iter() .map(|variant| { let variant_name = &variant.ident; - let os_display_attr = variant + let variant_os_display_attr = variant .attrs .iter() .find(|attr| attr.path().is_ident("os_display")); - let format_tokens = if let Some(attr) = os_display_attr { + + let format_tokens = if let Some(attr) = variant_os_display_attr { + let parsed_attr = attr .parse_args_with(OsDisplayAttribute::parse) - .expect("Failed to parse #[os_display] attribute arguments"); + .expect("Failed to parse #[os_display] attribute arguments for variant"); match parsed_attr { OsDisplayAttribute::Transparent => { - + match &variant.fields { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { quote_spanned! {variant.span() => match self { #name::#variant_name(value) => value.fmt_os(f), - _ => unreachable!(), + _ => unreachable!(), } } } @@ -51,7 +75,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { quote_spanned! {variant.span() => match self { #name::#variant_name{#field_ident} => #field_ident.fmt_os(f), - _ => unreachable!(), + _ => unreachable!(), } } } @@ -62,15 +86,13 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } } - OsDisplayAttribute::FromDisplay => { - - + OsDisplayAttribute::FromDisplay => { match &variant.fields { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { quote_spanned! {variant.span() => match self { #name::#variant_name(value) => f.write_str(&value.to_string()), - _ => unreachable!(), + _ => unreachable!(), } } } @@ -79,7 +101,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { quote_spanned! {variant.span() => match self { #name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()), - _ => unreachable!(), + _ => unreachable!(), } } } @@ -91,11 +113,8 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } OsDisplayAttribute::Format(format_args) => { - let format_str_value = format_args.format_string.value(); - let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); - let named_expressions: HashMap = format_args.named_args .into_iter() .map(|(ident, expr)| (ident.to_string(), expr)) @@ -109,13 +128,36 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { ) } } + } else if top_level_from_display { + match &variant.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + quote_spanned! {variant.span() => + match self { + #name::#variant_name(value) => f.write_str(&value.to_string()), + _ => unreachable!(), + } + } + } + Fields::Named(fields) if fields.named.len() == 1 => { + let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); + quote_spanned! {variant.span() => + match self { + #name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()), + _ => unreachable!(), + } + } + } + _ => { + return quote_spanned! {variant.span() => + compile_error!("Enum has #[os_display(from_display)], but variant `#variant_name` is not a single-field variant that can inherit `from_display` behavior. Specify an explicit #[os_display] for this variant or ensure it has a single field."); + }; + } + } } else { - let variant_name_str = format!("{variant_name}"); quote! { f.write_str(#variant_name_str)?; } }; - match &variant.fields { Fields::Unit => { quote! { @@ -130,21 +172,22 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { .map(|(i, _)| Ident::new(&format!("_{i}"), variant.span())) .collect(); - - let should_capture_value = if let Some(attr) = os_display_attr { + let should_capture_value = if let Some(attr) = variant_os_display_attr { if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) { matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay) } else { false } - } else { false }; + } else { + + top_level_from_display && fields.unnamed.len() == 1 + }; if should_capture_value { - - + quote! { #name::#variant_name(value) => { #format_tokens } } } else { - + quote! { #name::#variant_name(#(#field_idents),*) => { #format_tokens } } @@ -157,20 +200,23 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { .map(|f| f.ident.as_ref().unwrap().clone()) .collect(); - let should_capture_value = if let Some(attr) = os_display_attr { + let should_capture_value = if let Some(attr) = variant_os_display_attr { if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) { matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay) } else { false } - } else { false }; + } else { + + top_level_from_display && fields.named.len() == 1 + }; if should_capture_value { - + let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); quote! { #name::#variant_name{#field_ident} => { #format_tokens } } } else { - + quote! { #name::#variant_name{#(#field_idents),*} => { #format_tokens } } @@ -178,7 +224,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } }) - .collect(); + .collect(); quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { @@ -203,56 +249,51 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { .expect("Failed to parse #[os_display] attribute arguments"); match parsed_attr { - OsDisplayAttribute::Transparent => { - - match &data_struct.fields { - Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { - quote! { - impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { - fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { - self.0.fmt_os(f) - } + OsDisplayAttribute::Transparent => match &data_struct.fields { + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + quote! { + impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { + fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { + self.0.fmt_os(f) } } } - Fields::Named(fields) if fields.named.len() == 1 => { - let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); - quote! { - impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { - fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { - self.#field_ident.fmt_os(f) - } - } - } - } - _ => { - quote_spanned! {name.span() => - compile_error!("#[os_display(transparent)] can only be used on single-field structs (newtypes)."); - } - } } - } - OsDisplayAttribute::FromDisplay => { - - + Fields::Named(fields) if fields.named.len() == 1 => { + let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); + quote! { + impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { + fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { + self.#field_ident.fmt_os(f) + } + } + } + } + _ => { + quote_spanned! {name.span() => + compile_error!("#[os_display(transparent)] can only be used on single-field structs (newtypes)."); + } + } + }, + OsDisplayAttribute::FromDisplay => { quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { - f.write_str(&self.to_string()) } } } } OsDisplayAttribute::Format(format_args) => { - let format_str_value = format_args.format_string.value(); - let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); - let named_expressions: HashMap = format_args.named_args + let positional_expressions: Vec<&Expr> = + format_args.positional_args.iter().collect(); + let named_expressions: HashMap = format_args + .named_args .into_iter() .map(|(ident, expr)| (ident.to_string(), expr)) .collect(); - + let generated_code = parse_os_display_format_string( &format_str_value, &positional_expressions, @@ -262,18 +303,27 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { let field_bindings = match &data_struct.fields { Fields::Named(fields) => { - let idents: Vec<&Ident> = fields.named.iter().filter_map(|f| f.ident.as_ref()).collect(); + let idents: Vec<&Ident> = fields + .named + .iter() + .filter_map(|f| f.ident.as_ref()) + .collect(); quote! { let Self { #(#idents),* } = self; } - }, + } Fields::Unnamed(fields) => { - let idents: Vec = fields.unnamed.iter().enumerate().map(|(i, _)| Ident::new(&format!("_{i}"), name.span())).collect(); + let idents: Vec = fields + .unnamed + .iter() + .enumerate() + .map(|(i, _)| Ident::new(&format!("_{i}"), name.span())) + .collect(); quote! { let Self(#(#idents),*) = self; } - }, - Fields::Unit => quote!{}, + } + Fields::Unit => quote! {}, }; quote! { @@ -288,7 +338,6 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { } } } else { - quote_spanned! {name.span() => compile_error!("OsDisplay derive macro is not yet implemented for structs without an #[os_display] attribute. Consider adding #[os_display(transparent)], #[os_display(from_display)], or specifying a format string using #[os_display(\"...\")] syntax."); } @@ -304,7 +353,6 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream { os_display_impl.into() } - fn parse_os_display_format_string( format_str: &str, positional_expressions: &[&syn::Expr], @@ -329,7 +377,7 @@ fn parse_os_display_format_string( } let mut placeholder_content = String::new(); - for p in &mut chars { + while let Some(p) = chars.next() { if p == '}' { break; }