Skip to main content

ide_assists/handlers/
extract_type_alias.rs

1use either::Either;
2use hir::HirDisplay;
3use ide_db::syntax_helpers::{node_ext::walk_ty, suggest_name::NameGenerator};
4use syntax::{
5    ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel},
6    syntax_editor,
7};
8
9use crate::{AssistContext, AssistId, Assists};
10
11// Assist: extract_type_alias
12//
13// Extracts the selected type as a type alias.
14//
15// ```
16// struct S {
17//     field: $0(u8, u8, u8)$0,
18// }
19// ```
20// ->
21// ```
22// type $0Type = (u8, u8, u8);
23//
24// struct S {
25//     field: Type,
26// }
27// ```
28pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
29    if ctx.has_empty_selection() {
30        return None;
31    }
32
33    let ty = ctx.find_node_at_range::<ast::Type>()?;
34    let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
35    let assoc_owner =
36        item.syntax().ancestors().nth(2).and_then(Either::<ast::Trait, ast::Impl>::cast);
37    let node = assoc_owner.as_ref().map_or_else(
38        || item.syntax(),
39        |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
40    );
41    let target = ty.syntax().text_range();
42
43    let scope = ctx.sema.scope(ty.syntax())?;
44    let resolved_ty = ctx.sema.resolve_type(&ty)?;
45    let resolved_ty = if !resolved_ty.contains_unknown() {
46        let module = scope.module();
47        resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?
48    } else {
49        ty.to_string()
50    };
51
52    acc.add(
53        AssistId::refactor_extract("extract_type_alias"),
54        "Extract type as type alias",
55        target,
56        |builder| {
57            let editor = builder.make_editor(node);
58            let make = editor.make();
59
60            let resolved_ty = make.ty(&resolved_ty);
61            let name = &NameGenerator::new_from_scope_non_locals(Some(scope)).suggest_name("Type");
62
63            let mut known_generics = match item.generic_param_list() {
64                Some(it) => it.generic_params().collect(),
65                None => Vec::new(),
66            };
67            if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
68                Either::Left(it) => it.generic_param_list(),
69                Either::Right(it) => it.generic_param_list(),
70            }) {
71                known_generics.extend(it.generic_params());
72            }
73            let generics = collect_used_generics(&ty, &known_generics);
74            let generic_params =
75                generics.map(|it| make.generic_param_list(it.into_iter().cloned()));
76
77            // Replace original type with the alias
78            let ty_args = generic_params.as_ref().map(|it| it.to_generic_args(make).generic_args());
79            let new_ty = if let Some(ty_args) = ty_args {
80                make.generic_ty_path_segment(make.name_ref(name), ty_args)
81            } else {
82                make.path_segment(make.name_ref(name))
83            };
84            editor.replace(ty.syntax(), new_ty.syntax());
85
86            // Insert new alias
87            let ty_alias =
88                make.ty_alias(None, name, generic_params, None, None, Some((resolved_ty, None)));
89
90            if let Some(cap) = ctx.config.snippet_cap
91                && let Some(name) = ty_alias.name()
92            {
93                editor.add_annotation(name.syntax(), builder.make_tabstop_before(cap));
94            }
95
96            let indent = IndentLevel::from_node(node);
97            editor.insert_all(
98                syntax_editor::Position::before(node),
99                vec![
100                    ty_alias.syntax().clone().into(),
101                    make.whitespace(&format!("\n\n{indent}")).into(),
102                ],
103            );
104
105            builder.add_file_edits(ctx.vfs_file_id(), editor);
106        },
107    )
108}
109
110fn collect_used_generics<'gp>(
111    ty: &ast::Type,
112    known_generics: &'gp [ast::GenericParam],
113) -> Option<Vec<&'gp ast::GenericParam>> {
114    // can't use a closure -> closure here cause lifetime inference fails for that
115    fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
116        move |gp: &&ast::GenericParam| match gp {
117            ast::GenericParam::LifetimeParam(lp) => {
118                lp.lifetime().is_some_and(|lt| lt.text() == text)
119            }
120            _ => false,
121        }
122    }
123
124    let mut generics = Vec::new();
125    walk_ty(ty, &mut |ty| {
126        match ty {
127            ast::Type::PathType(ty) => {
128                if let Some(path) = ty.path() {
129                    if let Some(name_ref) = path.as_single_name_ref()
130                        && let Some(param) = known_generics.iter().find(|gp| {
131                            match gp {
132                                ast::GenericParam::ConstParam(cp) => cp.name(),
133                                ast::GenericParam::TypeParam(tp) => tp.name(),
134                                _ => None,
135                            }
136                            .is_some_and(|n| n.text() == name_ref.text())
137                        })
138                    {
139                        generics.push(param);
140                    }
141                    generics.extend(
142                        path.segments()
143                            .filter_map(|seg| seg.generic_arg_list())
144                            .flat_map(|it| it.generic_args())
145                            .filter_map(|it| match it {
146                                ast::GenericArg::LifetimeArg(lt) => {
147                                    let lt = lt.lifetime()?;
148                                    known_generics.iter().find(find_lifetime(&lt.text()))
149                                }
150                                _ => None,
151                            }),
152                    );
153                }
154            }
155            ast::Type::ImplTraitType(impl_ty) => {
156                if let Some(it) = impl_ty.type_bound_list() {
157                    generics.extend(
158                        it.bounds()
159                            .filter_map(|it| it.lifetime())
160                            .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
161                    );
162                }
163            }
164            ast::Type::DynTraitType(dyn_ty) => {
165                if let Some(it) = dyn_ty.type_bound_list() {
166                    generics.extend(
167                        it.bounds()
168                            .filter_map(|it| it.lifetime())
169                            .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
170                    );
171                }
172            }
173            ast::Type::RefType(ref_) => generics.extend(
174                ref_.lifetime()
175                    .and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
176            ),
177            ast::Type::ArrayType(ar) => {
178                if let Some(ast::Expr::PathExpr(p)) = ar.const_arg().and_then(|x| x.expr())
179                    && let Some(path) = p.path()
180                    && let Some(name_ref) = path.as_single_name_ref()
181                    && let Some(param) = known_generics.iter().find(|gp| {
182                        if let ast::GenericParam::ConstParam(cp) = gp {
183                            cp.name().is_some_and(|n| n.text() == name_ref.text())
184                        } else {
185                            false
186                        }
187                    })
188                {
189                    generics.push(param);
190                }
191            }
192            _ => (),
193        };
194        false
195    });
196    // stable resort to lifetime, type, const
197    generics.sort_by_key(|gp| match gp {
198        ast::GenericParam::ConstParam(_) => 2,
199        ast::GenericParam::LifetimeParam(_) => 0,
200        ast::GenericParam::TypeParam(_) => 1,
201    });
202
203    Some(generics).filter(|it| !it.is_empty())
204}
205
206#[cfg(test)]
207mod tests {
208    use crate::tests::{check_assist, check_assist_not_applicable};
209
210    use super::*;
211
212    #[test]
213    fn test_not_applicable_without_selection() {
214        check_assist_not_applicable(
215            extract_type_alias,
216            r"
217struct S {
218    field: $0(u8, u8, u8),
219}
220            ",
221        );
222    }
223
224    #[test]
225    fn test_simple_types() {
226        check_assist(
227            extract_type_alias,
228            r"
229struct S {
230    field: $0u8$0,
231}
232            ",
233            r#"
234type $0Type = u8;
235
236struct S {
237    field: Type,
238}
239            "#,
240        );
241    }
242
243    #[test]
244    fn test_generic_type_arg() {
245        check_assist(
246            extract_type_alias,
247            r"
248fn generic<T>() {}
249
250fn f() {
251    generic::<$0()$0>();
252}
253            ",
254            r#"
255fn generic<T>() {}
256
257type $0Type = ();
258
259fn f() {
260    generic::<Type>();
261}
262            "#,
263        );
264    }
265
266    #[test]
267    fn test_inner_type_arg() {
268        check_assist(
269            extract_type_alias,
270            r"
271struct Vec<T> {}
272struct S {
273    v: Vec<Vec<$0Vec<u8>$0>>,
274}
275            ",
276            r#"
277struct Vec<T> {}
278type $0Type = Vec<u8>;
279
280struct S {
281    v: Vec<Vec<Type>>,
282}
283            "#,
284        );
285    }
286
287    #[test]
288    fn test_extract_inner_type() {
289        check_assist(
290            extract_type_alias,
291            r"
292struct S {
293    field: ($0u8$0,),
294}
295            ",
296            r#"
297type $0Type = u8;
298
299struct S {
300    field: (Type,),
301}
302            "#,
303        );
304    }
305
306    #[test]
307    fn extract_from_impl_or_trait() {
308        // When invoked in an impl/trait, extracted type alias should be placed next to the
309        // impl/trait, not inside.
310        check_assist(
311            extract_type_alias,
312            r#"
313impl S {
314    fn f() -> $0(u8, u8)$0 {}
315}
316            "#,
317            r#"
318type $0Type = (u8, u8);
319
320impl S {
321    fn f() -> Type {}
322}
323            "#,
324        );
325        check_assist(
326            extract_type_alias,
327            r#"
328trait Tr {
329    fn f() -> $0(u8, u8)$0 {}
330}
331            "#,
332            r#"
333type $0Type = (u8, u8);
334
335trait Tr {
336    fn f() -> Type {}
337}
338            "#,
339        );
340    }
341
342    #[test]
343    fn indentation() {
344        check_assist(
345            extract_type_alias,
346            r#"
347mod m {
348    fn f() -> $0u8$0 {}
349}
350            "#,
351            r#"
352mod m {
353    type $0Type = u8;
354
355    fn f() -> Type {}
356}
357            "#,
358        );
359    }
360
361    #[test]
362    fn generics() {
363        check_assist(
364            extract_type_alias,
365            r#"
366struct Struct<const C: usize>;
367impl<'outer, Outer, const OUTER: usize> () {
368    fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
369}
370"#,
371            r#"
372struct Struct<const C: usize>;
373type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
374
375impl<'outer, Outer, const OUTER: usize> () {
376    fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
377}
378"#,
379        );
380    }
381
382    #[test]
383    fn issue_11197() {
384        check_assist(
385            extract_type_alias,
386            r#"
387struct Foo<T, const N: usize>
388where
389    [T; N]: Sized,
390{
391    arr: $0[T; N]$0,
392}
393            "#,
394            r#"
395type $0Type<T, const N: usize> = [T; N];
396
397struct Foo<T, const N: usize>
398where
399    [T; N]: Sized,
400{
401    arr: Type<T, N>,
402}
403            "#,
404        );
405    }
406
407    #[test]
408    fn inferred_generic_type_parameter() {
409        check_assist(
410            extract_type_alias,
411            r#"
412struct Wrap<T>(T);
413
414fn main() {
415    let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32);
416}
417            "#,
418            r#"
419struct Wrap<T>(T);
420
421type $0Type = Wrap<i32>;
422
423fn main() {
424    let wrap: Type = Wrap::<_>(3i32);
425}
426            "#,
427        )
428    }
429
430    #[test]
431    fn inferred_type() {
432        check_assist(
433            extract_type_alias,
434            r#"
435struct Wrap<T>(T);
436
437fn main() {
438    let wrap: Wrap<$0_$0> = Wrap::<_>(3i32);
439}
440            "#,
441            r#"
442struct Wrap<T>(T);
443
444type $0Type = i32;
445
446fn main() {
447    let wrap: Wrap<Type> = Wrap::<_>(3i32);
448}
449            "#,
450        )
451    }
452
453    #[test]
454    fn duplicate_names() {
455        check_assist(
456            extract_type_alias,
457            r"
458struct Type;
459struct S {
460    field: $0u8$0,
461}
462            ",
463            r#"
464struct Type;
465type $0Type1 = u8;
466
467struct S {
468    field: Type1,
469}
470            "#,
471        );
472
473        check_assist(
474            extract_type_alias,
475            r"
476struct S<Type> {
477    field: $0u8$0,
478}
479            ",
480            r#"
481type $0Type1 = u8;
482
483struct S<Type> {
484    field: Type1,
485}
486            "#,
487        );
488    }
489}