1use crate::assist_context::{AssistContext, Assists};
2use ide_db::assists::AssistId;
3use syntax::{
4 AstNode, AstToken, SyntaxKind, T,
5 ast::{
6 self, HasDocComments, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit,
7 syntax_factory::SyntaxFactory,
8 },
9 syntax_editor::{Position, SyntaxEditor},
10};
11
12pub(crate) fn generate_trait_from_impl(
72 acc: &mut Assists,
73 ctx: &AssistContext<'_, '_>,
74) -> Option<()> {
75 let impl_ast = ctx.find_node_at_offset::<ast::Impl>()?;
77
78 let l_curly = impl_ast.assoc_item_list()?.l_curly_token()?;
81
82 let cursor_offset = ctx.offset();
83 let l_curly_offset = l_curly.text_range();
84 if cursor_offset >= l_curly_offset.start() {
85 return None;
86 }
87
88 if impl_ast.for_token().is_some() {
90 return None;
91 }
92
93 let impl_assoc_items = impl_ast.assoc_item_list()?;
94 let first_element = impl_assoc_items.assoc_items().next();
95 first_element.as_ref()?;
96
97 let impl_name = impl_ast.self_ty()?;
98
99 acc.add(
100 AssistId::generate("generate_trait_from_impl"),
101 "Generate trait from impl",
102 impl_ast.syntax().text_range(),
103 |builder| {
104 let trait_items: ast::AssocItemList = {
105 let (trait_items_editor, trait_items) =
106 SyntaxEditor::with_ast_node(&impl_assoc_items);
107
108 trait_items.assoc_items().for_each(|item| {
109 strip_body(&trait_items_editor, &item);
110 remove_items_visibility(&trait_items_editor, &item);
111 });
112 ast::AssocItemList::cast(trait_items_editor.finish().new_root().clone()).unwrap()
113 };
114
115 let editor = builder.make_editor(impl_ast.syntax());
116 let make = editor.make();
117 let trait_ast = make.trait_(
118 false,
119 &trait_name(&impl_assoc_items, make).text(),
120 impl_ast.generic_param_list(),
121 impl_ast.where_clause(),
122 trait_items,
123 );
124
125 let trait_name = trait_ast.name().expect("new trait should have a name");
126 let trait_name_ref = make.name_ref(&trait_name.to_string());
127
128 let mut elements = vec![
130 trait_name_ref.syntax().clone().into(),
131 make.whitespace(" ").into(),
132 make.token(T![for]).into(),
133 make.whitespace(" ").into(),
134 ];
135
136 if let Some(params) = impl_ast.generic_param_list() {
137 let gen_args = ¶ms.to_generic_args(make);
138 elements.insert(1, gen_args.syntax().clone().into());
139 }
140
141 impl_assoc_items.assoc_items().for_each(|item| {
142 remove_items_visibility(&editor, &item);
143 remove_doc_comments(&editor, &item);
144 });
145
146 editor.insert_all(Position::before(impl_name.syntax()), elements);
147
148 editor.insert_all(
150 Position::before(impl_ast.syntax()),
151 vec![
152 trait_ast.syntax().clone().into(),
153 make.whitespace(&format!("\n\n{}", impl_ast.indent_level())).into(),
154 ],
155 );
156
157 if let Some(cap) = ctx.config.snippet_cap {
159 let placeholder = builder.make_placeholder_snippet(cap);
160 editor.add_annotation(trait_name.syntax(), placeholder);
161 editor.add_annotation(trait_name_ref.syntax(), placeholder);
162 }
163 builder.add_file_edits(ctx.vfs_file_id(), editor);
164 },
165 );
166
167 Some(())
168}
169
170fn trait_name(items: &ast::AssocItemList, make: &SyntaxFactory) -> ast::Name {
171 let mut fn_names = items
172 .assoc_items()
173 .filter_map(|x| if let ast::AssocItem::Fn(f) = x { f.name() } else { None });
174 fn_names
175 .next()
176 .and_then(|name| {
177 fn_names.next().is_none().then(|| make.name(&stdx::to_camel_case(&name.text())))
178 })
179 .unwrap_or_else(|| make.name("NewTrait"))
180}
181
182fn remove_items_visibility(editor: &SyntaxEditor, item: &ast::AssocItem) {
184 if let Some(has_vis) = ast::AnyHasVisibility::cast(item.syntax().clone()) {
185 if let Some(vis) = has_vis.visibility()
186 && let Some(token) = vis.syntax().next_sibling_or_token()
187 && token.kind() == SyntaxKind::WHITESPACE
188 {
189 editor.delete(token);
190 }
191 if let Some(vis) = has_vis.visibility() {
192 editor.delete(vis.syntax());
193 }
194 }
195}
196
197fn remove_doc_comments(editor: &SyntaxEditor, item: &ast::AssocItem) {
198 for doc in item.doc_comments() {
199 if let Some(next) = doc.syntax().next_token()
200 && next.kind() == SyntaxKind::WHITESPACE
201 {
202 editor.delete(next);
203 }
204 editor.delete(doc.syntax());
205 }
206}
207
208fn strip_body(editor: &SyntaxEditor, item: &ast::AssocItem) {
209 let make = editor.make();
210 if let ast::AssocItem::Fn(f) = item
211 && let Some(body) = f.body()
212 {
213 if let Some(prev) = body.syntax().prev_sibling_or_token()
216 && prev.kind() == SyntaxKind::WHITESPACE
217 {
218 editor.delete(prev);
219 }
220
221 editor.replace(body.syntax(), make.token(T![;]));
222 };
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::tests::{check_assist, check_assist_no_snippet_cap, check_assist_not_applicable};
229
230 #[test]
231 fn test_trigger_when_cursor_on_header() {
232 check_assist_not_applicable(
233 generate_trait_from_impl,
234 r#"
235struct Foo(f64);
236
237impl Foo { $0
238 fn add(&mut self, x: f64) {
239 self.0 += x;
240 }
241}"#,
242 );
243 }
244
245 #[test]
246 fn test_assoc_item_fn() {
247 check_assist_no_snippet_cap(
248 generate_trait_from_impl,
249 r#"
250struct Foo(f64);
251
252impl F$0oo {
253 fn add(&mut self, x: f64) {
254 self.0 += x;
255 }
256}"#,
257 r#"
258struct Foo(f64);
259
260trait Add {
261 fn add(&mut self, x: f64);
262}
263
264impl Add for Foo {
265 fn add(&mut self, x: f64) {
266 self.0 += x;
267 }
268}"#,
269 )
270 }
271
272 #[test]
273 fn test_remove_doc_comments() {
274 check_assist_no_snippet_cap(
275 generate_trait_from_impl,
276 r#"
277struct Foo(f64);
278
279impl F$0oo {
280 /// Add `x`
281 ///
282 /// # Examples
283 #[cfg(true)]
284 fn add(&mut self, x: f64) {
285 self.0 += x;
286 }
287}"#,
288 r#"
289struct Foo(f64);
290
291trait Add {
292 /// Add `x`
293 ///
294 /// # Examples
295 #[cfg(true)]
296 fn add(&mut self, x: f64);
297}
298
299impl Add for Foo {
300 #[cfg(true)]
301 fn add(&mut self, x: f64) {
302 self.0 += x;
303 }
304}"#,
305 )
306 }
307
308 #[test]
309 fn test_assoc_item_macro() {
310 check_assist_no_snippet_cap(
311 generate_trait_from_impl,
312 r#"
313struct Foo;
314
315macro_rules! const_maker {
316 ($t:ty, $v:tt) => {
317 const CONST: $t = $v;
318 };
319}
320
321impl F$0oo {
322 const_maker! {i32, 7}
323}"#,
324 r#"
325struct Foo;
326
327macro_rules! const_maker {
328 ($t:ty, $v:tt) => {
329 const CONST: $t = $v;
330 };
331}
332
333trait NewTrait {
334 const_maker! {i32, 7}
335}
336
337impl NewTrait for Foo {
338 const_maker! {i32, 7}
339}"#,
340 )
341 }
342
343 #[test]
344 fn test_assoc_item_const() {
345 check_assist_no_snippet_cap(
346 generate_trait_from_impl,
347 r#"
348struct Foo;
349
350impl F$0oo {
351 const ABC: i32 = 3;
352}"#,
353 r#"
354struct Foo;
355
356trait NewTrait {
357 const ABC: i32 = 3;
358}
359
360impl NewTrait for Foo {
361 const ABC: i32 = 3;
362}"#,
363 )
364 }
365
366 #[test]
367 fn test_impl_with_generics() {
368 check_assist_no_snippet_cap(
369 generate_trait_from_impl,
370 r#"
371struct Foo<const N: usize>([i32; N]);
372
373impl<const N: usize> F$0oo<N> {
374 // Used as an associated constant.
375 const CONST: usize = N * 4;
376}
377 "#,
378 r#"
379struct Foo<const N: usize>([i32; N]);
380
381trait NewTrait<const N: usize> {
382 // Used as an associated constant.
383 const CONST: usize = N * 4;
384}
385
386impl<const N: usize> NewTrait<N> for Foo<N> {
387 // Used as an associated constant.
388 const CONST: usize = N * 4;
389}
390 "#,
391 )
392 }
393
394 #[test]
395 fn test_trait_items_should_not_have_vis() {
396 check_assist_no_snippet_cap(
397 generate_trait_from_impl,
398 r#"
399struct Foo;
400
401impl F$0oo {
402 pub fn a_func() -> Option<()> {
403 Some(())
404 }
405}"#,
406 r#"
407struct Foo;
408
409trait AFunc {
410 fn a_func() -> Option<()>;
411}
412
413impl AFunc for Foo {
414 fn a_func() -> Option<()> {
415 Some(())
416 }
417}"#,
418 )
419 }
420
421 #[test]
422 fn test_empty_inherent_impl() {
423 check_assist_not_applicable(
424 generate_trait_from_impl,
425 r#"
426impl Emp$0tyImpl{}
427"#,
428 )
429 }
430
431 #[test]
432 fn test_not_top_level_impl() {
433 check_assist_no_snippet_cap(
434 generate_trait_from_impl,
435 r#"
436mod a {
437 impl S$0 {
438 fn foo() {}
439 }
440}"#,
441 r#"
442mod a {
443 trait Foo {
444 fn foo();
445 }
446
447 impl Foo for S {
448 fn foo() {}
449 }
450}"#,
451 )
452 }
453
454 #[test]
455 fn test_multi_fn_impl_not_suggest_trait_name() {
456 check_assist_no_snippet_cap(
457 generate_trait_from_impl,
458 r#"
459impl S$0 {
460 fn foo() {}
461 fn bar() {}
462}"#,
463 r#"
464trait NewTrait {
465 fn foo();
466 fn bar();
467}
468
469impl NewTrait for S {
470 fn foo() {}
471 fn bar() {}
472}"#,
473 )
474 }
475
476 #[test]
477 fn test_snippet_cap_is_some() {
478 check_assist(
479 generate_trait_from_impl,
480 r#"
481struct Foo<const N: usize>([i32; N]);
482
483impl<const N: usize> F$0oo<N> {
484 // Used as an associated constant.
485 const CONST: usize = N * 4;
486}
487 "#,
488 r#"
489struct Foo<const N: usize>([i32; N]);
490
491trait ${0:NewTrait}<const N: usize> {
492 // Used as an associated constant.
493 const CONST: usize = N * 4;
494}
495
496impl<const N: usize> ${0:NewTrait}<N> for Foo<N> {
497 // Used as an associated constant.
498 const CONST: usize = N * 4;
499}
500 "#,
501 )
502 }
503}