1use hir::HasVisibility;
2use ide_db::{FxHashSet, path_transform::PathTransform};
3use syntax::{
4 ast::{
5 self, AstNode, HasGenericParams, HasName, HasVisibility as _,
6 edit::{AstNodeEdit, IndentLevel},
7 },
8 syntax_editor::Position,
9};
10
11use crate::{
12 AssistContext, AssistId, AssistKind, Assists, GroupLabel,
13 utils::{convert_param_list_to_arg_list, find_struct_impl},
14};
15
16pub(crate) fn generate_delegate_methods(
52 acc: &mut Assists,
53 ctx: &AssistContext<'_, '_>,
54) -> Option<()> {
55 if !ctx.config.code_action_grouping {
56 return None;
57 }
58
59 let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
60 let strukt_name = strukt.name()?;
61 let current_module = ctx.sema.scope(strukt.syntax())?.module();
62 let current_edition = current_module.krate(ctx.db()).edition(ctx.db());
63
64 let (field_name, field_ty, target) = match ctx.find_node_at_offset::<ast::RecordField>() {
65 Some(field) => {
66 let field_name = field.name()?;
67 let field_ty = field.ty()?;
68 (field_name.to_string(), field_ty, field.syntax().text_range())
69 }
70 None => {
71 let field = ctx.find_node_at_offset::<ast::TupleField>()?;
72 let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
73 let field_list_index = field_list.fields().position(|it| it == field)?;
74 let field_ty = field.ty()?;
75 (field_list_index.to_string(), field_ty, field.syntax().text_range())
76 }
77 };
78
79 let sema_field_ty = ctx.sema.resolve_type(&field_ty)?;
80 let mut methods = vec![];
81 let mut seen_names = FxHashSet::default();
82
83 for ty in sema_field_ty.autoderef(ctx.db()) {
84 ty.iterate_assoc_items(ctx.db(), |item| {
85 if let hir::AssocItem::Function(f) = item {
86 let name = f.name(ctx.db());
87 if f.self_param(ctx.db()).is_some()
88 && f.is_visible_from(ctx.db(), current_module)
89 && seen_names.insert(name.clone())
90 {
91 methods.push((name, f))
92 }
93 }
94 Option::<()>::None
95 });
96 }
97 methods.sort_by(|(a, _), (b, _)| a.cmp(b));
98 for (index, (name, method)) in methods.into_iter().enumerate() {
99 let adt = ast::Adt::Struct(strukt.clone());
100 let name = name.display(ctx.db(), current_edition).to_string();
101 let Some(impl_def) = find_struct_impl(ctx, &adt, std::slice::from_ref(&name)) else {
103 continue;
104 };
105
106 acc.add_group(
107 &GroupLabel("Generate delegate methods…".to_owned()),
108 AssistId("generate_delegate_methods", AssistKind::Generate, Some(index)),
109 format!("Generate delegate for `{field_name}.{name}()`",),
110 target,
111 |edit| {
112 let editor = edit.make_editor(strukt.syntax());
113 let make = editor.make();
114 let field = make
115 .field_from_idents(["self", &field_name])
116 .expect("always be a valid expression");
117 let method_source = match ctx.sema.source(method) {
119 Some(source) => {
120 let v = source.value;
121 let source_scope = ctx.sema.scope(v.syntax());
122 let target_scope = ctx.sema.scope(strukt.syntax());
123 if let (Some(s), Some(t)) = (source_scope, target_scope) {
124 ast::Fn::cast(
125 PathTransform::generic_transformation(&t, &s).apply(v.syntax()),
126 )
127 .unwrap_or(v)
128 } else {
129 v
130 }
131 }
132 None => return,
133 };
134
135 let vis = method_source.visibility();
136 let is_async = method_source.async_token().is_some();
137 let is_const = method_source.const_token().is_some();
138 let is_unsafe = method_source.unsafe_token().is_some();
139 let is_gen = method_source.gen_token().is_some();
140
141 let fn_name = make.name(&name);
142
143 let type_params = method_source.generic_param_list();
144 let where_clause = method_source.where_clause();
145 let params =
146 method_source.param_list().unwrap_or_else(|| make.param_list(None, []));
147
148 let arg_list = method_source
150 .param_list()
151 .map(|v| convert_param_list_to_arg_list(v, make))
152 .unwrap_or_else(|| make.arg_list([]));
153
154 let tail_expr = make.expr_method_call(field, make.name_ref(&name), arg_list).into();
155 let tail_expr_finished =
156 if is_async { make.expr_await(tail_expr).into() } else { tail_expr };
157 let body = make.block_expr([], Some(tail_expr_finished));
158
159 let ret_type = method_source.ret_type();
160
161 let f = make
162 .fn_(
163 None,
164 vis,
165 fn_name,
166 type_params,
167 where_clause,
168 params,
169 body,
170 ret_type,
171 is_async,
172 is_const,
173 is_unsafe,
174 is_gen,
175 )
176 .indent(IndentLevel(1));
177 let item = ast::AssocItem::Fn(f.clone());
178
179 let fn_: Option<ast::AssocItem> = match impl_def {
180 Some(impl_def) => match impl_def.assoc_item_list() {
181 Some(assoc_item_list) => {
182 let item = item.indent(IndentLevel::from_node(impl_def.syntax()));
183 assoc_item_list.add_items(&editor, vec![item.clone()]);
184 Some(item)
185 }
186 None => {
187 let assoc_item_list = make.assoc_item_list(vec![item]);
188 editor.insert(
189 Position::last_child_of(impl_def.syntax()),
190 assoc_item_list.syntax(),
191 );
192 assoc_item_list.assoc_items().next()
193 }
194 },
195 None => {
196 let name = &strukt_name.to_string();
197 let ty_params = strukt.generic_param_list();
198 let ty_args = ty_params.as_ref().map(|it| it.to_generic_args(make));
199 let where_clause = strukt.where_clause();
200 let assoc_item_list = make.assoc_item_list(vec![item]);
201
202 let impl_def = make.impl_(
203 None,
204 ty_params,
205 ty_args,
206 syntax::ast::Type::PathType(make.ty_path(make.ident_path(name))),
207 where_clause,
208 Some(assoc_item_list),
209 );
210
211 let indent = strukt.indent_level();
213 let impl_def = impl_def.indent(indent);
214
215 editor.insert_all(
217 Position::after(strukt.syntax()),
218 vec![
219 make.whitespace(&format!("\n\n{indent}")).into(),
220 impl_def.syntax().clone().into(),
221 ],
222 );
223 impl_def.assoc_item_list().and_then(|list| list.assoc_items().next())
224 }
225 };
226
227 if let Some(cap) = ctx.config.snippet_cap
228 && let Some(fn_) = fn_
229 {
230 let tabstop = edit.make_tabstop_before(cap);
231 editor.add_annotation(fn_.syntax(), tabstop);
232 }
233 edit.add_file_edits(ctx.vfs_file_id(), editor);
234 },
235 )?;
236 }
237 Some(())
238}
239
240#[cfg(test)]
241mod tests {
242 use crate::tests::{
243 check_assist, check_assist_not_applicable, check_assist_not_applicable_no_grouping,
244 };
245
246 use super::*;
247
248 #[test]
249 fn test_generate_delegate_create_impl_block() {
250 check_assist(
251 generate_delegate_methods,
252 r#"
253struct Age(u8);
254impl Age {
255 fn age(&self) -> u8 {
256 self.0
257 }
258}
259
260struct Person {
261 ag$0e: Age,
262}"#,
263 r#"
264struct Age(u8);
265impl Age {
266 fn age(&self) -> u8 {
267 self.0
268 }
269}
270
271struct Person {
272 age: Age,
273}
274
275impl Person {
276 $0fn age(&self) -> u8 {
277 self.age.age()
278 }
279}"#,
280 );
281 }
282
283 #[test]
284 fn test_generate_delegate_create_impl_block_match_indent() {
285 check_assist(
286 generate_delegate_methods,
287 r#"
288mod indent {
289 struct Age(u8);
290 impl Age {
291 fn age(&self) -> u8 {
292 self.0
293 }
294 }
295
296 struct Person {
297 ag$0e: Age,
298 }
299}"#,
300 r#"
301mod indent {
302 struct Age(u8);
303 impl Age {
304 fn age(&self) -> u8 {
305 self.0
306 }
307 }
308
309 struct Person {
310 age: Age,
311 }
312
313 impl Person {
314 $0fn age(&self) -> u8 {
315 self.age.age()
316 }
317 }
318}"#,
319 );
320 }
321
322 #[test]
323 fn test_generate_delegate_update_impl_block() {
324 check_assist(
325 generate_delegate_methods,
326 r#"
327struct Age(u8);
328impl Age {
329 fn age(&self) -> u8 {
330 self.0
331 }
332}
333
334struct Person {
335 ag$0e: Age,
336}
337
338impl Person {}"#,
339 r#"
340struct Age(u8);
341impl Age {
342 fn age(&self) -> u8 {
343 self.0
344 }
345}
346
347struct Person {
348 age: Age,
349}
350
351impl Person {
352 $0fn age(&self) -> u8 {
353 self.age.age()
354 }
355}"#,
356 );
357 }
358
359 #[test]
360 fn test_generate_delegate_update_impl_block_match_indent() {
361 check_assist(
362 generate_delegate_methods,
363 r#"
364mod indent {
365 struct Age(u8);
366 impl Age {
367 fn age(&self) -> u8 {
368 self.0
369 }
370 }
371
372 struct Person {
373 ag$0e: Age,
374 }
375
376 impl Person {}
377}"#,
378 r#"
379mod indent {
380 struct Age(u8);
381 impl Age {
382 fn age(&self) -> u8 {
383 self.0
384 }
385 }
386
387 struct Person {
388 age: Age,
389 }
390
391 impl Person {
392 $0fn age(&self) -> u8 {
393 self.age.age()
394 }
395 }
396}"#,
397 );
398 }
399
400 #[test]
401 fn test_generate_delegate_tuple_struct() {
402 check_assist(
403 generate_delegate_methods,
404 r#"
405struct Age(u8);
406impl Age {
407 fn age(&self) -> u8 {
408 self.0
409 }
410}
411
412struct Person(A$0ge);"#,
413 r#"
414struct Age(u8);
415impl Age {
416 fn age(&self) -> u8 {
417 self.0
418 }
419}
420
421struct Person(Age);
422
423impl Person {
424 $0fn age(&self) -> u8 {
425 self.0.age()
426 }
427}"#,
428 );
429 }
430
431 #[test]
432 fn test_generate_delegate_enable_all_attributes() {
433 check_assist(
434 generate_delegate_methods,
435 r#"
436struct Age<T>(T);
437impl<T> Age<T> {
438 pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
439 self.0
440 }
441}
442
443struct Person<T> {
444 ag$0e: Age<T>,
445}"#,
446 r#"
447struct Age<T>(T);
448impl<T> Age<T> {
449 pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
450 self.0
451 }
452}
453
454struct Person<T> {
455 age: Age<T>,
456}
457
458impl<T> Person<T> {
459 $0pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
460 self.age.age(ty, arg).await
461 }
462}"#,
463 );
464 }
465
466 #[test]
467 fn test_generates_delegate_autoderef() {
468 check_assist(
469 generate_delegate_methods,
470 r#"
471//- minicore: deref
472struct Age(u8);
473impl Age {
474 fn age(&self) -> u8 {
475 self.0
476 }
477}
478struct AgeDeref(Age);
479impl core::ops::Deref for AgeDeref { type Target = Age; }
480struct Person {
481 ag$0e: AgeDeref,
482}
483impl Person {}"#,
484 r#"
485struct Age(u8);
486impl Age {
487 fn age(&self) -> u8 {
488 self.0
489 }
490}
491struct AgeDeref(Age);
492impl core::ops::Deref for AgeDeref { type Target = Age; }
493struct Person {
494 age: AgeDeref,
495}
496impl Person {
497 $0fn age(&self) -> u8 {
498 self.age.age()
499 }
500}"#,
501 );
502 }
503
504 #[test]
505 fn test_preserve_where_clause() {
506 check_assist(
507 generate_delegate_methods,
508 r#"
509struct Inner<T>(T);
510impl<T> Inner<T> {
511 fn get(&self) -> T
512 where
513 T: Copy,
514 T: PartialEq,
515 {
516 self.0
517 }
518}
519
520struct Struct<T> {
521 $0field: Inner<T>,
522}
523"#,
524 r#"
525struct Inner<T>(T);
526impl<T> Inner<T> {
527 fn get(&self) -> T
528 where
529 T: Copy,
530 T: PartialEq,
531 {
532 self.0
533 }
534}
535
536struct Struct<T> {
537 field: Inner<T>,
538}
539
540impl<T> Struct<T> {
541 $0fn get(&self) -> T where
542 T: Copy,
543 T: PartialEq, {
544 self.field.get()
545 }
546}
547"#,
548 );
549 }
550
551 #[test]
552 fn test_fixes_basic_self_references() {
553 check_assist(
554 generate_delegate_methods,
555 r#"
556struct Foo {
557 field: $0Bar,
558}
559
560struct Bar;
561
562impl Bar {
563 fn bar(&self, other: Self) -> Self {
564 other
565 }
566}
567"#,
568 r#"
569struct Foo {
570 field: Bar,
571}
572
573impl Foo {
574 $0fn bar(&self, other: Bar) -> Bar {
575 self.field.bar(other)
576 }
577}
578
579struct Bar;
580
581impl Bar {
582 fn bar(&self, other: Self) -> Self {
583 other
584 }
585}
586"#,
587 );
588 }
589
590 #[test]
591 fn test_fixes_nested_self_references() {
592 check_assist(
593 generate_delegate_methods,
594 r#"
595struct Foo {
596 field: $0Bar,
597}
598
599struct Bar;
600
601impl Bar {
602 fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
603}
604"#,
605 r#"
606struct Foo {
607 field: Bar,
608}
609
610impl Foo {
611 $0fn bar(&mut self, a: (Bar, [Bar; 4]), b: Vec<Bar>) {
612 self.field.bar(a, b)
613 }
614}
615
616struct Bar;
617
618impl Bar {
619 fn bar(&mut self, a: (Self, [Self; 4]), b: Vec<Self>) {}
620}
621"#,
622 );
623 }
624
625 #[test]
626 fn test_fixes_self_references_with_lifetimes_and_generics() {
627 check_assist(
628 generate_delegate_methods,
629 r#"
630struct Foo<'a, T> {
631 $0field: Bar<'a, T>,
632}
633
634struct Bar<'a, T>(&'a T);
635
636impl<'a, T> Bar<'a, T> {
637 fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
638 b.pop().unwrap()
639 }
640}
641"#,
642 r#"
643struct Foo<'a, T> {
644 field: Bar<'a, T>,
645}
646
647impl<'a, T> Foo<'a, T> {
648 $0fn bar(self, mut b: Vec<&'a Bar<'a, T>>) -> &'a Bar<'a, T> {
649 self.field.bar(b)
650 }
651}
652
653struct Bar<'a, T>(&'a T);
654
655impl<'a, T> Bar<'a, T> {
656 fn bar(self, mut b: Vec<&'a Self>) -> &'a Self {
657 b.pop().unwrap()
658 }
659}
660"#,
661 );
662 }
663
664 #[test]
665 fn test_fixes_self_references_across_macros() {
666 check_assist(
667 generate_delegate_methods,
668 r#"
669//- /bar.rs
670macro_rules! test_method {
671 () => {
672 pub fn test(self, b: Bar) -> Self {
673 self
674 }
675 };
676}
677
678pub struct Bar;
679
680impl Bar {
681 test_method!();
682}
683
684//- /main.rs
685mod bar;
686
687struct Foo {
688 $0bar: bar::Bar,
689}
690"#,
691 r#"
692mod bar;
693
694struct Foo {
695 bar: bar::Bar,
696}
697
698impl Foo {
699 $0pub fn test(self,b:bar::Bar) ->bar::Bar {
700 self.bar.test(b)
701 }
702}
703"#,
704 );
705 }
706
707 #[test]
708 fn test_generate_delegate_visibility() {
709 check_assist_not_applicable(
710 generate_delegate_methods,
711 r#"
712mod m {
713 pub struct Age(u8);
714 impl Age {
715 fn age(&self) -> u8 {
716 self.0
717 }
718 }
719}
720
721struct Person {
722 ag$0e: m::Age,
723}"#,
724 )
725 }
726
727 #[test]
728 fn test_generate_not_eligible_if_fn_exists() {
729 check_assist_not_applicable(
730 generate_delegate_methods,
731 r#"
732struct Age(u8);
733impl Age {
734 fn age(&self) -> u8 {
735 self.0
736 }
737}
738
739struct Person {
740 ag$0e: Age,
741}
742impl Person {
743 fn age(&self) -> u8 { 0 }
744}
745"#,
746 );
747 }
748
749 #[test]
750 fn delegate_method_skipped_when_no_grouping() {
751 check_assist_not_applicable_no_grouping(
752 generate_delegate_methods,
753 r#"
754struct Age(u8);
755impl Age {
756 fn age(&self) -> u8 {
757 self.0
758 }
759}
760struct Person {
761 ag$0e: Age,
762}"#,
763 );
764 }
765}