swc_eq_ignore_macros/
lib.rs

1use proc_macro2::Span;
2use quote::quote;
3use syn::{
4    parse, parse_quote, punctuated::Punctuated, spanned::Spanned, Arm, BinOp, Block, Data,
5    DeriveInput, Expr, ExprBinary, ExprBlock, Field, FieldPat, Fields, Ident, Index, Member, Pat,
6    PatIdent, PatRest, PatStruct, PatTuple, Path, Stmt, Token,
7};
8
9/// Derives `swc_common::TypeEq`.
10///
11/// - Field annotated with `#[use_eq]` will be compared using `==`.
12/// - Field annotated with `#[not_type]` will be ignored
13#[proc_macro_derive(TypeEq, attributes(not_type, use_eq, use_eq_ignore_span))]
14pub fn derive_type_eq(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
15    Deriver {
16        trait_name: Ident::new("TypeEq", Span::call_site()),
17        method_name: Ident::new("type_eq", Span::call_site()),
18        ignore_field: Box::new(|field| {
19            // Search for `#[not_type]`.
20            for attr in &field.attrs {
21                if attr.path().is_ident("not_type") {
22                    return true;
23                }
24            }
25
26            false
27        }),
28    }
29    .derive(item)
30}
31
32/// Derives `swc_common::EqIgnoreSpan`.
33///
34///
35/// Fields annotated with `#[not_spanned]` or `#[use_eq]` will use` ==` instead
36/// of `eq_ignore_span`.
37#[proc_macro_derive(EqIgnoreSpan, attributes(not_spanned, use_eq))]
38pub fn derive_eq_ignore_span(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
39    Deriver {
40        trait_name: Ident::new("EqIgnoreSpan", Span::call_site()),
41        method_name: Ident::new("eq_ignore_span", Span::call_site()),
42        ignore_field: Box::new(|_field| {
43            // We call eq_ignore_span for all fields.
44            false
45        }),
46    }
47    .derive(item)
48}
49
50struct Deriver {
51    trait_name: Ident,
52    method_name: Ident,
53    ignore_field: Box<dyn Fn(&Field) -> bool>,
54}
55
56impl Deriver {
57    fn derive(&self, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
58        let input: DeriveInput = parse(item).unwrap();
59
60        let body = self.make_body(&input.data);
61
62        let trait_name = &self.trait_name;
63        let ty = &input.ident;
64        let method_name = &self.method_name;
65        quote!(
66            #[automatically_derived]
67            impl ::swc_common::#trait_name for #ty {
68                #[allow(non_snake_case)]
69                fn #method_name(&self, other: &Self) -> bool {
70                    #body
71                }
72            }
73        )
74        .into()
75    }
76
77    fn make_body(&self, data: &Data) -> Expr {
78        match data {
79            Data::Struct(s) => {
80                let arm = self.make_arm_from_fields(parse_quote!(Self), &s.fields);
81
82                parse_quote!(match (self, other) { #arm })
83            }
84            Data::Enum(e) => {
85                //
86                let mut arms = Punctuated::<_, Token![,]>::default();
87                for v in &e.variants {
88                    let vi = &v.ident;
89                    let arm = self.make_arm_from_fields(parse_quote!(Self::#vi), &v.fields);
90
91                    arms.push(arm);
92                }
93
94                arms.push(parse_quote!(_ => false));
95
96                parse_quote!(match (self, other) { #arms })
97            }
98            Data::Union(_) => unimplemented!("union"),
99        }
100    }
101
102    fn make_arm_from_fields(&self, pat_path: Path, fields: &Fields) -> Arm {
103        let mut l_pat_fields = Punctuated::<_, Token![,]>::default();
104        let mut r_pat_fields = Punctuated::<_, Token![,]>::default();
105        let mut exprs = Vec::new();
106
107        for (i, field) in fields
108            .iter()
109            .enumerate()
110            .filter(|(_, f)| !(self.ignore_field)(f))
111        {
112            let method_name =
113                if field.attrs.iter().any(|attr| {
114                    attr.path().is_ident("not_spanned") || attr.path().is_ident("use_eq")
115                }) {
116                    Ident::new("eq", Span::call_site())
117                } else if field
118                    .attrs
119                    .iter()
120                    .any(|attr| attr.path().is_ident("use_eq_ignore_span"))
121                {
122                    Ident::new("eq_ignore_span", Span::call_site())
123                } else {
124                    self.method_name.clone()
125                };
126
127            let base = field
128                .ident
129                .clone()
130                .unwrap_or_else(|| Ident::new(&format!("_{}", i), field.ty.span()));
131            //
132            let l_binding_ident = Ident::new(&format!("_l_{}", base), base.span());
133            let r_binding_ident = Ident::new(&format!("_r_{}", base), base.span());
134
135            let make_pat_field = |ident: &Ident| FieldPat {
136                attrs: Default::default(),
137                member: match &field.ident {
138                    Some(v) => Member::Named(v.clone()),
139                    None => Member::Unnamed(Index {
140                        index: i as _,
141                        span: field.ty.span(),
142                    }),
143                },
144                colon_token: Some(Token![:](ident.span())),
145                pat: Box::new(Pat::Ident(PatIdent {
146                    attrs: Default::default(),
147                    by_ref: Some(Token![ref](ident.span())),
148                    mutability: None,
149                    ident: ident.clone(),
150                    subpat: None,
151                })),
152            };
153
154            l_pat_fields.push(make_pat_field(&l_binding_ident));
155            r_pat_fields.push(make_pat_field(&r_binding_ident));
156
157            exprs.push(parse_quote!(#l_binding_ident.#method_name(#r_binding_ident)));
158        }
159
160        // true && a.type_eq(&other.a) && b.type_eq(&other.b)
161        let mut expr: Expr = parse_quote!(true);
162
163        for expr_el in exprs {
164            expr = Expr::Binary(ExprBinary {
165                attrs: Default::default(),
166                left: Box::new(expr),
167                op: BinOp::And(Token![&&](Span::call_site())),
168                right: Box::new(expr_el),
169            });
170        }
171
172        Arm {
173            attrs: Default::default(),
174            pat: Pat::Tuple(PatTuple {
175                attrs: Default::default(),
176                paren_token: Default::default(),
177                elems: {
178                    let mut elems = Punctuated::default();
179                    elems.push(Pat::Struct(PatStruct {
180                        attrs: Default::default(),
181                        qself: None,
182                        path: pat_path.clone(),
183                        brace_token: Default::default(),
184                        fields: l_pat_fields,
185                        rest: Some(PatRest {
186                            attrs: Default::default(),
187                            dot2_token: Token![..](Span::call_site()),
188                        }),
189                    }));
190                    elems.push(Pat::Struct(PatStruct {
191                        attrs: Default::default(),
192                        qself: None,
193                        path: pat_path,
194                        brace_token: Default::default(),
195                        fields: r_pat_fields,
196                        rest: Some(PatRest {
197                            attrs: Default::default(),
198                            dot2_token: Token![..](Span::call_site()),
199                        }),
200                    }));
201                    elems
202                },
203            }),
204            guard: Default::default(),
205            fat_arrow_token: Token![=>](Span::call_site()),
206            body: Box::new(Expr::Block(ExprBlock {
207                attrs: Default::default(),
208                label: Default::default(),
209                block: Block {
210                    brace_token: Default::default(),
211                    stmts: vec![Stmt::Expr(expr, None)],
212                },
213            })),
214            comma: Default::default(),
215        }
216    }
217}