Skip to main content

ide_assists/handlers/
generate_delegate_methods.rs

1use hir::HasVisibility;
2use ide_db::{FxHashSet, path_transform::PathTransform};
3use syntax::{
4    ast::{
5        self, AstNode, HasGenericParams, HasName, HasVisibility as _,
6        edit::{AstNodeEdit, IndentLevel},
7    },
8    syntax_editor::Position,
9};
10
11use crate::{
12    AssistContext, AssistId, AssistKind, Assists, GroupLabel,
13    utils::{convert_param_list_to_arg_list, find_struct_impl},
14};
15
16// Assist: generate_delegate_methods
17//
18// Generate delegate methods.
19//
20// ```
21// struct Age(u8);
22// impl Age {
23//     fn age(&self) -> u8 {
24//         self.0
25//     }
26// }
27//
28// struct Person {
29//     ag$0e: Age,
30// }
31// ```
32// ->
33// ```
34// struct Age(u8);
35// impl Age {
36//     fn age(&self) -> u8 {
37//         self.0
38//     }
39// }
40//
41// struct Person {
42//     age: Age,
43// }
44//
45// impl Person {
46//     $0fn age(&self) -> u8 {
47//         self.age.age()
48//     }
49// }
50// ```
51pub(crate) fn generate_delegate_methods(
52    acc: &mut Assists,
53    ctx: &AssistContext<'_, '_>,
54) -> Option<()> {
55    if !ctx.config.code_action_grouping {
56        return None;
57    }
58
59    let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
60    let strukt_name = strukt.name()?;
61    let current_module = ctx.sema.scope(strukt.syntax())?.module();
62    let current_edition = current_module.krate(ctx.db()).edition(ctx.db());
63
64    let (field_name, field_ty, target) = match ctx.find_node_at_offset::<ast::RecordField>() {
65        Some(field) => {
66            let field_name = field.name()?;
67            let field_ty = field.ty()?;
68            (field_name.to_string(), field_ty, field.syntax().text_range())
69        }
70        None => {
71            let field = ctx.find_node_at_offset::<ast::TupleField>()?;
72            let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
73            let field_list_index = field_list.fields().position(|it| it == field)?;
74            let field_ty = field.ty()?;
75            (field_list_index.to_string(), field_ty, field.syntax().text_range())
76        }
77    };
78
79    let sema_field_ty = ctx.sema.resolve_type(&field_ty)?;
80    let mut methods = vec![];
81    let mut seen_names = FxHashSet::default();
82
83    for ty in sema_field_ty.autoderef(ctx.db()) {
84        ty.iterate_assoc_items(ctx.db(), |item| {
85            if let hir::AssocItem::Function(f) = item {
86                let name = f.name(ctx.db());
87                if f.self_param(ctx.db()).is_some()
88                    && f.is_visible_from(ctx.db(), current_module)
89                    && seen_names.insert(name.clone())
90                {
91                    methods.push((name, f))
92                }
93            }
94            Option::<()>::None
95        });
96    }
97    methods.sort_by(|(a, _), (b, _)| a.cmp(b));
98    for (index, (name, method)) in methods.into_iter().enumerate() {
99        let adt = ast::Adt::Struct(strukt.clone());
100        let name = name.display(ctx.db(), current_edition).to_string();
101        // if `find_struct_impl` returns None, that means that a function named `name` already exists.
102        let Some(impl_def) = find_struct_impl(ctx, &adt, std::slice::from_ref(&name)) else {
103            continue;
104        };
105
106        acc.add_group(
107            &GroupLabel("Generate delegate methods…".to_owned()),
108            AssistId("generate_delegate_methods", AssistKind::Generate, Some(index)),
109            format!("Generate delegate for `{field_name}.{name}()`",),
110            target,
111            |edit| {
112                let editor = edit.make_editor(strukt.syntax());
113                let make = editor.make();
114                let field = make
115                    .field_from_idents(["self", &field_name])
116                    .expect("always be a valid expression");
117                // Create the function
118                let method_source = match ctx.sema.source(method) {
119                    Some(source) => {
120                        let v = source.value;
121                        let source_scope = ctx.sema.scope(v.syntax());
122                        let target_scope = ctx.sema.scope(strukt.syntax());
123                        if let (Some(s), Some(t)) = (source_scope, target_scope) {
124                            ast::Fn::cast(
125                                PathTransform::generic_transformation(&t, &s).apply(v.syntax()),
126                            )
127                            .unwrap_or(v)
128                        } else {
129                            v
130                        }
131                    }
132                    None => return,
133                };
134
135                let vis = method_source.visibility();
136                let is_async = method_source.async_token().is_some();
137                let is_const = method_source.const_token().is_some();
138                let is_unsafe = method_source.unsafe_token().is_some();
139                let is_gen = method_source.gen_token().is_some();
140
141                let fn_name = make.name(&name);
142
143                let type_params = method_source.generic_param_list();
144                let where_clause = method_source.where_clause();
145                let params =
146                    method_source.param_list().unwrap_or_else(|| make.param_list(None, []));
147
148                // compute the `body`
149                let arg_list = method_source
150                    .param_list()
151                    .map(|v| convert_param_list_to_arg_list(v, make))
152                    .unwrap_or_else(|| make.arg_list([]));
153
154                let tail_expr = make.expr_method_call(field, make.name_ref(&name), arg_list).into();
155                let tail_expr_finished =
156                    if is_async { make.expr_await(tail_expr).into() } else { tail_expr };
157                let body = make.block_expr([], Some(tail_expr_finished));
158
159                let ret_type = method_source.ret_type();
160
161                let f = make
162                    .fn_(
163                        None,
164                        vis,
165                        fn_name,
166                        type_params,
167                        where_clause,
168                        params,
169                        body,
170                        ret_type,
171                        is_async,
172                        is_const,
173                        is_unsafe,
174                        is_gen,
175                    )
176                    .indent(IndentLevel(1));
177                let item = ast::AssocItem::Fn(f.clone());
178
179                let fn_: Option<ast::AssocItem> = match impl_def {
180                    Some(impl_def) => match impl_def.assoc_item_list() {
181                        Some(assoc_item_list) => {
182                            let item = item.indent(IndentLevel::from_node(impl_def.syntax()));
183                            assoc_item_list.add_items(&editor, vec![item.clone()]);
184                            Some(item)
185                        }
186                        None => {
187                            let assoc_item_list = make.assoc_item_list(vec![item]);
188                            editor.insert(
189                                Position::last_child_of(impl_def.syntax()),
190                                assoc_item_list.syntax(),
191                            );
192                            assoc_item_list.assoc_items().next()
193                        }
194                    },
195                    None => {
196                        let name = &strukt_name.to_string();
197                        let ty_params = strukt.generic_param_list();
198                        let ty_args = ty_params.as_ref().map(|it| it.to_generic_args(make));
199                        let where_clause = strukt.where_clause();
200                        let assoc_item_list = make.assoc_item_list(vec![item]);
201
202                        let impl_def = make.impl_(
203                            None,
204                            ty_params,
205                            ty_args,
206                            syntax::ast::Type::PathType(make.ty_path(make.ident_path(name))),
207                            where_clause,
208                            Some(assoc_item_list),
209                        );
210
211                        // Fixup impl_def indentation
212                        let indent = strukt.indent_level();
213                        let impl_def = impl_def.indent(indent);
214
215                        // Insert the impl block.
216                        editor.insert_all(
217                            Position::after(strukt.syntax()),
218                            vec![
219                                make.whitespace(&format!("\n\n{indent}")).into(),
220                                impl_def.syntax().clone().into(),
221                            ],
222                        );
223                        impl_def.assoc_item_list().and_then(|list| list.assoc_items().next())
224                    }
225                };
226
227                if let Some(cap) = ctx.config.snippet_cap
228                    && let Some(fn_) = fn_
229                {
230                    let tabstop = edit.make_tabstop_before(cap);
231                    editor.add_annotation(fn_.syntax(), tabstop);
232                }
233                edit.add_file_edits(ctx.vfs_file_id(), editor);
234            },
235        )?;
236    }
237    Some(())
238}
239
240#[cfg(test)]
241mod tests {
242    use crate::tests::{
243        check_assist, check_assist_not_applicable, check_assist_not_applicable_no_grouping,
244    };
245
246    use super::*;
247
248    #[test]
249    fn test_generate_delegate_create_impl_block() {
250        check_assist(
251            generate_delegate_methods,
252            r#"
253struct Age(u8);
254impl Age {
255    fn age(&self) -> u8 {
256        self.0
257    }
258}
259
260struct Person {
261    ag$0e: Age,
262}"#,
263            r#"
264struct Age(u8);
265impl Age {
266    fn age(&self) -> u8 {
267        self.0
268    }
269}
270
271struct Person {
272    age: Age,
273}
274
275impl Person {
276    $0fn age(&self) -> u8 {
277        self.age.age()
278    }
279}"#,
280        );
281    }
282
283    #[test]
284    fn test_generate_delegate_create_impl_block_match_indent() {
285        check_assist(
286            generate_delegate_methods,
287            r#"
288mod indent {
289    struct Age(u8);
290    impl Age {
291        fn age(&self) -> u8 {
292            self.0
293        }
294    }
295
296    struct Person {
297        ag$0e: Age,
298    }
299}"#,
300            r#"
301mod indent {
302    struct Age(u8);
303    impl Age {
304        fn age(&self) -> u8 {
305            self.0
306        }
307    }
308
309    struct Person {
310        age: Age,
311    }
312
313    impl Person {
314        $0fn age(&self) -> u8 {
315            self.age.age()
316        }
317    }
318}"#,
319        );
320    }
321
322    #[test]
323    fn test_generate_delegate_update_impl_block() {
324        check_assist(
325            generate_delegate_methods,
326            r#"
327struct Age(u8);
328impl Age {
329    fn age(&self) -> u8 {
330        self.0
331    }
332}
333
334struct Person {
335    ag$0e: Age,
336}
337
338impl Person {}"#,
339            r#"
340struct Age(u8);
341impl Age {
342    fn age(&self) -> u8 {
343        self.0
344    }
345}
346
347struct Person {
348    age: Age,
349}
350
351impl Person {
352    $0fn age(&self) -> u8 {
353        self.age.age()
354    }
355}"#,
356        );
357    }
358
359    #[test]
360    fn test_generate_delegate_update_impl_block_match_indent() {
361        check_assist(
362            generate_delegate_methods,
363            r#"
364mod indent {
365    struct Age(u8);
366    impl Age {
367        fn age(&self) -> u8 {
368            self.0
369        }
370    }
371
372    struct Person {
373        ag$0e: Age,
374    }
375
376    impl Person {}
377}"#,
378            r#"
379mod indent {
380    struct Age(u8);
381    impl Age {
382        fn age(&self) -> u8 {
383            self.0
384        }
385    }
386
387    struct Person {
388        age: Age,
389    }
390
391    impl Person {
392        $0fn age(&self) -> u8 {
393            self.age.age()
394        }
395    }
396}"#,
397        );
398    }
399
400    #[test]
401    fn test_generate_delegate_tuple_struct() {
402        check_assist(
403            generate_delegate_methods,
404            r#"
405struct Age(u8);
406impl Age {
407    fn age(&self) -> u8 {
408        self.0
409    }
410}
411
412struct Person(A$0ge);"#,
413            r#"
414struct Age(u8);
415impl Age {
416    fn age(&self) -> u8 {
417        self.0
418    }
419}
420
421struct Person(Age);
422
423impl Person {
424    $0fn age(&self) -> u8 {
425        self.0.age()
426    }
427}"#,
428        );
429    }
430
431    #[test]
432    fn test_generate_delegate_enable_all_attributes() {
433        check_assist(
434            generate_delegate_methods,
435            r#"
436struct Age<T>(T);
437impl<T> Age<T> {
438    pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
439        self.0
440    }
441}
442
443struct Person<T> {
444    ag$0e: Age<T>,
445}"#,
446            r#"
447struct Age<T>(T);
448impl<T> Age<T> {
449    pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
450        self.0
451    }
452}
453
454struct Person<T> {
455    age: Age<T>,
456}
457
458impl<T> Person<T> {
459    $0pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
460        self.age.age(ty, arg).await
461    }
462}"#,
463        );
464    }
465
466    #[test]
467    fn test_generates_delegate_autoderef() {
468        check_assist(
469            generate_delegate_methods,
470            r#"
471//- minicore: deref
472struct Age(u8);
473impl Age {
474    fn age(&self) -> u8 {
475        self.0
476    }
477}
478struct AgeDeref(Age);
479impl core::ops::Deref for AgeDeref { type Target = Age; }
480struct Person {
481    ag$0e: AgeDeref,
482}
483impl Person {}"#,
484            r#"
485struct Age(u8);
486impl Age {
487    fn age(&self) -> u8 {
488        self.0
489    }
490}
491struct AgeDeref(Age);
492impl core::ops::Deref for AgeDeref { type Target = Age; }
493struct Person {
494    age: AgeDeref,
495}
496impl Person {
497    $0fn age(&self) -> u8 {
498        self.age.age()
499    }
500}"#,
501        );
502    }
503
504    #[test]
505    fn test_preserve_where_clause() {
506        check_assist(
507            generate_delegate_methods,
508            r#"
509struct Inner<T>(T);
510impl<T> Inner<T> {
511    fn get(&self) -> T
512    where
513        T: Copy,
514        T: PartialEq,
515    {
516        self.0
517    }
518}
519
520struct Struct<T> {
521    $0field: Inner<T>,
522}
523"#,
524            r#"
525struct Inner<T>(T);
526impl<T> Inner<T> {
527    fn get(&self) -> T
528    where
529        T: Copy,
530        T: PartialEq,
531    {
532        self.0
533    }
534}
535
536struct Struct<T> {
537    field: Inner<T>,
538}
539
540impl<T> Struct<T> {
541    $0fn get(&self) -> T where
542            T: Copy,
543            T: PartialEq, {
544        self.field.get()
545    }
546}
547"#,
548        );
549    }
550
551    #[test]
552    fn test_fixes_basic_self_references() {
553        check_assist(
554            generate_delegate_methods,
555            r#"
556struct Foo {
557    field: $0Bar,
558}
559
560struct Bar;
561
562impl Bar {
563    fn bar(&self, other: Self) -> Self {
564        other
565    }
566}
567"#,
568            r#"
569struct Foo {
570    field: Bar,
571}
572
573impl Foo {
574    $0fn bar(&self, other: Bar) -> Bar {
575        self.field.bar(other)
576    }
577}
578
579struct Bar;
580
581impl Bar {
582    fn bar(&self, other: Self) -> Self {
583        other
584    }
585}
586"#,
587        );
588    }
589
590    #[test]
591    fn test_fixes_nested_self_references() {
592        check_assist(
593            generate_delegate_methods,
594            r#"
595struct Foo {
596    field: $0Bar,
597}
598
599struct Bar;
600
601impl Bar {
602    fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
603}
604"#,
605            r#"
606struct Foo {
607    field: Bar,
608}
609
610impl Foo {
611    $0fn bar(&mut self, a: (Bar, [Bar; 4]), b: Vec<Bar>) {
612        self.field.bar(a, b)
613    }
614}
615
616struct Bar;
617
618impl Bar {
619    fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
620}
621"#,
622        );
623    }
624
625    #[test]
626    fn test_fixes_self_references_with_lifetimes_and_generics() {
627        check_assist(
628            generate_delegate_methods,
629            r#"
630struct Foo<'a, T> {
631    $0field: Bar<'a, T>,
632}
633
634struct Bar<'a, T>(&'a T);
635
636impl<'a, T> Bar<'a, T> {
637    fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
638        b.pop().unwrap()
639    }
640}
641"#,
642            r#"
643struct Foo<'a, T> {
644    field: Bar<'a, T>,
645}
646
647impl<'a, T> Foo<'a, T> {
648    $0fn bar(self, mut b: Vec<&'a Bar<'a, T>>) -> &'a Bar<'a, T> {
649        self.field.bar(b)
650    }
651}
652
653struct Bar<'a, T>(&'a T);
654
655impl<'a, T> Bar<'a, T> {
656    fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
657        b.pop().unwrap()
658    }
659}
660"#,
661        );
662    }
663
664    #[test]
665    fn test_fixes_self_references_across_macros() {
666        check_assist(
667            generate_delegate_methods,
668            r#"
669//- /bar.rs
670macro_rules! test_method {
671    () => {
672        pub fn test(self, b: Bar) -> Self {
673            self
674        }
675    };
676}
677
678pub struct Bar;
679
680impl Bar {
681    test_method!();
682}
683
684//- /main.rs
685mod bar;
686
687struct Foo {
688    $0bar: bar::Bar,
689}
690"#,
691            r#"
692mod bar;
693
694struct Foo {
695    bar: bar::Bar,
696}
697
698impl Foo {
699    $0pub fn test(self,b:bar::Bar) ->bar::Bar {
700        self.bar.test(b)
701    }
702}
703"#,
704        );
705    }
706
707    #[test]
708    fn test_generate_delegate_visibility() {
709        check_assist_not_applicable(
710            generate_delegate_methods,
711            r#"
712mod m {
713    pub struct Age(u8);
714    impl Age {
715        fn age(&self) -> u8 {
716            self.0
717        }
718    }
719}
720
721struct Person {
722    ag$0e: m::Age,
723}"#,
724        )
725    }
726
727    #[test]
728    fn test_generate_not_eligible_if_fn_exists() {
729        check_assist_not_applicable(
730            generate_delegate_methods,
731            r#"
732struct Age(u8);
733impl Age {
734    fn age(&self) -> u8 {
735        self.0
736    }
737}
738
739struct Person {
740    ag$0e: Age,
741}
742impl Person {
743    fn age(&self) -> u8 { 0 }
744}
745"#,
746        );
747    }
748
749    #[test]
750    fn delegate_method_skipped_when_no_grouping() {
751        check_assist_not_applicable_no_grouping(
752            generate_delegate_methods,
753            r#"
754struct Age(u8);
755impl Age {
756    fn age(&self) -> u8 {
757        self.0
758    }
759}
760struct Person {
761    ag$0e: Age,
762}"#,
763        );
764    }
765}