1use either::Either;
4use intern::sym;
5use itertools::{Itertools, izip};
6use parser::SyntaxKind;
7use rustc_hash::FxHashSet;
8use span::{Edition, Span, SyntaxContext};
9use stdx::never;
10use syntax_bridge::DocCommentDesugarMode;
11use tracing::debug;
12
13use crate::{
14 ExpandError, ExpandResult, MacroCallId,
15 builtin::quote::dollar_crate,
16 db::ExpandDatabase,
17 hygiene::span_with_def_site_ctxt,
18 name::{self, AsName, Name},
19 span_map::ExpansionSpanMap,
20 tt,
21};
22use syntax::{
23 ast::{
24 self, AstNode, FieldList, HasAttrs, HasGenericArgs, HasGenericParams, HasModuleItem,
25 HasName, HasTypeBounds, edit_in_place::GenericParamsOwnerEdit, make,
26 },
27 ted,
28};
29
30macro_rules! register_builtin {
31 ( $($trait:ident => $expand:ident),* ) => {
32 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33 pub enum BuiltinDeriveExpander {
34 $($trait),*
35 }
36
37 impl BuiltinDeriveExpander {
38 pub fn expander(&self) -> fn(&dyn ExpandDatabase, Span, &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
39 match *self {
40 $( BuiltinDeriveExpander::$trait => $expand, )*
41 }
42 }
43
44 fn find_by_name(name: &name::Name) -> Option<Self> {
45 match name {
46 $( id if id == &sym::$trait => Some(BuiltinDeriveExpander::$trait), )*
47 _ => None,
48 }
49 }
50 }
51
52 };
53}
54
55impl BuiltinDeriveExpander {
56 pub fn expand(
57 &self,
58 db: &dyn ExpandDatabase,
59 id: MacroCallId,
60 tt: &tt::TopSubtree,
61 span: Span,
62 ) -> ExpandResult<tt::TopSubtree> {
63 let span = span_with_def_site_ctxt(db, span, id.into(), Edition::CURRENT);
64 self.expander()(db, span, tt)
65 }
66}
67
68register_builtin! {
69 Copy => copy_expand,
70 Clone => clone_expand,
71 Default => default_expand,
72 Debug => debug_expand,
73 Hash => hash_expand,
74 Ord => ord_expand,
75 PartialOrd => partial_ord_expand,
76 Eq => eq_expand,
77 PartialEq => partial_eq_expand,
78 CoercePointee => coerce_pointee_expand
79}
80
81pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> {
82 BuiltinDeriveExpander::find_by_name(ident)
83}
84
85#[derive(Clone)]
86enum VariantShape {
87 Struct(Vec<tt::Ident>),
88 Tuple(usize),
89 Unit,
90}
91
92fn tuple_field_iterator(span: Span, n: usize) -> impl Iterator<Item = tt::Ident> {
93 (0..n).map(move |it| tt::Ident::new(&format!("f{it}"), span))
94}
95
96impl VariantShape {
97 fn as_pattern(&self, path: tt::TopSubtree, span: Span) -> tt::TopSubtree {
98 self.as_pattern_map(path, span, |it| quote!(span => #it))
99 }
100
101 fn field_names(&self, span: Span) -> Vec<tt::Ident> {
102 match self {
103 VariantShape::Struct(s) => s.clone(),
104 VariantShape::Tuple(n) => tuple_field_iterator(span, *n).collect(),
105 VariantShape::Unit => vec![],
106 }
107 }
108
109 fn as_pattern_map(
110 &self,
111 path: tt::TopSubtree,
112 span: Span,
113 field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
114 ) -> tt::TopSubtree {
115 match self {
116 VariantShape::Struct(fields) => {
117 let fields = fields.iter().map(|it| {
118 let mapped = field_map(it);
119 quote! {span => #it : #mapped , }
120 });
121 quote! {span =>
122 #path { # #fields }
123 }
124 }
125 &VariantShape::Tuple(n) => {
126 let fields = tuple_field_iterator(span, n).map(|it| {
127 let mapped = field_map(&it);
128 quote! {span =>
129 #mapped ,
130 }
131 });
132 quote! {span =>
133 #path ( # #fields )
134 }
135 }
136 VariantShape::Unit => path,
137 }
138 }
139
140 fn from(
141 call_site: Span,
142 tm: &ExpansionSpanMap,
143 value: Option<FieldList>,
144 ) -> Result<Self, ExpandError> {
145 let r = match value {
146 None => VariantShape::Unit,
147 Some(FieldList::RecordFieldList(it)) => VariantShape::Struct(
148 it.fields()
149 .map(|it| it.name())
150 .map(|it| name_to_token(call_site, tm, it))
151 .collect::<Result<_, _>>()?,
152 ),
153 Some(FieldList::TupleFieldList(it)) => VariantShape::Tuple(it.fields().count()),
154 };
155 Ok(r)
156 }
157}
158
159#[derive(Clone)]
160enum AdtShape {
161 Struct(VariantShape),
162 Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> },
163 Union,
164}
165
166impl AdtShape {
167 fn as_pattern(&self, span: Span, name: &tt::Ident) -> Vec<tt::TopSubtree> {
168 self.as_pattern_map(name, |it| quote!(span =>#it), span)
169 }
170
171 fn field_names(&self, span: Span) -> Vec<Vec<tt::Ident>> {
172 match self {
173 AdtShape::Struct(s) => {
174 vec![s.field_names(span)]
175 }
176 AdtShape::Enum { variants, .. } => {
177 variants.iter().map(|(_, fields)| fields.field_names(span)).collect()
178 }
179 AdtShape::Union => {
180 never!("using fields of union in derive is always wrong");
181 vec![]
182 }
183 }
184 }
185
186 fn as_pattern_map(
187 &self,
188 name: &tt::Ident,
189 field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
190 span: Span,
191 ) -> Vec<tt::TopSubtree> {
192 match self {
193 AdtShape::Struct(s) => {
194 vec![s.as_pattern_map(quote! {span => #name }, span, field_map)]
195 }
196 AdtShape::Enum { variants, .. } => variants
197 .iter()
198 .map(|(v, fields)| {
199 fields.as_pattern_map(quote! {span => #name :: #v }, span, &field_map)
200 })
201 .collect(),
202 AdtShape::Union => {
203 never!("pattern matching on union is always wrong");
204 vec![quote! {span => un }]
205 }
206 }
207 }
208}
209
210#[derive(Clone)]
211struct BasicAdtInfo {
212 name: tt::Ident,
213 shape: AdtShape,
214 param_types: Vec<AdtParam>,
218 where_clause: Vec<tt::TopSubtree>,
219 associated_types: Vec<tt::TopSubtree>,
220}
221
222#[derive(Clone)]
223struct AdtParam {
224 name: tt::TopSubtree,
225 const_ty: Option<tt::TopSubtree>,
227 bounds: Option<tt::TopSubtree>,
228}
229
230fn parse_adt(
232 db: &dyn ExpandDatabase,
233 tt: &tt::TopSubtree,
234 call_site: Span,
235) -> Result<BasicAdtInfo, ExpandError> {
236 let (adt, tm) = to_adt_syntax(db, tt, call_site)?;
237 parse_adt_from_syntax(&adt, &tm, call_site)
238}
239
240fn parse_adt_from_syntax(
241 adt: &ast::Adt,
242 tm: &span::SpanMap<SyntaxContext>,
243 call_site: Span,
244) -> Result<BasicAdtInfo, ExpandError> {
245 let (name, generic_param_list, where_clause, shape) = match &adt {
246 ast::Adt::Struct(it) => (
247 it.name(),
248 it.generic_param_list(),
249 it.where_clause(),
250 AdtShape::Struct(VariantShape::from(call_site, tm, it.field_list())?),
251 ),
252 ast::Adt::Enum(it) => {
253 let default_variant = it
254 .variant_list()
255 .into_iter()
256 .flat_map(|it| it.variants())
257 .position(|it| it.attrs().any(|it| it.simple_name() == Some("default".into())));
258 (
259 it.name(),
260 it.generic_param_list(),
261 it.where_clause(),
262 AdtShape::Enum {
263 default_variant,
264 variants: it
265 .variant_list()
266 .into_iter()
267 .flat_map(|it| it.variants())
268 .map(|it| {
269 Ok((
270 name_to_token(call_site, tm, it.name())?,
271 VariantShape::from(call_site, tm, it.field_list())?,
272 ))
273 })
274 .collect::<Result<_, ExpandError>>()?,
275 },
276 )
277 }
278 ast::Adt::Union(it) => {
279 (it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union)
280 }
281 };
282
283 let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
284 let param_types = generic_param_list
285 .into_iter()
286 .flat_map(|param_list| param_list.type_or_const_params())
287 .map(|param| {
288 let name = {
289 let this = param.name();
290 match this {
291 Some(it) => {
292 param_type_set.insert(it.as_name());
293 syntax_bridge::syntax_node_to_token_tree(
294 it.syntax(),
295 tm,
296 call_site,
297 DocCommentDesugarMode::ProcMacro,
298 )
299 }
300 None => {
301 tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
302 }
303 }
304 };
305 let bounds = match ¶m {
306 ast::TypeOrConstParam::Type(it) => it.type_bound_list().map(|it| {
307 syntax_bridge::syntax_node_to_token_tree(
308 it.syntax(),
309 tm,
310 call_site,
311 DocCommentDesugarMode::ProcMacro,
312 )
313 }),
314 ast::TypeOrConstParam::Const(_) => None,
315 };
316 let const_ty = if let ast::TypeOrConstParam::Const(param) = param {
317 let ty = param
318 .ty()
319 .map(|ty| {
320 syntax_bridge::syntax_node_to_token_tree(
321 ty.syntax(),
322 tm,
323 call_site,
324 DocCommentDesugarMode::ProcMacro,
325 )
326 })
327 .unwrap_or_else(|| {
328 tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
329 });
330 Some(ty)
331 } else {
332 None
333 };
334 AdtParam { name, const_ty, bounds }
335 })
336 .collect();
337
338 let where_clause = if let Some(w) = where_clause {
339 w.predicates()
340 .map(|it| {
341 syntax_bridge::syntax_node_to_token_tree(
342 it.syntax(),
343 tm,
344 call_site,
345 DocCommentDesugarMode::ProcMacro,
346 )
347 })
348 .collect()
349 } else {
350 vec![]
351 };
352
353 let field_list = match adt {
365 ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
366 ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
367 ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
368 };
369 let associated_types = field_list
370 .into_iter()
371 .flat_map(|it| it.descendants())
372 .filter_map(ast::PathType::cast)
373 .filter_map(|p| {
374 let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
375 param_type_set.contains(&name).then_some(p)
376 })
377 .map(|it| {
378 syntax_bridge::syntax_node_to_token_tree(
379 it.syntax(),
380 tm,
381 call_site,
382 DocCommentDesugarMode::ProcMacro,
383 )
384 })
385 .collect();
386 let name_token = name_to_token(call_site, tm, name)?;
387 Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
388}
389
390fn to_adt_syntax(
391 db: &dyn ExpandDatabase,
392 tt: &tt::TopSubtree,
393 call_site: Span,
394) -> Result<(ast::Adt, span::SpanMap<SyntaxContext>), ExpandError> {
395 let (parsed, tm) = crate::db::token_tree_to_syntax_node(db, tt, crate::ExpandTo::Items);
396 let macro_items = ast::MacroItems::cast(parsed.syntax_node())
397 .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
398 let item =
399 macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
400 let adt = ast::Adt::cast(item.syntax().clone())
401 .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
402 Ok((adt, tm))
403}
404
405fn name_to_token(
406 call_site: Span,
407 token_map: &ExpansionSpanMap,
408 name: Option<ast::Name>,
409) -> Result<tt::Ident, ExpandError> {
410 let name = name.ok_or_else(|| {
411 debug!("parsed item has no name");
412 ExpandError::other(call_site, "missing name")
413 })?;
414 let span = token_map.span_at(name.syntax().text_range().start());
415
416 let name_token = tt::Ident::new(name.text().as_ref(), span);
417 Ok(name_token)
418}
419
420fn expand_simple_derive(
452 db: &dyn ExpandDatabase,
453 invoc_span: Span,
454 tt: &tt::TopSubtree,
455 trait_path: tt::TopSubtree,
456 allow_unions: bool,
457 make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
458) -> ExpandResult<tt::TopSubtree> {
459 let info = match parse_adt(db, tt, invoc_span) {
460 Ok(info) => info,
461 Err(e) => {
462 return ExpandResult::new(
463 tt::TopSubtree::empty(tt::DelimSpan { open: invoc_span, close: invoc_span }),
464 e,
465 );
466 }
467 };
468 if !allow_unions && matches!(info.shape, AdtShape::Union) {
469 return ExpandResult::new(
470 tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
471 ExpandError::other(invoc_span, "this trait cannot be derived for unions"),
472 );
473 }
474 ExpandResult::ok(expand_simple_derive_with_parsed(
475 invoc_span,
476 info,
477 trait_path,
478 make_trait_body,
479 true,
480 tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
481 ))
482}
483
484fn expand_simple_derive_with_parsed(
485 invoc_span: Span,
486 info: BasicAdtInfo,
487 trait_path: tt::TopSubtree,
488 make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
489 constrain_to_trait: bool,
490 extra_impl_params: tt::TopSubtree,
491) -> tt::TopSubtree {
492 let trait_body = make_trait_body(&info);
493 let mut where_block: Vec<_> =
494 info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
495 let (params, args): (Vec<_>, Vec<_>) = info
496 .param_types
497 .into_iter()
498 .map(|param| {
499 let ident = param.name;
500 if let Some(b) = param.bounds {
501 let ident2 = ident.clone();
502 where_block.push(quote! {invoc_span => #ident2 : #b , });
503 }
504 if let Some(ty) = param.const_ty {
505 let ident2 = ident.clone();
506 (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , })
507 } else {
508 let bound = trait_path.clone();
509 let ident2 = ident.clone();
510 let param = if constrain_to_trait {
511 quote! {invoc_span => #ident : #bound , }
512 } else {
513 quote! {invoc_span => #ident , }
514 };
515 (param, quote! {invoc_span => #ident2 , })
516 }
517 })
518 .unzip();
519
520 if constrain_to_trait {
521 where_block.extend(info.associated_types.iter().map(|it| {
522 let it = it.clone();
523 let bound = trait_path.clone();
524 quote! {invoc_span => #it : #bound , }
525 }));
526 }
527
528 let name = info.name;
529 quote! {invoc_span =>
530 impl < # #params #extra_impl_params > #trait_path for #name < # #args > where # #where_block { #trait_body }
531 }
532}
533
534fn copy_expand(
535 db: &dyn ExpandDatabase,
536 span: Span,
537 tt: &tt::TopSubtree,
538) -> ExpandResult<tt::TopSubtree> {
539 let krate = dollar_crate(span);
540 expand_simple_derive(
541 db,
542 span,
543 tt,
544 quote! {span => #krate::marker::Copy },
545 true,
546 |_| quote! {span =>},
547 )
548}
549
550fn clone_expand(
551 db: &dyn ExpandDatabase,
552 span: Span,
553 tt: &tt::TopSubtree,
554) -> ExpandResult<tt::TopSubtree> {
555 let krate = dollar_crate(span);
556 expand_simple_derive(db, span, tt, quote! {span => #krate::clone::Clone }, true, |adt| {
557 if matches!(adt.shape, AdtShape::Union) {
558 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
559 return quote! {span =>
560 fn clone(&self) -> Self {
561 #star self
562 }
563 };
564 }
565 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
566 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
567 return quote! {span =>
568 fn clone(&self) -> Self {
569 match #star self {}
570 }
571 };
572 }
573 let name = &adt.name;
574 let patterns = adt.shape.as_pattern(span, name);
575 let exprs = adt.shape.as_pattern_map(name, |it| quote! {span => #it .clone() }, span);
576 let arms = patterns.into_iter().zip(exprs).map(|(pat, expr)| {
577 let fat_arrow = fat_arrow(span);
578 quote! {span =>
579 #pat #fat_arrow #expr,
580 }
581 });
582
583 quote! {span =>
584 fn clone(&self) -> Self {
585 match self {
586 # #arms
587 }
588 }
589 }
590 })
591}
592
593fn fat_arrow(span: Span) -> tt::TopSubtree {
595 let eq = tt::Punct { char: '=', spacing: ::tt::Spacing::Joint, span };
596 quote! {span => #eq> }
597}
598
599fn and_and(span: Span) -> tt::TopSubtree {
601 let and = tt::Punct { char: '&', spacing: ::tt::Spacing::Joint, span };
602 quote! {span => #and& }
603}
604
605fn default_expand(
606 db: &dyn ExpandDatabase,
607 span: Span,
608 tt: &tt::TopSubtree,
609) -> ExpandResult<tt::TopSubtree> {
610 let krate = &dollar_crate(span);
611 let adt = match parse_adt(db, tt, span) {
612 Ok(info) => info,
613 Err(e) => {
614 return ExpandResult::new(
615 tt::TopSubtree::empty(tt::DelimSpan { open: span, close: span }),
616 e,
617 );
618 }
619 };
620 let (body, constrain_to_trait) = match &adt.shape {
621 AdtShape::Struct(fields) => {
622 let name = &adt.name;
623 let body = fields.as_pattern_map(
624 quote!(span =>#name),
625 span,
626 |_| quote!(span =>#krate::default::Default::default()),
627 );
628 (body, true)
629 }
630 AdtShape::Enum { default_variant, variants } => {
631 if let Some(d) = default_variant {
632 let (name, fields) = &variants[*d];
633 let adt_name = &adt.name;
634 let body = fields.as_pattern_map(
635 quote!(span =>#adt_name :: #name),
636 span,
637 |_| quote!(span =>#krate::default::Default::default()),
638 );
639 (body, false)
640 } else {
641 return ExpandResult::new(
642 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
643 ExpandError::other(span, "`#[derive(Default)]` on enum with no `#[default]`"),
644 );
645 }
646 }
647 AdtShape::Union => {
648 return ExpandResult::new(
649 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
650 ExpandError::other(span, "this trait cannot be derived for unions"),
651 );
652 }
653 };
654 ExpandResult::ok(expand_simple_derive_with_parsed(
655 span,
656 adt,
657 quote! {span => #krate::default::Default },
658 |_adt| {
659 quote! {span =>
660 fn default() -> Self {
661 #body
662 }
663 }
664 },
665 constrain_to_trait,
666 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
667 ))
668}
669
670fn debug_expand(
671 db: &dyn ExpandDatabase,
672 span: Span,
673 tt: &tt::TopSubtree,
674) -> ExpandResult<tt::TopSubtree> {
675 let krate = &dollar_crate(span);
676 expand_simple_derive(db, span, tt, quote! {span => #krate::fmt::Debug }, false, |adt| {
677 let for_variant = |name: String, v: &VariantShape| match v {
678 VariantShape::Struct(fields) => {
679 let for_fields = fields.iter().map(|it| {
680 let x_string = it.to_string();
681 quote! {span =>
682 .field(#x_string, & #it)
683 }
684 });
685 quote! {span =>
686 f.debug_struct(#name) # #for_fields .finish()
687 }
688 }
689 VariantShape::Tuple(n) => {
690 let for_fields = tuple_field_iterator(span, *n).map(|it| {
691 quote! {span =>
692 .field( & #it)
693 }
694 });
695 quote! {span =>
696 f.debug_tuple(#name) # #for_fields .finish()
697 }
698 }
699 VariantShape::Unit => quote! {span =>
700 f.write_str(#name)
701 },
702 };
703 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
704 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
705 return quote! {span =>
706 fn fmt(&self, f: &mut #krate::fmt::Formatter) -> #krate::fmt::Result {
707 match #star self {}
708 }
709 };
710 }
711 let arms = match &adt.shape {
712 AdtShape::Struct(fields) => {
713 let fat_arrow = fat_arrow(span);
714 let name = &adt.name;
715 let pat = fields.as_pattern(quote!(span =>#name), span);
716 let expr = for_variant(name.to_string(), fields);
717 vec![quote! {span => #pat #fat_arrow #expr }]
718 }
719 AdtShape::Enum { variants, .. } => variants
720 .iter()
721 .map(|(name, v)| {
722 let fat_arrow = fat_arrow(span);
723 let adt_name = &adt.name;
724 let pat = v.as_pattern(quote!(span =>#adt_name :: #name), span);
725 let expr = for_variant(name.to_string(), v);
726 quote! {span =>
727 #pat #fat_arrow #expr ,
728 }
729 })
730 .collect(),
731 AdtShape::Union => unreachable!(),
732 };
733 quote! {span =>
734 fn fmt(&self, f: &mut #krate::fmt::Formatter) -> #krate::fmt::Result {
735 match self {
736 # #arms
737 }
738 }
739 }
740 })
741}
742
743fn hash_expand(
744 db: &dyn ExpandDatabase,
745 span: Span,
746 tt: &tt::TopSubtree,
747) -> ExpandResult<tt::TopSubtree> {
748 let krate = &dollar_crate(span);
749 expand_simple_derive(db, span, tt, quote! {span => #krate::hash::Hash }, false, |adt| {
750 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
751 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
752 return quote! {span =>
753 fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
754 match #star self {}
755 }
756 };
757 }
758 let arms =
759 adt.shape.as_pattern(span, &adt.name).into_iter().zip(adt.shape.field_names(span)).map(
760 |(pat, names)| {
761 let expr = {
762 let it =
763 names.iter().map(|it| quote! {span => #it . hash(ra_expand_state); });
764 quote! {span => {
765 # #it
766 } }
767 };
768 let fat_arrow = fat_arrow(span);
769 quote! {span =>
770 #pat #fat_arrow #expr ,
771 }
772 },
773 );
774 let check_discriminant = if matches!(&adt.shape, AdtShape::Enum { .. }) {
775 quote! {span => #krate::mem::discriminant(self).hash(ra_expand_state); }
776 } else {
777 quote! {span =>}
778 };
779 quote! {span =>
780 fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
781 #check_discriminant
782 match self {
783 # #arms
784 }
785 }
786 }
787 })
788}
789
790fn eq_expand(
791 db: &dyn ExpandDatabase,
792 span: Span,
793 tt: &tt::TopSubtree,
794) -> ExpandResult<tt::TopSubtree> {
795 let krate = dollar_crate(span);
796 expand_simple_derive(
797 db,
798 span,
799 tt,
800 quote! {span => #krate::cmp::Eq },
801 true,
802 |_| quote! {span =>},
803 )
804}
805
806fn partial_eq_expand(
807 db: &dyn ExpandDatabase,
808 span: Span,
809 tt: &tt::TopSubtree,
810) -> ExpandResult<tt::TopSubtree> {
811 let krate = dollar_crate(span);
812 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::PartialEq }, false, |adt| {
813 let name = &adt.name;
814
815 let (self_patterns, other_patterns) = self_and_other_patterns(adt, name, span);
816 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
817 |(pat1, pat2, names)| {
818 let fat_arrow = fat_arrow(span);
819 let body = match &*names {
820 [] => {
821 quote!(span =>true)
822 }
823 [first, rest @ ..] => {
824 let rest = rest.iter().map(|it| {
825 let t1 = tt::Ident::new(&format!("{}_self", it.sym), it.span);
826 let t2 = tt::Ident::new(&format!("{}_other", it.sym), it.span);
827 let and_and = and_and(span);
828 quote!(span =>#and_and #t1 .eq( #t2 ))
829 });
830 let first = {
831 let t1 = tt::Ident::new(&format!("{}_self", first.sym), first.span);
832 let t2 = tt::Ident::new(&format!("{}_other", first.sym), first.span);
833 quote!(span =>#t1 .eq( #t2 ))
834 };
835 quote!(span =>#first # #rest)
836 }
837 };
838 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
839 },
840 );
841
842 let fat_arrow = fat_arrow(span);
843 quote! {span =>
844 fn eq(&self, other: &Self) -> bool {
845 match (self, other) {
846 # #arms
847 _unused #fat_arrow false
848 }
849 }
850 }
851 })
852}
853
854fn self_and_other_patterns(
855 adt: &BasicAdtInfo,
856 name: &tt::Ident,
857 span: Span,
858) -> (Vec<tt::TopSubtree>, Vec<tt::TopSubtree>) {
859 let self_patterns = adt.shape.as_pattern_map(
860 name,
861 |it| {
862 let t = tt::Ident::new(&format!("{}_self", it.sym), it.span);
863 quote!(span =>#t)
864 },
865 span,
866 );
867 let other_patterns = adt.shape.as_pattern_map(
868 name,
869 |it| {
870 let t = tt::Ident::new(&format!("{}_other", it.sym), it.span);
871 quote!(span =>#t)
872 },
873 span,
874 );
875 (self_patterns, other_patterns)
876}
877
878fn ord_expand(
879 db: &dyn ExpandDatabase,
880 span: Span,
881 tt: &tt::TopSubtree,
882) -> ExpandResult<tt::TopSubtree> {
883 let krate = &dollar_crate(span);
884 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::Ord }, false, |adt| {
885 fn compare(
886 krate: &tt::Ident,
887 left: tt::TopSubtree,
888 right: tt::TopSubtree,
889 rest: tt::TopSubtree,
890 span: Span,
891 ) -> tt::TopSubtree {
892 let fat_arrow1 = fat_arrow(span);
893 let fat_arrow2 = fat_arrow(span);
894 quote! {span =>
895 match #left.cmp(&#right) {
896 #krate::cmp::Ordering::Equal #fat_arrow1 {
897 #rest
898 }
899 c #fat_arrow2 return c,
900 }
901 }
902 }
903 let (self_patterns, other_patterns) = self_and_other_patterns(adt, &adt.name, span);
904 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
905 |(pat1, pat2, fields)| {
906 let mut body = quote!(span =>#krate::cmp::Ordering::Equal);
907 for f in fields.into_iter().rev() {
908 let t1 = tt::Ident::new(&format!("{}_self", f.sym), f.span);
909 let t2 = tt::Ident::new(&format!("{}_other", f.sym), f.span);
910 body = compare(krate, quote!(span =>#t1), quote!(span =>#t2), body, span);
911 }
912 let fat_arrow = fat_arrow(span);
913 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
914 },
915 );
916 let fat_arrow = fat_arrow(span);
917 let mut body = quote! {span =>
918 match (self, other) {
919 # #arms
920 _unused #fat_arrow #krate::cmp::Ordering::Equal
921 }
922 };
923 if matches!(&adt.shape, AdtShape::Enum { .. }) {
924 let left = quote!(span =>#krate::intrinsics::discriminant_value(self));
925 let right = quote!(span =>#krate::intrinsics::discriminant_value(other));
926 body = compare(krate, left, right, body, span);
927 }
928 quote! {span =>
929 fn cmp(&self, other: &Self) -> #krate::cmp::Ordering {
930 #body
931 }
932 }
933 })
934}
935
936fn partial_ord_expand(
937 db: &dyn ExpandDatabase,
938 span: Span,
939 tt: &tt::TopSubtree,
940) -> ExpandResult<tt::TopSubtree> {
941 let krate = &dollar_crate(span);
942 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::PartialOrd }, false, |adt| {
943 fn compare(
944 krate: &tt::Ident,
945 left: tt::TopSubtree,
946 right: tt::TopSubtree,
947 rest: tt::TopSubtree,
948 span: Span,
949 ) -> tt::TopSubtree {
950 let fat_arrow1 = fat_arrow(span);
951 let fat_arrow2 = fat_arrow(span);
952 quote! {span =>
953 match #left.partial_cmp(&#right) {
954 #krate::option::Option::Some(#krate::cmp::Ordering::Equal) #fat_arrow1 {
955 #rest
956 }
957 c #fat_arrow2 return c,
958 }
959 }
960 }
961 let left = quote!(span =>#krate::intrinsics::discriminant_value(self));
962 let right = quote!(span =>#krate::intrinsics::discriminant_value(other));
963
964 let (self_patterns, other_patterns) = self_and_other_patterns(adt, &adt.name, span);
965 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
966 |(pat1, pat2, fields)| {
967 let mut body =
968 quote!(span =>#krate::option::Option::Some(#krate::cmp::Ordering::Equal));
969 for f in fields.into_iter().rev() {
970 let t1 = tt::Ident::new(&format!("{}_self", f.sym), f.span);
971 let t2 = tt::Ident::new(&format!("{}_other", f.sym), f.span);
972 body = compare(krate, quote!(span =>#t1), quote!(span =>#t2), body, span);
973 }
974 let fat_arrow = fat_arrow(span);
975 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
976 },
977 );
978 let fat_arrow = fat_arrow(span);
979 let body = compare(
980 krate,
981 left,
982 right,
983 quote! {span =>
984 match (self, other) {
985 # #arms
986 _unused #fat_arrow #krate::option::Option::Some(#krate::cmp::Ordering::Equal)
987 }
988 },
989 span,
990 );
991 quote! {span =>
992 fn partial_cmp(&self, other: &Self) -> #krate::option::Option<#krate::cmp::Ordering> {
993 #body
994 }
995 }
996 })
997}
998
999fn coerce_pointee_expand(
1000 db: &dyn ExpandDatabase,
1001 span: Span,
1002 tt: &tt::TopSubtree,
1003) -> ExpandResult<tt::TopSubtree> {
1004 let (adt, _span_map) = match to_adt_syntax(db, tt, span) {
1005 Ok(it) => it,
1006 Err(err) => {
1007 return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
1008 }
1009 };
1010 let adt = adt.clone_for_update();
1011 let ast::Adt::Struct(strukt) = &adt else {
1012 return ExpandResult::new(
1013 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1014 ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"),
1015 );
1016 };
1017 let has_at_least_one_field = strukt
1018 .field_list()
1019 .map(|it| match it {
1020 ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(),
1021 ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(),
1022 })
1023 .unwrap_or(false);
1024 if !has_at_least_one_field {
1025 return ExpandResult::new(
1026 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1027 ExpandError::other(
1028 span,
1029 "`CoercePointee` can only be derived on `struct`s with at least one field",
1030 ),
1031 );
1032 }
1033 let is_repr_transparent = strukt.attrs().any(|attr| {
1034 attr.as_simple_call().is_some_and(|(name, tt)| {
1035 name == "repr"
1036 && tt.syntax().children_with_tokens().any(|it| {
1037 it.into_token().is_some_and(|it| {
1038 it.kind() == SyntaxKind::IDENT && it.text() == "transparent"
1039 })
1040 })
1041 })
1042 });
1043 if !is_repr_transparent {
1044 return ExpandResult::new(
1045 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1046 ExpandError::other(
1047 span,
1048 "`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`",
1049 ),
1050 );
1051 }
1052 let type_params = strukt
1053 .generic_param_list()
1054 .into_iter()
1055 .flat_map(|generics| {
1056 generics.generic_params().filter_map(|param| match param {
1057 ast::GenericParam::TypeParam(param) => Some(param),
1058 _ => None,
1059 })
1060 })
1061 .collect_vec();
1062 if type_params.is_empty() {
1063 return ExpandResult::new(
1064 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1065 ExpandError::other(
1066 span,
1067 "`CoercePointee` can only be derived on `struct`s that are generic over at least one type",
1068 ),
1069 );
1070 }
1071 let (pointee_param, pointee_param_idx) = if type_params.len() == 1 {
1072 (type_params[0].clone(), 0)
1074 } else {
1075 let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| {
1076 param.attrs().any(|attr| {
1077 let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee");
1078 if is_pointee {
1079 ted::remove(attr.syntax());
1082 }
1083 is_pointee
1084 })
1085 });
1086 match (pointees.next(), pointees.next()) {
1087 (Some((pointee_idx, pointee)), None) => (pointee, pointee_idx),
1088 (None, _) => {
1089 return ExpandResult::new(
1090 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1091 ExpandError::other(
1092 span,
1093 "exactly one generic type parameter must be marked \
1094 as `#[pointee]` to derive `CoercePointee` traits",
1095 ),
1096 );
1097 }
1098 (Some(_), Some(_)) => {
1099 return ExpandResult::new(
1100 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1101 ExpandError::other(
1102 span,
1103 "only one type parameter can be marked as `#[pointee]` \
1104 when deriving `CoercePointee` traits",
1105 ),
1106 );
1107 }
1108 }
1109 };
1110 let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name())
1111 else {
1112 return ExpandResult::new(
1113 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1114 ExpandError::other(span, "invalid item"),
1115 );
1116 };
1117
1118 {
1119 let mut pointee_has_maybe_sized_bound = false;
1120 if let Some(bounds) = pointee_param.type_bound_list() {
1121 pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound);
1122 }
1123 if let Some(where_clause) = strukt.where_clause() {
1124 pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| {
1125 let Some(ast::Type::PathType(ty)) = pred.ty() else { return false };
1126 let is_not_pointee = ty.path().is_none_or(|path| {
1127 let is_pointee = path
1128 .as_single_name_ref()
1129 .is_some_and(|name| name.text() == pointee_param_name.text());
1130 !is_pointee
1131 });
1132 if is_not_pointee {
1133 return false;
1134 }
1135 pred.type_bound_list()
1136 .is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound))
1137 })
1138 }
1139 if !pointee_has_maybe_sized_bound {
1140 return ExpandResult::new(
1141 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1142 ExpandError::other(
1143 span,
1144 format!(
1145 "`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"
1146 ),
1147 ),
1148 );
1149 }
1150 }
1151
1152 const ADDED_PARAM: &str = "__S";
1153
1154 let where_clause = strukt.get_or_create_where_clause();
1155
1156 {
1157 let mut new_predicates = Vec::new();
1158
1159 for param in &type_params {
1181 let Some(param_name) = param.name() else { continue };
1182 if let Some(bounds) = param.type_bound_list() {
1183 let is_pointee = param_name.text() == pointee_param_name.text();
1186 let new_bounds = bounds
1187 .bounds()
1188 .map(|bound| bound.clone_subtree().clone_for_update())
1189 .filter(|bound| {
1190 bound.ty().is_some_and(|ty| {
1191 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
1192 || is_pointee
1193 })
1194 });
1195 let new_bounds_target = if is_pointee {
1196 make::name_ref(ADDED_PARAM)
1197 } else {
1198 make::name_ref(¶m_name.text())
1199 };
1200 new_predicates.push(
1201 make::where_pred(
1202 Either::Right(make::ty_path(make::path_from_segments(
1203 [make::path_segment(new_bounds_target)],
1204 false,
1205 ))),
1206 new_bounds,
1207 )
1208 .clone_for_update(),
1209 );
1210 }
1211 }
1212
1213 for predicate in where_clause.predicates() {
1240 let predicate = predicate.clone_subtree().clone_for_update();
1241 let Some(pred_target) = predicate.ty() else { continue };
1242
1243 if substitute_type_in_bound(
1246 pred_target.clone(),
1247 &pointee_param_name.text(),
1248 ADDED_PARAM,
1249 ) {
1250 if let Some(bounds) = predicate.type_bound_list() {
1251 for bound in bounds.bounds() {
1252 if let Some(ty) = bound.ty() {
1253 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM);
1254 }
1255 }
1256 }
1257
1258 new_predicates.push(predicate);
1259 } else if let Some(bounds) = predicate.type_bound_list() {
1260 let new_bounds = bounds
1261 .bounds()
1262 .map(|bound| bound.clone_subtree().clone_for_update())
1263 .filter(|bound| {
1264 bound.ty().is_some_and(|ty| {
1265 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
1266 })
1267 });
1268 new_predicates.push(
1269 make::where_pred(Either::Right(pred_target), new_bounds).clone_for_update(),
1270 );
1271 }
1272 }
1273
1274 for new_predicate in new_predicates {
1275 where_clause.add_predicate(new_predicate);
1276 }
1277 }
1278
1279 {
1280 where_clause.add_predicate(
1284 make::where_pred(
1285 Either::Right(make::ty_path(make::path_from_segments(
1286 [make::path_segment(make::name_ref(&pointee_param_name.text()))],
1287 false,
1288 ))),
1289 [make::type_bound(make::ty_path(make::path_from_segments(
1290 [
1291 make::path_segment(make::name_ref("core")),
1292 make::path_segment(make::name_ref("marker")),
1293 make::generic_ty_path_segment(
1294 make::name_ref("Unsize"),
1295 [make::type_arg(make::ty_path(make::path_from_segments(
1296 [make::path_segment(make::name_ref(ADDED_PARAM))],
1297 false,
1298 )))
1299 .into()],
1300 ),
1301 ],
1302 true,
1303 )))],
1304 )
1305 .clone_for_update(),
1306 );
1307 }
1308
1309 let self_for_traits = {
1310 let mut type_param_idx = 0;
1312 let self_params_for_traits = strukt
1313 .generic_param_list()
1314 .into_iter()
1315 .flat_map(|params| params.generic_params())
1316 .filter_map(|param| {
1317 Some(match param {
1318 ast::GenericParam::ConstParam(param) => {
1319 ast::GenericArg::ConstArg(make::expr_const_value(¶m.name()?.text()))
1320 }
1321 ast::GenericParam::LifetimeParam(param) => {
1322 make::lifetime_arg(param.lifetime()?).into()
1323 }
1324 ast::GenericParam::TypeParam(param) => {
1325 let name = if pointee_param_idx == type_param_idx {
1326 make::name_ref(ADDED_PARAM)
1327 } else {
1328 make::name_ref(¶m.name()?.text())
1329 };
1330 type_param_idx += 1;
1331 make::type_arg(make::ty_path(make::path_from_segments(
1332 [make::path_segment(name)],
1333 false,
1334 )))
1335 .into()
1336 }
1337 })
1338 });
1339
1340 make::path_from_segments(
1341 [make::generic_ty_path_segment(
1342 make::name_ref(&struct_name.text()),
1343 self_params_for_traits,
1344 )],
1345 false,
1346 )
1347 .clone_for_update()
1348 };
1349
1350 let mut span_map = span::SpanMap::empty();
1351 span_map.push(adt.syntax().text_range().end(), span);
1353
1354 let self_for_traits = syntax_bridge::syntax_node_to_token_tree(
1355 self_for_traits.syntax(),
1356 &span_map,
1357 span,
1358 DocCommentDesugarMode::ProcMacro,
1359 );
1360 let info = match parse_adt_from_syntax(&adt, &span_map, span) {
1361 Ok(it) => it,
1362 Err(err) => {
1363 return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
1364 }
1365 };
1366
1367 let self_for_traits2 = self_for_traits.clone();
1368 let krate = dollar_crate(span);
1369 let krate2 = krate.clone();
1370 let dispatch_from_dyn = expand_simple_derive_with_parsed(
1371 span,
1372 info.clone(),
1373 quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> },
1374 |_adt| quote! {span => },
1375 false,
1376 quote! {span => __S },
1377 );
1378 let coerce_unsized = expand_simple_derive_with_parsed(
1379 span,
1380 info,
1381 quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> },
1382 |_adt| quote! {span => },
1383 false,
1384 quote! {span => __S },
1385 );
1386 return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized });
1387
1388 fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool {
1389 if bound.question_mark_token().is_none() {
1390 return false;
1391 }
1392 let Some(ast::Type::PathType(ty)) = bound.ty() else {
1393 return false;
1394 };
1395 let Some(path) = ty.path() else {
1396 return false;
1397 };
1398 return segments_eq(&path, &["Sized"])
1399 || segments_eq(&path, &["core", "marker", "Sized"])
1400 || segments_eq(&path, &["std", "marker", "Sized"]);
1401
1402 fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool {
1403 path.segments().zip_longest(expected.iter().copied()).all(|value| {
1404 value.both().is_some_and(|(segment, expected)| {
1405 segment.name_ref().is_some_and(|name| name.text() == expected)
1406 })
1407 })
1408 }
1409 }
1410
1411 fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool {
1413 return match ty {
1414 ast::Type::ArrayType(ty) => {
1415 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1416 }
1417 ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement),
1418 ast::Type::FnPtrType(ty) => any_long(
1419 ty.param_list()
1420 .into_iter()
1421 .flat_map(|params| params.params().filter_map(|param| param.ty()))
1422 .chain(ty.ret_type().and_then(|it| it.ty())),
1423 |ty| substitute_type_in_bound(ty, param_name, replacement),
1424 ),
1425 ast::Type::ForType(ty) => {
1426 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1427 }
1428 ast::Type::ImplTraitType(ty) => {
1429 go_bounds(ty.type_bound_list(), param_name, replacement)
1430 }
1431 ast::Type::ParenType(ty) => {
1432 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1433 }
1434 ast::Type::PathType(ty) => ty.path().is_some_and(|path| {
1435 if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) {
1436 ted::replace(
1437 path.syntax(),
1438 make::path_from_segments(
1439 [make::path_segment(make::name_ref(replacement))],
1440 false,
1441 )
1442 .clone_for_update()
1443 .syntax(),
1444 );
1445 return true;
1446 }
1447
1448 any_long(
1449 path.segments()
1450 .filter_map(|segment| segment.generic_arg_list())
1451 .flat_map(|it| it.generic_args())
1452 .filter_map(|generic_arg| match generic_arg {
1453 ast::GenericArg::TypeArg(ty) => ty.ty(),
1454 _ => None,
1455 }),
1456 |ty| substitute_type_in_bound(ty, param_name, replacement),
1457 )
1458 }),
1459 ast::Type::PtrType(ty) => {
1460 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1461 }
1462 ast::Type::RefType(ty) => {
1463 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1464 }
1465 ast::Type::SliceType(ty) => {
1466 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1467 }
1468 ast::Type::TupleType(ty) => {
1469 any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement))
1470 }
1471 ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false,
1472 };
1473
1474 fn go_bounds(
1475 bounds: Option<ast::TypeBoundList>,
1476 param_name: &str,
1477 replacement: &str,
1478 ) -> bool {
1479 bounds.is_some_and(|bounds| {
1480 any_long(bounds.bounds(), |bound| {
1481 bound
1482 .ty()
1483 .is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1484 })
1485 })
1486 }
1487
1488 fn any_long<I: Iterator, F: FnMut(I::Item) -> bool>(iter: I, mut f: F) -> bool {
1490 let mut result = false;
1491 iter.for_each(|item| result |= f(item));
1492 result
1493 }
1494 }
1495}