ide_assists/utils/
gen_trait_fn_body.rs

1//! This module contains functions to generate default trait impl function bodies where possible.
2
3use hir::TraitRef;
4use syntax::ast::{self, AstNode, BinaryOp, CmpOp, HasName, LogicOp, edit::AstNodeEdit, make};
5
6/// Generate custom trait bodies without default implementation where possible.
7///
8/// If `func` is defined within an existing impl block, pass [`TraitRef`]. Otherwise pass `None`.
9///
10/// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11/// `None` means that generating a custom trait body failed, and the body will remain
12/// as `todo!` instead.
13pub(crate) fn gen_trait_fn_body(
14    func: &ast::Fn,
15    trait_path: &ast::Path,
16    adt: &ast::Adt,
17    trait_ref: Option<TraitRef<'_>>,
18) -> Option<ast::BlockExpr> {
19    let _ = func.body()?;
20    match trait_path.segment()?.name_ref()?.text().as_str() {
21        "Clone" => {
22            stdx::always!(func.name().is_some_and(|name| name.text() == "clone"));
23            gen_clone_impl(adt)
24        }
25        "Debug" => gen_debug_impl(adt),
26        "Default" => gen_default_impl(adt),
27        "Hash" => {
28            stdx::always!(func.name().is_some_and(|name| name.text() == "hash"));
29            gen_hash_impl(adt)
30        }
31        "PartialEq" => {
32            stdx::always!(func.name().is_some_and(|name| name.text() == "eq"));
33            gen_partial_eq(adt, trait_ref)
34        }
35        "PartialOrd" => {
36            stdx::always!(func.name().is_some_and(|name| name.text() == "partial_cmp"));
37            gen_partial_ord(adt, trait_ref)
38        }
39        _ => None,
40    }
41}
42
43/// Generate a `Clone` impl based on the fields and members of the target type.
44fn gen_clone_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
45    fn gen_clone_call(target: ast::Expr) -> ast::Expr {
46        let method = make::name_ref("clone");
47        make::expr_method_call(target, method, make::arg_list(None)).into()
48    }
49    let expr = match adt {
50        // `Clone` cannot be derived for unions, so no default impl can be provided.
51        ast::Adt::Union(_) => return None,
52        ast::Adt::Enum(enum_) => {
53            let list = enum_.variant_list()?;
54            let mut arms = vec![];
55            for variant in list.variants() {
56                let name = variant.name()?;
57                let variant_name = make::ext::path_from_idents(["Self", &format!("{name}")])?;
58
59                match variant.field_list() {
60                    // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
61                    Some(ast::FieldList::RecordFieldList(list)) => {
62                        let mut pats = vec![];
63                        let mut fields = vec![];
64                        for field in list.fields() {
65                            let field_name = field.name()?;
66                            let pat = make::ident_pat(false, false, field_name.clone());
67                            pats.push(pat.into());
68
69                            let path = make::ext::ident_path(&field_name.to_string());
70                            let method_call = gen_clone_call(make::expr_path(path));
71                            let name_ref = make::name_ref(&field_name.to_string());
72                            let field = make::record_expr_field(name_ref, Some(method_call));
73                            fields.push(field);
74                        }
75                        let pat = make::record_pat(variant_name.clone(), pats.into_iter());
76                        let fields = make::record_expr_field_list(fields);
77                        let record_expr = make::record_expr(variant_name, fields).into();
78                        arms.push(make::match_arm(pat.into(), None, record_expr));
79                    }
80
81                    // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
82                    Some(ast::FieldList::TupleFieldList(list)) => {
83                        let mut pats = vec![];
84                        let mut fields = vec![];
85                        for (i, _) in list.fields().enumerate() {
86                            let field_name = format!("arg{i}");
87                            let pat = make::ident_pat(false, false, make::name(&field_name));
88                            pats.push(pat.into());
89
90                            let f_path = make::expr_path(make::ext::ident_path(&field_name));
91                            fields.push(gen_clone_call(f_path));
92                        }
93                        let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
94                        let struct_name = make::expr_path(variant_name);
95                        let tuple_expr =
96                            make::expr_call(struct_name, make::arg_list(fields)).into();
97                        arms.push(make::match_arm(pat.into(), None, tuple_expr));
98                    }
99
100                    // => match self { Self::Name => Self::Name }
101                    None => {
102                        let pattern = make::path_pat(variant_name.clone());
103                        let variant_expr = make::expr_path(variant_name);
104                        arms.push(make::match_arm(pattern, None, variant_expr));
105                    }
106                }
107            }
108
109            let match_target = make::expr_path(make::ext::ident_path("self"));
110            let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
111            make::expr_match(match_target, list).into()
112        }
113        ast::Adt::Struct(strukt) => {
114            match strukt.field_list() {
115                // => Self { name: self.name.clone() }
116                Some(ast::FieldList::RecordFieldList(field_list)) => {
117                    let mut fields = vec![];
118                    for field in field_list.fields() {
119                        let base = make::expr_path(make::ext::ident_path("self"));
120                        let target = make::expr_field(base, &field.name()?.to_string());
121                        let method_call = gen_clone_call(target);
122                        let name_ref = make::name_ref(&field.name()?.to_string());
123                        let field = make::record_expr_field(name_ref, Some(method_call));
124                        fields.push(field);
125                    }
126                    let struct_name = make::ext::ident_path("Self");
127                    let fields = make::record_expr_field_list(fields);
128                    make::record_expr(struct_name, fields).into()
129                }
130                // => Self(self.0.clone(), self.1.clone())
131                Some(ast::FieldList::TupleFieldList(field_list)) => {
132                    let mut fields = vec![];
133                    for (i, _) in field_list.fields().enumerate() {
134                        let f_path = make::expr_path(make::ext::ident_path("self"));
135                        let target = make::expr_field(f_path, &format!("{i}"));
136                        fields.push(gen_clone_call(target));
137                    }
138                    let struct_name = make::expr_path(make::ext::ident_path("Self"));
139                    make::expr_call(struct_name, make::arg_list(fields)).into()
140                }
141                // => Self { }
142                None => {
143                    let struct_name = make::ext::ident_path("Self");
144                    let fields = make::record_expr_field_list(None);
145                    make::record_expr(struct_name, fields).into()
146                }
147            }
148        }
149    };
150    let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
151    Some(body)
152}
153
154/// Generate a `Debug` impl based on the fields and members of the target type.
155fn gen_debug_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
156    let annotated_name = adt.name()?;
157    match adt {
158        // `Debug` cannot be derived for unions, so no default impl can be provided.
159        ast::Adt::Union(_) => None,
160
161        // => match self { Self::Variant => write!(f, "Variant") }
162        ast::Adt::Enum(enum_) => {
163            let list = enum_.variant_list()?;
164            let mut arms = vec![];
165            for variant in list.variants() {
166                let name = variant.name()?;
167                let variant_name = make::ext::path_from_idents(["Self", &format!("{name}")])?;
168                let target = make::expr_path(make::ext::ident_path("f"));
169
170                match variant.field_list() {
171                    Some(ast::FieldList::RecordFieldList(list)) => {
172                        // => f.debug_struct(name)
173                        let target = make::expr_path(make::ext::ident_path("f"));
174                        let method = make::name_ref("debug_struct");
175                        let struct_name = format!("\"{name}\"");
176                        let args = make::arg_list(Some(make::expr_literal(&struct_name).into()));
177                        let mut expr = make::expr_method_call(target, method, args).into();
178
179                        let mut pats = vec![];
180                        for field in list.fields() {
181                            let field_name = field.name()?;
182
183                            // create a field pattern for use in `MyStruct { fields.. }`
184                            let pat = make::ident_pat(false, false, field_name.clone());
185                            pats.push(pat.into());
186
187                            // => <expr>.field("field_name", field)
188                            let method_name = make::name_ref("field");
189                            let name = make::expr_literal(&(format!("\"{field_name}\""))).into();
190                            let path = &format!("{field_name}");
191                            let path = make::expr_path(make::ext::ident_path(path));
192                            let args = make::arg_list(vec![name, path]);
193                            expr = make::expr_method_call(expr, method_name, args).into();
194                        }
195
196                        // => <expr>.finish()
197                        let method = make::name_ref("finish");
198                        let expr =
199                            make::expr_method_call(expr, method, make::arg_list(None)).into();
200
201                        // => MyStruct { fields.. } => f.debug_struct("MyStruct")...finish(),
202                        let pat = make::record_pat(variant_name.clone(), pats.into_iter());
203                        arms.push(make::match_arm(pat.into(), None, expr));
204                    }
205                    Some(ast::FieldList::TupleFieldList(list)) => {
206                        // => f.debug_tuple(name)
207                        let target = make::expr_path(make::ext::ident_path("f"));
208                        let method = make::name_ref("debug_tuple");
209                        let struct_name = format!("\"{name}\"");
210                        let args = make::arg_list(Some(make::expr_literal(&struct_name).into()));
211                        let mut expr = make::expr_method_call(target, method, args).into();
212
213                        let mut pats = vec![];
214                        for (i, _) in list.fields().enumerate() {
215                            let name = format!("arg{i}");
216
217                            // create a field pattern for use in `MyStruct(fields..)`
218                            let field_name = make::name(&name);
219                            let pat = make::ident_pat(false, false, field_name.clone());
220                            pats.push(pat.into());
221
222                            // => <expr>.field(field)
223                            let method_name = make::name_ref("field");
224                            let field_path = &name.to_string();
225                            let field_path = make::expr_path(make::ext::ident_path(field_path));
226                            let args = make::arg_list(vec![field_path]);
227                            expr = make::expr_method_call(expr, method_name, args).into();
228                        }
229
230                        // => <expr>.finish()
231                        let method = make::name_ref("finish");
232                        let expr =
233                            make::expr_method_call(expr, method, make::arg_list(None)).into();
234
235                        // => MyStruct (fields..) => f.debug_tuple("MyStruct")...finish(),
236                        let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
237                        arms.push(make::match_arm(pat.into(), None, expr));
238                    }
239                    None => {
240                        let fmt_string = make::expr_literal(&(format!("\"{name}\""))).into();
241                        let args = make::ext::token_tree_from_node(
242                            make::arg_list([target, fmt_string]).syntax(),
243                        );
244                        let macro_name = make::ext::ident_path("write");
245                        let macro_call = make::expr_macro(macro_name, args);
246
247                        let variant_name = make::path_pat(variant_name);
248                        arms.push(make::match_arm(variant_name, None, macro_call.into()));
249                    }
250                }
251            }
252
253            let match_target = make::expr_path(make::ext::ident_path("self"));
254            let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
255            let match_expr = make::expr_match(match_target, list);
256
257            let body = make::block_expr(None, Some(match_expr.into()));
258            let body = body.indent(ast::edit::IndentLevel(1));
259            Some(body)
260        }
261
262        ast::Adt::Struct(strukt) => {
263            let name = format!("\"{annotated_name}\"");
264            let args = make::arg_list(Some(make::expr_literal(&name).into()));
265            let target = make::expr_path(make::ext::ident_path("f"));
266
267            let expr = match strukt.field_list() {
268                // => f.debug_struct("Name").finish()
269                None => make::expr_method_call(target, make::name_ref("debug_struct"), args).into(),
270
271                // => f.debug_struct("Name").field("foo", &self.foo).finish()
272                Some(ast::FieldList::RecordFieldList(field_list)) => {
273                    let method = make::name_ref("debug_struct");
274                    let mut expr = make::expr_method_call(target, method, args).into();
275                    for field in field_list.fields() {
276                        let name = field.name()?;
277                        let f_name = make::expr_literal(&(format!("\"{name}\""))).into();
278                        let f_path = make::expr_path(make::ext::ident_path("self"));
279                        let f_path = make::expr_ref(f_path, false);
280                        let f_path = make::expr_field(f_path, &format!("{name}"));
281                        let args = make::arg_list([f_name, f_path]);
282                        expr = make::expr_method_call(expr, make::name_ref("field"), args).into();
283                    }
284                    expr
285                }
286
287                // => f.debug_tuple("Name").field(self.0).finish()
288                Some(ast::FieldList::TupleFieldList(field_list)) => {
289                    let method = make::name_ref("debug_tuple");
290                    let mut expr = make::expr_method_call(target, method, args).into();
291                    for (i, _) in field_list.fields().enumerate() {
292                        let f_path = make::expr_path(make::ext::ident_path("self"));
293                        let f_path = make::expr_ref(f_path, false);
294                        let f_path = make::expr_field(f_path, &format!("{i}"));
295                        let method = make::name_ref("field");
296                        expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)))
297                            .into();
298                    }
299                    expr
300                }
301            };
302
303            let method = make::name_ref("finish");
304            let expr = make::expr_method_call(expr, method, make::arg_list(None)).into();
305            let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
306            Some(body)
307        }
308    }
309}
310
311/// Generate a `Debug` impl based on the fields and members of the target type.
312fn gen_default_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
313    fn gen_default_call() -> Option<ast::Expr> {
314        let fn_name = make::ext::path_from_idents(["Default", "default"])?;
315        Some(make::expr_call(make::expr_path(fn_name), make::arg_list(None)).into())
316    }
317    match adt {
318        // `Debug` cannot be derived for unions, so no default impl can be provided.
319        ast::Adt::Union(_) => None,
320        // Deriving `Debug` for enums is not stable yet.
321        ast::Adt::Enum(_) => None,
322        ast::Adt::Struct(strukt) => {
323            let expr = match strukt.field_list() {
324                Some(ast::FieldList::RecordFieldList(field_list)) => {
325                    let mut fields = vec![];
326                    for field in field_list.fields() {
327                        let method_call = gen_default_call()?;
328                        let name_ref = make::name_ref(&field.name()?.to_string());
329                        let field = make::record_expr_field(name_ref, Some(method_call));
330                        fields.push(field);
331                    }
332                    let struct_name = make::ext::ident_path("Self");
333                    let fields = make::record_expr_field_list(fields);
334                    make::record_expr(struct_name, fields).into()
335                }
336                Some(ast::FieldList::TupleFieldList(field_list)) => {
337                    let struct_name = make::expr_path(make::ext::ident_path("Self"));
338                    let fields = field_list
339                        .fields()
340                        .map(|_| gen_default_call())
341                        .collect::<Option<Vec<ast::Expr>>>()?;
342                    make::expr_call(struct_name, make::arg_list(fields)).into()
343                }
344                None => {
345                    let struct_name = make::ext::ident_path("Self");
346                    let fields = make::record_expr_field_list(None);
347                    make::record_expr(struct_name, fields).into()
348                }
349            };
350            let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
351            Some(body)
352        }
353    }
354}
355
356/// Generate a `Hash` impl based on the fields and members of the target type.
357fn gen_hash_impl(adt: &ast::Adt) -> Option<ast::BlockExpr> {
358    fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
359        let method = make::name_ref("hash");
360        let arg = make::expr_path(make::ext::ident_path("state"));
361        let expr = make::expr_method_call(target, method, make::arg_list(Some(arg))).into();
362        make::expr_stmt(expr).into()
363    }
364
365    let body = match adt {
366        // `Hash` cannot be derived for unions, so no default impl can be provided.
367        ast::Adt::Union(_) => return None,
368
369        // => std::mem::discriminant(self).hash(state);
370        ast::Adt::Enum(_) => {
371            let fn_name = make_discriminant()?;
372
373            let arg = make::expr_path(make::ext::ident_path("self"));
374            let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg))).into();
375            let stmt = gen_hash_call(fn_call);
376
377            make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
378        }
379        ast::Adt::Struct(strukt) => match strukt.field_list() {
380            // => self.<field>.hash(state);
381            Some(ast::FieldList::RecordFieldList(field_list)) => {
382                let mut stmts = vec![];
383                for field in field_list.fields() {
384                    let base = make::expr_path(make::ext::ident_path("self"));
385                    let target = make::expr_field(base, &field.name()?.to_string());
386                    stmts.push(gen_hash_call(target));
387                }
388                make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
389            }
390
391            // => self.<field_index>.hash(state);
392            Some(ast::FieldList::TupleFieldList(field_list)) => {
393                let mut stmts = vec![];
394                for (i, _) in field_list.fields().enumerate() {
395                    let base = make::expr_path(make::ext::ident_path("self"));
396                    let target = make::expr_field(base, &format!("{i}"));
397                    stmts.push(gen_hash_call(target));
398                }
399                make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
400            }
401
402            // No fields in the body means there's nothing to hash.
403            None => return None,
404        },
405    };
406
407    Some(body)
408}
409
410/// Generate a `PartialEq` impl based on the fields and members of the target type.
411fn gen_partial_eq(adt: &ast::Adt, trait_ref: Option<TraitRef<'_>>) -> Option<ast::BlockExpr> {
412    fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
413        match expr {
414            Some(expr) => Some(make::expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)),
415            None => Some(cmp),
416        }
417    }
418
419    fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
420        let pat = make::ext::simple_ident_pat(make::name(pat_name));
421        let name_ref = make::name_ref(field_name);
422        make::record_pat_field(name_ref, pat.into())
423    }
424
425    fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
426        let list = make::record_pat_field_list(fields, None);
427        make::record_pat_with_fields(record_name, list)
428    }
429
430    fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
431        make::ext::path_from_idents(["Self", &variant.name()?.to_string()])
432    }
433
434    fn gen_tuple_field(field_name: &str) -> ast::Pat {
435        ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
436    }
437
438    // Check that self type and rhs type match. We don't know how to implement the method
439    // automatically otherwise.
440    if let Some(trait_ref) = trait_ref {
441        let self_ty = trait_ref.self_ty();
442        let rhs_ty = trait_ref.get_type_argument(1)?;
443        if self_ty != rhs_ty {
444            return None;
445        }
446    }
447
448    let body = match adt {
449        // `PartialEq` cannot be derived for unions, so no default impl can be provided.
450        ast::Adt::Union(_) => return None,
451
452        ast::Adt::Enum(enum_) => {
453            // => std::mem::discriminant(self) == std::mem::discriminant(other)
454            let lhs_name = make::expr_path(make::ext::ident_path("self"));
455            let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())))
456                .into();
457            let rhs_name = make::expr_path(make::ext::ident_path("other"));
458            let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())))
459                .into();
460            let eq_check =
461                make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
462
463            let mut n_cases = 0;
464            let mut arms = vec![];
465            for variant in enum_.variant_list()?.variants() {
466                n_cases += 1;
467                match variant.field_list() {
468                    // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
469                    Some(ast::FieldList::RecordFieldList(list)) => {
470                        let mut expr = None;
471                        let mut l_fields = vec![];
472                        let mut r_fields = vec![];
473
474                        for field in list.fields() {
475                            let field_name = field.name()?.to_string();
476
477                            let l_name = &format!("l_{field_name}");
478                            l_fields.push(gen_record_pat_field(&field_name, l_name));
479
480                            let r_name = &format!("r_{field_name}");
481                            r_fields.push(gen_record_pat_field(&field_name, r_name));
482
483                            let lhs = make::expr_path(make::ext::ident_path(l_name));
484                            let rhs = make::expr_path(make::ext::ident_path(r_name));
485                            let cmp = make::expr_bin_op(
486                                lhs,
487                                BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
488                                rhs,
489                            );
490                            expr = gen_eq_chain(expr, cmp);
491                        }
492
493                        let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
494                        let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
495                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
496
497                        if let Some(expr) = expr {
498                            arms.push(make::match_arm(tuple.into(), None, expr));
499                        }
500                    }
501
502                    Some(ast::FieldList::TupleFieldList(list)) => {
503                        let mut expr = None;
504                        let mut l_fields = vec![];
505                        let mut r_fields = vec![];
506
507                        for (i, _) in list.fields().enumerate() {
508                            let field_name = format!("{i}");
509
510                            let l_name = format!("l{field_name}");
511                            l_fields.push(gen_tuple_field(&l_name));
512
513                            let r_name = format!("r{field_name}");
514                            r_fields.push(gen_tuple_field(&r_name));
515
516                            let lhs = make::expr_path(make::ext::ident_path(&l_name));
517                            let rhs = make::expr_path(make::ext::ident_path(&r_name));
518                            let cmp = make::expr_bin_op(
519                                lhs,
520                                BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
521                                rhs,
522                            );
523                            expr = gen_eq_chain(expr, cmp);
524                        }
525
526                        let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
527                        let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
528                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
529
530                        if let Some(expr) = expr {
531                            arms.push(make::match_arm(tuple.into(), None, expr));
532                        }
533                    }
534                    None => continue,
535                }
536            }
537
538            let expr = match arms.len() {
539                0 => eq_check,
540                arms_len => {
541                    // Generate the fallback arm when this enum has >1 variants.
542                    // The fallback arm will be `_ => false,` if we've already gone through every case where the variants of self and other match,
543                    // and `_ => std::mem::discriminant(self) == std::mem::discriminant(other),` otherwise.
544                    if n_cases > 1 {
545                        let lhs = make::wildcard_pat().into();
546                        let rhs = if arms_len == n_cases {
547                            make::expr_literal("false").into()
548                        } else {
549                            eq_check
550                        };
551                        arms.push(make::match_arm(lhs, None, rhs));
552                    }
553
554                    let match_target = make::expr_tuple([lhs_name, rhs_name]).into();
555                    let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
556                    make::expr_match(match_target, list).into()
557                }
558            };
559
560            make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
561        }
562        ast::Adt::Struct(strukt) => match strukt.field_list() {
563            Some(ast::FieldList::RecordFieldList(field_list)) => {
564                let mut expr = None;
565                for field in field_list.fields() {
566                    let lhs = make::expr_path(make::ext::ident_path("self"));
567                    let lhs = make::expr_field(lhs, &field.name()?.to_string());
568                    let rhs = make::expr_path(make::ext::ident_path("other"));
569                    let rhs = make::expr_field(rhs, &field.name()?.to_string());
570                    let cmp =
571                        make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
572                    expr = gen_eq_chain(expr, cmp);
573                }
574                make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
575            }
576
577            Some(ast::FieldList::TupleFieldList(field_list)) => {
578                let mut expr = None;
579                for (i, _) in field_list.fields().enumerate() {
580                    let idx = format!("{i}");
581                    let lhs = make::expr_path(make::ext::ident_path("self"));
582                    let lhs = make::expr_field(lhs, &idx);
583                    let rhs = make::expr_path(make::ext::ident_path("other"));
584                    let rhs = make::expr_field(rhs, &idx);
585                    let cmp =
586                        make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
587                    expr = gen_eq_chain(expr, cmp);
588                }
589                make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
590            }
591
592            // No fields in the body means there's nothing to compare.
593            None => {
594                let expr = make::expr_literal("true").into();
595                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
596            }
597        },
598    };
599
600    Some(body)
601}
602
603fn gen_partial_ord(adt: &ast::Adt, trait_ref: Option<TraitRef<'_>>) -> Option<ast::BlockExpr> {
604    fn gen_partial_eq_match(match_target: ast::Expr) -> Option<ast::Stmt> {
605        let mut arms = vec![];
606
607        let variant_name =
608            make::path_pat(make::ext::path_from_idents(["core", "cmp", "Ordering", "Equal"])?);
609        let lhs = make::tuple_struct_pat(make::ext::path_from_idents(["Some"])?, [variant_name]);
610        arms.push(make::match_arm(lhs.into(), None, make::expr_empty_block().into()));
611
612        arms.push(make::match_arm(
613            make::ident_pat(false, false, make::name("ord")).into(),
614            None,
615            make::expr_return(Some(make::expr_path(make::ext::ident_path("ord")))),
616        ));
617        let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
618        Some(make::expr_stmt(make::expr_match(match_target, list).into()).into())
619    }
620
621    fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr {
622        let rhs = make::expr_ref(rhs, false);
623        let method = make::name_ref("partial_cmp");
624        make::expr_method_call(lhs, method, make::arg_list(Some(rhs))).into()
625    }
626
627    // Check that self type and rhs type match. We don't know how to implement the method
628    // automatically otherwise.
629    if let Some(trait_ref) = trait_ref {
630        let self_ty = trait_ref.self_ty();
631        let rhs_ty = trait_ref.get_type_argument(1)?;
632        if self_ty != rhs_ty {
633            return None;
634        }
635    }
636
637    let body = match adt {
638        // `PartialOrd` cannot be derived for unions, so no default impl can be provided.
639        ast::Adt::Union(_) => return None,
640        // `core::mem::Discriminant` does not implement `PartialOrd` in stable Rust today.
641        ast::Adt::Enum(_) => return None,
642        ast::Adt::Struct(strukt) => match strukt.field_list() {
643            Some(ast::FieldList::RecordFieldList(field_list)) => {
644                let mut exprs = vec![];
645                for field in field_list.fields() {
646                    let lhs = make::expr_path(make::ext::ident_path("self"));
647                    let lhs = make::expr_field(lhs, &field.name()?.to_string());
648                    let rhs = make::expr_path(make::ext::ident_path("other"));
649                    let rhs = make::expr_field(rhs, &field.name()?.to_string());
650                    let ord = gen_partial_cmp_call(lhs, rhs);
651                    exprs.push(ord);
652                }
653
654                let tail = exprs.pop();
655                let stmts = exprs
656                    .into_iter()
657                    .map(gen_partial_eq_match)
658                    .collect::<Option<Vec<ast::Stmt>>>()?;
659                make::block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
660            }
661
662            Some(ast::FieldList::TupleFieldList(field_list)) => {
663                let mut exprs = vec![];
664                for (i, _) in field_list.fields().enumerate() {
665                    let idx = format!("{i}");
666                    let lhs = make::expr_path(make::ext::ident_path("self"));
667                    let lhs = make::expr_field(lhs, &idx);
668                    let rhs = make::expr_path(make::ext::ident_path("other"));
669                    let rhs = make::expr_field(rhs, &idx);
670                    let ord = gen_partial_cmp_call(lhs, rhs);
671                    exprs.push(ord);
672                }
673                let tail = exprs.pop();
674                let stmts = exprs
675                    .into_iter()
676                    .map(gen_partial_eq_match)
677                    .collect::<Option<Vec<ast::Stmt>>>()?;
678                make::block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
679            }
680
681            // No fields in the body means there's nothing to compare.
682            None => {
683                let expr = make::expr_literal("true").into();
684                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
685            }
686        },
687    };
688
689    Some(body)
690}
691
692fn make_discriminant() -> Option<ast::Expr> {
693    Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?))
694}