query_group_macro/
lib.rs

1//! A macro that mimics the old Salsa-style `#[query_group]` macro.
2
3use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9    GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10    Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::parse::{Parse, ParseStream};
14use syn::punctuated::Punctuated;
15use syn::spanned::Spanned;
16use syn::visit_mut::VisitMut;
17use syn::{
18    Attribute, FnArg, ItemTrait, Path, Token, TraitItem, TraitItemFn, parse_quote,
19    parse_quote_spanned,
20};
21
22mod queries;
23
24#[proc_macro_attribute]
25pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
26    match query_group_impl(args, input.clone()) {
27        Ok(tokens) => tokens,
28        Err(e) => token_stream_with_error(input, e),
29    }
30}
31
32#[derive(Debug)]
33struct InputStructField {
34    name: proc_macro2::TokenStream,
35    ty: proc_macro2::TokenStream,
36}
37
38impl fmt::Display for InputStructField {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.name)
41    }
42}
43
44struct SalsaAttr {
45    name: String,
46    tts: TokenStream,
47    span: Span,
48}
49
50impl std::fmt::Debug for SalsaAttr {
51    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(fmt, "{:?}", self.name)
53    }
54}
55
56impl TryFrom<syn::Attribute> for SalsaAttr {
57    type Error = syn::Attribute;
58
59    fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
60        if is_not_salsa_attr_path(attr.path()) {
61            return Err(attr);
62        }
63
64        let span = attr.span();
65
66        let name = attr.path().segments[1].ident.to_string();
67        let tts = match attr.meta {
68            syn::Meta::Path(path) => path.into_token_stream(),
69            syn::Meta::List(ref list) => {
70                let tts = list
71                    .into_token_stream()
72                    .into_iter()
73                    .skip(attr.path().to_token_stream().into_iter().count());
74                proc_macro2::TokenStream::from_iter(tts)
75            }
76            syn::Meta::NameValue(nv) => nv.into_token_stream(),
77        }
78        .into();
79
80        Ok(SalsaAttr { name, tts, span })
81    }
82}
83
84fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
85    path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
86}
87
88fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
89    let mut other = vec![];
90    let mut salsa = vec![];
91    // Leave non-salsa attributes untouched. These are
92    // attributes that don't start with `salsa::` or don't have
93    // exactly two segments in their path.
94    for attr in attrs {
95        match SalsaAttr::try_from(attr) {
96            Ok(it) => salsa.push(it),
97            Err(it) => other.push(it),
98        }
99    }
100    (other, salsa)
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104enum QueryKind {
105    Input,
106    Tracked,
107    TrackedWithSalsaStruct,
108    Transparent,
109    Interned,
110}
111
112#[derive(Default, Debug, Clone)]
113struct Cycle {
114    cycle_fn: Option<(syn::Ident, Path)>,
115    cycle_initial: Option<(syn::Ident, Path)>,
116    cycle_result: Option<(syn::Ident, Path)>,
117}
118
119impl Parse for Cycle {
120    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
121        let options = Punctuated::<Option, Token![,]>::parse_terminated(input)?;
122        let mut cycle_fn = None;
123        let mut cycle_initial = None;
124        let mut cycle_result = None;
125        for option in options {
126            let name = option.name.to_string();
127            match &*name {
128                "cycle_fn" => {
129                    if cycle_fn.is_some() {
130                        return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
131                    }
132                    cycle_fn = Some((option.name, option.value));
133                }
134                "cycle_initial" => {
135                    if cycle_initial.is_some() {
136                        return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
137                    }
138                    cycle_initial = Some((option.name, option.value));
139                }
140                "cycle_result" => {
141                    if cycle_result.is_some() {
142                        return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
143                    }
144                    cycle_result = Some((option.name, option.value));
145                }
146                _ => {
147                    return Err(syn::Error::new_spanned(
148                        &option.name,
149                        "unknown cycle option. Accepted values: `cycle_result`, `cycle_fn`, `cycle_initial`",
150                    ));
151                }
152            }
153        }
154        return Ok(Self { cycle_fn, cycle_initial, cycle_result });
155
156        struct Option {
157            name: syn::Ident,
158            value: Path,
159        }
160
161        impl Parse for Option {
162            fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
163                let name = input.parse()?;
164                input.parse::<Token![=]>()?;
165                let value = input.parse()?;
166                Ok(Self { name, value })
167            }
168        }
169    }
170}
171
172pub(crate) fn query_group_impl(
173    _args: proc_macro::TokenStream,
174    input: proc_macro::TokenStream,
175) -> Result<proc_macro::TokenStream, syn::Error> {
176    let mut item_trait = syn::parse::<ItemTrait>(input)?;
177
178    let supertraits = &item_trait.supertraits;
179
180    let db_attr: Attribute = parse_quote! {
181        #[salsa_macros::db]
182    };
183    item_trait.attrs.push(db_attr);
184
185    let trait_name_ident = &item_trait.ident.clone();
186    let input_struct_name = format_ident!("{}Data", trait_name_ident);
187    let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
188
189    let mut input_struct_fields: Vec<InputStructField> = vec![];
190    let mut trait_methods = vec![];
191    let mut setter_trait_methods = vec![];
192    let mut lookup_signatures = vec![];
193    let mut lookup_methods = vec![];
194
195    for item in &mut item_trait.items {
196        if let syn::TraitItem::Fn(method) = item {
197            let method_name = &method.sig.ident;
198            let signature = &method.sig;
199
200            let (_attrs, salsa_attrs) = filter_attrs(method.attrs.clone());
201
202            let mut query_kind = QueryKind::TrackedWithSalsaStruct;
203            let mut invoke = None;
204            let mut cycle = None;
205            let mut interned_struct_path = None;
206            let mut lru = None;
207
208            let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
209            let pat_and_tys = params
210                .into_iter()
211                .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
212                .map(|fn_arg| match fn_arg {
213                    FnArg::Typed(pat_type) => pat_type,
214                    FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
215                })
216                .collect::<Vec<syn::PatType>>();
217
218            for SalsaAttr { name, tts, span } in salsa_attrs {
219                match name.as_str() {
220                    "cycle" => {
221                        let c = syn::parse::<Parenthesized<Cycle>>(tts)?;
222                        cycle = Some(c.0);
223                    }
224                    "input" => {
225                        if !pat_and_tys.is_empty() {
226                            return Err(syn::Error::new(
227                                span,
228                                "input methods cannot have a parameter",
229                            ));
230                        }
231                        query_kind = QueryKind::Input;
232                    }
233                    "interned" => {
234                        let syn::ReturnType::Type(_, ty) = &signature.output else {
235                            return Err(syn::Error::new(
236                                span,
237                                "interned queries must have return type",
238                            ));
239                        };
240                        let syn::Type::Path(path) = &**ty else {
241                            return Err(syn::Error::new(
242                                span,
243                                "interned queries must have return type",
244                            ));
245                        };
246                        interned_struct_path = Some(path.path.clone());
247                        query_kind = QueryKind::Interned;
248                    }
249                    "invoke_interned" => {
250                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
251                        invoke = Some(path.0.clone());
252                        query_kind = QueryKind::Tracked;
253                    }
254                    "invoke" => {
255                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
256                        invoke = Some(path.0.clone());
257                        if query_kind != QueryKind::Transparent {
258                            query_kind = QueryKind::TrackedWithSalsaStruct;
259                        }
260                    }
261                    "tracked" if method.default.is_some() => {
262                        query_kind = QueryKind::TrackedWithSalsaStruct;
263                    }
264                    "lru" => {
265                        let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
266                        let lru_count = lru_count.0.base10_parse::<u32>()?;
267
268                        lru = Some(lru_count);
269                    }
270                    "transparent" => {
271                        query_kind = QueryKind::Transparent;
272                    }
273                    _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
274                }
275            }
276
277            let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
278                return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
279            };
280
281            if let syn::Type::Path(ref ty_path) = *return_ty
282                && matches!(query_kind, QueryKind::Input)
283            {
284                let field = InputStructField {
285                    name: method_name.to_token_stream(),
286                    ty: ty_path.path.to_token_stream(),
287                };
288
289                input_struct_fields.push(field);
290            }
291
292            if let Some(block) = &mut method.default {
293                SelfToDbRewriter.visit_block_mut(block);
294            }
295
296            match (query_kind, invoke) {
297                // input
298                (QueryKind::Input, None) => {
299                    let query = InputQuery {
300                        signature: method.sig.clone(),
301                        create_data_ident: create_data_ident.clone(),
302                    };
303                    let value = Queries::InputQuery(query);
304                    trait_methods.push(value);
305
306                    let setter = InputSetter {
307                        signature: method.sig.clone(),
308                        return_type: *return_ty.clone(),
309                        create_data_ident: create_data_ident.clone(),
310                    };
311                    setter_trait_methods.push(SetterKind::Plain(setter));
312
313                    let setter = InputSetterWithDurability {
314                        signature: method.sig.clone(),
315                        return_type: *return_ty.clone(),
316                        create_data_ident: create_data_ident.clone(),
317                    };
318                    setter_trait_methods.push(SetterKind::WithDurability(setter));
319                }
320                (QueryKind::Interned, None) => {
321                    let interned_struct_path = interned_struct_path.unwrap();
322                    let method = Intern {
323                        signature: signature.clone(),
324                        pat_and_tys: pat_and_tys.clone(),
325                        interned_struct_path: interned_struct_path.clone(),
326                    };
327
328                    trait_methods.push(Queries::Intern(method));
329
330                    let mut method = Lookup {
331                        signature: signature.clone(),
332                        pat_and_tys: pat_and_tys.clone(),
333                        return_ty: *return_ty,
334                        interned_struct_path,
335                    };
336                    method.prepare_signature();
337
338                    lookup_signatures
339                        .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
340                    lookup_methods.push(method);
341                }
342                // tracked function. it might have an invoke, or might not.
343                (QueryKind::Tracked, invoke) => {
344                    let method = TrackedQuery {
345                        trait_name: trait_name_ident.clone(),
346                        generated_struct: Some(GeneratedInputStruct {
347                            input_struct_name: input_struct_name.clone(),
348                            create_data_ident: create_data_ident.clone(),
349                        }),
350                        signature: signature.clone(),
351                        pat_and_tys: pat_and_tys.clone(),
352                        invoke,
353                        cycle,
354                        lru,
355                        default: method.default.take(),
356                    };
357
358                    trait_methods.push(Queries::TrackedQuery(method));
359                }
360                (QueryKind::TrackedWithSalsaStruct, invoke) => {
361                    let method = TrackedQuery {
362                        trait_name: trait_name_ident.clone(),
363                        generated_struct: None,
364                        signature: signature.clone(),
365                        pat_and_tys: pat_and_tys.clone(),
366                        invoke,
367                        cycle,
368                        lru,
369                        default: method.default.take(),
370                    };
371
372                    trait_methods.push(Queries::TrackedQuery(method))
373                }
374                (QueryKind::Transparent, invoke) => {
375                    let method = Transparent {
376                        signature: method.sig.clone(),
377                        pat_and_tys: pat_and_tys.clone(),
378                        invoke,
379                        default: method.default.take(),
380                    };
381                    trait_methods.push(Queries::Transparent(method));
382                }
383                // error/invalid constructions
384                (QueryKind::Interned, Some(path)) => {
385                    return Err(syn::Error::new(
386                        path.span(),
387                        "Interned queries cannot be used with an `#[invoke]`".to_string(),
388                    ));
389                }
390                (QueryKind::Input, Some(path)) => {
391                    return Err(syn::Error::new(
392                        path.span(),
393                        "Inputs cannot be used with an `#[invoke]`".to_string(),
394                    ));
395                }
396            }
397        }
398    }
399
400    let fields = input_struct_fields
401        .into_iter()
402        .map(|input| {
403            let name = input.name;
404            let ret = input.ty;
405            quote! { #name: Option<#ret> }
406        })
407        .collect::<Vec<proc_macro2::TokenStream>>();
408
409    let input_struct = quote! {
410        #[salsa_macros::input]
411        pub(crate) struct #input_struct_name {
412            #(#fields),*
413        }
414    };
415
416    let field_params = std::iter::repeat_n(quote! { None }, fields.len())
417        .collect::<Vec<proc_macro2::TokenStream>>();
418
419    let create_data_method = quote! {
420        #[allow(non_snake_case)]
421        #[salsa_macros::tracked]
422        fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
423            #input_struct_name::new(db, #(#field_params),*)
424        }
425    };
426
427    let mut setter_signatures = vec![];
428    let mut setter_methods = vec![];
429    for trait_item in setter_trait_methods
430        .iter()
431        .map(|method| method.to_token_stream())
432        .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
433    {
434        let mut methods_sans_body = trait_item.clone();
435        methods_sans_body.default = None;
436        methods_sans_body.semi_token = Some(syn::Token![;](trait_item.span()));
437
438        setter_signatures.push(TraitItem::Fn(methods_sans_body));
439        setter_methods.push(TraitItem::Fn(trait_item));
440    }
441
442    item_trait.items.append(&mut setter_signatures);
443    item_trait.items.append(&mut lookup_signatures);
444
445    let trait_impl = quote! {
446        #[salsa_macros::db]
447        impl<DB> #trait_name_ident for DB
448        where
449            DB: #supertraits,
450        {
451            #(#trait_methods)*
452
453            #(#setter_methods)*
454
455            #(#lookup_methods)*
456        }
457    };
458    RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
459
460    let out = quote! {
461        #item_trait
462
463        #trait_impl
464
465        #input_struct
466
467        #create_data_method
468    }
469    .into();
470
471    Ok(out)
472}
473
474/// Parenthesis helper
475pub(crate) struct Parenthesized<T>(pub(crate) T);
476
477impl<T> syn::parse::Parse for Parenthesized<T>
478where
479    T: syn::parse::Parse,
480{
481    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
482        let content;
483        syn::parenthesized!(content in input);
484        content.parse::<T>().map(Parenthesized)
485    }
486}
487
488fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
489    TraitItemFn {
490        attrs: vec![],
491        sig: sig.clone(),
492        semi_token: Some(syn::Token![;](sig.span())),
493        default: None,
494    }
495}
496
497struct RemoveAttrsFromTraitMethods;
498
499impl VisitMut for RemoveAttrsFromTraitMethods {
500    fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
501        for item in &mut i.items {
502            if let TraitItem::Fn(trait_item_fn) = item {
503                trait_item_fn.attrs = vec![];
504            }
505        }
506    }
507}
508
509pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
510    tokens.extend(TokenStream::from(error.into_compile_error()));
511    tokens
512}
513
514struct SelfToDbRewriter;
515
516impl VisitMut for SelfToDbRewriter {
517    fn visit_expr_path_mut(&mut self, i: &mut syn::ExprPath) {
518        if i.path.is_ident("self") {
519            i.path = parse_quote_spanned!(i.path.span() => db);
520        }
521    }
522}