Skip to content

Commit 7a74484

Browse files
committed
feat(ast_tools): allow attrs on types which do not derive the trait
1 parent 0a8303d commit 7a74484

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

tasks/ast_tools/src/parse/attr.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,25 @@ bitflags! {
3030
/// Positions in which an attribute is legal.
3131
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3232
pub struct AttrPositions: u8 {
33-
/// Attribute on a struct
33+
/// Attribute on a struct which derives the trait
3434
const Struct = 1 << 0;
35-
/// Attribute on an enum
36-
const Enum = 1 << 1;
35+
/// Attribute on a struct which doesn't derive the trait
36+
const StructNotDerived = 1 << 1;
37+
/// Attribute on an enum which derives the trait
38+
const Enum = 1 << 2;
39+
/// Attribute on an enum which doesn't derive the trait
40+
const EnumNotDerived = 1 << 3;
3741
/// Attribute on a struct field
38-
const StructField = 1 << 2;
42+
const StructField = 1 << 4;
3943
/// Attribute on an enum variant
40-
const EnumVariant = 1 << 3;
44+
const EnumVariant = 1 << 5;
4145
/// Part of `#[ast]` attr e.g. `visit` in `#[ast(visit)]`
42-
const AstAttr = 1 << 4;
46+
const AstAttr = 1 << 6;
47+
48+
/// Attribute on a struct which may or may not derive the trait
49+
const StructMaybeDerived = Self::Struct.bits() | Self::StructNotDerived.bits();
50+
/// Attribute on an enum which may or may not derive the trait
51+
const EnumMaybeDerived = Self::Enum.bits() | Self::EnumNotDerived.bits();
4352
}
4453
}
4554

tasks/ast_tools/src/parse/parse.rs

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -503,42 +503,59 @@ impl<'c> Parser<'c> {
503503
}
504504

505505
if let Some((processor, positions)) = self.codegen.attr_processor(&attr_name) {
506-
// Check attribute is legal in this position
507-
match type_def {
506+
// Check attribute is legal in this position and this type has the relevant trait
507+
// `#[generate_derive]`-ed on it (unless the derive stated legal positions as
508+
// `AttrPositions::StructNotDerived` or `AttrPositions::EnumNotDerived`)
509+
let location = match type_def {
508510
TypeDef::Struct(struct_def) => {
511+
let found_in_positions = match processor {
512+
AttrProcessor::Derive(derive_id) => {
513+
let is_derived = struct_def.generates_derive(derive_id);
514+
if is_derived {
515+
AttrPositions::Struct
516+
} else {
517+
AttrPositions::StructNotDerived
518+
}
519+
}
520+
AttrProcessor::Generator(_) => AttrPositions::StructMaybeDerived,
521+
};
522+
509523
check_attr_position(
510524
positions,
511-
AttrPositions::Struct,
525+
found_in_positions,
512526
struct_def.name(),
513527
&attr_name,
514528
"struct",
515529
);
530+
531+
AttrLocation::Struct(struct_def)
516532
}
517533
TypeDef::Enum(enum_def) => {
534+
let found_in_positions = match processor {
535+
AttrProcessor::Derive(derive_id) => {
536+
let is_derived = enum_def.generates_derive(derive_id);
537+
if is_derived {
538+
AttrPositions::Enum
539+
} else {
540+
AttrPositions::EnumNotDerived
541+
}
542+
}
543+
AttrProcessor::Generator(_) => AttrPositions::EnumMaybeDerived,
544+
};
545+
518546
check_attr_position(
519547
positions,
520-
AttrPositions::Enum,
548+
found_in_positions,
521549
enum_def.name(),
522550
&attr_name,
523551
"enum",
524552
);
525-
}
526-
_ => unreachable!(),
527-
}
528-
529-
// Check this type has the relevant trait `#[generate_derive]`-ed on it
530-
check_attr_is_derived(
531-
processor,
532-
type_def.generated_derives(),
533-
type_def.name(),
534-
&attr_name,
535-
);
536553

537-
let location = match type_def {
538-
TypeDef::Struct(struct_def) => AttrLocation::Struct(struct_def),
539-
TypeDef::Enum(enum_def) => AttrLocation::Enum(enum_def),
554+
AttrLocation::Enum(enum_def)
555+
}
540556
_ => unreachable!(),
541557
};
558+
542559
let result = process_attr(processor, &attr_name, location, &attr.meta);
543560
assert!(
544561
result.is_ok(),
@@ -770,13 +787,13 @@ fn check_attr_is_derived(
770787
/// Check attribute is in a legal position.
771788
fn check_attr_position(
772789
expected_positions: AttrPositions,
773-
found_in_position: AttrPositions,
790+
found_in_positions: AttrPositions,
774791
type_name: &str,
775792
attr_name: &str,
776793
position_debug_str: &str,
777794
) {
778795
assert!(
779-
expected_positions.contains(found_in_position),
796+
expected_positions.intersects(found_in_positions),
780797
"`{type_name}` type has `#[{attr_name}]` attribute on a {position_debug_str}, \
781798
but `#[{attr_name}]` is not legal in this position."
782799
);

0 commit comments

Comments
 (0)