ide_assists/handlers/
generate_trait_from_impl.rs

1use crate::assist_context::{AssistContext, Assists};
2use ide_db::assists::AssistId;
3use syntax::{
4    AstNode, SyntaxKind, T,
5    ast::{self, HasGenericParams, HasName, HasVisibility, edit_in_place::Indent, make},
6    syntax_editor::{Position, SyntaxEditor},
7};
8
9// NOTES :
10// We generate erroneous code if a function is declared const (E0379)
11// This is left to the user to correct as our only option is to remove the
12// function completely which we should not be doing.
13
14// Assist: generate_trait_from_impl
15//
16// Generate trait for an already defined inherent impl and convert impl to a trait impl.
17//
18// ```
19// struct Foo<const N: usize>([i32; N]);
20//
21// macro_rules! const_maker {
22//     ($t:ty, $v:tt) => {
23//         const CONST: $t = $v;
24//     };
25// }
26//
27// impl<const N: usize> Fo$0o<N> {
28//     // Used as an associated constant.
29//     const CONST_ASSOC: usize = N * 4;
30//
31//     fn create() -> Option<()> {
32//         Some(())
33//     }
34//
35//     const_maker! {i32, 7}
36// }
37// ```
38// ->
39// ```
40// struct Foo<const N: usize>([i32; N]);
41//
42// macro_rules! const_maker {
43//     ($t:ty, $v:tt) => {
44//         const CONST: $t = $v;
45//     };
46// }
47//
48// trait ${0:NewTrait}<const N: usize> {
49//     // Used as an associated constant.
50//     const CONST_ASSOC: usize = N * 4;
51//
52//     fn create() -> Option<()>;
53//
54//     const_maker! {i32, 7}
55// }
56//
57// impl<const N: usize> ${0:NewTrait}<N> for Foo<N> {
58//     // Used as an associated constant.
59//     const CONST_ASSOC: usize = N * 4;
60//
61//     fn create() -> Option<()> {
62//         Some(())
63//     }
64//
65//     const_maker! {i32, 7}
66// }
67// ```
68pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
69    // Get AST Node
70    let impl_ast = ctx.find_node_at_offset::<ast::Impl>()?;
71
72    // Check if cursor is to the left of assoc item list's L_CURLY.
73    // if no L_CURLY then return.
74    let l_curly = impl_ast.assoc_item_list()?.l_curly_token()?;
75
76    let cursor_offset = ctx.offset();
77    let l_curly_offset = l_curly.text_range();
78    if cursor_offset >= l_curly_offset.start() {
79        return None;
80    }
81
82    // If impl is not inherent then we don't really need to go any further.
83    if impl_ast.for_token().is_some() {
84        return None;
85    }
86
87    let impl_assoc_items = impl_ast.assoc_item_list()?;
88    let first_element = impl_assoc_items.assoc_items().next();
89    first_element.as_ref()?;
90
91    let impl_name = impl_ast.self_ty()?;
92
93    acc.add(
94        AssistId::generate("generate_trait_from_impl"),
95        "Generate trait from impl",
96        impl_ast.syntax().text_range(),
97        |builder| {
98            let trait_items: ast::AssocItemList = {
99                let trait_items = impl_assoc_items.clone_subtree();
100                let mut trait_items_editor = SyntaxEditor::new(trait_items.syntax().clone());
101
102                trait_items.assoc_items().for_each(|item| {
103                    strip_body(&mut trait_items_editor, &item);
104                    remove_items_visibility(&mut trait_items_editor, &item);
105                });
106                ast::AssocItemList::cast(trait_items_editor.finish().new_root().clone()).unwrap()
107            };
108            let trait_ast = make::trait_(
109                false,
110                "NewTrait",
111                impl_ast.generic_param_list(),
112                impl_ast.where_clause(),
113                trait_items,
114            )
115            .clone_for_update();
116
117            let trait_name = trait_ast.name().expect("new trait should have a name");
118            let trait_name_ref = make::name_ref(&trait_name.to_string()).clone_for_update();
119
120            // Change `impl Foo` to `impl NewTrait for Foo`
121            let mut elements = vec![
122                trait_name_ref.syntax().clone().into(),
123                make::tokens::single_space().into(),
124                make::token(T![for]).into(),
125                make::tokens::single_space().into(),
126            ];
127
128            if let Some(params) = impl_ast.generic_param_list() {
129                let gen_args = &params.to_generic_args().clone_for_update();
130                elements.insert(1, gen_args.syntax().clone().into());
131            }
132
133            let mut editor = builder.make_editor(impl_ast.syntax());
134            impl_assoc_items.assoc_items().for_each(|item| {
135                remove_items_visibility(&mut editor, &item);
136            });
137
138            editor.insert_all(Position::before(impl_name.syntax()), elements);
139
140            // Insert trait before TraitImpl
141            editor.insert_all(
142                Position::before(impl_ast.syntax()),
143                vec![
144                    trait_ast.syntax().clone().into(),
145                    make::tokens::whitespace(&format!("\n\n{}", impl_ast.indent_level())).into(),
146                ],
147            );
148
149            // Link the trait name & trait ref names together as a placeholder snippet group
150            if let Some(cap) = ctx.config.snippet_cap {
151                let placeholder = builder.make_placeholder_snippet(cap);
152                editor.add_annotation(trait_name.syntax(), placeholder);
153                editor.add_annotation(trait_name_ref.syntax(), placeholder);
154            }
155
156            builder.add_file_edits(ctx.vfs_file_id(), editor);
157        },
158    );
159
160    Some(())
161}
162
163/// `E0449` Trait items always share the visibility of their trait
164fn remove_items_visibility(editor: &mut SyntaxEditor, item: &ast::AssocItem) {
165    if let Some(has_vis) = ast::AnyHasVisibility::cast(item.syntax().clone()) {
166        if let Some(vis) = has_vis.visibility()
167            && let Some(token) = vis.syntax().next_sibling_or_token()
168            && token.kind() == SyntaxKind::WHITESPACE
169        {
170            editor.delete(token);
171        }
172        if let Some(vis) = has_vis.visibility() {
173            editor.delete(vis.syntax());
174        }
175    }
176}
177
178fn strip_body(editor: &mut SyntaxEditor, item: &ast::AssocItem) {
179    if let ast::AssocItem::Fn(f) = item
180        && let Some(body) = f.body()
181    {
182        // In contrast to function bodies, we want to see no ws before a semicolon.
183        // So let's remove them if we see any.
184        if let Some(prev) = body.syntax().prev_sibling_or_token()
185            && prev.kind() == SyntaxKind::WHITESPACE
186        {
187            editor.delete(prev);
188        }
189
190        editor.replace(body.syntax(), make::tokens::semicolon());
191    };
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::tests::{check_assist, check_assist_no_snippet_cap, check_assist_not_applicable};
198
199    #[test]
200    fn test_trigger_when_cursor_on_header() {
201        check_assist_not_applicable(
202            generate_trait_from_impl,
203            r#"
204struct Foo(f64);
205
206impl Foo { $0
207    fn add(&mut self, x: f64) {
208        self.0 += x;
209    }
210}"#,
211        );
212    }
213
214    #[test]
215    fn test_assoc_item_fn() {
216        check_assist_no_snippet_cap(
217            generate_trait_from_impl,
218            r#"
219struct Foo(f64);
220
221impl F$0oo {
222    fn add(&mut self, x: f64) {
223        self.0 += x;
224    }
225}"#,
226            r#"
227struct Foo(f64);
228
229trait NewTrait {
230    fn add(&mut self, x: f64);
231}
232
233impl NewTrait for Foo {
234    fn add(&mut self, x: f64) {
235        self.0 += x;
236    }
237}"#,
238        )
239    }
240
241    #[test]
242    fn test_assoc_item_macro() {
243        check_assist_no_snippet_cap(
244            generate_trait_from_impl,
245            r#"
246struct Foo;
247
248macro_rules! const_maker {
249    ($t:ty, $v:tt) => {
250        const CONST: $t = $v;
251    };
252}
253
254impl F$0oo {
255    const_maker! {i32, 7}
256}"#,
257            r#"
258struct Foo;
259
260macro_rules! const_maker {
261    ($t:ty, $v:tt) => {
262        const CONST: $t = $v;
263    };
264}
265
266trait NewTrait {
267    const_maker! {i32, 7}
268}
269
270impl NewTrait for Foo {
271    const_maker! {i32, 7}
272}"#,
273        )
274    }
275
276    #[test]
277    fn test_assoc_item_const() {
278        check_assist_no_snippet_cap(
279            generate_trait_from_impl,
280            r#"
281struct Foo;
282
283impl F$0oo {
284    const ABC: i32 = 3;
285}"#,
286            r#"
287struct Foo;
288
289trait NewTrait {
290    const ABC: i32 = 3;
291}
292
293impl NewTrait for Foo {
294    const ABC: i32 = 3;
295}"#,
296        )
297    }
298
299    #[test]
300    fn test_impl_with_generics() {
301        check_assist_no_snippet_cap(
302            generate_trait_from_impl,
303            r#"
304struct Foo<const N: usize>([i32; N]);
305
306impl<const N: usize> F$0oo<N> {
307    // Used as an associated constant.
308    const CONST: usize = N * 4;
309}
310            "#,
311            r#"
312struct Foo<const N: usize>([i32; N]);
313
314trait NewTrait<const N: usize> {
315    // Used as an associated constant.
316    const CONST: usize = N * 4;
317}
318
319impl<const N: usize> NewTrait<N> for Foo<N> {
320    // Used as an associated constant.
321    const CONST: usize = N * 4;
322}
323            "#,
324        )
325    }
326
327    #[test]
328    fn test_trait_items_should_not_have_vis() {
329        check_assist_no_snippet_cap(
330            generate_trait_from_impl,
331            r#"
332struct Foo;
333
334impl F$0oo {
335    pub fn a_func() -> Option<()> {
336        Some(())
337    }
338}"#,
339            r#"
340struct Foo;
341
342trait NewTrait {
343    fn a_func() -> Option<()>;
344}
345
346impl NewTrait for Foo {
347    fn a_func() -> Option<()> {
348        Some(())
349    }
350}"#,
351        )
352    }
353
354    #[test]
355    fn test_empty_inherent_impl() {
356        check_assist_not_applicable(
357            generate_trait_from_impl,
358            r#"
359impl Emp$0tyImpl{}
360"#,
361        )
362    }
363
364    #[test]
365    fn test_not_top_level_impl() {
366        check_assist_no_snippet_cap(
367            generate_trait_from_impl,
368            r#"
369mod a {
370    impl S$0 {
371        fn foo() {}
372    }
373}"#,
374            r#"
375mod a {
376    trait NewTrait {
377        fn foo();
378    }
379
380    impl NewTrait for S {
381        fn foo() {}
382    }
383}"#,
384        )
385    }
386
387    #[test]
388    fn test_snippet_cap_is_some() {
389        check_assist(
390            generate_trait_from_impl,
391            r#"
392struct Foo<const N: usize>([i32; N]);
393
394impl<const N: usize> F$0oo<N> {
395    // Used as an associated constant.
396    const CONST: usize = N * 4;
397}
398            "#,
399            r#"
400struct Foo<const N: usize>([i32; N]);
401
402trait ${0:NewTrait}<const N: usize> {
403    // Used as an associated constant.
404    const CONST: usize = N * 4;
405}
406
407impl<const N: usize> ${0:NewTrait}<N> for Foo<N> {
408    // Used as an associated constant.
409    const CONST: usize = N * 4;
410}
411            "#,
412        )
413    }
414}