1use hir::HasVisibility;
2use ide_db::{FxHashSet, path_transform::PathTransform};
3use syntax::{
4 ast::{
5 self, AstNode, HasGenericParams, HasName, HasVisibility as _,
6 edit::{AstNodeEdit, IndentLevel},
7 make,
8 },
9 syntax_editor::Position,
10};
11
12use crate::{
13 AssistContext, AssistId, AssistKind, Assists, GroupLabel,
14 utils::{convert_param_list_to_arg_list, find_struct_impl},
15};
16
17pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
53 if !ctx.config.code_action_grouping {
54 return None;
55 }
56
57 let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
58 let strukt_name = strukt.name()?;
59 let current_module = ctx.sema.scope(strukt.syntax())?.module();
60 let current_edition = current_module.krate(ctx.db()).edition(ctx.db());
61
62 let (field_name, field_ty, target) = match ctx.find_node_at_offset::<ast::RecordField>() {
63 Some(field) => {
64 let field_name = field.name()?;
65 let field_ty = field.ty()?;
66 (field_name.to_string(), field_ty, field.syntax().text_range())
67 }
68 None => {
69 let field = ctx.find_node_at_offset::<ast::TupleField>()?;
70 let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
71 let field_list_index = field_list.fields().position(|it| it == field)?;
72 let field_ty = field.ty()?;
73 (field_list_index.to_string(), field_ty, field.syntax().text_range())
74 }
75 };
76
77 let sema_field_ty = ctx.sema.resolve_type(&field_ty)?;
78 let mut methods = vec![];
79 let mut seen_names = FxHashSet::default();
80
81 for ty in sema_field_ty.autoderef(ctx.db()) {
82 ty.iterate_assoc_items(ctx.db(), |item| {
83 if let hir::AssocItem::Function(f) = item {
84 let name = f.name(ctx.db());
85 if f.self_param(ctx.db()).is_some()
86 && f.is_visible_from(ctx.db(), current_module)
87 && seen_names.insert(name.clone())
88 {
89 methods.push((name, f))
90 }
91 }
92 Option::<()>::None
93 });
94 }
95 methods.sort_by(|(a, _), (b, _)| a.cmp(b));
96 for (index, (name, method)) in methods.into_iter().enumerate() {
97 let adt = ast::Adt::Struct(strukt.clone());
98 let name = name.display(ctx.db(), current_edition).to_string();
99 let Some(impl_def) = find_struct_impl(ctx, &adt, std::slice::from_ref(&name)) else {
101 continue;
102 };
103 let field = make::ext::field_from_idents(["self", &field_name])?;
104
105 acc.add_group(
106 &GroupLabel("Generate delegate methods…".to_owned()),
107 AssistId("generate_delegate_methods", AssistKind::Generate, Some(index)),
108 format!("Generate delegate for `{field_name}.{name}()`",),
109 target,
110 |edit| {
111 let method_source = match ctx.sema.source(method) {
113 Some(source) => {
114 let v = source.value.clone_for_update();
115 let source_scope = ctx.sema.scope(v.syntax());
116 let target_scope = ctx.sema.scope(strukt.syntax());
117 if let (Some(s), Some(t)) = (source_scope, target_scope) {
118 ast::Fn::cast(
119 PathTransform::generic_transformation(&t, &s).apply(v.syntax()),
120 )
121 .unwrap_or(v)
122 } else {
123 v
124 }
125 }
126 None => return,
127 };
128
129 let vis = method_source.visibility();
130 let is_async = method_source.async_token().is_some();
131 let is_const = method_source.const_token().is_some();
132 let is_unsafe = method_source.unsafe_token().is_some();
133 let is_gen = method_source.gen_token().is_some();
134
135 let fn_name = make::name(&name);
136
137 let type_params = method_source.generic_param_list();
138 let where_clause = method_source.where_clause();
139 let params =
140 method_source.param_list().unwrap_or_else(|| make::param_list(None, []));
141
142 let arg_list = method_source
144 .param_list()
145 .map(convert_param_list_to_arg_list)
146 .unwrap_or_else(|| make::arg_list([]));
147
148 let tail_expr =
149 make::expr_method_call(field, make::name_ref(&name), arg_list).into();
150 let tail_expr_finished =
151 if is_async { make::expr_await(tail_expr) } else { tail_expr };
152 let body = make::block_expr([], Some(tail_expr_finished));
153
154 let ret_type = method_source.ret_type();
155
156 let f = make::fn_(
157 None,
158 vis,
159 fn_name,
160 type_params,
161 where_clause,
162 params,
163 body,
164 ret_type,
165 is_async,
166 is_const,
167 is_unsafe,
168 is_gen,
169 )
170 .indent(IndentLevel(1));
171 let item = ast::AssocItem::Fn(f.clone());
172
173 let mut editor = edit.make_editor(strukt.syntax());
174 let fn_: Option<ast::AssocItem> = match impl_def {
175 Some(impl_def) => match impl_def.assoc_item_list() {
176 Some(assoc_item_list) => {
177 let item = item.indent(IndentLevel::from_node(impl_def.syntax()));
178 assoc_item_list.add_items(&mut editor, vec![item.clone()]);
179 Some(item)
180 }
181 None => {
182 let assoc_item_list = make::assoc_item_list(Some(vec![item]));
183 editor.insert(
184 Position::last_child_of(impl_def.syntax()),
185 assoc_item_list.syntax(),
186 );
187 assoc_item_list.assoc_items().next()
188 }
189 },
190 None => {
191 let name = &strukt_name.to_string();
192 let ty_params = strukt.generic_param_list();
193 let ty_args = ty_params.as_ref().map(|it| it.to_generic_args());
194 let where_clause = strukt.where_clause();
195 let assoc_item_list = make::assoc_item_list(Some(vec![item]));
196
197 let impl_def = make::impl_(
198 None,
199 ty_params,
200 ty_args,
201 make::ty_path(make::ext::ident_path(name)),
202 where_clause,
203 Some(assoc_item_list),
204 )
205 .clone_for_update();
206
207 let indent = strukt.indent_level();
209 let impl_def = impl_def.indent(indent);
210
211 let strukt = edit.make_mut(strukt.clone());
213 editor.insert_all(
214 Position::after(strukt.syntax()),
215 vec![
216 make::tokens::whitespace(&format!("\n\n{indent}")).into(),
217 impl_def.syntax().clone().into(),
218 ],
219 );
220 impl_def.assoc_item_list().and_then(|list| list.assoc_items().next())
221 }
222 };
223
224 if let Some(cap) = ctx.config.snippet_cap
225 && let Some(fn_) = fn_
226 {
227 let tabstop = edit.make_tabstop_before(cap);
228 editor.add_annotation(fn_.syntax(), tabstop);
229 }
230 edit.add_file_edits(ctx.vfs_file_id(), editor);
231 },
232 )?;
233 }
234 Some(())
235}
236
237#[cfg(test)]
238mod tests {
239 use crate::tests::{
240 check_assist, check_assist_not_applicable, check_assist_not_applicable_no_grouping,
241 };
242
243 use super::*;
244
245 #[test]
246 fn test_generate_delegate_create_impl_block() {
247 check_assist(
248 generate_delegate_methods,
249 r#"
250struct Age(u8);
251impl Age {
252 fn age(&self) -> u8 {
253 self.0
254 }
255}
256
257struct Person {
258 ag$0e: Age,
259}"#,
260 r#"
261struct Age(u8);
262impl Age {
263 fn age(&self) -> u8 {
264 self.0
265 }
266}
267
268struct Person {
269 age: Age,
270}
271
272impl Person {
273 $0fn age(&self) -> u8 {
274 self.age.age()
275 }
276}"#,
277 );
278 }
279
280 #[test]
281 fn test_generate_delegate_create_impl_block_match_indent() {
282 check_assist(
283 generate_delegate_methods,
284 r#"
285mod indent {
286 struct Age(u8);
287 impl Age {
288 fn age(&self) -> u8 {
289 self.0
290 }
291 }
292
293 struct Person {
294 ag$0e: Age,
295 }
296}"#,
297 r#"
298mod indent {
299 struct Age(u8);
300 impl Age {
301 fn age(&self) -> u8 {
302 self.0
303 }
304 }
305
306 struct Person {
307 age: Age,
308 }
309
310 impl Person {
311 $0fn age(&self) -> u8 {
312 self.age.age()
313 }
314 }
315}"#,
316 );
317 }
318
319 #[test]
320 fn test_generate_delegate_update_impl_block() {
321 check_assist(
322 generate_delegate_methods,
323 r#"
324struct Age(u8);
325impl Age {
326 fn age(&self) -> u8 {
327 self.0
328 }
329}
330
331struct Person {
332 ag$0e: Age,
333}
334
335impl Person {}"#,
336 r#"
337struct Age(u8);
338impl Age {
339 fn age(&self) -> u8 {
340 self.0
341 }
342}
343
344struct Person {
345 age: Age,
346}
347
348impl Person {
349 $0fn age(&self) -> u8 {
350 self.age.age()
351 }
352}"#,
353 );
354 }
355
356 #[test]
357 fn test_generate_delegate_update_impl_block_match_indent() {
358 check_assist(
359 generate_delegate_methods,
360 r#"
361mod indent {
362 struct Age(u8);
363 impl Age {
364 fn age(&self) -> u8 {
365 self.0
366 }
367 }
368
369 struct Person {
370 ag$0e: Age,
371 }
372
373 impl Person {}
374}"#,
375 r#"
376mod indent {
377 struct Age(u8);
378 impl Age {
379 fn age(&self) -> u8 {
380 self.0
381 }
382 }
383
384 struct Person {
385 age: Age,
386 }
387
388 impl Person {
389 $0fn age(&self) -> u8 {
390 self.age.age()
391 }
392 }
393}"#,
394 );
395 }
396
397 #[test]
398 fn test_generate_delegate_tuple_struct() {
399 check_assist(
400 generate_delegate_methods,
401 r#"
402struct Age(u8);
403impl Age {
404 fn age(&self) -> u8 {
405 self.0
406 }
407}
408
409struct Person(A$0ge);"#,
410 r#"
411struct Age(u8);
412impl Age {
413 fn age(&self) -> u8 {
414 self.0
415 }
416}
417
418struct Person(Age);
419
420impl Person {
421 $0fn age(&self) -> u8 {
422 self.0.age()
423 }
424}"#,
425 );
426 }
427
428 #[test]
429 fn test_generate_delegate_enable_all_attributes() {
430 check_assist(
431 generate_delegate_methods,
432 r#"
433struct Age<T>(T);
434impl<T> Age<T> {
435 pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
436 self.0
437 }
438}
439
440struct Person<T> {
441 ag$0e: Age<T>,
442}"#,
443 r#"
444struct Age<T>(T);
445impl<T> Age<T> {
446 pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
447 self.0
448 }
449}
450
451struct Person<T> {
452 age: Age<T>,
453}
454
455impl<T> Person<T> {
456 $0pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
457 self.age.age(ty, arg).await
458 }
459}"#,
460 );
461 }
462
463 #[test]
464 fn test_generates_delegate_autoderef() {
465 check_assist(
466 generate_delegate_methods,
467 r#"
468//- minicore: deref
469struct Age(u8);
470impl Age {
471 fn age(&self) -> u8 {
472 self.0
473 }
474}
475struct AgeDeref(Age);
476impl core::ops::Deref for AgeDeref { type Target = Age; }
477struct Person {
478 ag$0e: AgeDeref,
479}
480impl Person {}"#,
481 r#"
482struct Age(u8);
483impl Age {
484 fn age(&self) -> u8 {
485 self.0
486 }
487}
488struct AgeDeref(Age);
489impl core::ops::Deref for AgeDeref { type Target = Age; }
490struct Person {
491 age: AgeDeref,
492}
493impl Person {
494 $0fn age(&self) -> u8 {
495 self.age.age()
496 }
497}"#,
498 );
499 }
500
501 #[test]
502 fn test_preserve_where_clause() {
503 check_assist(
504 generate_delegate_methods,
505 r#"
506struct Inner<T>(T);
507impl<T> Inner<T> {
508 fn get(&self) -> T
509 where
510 T: Copy,
511 T: PartialEq,
512 {
513 self.0
514 }
515}
516
517struct Struct<T> {
518 $0field: Inner<T>,
519}
520"#,
521 r#"
522struct Inner<T>(T);
523impl<T> Inner<T> {
524 fn get(&self) -> T
525 where
526 T: Copy,
527 T: PartialEq,
528 {
529 self.0
530 }
531}
532
533struct Struct<T> {
534 field: Inner<T>,
535}
536
537impl<T> Struct<T> {
538 $0fn get(&self) -> T where
539 T: Copy,
540 T: PartialEq, {
541 self.field.get()
542 }
543}
544"#,
545 );
546 }
547
548 #[test]
549 fn test_fixes_basic_self_references() {
550 check_assist(
551 generate_delegate_methods,
552 r#"
553struct Foo {
554 field: $0Bar,
555}
556
557struct Bar;
558
559impl Bar {
560 fn bar(&self, other: Self) -> Self {
561 other
562 }
563}
564"#,
565 r#"
566struct Foo {
567 field: Bar,
568}
569
570impl Foo {
571 $0fn bar(&self, other: Bar) -> Bar {
572 self.field.bar(other)
573 }
574}
575
576struct Bar;
577
578impl Bar {
579 fn bar(&self, other: Self) -> Self {
580 other
581 }
582}
583"#,
584 );
585 }
586
587 #[test]
588 fn test_fixes_nested_self_references() {
589 check_assist(
590 generate_delegate_methods,
591 r#"
592struct Foo {
593 field: $0Bar,
594}
595
596struct Bar;
597
598impl Bar {
599 fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
600}
601"#,
602 r#"
603struct Foo {
604 field: Bar,
605}
606
607impl Foo {
608 $0fn bar(&mut self, a: (Bar, [Bar; 4]), b: Vec<Bar>) {
609 self.field.bar(a, b)
610 }
611}
612
613struct Bar;
614
615impl Bar {
616 fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
617}
618"#,
619 );
620 }
621
622 #[test]
623 fn test_fixes_self_references_with_lifetimes_and_generics() {
624 check_assist(
625 generate_delegate_methods,
626 r#"
627struct Foo<'a, T> {
628 $0field: Bar<'a, T>,
629}
630
631struct Bar<'a, T>(&'a T);
632
633impl<'a, T> Bar<'a, T> {
634 fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
635 b.pop().unwrap()
636 }
637}
638"#,
639 r#"
640struct Foo<'a, T> {
641 field: Bar<'a, T>,
642}
643
644impl<'a, T> Foo<'a, T> {
645 $0fn bar(self, mut b: Vec<&'a Bar<'a, T>>) -> &'a Bar<'a, T> {
646 self.field.bar(b)
647 }
648}
649
650struct Bar<'a, T>(&'a T);
651
652impl<'a, T> Bar<'a, T> {
653 fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
654 b.pop().unwrap()
655 }
656}
657"#,
658 );
659 }
660
661 #[test]
662 fn test_fixes_self_references_across_macros() {
663 check_assist(
664 generate_delegate_methods,
665 r#"
666//- /bar.rs
667macro_rules! test_method {
668 () => {
669 pub fn test(self, b: Bar) -> Self {
670 self
671 }
672 };
673}
674
675pub struct Bar;
676
677impl Bar {
678 test_method!();
679}
680
681//- /main.rs
682mod bar;
683
684struct Foo {
685 $0bar: bar::Bar,
686}
687"#,
688 r#"
689mod bar;
690
691struct Foo {
692 bar: bar::Bar,
693}
694
695impl Foo {
696 $0pub fn test(self,b:bar::Bar) ->bar::Bar {
697 self.bar.test(b)
698 }
699}
700"#,
701 );
702 }
703
704 #[test]
705 fn test_generate_delegate_visibility() {
706 check_assist_not_applicable(
707 generate_delegate_methods,
708 r#"
709mod m {
710 pub struct Age(u8);
711 impl Age {
712 fn age(&self) -> u8 {
713 self.0
714 }
715 }
716}
717
718struct Person {
719 ag$0e: m::Age,
720}"#,
721 )
722 }
723
724 #[test]
725 fn test_generate_not_eligible_if_fn_exists() {
726 check_assist_not_applicable(
727 generate_delegate_methods,
728 r#"
729struct Age(u8);
730impl Age {
731 fn age(&self) -> u8 {
732 self.0
733 }
734}
735
736struct Person {
737 ag$0e: Age,
738}
739impl Person {
740 fn age(&self) -> u8 { 0 }
741}
742"#,
743 );
744 }
745
746 #[test]
747 fn delegate_method_skipped_when_no_grouping() {
748 check_assist_not_applicable_no_grouping(
749 generate_delegate_methods,
750 r#"
751struct Age(u8);
752impl Age {
753 fn age(&self) -> u8 {
754 self.0
755 }
756}
757struct Person {
758 ag$0e: Age,
759}"#,
760 );
761 }
762}