ide_assists/handlers/
generate_enum_variant.rs

1use hir::{HasSource, HirDisplay, InRealFile};
2use ide_db::assists::AssistId;
3use syntax::{
4    AstNode, SyntaxNode,
5    ast::{self, HasArgList, syntax_factory::SyntaxFactory},
6    match_ast,
7};
8
9use crate::assist_context::{AssistContext, Assists};
10
11// Assist: generate_enum_variant
12//
13// Adds a variant to an enum.
14//
15// ```
16// enum Countries {
17//     Ghana,
18// }
19//
20// fn main() {
21//     let country = Countries::Lesotho$0;
22// }
23// ```
24// ->
25// ```
26// enum Countries {
27//     Ghana,
28//     Lesotho,
29// }
30//
31// fn main() {
32//     let country = Countries::Lesotho;
33// }
34// ```
35pub(crate) fn generate_enum_variant(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
36    let path: ast::Path = ctx.find_node_at_offset()?;
37    let parent = PathParent::new(&path)?;
38
39    if ctx.sema.resolve_path(&path).is_some() {
40        // No need to generate anything if the path resolves
41        return None;
42    }
43
44    let name_ref = path.segment()?.name_ref()?;
45    if name_ref.text().starts_with(char::is_lowercase) {
46        // Don't suggest generating variant if the name starts with a lowercase letter
47        return None;
48    }
49
50    let Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Enum(e)))) =
51        ctx.sema.resolve_path(&path.qualifier()?)
52    else {
53        return None;
54    };
55
56    let target = path.syntax().text_range();
57    let name_ref: &ast::NameRef = &name_ref;
58    let db = ctx.db();
59    let InRealFile { file_id, value: enum_node } = e.source(db)?.original_ast_node_rooted(db)?;
60
61    acc.add(AssistId::generate("generate_enum_variant"), "Generate variant", target, |builder| {
62        let mut editor = builder.make_editor(enum_node.syntax());
63        let make = SyntaxFactory::with_mappings();
64        let field_list = parent.make_field_list(ctx, &make);
65        let variant = make.variant(None, make.name(&name_ref.text()), field_list, None);
66        if let Some(it) = enum_node.variant_list() {
67            it.add_variant(&mut editor, &variant);
68        }
69        builder.add_file_edits(file_id.file_id(ctx.db()), editor);
70    })
71}
72
73#[derive(Debug)]
74enum PathParent {
75    PathExpr(ast::PathExpr),
76    RecordExpr(ast::RecordExpr),
77    PathPat(ast::PathPat),
78    UseTree(ast::UseTree),
79}
80
81impl PathParent {
82    fn new(path: &ast::Path) -> Option<Self> {
83        let parent = path.syntax().parent()?;
84
85        match_ast! {
86            match parent {
87                ast::PathExpr(it) => Some(PathParent::PathExpr(it)),
88                ast::RecordExpr(it) => Some(PathParent::RecordExpr(it)),
89                ast::PathPat(it) => Some(PathParent::PathPat(it)),
90                ast::UseTree(it) => Some(PathParent::UseTree(it)),
91                _ => None
92            }
93        }
94    }
95
96    fn syntax(&self) -> &SyntaxNode {
97        match self {
98            PathParent::PathExpr(it) => it.syntax(),
99            PathParent::RecordExpr(it) => it.syntax(),
100            PathParent::PathPat(it) => it.syntax(),
101            PathParent::UseTree(it) => it.syntax(),
102        }
103    }
104
105    fn make_field_list(
106        &self,
107        ctx: &AssistContext<'_>,
108        make: &SyntaxFactory,
109    ) -> Option<ast::FieldList> {
110        let scope = ctx.sema.scope(self.syntax())?;
111
112        match self {
113            PathParent::PathExpr(it) => {
114                let call_expr = ast::CallExpr::cast(it.syntax().parent()?)?;
115                let args = call_expr.arg_list()?.args();
116                let tuple_fields = args.map(|arg| {
117                    let ty =
118                        expr_ty(ctx, make, arg, &scope).unwrap_or_else(|| make.ty_infer().into());
119                    make.tuple_field(None, ty)
120                });
121                Some(make.tuple_field_list(tuple_fields).into())
122            }
123            PathParent::RecordExpr(it) => {
124                let fields = it.record_expr_field_list()?.fields();
125                let record_fields = fields.map(|field| {
126                    let name = name_from_field(make, &field);
127
128                    let ty = field
129                        .expr()
130                        .and_then(|it| expr_ty(ctx, make, it, &scope))
131                        .unwrap_or_else(|| make.ty_infer().into());
132
133                    make.record_field(None, name, ty)
134                });
135                Some(make.record_field_list(record_fields).into())
136            }
137            PathParent::UseTree(_) | PathParent::PathPat(_) => None,
138        }
139    }
140}
141
142fn name_from_field(make: &SyntaxFactory, field: &ast::RecordExprField) -> ast::Name {
143    let text = match field.name_ref() {
144        Some(it) => it.to_string(),
145        None => name_from_field_shorthand(field).unwrap_or("unknown".to_owned()),
146    };
147    make.name(&text)
148}
149
150fn name_from_field_shorthand(field: &ast::RecordExprField) -> Option<String> {
151    let path = match field.expr()? {
152        ast::Expr::PathExpr(path_expr) => path_expr.path(),
153        _ => None,
154    }?;
155    Some(path.as_single_name_ref()?.to_string())
156}
157
158fn expr_ty(
159    ctx: &AssistContext<'_>,
160    make: &SyntaxFactory,
161    arg: ast::Expr,
162    scope: &hir::SemanticsScope<'_>,
163) -> Option<ast::Type> {
164    let ty = ctx.sema.type_of_expr(&arg).map(|it| it.adjusted())?;
165    let text = ty.display_source_code(ctx.db(), scope.module().into(), false).ok()?;
166    Some(make.ty(&text))
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::tests::{check_assist, check_assist_not_applicable};
172
173    use super::*;
174
175    #[test]
176    fn generate_basic_enum_variant_in_empty_enum() {
177        check_assist(
178            generate_enum_variant,
179            r"
180enum Foo {}
181fn main() {
182    Foo::Bar$0
183}
184",
185            r"
186enum Foo {
187    Bar,
188}
189fn main() {
190    Foo::Bar
191}
192",
193        )
194    }
195
196    #[test]
197    fn generate_basic_enum_variant_in_non_empty_enum() {
198        check_assist(
199            generate_enum_variant,
200            r"
201enum Foo {
202    Bar,
203}
204fn main() {
205    Foo::Baz$0
206}
207",
208            r"
209enum Foo {
210    Bar,
211    Baz,
212}
213fn main() {
214    Foo::Baz
215}
216",
217        )
218    }
219
220    #[test]
221    fn generate_basic_enum_variant_in_different_file() {
222        check_assist(
223            generate_enum_variant,
224            r"
225//- /main.rs
226mod foo;
227use foo::Foo;
228
229fn main() {
230    Foo::Baz$0
231}
232
233//- /foo.rs
234pub enum Foo {
235    Bar,
236}
237",
238            r"
239pub enum Foo {
240    Bar,
241    Baz,
242}
243",
244        )
245    }
246
247    #[test]
248    fn not_applicable_for_existing_variant() {
249        check_assist_not_applicable(
250            generate_enum_variant,
251            r"
252enum Foo {
253    Bar,
254}
255fn main() {
256    Foo::Bar$0
257}
258",
259        )
260    }
261
262    #[test]
263    fn not_applicable_for_lowercase() {
264        check_assist_not_applicable(
265            generate_enum_variant,
266            r"
267enum Foo {
268    Bar,
269}
270fn main() {
271    Foo::new$0
272}
273",
274        )
275    }
276
277    #[test]
278    fn indentation_level_is_correct() {
279        check_assist(
280            generate_enum_variant,
281            r"
282mod m {
283    pub enum Foo {
284        Bar,
285    }
286}
287fn main() {
288    m::Foo::Baz$0
289}
290",
291            r"
292mod m {
293    pub enum Foo {
294        Bar,
295        Baz,
296    }
297}
298fn main() {
299    m::Foo::Baz
300}
301",
302        )
303    }
304
305    #[test]
306    fn associated_single_element_tuple() {
307        check_assist(
308            generate_enum_variant,
309            r"
310enum Foo {}
311fn main() {
312    Foo::Bar$0(true)
313}
314",
315            r"
316enum Foo {
317    Bar(bool),
318}
319fn main() {
320    Foo::Bar(true)
321}
322",
323        )
324    }
325
326    #[test]
327    fn associated_single_element_tuple_unknown_type() {
328        check_assist(
329            generate_enum_variant,
330            r"
331enum Foo {}
332fn main() {
333    Foo::Bar$0(x)
334}
335",
336            r"
337enum Foo {
338    Bar(_),
339}
340fn main() {
341    Foo::Bar(x)
342}
343",
344        )
345    }
346
347    #[test]
348    fn associated_multi_element_tuple() {
349        check_assist(
350            generate_enum_variant,
351            r"
352struct Struct {}
353enum Foo {}
354fn main() {
355    Foo::Bar$0(true, x, Struct {})
356}
357",
358            r"
359struct Struct {}
360enum Foo {
361    Bar(bool, _, Struct),
362}
363fn main() {
364    Foo::Bar(true, x, Struct {})
365}
366",
367        )
368    }
369
370    #[test]
371    fn associated_record() {
372        check_assist(
373            generate_enum_variant,
374            r"
375enum Foo {}
376fn main() {
377    Foo::$0Bar { x: true }
378}
379",
380            r"
381enum Foo {
382    Bar { x: bool },
383}
384fn main() {
385    Foo::Bar { x: true }
386}
387",
388        )
389    }
390
391    #[test]
392    fn associated_record_unknown_type() {
393        check_assist(
394            generate_enum_variant,
395            r"
396enum Foo {}
397fn main() {
398    Foo::$0Bar { x: y }
399}
400",
401            r"
402enum Foo {
403    Bar { x: _ },
404}
405fn main() {
406    Foo::Bar { x: y }
407}
408",
409        )
410    }
411
412    #[test]
413    fn associated_record_field_shorthand() {
414        check_assist(
415            generate_enum_variant,
416            r"
417enum Foo {}
418fn main() {
419    let x = true;
420    Foo::$0Bar { x }
421}
422",
423            r"
424enum Foo {
425    Bar { x: bool },
426}
427fn main() {
428    let x = true;
429    Foo::Bar { x }
430}
431",
432        )
433    }
434
435    #[test]
436    fn associated_record_field_shorthand_unknown_type() {
437        check_assist(
438            generate_enum_variant,
439            r"
440enum Foo {}
441fn main() {
442    Foo::$0Bar { x }
443}
444",
445            r"
446enum Foo {
447    Bar { x: _ },
448}
449fn main() {
450    Foo::Bar { x }
451}
452",
453        )
454    }
455
456    #[test]
457    fn associated_record_field_multiple_fields() {
458        check_assist(
459            generate_enum_variant,
460            r"
461struct Struct {}
462enum Foo {}
463fn main() {
464    Foo::$0Bar { x, y: x, s: Struct {} }
465}
466",
467            r"
468struct Struct {}
469enum Foo {
470    Bar { x: _, y: _, s: Struct },
471}
472fn main() {
473    Foo::Bar { x, y: x, s: Struct {} }
474}
475",
476        )
477    }
478
479    #[test]
480    fn use_tree() {
481        check_assist(
482            generate_enum_variant,
483            r"
484//- /main.rs
485mod foo;
486use foo::Foo::Bar$0;
487
488//- /foo.rs
489pub enum Foo {}
490",
491            r"
492pub enum Foo {
493    Bar,
494}
495",
496        )
497    }
498
499    #[test]
500    fn not_applicable_for_path_type() {
501        check_assist_not_applicable(
502            generate_enum_variant,
503            r"
504enum Foo {}
505impl Foo::Bar$0 {}
506",
507        )
508    }
509
510    #[test]
511    fn path_pat() {
512        check_assist(
513            generate_enum_variant,
514            r"
515enum Foo {}
516fn foo(x: Foo) {
517    match x {
518        Foo::Bar$0 =>
519    }
520}
521",
522            r"
523enum Foo {
524    Bar,
525}
526fn foo(x: Foo) {
527    match x {
528        Foo::Bar =>
529    }
530}
531",
532        )
533    }
534}