mod attr_args; extern crate proc_macro; use crate::attr_args::{OsDisplayAttribute}; use proc_macro::TokenStream; use quote::{quote, quote_spanned}; use syn::parse::Parse; use syn::spanned::Spanned; use syn::{Data, DeriveInput, Expr, Fields, Ident, LitStr, parse_macro_input}; use std::collections::HashMap; #[proc_macro_derive(OsDisplay, attributes(os_display))] pub fn os_display_derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let os_display_impl = match &input.data { Data::Enum(data_enum) => { let variant_arms: Vec<_> = data_enum .variants .iter() .map(|variant| { let variant_name = &variant.ident; let os_display_attr = variant .attrs .iter() .find(|attr| attr.path().is_ident("os_display")); let format_tokens = if let Some(attr) = os_display_attr { let parsed_attr = attr .parse_args_with(OsDisplayAttribute::parse) .expect("Failed to parse #[os_display] attribute arguments"); match parsed_attr { OsDisplayAttribute::Transparent => { match &variant.fields { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { quote_spanned! {variant.span() => match self { #name::#variant_name(value) => value.fmt_os(f), _ => unreachable!(), } } } Fields::Named(fields) if fields.named.len() == 1 => { let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); quote_spanned! {variant.span() => match self { #name::#variant_name{#field_ident} => #field_ident.fmt_os(f), _ => unreachable!(), } } } _ => { return quote_spanned! {variant.span() => compile_error!("#[os_display(transparent)] can only be used on single-field enum variants."); }; } } } OsDisplayAttribute::Format(format_args) => { let format_str_value = format_args.format_string.value(); let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); let named_expressions: HashMap = format_args.named_args .into_iter() .map(|(ident, expr)| (ident.to_string(), expr)) .collect(); parse_os_display_format_string( &format_str_value, &positional_expressions, &named_expressions, variant.span(), ) } } } else { let variant_name_str = format!("{variant_name}"); quote! { f.write_str(#variant_name_str)?; } }; match &variant.fields { Fields::Unit => { quote! { #name::#variant_name => { #format_tokens } } } Fields::Unnamed(fields) => { let field_idents: Vec = fields .unnamed .iter() .enumerate() .map(|(i, _)| Ident::new(&format!("_{i}"), variant.span())) .collect(); if let Some(attr) = os_display_attr { if let Ok(OsDisplayAttribute::Transparent) = attr.parse_args_with(OsDisplayAttribute::parse) { quote! { #name::#variant_name(value) => { #format_tokens } } } else { quote! { #name::#variant_name(#(#field_idents),*) => { #format_tokens } } } } else { quote! { #name::#variant_name(#(#field_idents),*) => { #format_tokens } } } } Fields::Named(fields) => { let field_idents: Vec = fields .named .iter() .map(|f| f.ident.as_ref().unwrap().clone()) .collect(); if let Some(attr) = os_display_attr { if let Ok(OsDisplayAttribute::Transparent) = attr.parse_args_with(OsDisplayAttribute::parse) { let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); quote! { #name::#variant_name{#field_ident} => { #format_tokens } } } else { quote! { #name::#variant_name{#(#field_idents),*} => { #format_tokens } } } } else { quote! { #name::#variant_name{#(#field_idents),*} => { #format_tokens } } } } } }) .collect(); quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { match self { #(#variant_arms),* } Ok(()) } } } } Data::Struct(data_struct) => { let os_display_attr = input .attrs .iter() .find(|attr| attr.path().is_ident("os_display")); if let Some(attr) = os_display_attr { let parsed_attr = attr .parse_args_with(OsDisplayAttribute::parse) .expect("Failed to parse #[os_display] attribute arguments"); match parsed_attr { OsDisplayAttribute::Transparent => { match &data_struct.fields { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { self.0.fmt_os(f) } } } } Fields::Named(fields) if fields.named.len() == 1 => { let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { self.#field_ident.fmt_os(f) } } } } _ => { quote_spanned! {name.span() => compile_error!("#[os_display(transparent)] can only be used on single-field structs (newtypes)."); } } } } OsDisplayAttribute::Format(format_args) => { let format_str_value = format_args.format_string.value(); let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); let named_expressions: HashMap = format_args.named_args .into_iter() .map(|(ident, expr)| (ident.to_string(), expr)) .collect(); let generated_code = parse_os_display_format_string( &format_str_value, &positional_expressions, &named_expressions, name.span(), ); let field_bindings = match &data_struct.fields { Fields::Named(fields) => { let idents: Vec<&Ident> = fields.named.iter().filter_map(|f| f.ident.as_ref()).collect(); quote! { let Self { #(#idents),* } = self; } }, Fields::Unnamed(fields) => { let idents: Vec = fields.unnamed.iter().enumerate().map(|(i, _)| Ident::new(&format!("_{i}"), name.span())).collect(); quote! { let Self(#(#idents),*) = self; } }, Fields::Unit => quote!{}, }; quote! { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { #field_bindings #generated_code Ok(()) } } } } } } else { quote_spanned! {name.span() => compile_error!("OsDisplay derive macro is not yet implemented for structs without #[os_display] attribute. Consider adding #[os_display(transparent)] for newtypes or specifying a format string using #[os_display(\"...\")] syntax."); } } } Data::Union(_) => { quote_spanned! {name.span() => compile_error!("OsDisplay derive macro does not support unions"); } } }; os_display_impl.into() } fn parse_os_display_format_string( format_str: &str, positional_expressions: &[&syn::Expr], named_expressions: &HashMap, span: proc_macro2::Span, ) -> proc_macro2::TokenStream { let mut generated_code_parts = Vec::new(); let mut current_literal = String::new(); let mut chars = format_str.chars().peekable(); let mut positional_arg_index = 0; while let Some(c) = chars.next() { if c == '{' { if let Some('{') = chars.peek() { current_literal.push(c); chars.next(); } else { if !current_literal.is_empty() { let lit_str = LitStr::new(¤t_literal, span); generated_code_parts.push(quote! { f.write_str(#lit_str)?; }); current_literal.clear(); } let mut placeholder_content = String::new(); for p in &mut chars { if p == '}' { break; } placeholder_content.push(p); } let expr_to_format: syn::Expr = if placeholder_content.is_empty() { if let Some(expr_ref) = positional_expressions.get(positional_arg_index) { positional_arg_index += 1; (**expr_ref).clone() } else { return quote_spanned! {span => compile_error!("Not enough positional arguments for format string: missing argument for empty '{}' placeholder.", #placeholder_content); }; } } else { let parsed_ident_res: syn::Result = syn::parse_str(&placeholder_content); if let Ok(ident) = parsed_ident_res { if let Some(expr) = named_expressions.get(&ident.to_string()) { expr.clone() } else { match syn::parse_str(&placeholder_content) { Ok(e) => e, Err(e) => { let error_message = e.to_string(); return quote_spanned! {span => compile_error!(format!("Invalid placeholder content '{}'. Error: {}. Named arguments must be simple identifiers provided in the attribute, or full expressions.", #placeholder_content, #error_message)); }; } } } } else { match syn::parse_str(&placeholder_content) { Ok(e) => e, Err(e) => { let error_message = e.to_string(); return quote_spanned! {span => compile_error!(format!("Invalid expression in os_display attribute: {}. Error: {}", #placeholder_content, #error_message)); }; } } } }; generated_code_parts.push(quote! { (#expr_to_format).fmt_os(f)?; }); } } else if c == '}' { if let Some('}') = chars.peek() { current_literal.push(c); chars.next(); } else { return quote_spanned! {span => compile_error!("Mismatched closing brace `}}` in os_display attribute."); }; } } else { current_literal.push(c); } } if !current_literal.is_empty() { let lit_str = LitStr::new(¤t_literal, span); generated_code_parts.push(quote! { f.write_str(#lit_str)?; }); } if positional_arg_index < positional_expressions.len() { return quote_spanned! {span => compile_error!("Too many positional arguments for format string: unused arguments provided."); }; } quote! { #(#generated_code_parts)* } }