|
3 | 3 | // Use of this source code is governed by a BSD-style |
4 | 4 | // license that can be found in the LICENSE file. |
5 | 5 |
|
| 6 | +use parse_attrs::has_argh_attrs; |
6 | 7 | use syn::ext::IdentExt as _; |
7 | 8 |
|
8 | 9 | /// Implementation of the `FromArgs` and `argh(...)` derive attributes. |
@@ -34,6 +35,14 @@ pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { |
34 | 35 | gen.into() |
35 | 36 | } |
36 | 37 |
|
| 38 | +/// Entrypoint for `#[derive(FromArgValue)]`. |
| 39 | +#[proc_macro_derive(FromArgValue, attributes(argh))] |
| 40 | +pub fn argh_value_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { |
| 41 | + let ast = syn::parse_macro_input!(input as syn::DeriveInput); |
| 42 | + let gen = impl_from_arg_value(&ast); |
| 43 | + gen.into() |
| 44 | +} |
| 45 | + |
37 | 46 | /// Entrypoint for `#[derive(ArgsInfo)]`. |
38 | 47 | #[proc_macro_derive(ArgsInfo, attributes(argh))] |
39 | 48 | pub fn args_info_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { |
@@ -63,6 +72,25 @@ fn impl_from_args(input: &syn::DeriveInput) -> TokenStream { |
63 | 72 | output_tokens |
64 | 73 | } |
65 | 74 |
|
| 75 | +fn impl_from_arg_value(input: &syn::DeriveInput) -> TokenStream { |
| 76 | + let errors = &Errors::default(); |
| 77 | + let mut output_tokens = match &input.data { |
| 78 | + syn::Data::Enum(de) => impl_from_arg_value_enum(errors, &input.ident, &input.generics, de), |
| 79 | + _ => { |
| 80 | + errors.err(input, "`#[derive(FromArgValue)]` can only be applied to `enum`s"); |
| 81 | + TokenStream::new() |
| 82 | + } |
| 83 | + }; |
| 84 | + if has_argh_attrs(&input.attrs) { |
| 85 | + errors.err( |
| 86 | + &input.ident, |
| 87 | + "`#[derive(FromArgValue)]` `enum`s do not support `#[argh(...)]` attributes", |
| 88 | + ); |
| 89 | + } |
| 90 | + errors.to_tokens(&mut output_tokens); |
| 91 | + output_tokens |
| 92 | +} |
| 93 | + |
66 | 94 | /// The kind of optionality a parameter has. |
67 | 95 | enum Optionality { |
68 | 96 | None, |
@@ -1178,3 +1206,101 @@ fn enum_only_single_field_unnamed_variants<'a>( |
1178 | 1206 | } |
1179 | 1207 | } |
1180 | 1208 | } |
| 1209 | + |
| 1210 | +/// Implements `FromArgValue` for a `#![derive(FromArgValue)]` enum (a choice enum). |
| 1211 | +fn impl_from_arg_value_enum( |
| 1212 | + errors: &Errors, |
| 1213 | + name: &syn::Ident, |
| 1214 | + generic_args: &syn::Generics, |
| 1215 | + de: &syn::DataEnum, |
| 1216 | +) -> TokenStream { |
| 1217 | + // An enum variant like `<name>` |
| 1218 | + struct ChoiceVariant<'a> { |
| 1219 | + ident: &'a syn::Ident, |
| 1220 | + name: syn::LitStr, |
| 1221 | + } |
| 1222 | + |
| 1223 | + let variants: Vec<ChoiceVariant<'_>> = de |
| 1224 | + .variants |
| 1225 | + .iter() |
| 1226 | + .map(|variant| { |
| 1227 | + let ident = &variant.ident; |
| 1228 | + choice_enum_only_fieldless_variant(errors, &variant.fields); |
| 1229 | + let attrs = parse_attrs::ChoiceVariantAttrs::parse(errors, variant); |
| 1230 | + let name = match attrs.name_override { |
| 1231 | + Some(lit) => lit, |
| 1232 | + None => { |
| 1233 | + let name_str = pascal_to_snake_case(&format!("{}", ident)); |
| 1234 | + syn::LitStr::new(&name_str, ident.span()) |
| 1235 | + } |
| 1236 | + }; |
| 1237 | + ChoiceVariant { ident, name } |
| 1238 | + }) |
| 1239 | + .collect(); |
| 1240 | + |
| 1241 | + if variants.is_empty() { |
| 1242 | + errors.err(&de.variants, "Choice enums must have at least one variant"); |
| 1243 | + } |
| 1244 | + |
| 1245 | + let name_repeating = std::iter::repeat(name.clone()); |
| 1246 | + let variant_idents = variants.iter().map(|x| x.ident); |
| 1247 | + let variant_names = variants.iter().map(|x| &x.name).collect::<Vec<_>>(); |
| 1248 | + let err_literal = { |
| 1249 | + let mut err = "expected ".to_string(); |
| 1250 | + for (i, name) in variant_names.iter().enumerate() { |
| 1251 | + if i == 0 { |
| 1252 | + } else if i == variant_names.len() - 1 { |
| 1253 | + err.push_str(" or "); |
| 1254 | + } else { |
| 1255 | + err.push_str(", "); |
| 1256 | + } |
| 1257 | + err.push_str(&format!("{:?}", name.value())); |
| 1258 | + } |
| 1259 | + LitStr::new(&err, name.span()) |
| 1260 | + }; |
| 1261 | + let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl(); |
| 1262 | + quote! { |
| 1263 | + impl #impl_generics argh::FromArgValue for #name #ty_generics #where_clause { |
| 1264 | + fn from_arg_value(value: &str) |
| 1265 | + -> std::result::Result<Self, String> |
| 1266 | + { |
| 1267 | + Ok(match value { |
| 1268 | + #( |
| 1269 | + #variant_names => #name_repeating::#variant_idents, |
| 1270 | + )* |
| 1271 | + _ => { |
| 1272 | + return Err(#err_literal.to_owned()) |
| 1273 | + } |
| 1274 | + }) |
| 1275 | + } |
| 1276 | + } |
| 1277 | + } |
| 1278 | +} |
| 1279 | + |
| 1280 | +/// Generates an error if the variant is not a field-less variant like `Foo`. |
| 1281 | +fn choice_enum_only_fieldless_variant(errors: &Errors, variant_fields: &syn::Fields) { |
| 1282 | + match variant_fields { |
| 1283 | + syn::Fields::Unit => {} |
| 1284 | + _ => { |
| 1285 | + errors.err( |
| 1286 | + variant_fields, |
| 1287 | + "Choice `enum`s tagged with `#![derive(FromArgValue)]` do not support variants with associated data.", |
| 1288 | + ); |
| 1289 | + } |
| 1290 | + } |
| 1291 | +} |
| 1292 | + |
| 1293 | +fn pascal_to_snake_case(camel: &str) -> String { |
| 1294 | + let mut out = String::with_capacity(camel.len() + 8); |
| 1295 | + for (i, c) in camel.chars().enumerate() { |
| 1296 | + if c.is_uppercase() { |
| 1297 | + if i != 0 { |
| 1298 | + out.push('_'); |
| 1299 | + } |
| 1300 | + out.extend(c.to_lowercase()); |
| 1301 | + } else { |
| 1302 | + out.push(c); |
| 1303 | + } |
| 1304 | + } |
| 1305 | + out |
| 1306 | +} |
0 commit comments