1use either::Either;
4use intern::sym;
5use itertools::{Itertools, izip};
6use parser::SyntaxKind;
7use rustc_hash::FxHashSet;
8use span::{Edition, Span, SyntaxContext};
9use stdx::never;
10use syntax_bridge::DocCommentDesugarMode;
11use tracing::debug;
12
13use crate::{
14 ExpandError, ExpandResult, MacroCallId,
15 builtin::quote::{dollar_crate, quote},
16 db::ExpandDatabase,
17 hygiene::span_with_def_site_ctxt,
18 name::{self, AsName, Name},
19 span_map::ExpansionSpanMap,
20 tt,
21};
22use syntax::{
23 ast::{
24 self, AstNode, FieldList, HasAttrs, HasGenericArgs, HasGenericParams, HasModuleItem,
25 HasName, HasTypeBounds, edit_in_place::GenericParamsOwnerEdit, make,
26 },
27 ted,
28};
29
30macro_rules! register_builtin {
31 ( $($trait:ident => $expand:ident),* ) => {
32 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33 pub enum BuiltinDeriveExpander {
34 $($trait),*
35 }
36
37 impl BuiltinDeriveExpander {
38 pub fn expander(&self) -> fn(&dyn ExpandDatabase, Span, &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
39 match *self {
40 $( BuiltinDeriveExpander::$trait => $expand, )*
41 }
42 }
43
44 fn find_by_name(name: &name::Name) -> Option<Self> {
45 match name {
46 $( id if id == &sym::$trait => Some(BuiltinDeriveExpander::$trait), )*
47 _ => None,
48 }
49 }
50 }
51
52 };
53}
54
55impl BuiltinDeriveExpander {
56 pub fn expand(
57 &self,
58 db: &dyn ExpandDatabase,
59 id: MacroCallId,
60 tt: &tt::TopSubtree,
61 span: Span,
62 ) -> ExpandResult<tt::TopSubtree> {
63 let span = span_with_def_site_ctxt(db, span, id.into(), Edition::CURRENT);
64 self.expander()(db, span, tt)
65 }
66}
67
68register_builtin! {
69 Copy => copy_expand,
70 Clone => clone_expand,
71 Default => default_expand,
72 Debug => debug_expand,
73 Hash => hash_expand,
74 Ord => ord_expand,
75 PartialOrd => partial_ord_expand,
76 Eq => eq_expand,
77 PartialEq => partial_eq_expand,
78 CoercePointee => coerce_pointee_expand
79}
80
81pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> {
82 BuiltinDeriveExpander::find_by_name(ident)
83}
84
85#[derive(Clone)]
86enum VariantShape {
87 Struct(Vec<tt::Ident>),
88 Tuple(usize),
89 Unit,
90}
91
92fn tuple_field_iterator(span: Span, n: usize) -> impl Iterator<Item = tt::Ident> {
93 (0..n).map(move |it| tt::Ident::new(&format!("f{it}"), span))
94}
95
96impl VariantShape {
97 fn as_pattern(&self, path: tt::TopSubtree, span: Span) -> tt::TopSubtree {
98 self.as_pattern_map(path, span, |it| quote!(span => #it))
99 }
100
101 fn field_names(&self, span: Span) -> Vec<tt::Ident> {
102 match self {
103 VariantShape::Struct(s) => s.clone(),
104 VariantShape::Tuple(n) => tuple_field_iterator(span, *n).collect(),
105 VariantShape::Unit => vec![],
106 }
107 }
108
109 fn as_pattern_map(
110 &self,
111 path: tt::TopSubtree,
112 span: Span,
113 field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
114 ) -> tt::TopSubtree {
115 match self {
116 VariantShape::Struct(fields) => {
117 let fields = fields.iter().map(|it| {
118 let mapped = field_map(it);
119 quote! {span => #it : #mapped , }
120 });
121 quote! {span =>
122 #path { # #fields }
123 }
124 }
125 &VariantShape::Tuple(n) => {
126 let fields = tuple_field_iterator(span, n).map(|it| {
127 let mapped = field_map(&it);
128 quote! {span =>
129 #mapped ,
130 }
131 });
132 quote! {span =>
133 #path ( # #fields )
134 }
135 }
136 VariantShape::Unit => path,
137 }
138 }
139
140 fn from(
141 call_site: Span,
142 tm: &ExpansionSpanMap,
143 value: Option<FieldList>,
144 ) -> Result<Self, ExpandError> {
145 let r = match value {
146 None => VariantShape::Unit,
147 Some(FieldList::RecordFieldList(it)) => VariantShape::Struct(
148 it.fields()
149 .map(|it| it.name())
150 .map(|it| name_to_token(call_site, tm, it))
151 .collect::<Result<_, _>>()?,
152 ),
153 Some(FieldList::TupleFieldList(it)) => VariantShape::Tuple(it.fields().count()),
154 };
155 Ok(r)
156 }
157}
158
159#[derive(Clone)]
160enum AdtShape {
161 Struct(VariantShape),
162 Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> },
163 Union,
164}
165
166impl AdtShape {
167 fn as_pattern(&self, span: Span, name: &tt::Ident) -> Vec<tt::TopSubtree> {
168 self.as_pattern_map(name, |it| quote!(span =>#it), span)
169 }
170
171 fn field_names(&self, span: Span) -> Vec<Vec<tt::Ident>> {
172 match self {
173 AdtShape::Struct(s) => {
174 vec![s.field_names(span)]
175 }
176 AdtShape::Enum { variants, .. } => {
177 variants.iter().map(|(_, fields)| fields.field_names(span)).collect()
178 }
179 AdtShape::Union => {
180 never!("using fields of union in derive is always wrong");
181 vec![]
182 }
183 }
184 }
185
186 fn as_pattern_map(
187 &self,
188 name: &tt::Ident,
189 field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
190 span: Span,
191 ) -> Vec<tt::TopSubtree> {
192 match self {
193 AdtShape::Struct(s) => {
194 vec![s.as_pattern_map(quote! {span => #name }, span, field_map)]
195 }
196 AdtShape::Enum { variants, .. } => variants
197 .iter()
198 .map(|(v, fields)| {
199 fields.as_pattern_map(quote! {span => #name :: #v }, span, &field_map)
200 })
201 .collect(),
202 AdtShape::Union => {
203 never!("pattern matching on union is always wrong");
204 vec![quote! {span => un }]
205 }
206 }
207 }
208}
209
210#[derive(Clone)]
211struct BasicAdtInfo {
212 name: tt::Ident,
213 shape: AdtShape,
214 param_types: Vec<AdtParam>,
218 where_clause: Vec<tt::TopSubtree>,
219 associated_types: Vec<tt::TopSubtree>,
220}
221
222#[derive(Clone)]
223struct AdtParam {
224 name: tt::TopSubtree,
225 const_ty: Option<tt::TopSubtree>,
227 bounds: Option<tt::TopSubtree>,
228}
229
230fn parse_adt(
232 db: &dyn ExpandDatabase,
233 tt: &tt::TopSubtree,
234 call_site: Span,
235) -> Result<BasicAdtInfo, ExpandError> {
236 let (adt, tm) = to_adt_syntax(db, tt, call_site)?;
237 parse_adt_from_syntax(&adt, &tm, call_site)
238}
239
240fn parse_adt_from_syntax(
241 adt: &ast::Adt,
242 tm: &span::SpanMap<SyntaxContext>,
243 call_site: Span,
244) -> Result<BasicAdtInfo, ExpandError> {
245 let (name, generic_param_list, where_clause, shape) = match &adt {
246 ast::Adt::Struct(it) => (
247 it.name(),
248 it.generic_param_list(),
249 it.where_clause(),
250 AdtShape::Struct(VariantShape::from(call_site, tm, it.field_list())?),
251 ),
252 ast::Adt::Enum(it) => {
253 let default_variant = it
254 .variant_list()
255 .into_iter()
256 .flat_map(|it| it.variants())
257 .position(|it| it.attrs().any(|it| it.simple_name() == Some("default".into())));
258 (
259 it.name(),
260 it.generic_param_list(),
261 it.where_clause(),
262 AdtShape::Enum {
263 default_variant,
264 variants: it
265 .variant_list()
266 .into_iter()
267 .flat_map(|it| it.variants())
268 .map(|it| {
269 Ok((
270 name_to_token(call_site, tm, it.name())?,
271 VariantShape::from(call_site, tm, it.field_list())?,
272 ))
273 })
274 .collect::<Result<_, ExpandError>>()?,
275 },
276 )
277 }
278 ast::Adt::Union(it) => {
279 (it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union)
280 }
281 };
282
283 let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
284 let param_types = generic_param_list
285 .into_iter()
286 .flat_map(|param_list| param_list.type_or_const_params())
287 .map(|param| {
288 let name = {
289 let this = param.name();
290 match this {
291 Some(it) => {
292 param_type_set.insert(it.as_name());
293 syntax_bridge::syntax_node_to_token_tree(
294 it.syntax(),
295 tm,
296 call_site,
297 DocCommentDesugarMode::ProcMacro,
298 )
299 }
300 None => {
301 tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
302 }
303 }
304 };
305 let bounds = match ¶m {
306 ast::TypeOrConstParam::Type(it) => it.type_bound_list().map(|it| {
307 syntax_bridge::syntax_node_to_token_tree(
308 it.syntax(),
309 tm,
310 call_site,
311 DocCommentDesugarMode::ProcMacro,
312 )
313 }),
314 ast::TypeOrConstParam::Const(_) => None,
315 };
316 let const_ty = if let ast::TypeOrConstParam::Const(param) = param {
317 let ty = param
318 .ty()
319 .map(|ty| {
320 syntax_bridge::syntax_node_to_token_tree(
321 ty.syntax(),
322 tm,
323 call_site,
324 DocCommentDesugarMode::ProcMacro,
325 )
326 })
327 .unwrap_or_else(|| {
328 tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
329 });
330 Some(ty)
331 } else {
332 None
333 };
334 AdtParam { name, const_ty, bounds }
335 })
336 .collect();
337
338 let where_clause = if let Some(w) = where_clause {
339 w.predicates()
340 .map(|it| {
341 syntax_bridge::syntax_node_to_token_tree(
342 it.syntax(),
343 tm,
344 call_site,
345 DocCommentDesugarMode::ProcMacro,
346 )
347 })
348 .collect()
349 } else {
350 vec![]
351 };
352
353 let field_list = match adt {
365 ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
366 ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
367 ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
368 };
369 let associated_types = field_list
370 .into_iter()
371 .flat_map(|it| it.descendants())
372 .filter_map(ast::PathType::cast)
373 .filter_map(|p| {
374 let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
375 param_type_set.contains(&name).then_some(p)
376 })
377 .map(|it| {
378 syntax_bridge::syntax_node_to_token_tree(
379 it.syntax(),
380 tm,
381 call_site,
382 DocCommentDesugarMode::ProcMacro,
383 )
384 })
385 .collect();
386 let name_token = name_to_token(call_site, tm, name)?;
387 Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
388}
389
390fn to_adt_syntax(
391 db: &dyn ExpandDatabase,
392 tt: &tt::TopSubtree,
393 call_site: Span,
394) -> Result<(ast::Adt, span::SpanMap<SyntaxContext>), ExpandError> {
395 let (parsed, tm) = crate::db::token_tree_to_syntax_node(
396 db,
397 tt,
398 crate::ExpandTo::Items,
399 parser::Edition::CURRENT_FIXME,
400 );
401 let macro_items = ast::MacroItems::cast(parsed.syntax_node())
402 .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
403 let item =
404 macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
405 let adt = ast::Adt::cast(item.syntax().clone())
406 .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
407 Ok((adt, tm))
408}
409
410fn name_to_token(
411 call_site: Span,
412 token_map: &ExpansionSpanMap,
413 name: Option<ast::Name>,
414) -> Result<tt::Ident, ExpandError> {
415 let name = name.ok_or_else(|| {
416 debug!("parsed item has no name");
417 ExpandError::other(call_site, "missing name")
418 })?;
419 let span = token_map.span_at(name.syntax().text_range().start());
420
421 let name_token = tt::Ident::new(name.text().as_ref(), span);
422 Ok(name_token)
423}
424
425fn expand_simple_derive(
457 db: &dyn ExpandDatabase,
458 invoc_span: Span,
459 tt: &tt::TopSubtree,
460 trait_path: tt::TopSubtree,
461 allow_unions: bool,
462 make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
463) -> ExpandResult<tt::TopSubtree> {
464 let info = match parse_adt(db, tt, invoc_span) {
465 Ok(info) => info,
466 Err(e) => {
467 return ExpandResult::new(
468 tt::TopSubtree::empty(tt::DelimSpan { open: invoc_span, close: invoc_span }),
469 e,
470 );
471 }
472 };
473 if !allow_unions && matches!(info.shape, AdtShape::Union) {
474 return ExpandResult::new(
475 tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
476 ExpandError::other(invoc_span, "this trait cannot be derived for unions"),
477 );
478 }
479 ExpandResult::ok(expand_simple_derive_with_parsed(
480 invoc_span,
481 info,
482 trait_path,
483 make_trait_body,
484 true,
485 tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
486 ))
487}
488
489fn expand_simple_derive_with_parsed(
490 invoc_span: Span,
491 info: BasicAdtInfo,
492 trait_path: tt::TopSubtree,
493 make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
494 constrain_to_trait: bool,
495 extra_impl_params: tt::TopSubtree,
496) -> tt::TopSubtree {
497 let trait_body = make_trait_body(&info);
498 let mut where_block: Vec<_> =
499 info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
500 let (params, args): (Vec<_>, Vec<_>) = info
501 .param_types
502 .into_iter()
503 .map(|param| {
504 let ident = param.name;
505 if let Some(b) = param.bounds {
506 let ident2 = ident.clone();
507 where_block.push(quote! {invoc_span => #ident2 : #b , });
508 }
509 if let Some(ty) = param.const_ty {
510 let ident2 = ident.clone();
511 (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , })
512 } else {
513 let bound = trait_path.clone();
514 let ident2 = ident.clone();
515 let param = if constrain_to_trait {
516 quote! {invoc_span => #ident : #bound , }
517 } else {
518 quote! {invoc_span => #ident , }
519 };
520 (param, quote! {invoc_span => #ident2 , })
521 }
522 })
523 .unzip();
524
525 if constrain_to_trait {
526 where_block.extend(info.associated_types.iter().map(|it| {
527 let it = it.clone();
528 let bound = trait_path.clone();
529 quote! {invoc_span => #it : #bound , }
530 }));
531 }
532
533 let name = info.name;
534 quote! {invoc_span =>
535 impl < # #params #extra_impl_params > #trait_path for #name < # #args > where # #where_block { #trait_body }
536 }
537}
538
539fn copy_expand(
540 db: &dyn ExpandDatabase,
541 span: Span,
542 tt: &tt::TopSubtree,
543) -> ExpandResult<tt::TopSubtree> {
544 let krate = dollar_crate(span);
545 expand_simple_derive(
546 db,
547 span,
548 tt,
549 quote! {span => #krate::marker::Copy },
550 true,
551 |_| quote! {span =>},
552 )
553}
554
555fn clone_expand(
556 db: &dyn ExpandDatabase,
557 span: Span,
558 tt: &tt::TopSubtree,
559) -> ExpandResult<tt::TopSubtree> {
560 let krate = dollar_crate(span);
561 expand_simple_derive(db, span, tt, quote! {span => #krate::clone::Clone }, true, |adt| {
562 if matches!(adt.shape, AdtShape::Union) {
563 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
564 return quote! {span =>
565 fn clone(&self) -> Self {
566 #star self
567 }
568 };
569 }
570 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
571 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
572 return quote! {span =>
573 fn clone(&self) -> Self {
574 match #star self {}
575 }
576 };
577 }
578 let name = &adt.name;
579 let patterns = adt.shape.as_pattern(span, name);
580 let exprs = adt.shape.as_pattern_map(name, |it| quote! {span => #it .clone() }, span);
581 let arms = patterns.into_iter().zip(exprs).map(|(pat, expr)| {
582 let fat_arrow = fat_arrow(span);
583 quote! {span =>
584 #pat #fat_arrow #expr,
585 }
586 });
587
588 quote! {span =>
589 fn clone(&self) -> Self {
590 match self {
591 # #arms
592 }
593 }
594 }
595 })
596}
597
598fn fat_arrow(span: Span) -> tt::TopSubtree {
600 let eq = tt::Punct { char: '=', spacing: ::tt::Spacing::Joint, span };
601 quote! {span => #eq> }
602}
603
604fn and_and(span: Span) -> tt::TopSubtree {
606 let and = tt::Punct { char: '&', spacing: ::tt::Spacing::Joint, span };
607 quote! {span => #and& }
608}
609
610fn default_expand(
611 db: &dyn ExpandDatabase,
612 span: Span,
613 tt: &tt::TopSubtree,
614) -> ExpandResult<tt::TopSubtree> {
615 let krate = &dollar_crate(span);
616 let adt = match parse_adt(db, tt, span) {
617 Ok(info) => info,
618 Err(e) => {
619 return ExpandResult::new(
620 tt::TopSubtree::empty(tt::DelimSpan { open: span, close: span }),
621 e,
622 );
623 }
624 };
625 let (body, constrain_to_trait) = match &adt.shape {
626 AdtShape::Struct(fields) => {
627 let name = &adt.name;
628 let body = fields.as_pattern_map(
629 quote!(span =>#name),
630 span,
631 |_| quote!(span =>#krate::default::Default::default()),
632 );
633 (body, true)
634 }
635 AdtShape::Enum { default_variant, variants } => {
636 if let Some(d) = default_variant {
637 let (name, fields) = &variants[*d];
638 let adt_name = &adt.name;
639 let body = fields.as_pattern_map(
640 quote!(span =>#adt_name :: #name),
641 span,
642 |_| quote!(span =>#krate::default::Default::default()),
643 );
644 (body, false)
645 } else {
646 return ExpandResult::new(
647 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
648 ExpandError::other(span, "`#[derive(Default)]` on enum with no `#[default]`"),
649 );
650 }
651 }
652 AdtShape::Union => {
653 return ExpandResult::new(
654 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
655 ExpandError::other(span, "this trait cannot be derived for unions"),
656 );
657 }
658 };
659 ExpandResult::ok(expand_simple_derive_with_parsed(
660 span,
661 adt,
662 quote! {span => #krate::default::Default },
663 |_adt| {
664 quote! {span =>
665 fn default() -> Self {
666 #body
667 }
668 }
669 },
670 constrain_to_trait,
671 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
672 ))
673}
674
675fn debug_expand(
676 db: &dyn ExpandDatabase,
677 span: Span,
678 tt: &tt::TopSubtree,
679) -> ExpandResult<tt::TopSubtree> {
680 let krate = &dollar_crate(span);
681 expand_simple_derive(db, span, tt, quote! {span => #krate::fmt::Debug }, false, |adt| {
682 let for_variant = |name: String, v: &VariantShape| match v {
683 VariantShape::Struct(fields) => {
684 let for_fields = fields.iter().map(|it| {
685 let x_string = it.to_string();
686 quote! {span =>
687 .field(#x_string, & #it)
688 }
689 });
690 quote! {span =>
691 f.debug_struct(#name) # #for_fields .finish()
692 }
693 }
694 VariantShape::Tuple(n) => {
695 let for_fields = tuple_field_iterator(span, *n).map(|it| {
696 quote! {span =>
697 .field( & #it)
698 }
699 });
700 quote! {span =>
701 f.debug_tuple(#name) # #for_fields .finish()
702 }
703 }
704 VariantShape::Unit => quote! {span =>
705 f.write_str(#name)
706 },
707 };
708 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
709 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
710 return quote! {span =>
711 fn fmt(&self, f: &mut #krate::fmt::Formatter) -> #krate::fmt::Result {
712 match #star self {}
713 }
714 };
715 }
716 let arms = match &adt.shape {
717 AdtShape::Struct(fields) => {
718 let fat_arrow = fat_arrow(span);
719 let name = &adt.name;
720 let pat = fields.as_pattern(quote!(span =>#name), span);
721 let expr = for_variant(name.to_string(), fields);
722 vec![quote! {span => #pat #fat_arrow #expr }]
723 }
724 AdtShape::Enum { variants, .. } => variants
725 .iter()
726 .map(|(name, v)| {
727 let fat_arrow = fat_arrow(span);
728 let adt_name = &adt.name;
729 let pat = v.as_pattern(quote!(span =>#adt_name :: #name), span);
730 let expr = for_variant(name.to_string(), v);
731 quote! {span =>
732 #pat #fat_arrow #expr ,
733 }
734 })
735 .collect(),
736 AdtShape::Union => unreachable!(),
737 };
738 quote! {span =>
739 fn fmt(&self, f: &mut #krate::fmt::Formatter) -> #krate::fmt::Result {
740 match self {
741 # #arms
742 }
743 }
744 }
745 })
746}
747
748fn hash_expand(
749 db: &dyn ExpandDatabase,
750 span: Span,
751 tt: &tt::TopSubtree,
752) -> ExpandResult<tt::TopSubtree> {
753 let krate = &dollar_crate(span);
754 expand_simple_derive(db, span, tt, quote! {span => #krate::hash::Hash }, false, |adt| {
755 if matches!(&adt.shape, AdtShape::Enum { variants, .. } if variants.is_empty()) {
756 let star = tt::Punct { char: '*', spacing: ::tt::Spacing::Alone, span };
757 return quote! {span =>
758 fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
759 match #star self {}
760 }
761 };
762 }
763 let arms =
764 adt.shape.as_pattern(span, &adt.name).into_iter().zip(adt.shape.field_names(span)).map(
765 |(pat, names)| {
766 let expr = {
767 let it =
768 names.iter().map(|it| quote! {span => #it . hash(ra_expand_state); });
769 quote! {span => {
770 # #it
771 } }
772 };
773 let fat_arrow = fat_arrow(span);
774 quote! {span =>
775 #pat #fat_arrow #expr ,
776 }
777 },
778 );
779 let check_discriminant = if matches!(&adt.shape, AdtShape::Enum { .. }) {
780 quote! {span => #krate::mem::discriminant(self).hash(ra_expand_state); }
781 } else {
782 quote! {span =>}
783 };
784 quote! {span =>
785 fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
786 #check_discriminant
787 match self {
788 # #arms
789 }
790 }
791 }
792 })
793}
794
795fn eq_expand(
796 db: &dyn ExpandDatabase,
797 span: Span,
798 tt: &tt::TopSubtree,
799) -> ExpandResult<tt::TopSubtree> {
800 let krate = dollar_crate(span);
801 expand_simple_derive(
802 db,
803 span,
804 tt,
805 quote! {span => #krate::cmp::Eq },
806 true,
807 |_| quote! {span =>},
808 )
809}
810
811fn partial_eq_expand(
812 db: &dyn ExpandDatabase,
813 span: Span,
814 tt: &tt::TopSubtree,
815) -> ExpandResult<tt::TopSubtree> {
816 let krate = dollar_crate(span);
817 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::PartialEq }, false, |adt| {
818 let name = &adt.name;
819
820 let (self_patterns, other_patterns) = self_and_other_patterns(adt, name, span);
821 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
822 |(pat1, pat2, names)| {
823 let fat_arrow = fat_arrow(span);
824 let body = match &*names {
825 [] => {
826 quote!(span =>true)
827 }
828 [first, rest @ ..] => {
829 let rest = rest.iter().map(|it| {
830 let t1 = tt::Ident::new(&format!("{}_self", it.sym), it.span);
831 let t2 = tt::Ident::new(&format!("{}_other", it.sym), it.span);
832 let and_and = and_and(span);
833 quote!(span =>#and_and #t1 .eq( #t2 ))
834 });
835 let first = {
836 let t1 = tt::Ident::new(&format!("{}_self", first.sym), first.span);
837 let t2 = tt::Ident::new(&format!("{}_other", first.sym), first.span);
838 quote!(span =>#t1 .eq( #t2 ))
839 };
840 quote!(span =>#first # #rest)
841 }
842 };
843 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
844 },
845 );
846
847 let fat_arrow = fat_arrow(span);
848 quote! {span =>
849 fn eq(&self, other: &Self) -> bool {
850 match (self, other) {
851 # #arms
852 _unused #fat_arrow false
853 }
854 }
855 }
856 })
857}
858
859fn self_and_other_patterns(
860 adt: &BasicAdtInfo,
861 name: &tt::Ident,
862 span: Span,
863) -> (Vec<tt::TopSubtree>, Vec<tt::TopSubtree>) {
864 let self_patterns = adt.shape.as_pattern_map(
865 name,
866 |it| {
867 let t = tt::Ident::new(&format!("{}_self", it.sym), it.span);
868 quote!(span =>#t)
869 },
870 span,
871 );
872 let other_patterns = adt.shape.as_pattern_map(
873 name,
874 |it| {
875 let t = tt::Ident::new(&format!("{}_other", it.sym), it.span);
876 quote!(span =>#t)
877 },
878 span,
879 );
880 (self_patterns, other_patterns)
881}
882
883fn ord_expand(
884 db: &dyn ExpandDatabase,
885 span: Span,
886 tt: &tt::TopSubtree,
887) -> ExpandResult<tt::TopSubtree> {
888 let krate = &dollar_crate(span);
889 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::Ord }, false, |adt| {
890 fn compare(
891 krate: &tt::Ident,
892 left: tt::TopSubtree,
893 right: tt::TopSubtree,
894 rest: tt::TopSubtree,
895 span: Span,
896 ) -> tt::TopSubtree {
897 let fat_arrow1 = fat_arrow(span);
898 let fat_arrow2 = fat_arrow(span);
899 quote! {span =>
900 match #left.cmp(&#right) {
901 #krate::cmp::Ordering::Equal #fat_arrow1 {
902 #rest
903 }
904 c #fat_arrow2 return c,
905 }
906 }
907 }
908 let (self_patterns, other_patterns) = self_and_other_patterns(adt, &adt.name, span);
909 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
910 |(pat1, pat2, fields)| {
911 let mut body = quote!(span =>#krate::cmp::Ordering::Equal);
912 for f in fields.into_iter().rev() {
913 let t1 = tt::Ident::new(&format!("{}_self", f.sym), f.span);
914 let t2 = tt::Ident::new(&format!("{}_other", f.sym), f.span);
915 body = compare(krate, quote!(span =>#t1), quote!(span =>#t2), body, span);
916 }
917 let fat_arrow = fat_arrow(span);
918 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
919 },
920 );
921 let fat_arrow = fat_arrow(span);
922 let mut body = quote! {span =>
923 match (self, other) {
924 # #arms
925 _unused #fat_arrow #krate::cmp::Ordering::Equal
926 }
927 };
928 if matches!(&adt.shape, AdtShape::Enum { .. }) {
929 let left = quote!(span =>#krate::intrinsics::discriminant_value(self));
930 let right = quote!(span =>#krate::intrinsics::discriminant_value(other));
931 body = compare(krate, left, right, body, span);
932 }
933 quote! {span =>
934 fn cmp(&self, other: &Self) -> #krate::cmp::Ordering {
935 #body
936 }
937 }
938 })
939}
940
941fn partial_ord_expand(
942 db: &dyn ExpandDatabase,
943 span: Span,
944 tt: &tt::TopSubtree,
945) -> ExpandResult<tt::TopSubtree> {
946 let krate = &dollar_crate(span);
947 expand_simple_derive(db, span, tt, quote! {span => #krate::cmp::PartialOrd }, false, |adt| {
948 fn compare(
949 krate: &tt::Ident,
950 left: tt::TopSubtree,
951 right: tt::TopSubtree,
952 rest: tt::TopSubtree,
953 span: Span,
954 ) -> tt::TopSubtree {
955 let fat_arrow1 = fat_arrow(span);
956 let fat_arrow2 = fat_arrow(span);
957 quote! {span =>
958 match #left.partial_cmp(&#right) {
959 #krate::option::Option::Some(#krate::cmp::Ordering::Equal) #fat_arrow1 {
960 #rest
961 }
962 c #fat_arrow2 return c,
963 }
964 }
965 }
966 let left = quote!(span =>#krate::intrinsics::discriminant_value(self));
967 let right = quote!(span =>#krate::intrinsics::discriminant_value(other));
968
969 let (self_patterns, other_patterns) = self_and_other_patterns(adt, &adt.name, span);
970 let arms = izip!(self_patterns, other_patterns, adt.shape.field_names(span)).map(
971 |(pat1, pat2, fields)| {
972 let mut body =
973 quote!(span =>#krate::option::Option::Some(#krate::cmp::Ordering::Equal));
974 for f in fields.into_iter().rev() {
975 let t1 = tt::Ident::new(&format!("{}_self", f.sym), f.span);
976 let t2 = tt::Ident::new(&format!("{}_other", f.sym), f.span);
977 body = compare(krate, quote!(span =>#t1), quote!(span =>#t2), body, span);
978 }
979 let fat_arrow = fat_arrow(span);
980 quote! {span => ( #pat1 , #pat2 ) #fat_arrow #body , }
981 },
982 );
983 let fat_arrow = fat_arrow(span);
984 let body = compare(
985 krate,
986 left,
987 right,
988 quote! {span =>
989 match (self, other) {
990 # #arms
991 _unused #fat_arrow #krate::option::Option::Some(#krate::cmp::Ordering::Equal)
992 }
993 },
994 span,
995 );
996 quote! {span =>
997 fn partial_cmp(&self, other: &Self) -> #krate::option::Option<#krate::cmp::Ordering> {
998 #body
999 }
1000 }
1001 })
1002}
1003
1004fn coerce_pointee_expand(
1005 db: &dyn ExpandDatabase,
1006 span: Span,
1007 tt: &tt::TopSubtree,
1008) -> ExpandResult<tt::TopSubtree> {
1009 let (adt, _span_map) = match to_adt_syntax(db, tt, span) {
1010 Ok(it) => it,
1011 Err(err) => {
1012 return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
1013 }
1014 };
1015 let adt = adt.clone_for_update();
1016 let ast::Adt::Struct(strukt) = &adt else {
1017 return ExpandResult::new(
1018 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1019 ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"),
1020 );
1021 };
1022 let has_at_least_one_field = strukt
1023 .field_list()
1024 .map(|it| match it {
1025 ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(),
1026 ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(),
1027 })
1028 .unwrap_or(false);
1029 if !has_at_least_one_field {
1030 return ExpandResult::new(
1031 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1032 ExpandError::other(
1033 span,
1034 "`CoercePointee` can only be derived on `struct`s with at least one field",
1035 ),
1036 );
1037 }
1038 let is_repr_transparent = strukt.attrs().any(|attr| {
1039 attr.as_simple_call().is_some_and(|(name, tt)| {
1040 name == "repr"
1041 && tt.syntax().children_with_tokens().any(|it| {
1042 it.into_token().is_some_and(|it| {
1043 it.kind() == SyntaxKind::IDENT && it.text() == "transparent"
1044 })
1045 })
1046 })
1047 });
1048 if !is_repr_transparent {
1049 return ExpandResult::new(
1050 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1051 ExpandError::other(
1052 span,
1053 "`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`",
1054 ),
1055 );
1056 }
1057 let type_params = strukt
1058 .generic_param_list()
1059 .into_iter()
1060 .flat_map(|generics| {
1061 generics.generic_params().filter_map(|param| match param {
1062 ast::GenericParam::TypeParam(param) => Some(param),
1063 _ => None,
1064 })
1065 })
1066 .collect_vec();
1067 if type_params.is_empty() {
1068 return ExpandResult::new(
1069 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1070 ExpandError::other(
1071 span,
1072 "`CoercePointee` can only be derived on `struct`s that are generic over at least one type",
1073 ),
1074 );
1075 }
1076 let (pointee_param, pointee_param_idx) = if type_params.len() == 1 {
1077 (type_params[0].clone(), 0)
1079 } else {
1080 let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| {
1081 param.attrs().any(|attr| {
1082 let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee");
1083 if is_pointee {
1084 ted::remove(attr.syntax());
1087 }
1088 is_pointee
1089 })
1090 });
1091 match (pointees.next(), pointees.next()) {
1092 (Some((pointee_idx, pointee)), None) => (pointee, pointee_idx),
1093 (None, _) => {
1094 return ExpandResult::new(
1095 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1096 ExpandError::other(
1097 span,
1098 "exactly one generic type parameter must be marked \
1099 as `#[pointee]` to derive `CoercePointee` traits",
1100 ),
1101 );
1102 }
1103 (Some(_), Some(_)) => {
1104 return ExpandResult::new(
1105 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1106 ExpandError::other(
1107 span,
1108 "only one type parameter can be marked as `#[pointee]` \
1109 when deriving `CoercePointee` traits",
1110 ),
1111 );
1112 }
1113 }
1114 };
1115 let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name())
1116 else {
1117 return ExpandResult::new(
1118 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1119 ExpandError::other(span, "invalid item"),
1120 );
1121 };
1122
1123 {
1124 let mut pointee_has_maybe_sized_bound = false;
1125 if let Some(bounds) = pointee_param.type_bound_list() {
1126 pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound);
1127 }
1128 if let Some(where_clause) = strukt.where_clause() {
1129 pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| {
1130 let Some(ast::Type::PathType(ty)) = pred.ty() else { return false };
1131 let is_not_pointee = ty.path().is_none_or(|path| {
1132 let is_pointee = path
1133 .as_single_name_ref()
1134 .is_some_and(|name| name.text() == pointee_param_name.text());
1135 !is_pointee
1136 });
1137 if is_not_pointee {
1138 return false;
1139 }
1140 pred.type_bound_list()
1141 .is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound))
1142 })
1143 }
1144 if !pointee_has_maybe_sized_bound {
1145 return ExpandResult::new(
1146 tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
1147 ExpandError::other(
1148 span,
1149 format!(
1150 "`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"
1151 ),
1152 ),
1153 );
1154 }
1155 }
1156
1157 const ADDED_PARAM: &str = "__S";
1158
1159 let where_clause = strukt.get_or_create_where_clause();
1160
1161 {
1162 let mut new_predicates = Vec::new();
1163
1164 for param in &type_params {
1186 let Some(param_name) = param.name() else { continue };
1187 if let Some(bounds) = param.type_bound_list() {
1188 let is_pointee = param_name.text() == pointee_param_name.text();
1191 let new_bounds = bounds
1192 .bounds()
1193 .map(|bound| bound.clone_subtree().clone_for_update())
1194 .filter(|bound| {
1195 bound.ty().is_some_and(|ty| {
1196 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
1197 || is_pointee
1198 })
1199 });
1200 let new_bounds_target = if is_pointee {
1201 make::name_ref(ADDED_PARAM)
1202 } else {
1203 make::name_ref(¶m_name.text())
1204 };
1205 new_predicates.push(
1206 make::where_pred(
1207 Either::Right(make::ty_path(make::path_from_segments(
1208 [make::path_segment(new_bounds_target)],
1209 false,
1210 ))),
1211 new_bounds,
1212 )
1213 .clone_for_update(),
1214 );
1215 }
1216 }
1217
1218 for predicate in where_clause.predicates() {
1245 let predicate = predicate.clone_subtree().clone_for_update();
1246 let Some(pred_target) = predicate.ty() else { continue };
1247
1248 if substitute_type_in_bound(
1251 pred_target.clone(),
1252 &pointee_param_name.text(),
1253 ADDED_PARAM,
1254 ) {
1255 if let Some(bounds) = predicate.type_bound_list() {
1256 for bound in bounds.bounds() {
1257 if let Some(ty) = bound.ty() {
1258 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM);
1259 }
1260 }
1261 }
1262
1263 new_predicates.push(predicate);
1264 } else if let Some(bounds) = predicate.type_bound_list() {
1265 let new_bounds = bounds
1266 .bounds()
1267 .map(|bound| bound.clone_subtree().clone_for_update())
1268 .filter(|bound| {
1269 bound.ty().is_some_and(|ty| {
1270 substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
1271 })
1272 });
1273 new_predicates.push(
1274 make::where_pred(Either::Right(pred_target), new_bounds).clone_for_update(),
1275 );
1276 }
1277 }
1278
1279 for new_predicate in new_predicates {
1280 where_clause.add_predicate(new_predicate);
1281 }
1282 }
1283
1284 {
1285 where_clause.add_predicate(
1289 make::where_pred(
1290 Either::Right(make::ty_path(make::path_from_segments(
1291 [make::path_segment(make::name_ref(&pointee_param_name.text()))],
1292 false,
1293 ))),
1294 [make::type_bound(make::ty_path(make::path_from_segments(
1295 [
1296 make::path_segment(make::name_ref("core")),
1297 make::path_segment(make::name_ref("marker")),
1298 make::generic_ty_path_segment(
1299 make::name_ref("Unsize"),
1300 [make::type_arg(make::ty_path(make::path_from_segments(
1301 [make::path_segment(make::name_ref(ADDED_PARAM))],
1302 false,
1303 )))
1304 .into()],
1305 ),
1306 ],
1307 true,
1308 )))],
1309 )
1310 .clone_for_update(),
1311 );
1312 }
1313
1314 let self_for_traits = {
1315 let mut type_param_idx = 0;
1317 let self_params_for_traits = strukt
1318 .generic_param_list()
1319 .into_iter()
1320 .flat_map(|params| params.generic_params())
1321 .filter_map(|param| {
1322 Some(match param {
1323 ast::GenericParam::ConstParam(param) => {
1324 ast::GenericArg::ConstArg(make::expr_const_value(¶m.name()?.text()))
1325 }
1326 ast::GenericParam::LifetimeParam(param) => {
1327 make::lifetime_arg(param.lifetime()?).into()
1328 }
1329 ast::GenericParam::TypeParam(param) => {
1330 let name = if pointee_param_idx == type_param_idx {
1331 make::name_ref(ADDED_PARAM)
1332 } else {
1333 make::name_ref(¶m.name()?.text())
1334 };
1335 type_param_idx += 1;
1336 make::type_arg(make::ty_path(make::path_from_segments(
1337 [make::path_segment(name)],
1338 false,
1339 )))
1340 .into()
1341 }
1342 })
1343 });
1344
1345 make::path_from_segments(
1346 [make::generic_ty_path_segment(
1347 make::name_ref(&struct_name.text()),
1348 self_params_for_traits,
1349 )],
1350 false,
1351 )
1352 .clone_for_update()
1353 };
1354
1355 let mut span_map = span::SpanMap::empty();
1356 span_map.push(adt.syntax().text_range().end(), span);
1358
1359 let self_for_traits = syntax_bridge::syntax_node_to_token_tree(
1360 self_for_traits.syntax(),
1361 &span_map,
1362 span,
1363 DocCommentDesugarMode::ProcMacro,
1364 );
1365 let info = match parse_adt_from_syntax(&adt, &span_map, span) {
1366 Ok(it) => it,
1367 Err(err) => {
1368 return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
1369 }
1370 };
1371
1372 let self_for_traits2 = self_for_traits.clone();
1373 let krate = dollar_crate(span);
1374 let krate2 = krate.clone();
1375 let dispatch_from_dyn = expand_simple_derive_with_parsed(
1376 span,
1377 info.clone(),
1378 quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> },
1379 |_adt| quote! {span => },
1380 false,
1381 quote! {span => __S },
1382 );
1383 let coerce_unsized = expand_simple_derive_with_parsed(
1384 span,
1385 info,
1386 quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> },
1387 |_adt| quote! {span => },
1388 false,
1389 quote! {span => __S },
1390 );
1391 return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized });
1392
1393 fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool {
1394 if bound.question_mark_token().is_none() {
1395 return false;
1396 }
1397 let Some(ast::Type::PathType(ty)) = bound.ty() else {
1398 return false;
1399 };
1400 let Some(path) = ty.path() else {
1401 return false;
1402 };
1403 return segments_eq(&path, &["Sized"])
1404 || segments_eq(&path, &["core", "marker", "Sized"])
1405 || segments_eq(&path, &["std", "marker", "Sized"]);
1406
1407 fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool {
1408 path.segments().zip_longest(expected.iter().copied()).all(|value| {
1409 value.both().is_some_and(|(segment, expected)| {
1410 segment.name_ref().is_some_and(|name| name.text() == expected)
1411 })
1412 })
1413 }
1414 }
1415
1416 fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool {
1418 return match ty {
1419 ast::Type::ArrayType(ty) => {
1420 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1421 }
1422 ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement),
1423 ast::Type::FnPtrType(ty) => any_long(
1424 ty.param_list()
1425 .into_iter()
1426 .flat_map(|params| params.params().filter_map(|param| param.ty()))
1427 .chain(ty.ret_type().and_then(|it| it.ty())),
1428 |ty| substitute_type_in_bound(ty, param_name, replacement),
1429 ),
1430 ast::Type::ForType(ty) => {
1431 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1432 }
1433 ast::Type::ImplTraitType(ty) => {
1434 go_bounds(ty.type_bound_list(), param_name, replacement)
1435 }
1436 ast::Type::ParenType(ty) => {
1437 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1438 }
1439 ast::Type::PathType(ty) => ty.path().is_some_and(|path| {
1440 if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) {
1441 ted::replace(
1442 path.syntax(),
1443 make::path_from_segments(
1444 [make::path_segment(make::name_ref(replacement))],
1445 false,
1446 )
1447 .clone_for_update()
1448 .syntax(),
1449 );
1450 return true;
1451 }
1452
1453 any_long(
1454 path.segments()
1455 .filter_map(|segment| segment.generic_arg_list())
1456 .flat_map(|it| it.generic_args())
1457 .filter_map(|generic_arg| match generic_arg {
1458 ast::GenericArg::TypeArg(ty) => ty.ty(),
1459 _ => None,
1460 }),
1461 |ty| substitute_type_in_bound(ty, param_name, replacement),
1462 )
1463 }),
1464 ast::Type::PtrType(ty) => {
1465 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1466 }
1467 ast::Type::RefType(ty) => {
1468 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1469 }
1470 ast::Type::SliceType(ty) => {
1471 ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1472 }
1473 ast::Type::TupleType(ty) => {
1474 any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement))
1475 }
1476 ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false,
1477 };
1478
1479 fn go_bounds(
1480 bounds: Option<ast::TypeBoundList>,
1481 param_name: &str,
1482 replacement: &str,
1483 ) -> bool {
1484 bounds.is_some_and(|bounds| {
1485 any_long(bounds.bounds(), |bound| {
1486 bound
1487 .ty()
1488 .is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
1489 })
1490 })
1491 }
1492
1493 fn any_long<I: Iterator, F: FnMut(I::Item) -> bool>(iter: I, mut f: F) -> bool {
1495 let mut result = false;
1496 iter.for_each(|item| result |= f(item));
1497 result
1498 }
1499 }
1500}