add os_display overrides

This commit is contained in:
Rowan 2025-07-09 14:23:34 -04:00
parent a6f13837e1
commit 13f63b29a7
2 changed files with 119 additions and 71 deletions

View file

@ -66,7 +66,7 @@ impl_borrowed!(std::ffi::OsStr);
impl OsDisplay for std::path::Path { impl OsDisplay for std::path::Path {
fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result { fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result {
f.write_os_str(&self.as_os_str()) f.write_os_str(self.as_os_str())
} }
} }
@ -74,7 +74,7 @@ impl_borrowed!(std::path::Path);
impl OsDisplay for std::path::PathBuf { impl OsDisplay for std::path::PathBuf {
fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result { fn fmt_os(&self, f: &mut OsStringFormatter) -> std::fmt::Result {
f.write_os_str(&self.as_os_str()) f.write_os_str(self.as_os_str())
} }
} }

View file

@ -5,10 +5,10 @@ extern crate proc_macro;
use crate::attr_args::OsDisplayAttribute; use crate::attr_args::OsDisplayAttribute;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{quote, quote_spanned}; use quote::{quote, quote_spanned};
use std::collections::HashMap;
use syn::parse::Parse; use syn::parse::Parse;
use syn::spanned::Spanned; use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Expr, Fields, Ident, LitStr, parse_macro_input}; use syn::{Data, DeriveInput, Expr, Fields, Ident, LitStr, parse_macro_input};
use std::collections::HashMap;
#[proc_macro_derive(OsDisplay, attributes(os_display))] #[proc_macro_derive(OsDisplay, attributes(os_display))]
pub fn os_display_derive(input: TokenStream) -> TokenStream { pub fn os_display_derive(input: TokenStream) -> TokenStream {
@ -19,30 +19,54 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
let os_display_impl = match &input.data { let os_display_impl = match &input.data {
Data::Enum(data_enum) => { Data::Enum(data_enum) => {
let enum_os_display_attr = input
.attrs
.iter()
.find(|attr| attr.path().is_ident("os_display"));
let top_level_from_display = if let Some(attr) = enum_os_display_attr {
let parsed_attr = attr
.parse_args_with(OsDisplayAttribute::parse)
.expect("Failed to parse top-level #[os_display] attribute arguments");
match parsed_attr {
OsDisplayAttribute::FromDisplay => true,
_ => {
return quote_spanned! {name.span() =>
compile_error!("Only #[os_display(from_display)] is currently supported at the enum level for defaulting variant behavior.");
}.into();
}
}
} else {
false
};
let variant_arms: Vec<_> = data_enum let variant_arms: Vec<_> = data_enum
.variants .variants
.iter() .iter()
.map(|variant| { .map(|variant| {
let variant_name = &variant.ident; let variant_name = &variant.ident;
let os_display_attr = variant let variant_os_display_attr = variant
.attrs .attrs
.iter() .iter()
.find(|attr| attr.path().is_ident("os_display")); .find(|attr| attr.path().is_ident("os_display"));
let format_tokens = if let Some(attr) = os_display_attr {
let format_tokens = if let Some(attr) = variant_os_display_attr {
let parsed_attr = attr let parsed_attr = attr
.parse_args_with(OsDisplayAttribute::parse) .parse_args_with(OsDisplayAttribute::parse)
.expect("Failed to parse #[os_display] attribute arguments"); .expect("Failed to parse #[os_display] attribute arguments for variant");
match parsed_attr { match parsed_attr {
OsDisplayAttribute::Transparent => { OsDisplayAttribute::Transparent => {
match &variant.fields { match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
quote_spanned! {variant.span() => quote_spanned! {variant.span() =>
match self { match self {
#name::#variant_name(value) => value.fmt_os(f), #name::#variant_name(value) => value.fmt_os(f),
_ => unreachable!(), _ => unreachable!(),
} }
} }
} }
@ -51,7 +75,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
quote_spanned! {variant.span() => quote_spanned! {variant.span() =>
match self { match self {
#name::#variant_name{#field_ident} => #field_ident.fmt_os(f), #name::#variant_name{#field_ident} => #field_ident.fmt_os(f),
_ => unreachable!(), _ => unreachable!(),
} }
} }
} }
@ -62,15 +86,13 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
} }
} }
} }
OsDisplayAttribute::FromDisplay => { OsDisplayAttribute::FromDisplay => {
match &variant.fields { match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
quote_spanned! {variant.span() => quote_spanned! {variant.span() =>
match self { match self {
#name::#variant_name(value) => f.write_str(&value.to_string()), #name::#variant_name(value) => f.write_str(&value.to_string()),
_ => unreachable!(), _ => unreachable!(),
} }
} }
} }
@ -79,7 +101,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
quote_spanned! {variant.span() => quote_spanned! {variant.span() =>
match self { match self {
#name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()), #name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()),
_ => unreachable!(), _ => unreachable!(),
} }
} }
} }
@ -91,11 +113,8 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
} }
} }
OsDisplayAttribute::Format(format_args) => { OsDisplayAttribute::Format(format_args) => {
let format_str_value = format_args.format_string.value(); let format_str_value = format_args.format_string.value();
let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect();
let named_expressions: HashMap<String, Expr> = format_args.named_args let named_expressions: HashMap<String, Expr> = format_args.named_args
.into_iter() .into_iter()
.map(|(ident, expr)| (ident.to_string(), expr)) .map(|(ident, expr)| (ident.to_string(), expr))
@ -109,13 +128,36 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
) )
} }
} }
} else if top_level_from_display {
match &variant.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
quote_spanned! {variant.span() =>
match self {
#name::#variant_name(value) => f.write_str(&value.to_string()),
_ => unreachable!(),
}
}
}
Fields::Named(fields) if fields.named.len() == 1 => {
let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap();
quote_spanned! {variant.span() =>
match self {
#name::#variant_name{#field_ident} => f.write_str(&#field_ident.to_string()),
_ => unreachable!(),
}
}
}
_ => {
return quote_spanned! {variant.span() =>
compile_error!("Enum has #[os_display(from_display)], but variant `#variant_name` is not a single-field variant that can inherit `from_display` behavior. Specify an explicit #[os_display] for this variant or ensure it has a single field.");
};
}
}
} else { } else {
let variant_name_str = format!("{variant_name}"); let variant_name_str = format!("{variant_name}");
quote! { f.write_str(#variant_name_str)?; } quote! { f.write_str(#variant_name_str)?; }
}; };
match &variant.fields { match &variant.fields {
Fields::Unit => { Fields::Unit => {
quote! { quote! {
@ -130,21 +172,22 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
.map(|(i, _)| Ident::new(&format!("_{i}"), variant.span())) .map(|(i, _)| Ident::new(&format!("_{i}"), variant.span()))
.collect(); .collect();
let should_capture_value = if let Some(attr) = variant_os_display_attr {
let should_capture_value = if let Some(attr) = os_display_attr {
if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) { if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) {
matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay) matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay)
} else { false } } else { false }
} else { false }; } else {
top_level_from_display && fields.unnamed.len() == 1
};
if should_capture_value { if should_capture_value {
quote! { quote! {
#name::#variant_name(value) => { #format_tokens } #name::#variant_name(value) => { #format_tokens }
} }
} else { } else {
quote! { quote! {
#name::#variant_name(#(#field_idents),*) => { #format_tokens } #name::#variant_name(#(#field_idents),*) => { #format_tokens }
} }
@ -157,20 +200,23 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
.map(|f| f.ident.as_ref().unwrap().clone()) .map(|f| f.ident.as_ref().unwrap().clone())
.collect(); .collect();
let should_capture_value = if let Some(attr) = os_display_attr { let should_capture_value = if let Some(attr) = variant_os_display_attr {
if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) { if let Ok(parsed) = attr.parse_args_with(OsDisplayAttribute::parse) {
matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay) matches!(parsed, OsDisplayAttribute::Transparent | OsDisplayAttribute::FromDisplay)
} else { false } } else { false }
} else { false }; } else {
top_level_from_display && fields.named.len() == 1
};
if should_capture_value { if should_capture_value {
let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap(); let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap();
quote! { quote! {
#name::#variant_name{#field_ident} => { #format_tokens } #name::#variant_name{#field_ident} => { #format_tokens }
} }
} else { } else {
quote! { quote! {
#name::#variant_name{#(#field_idents),*} => { #format_tokens } #name::#variant_name{#(#field_idents),*} => { #format_tokens }
} }
@ -178,7 +224,7 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
} }
} }
}) })
.collect(); .collect();
quote! { quote! {
impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause {
@ -203,56 +249,51 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
.expect("Failed to parse #[os_display] attribute arguments"); .expect("Failed to parse #[os_display] attribute arguments");
match parsed_attr { match parsed_attr {
OsDisplayAttribute::Transparent => { OsDisplayAttribute::Transparent => match &data_struct.fields {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
match &data_struct.fields { quote! {
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause {
quote! { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result {
impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { self.0.fmt_os(f)
fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result {
self.0.fmt_os(f)
}
} }
} }
} }
Fields::Named(fields) if fields.named.len() == 1 => {
let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap();
quote! {
impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause {
fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result {
self.#field_ident.fmt_os(f)
}
}
}
}
_ => {
quote_spanned! {name.span() =>
compile_error!("#[os_display(transparent)] can only be used on single-field structs (newtypes).");
}
}
} }
} Fields::Named(fields) if fields.named.len() == 1 => {
OsDisplayAttribute::FromDisplay => { let field_ident = fields.named.first().unwrap().ident.as_ref().unwrap();
quote! {
impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause {
fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result {
self.#field_ident.fmt_os(f)
}
}
}
}
_ => {
quote_spanned! {name.span() =>
compile_error!("#[os_display(transparent)] can only be used on single-field structs (newtypes).");
}
}
},
OsDisplayAttribute::FromDisplay => {
quote! { quote! {
impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause { impl #impl_generics osstr_traits::OsDisplay for #name #ty_generics #where_clause {
fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result { fn fmt_os(&self, f: &mut osstr_traits::OsStringFormatter) -> std::fmt::Result {
f.write_str(&self.to_string()) f.write_str(&self.to_string())
} }
} }
} }
} }
OsDisplayAttribute::Format(format_args) => { OsDisplayAttribute::Format(format_args) => {
let format_str_value = format_args.format_string.value(); let format_str_value = format_args.format_string.value();
let positional_expressions: Vec<&Expr> = format_args.positional_args.iter().collect(); let positional_expressions: Vec<&Expr> =
let named_expressions: HashMap<String, Expr> = format_args.named_args format_args.positional_args.iter().collect();
let named_expressions: HashMap<String, Expr> = format_args
.named_args
.into_iter() .into_iter()
.map(|(ident, expr)| (ident.to_string(), expr)) .map(|(ident, expr)| (ident.to_string(), expr))
.collect(); .collect();
let generated_code = parse_os_display_format_string( let generated_code = parse_os_display_format_string(
&format_str_value, &format_str_value,
&positional_expressions, &positional_expressions,
@ -262,18 +303,27 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
let field_bindings = match &data_struct.fields { let field_bindings = match &data_struct.fields {
Fields::Named(fields) => { Fields::Named(fields) => {
let idents: Vec<&Ident> = fields.named.iter().filter_map(|f| f.ident.as_ref()).collect(); let idents: Vec<&Ident> = fields
.named
.iter()
.filter_map(|f| f.ident.as_ref())
.collect();
quote! { quote! {
let Self { #(#idents),* } = self; let Self { #(#idents),* } = self;
} }
}, }
Fields::Unnamed(fields) => { Fields::Unnamed(fields) => {
let idents: Vec<Ident> = fields.unnamed.iter().enumerate().map(|(i, _)| Ident::new(&format!("_{i}"), name.span())).collect(); let idents: Vec<Ident> = fields
.unnamed
.iter()
.enumerate()
.map(|(i, _)| Ident::new(&format!("_{i}"), name.span()))
.collect();
quote! { quote! {
let Self(#(#idents),*) = self; let Self(#(#idents),*) = self;
} }
}, }
Fields::Unit => quote!{}, Fields::Unit => quote! {},
}; };
quote! { quote! {
@ -288,7 +338,6 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
} }
} }
} else { } else {
quote_spanned! {name.span() => quote_spanned! {name.span() =>
compile_error!("OsDisplay derive macro is not yet implemented for structs without an #[os_display] attribute. Consider adding #[os_display(transparent)], #[os_display(from_display)], or specifying a format string using #[os_display(\"...\")] syntax."); compile_error!("OsDisplay derive macro is not yet implemented for structs without an #[os_display] attribute. Consider adding #[os_display(transparent)], #[os_display(from_display)], or specifying a format string using #[os_display(\"...\")] syntax.");
} }
@ -304,7 +353,6 @@ pub fn os_display_derive(input: TokenStream) -> TokenStream {
os_display_impl.into() os_display_impl.into()
} }
fn parse_os_display_format_string( fn parse_os_display_format_string(
format_str: &str, format_str: &str,
positional_expressions: &[&syn::Expr], positional_expressions: &[&syn::Expr],
@ -329,7 +377,7 @@ fn parse_os_display_format_string(
} }
let mut placeholder_content = String::new(); let mut placeholder_content = String::new();
for p in &mut chars { while let Some(p) = chars.next() {
if p == '}' { if p == '}' {
break; break;
} }