Skip to main content

ide_assists/utils/
gen_trait_fn_body.rs

1//! This module contains functions to generate default trait impl function bodies where possible.
2
3use hir::TraitRef;
4use syntax::ast::{
5    self, AstNode, BinaryOp, CmpOp, HasName, LogicOp, edit::AstNodeEdit,
6    syntax_factory::SyntaxFactory,
7};
8
9/// Generate custom trait bodies without default implementation where possible.
10///
11/// If `func` is defined within an existing impl block, pass [`TraitRef`]. Otherwise pass `None`.
12///
13/// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
14/// `None` means that generating a custom trait body failed, and the body will remain
15/// as `todo!` instead.
16pub(crate) fn gen_trait_fn_body(
17    make: &SyntaxFactory,
18    func: &ast::Fn,
19    trait_path: &ast::Path,
20    adt: &ast::Adt,
21    trait_ref: Option<TraitRef<'_>>,
22) -> Option<ast::BlockExpr> {
23    let _ = func.body()?;
24    match trait_path.segment()?.name_ref()?.text().as_str() {
25        "Clone" => {
26            stdx::always!(func.name().is_some_and(|name| name.text() == "clone"));
27            gen_clone_impl(make, adt)
28        }
29        "Debug" => gen_debug_impl(make, adt),
30        "Default" => gen_default_impl(make, adt),
31        "Hash" => {
32            stdx::always!(func.name().is_some_and(|name| name.text() == "hash"));
33            gen_hash_impl(make, adt)
34        }
35        "PartialEq" => {
36            stdx::always!(func.name().is_some_and(|name| name.text() == "eq"));
37            gen_partial_eq(make, adt, trait_ref)
38        }
39        "PartialOrd" => {
40            stdx::always!(func.name().is_some_and(|name| name.text() == "partial_cmp"));
41            gen_partial_ord(make, adt, trait_ref)
42        }
43        _ => None,
44    }
45}
46
47/// Generate a `Clone` impl based on the fields and members of the target type.
48fn gen_clone_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
49    let gen_clone_call = |target: ast::Expr| -> ast::Expr {
50        let method = make.name_ref("clone");
51        make.expr_method_call(target, method, make.arg_list([])).into()
52    };
53    let expr = match adt {
54        // `Clone` cannot be derived for unions, so no default impl can be provided.
55        ast::Adt::Union(_) => return None,
56        ast::Adt::Enum(enum_) => {
57            let list = enum_.variant_list()?;
58            let mut arms = vec![];
59            for variant in list.variants() {
60                let name = variant.name()?;
61                let variant_name = make.path_from_idents(["Self", &format!("{name}")])?;
62
63                match variant.field_list() {
64                    // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
65                    Some(ast::FieldList::RecordFieldList(list)) => {
66                        let mut pats = vec![];
67                        let mut fields = vec![];
68                        for field in list.fields() {
69                            let field_name = field.name()?;
70                            let pat = make.ident_pat(false, false, field_name.clone());
71                            pats.push(make.record_pat_field_shorthand(pat.into()));
72
73                            let path = make.ident_path(&field_name.to_string());
74                            let method_call = gen_clone_call(make.expr_path(path));
75                            let name_ref = make.name_ref(&field_name.to_string());
76                            let field = make.record_expr_field(name_ref, Some(method_call));
77                            fields.push(field);
78                        }
79                        let pat_field_list = make.record_pat_field_list(pats, None);
80                        let pat = make.record_pat_with_fields(variant_name.clone(), pat_field_list);
81                        let fields = make.record_expr_field_list(fields);
82                        let record_expr = make.record_expr(variant_name, fields).into();
83                        arms.push(make.match_arm(pat.into(), None, record_expr));
84                    }
85
86                    // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
87                    Some(ast::FieldList::TupleFieldList(list)) => {
88                        let mut pats = vec![];
89                        let mut fields = vec![];
90                        for (i, _) in list.fields().enumerate() {
91                            let field_name = format!("arg{i}");
92                            let pat = make.ident_pat(false, false, make.name(&field_name));
93                            pats.push(pat.into());
94
95                            let f_path = make.expr_path(make.ident_path(&field_name));
96                            fields.push(gen_clone_call(f_path));
97                        }
98                        let pat = make.tuple_struct_pat(variant_name.clone(), pats);
99                        let struct_name = make.expr_path(variant_name);
100                        let tuple_expr = make.expr_call(struct_name, make.arg_list(fields)).into();
101                        arms.push(make.match_arm(pat.into(), None, tuple_expr));
102                    }
103
104                    // => match self { Self::Name => Self::Name }
105                    None => {
106                        let pattern = make.path_pat(variant_name.clone());
107                        let variant_expr = make.expr_path(variant_name);
108                        arms.push(make.match_arm(pattern, None, variant_expr));
109                    }
110                }
111            }
112
113            let match_target = make.expr_path(make.ident_path("self"));
114            let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
115            make.expr_match(match_target, list).into()
116        }
117        ast::Adt::Struct(strukt) => {
118            match strukt.field_list() {
119                // => Self { name: self.name.clone() }
120                Some(ast::FieldList::RecordFieldList(field_list)) => {
121                    let mut fields = vec![];
122                    for field in field_list.fields() {
123                        let base = make.expr_path(make.ident_path("self"));
124                        let target = make.expr_field(base, &field.name()?.to_string()).into();
125                        let method_call = gen_clone_call(target);
126                        let name_ref = make.name_ref(&field.name()?.to_string());
127                        let field = make.record_expr_field(name_ref, Some(method_call));
128                        fields.push(field);
129                    }
130                    let struct_name = make.ident_path("Self");
131                    let fields = make.record_expr_field_list(fields);
132                    make.record_expr(struct_name, fields).into()
133                }
134                // => Self(self.0.clone(), self.1.clone())
135                Some(ast::FieldList::TupleFieldList(field_list)) => {
136                    let mut fields = vec![];
137                    for (i, _) in field_list.fields().enumerate() {
138                        let f_path = make.expr_path(make.ident_path("self"));
139                        let target = make.expr_field(f_path, &format!("{i}")).into();
140                        fields.push(gen_clone_call(target));
141                    }
142                    let struct_name = make.expr_path(make.ident_path("Self"));
143                    make.expr_call(struct_name, make.arg_list(fields)).into()
144                }
145                // => Self { }
146                None => {
147                    let struct_name = make.ident_path("Self");
148                    let fields = make.record_expr_field_list([]);
149                    make.record_expr(struct_name, fields).into()
150                }
151            }
152        }
153    };
154    let body = make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
155    Some(body)
156}
157
158/// Generate a `Debug` impl based on the fields and members of the target type.
159fn gen_debug_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
160    let annotated_name = adt.name()?;
161    match adt {
162        // `Debug` cannot be derived for unions, so no default impl can be provided.
163        ast::Adt::Union(_) => None,
164
165        // => match self { Self::Variant => write!(f, "Variant") }
166        ast::Adt::Enum(enum_) => {
167            let list = enum_.variant_list()?;
168            let mut arms = vec![];
169            for variant in list.variants() {
170                let name = variant.name()?;
171                let variant_name = make.path_from_idents(["Self", &format!("{name}")])?;
172                let target = make.expr_path(make.ident_path("f"));
173
174                match variant.field_list() {
175                    Some(ast::FieldList::RecordFieldList(list)) => {
176                        // => f.debug_struct(name)
177                        let target = make.expr_path(make.ident_path("f"));
178                        let method = make.name_ref("debug_struct");
179                        let struct_name = format!("\"{name}\"");
180                        let args = make.arg_list([make.expr_literal(&struct_name).into()]);
181                        let mut expr = make.expr_method_call(target, method, args).into();
182
183                        let mut pats = vec![];
184                        for field in list.fields() {
185                            let field_name = field.name()?;
186
187                            // create a field pattern for use in `MyStruct { fields.. }`
188                            let pat = make.ident_pat(false, false, field_name.clone());
189                            pats.push(make.record_pat_field_shorthand(pat.into()));
190
191                            // => <expr>.field("field_name", field)
192                            let method_name = make.name_ref("field");
193                            let name = make.expr_literal(&(format!("\"{field_name}\""))).into();
194                            let path = &format!("{field_name}");
195                            let path = make.expr_path(make.ident_path(path));
196                            let args = make.arg_list([name, path]);
197                            expr = make.expr_method_call(expr, method_name, args).into();
198                        }
199
200                        // => <expr>.finish()
201                        let method = make.name_ref("finish");
202                        let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
203
204                        // => MyStruct { fields.. } => f.debug_struct("MyStruct")...finish(),
205                        let pat_field_list = make.record_pat_field_list(pats, None);
206                        let pat = make.record_pat_with_fields(variant_name.clone(), pat_field_list);
207                        arms.push(make.match_arm(pat.into(), None, expr));
208                    }
209                    Some(ast::FieldList::TupleFieldList(list)) => {
210                        // => f.debug_tuple(name)
211                        let target = make.expr_path(make.ident_path("f"));
212                        let method = make.name_ref("debug_tuple");
213                        let struct_name = format!("\"{name}\"");
214                        let args = make.arg_list([make.expr_literal(&struct_name).into()]);
215                        let mut expr = make.expr_method_call(target, method, args).into();
216
217                        let mut pats = vec![];
218                        for (i, _) in list.fields().enumerate() {
219                            let name = format!("arg{i}");
220
221                            // create a field pattern for use in `MyStruct(fields..)`
222                            let field_name = make.name(&name);
223                            let pat = make.ident_pat(false, false, field_name.clone());
224                            pats.push(pat.into());
225
226                            // => <expr>.field(field)
227                            let method_name = make.name_ref("field");
228                            let field_path = &name.to_string();
229                            let field_path = make.expr_path(make.ident_path(field_path));
230                            let args = make.arg_list([field_path]);
231                            expr = make.expr_method_call(expr, method_name, args).into();
232                        }
233
234                        // => <expr>.finish()
235                        let method = make.name_ref("finish");
236                        let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
237
238                        // => MyStruct (fields..) => f.debug_tuple("MyStruct")...finish(),
239                        let pat = make.tuple_struct_pat(variant_name.clone(), pats);
240                        arms.push(make.match_arm(pat.into(), None, expr));
241                    }
242                    None => {
243                        let fmt_string = make.expr_literal(&(format!("\"{name}\""))).into();
244                        let args =
245                            make.token_tree_from_node(make.arg_list([target, fmt_string]).syntax());
246                        let macro_name = make.ident_path("write");
247                        let macro_call = make.expr_macro(macro_name, args);
248
249                        let variant_name = make.path_pat(variant_name);
250                        arms.push(make.match_arm(variant_name, None, macro_call.into()));
251                    }
252                }
253            }
254
255            let match_target = make.expr_path(make.ident_path("self"));
256            let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
257            let match_expr = make.expr_match(match_target, list);
258
259            let body = make.block_expr(None::<ast::Stmt>, Some(match_expr.into()));
260            let body = body.indent(ast::edit::IndentLevel(1));
261            Some(body)
262        }
263
264        ast::Adt::Struct(strukt) => {
265            let name = format!("\"{annotated_name}\"");
266            let args = make.arg_list([make.expr_literal(&name).into()]);
267            let target = make.expr_path(make.ident_path("f"));
268
269            let expr = match strukt.field_list() {
270                // => f.debug_struct("Name").finish()
271                None => make.expr_method_call(target, make.name_ref("debug_struct"), args).into(),
272
273                // => f.debug_struct("Name").field("foo", &self.foo).finish()
274                Some(ast::FieldList::RecordFieldList(field_list)) => {
275                    let method = make.name_ref("debug_struct");
276                    let mut expr = make.expr_method_call(target, method, args).into();
277                    for field in field_list.fields() {
278                        let name = field.name()?;
279                        let f_name = make.expr_literal(&(format!("\"{name}\""))).into();
280                        let f_path = make.expr_path(make.ident_path("self"));
281                        let f_path = make.expr_field(f_path, &format!("{name}")).into();
282                        let f_path = make.expr_ref(f_path, false);
283                        let args = make.arg_list([f_name, f_path]);
284                        expr = make.expr_method_call(expr, make.name_ref("field"), args).into();
285                    }
286                    expr
287                }
288
289                // => f.debug_tuple("Name").field(&self.0).finish()
290                Some(ast::FieldList::TupleFieldList(field_list)) => {
291                    let method = make.name_ref("debug_tuple");
292                    let mut expr = make.expr_method_call(target, method, args).into();
293                    for (i, _) in field_list.fields().enumerate() {
294                        let f_path = make.expr_path(make.ident_path("self"));
295                        let f_path = make.expr_field(f_path, &format!("{i}")).into();
296                        let f_path = make.expr_ref(f_path, false);
297                        let method = make.name_ref("field");
298                        expr = make.expr_method_call(expr, method, make.arg_list([f_path])).into();
299                    }
300                    expr
301                }
302            };
303
304            let method = make.name_ref("finish");
305            let expr = make.expr_method_call(expr, method, make.arg_list([])).into();
306            let body =
307                make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1));
308            Some(body)
309        }
310    }
311}
312
313/// Generate a `Default` impl based on the fields and members of the target type.
314fn gen_default_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
315    let gen_default_call = || -> Option<ast::Expr> {
316        let fn_name = make.path_from_idents(["Default", "default"])?;
317        Some(make.expr_call(make.expr_path(fn_name), make.arg_list([])).into())
318    };
319    match adt {
320        // `Debug` cannot be derived for unions, so no default impl can be provided.
321        ast::Adt::Union(_) => None,
322        // Deriving `Debug` for enums is not stable yet.
323        ast::Adt::Enum(_) => None,
324        ast::Adt::Struct(strukt) => {
325            let expr = match strukt.field_list() {
326                Some(ast::FieldList::RecordFieldList(field_list)) => {
327                    let mut fields = vec![];
328                    for field in field_list.fields() {
329                        let method_call = gen_default_call()?;
330                        let name_ref = make.name_ref(&field.name()?.to_string());
331                        let field = make.record_expr_field(name_ref, Some(method_call));
332                        fields.push(field);
333                    }
334                    let struct_name = make.ident_path("Self");
335                    let fields = make.record_expr_field_list(fields);
336                    make.record_expr(struct_name, fields).into()
337                }
338                Some(ast::FieldList::TupleFieldList(field_list)) => {
339                    let struct_name = make.expr_path(make.ident_path("Self"));
340                    let fields = field_list
341                        .fields()
342                        .map(|_| gen_default_call())
343                        .collect::<Option<Vec<ast::Expr>>>()?;
344                    make.expr_call(struct_name, make.arg_list(fields)).into()
345                }
346                None => {
347                    let struct_name = make.ident_path("Self");
348                    let fields = make.record_expr_field_list([]);
349                    make.record_expr(struct_name, fields).into()
350                }
351            };
352            let body =
353                make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1));
354            Some(body)
355        }
356    }
357}
358
359/// Generate a `Hash` impl based on the fields and members of the target type.
360fn gen_hash_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option<ast::BlockExpr> {
361    let gen_hash_call = |target: ast::Expr| -> ast::Stmt {
362        let method = make.name_ref("hash");
363        let arg = make.expr_path(make.ident_path("state"));
364        let expr = make.expr_method_call(target, method, make.arg_list([arg])).into();
365        make.expr_stmt(expr).into()
366    };
367
368    let body = match adt {
369        // `Hash` cannot be derived for unions, so no default impl can be provided.
370        ast::Adt::Union(_) => return None,
371
372        // => std::mem::discriminant(self).hash(state);
373        ast::Adt::Enum(_) => {
374            let fn_name = make_discriminant(make)?;
375
376            let arg = make.expr_path(make.ident_path("self"));
377            let fn_call: ast::Expr = make.expr_call(fn_name, make.arg_list([arg])).into();
378            let stmt = gen_hash_call(fn_call);
379
380            make.block_expr([stmt], None).indent(ast::edit::IndentLevel(1))
381        }
382        ast::Adt::Struct(strukt) => match strukt.field_list() {
383            // => self.<field>.hash(state);
384            Some(ast::FieldList::RecordFieldList(field_list)) => {
385                let mut stmts = vec![];
386                for field in field_list.fields() {
387                    let base = make.expr_path(make.ident_path("self"));
388                    let target = make.expr_field(base, &field.name()?.to_string()).into();
389                    stmts.push(gen_hash_call(target));
390                }
391                make.block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
392            }
393
394            // => self.<field_index>.hash(state);
395            Some(ast::FieldList::TupleFieldList(field_list)) => {
396                let mut stmts = vec![];
397                for (i, _) in field_list.fields().enumerate() {
398                    let base = make.expr_path(make.ident_path("self"));
399                    let target = make.expr_field(base, &format!("{i}")).into();
400                    stmts.push(gen_hash_call(target));
401                }
402                make.block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
403            }
404
405            // No fields in the body means there's nothing to hash.
406            None => return None,
407        },
408    };
409
410    Some(body)
411}
412
413/// Generate a `PartialEq` impl based on the fields and members of the target type.
414fn gen_partial_eq(
415    make: &SyntaxFactory,
416    adt: &ast::Adt,
417    trait_ref: Option<TraitRef<'_>>,
418) -> Option<ast::BlockExpr> {
419    let gen_eq_chain = |expr: Option<ast::Expr>, cmp: ast::Expr| -> Option<ast::Expr> {
420        match expr {
421            Some(expr) => Some(make.expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)),
422            None => Some(cmp),
423        }
424    };
425
426    let gen_record_pat_field = |field_name: &str, pat_name: &str| -> ast::RecordPatField {
427        let pat = make.ident_pat(false, false, make.name(pat_name));
428        let name_ref = make.name_ref(field_name);
429        make.record_pat_field(name_ref, pat.into())
430    };
431
432    let gen_record_pat =
433        |record_name: ast::Path, fields: Vec<ast::RecordPatField>| -> ast::RecordPat {
434            let list = make.record_pat_field_list(fields, None);
435            make.record_pat_with_fields(record_name, list)
436        };
437
438    let gen_variant_path = |variant: &ast::Variant| -> Option<ast::Path> {
439        make.path_from_idents(["Self", &variant.name()?.to_string()])
440    };
441
442    let gen_tuple_field = |field_name: &str| -> ast::Pat {
443        ast::Pat::IdentPat(make.ident_pat(false, false, make.name(field_name)))
444    };
445
446    // Check that self type and rhs type match. We don't know how to implement the method
447    // automatically otherwise.
448    if let Some(trait_ref) = trait_ref {
449        let self_ty = trait_ref.self_ty();
450        let rhs_ty = trait_ref.get_type_argument(1)?;
451        if self_ty != rhs_ty {
452            return None;
453        }
454    }
455
456    let body = match adt {
457        // `PartialEq` cannot be derived for unions, so no default impl can be provided.
458        ast::Adt::Union(_) => return None,
459
460        ast::Adt::Enum(enum_) => {
461            // => std::mem::discriminant(self) == std::mem::discriminant(other)
462            let lhs_name = make.expr_path(make.ident_path("self"));
463            let lhs =
464                make.expr_call(make_discriminant(make)?, make.arg_list([lhs_name.clone()])).into();
465            let rhs_name = make.expr_path(make.ident_path("other"));
466            let rhs =
467                make.expr_call(make_discriminant(make)?, make.arg_list([rhs_name.clone()])).into();
468            let eq_check =
469                make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
470
471            let mut n_cases = 0;
472            let mut arms = vec![];
473            for variant in enum_.variant_list()?.variants() {
474                n_cases += 1;
475                match variant.field_list() {
476                    // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
477                    Some(ast::FieldList::RecordFieldList(list)) => {
478                        let mut expr = None;
479                        let mut l_fields = vec![];
480                        let mut r_fields = vec![];
481
482                        for field in list.fields() {
483                            let field_name = field.name()?.to_string();
484
485                            let l_name = &format!("l_{field_name}");
486                            l_fields.push(gen_record_pat_field(&field_name, l_name));
487
488                            let r_name = &format!("r_{field_name}");
489                            r_fields.push(gen_record_pat_field(&field_name, r_name));
490
491                            let lhs = make.expr_path(make.ident_path(l_name));
492                            let rhs = make.expr_path(make.ident_path(r_name));
493                            let cmp = make.expr_bin_op(
494                                lhs,
495                                BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
496                                rhs,
497                            );
498                            expr = gen_eq_chain(expr, cmp);
499                        }
500
501                        let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
502                        let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
503                        let tuple = make.tuple_pat(vec![left.into(), right.into()]);
504
505                        if let Some(expr) = expr {
506                            arms.push(make.match_arm(tuple.into(), None, expr));
507                        }
508                    }
509
510                    Some(ast::FieldList::TupleFieldList(list)) => {
511                        let mut expr = None;
512                        let mut l_fields = vec![];
513                        let mut r_fields = vec![];
514
515                        for (i, _) in list.fields().enumerate() {
516                            let field_name = format!("{i}");
517
518                            let l_name = format!("l{field_name}");
519                            l_fields.push(gen_tuple_field(&l_name));
520
521                            let r_name = format!("r{field_name}");
522                            r_fields.push(gen_tuple_field(&r_name));
523
524                            let lhs = make.expr_path(make.ident_path(&l_name));
525                            let rhs = make.expr_path(make.ident_path(&r_name));
526                            let cmp = make.expr_bin_op(
527                                lhs,
528                                BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
529                                rhs,
530                            );
531                            expr = gen_eq_chain(expr, cmp);
532                        }
533
534                        let left = make.tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
535                        let right = make.tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
536                        let tuple = make.tuple_pat(vec![left.into(), right.into()]);
537
538                        if let Some(expr) = expr {
539                            arms.push(make.match_arm(tuple.into(), None, expr));
540                        }
541                    }
542                    None => continue,
543                }
544            }
545
546            let expr = match arms.len() {
547                0 => eq_check,
548                arms_len => {
549                    // Generate the fallback arm when this enum has >1 variants.
550                    // The fallback arm will be `_ => false,` if we've already gone through every case where the variants of self and other match,
551                    // and `_ => std::mem::discriminant(self) == std::mem::discriminant(other),` otherwise.
552                    if n_cases > 1 {
553                        let lhs = make.wildcard_pat().into();
554                        let rhs = if arms_len == n_cases {
555                            make.expr_literal("false").into()
556                        } else {
557                            eq_check
558                        };
559                        arms.push(make.match_arm(lhs, None, rhs));
560                    }
561
562                    let match_target = make.expr_tuple([lhs_name, rhs_name]).into();
563                    let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
564                    make.expr_match(match_target, list).into()
565                }
566            };
567
568            make.block_expr(None::<ast::Stmt>, Some(expr)).indent(ast::edit::IndentLevel(1))
569        }
570        ast::Adt::Struct(strukt) => match strukt.field_list() {
571            Some(ast::FieldList::RecordFieldList(field_list)) => {
572                let mut expr = None;
573                for field in field_list.fields() {
574                    let lhs = make.expr_path(make.ident_path("self"));
575                    let lhs = make.expr_field(lhs, &field.name()?.to_string()).into();
576                    let rhs = make.expr_path(make.ident_path("other"));
577                    let rhs = make.expr_field(rhs, &field.name()?.to_string()).into();
578                    let cmp =
579                        make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
580                    expr = gen_eq_chain(expr, cmp);
581                }
582                make.block_expr(None, expr).indent(ast::edit::IndentLevel(1))
583            }
584
585            Some(ast::FieldList::TupleFieldList(field_list)) => {
586                let mut expr = None;
587                for (i, _) in field_list.fields().enumerate() {
588                    let idx = format!("{i}");
589                    let lhs = make.expr_path(make.ident_path("self"));
590                    let lhs = make.expr_field(lhs, &idx).into();
591                    let rhs = make.expr_path(make.ident_path("other"));
592                    let rhs = make.expr_field(rhs, &idx).into();
593                    let cmp =
594                        make.expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
595                    expr = gen_eq_chain(expr, cmp);
596                }
597                make.block_expr(None::<ast::Stmt>, expr).indent(ast::edit::IndentLevel(1))
598            }
599
600            // No fields in the body means there's nothing to compare.
601            None => {
602                let expr = make.expr_literal("true").into();
603                make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
604            }
605        },
606    };
607
608    Some(body)
609}
610
611fn gen_partial_ord(
612    make: &SyntaxFactory,
613    adt: &ast::Adt,
614    trait_ref: Option<TraitRef<'_>>,
615) -> Option<ast::BlockExpr> {
616    let gen_partial_eq_match = |match_target: ast::Expr| -> Option<ast::Stmt> {
617        let mut arms = vec![];
618
619        let variant_name =
620            make.path_pat(make.path_from_idents(["core", "cmp", "Ordering", "Equal"])?);
621        let lhs = make.tuple_struct_pat(make.path_from_idents(["Some"])?, [variant_name]);
622        arms.push(make.match_arm(lhs.into(), None, make.expr_empty_block().into()));
623
624        arms.push(make.match_arm(
625            make.ident_pat(false, false, make.name("ord")).into(),
626            None,
627            make.expr_return(Some(make.expr_path(make.ident_path("ord")))).into(),
628        ));
629        let list = make.match_arm_list(arms).indent(ast::edit::IndentLevel(1));
630        Some(make.expr_stmt(make.expr_match(match_target, list).into()).into())
631    };
632
633    let gen_partial_cmp_call = |lhs: ast::Expr, rhs: ast::Expr| -> ast::Expr {
634        let rhs = make.expr_ref(rhs, false);
635        let method = make.name_ref("partial_cmp");
636        make.expr_method_call(lhs, method, make.arg_list([rhs])).into()
637    };
638
639    // Check that self type and rhs type match. We don't know how to implement the method
640    // automatically otherwise.
641    if let Some(trait_ref) = trait_ref {
642        let self_ty = trait_ref.self_ty();
643        let rhs_ty = trait_ref.get_type_argument(1)?;
644        if self_ty != rhs_ty {
645            return None;
646        }
647    }
648
649    let body = match adt {
650        // `PartialOrd` cannot be derived for unions, so no default impl can be provided.
651        ast::Adt::Union(_) => return None,
652        // `core::mem::Discriminant` does not implement `PartialOrd` in stable Rust today.
653        ast::Adt::Enum(_) => return None,
654        ast::Adt::Struct(strukt) => match strukt.field_list() {
655            Some(ast::FieldList::RecordFieldList(field_list)) => {
656                let mut exprs = vec![];
657                for field in field_list.fields() {
658                    let lhs = make.expr_path(make.ident_path("self"));
659                    let lhs = make.expr_field(lhs, &field.name()?.to_string()).into();
660                    let rhs = make.expr_path(make.ident_path("other"));
661                    let rhs = make.expr_field(rhs, &field.name()?.to_string()).into();
662                    let ord = gen_partial_cmp_call(lhs, rhs);
663                    exprs.push(ord);
664                }
665
666                let tail = exprs.pop();
667                let stmts = exprs
668                    .into_iter()
669                    .map(gen_partial_eq_match)
670                    .collect::<Option<Vec<ast::Stmt>>>()?;
671                make.block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
672            }
673
674            Some(ast::FieldList::TupleFieldList(field_list)) => {
675                let mut exprs = vec![];
676                for (i, _) in field_list.fields().enumerate() {
677                    let idx = format!("{i}");
678                    let lhs = make.expr_path(make.ident_path("self"));
679                    let lhs = make.expr_field(lhs, &idx).into();
680                    let rhs = make.expr_path(make.ident_path("other"));
681                    let rhs = make.expr_field(rhs, &idx).into();
682                    let ord = gen_partial_cmp_call(lhs, rhs);
683                    exprs.push(ord);
684                }
685                let tail = exprs.pop();
686                let stmts = exprs
687                    .into_iter()
688                    .map(gen_partial_eq_match)
689                    .collect::<Option<Vec<ast::Stmt>>>()?;
690                make.block_expr(stmts, tail).indent(ast::edit::IndentLevel(1))
691            }
692
693            // No fields in the body means there's nothing to compare.
694            None => {
695                let expr = make.expr_literal("true").into();
696                make.block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
697            }
698        },
699    };
700
701    Some(body)
702}
703
704fn make_discriminant(make: &SyntaxFactory) -> Option<ast::Expr> {
705    Some(make.expr_path(make.path_from_idents(["core", "mem", "discriminant"])?))
706}