1use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9 GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10 Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::parse::{Parse, ParseStream};
14use syn::punctuated::Punctuated;
15use syn::spanned::Spanned;
16use syn::visit_mut::VisitMut;
17use syn::{
18 Attribute, FnArg, ItemTrait, Path, Token, TraitItem, TraitItemFn, parse_quote,
19 parse_quote_spanned,
20};
21
22mod queries;
23
24#[proc_macro_attribute]
25pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
26 match query_group_impl(args, input.clone()) {
27 Ok(tokens) => tokens,
28 Err(e) => token_stream_with_error(input, e),
29 }
30}
31
32#[derive(Debug)]
33struct InputStructField {
34 name: proc_macro2::TokenStream,
35 ty: proc_macro2::TokenStream,
36}
37
38impl fmt::Display for InputStructField {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.name)
41 }
42}
43
44struct SalsaAttr {
45 name: String,
46 tts: TokenStream,
47 span: Span,
48}
49
50impl std::fmt::Debug for SalsaAttr {
51 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(fmt, "{:?}", self.name)
53 }
54}
55
56impl TryFrom<syn::Attribute> for SalsaAttr {
57 type Error = syn::Attribute;
58
59 fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
60 if is_not_salsa_attr_path(attr.path()) {
61 return Err(attr);
62 }
63
64 let span = attr.span();
65
66 let name = attr.path().segments[1].ident.to_string();
67 let tts = match attr.meta {
68 syn::Meta::Path(path) => path.into_token_stream(),
69 syn::Meta::List(ref list) => {
70 let tts = list
71 .into_token_stream()
72 .into_iter()
73 .skip(attr.path().to_token_stream().into_iter().count());
74 proc_macro2::TokenStream::from_iter(tts)
75 }
76 syn::Meta::NameValue(nv) => nv.into_token_stream(),
77 }
78 .into();
79
80 Ok(SalsaAttr { name, tts, span })
81 }
82}
83
84fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
85 path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
86}
87
88fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
89 let mut other = vec![];
90 let mut salsa = vec![];
91 for attr in attrs {
95 match SalsaAttr::try_from(attr) {
96 Ok(it) => salsa.push(it),
97 Err(it) => other.push(it),
98 }
99 }
100 (other, salsa)
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104enum QueryKind {
105 Input,
106 Tracked,
107 TrackedWithSalsaStruct,
108 Transparent,
109 Interned,
110}
111
112#[derive(Default, Debug, Clone)]
113struct Cycle {
114 cycle_fn: Option<(syn::Ident, Path)>,
115 cycle_initial: Option<(syn::Ident, Path)>,
116 cycle_result: Option<(syn::Ident, Path)>,
117}
118
119impl Parse for Cycle {
120 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
121 let options = Punctuated::<Option, Token![,]>::parse_terminated(input)?;
122 let mut cycle_fn = None;
123 let mut cycle_initial = None;
124 let mut cycle_result = None;
125 for option in options {
126 let name = option.name.to_string();
127 match &*name {
128 "cycle_fn" => {
129 if cycle_fn.is_some() {
130 return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
131 }
132 cycle_fn = Some((option.name, option.value));
133 }
134 "cycle_initial" => {
135 if cycle_initial.is_some() {
136 return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
137 }
138 cycle_initial = Some((option.name, option.value));
139 }
140 "cycle_result" => {
141 if cycle_result.is_some() {
142 return Err(syn::Error::new_spanned(&option.name, "duplicate option"));
143 }
144 cycle_result = Some((option.name, option.value));
145 }
146 _ => {
147 return Err(syn::Error::new_spanned(
148 &option.name,
149 "unknown cycle option. Accepted values: `cycle_result`, `cycle_fn`, `cycle_initial`",
150 ));
151 }
152 }
153 }
154 return Ok(Self { cycle_fn, cycle_initial, cycle_result });
155
156 struct Option {
157 name: syn::Ident,
158 value: Path,
159 }
160
161 impl Parse for Option {
162 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
163 let name = input.parse()?;
164 input.parse::<Token![=]>()?;
165 let value = input.parse()?;
166 Ok(Self { name, value })
167 }
168 }
169 }
170}
171
172pub(crate) fn query_group_impl(
173 _args: proc_macro::TokenStream,
174 input: proc_macro::TokenStream,
175) -> Result<proc_macro::TokenStream, syn::Error> {
176 let mut item_trait = syn::parse::<ItemTrait>(input)?;
177
178 let supertraits = &item_trait.supertraits;
179
180 let db_attr: Attribute = parse_quote! {
181 #[salsa_macros::db]
182 };
183 item_trait.attrs.push(db_attr);
184
185 let trait_name_ident = &item_trait.ident.clone();
186 let input_struct_name = format_ident!("{}Data", trait_name_ident);
187 let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
188
189 let mut input_struct_fields: Vec<InputStructField> = vec![];
190 let mut trait_methods = vec![];
191 let mut setter_trait_methods = vec![];
192 let mut lookup_signatures = vec![];
193 let mut lookup_methods = vec![];
194
195 for item in &mut item_trait.items {
196 if let syn::TraitItem::Fn(method) = item {
197 let method_name = &method.sig.ident;
198 let signature = &method.sig;
199
200 let (_attrs, salsa_attrs) = filter_attrs(method.attrs.clone());
201
202 let mut query_kind = QueryKind::TrackedWithSalsaStruct;
203 let mut invoke = None;
204 let mut cycle = None;
205 let mut interned_struct_path = None;
206 let mut lru = None;
207
208 let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
209 let pat_and_tys = params
210 .into_iter()
211 .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
212 .map(|fn_arg| match fn_arg {
213 FnArg::Typed(pat_type) => pat_type,
214 FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
215 })
216 .collect::<Vec<syn::PatType>>();
217
218 for SalsaAttr { name, tts, span } in salsa_attrs {
219 match name.as_str() {
220 "cycle" => {
221 let c = syn::parse::<Parenthesized<Cycle>>(tts)?;
222 cycle = Some(c.0);
223 }
224 "input" => {
225 if !pat_and_tys.is_empty() {
226 return Err(syn::Error::new(
227 span,
228 "input methods cannot have a parameter",
229 ));
230 }
231 query_kind = QueryKind::Input;
232 }
233 "interned" => {
234 let syn::ReturnType::Type(_, ty) = &signature.output else {
235 return Err(syn::Error::new(
236 span,
237 "interned queries must have return type",
238 ));
239 };
240 let syn::Type::Path(path) = &**ty else {
241 return Err(syn::Error::new(
242 span,
243 "interned queries must have return type",
244 ));
245 };
246 interned_struct_path = Some(path.path.clone());
247 query_kind = QueryKind::Interned;
248 }
249 "invoke_interned" => {
250 let path = syn::parse::<Parenthesized<Path>>(tts)?;
251 invoke = Some(path.0.clone());
252 query_kind = QueryKind::Tracked;
253 }
254 "invoke" => {
255 let path = syn::parse::<Parenthesized<Path>>(tts)?;
256 invoke = Some(path.0.clone());
257 if query_kind != QueryKind::Transparent {
258 query_kind = QueryKind::TrackedWithSalsaStruct;
259 }
260 }
261 "tracked" if method.default.is_some() => {
262 query_kind = QueryKind::TrackedWithSalsaStruct;
263 }
264 "lru" => {
265 let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
266 let lru_count = lru_count.0.base10_parse::<u32>()?;
267
268 lru = Some(lru_count);
269 }
270 "transparent" => {
271 query_kind = QueryKind::Transparent;
272 }
273 _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
274 }
275 }
276
277 let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
278 return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
279 };
280
281 if let syn::Type::Path(ref ty_path) = *return_ty
282 && matches!(query_kind, QueryKind::Input)
283 {
284 let field = InputStructField {
285 name: method_name.to_token_stream(),
286 ty: ty_path.path.to_token_stream(),
287 };
288
289 input_struct_fields.push(field);
290 }
291
292 if let Some(block) = &mut method.default {
293 SelfToDbRewriter.visit_block_mut(block);
294 }
295
296 match (query_kind, invoke) {
297 (QueryKind::Input, None) => {
299 let query = InputQuery {
300 signature: method.sig.clone(),
301 create_data_ident: create_data_ident.clone(),
302 };
303 let value = Queries::InputQuery(query);
304 trait_methods.push(value);
305
306 let setter = InputSetter {
307 signature: method.sig.clone(),
308 return_type: *return_ty.clone(),
309 create_data_ident: create_data_ident.clone(),
310 };
311 setter_trait_methods.push(SetterKind::Plain(setter));
312
313 let setter = InputSetterWithDurability {
314 signature: method.sig.clone(),
315 return_type: *return_ty.clone(),
316 create_data_ident: create_data_ident.clone(),
317 };
318 setter_trait_methods.push(SetterKind::WithDurability(setter));
319 }
320 (QueryKind::Interned, None) => {
321 let interned_struct_path = interned_struct_path.unwrap();
322 let method = Intern {
323 signature: signature.clone(),
324 pat_and_tys: pat_and_tys.clone(),
325 interned_struct_path: interned_struct_path.clone(),
326 };
327
328 trait_methods.push(Queries::Intern(method));
329
330 let mut method = Lookup {
331 signature: signature.clone(),
332 pat_and_tys: pat_and_tys.clone(),
333 return_ty: *return_ty,
334 interned_struct_path,
335 };
336 method.prepare_signature();
337
338 lookup_signatures
339 .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
340 lookup_methods.push(method);
341 }
342 (QueryKind::Tracked, invoke) => {
344 let method = TrackedQuery {
345 trait_name: trait_name_ident.clone(),
346 generated_struct: Some(GeneratedInputStruct {
347 input_struct_name: input_struct_name.clone(),
348 create_data_ident: create_data_ident.clone(),
349 }),
350 signature: signature.clone(),
351 pat_and_tys: pat_and_tys.clone(),
352 invoke,
353 cycle,
354 lru,
355 default: method.default.take(),
356 };
357
358 trait_methods.push(Queries::TrackedQuery(method));
359 }
360 (QueryKind::TrackedWithSalsaStruct, invoke) => {
361 let method = TrackedQuery {
362 trait_name: trait_name_ident.clone(),
363 generated_struct: None,
364 signature: signature.clone(),
365 pat_and_tys: pat_and_tys.clone(),
366 invoke,
367 cycle,
368 lru,
369 default: method.default.take(),
370 };
371
372 trait_methods.push(Queries::TrackedQuery(method))
373 }
374 (QueryKind::Transparent, invoke) => {
375 let method = Transparent {
376 signature: method.sig.clone(),
377 pat_and_tys: pat_and_tys.clone(),
378 invoke,
379 default: method.default.take(),
380 };
381 trait_methods.push(Queries::Transparent(method));
382 }
383 (QueryKind::Interned, Some(path)) => {
385 return Err(syn::Error::new(
386 path.span(),
387 "Interned queries cannot be used with an `#[invoke]`".to_string(),
388 ));
389 }
390 (QueryKind::Input, Some(path)) => {
391 return Err(syn::Error::new(
392 path.span(),
393 "Inputs cannot be used with an `#[invoke]`".to_string(),
394 ));
395 }
396 }
397 }
398 }
399
400 let fields = input_struct_fields
401 .into_iter()
402 .map(|input| {
403 let name = input.name;
404 let ret = input.ty;
405 quote! { #name: Option<#ret> }
406 })
407 .collect::<Vec<proc_macro2::TokenStream>>();
408
409 let input_struct = quote! {
410 #[salsa_macros::input]
411 pub(crate) struct #input_struct_name {
412 #(#fields),*
413 }
414 };
415
416 let field_params = std::iter::repeat_n(quote! { None }, fields.len())
417 .collect::<Vec<proc_macro2::TokenStream>>();
418
419 let create_data_method = quote! {
420 #[allow(non_snake_case)]
421 #[salsa_macros::tracked]
422 fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
423 #input_struct_name::new(db, #(#field_params),*)
424 }
425 };
426
427 let mut setter_signatures = vec![];
428 let mut setter_methods = vec![];
429 for trait_item in setter_trait_methods
430 .iter()
431 .map(|method| method.to_token_stream())
432 .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
433 {
434 let mut methods_sans_body = trait_item.clone();
435 methods_sans_body.default = None;
436 methods_sans_body.semi_token = Some(syn::Token));
437
438 setter_signatures.push(TraitItem::Fn(methods_sans_body));
439 setter_methods.push(TraitItem::Fn(trait_item));
440 }
441
442 item_trait.items.append(&mut setter_signatures);
443 item_trait.items.append(&mut lookup_signatures);
444
445 let trait_impl = quote! {
446 #[salsa_macros::db]
447 impl<DB> #trait_name_ident for DB
448 where
449 DB: #supertraits,
450 {
451 #(#trait_methods)*
452
453 #(#setter_methods)*
454
455 #(#lookup_methods)*
456 }
457 };
458 RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
459
460 let out = quote! {
461 #item_trait
462
463 #trait_impl
464
465 #input_struct
466
467 #create_data_method
468 }
469 .into();
470
471 Ok(out)
472}
473
474pub(crate) struct Parenthesized<T>(pub(crate) T);
476
477impl<T> syn::parse::Parse for Parenthesized<T>
478where
479 T: syn::parse::Parse,
480{
481 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
482 let content;
483 syn::parenthesized!(content in input);
484 content.parse::<T>().map(Parenthesized)
485 }
486}
487
488fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
489 TraitItemFn {
490 attrs: vec![],
491 sig: sig.clone(),
492 semi_token: Some(syn::Token)),
493 default: None,
494 }
495}
496
497struct RemoveAttrsFromTraitMethods;
498
499impl VisitMut for RemoveAttrsFromTraitMethods {
500 fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
501 for item in &mut i.items {
502 if let TraitItem::Fn(trait_item_fn) = item {
503 trait_item_fn.attrs = vec![];
504 }
505 }
506 }
507}
508
509pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
510 tokens.extend(TokenStream::from(error.into_compile_error()));
511 tokens
512}
513
514struct SelfToDbRewriter;
515
516impl VisitMut for SelfToDbRewriter {
517 fn visit_expr_path_mut(&mut self, i: &mut syn::ExprPath) {
518 if i.path.is_ident("self") {
519 i.path = parse_quote_spanned!(i.path.span() => db);
520 }
521 }
522}