Skip to main content

ide_assists/handlers/
generate_trait_from_impl.rs

1use crate::assist_context::{AssistContext, Assists};
2use ide_db::assists::AssistId;
3use syntax::{
4    AstNode, AstToken, SyntaxKind, T,
5    ast::{
6        self, HasDocComments, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit,
7        syntax_factory::SyntaxFactory,
8    },
9    syntax_editor::{Position, SyntaxEditor},
10};
11
12// NOTES :
13// We generate erroneous code if a function is declared const (E0379)
14// This is left to the user to correct as our only option is to remove the
15// function completely which we should not be doing.
16
17// Assist: generate_trait_from_impl
18//
19// Generate trait for an already defined inherent impl and convert impl to a trait impl.
20//
21// ```
22// struct Foo<const N: usize>([i32; N]);
23//
24// macro_rules! const_maker {
25//     ($t:ty, $v:tt) => {
26//         const CONST: $t = $v;
27//     };
28// }
29//
30// impl<const N: usize> Fo$0o<N> {
31//     // Used as an associated constant.
32//     const CONST_ASSOC: usize = N * 4;
33//
34//     fn create() -> Option<()> {
35//         Some(())
36//     }
37//
38//     const_maker! {i32, 7}
39// }
40// ```
41// ->
42// ```
43// struct Foo<const N: usize>([i32; N]);
44//
45// macro_rules! const_maker {
46//     ($t:ty, $v:tt) => {
47//         const CONST: $t = $v;
48//     };
49// }
50//
51// trait ${0:Create}<const N: usize> {
52//     // Used as an associated constant.
53//     const CONST_ASSOC: usize = N * 4;
54//
55//     fn create() -> Option<()>;
56//
57//     const_maker! {i32, 7}
58// }
59//
60// impl<const N: usize> ${0:Create}<N> for Foo<N> {
61//     // Used as an associated constant.
62//     const CONST_ASSOC: usize = N * 4;
63//
64//     fn create() -> Option<()> {
65//         Some(())
66//     }
67//
68//     const_maker! {i32, 7}
69// }
70// ```
71pub(crate) fn generate_trait_from_impl(
72    acc: &mut Assists,
73    ctx: &AssistContext<'_, '_>,
74) -> Option<()> {
75    // Get AST Node
76    let impl_ast = ctx.find_node_at_offset::<ast::Impl>()?;
77
78    // Check if cursor is to the left of assoc item list's L_CURLY.
79    // if no L_CURLY then return.
80    let l_curly = impl_ast.assoc_item_list()?.l_curly_token()?;
81
82    let cursor_offset = ctx.offset();
83    let l_curly_offset = l_curly.text_range();
84    if cursor_offset >= l_curly_offset.start() {
85        return None;
86    }
87
88    // If impl is not inherent then we don't really need to go any further.
89    if impl_ast.for_token().is_some() {
90        return None;
91    }
92
93    let impl_assoc_items = impl_ast.assoc_item_list()?;
94    let first_element = impl_assoc_items.assoc_items().next();
95    first_element.as_ref()?;
96
97    let impl_name = impl_ast.self_ty()?;
98
99    acc.add(
100        AssistId::generate("generate_trait_from_impl"),
101        "Generate trait from impl",
102        impl_ast.syntax().text_range(),
103        |builder| {
104            let trait_items: ast::AssocItemList = {
105                let (trait_items_editor, trait_items) =
106                    SyntaxEditor::with_ast_node(&impl_assoc_items);
107
108                trait_items.assoc_items().for_each(|item| {
109                    strip_body(&trait_items_editor, &item);
110                    remove_items_visibility(&trait_items_editor, &item);
111                });
112                ast::AssocItemList::cast(trait_items_editor.finish().new_root().clone()).unwrap()
113            };
114
115            let editor = builder.make_editor(impl_ast.syntax());
116            let make = editor.make();
117            let trait_ast = make.trait_(
118                false,
119                &trait_name(&impl_assoc_items, make).text(),
120                impl_ast.generic_param_list(),
121                impl_ast.where_clause(),
122                trait_items,
123            );
124
125            let trait_name = trait_ast.name().expect("new trait should have a name");
126            let trait_name_ref = make.name_ref(&trait_name.to_string());
127
128            // Change `impl Foo` to `impl NewTrait for Foo`
129            let mut elements = vec![
130                trait_name_ref.syntax().clone().into(),
131                make.whitespace(" ").into(),
132                make.token(T![for]).into(),
133                make.whitespace(" ").into(),
134            ];
135
136            if let Some(params) = impl_ast.generic_param_list() {
137                let gen_args = &params.to_generic_args(make);
138                elements.insert(1, gen_args.syntax().clone().into());
139            }
140
141            impl_assoc_items.assoc_items().for_each(|item| {
142                remove_items_visibility(&editor, &item);
143                remove_doc_comments(&editor, &item);
144            });
145
146            editor.insert_all(Position::before(impl_name.syntax()), elements);
147
148            // Insert trait before TraitImpl
149            editor.insert_all(
150                Position::before(impl_ast.syntax()),
151                vec![
152                    trait_ast.syntax().clone().into(),
153                    make.whitespace(&format!("\n\n{}", impl_ast.indent_level())).into(),
154                ],
155            );
156
157            // Link the trait name & trait ref names together as a placeholder snippet group
158            if let Some(cap) = ctx.config.snippet_cap {
159                let placeholder = builder.make_placeholder_snippet(cap);
160                editor.add_annotation(trait_name.syntax(), placeholder);
161                editor.add_annotation(trait_name_ref.syntax(), placeholder);
162            }
163            builder.add_file_edits(ctx.vfs_file_id(), editor);
164        },
165    );
166
167    Some(())
168}
169
170fn trait_name(items: &ast::AssocItemList, make: &SyntaxFactory) -> ast::Name {
171    let mut fn_names = items
172        .assoc_items()
173        .filter_map(|x| if let ast::AssocItem::Fn(f) = x { f.name() } else { None });
174    fn_names
175        .next()
176        .and_then(|name| {
177            fn_names.next().is_none().then(|| make.name(&stdx::to_camel_case(&name.text())))
178        })
179        .unwrap_or_else(|| make.name("NewTrait"))
180}
181
182/// `E0449` Trait items always share the visibility of their trait
183fn remove_items_visibility(editor: &SyntaxEditor, item: &ast::AssocItem) {
184    if let Some(has_vis) = ast::AnyHasVisibility::cast(item.syntax().clone()) {
185        if let Some(vis) = has_vis.visibility()
186            && let Some(token) = vis.syntax().next_sibling_or_token()
187            && token.kind() == SyntaxKind::WHITESPACE
188        {
189            editor.delete(token);
190        }
191        if let Some(vis) = has_vis.visibility() {
192            editor.delete(vis.syntax());
193        }
194    }
195}
196
197fn remove_doc_comments(editor: &SyntaxEditor, item: &ast::AssocItem) {
198    for doc in item.doc_comments() {
199        if let Some(next) = doc.syntax().next_token()
200            && next.kind() == SyntaxKind::WHITESPACE
201        {
202            editor.delete(next);
203        }
204        editor.delete(doc.syntax());
205    }
206}
207
208fn strip_body(editor: &SyntaxEditor, item: &ast::AssocItem) {
209    let make = editor.make();
210    if let ast::AssocItem::Fn(f) = item
211        && let Some(body) = f.body()
212    {
213        // In contrast to function bodies, we want to see no ws before a semicolon.
214        // So let's remove them if we see any.
215        if let Some(prev) = body.syntax().prev_sibling_or_token()
216            && prev.kind() == SyntaxKind::WHITESPACE
217        {
218            editor.delete(prev);
219        }
220
221        editor.replace(body.syntax(), make.token(T![;]));
222    };
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::tests::{check_assist, check_assist_no_snippet_cap, check_assist_not_applicable};
229
230    #[test]
231    fn test_trigger_when_cursor_on_header() {
232        check_assist_not_applicable(
233            generate_trait_from_impl,
234            r#"
235struct Foo(f64);
236
237impl Foo { $0
238    fn add(&mut self, x: f64) {
239        self.0 += x;
240    }
241}"#,
242        );
243    }
244
245    #[test]
246    fn test_assoc_item_fn() {
247        check_assist_no_snippet_cap(
248            generate_trait_from_impl,
249            r#"
250struct Foo(f64);
251
252impl F$0oo {
253    fn add(&mut self, x: f64) {
254        self.0 += x;
255    }
256}"#,
257            r#"
258struct Foo(f64);
259
260trait Add {
261    fn add(&mut self, x: f64);
262}
263
264impl Add for Foo {
265    fn add(&mut self, x: f64) {
266        self.0 += x;
267    }
268}"#,
269        )
270    }
271
272    #[test]
273    fn test_remove_doc_comments() {
274        check_assist_no_snippet_cap(
275            generate_trait_from_impl,
276            r#"
277struct Foo(f64);
278
279impl F$0oo {
280    /// Add `x`
281    ///
282    /// # Examples
283    #[cfg(true)]
284    fn add(&mut self, x: f64) {
285        self.0 += x;
286    }
287}"#,
288            r#"
289struct Foo(f64);
290
291trait Add {
292    /// Add `x`
293    ///
294    /// # Examples
295    #[cfg(true)]
296    fn add(&mut self, x: f64);
297}
298
299impl Add for Foo {
300    #[cfg(true)]
301    fn add(&mut self, x: f64) {
302        self.0 += x;
303    }
304}"#,
305        )
306    }
307
308    #[test]
309    fn test_assoc_item_macro() {
310        check_assist_no_snippet_cap(
311            generate_trait_from_impl,
312            r#"
313struct Foo;
314
315macro_rules! const_maker {
316    ($t:ty, $v:tt) => {
317        const CONST: $t = $v;
318    };
319}
320
321impl F$0oo {
322    const_maker! {i32, 7}
323}"#,
324            r#"
325struct Foo;
326
327macro_rules! const_maker {
328    ($t:ty, $v:tt) => {
329        const CONST: $t = $v;
330    };
331}
332
333trait NewTrait {
334    const_maker! {i32, 7}
335}
336
337impl NewTrait for Foo {
338    const_maker! {i32, 7}
339}"#,
340        )
341    }
342
343    #[test]
344    fn test_assoc_item_const() {
345        check_assist_no_snippet_cap(
346            generate_trait_from_impl,
347            r#"
348struct Foo;
349
350impl F$0oo {
351    const ABC: i32 = 3;
352}"#,
353            r#"
354struct Foo;
355
356trait NewTrait {
357    const ABC: i32 = 3;
358}
359
360impl NewTrait for Foo {
361    const ABC: i32 = 3;
362}"#,
363        )
364    }
365
366    #[test]
367    fn test_impl_with_generics() {
368        check_assist_no_snippet_cap(
369            generate_trait_from_impl,
370            r#"
371struct Foo<const N: usize>([i32; N]);
372
373impl<const N: usize> F$0oo<N> {
374    // Used as an associated constant.
375    const CONST: usize = N * 4;
376}
377            "#,
378            r#"
379struct Foo<const N: usize>([i32; N]);
380
381trait NewTrait<const N: usize> {
382    // Used as an associated constant.
383    const CONST: usize = N * 4;
384}
385
386impl<const N: usize> NewTrait<N> for Foo<N> {
387    // Used as an associated constant.
388    const CONST: usize = N * 4;
389}
390            "#,
391        )
392    }
393
394    #[test]
395    fn test_trait_items_should_not_have_vis() {
396        check_assist_no_snippet_cap(
397            generate_trait_from_impl,
398            r#"
399struct Foo;
400
401impl F$0oo {
402    pub fn a_func() -> Option<()> {
403        Some(())
404    }
405}"#,
406            r#"
407struct Foo;
408
409trait AFunc {
410    fn a_func() -> Option<()>;
411}
412
413impl AFunc for Foo {
414    fn a_func() -> Option<()> {
415        Some(())
416    }
417}"#,
418        )
419    }
420
421    #[test]
422    fn test_empty_inherent_impl() {
423        check_assist_not_applicable(
424            generate_trait_from_impl,
425            r#"
426impl Emp$0tyImpl{}
427"#,
428        )
429    }
430
431    #[test]
432    fn test_not_top_level_impl() {
433        check_assist_no_snippet_cap(
434            generate_trait_from_impl,
435            r#"
436mod a {
437    impl S$0 {
438        fn foo() {}
439    }
440}"#,
441            r#"
442mod a {
443    trait Foo {
444        fn foo();
445    }
446
447    impl Foo for S {
448        fn foo() {}
449    }
450}"#,
451        )
452    }
453
454    #[test]
455    fn test_multi_fn_impl_not_suggest_trait_name() {
456        check_assist_no_snippet_cap(
457            generate_trait_from_impl,
458            r#"
459impl S$0 {
460    fn foo() {}
461    fn bar() {}
462}"#,
463            r#"
464trait NewTrait {
465    fn foo();
466    fn bar();
467}
468
469impl NewTrait for S {
470    fn foo() {}
471    fn bar() {}
472}"#,
473        )
474    }
475
476    #[test]
477    fn test_snippet_cap_is_some() {
478        check_assist(
479            generate_trait_from_impl,
480            r#"
481struct Foo<const N: usize>([i32; N]);
482
483impl<const N: usize> F$0oo<N> {
484    // Used as an associated constant.
485    const CONST: usize = N * 4;
486}
487            "#,
488            r#"
489struct Foo<const N: usize>([i32; N]);
490
491trait ${0:NewTrait}<const N: usize> {
492    // Used as an associated constant.
493    const CONST: usize = N * 4;
494}
495
496impl<const N: usize> ${0:NewTrait}<N> for Foo<N> {
497    // Used as an associated constant.
498    const CONST: usize = N * 4;
499}
500            "#,
501        )
502    }
503}