ide_assists/handlers/
generate_delegate_methods.rs

1use hir::HasVisibility;
2use ide_db::{FxHashSet, path_transform::PathTransform};
3use syntax::{
4    ast::{
5        self, AstNode, HasGenericParams, HasName, HasVisibility as _,
6        edit::{AstNodeEdit, IndentLevel},
7        make,
8    },
9    syntax_editor::Position,
10};
11
12use crate::{
13    AssistContext, AssistId, AssistKind, Assists, GroupLabel,
14    utils::{convert_param_list_to_arg_list, find_struct_impl},
15};
16
17// Assist: generate_delegate_methods
18//
19// Generate delegate methods.
20//
21// ```
22// struct Age(u8);
23// impl Age {
24//     fn age(&self) -> u8 {
25//         self.0
26//     }
27// }
28//
29// struct Person {
30//     ag$0e: Age,
31// }
32// ```
33// ->
34// ```
35// struct Age(u8);
36// impl Age {
37//     fn age(&self) -> u8 {
38//         self.0
39//     }
40// }
41//
42// struct Person {
43//     age: Age,
44// }
45//
46// impl Person {
47//     $0fn age(&self) -> u8 {
48//         self.age.age()
49//     }
50// }
51// ```
52pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
53    if !ctx.config.code_action_grouping {
54        return None;
55    }
56
57    let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
58    let strukt_name = strukt.name()?;
59    let current_module = ctx.sema.scope(strukt.syntax())?.module();
60    let current_edition = current_module.krate(ctx.db()).edition(ctx.db());
61
62    let (field_name, field_ty, target) = match ctx.find_node_at_offset::<ast::RecordField>() {
63        Some(field) => {
64            let field_name = field.name()?;
65            let field_ty = field.ty()?;
66            (field_name.to_string(), field_ty, field.syntax().text_range())
67        }
68        None => {
69            let field = ctx.find_node_at_offset::<ast::TupleField>()?;
70            let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
71            let field_list_index = field_list.fields().position(|it| it == field)?;
72            let field_ty = field.ty()?;
73            (field_list_index.to_string(), field_ty, field.syntax().text_range())
74        }
75    };
76
77    let sema_field_ty = ctx.sema.resolve_type(&field_ty)?;
78    let mut methods = vec![];
79    let mut seen_names = FxHashSet::default();
80
81    for ty in sema_field_ty.autoderef(ctx.db()) {
82        ty.iterate_assoc_items(ctx.db(), |item| {
83            if let hir::AssocItem::Function(f) = item {
84                let name = f.name(ctx.db());
85                if f.self_param(ctx.db()).is_some()
86                    && f.is_visible_from(ctx.db(), current_module)
87                    && seen_names.insert(name.clone())
88                {
89                    methods.push((name, f))
90                }
91            }
92            Option::<()>::None
93        });
94    }
95    methods.sort_by(|(a, _), (b, _)| a.cmp(b));
96    for (index, (name, method)) in methods.into_iter().enumerate() {
97        let adt = ast::Adt::Struct(strukt.clone());
98        let name = name.display(ctx.db(), current_edition).to_string();
99        // if `find_struct_impl` returns None, that means that a function named `name` already exists.
100        let Some(impl_def) = find_struct_impl(ctx, &adt, std::slice::from_ref(&name)) else {
101            continue;
102        };
103        let field = make::ext::field_from_idents(["self", &field_name])?;
104
105        acc.add_group(
106            &GroupLabel("Generate delegate methods…".to_owned()),
107            AssistId("generate_delegate_methods", AssistKind::Generate, Some(index)),
108            format!("Generate delegate for `{field_name}.{name}()`",),
109            target,
110            |edit| {
111                // Create the function
112                let method_source = match ctx.sema.source(method) {
113                    Some(source) => {
114                        let v = source.value.clone_for_update();
115                        let source_scope = ctx.sema.scope(v.syntax());
116                        let target_scope = ctx.sema.scope(strukt.syntax());
117                        if let (Some(s), Some(t)) = (source_scope, target_scope) {
118                            ast::Fn::cast(
119                                PathTransform::generic_transformation(&t, &s).apply(v.syntax()),
120                            )
121                            .unwrap_or(v)
122                        } else {
123                            v
124                        }
125                    }
126                    None => return,
127                };
128
129                let vis = method_source.visibility();
130                let is_async = method_source.async_token().is_some();
131                let is_const = method_source.const_token().is_some();
132                let is_unsafe = method_source.unsafe_token().is_some();
133                let is_gen = method_source.gen_token().is_some();
134
135                let fn_name = make::name(&name);
136
137                let type_params = method_source.generic_param_list();
138                let where_clause = method_source.where_clause();
139                let params =
140                    method_source.param_list().unwrap_or_else(|| make::param_list(None, []));
141
142                // compute the `body`
143                let arg_list = method_source
144                    .param_list()
145                    .map(convert_param_list_to_arg_list)
146                    .unwrap_or_else(|| make::arg_list([]));
147
148                let tail_expr =
149                    make::expr_method_call(field, make::name_ref(&name), arg_list).into();
150                let tail_expr_finished =
151                    if is_async { make::expr_await(tail_expr) } else { tail_expr };
152                let body = make::block_expr([], Some(tail_expr_finished));
153
154                let ret_type = method_source.ret_type();
155
156                let f = make::fn_(
157                    None,
158                    vis,
159                    fn_name,
160                    type_params,
161                    where_clause,
162                    params,
163                    body,
164                    ret_type,
165                    is_async,
166                    is_const,
167                    is_unsafe,
168                    is_gen,
169                )
170                .indent(IndentLevel(1));
171                let item = ast::AssocItem::Fn(f.clone());
172
173                let mut editor = edit.make_editor(strukt.syntax());
174                let fn_: Option<ast::AssocItem> = match impl_def {
175                    Some(impl_def) => match impl_def.assoc_item_list() {
176                        Some(assoc_item_list) => {
177                            let item = item.indent(IndentLevel::from_node(impl_def.syntax()));
178                            assoc_item_list.add_items(&mut editor, vec![item.clone()]);
179                            Some(item)
180                        }
181                        None => {
182                            let assoc_item_list = make::assoc_item_list(Some(vec![item]));
183                            editor.insert(
184                                Position::last_child_of(impl_def.syntax()),
185                                assoc_item_list.syntax(),
186                            );
187                            assoc_item_list.assoc_items().next()
188                        }
189                    },
190                    None => {
191                        let name = &strukt_name.to_string();
192                        let ty_params = strukt.generic_param_list();
193                        let ty_args = ty_params.as_ref().map(|it| it.to_generic_args());
194                        let where_clause = strukt.where_clause();
195                        let assoc_item_list = make::assoc_item_list(Some(vec![item]));
196
197                        let impl_def = make::impl_(
198                            None,
199                            ty_params,
200                            ty_args,
201                            make::ty_path(make::ext::ident_path(name)),
202                            where_clause,
203                            Some(assoc_item_list),
204                        )
205                        .clone_for_update();
206
207                        // Fixup impl_def indentation
208                        let indent = strukt.indent_level();
209                        let impl_def = impl_def.indent(indent);
210
211                        // Insert the impl block.
212                        let strukt = edit.make_mut(strukt.clone());
213                        editor.insert_all(
214                            Position::after(strukt.syntax()),
215                            vec![
216                                make::tokens::whitespace(&format!("\n\n{indent}")).into(),
217                                impl_def.syntax().clone().into(),
218                            ],
219                        );
220                        impl_def.assoc_item_list().and_then(|list| list.assoc_items().next())
221                    }
222                };
223
224                if let Some(cap) = ctx.config.snippet_cap
225                    && let Some(fn_) = fn_
226                {
227                    let tabstop = edit.make_tabstop_before(cap);
228                    editor.add_annotation(fn_.syntax(), tabstop);
229                }
230                edit.add_file_edits(ctx.vfs_file_id(), editor);
231            },
232        )?;
233    }
234    Some(())
235}
236
237#[cfg(test)]
238mod tests {
239    use crate::tests::{
240        check_assist, check_assist_not_applicable, check_assist_not_applicable_no_grouping,
241    };
242
243    use super::*;
244
245    #[test]
246    fn test_generate_delegate_create_impl_block() {
247        check_assist(
248            generate_delegate_methods,
249            r#"
250struct Age(u8);
251impl Age {
252    fn age(&self) -> u8 {
253        self.0
254    }
255}
256
257struct Person {
258    ag$0e: Age,
259}"#,
260            r#"
261struct Age(u8);
262impl Age {
263    fn age(&self) -> u8 {
264        self.0
265    }
266}
267
268struct Person {
269    age: Age,
270}
271
272impl Person {
273    $0fn age(&self) -> u8 {
274        self.age.age()
275    }
276}"#,
277        );
278    }
279
280    #[test]
281    fn test_generate_delegate_create_impl_block_match_indent() {
282        check_assist(
283            generate_delegate_methods,
284            r#"
285mod indent {
286    struct Age(u8);
287    impl Age {
288        fn age(&self) -> u8 {
289            self.0
290        }
291    }
292
293    struct Person {
294        ag$0e: Age,
295    }
296}"#,
297            r#"
298mod indent {
299    struct Age(u8);
300    impl Age {
301        fn age(&self) -> u8 {
302            self.0
303        }
304    }
305
306    struct Person {
307        age: Age,
308    }
309
310    impl Person {
311        $0fn age(&self) -> u8 {
312            self.age.age()
313        }
314    }
315}"#,
316        );
317    }
318
319    #[test]
320    fn test_generate_delegate_update_impl_block() {
321        check_assist(
322            generate_delegate_methods,
323            r#"
324struct Age(u8);
325impl Age {
326    fn age(&self) -> u8 {
327        self.0
328    }
329}
330
331struct Person {
332    ag$0e: Age,
333}
334
335impl Person {}"#,
336            r#"
337struct Age(u8);
338impl Age {
339    fn age(&self) -> u8 {
340        self.0
341    }
342}
343
344struct Person {
345    age: Age,
346}
347
348impl Person {
349    $0fn age(&self) -> u8 {
350        self.age.age()
351    }
352}"#,
353        );
354    }
355
356    #[test]
357    fn test_generate_delegate_update_impl_block_match_indent() {
358        check_assist(
359            generate_delegate_methods,
360            r#"
361mod indent {
362    struct Age(u8);
363    impl Age {
364        fn age(&self) -> u8 {
365            self.0
366        }
367    }
368
369    struct Person {
370        ag$0e: Age,
371    }
372
373    impl Person {}
374}"#,
375            r#"
376mod indent {
377    struct Age(u8);
378    impl Age {
379        fn age(&self) -> u8 {
380            self.0
381        }
382    }
383
384    struct Person {
385        age: Age,
386    }
387
388    impl Person {
389        $0fn age(&self) -> u8 {
390            self.age.age()
391        }
392    }
393}"#,
394        );
395    }
396
397    #[test]
398    fn test_generate_delegate_tuple_struct() {
399        check_assist(
400            generate_delegate_methods,
401            r#"
402struct Age(u8);
403impl Age {
404    fn age(&self) -> u8 {
405        self.0
406    }
407}
408
409struct Person(A$0ge);"#,
410            r#"
411struct Age(u8);
412impl Age {
413    fn age(&self) -> u8 {
414        self.0
415    }
416}
417
418struct Person(Age);
419
420impl Person {
421    $0fn age(&self) -> u8 {
422        self.0.age()
423    }
424}"#,
425        );
426    }
427
428    #[test]
429    fn test_generate_delegate_enable_all_attributes() {
430        check_assist(
431            generate_delegate_methods,
432            r#"
433struct Age<T>(T);
434impl<T> Age<T> {
435    pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
436        self.0
437    }
438}
439
440struct Person<T> {
441    ag$0e: Age<T>,
442}"#,
443            r#"
444struct Age<T>(T);
445impl<T> Age<T> {
446    pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
447        self.0
448    }
449}
450
451struct Person<T> {
452    age: Age<T>,
453}
454
455impl<T> Person<T> {
456    $0pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
457        self.age.age(ty, arg).await
458    }
459}"#,
460        );
461    }
462
463    #[test]
464    fn test_generates_delegate_autoderef() {
465        check_assist(
466            generate_delegate_methods,
467            r#"
468//- minicore: deref
469struct Age(u8);
470impl Age {
471    fn age(&self) -> u8 {
472        self.0
473    }
474}
475struct AgeDeref(Age);
476impl core::ops::Deref for AgeDeref { type Target = Age; }
477struct Person {
478    ag$0e: AgeDeref,
479}
480impl Person {}"#,
481            r#"
482struct Age(u8);
483impl Age {
484    fn age(&self) -> u8 {
485        self.0
486    }
487}
488struct AgeDeref(Age);
489impl core::ops::Deref for AgeDeref { type Target = Age; }
490struct Person {
491    age: AgeDeref,
492}
493impl Person {
494    $0fn age(&self) -> u8 {
495        self.age.age()
496    }
497}"#,
498        );
499    }
500
501    #[test]
502    fn test_preserve_where_clause() {
503        check_assist(
504            generate_delegate_methods,
505            r#"
506struct Inner<T>(T);
507impl<T> Inner<T> {
508    fn get(&self) -> T
509    where
510        T: Copy,
511        T: PartialEq,
512    {
513        self.0
514    }
515}
516
517struct Struct<T> {
518    $0field: Inner<T>,
519}
520"#,
521            r#"
522struct Inner<T>(T);
523impl<T> Inner<T> {
524    fn get(&self) -> T
525    where
526        T: Copy,
527        T: PartialEq,
528    {
529        self.0
530    }
531}
532
533struct Struct<T> {
534    field: Inner<T>,
535}
536
537impl<T> Struct<T> {
538    $0fn get(&self) -> T where
539            T: Copy,
540            T: PartialEq, {
541        self.field.get()
542    }
543}
544"#,
545        );
546    }
547
548    #[test]
549    fn test_fixes_basic_self_references() {
550        check_assist(
551            generate_delegate_methods,
552            r#"
553struct Foo {
554    field: $0Bar,
555}
556
557struct Bar;
558
559impl Bar {
560    fn bar(&self, other: Self) -> Self {
561        other
562    }
563}
564"#,
565            r#"
566struct Foo {
567    field: Bar,
568}
569
570impl Foo {
571    $0fn bar(&self, other: Bar) -> Bar {
572        self.field.bar(other)
573    }
574}
575
576struct Bar;
577
578impl Bar {
579    fn bar(&self, other: Self) -> Self {
580        other
581    }
582}
583"#,
584        );
585    }
586
587    #[test]
588    fn test_fixes_nested_self_references() {
589        check_assist(
590            generate_delegate_methods,
591            r#"
592struct Foo {
593    field: $0Bar,
594}
595
596struct Bar;
597
598impl Bar {
599    fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
600}
601"#,
602            r#"
603struct Foo {
604    field: Bar,
605}
606
607impl Foo {
608    $0fn bar(&mut self, a: (Bar, [Bar; 4]), b: Vec<Bar>) {
609        self.field.bar(a, b)
610    }
611}
612
613struct Bar;
614
615impl Bar {
616    fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
617}
618"#,
619        );
620    }
621
622    #[test]
623    fn test_fixes_self_references_with_lifetimes_and_generics() {
624        check_assist(
625            generate_delegate_methods,
626            r#"
627struct Foo<'a, T> {
628    $0field: Bar<'a, T>,
629}
630
631struct Bar<'a, T>(&'a T);
632
633impl<'a, T> Bar<'a, T> {
634    fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
635        b.pop().unwrap()
636    }
637}
638"#,
639            r#"
640struct Foo<'a, T> {
641    field: Bar<'a, T>,
642}
643
644impl<'a, T> Foo<'a, T> {
645    $0fn bar(self, mut b: Vec<&'a Bar<'a, T>>) -> &'a Bar<'a, T> {
646        self.field.bar(b)
647    }
648}
649
650struct Bar<'a, T>(&'a T);
651
652impl<'a, T> Bar<'a, T> {
653    fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
654        b.pop().unwrap()
655    }
656}
657"#,
658        );
659    }
660
661    #[test]
662    fn test_fixes_self_references_across_macros() {
663        check_assist(
664            generate_delegate_methods,
665            r#"
666//- /bar.rs
667macro_rules! test_method {
668    () => {
669        pub fn test(self, b: Bar) -> Self {
670            self
671        }
672    };
673}
674
675pub struct Bar;
676
677impl Bar {
678    test_method!();
679}
680
681//- /main.rs
682mod bar;
683
684struct Foo {
685    $0bar: bar::Bar,
686}
687"#,
688            r#"
689mod bar;
690
691struct Foo {
692    bar: bar::Bar,
693}
694
695impl Foo {
696    $0pub fn test(self,b:bar::Bar) ->bar::Bar {
697        self.bar.test(b)
698    }
699}
700"#,
701        );
702    }
703
704    #[test]
705    fn test_generate_delegate_visibility() {
706        check_assist_not_applicable(
707            generate_delegate_methods,
708            r#"
709mod m {
710    pub struct Age(u8);
711    impl Age {
712        fn age(&self) -> u8 {
713            self.0
714        }
715    }
716}
717
718struct Person {
719    ag$0e: m::Age,
720}"#,
721        )
722    }
723
724    #[test]
725    fn test_generate_not_eligible_if_fn_exists() {
726        check_assist_not_applicable(
727            generate_delegate_methods,
728            r#"
729struct Age(u8);
730impl Age {
731    fn age(&self) -> u8 {
732        self.0
733    }
734}
735
736struct Person {
737    ag$0e: Age,
738}
739impl Person {
740    fn age(&self) -> u8 { 0 }
741}
742"#,
743        );
744    }
745
746    #[test]
747    fn delegate_method_skipped_when_no_grouping() {
748        check_assist_not_applicable_no_grouping(
749            generate_delegate_methods,
750            r#"
751struct Age(u8);
752impl Age {
753    fn age(&self) -> u8 {
754        self.0
755    }
756}
757struct Person {
758    ag$0e: Age,
759}"#,
760        );
761    }
762}