Skip to main content

ide_assists/handlers/
generate_mut_trait_impl.rs

1use ide_db::{famous_defs::FamousDefs, traits::resolve_target_trait};
2use syntax::{
3    AstNode, SyntaxElement, SyntaxNode, T,
4    ast::{self, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
5    syntax_editor::{Element, Position, SyntaxEditor},
6};
7
8use crate::{AssistContext, AssistId, Assists};
9
10// FIXME: Generate proper `index_mut` method body refer to `index` method body may impossible due to the unpredictable case [#15581].
11// Here just leave the `index_mut` method body be same as `index` method body, user can modify it manually to meet their need.
12
13// Assist: generate_mut_trait_impl
14//
15// Adds a IndexMut impl from the `Index` trait.
16//
17// ```
18// # //- minicore: index
19// pub enum Axis { X = 0, Y = 1, Z = 2 }
20//
21// impl<T> core::ops::Index$0<Axis> for [T; 3] {
22//     type Output = T;
23//
24//     fn index(&self, index: Axis) -> &Self::Output {
25//         &self[index as usize]
26//     }
27// }
28// ```
29// ->
30// ```
31// pub enum Axis { X = 0, Y = 1, Z = 2 }
32//
33// $0impl<T> core::ops::IndexMut<Axis> for [T; 3] {
34//     fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
35//         &mut self[index as usize]
36//     }
37// }
38//
39// impl<T> core::ops::Index<Axis> for [T; 3] {
40//     type Output = T;
41//
42//     fn index(&self, index: Axis) -> &Self::Output {
43//         &self[index as usize]
44//     }
45// }
46// ```
47pub(crate) fn generate_mut_trait_impl(
48    acc: &mut Assists,
49    ctx: &AssistContext<'_, '_>,
50) -> Option<()> {
51    let impl_def = ctx.find_node_at_offset::<ast::Impl>()?;
52    let indent = impl_def.indent_level();
53
54    let ast::Type::PathType(path) = impl_def.trait_()? else {
55        return None;
56    };
57
58    let trait_name = path.path()?.segment()?.name_ref()?;
59
60    let scope = ctx.sema.scope(impl_def.trait_()?.syntax())?;
61    let famous = FamousDefs(&ctx.sema, scope.krate());
62
63    let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?;
64    let trait_new = get_trait_mut(&trait_, famous)?;
65
66    let target = impl_def.syntax().text_range();
67
68    acc.add(
69        AssistId::generate("generate_mut_trait_impl"),
70        format!("Generate `{trait_new}` impl from this `{trait_name}` trait"),
71        target,
72        |edit| {
73            let (editor, impl_clone) = SyntaxEditor::with_ast_node(&impl_def.reset_indent());
74
75            apply_generate_mut_impl(&editor, &impl_clone, trait_new);
76
77            let new_root = editor.finish();
78            let new_root = new_root.new_root();
79
80            let new_impl = ast::Impl::cast(new_root.clone()).unwrap();
81
82            let new_impl = new_impl.indent(indent);
83
84            let editor = edit.make_editor(impl_def.syntax());
85            let make = editor.make();
86            editor.insert_all(
87                Position::before(impl_def.syntax()),
88                vec![
89                    new_impl.syntax().syntax_element(),
90                    make.whitespace(&format!("\n\n{indent}")).syntax_element(),
91                ],
92            );
93
94            if let Some(cap) = ctx.config.snippet_cap {
95                let tabstop_before = edit.make_tabstop_before(cap);
96                editor.add_annotation(new_impl.syntax(), tabstop_before);
97            }
98
99            edit.add_file_edits(ctx.vfs_file_id(), editor);
100        },
101    )
102}
103
104fn delete_with_trivia(editor: &SyntaxEditor, node: &SyntaxNode) {
105    let mut end: SyntaxElement = node.clone().into();
106
107    if let Some(next) = node.next_sibling_or_token()
108        && let SyntaxElement::Token(tok) = &next
109        && tok.kind().is_trivia()
110    {
111        end = next.clone();
112    }
113
114    editor.delete_all(node.clone().into()..=end);
115}
116
117fn apply_generate_mut_impl(
118    editor: &SyntaxEditor,
119    impl_def: &ast::Impl,
120    trait_new: &str,
121) -> Option<()> {
122    let make = editor.make();
123    let path =
124        impl_def.trait_().and_then(|t| t.syntax().descendants().find_map(ast::Path::cast))?;
125    let seg = path.segment()?;
126    let name_ref = seg.name_ref()?;
127
128    let new_name_ref = make.name_ref(trait_new);
129    editor.replace(name_ref.syntax(), new_name_ref.syntax());
130
131    if let Some((name, new_name)) =
132        impl_def.syntax().descendants().filter_map(ast::Name::cast).find_map(process_method_name)
133    {
134        let new_name_node = make.name(new_name);
135        editor.replace(name.syntax(), new_name_node.syntax());
136    }
137
138    if let Some(type_alias) = impl_def.syntax().descendants().find_map(ast::TypeAlias::cast) {
139        delete_with_trivia(editor, type_alias.syntax());
140    }
141
142    if let Some(self_param) = impl_def.syntax().descendants().find_map(ast::SelfParam::cast) {
143        let mut_self = make.mut_self_param();
144        editor.replace(self_param.syntax(), mut_self.syntax());
145    }
146
147    if let Some(ret_type) = impl_def.syntax().descendants().find_map(ast::RetType::cast)
148        && let Some(new_ty) = process_ret_type(make, &ret_type)
149    {
150        let new_ret = make.ret_type(new_ty);
151        editor.replace(ret_type.syntax(), new_ret.syntax())
152    }
153
154    if let Some(fn_) = impl_def.assoc_item_list().and_then(|l| {
155        l.assoc_items().find_map(|it| match it {
156            ast::AssocItem::Fn(f) => Some(f),
157            _ => None,
158        })
159    }) {
160        process_ref_mut(editor, &fn_);
161    }
162
163    Some(())
164}
165
166fn process_ref_mut(editor: &SyntaxEditor, fn_: &ast::Fn) {
167    let make = editor.make();
168    let Some(expr) = fn_.body().and_then(|b| b.tail_expr()) else { return };
169
170    let ast::Expr::RefExpr(ref_expr) = expr else { return };
171
172    if ref_expr.mut_token().is_some() {
173        return;
174    }
175
176    let Some(amp) = ref_expr.amp_token() else { return };
177
178    let mut_kw = make.token(T![mut]);
179    let space = make.whitespace(" ");
180
181    editor.insert(Position::after(amp.clone()), space.syntax_element());
182    editor.insert(Position::after(amp), mut_kw.syntax_element());
183}
184
185fn process_ret_type(factory: &SyntaxFactory, ref_ty: &ast::RetType) -> Option<ast::Type> {
186    let ty = ref_ty.ty()?;
187    let ast::Type::RefType(ref_type) = ty else {
188        return None;
189    };
190
191    let inner = ref_type.ty()?;
192    Some(factory.ty_ref(inner, true))
193}
194
195fn get_trait_mut(apply_trait: &hir::Trait, famous: FamousDefs<'_, '_>) -> Option<&'static str> {
196    let trait_ = Some(apply_trait);
197    if trait_ == famous.core_convert_Index().as_ref() {
198        return Some("IndexMut");
199    }
200    if trait_ == famous.core_convert_AsRef().as_ref() {
201        return Some("AsMut");
202    }
203    if trait_ == famous.core_borrow_Borrow().as_ref() {
204        return Some("BorrowMut");
205    }
206    if trait_ == famous.core_ops_Deref().as_ref() {
207        return Some("DerefMut");
208    }
209    None
210}
211
212fn process_method_name(name: ast::Name) -> Option<(ast::Name, &'static str)> {
213    let new_name = match &*name.text() {
214        "index" => "index_mut",
215        "as_ref" => "as_mut",
216        "borrow" => "borrow_mut",
217        "deref" => "deref_mut",
218        _ => return None,
219    };
220    Some((name, new_name))
221}
222
223#[cfg(test)]
224mod tests {
225    use crate::{
226        AssistConfig,
227        tests::{TEST_CONFIG, check_assist, check_assist_not_applicable, check_assist_with_config},
228    };
229
230    use super::*;
231
232    #[test]
233    fn test_generate_mut_trait_impl() {
234        check_assist(
235            generate_mut_trait_impl,
236            r#"
237//- minicore: index
238pub enum Axis { X = 0, Y = 1, Z = 2 }
239
240impl<T> core::ops::Index$0<Axis> for [T; 3] {
241    type Output = T;
242
243    fn index(&self, index: Axis) -> &Self::Output {
244        &self[index as usize]
245    }
246}
247"#,
248            r#"
249pub enum Axis { X = 0, Y = 1, Z = 2 }
250
251$0impl<T> core::ops::IndexMut<Axis> for [T; 3] {
252    fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
253        &mut self[index as usize]
254    }
255}
256
257impl<T> core::ops::Index<Axis> for [T; 3] {
258    type Output = T;
259
260    fn index(&self, index: Axis) -> &Self::Output {
261        &self[index as usize]
262    }
263}
264"#,
265        );
266
267        check_assist(
268            generate_mut_trait_impl,
269            r#"
270//- minicore: index
271pub enum Axis { X = 0, Y = 1, Z = 2 }
272
273impl<T> core::ops::Index$0<Axis> for [T; 3] where T: Copy {
274    type Output = T;
275
276    fn index(&self, index: Axis) -> &Self::Output {
277        let var_name = &self[index as usize];
278        var_name
279    }
280}
281"#,
282            r#"
283pub enum Axis { X = 0, Y = 1, Z = 2 }
284
285$0impl<T> core::ops::IndexMut<Axis> for [T; 3] where T: Copy {
286    fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
287        let var_name = &self[index as usize];
288        var_name
289    }
290}
291
292impl<T> core::ops::Index<Axis> for [T; 3] where T: Copy {
293    type Output = T;
294
295    fn index(&self, index: Axis) -> &Self::Output {
296        let var_name = &self[index as usize];
297        var_name
298    }
299}
300"#,
301        );
302
303        check_assist(
304            generate_mut_trait_impl,
305            r#"
306//- minicore: as_ref
307struct Foo(i32);
308
309impl core::convert::AsRef$0<i32> for Foo {
310    fn as_ref(&self) -> &i32 {
311        &self.0
312    }
313}
314"#,
315            r#"
316struct Foo(i32);
317
318$0impl core::convert::AsMut<i32> for Foo {
319    fn as_mut(&mut self) -> &mut i32 {
320        &mut self.0
321    }
322}
323
324impl core::convert::AsRef<i32> for Foo {
325    fn as_ref(&self) -> &i32 {
326        &self.0
327    }
328}
329"#,
330        );
331
332        check_assist(
333            generate_mut_trait_impl,
334            r#"
335//- minicore: deref
336struct Foo(i32);
337
338impl core::ops::Deref$0 for Foo {
339    type Target = i32;
340
341    fn deref(&self) -> &Self::Target {
342        &self.0
343    }
344}
345"#,
346            r#"
347struct Foo(i32);
348
349$0impl core::ops::DerefMut for Foo {
350    fn deref_mut(&mut self) -> &mut Self::Target {
351        &mut self.0
352    }
353}
354
355impl core::ops::Deref for Foo {
356    type Target = i32;
357
358    fn deref(&self) -> &Self::Target {
359        &self.0
360    }
361}
362"#,
363        );
364    }
365
366    #[test]
367    fn test_generate_mut_trait_impl_non_zero_indent() {
368        check_assist(
369            generate_mut_trait_impl,
370            r#"
371//- minicore: index
372mod foo {
373    pub enum Axis { X = 0, Y = 1, Z = 2 }
374
375    impl<T> core::ops::Index$0<Axis> for [T; 3] where T: Copy {
376        type Output = T;
377
378        fn index(&self, index: Axis) -> &Self::Output {
379            let var_name = &self[index as usize];
380            var_name
381        }
382    }
383}
384"#,
385            r#"
386mod foo {
387    pub enum Axis { X = 0, Y = 1, Z = 2 }
388
389    $0impl<T> core::ops::IndexMut<Axis> for [T; 3] where T: Copy {
390        fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
391            let var_name = &self[index as usize];
392            var_name
393        }
394    }
395
396    impl<T> core::ops::Index<Axis> for [T; 3] where T: Copy {
397        type Output = T;
398
399        fn index(&self, index: Axis) -> &Self::Output {
400            let var_name = &self[index as usize];
401            var_name
402        }
403    }
404}
405"#,
406        );
407
408        check_assist(
409            generate_mut_trait_impl,
410            r#"
411//- minicore: index
412mod foo {
413    mod bar {
414        pub enum Axis { X = 0, Y = 1, Z = 2 }
415
416        impl<T> core::ops::Index$0<Axis> for [T; 3] where T: Copy {
417            type Output = T;
418
419            fn index(&self, index: Axis) -> &Self::Output {
420                let var_name = &self[index as usize];
421                var_name
422            }
423        }
424    }
425}
426"#,
427            r#"
428mod foo {
429    mod bar {
430        pub enum Axis { X = 0, Y = 1, Z = 2 }
431
432        $0impl<T> core::ops::IndexMut<Axis> for [T; 3] where T: Copy {
433            fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
434                let var_name = &self[index as usize];
435                var_name
436            }
437        }
438
439        impl<T> core::ops::Index<Axis> for [T; 3] where T: Copy {
440            type Output = T;
441
442            fn index(&self, index: Axis) -> &Self::Output {
443                let var_name = &self[index as usize];
444                var_name
445            }
446        }
447    }
448}
449"#,
450        );
451    }
452
453    #[test]
454    fn test_generate_mut_trait_impl_not_applicable() {
455        check_assist_not_applicable(
456            generate_mut_trait_impl,
457            r#"
458pub trait Index<Idx: ?Sized> {}
459
460impl<T> Index$0<i32> for [T; 3] {}
461"#,
462        );
463        check_assist_not_applicable(
464            generate_mut_trait_impl,
465            r#"
466pub trait AsRef<T: ?Sized> {}
467
468impl AsRef$0<i32> for [T; 3] {}
469"#,
470        );
471    }
472
473    #[test]
474    fn no_snippets() {
475        check_assist_with_config(
476            generate_mut_trait_impl,
477            AssistConfig { snippet_cap: None, ..TEST_CONFIG },
478            r#"
479//- minicore: index
480pub enum Axis { X = 0, Y = 1, Z = 2 }
481
482impl<T> core::ops::Index$0<Axis> for [T; 3] {
483    type Output = T;
484
485    fn index(&self, index: Axis) -> &Self::Output {
486        &self[index as usize]
487    }
488}
489"#,
490            r#"
491pub enum Axis { X = 0, Y = 1, Z = 2 }
492
493impl<T> core::ops::IndexMut<Axis> for [T; 3] {
494    fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
495        &mut self[index as usize]
496    }
497}
498
499impl<T> core::ops::Index<Axis> for [T; 3] {
500    type Output = T;
501
502    fn index(&self, index: Axis) -> &Self::Output {
503        &self[index as usize]
504    }
505}
506"#,
507        );
508    }
509}