ide_assists/handlers/
extract_type_alias.rs

1use either::Either;
2use hir::HirDisplay;
3use ide_db::syntax_helpers::node_ext::walk_ty;
4use syntax::{
5    ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make},
6    syntax_editor,
7};
8
9use crate::{AssistContext, AssistId, Assists};
10
11// Assist: extract_type_alias
12//
13// Extracts the selected type as a type alias.
14//
15// ```
16// struct S {
17//     field: $0(u8, u8, u8)$0,
18// }
19// ```
20// ->
21// ```
22// type $0Type = (u8, u8, u8);
23//
24// struct S {
25//     field: Type,
26// }
27// ```
28pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
29    if ctx.has_empty_selection() {
30        return None;
31    }
32
33    let ty = ctx.find_node_at_range::<ast::Type>()?;
34    let item = ty.syntax().ancestors().find_map(ast::Item::cast)?;
35    let assoc_owner =
36        item.syntax().ancestors().nth(2).and_then(Either::<ast::Trait, ast::Impl>::cast);
37    let node = assoc_owner.as_ref().map_or_else(
38        || item.syntax(),
39        |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
40    );
41    let target = ty.syntax().text_range();
42
43    let resolved_ty = ctx.sema.resolve_type(&ty)?;
44    let resolved_ty = if !resolved_ty.contains_unknown() {
45        let module = ctx.sema.scope(ty.syntax())?.module();
46        let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?;
47        make::ty(&resolved_ty)
48    } else {
49        ty.clone()
50    };
51
52    acc.add(
53        AssistId::refactor_extract("extract_type_alias"),
54        "Extract type as type alias",
55        target,
56        |builder| {
57            let mut edit = builder.make_editor(node);
58
59            let mut known_generics = match item.generic_param_list() {
60                Some(it) => it.generic_params().collect(),
61                None => Vec::new(),
62            };
63            if let Some(it) = assoc_owner.as_ref().and_then(|it| match it {
64                Either::Left(it) => it.generic_param_list(),
65                Either::Right(it) => it.generic_param_list(),
66            }) {
67                known_generics.extend(it.generic_params());
68            }
69            let generics = collect_used_generics(&ty, &known_generics);
70            let generic_params =
71                generics.map(|it| make::generic_param_list(it.into_iter().cloned()));
72
73            // Replace original type with the alias
74            let ty_args = generic_params.as_ref().map(|it| it.to_generic_args().generic_args());
75            let new_ty = if let Some(ty_args) = ty_args {
76                make::generic_ty_path_segment(make::name_ref("Type"), ty_args)
77            } else {
78                make::path_segment(make::name_ref("Type"))
79            }
80            .clone_for_update();
81            edit.replace(ty.syntax(), new_ty.syntax());
82
83            // Insert new alias
84            let ty_alias =
85                make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None)))
86                    .clone_for_update();
87
88            if let Some(cap) = ctx.config.snippet_cap
89                && let Some(name) = ty_alias.name()
90            {
91                edit.add_annotation(name.syntax(), builder.make_tabstop_before(cap));
92            }
93
94            let indent = IndentLevel::from_node(node);
95            edit.insert_all(
96                syntax_editor::Position::before(node),
97                vec![
98                    ty_alias.syntax().clone().into(),
99                    make::tokens::whitespace(&format!("\n\n{indent}")).into(),
100                ],
101            );
102
103            builder.add_file_edits(ctx.vfs_file_id(), edit);
104        },
105    )
106}
107
108fn collect_used_generics<'gp>(
109    ty: &ast::Type,
110    known_generics: &'gp [ast::GenericParam],
111) -> Option<Vec<&'gp ast::GenericParam>> {
112    // can't use a closure -> closure here cause lifetime inference fails for that
113    fn find_lifetime(text: &str) -> impl Fn(&&ast::GenericParam) -> bool + '_ {
114        move |gp: &&ast::GenericParam| match gp {
115            ast::GenericParam::LifetimeParam(lp) => {
116                lp.lifetime().is_some_and(|lt| lt.text() == text)
117            }
118            _ => false,
119        }
120    }
121
122    let mut generics = Vec::new();
123    walk_ty(ty, &mut |ty| {
124        match ty {
125            ast::Type::PathType(ty) => {
126                if let Some(path) = ty.path() {
127                    if let Some(name_ref) = path.as_single_name_ref()
128                        && let Some(param) = known_generics.iter().find(|gp| {
129                            match gp {
130                                ast::GenericParam::ConstParam(cp) => cp.name(),
131                                ast::GenericParam::TypeParam(tp) => tp.name(),
132                                _ => None,
133                            }
134                            .is_some_and(|n| n.text() == name_ref.text())
135                        })
136                    {
137                        generics.push(param);
138                    }
139                    generics.extend(
140                        path.segments()
141                            .filter_map(|seg| seg.generic_arg_list())
142                            .flat_map(|it| it.generic_args())
143                            .filter_map(|it| match it {
144                                ast::GenericArg::LifetimeArg(lt) => {
145                                    let lt = lt.lifetime()?;
146                                    known_generics.iter().find(find_lifetime(&lt.text()))
147                                }
148                                _ => None,
149                            }),
150                    );
151                }
152            }
153            ast::Type::ImplTraitType(impl_ty) => {
154                if let Some(it) = impl_ty.type_bound_list() {
155                    generics.extend(
156                        it.bounds()
157                            .filter_map(|it| it.lifetime())
158                            .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
159                    );
160                }
161            }
162            ast::Type::DynTraitType(dyn_ty) => {
163                if let Some(it) = dyn_ty.type_bound_list() {
164                    generics.extend(
165                        it.bounds()
166                            .filter_map(|it| it.lifetime())
167                            .filter_map(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
168                    );
169                }
170            }
171            ast::Type::RefType(ref_) => generics.extend(
172                ref_.lifetime()
173                    .and_then(|lt| known_generics.iter().find(find_lifetime(&lt.text()))),
174            ),
175            ast::Type::ArrayType(ar) => {
176                if let Some(ast::Expr::PathExpr(p)) = ar.const_arg().and_then(|x| x.expr())
177                    && let Some(path) = p.path()
178                    && let Some(name_ref) = path.as_single_name_ref()
179                    && let Some(param) = known_generics.iter().find(|gp| {
180                        if let ast::GenericParam::ConstParam(cp) = gp {
181                            cp.name().is_some_and(|n| n.text() == name_ref.text())
182                        } else {
183                            false
184                        }
185                    })
186                {
187                    generics.push(param);
188                }
189            }
190            _ => (),
191        };
192        false
193    });
194    // stable resort to lifetime, type, const
195    generics.sort_by_key(|gp| match gp {
196        ast::GenericParam::ConstParam(_) => 2,
197        ast::GenericParam::LifetimeParam(_) => 0,
198        ast::GenericParam::TypeParam(_) => 1,
199    });
200
201    Some(generics).filter(|it| !it.is_empty())
202}
203
204#[cfg(test)]
205mod tests {
206    use crate::tests::{check_assist, check_assist_not_applicable};
207
208    use super::*;
209
210    #[test]
211    fn test_not_applicable_without_selection() {
212        check_assist_not_applicable(
213            extract_type_alias,
214            r"
215struct S {
216    field: $0(u8, u8, u8),
217}
218            ",
219        );
220    }
221
222    #[test]
223    fn test_simple_types() {
224        check_assist(
225            extract_type_alias,
226            r"
227struct S {
228    field: $0u8$0,
229}
230            ",
231            r#"
232type $0Type = u8;
233
234struct S {
235    field: Type,
236}
237            "#,
238        );
239    }
240
241    #[test]
242    fn test_generic_type_arg() {
243        check_assist(
244            extract_type_alias,
245            r"
246fn generic<T>() {}
247
248fn f() {
249    generic::<$0()$0>();
250}
251            ",
252            r#"
253fn generic<T>() {}
254
255type $0Type = ();
256
257fn f() {
258    generic::<Type>();
259}
260            "#,
261        );
262    }
263
264    #[test]
265    fn test_inner_type_arg() {
266        check_assist(
267            extract_type_alias,
268            r"
269struct Vec<T> {}
270struct S {
271    v: Vec<Vec<$0Vec<u8>$0>>,
272}
273            ",
274            r#"
275struct Vec<T> {}
276type $0Type = Vec<u8>;
277
278struct S {
279    v: Vec<Vec<Type>>,
280}
281            "#,
282        );
283    }
284
285    #[test]
286    fn test_extract_inner_type() {
287        check_assist(
288            extract_type_alias,
289            r"
290struct S {
291    field: ($0u8$0,),
292}
293            ",
294            r#"
295type $0Type = u8;
296
297struct S {
298    field: (Type,),
299}
300            "#,
301        );
302    }
303
304    #[test]
305    fn extract_from_impl_or_trait() {
306        // When invoked in an impl/trait, extracted type alias should be placed next to the
307        // impl/trait, not inside.
308        check_assist(
309            extract_type_alias,
310            r#"
311impl S {
312    fn f() -> $0(u8, u8)$0 {}
313}
314            "#,
315            r#"
316type $0Type = (u8, u8);
317
318impl S {
319    fn f() -> Type {}
320}
321            "#,
322        );
323        check_assist(
324            extract_type_alias,
325            r#"
326trait Tr {
327    fn f() -> $0(u8, u8)$0 {}
328}
329            "#,
330            r#"
331type $0Type = (u8, u8);
332
333trait Tr {
334    fn f() -> Type {}
335}
336            "#,
337        );
338    }
339
340    #[test]
341    fn indentation() {
342        check_assist(
343            extract_type_alias,
344            r#"
345mod m {
346    fn f() -> $0u8$0 {}
347}
348            "#,
349            r#"
350mod m {
351    type $0Type = u8;
352
353    fn f() -> Type {}
354}
355            "#,
356        );
357    }
358
359    #[test]
360    fn generics() {
361        check_assist(
362            extract_type_alias,
363            r#"
364struct Struct<const C: usize>;
365impl<'outer, Outer, const OUTER: usize> () {
366    fn func<'inner, Inner, const INNER: usize>(_: $0&(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ())$0) {}
367}
368"#,
369            r#"
370struct Struct<const C: usize>;
371type $0Type<'inner, 'outer, Outer, Inner, const INNER: usize, const OUTER: usize> = &(Struct<INNER>, Struct<OUTER>, Outer, &'inner (), Inner, &'outer ());
372
373impl<'outer, Outer, const OUTER: usize> () {
374    fn func<'inner, Inner, const INNER: usize>(_: Type<'inner, 'outer, Outer, Inner, INNER, OUTER>) {}
375}
376"#,
377        );
378    }
379
380    #[test]
381    fn issue_11197() {
382        check_assist(
383            extract_type_alias,
384            r#"
385struct Foo<T, const N: usize>
386where
387    [T; N]: Sized,
388{
389    arr: $0[T; N]$0,
390}
391            "#,
392            r#"
393type $0Type<T, const N: usize> = [T; N];
394
395struct Foo<T, const N: usize>
396where
397    [T; N]: Sized,
398{
399    arr: Type<T, N>,
400}
401            "#,
402        );
403    }
404
405    #[test]
406    fn inferred_generic_type_parameter() {
407        check_assist(
408            extract_type_alias,
409            r#"
410struct Wrap<T>(T);
411
412fn main() {
413    let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32);
414}
415            "#,
416            r#"
417struct Wrap<T>(T);
418
419type $0Type = Wrap<i32>;
420
421fn main() {
422    let wrap: Type = Wrap::<_>(3i32);
423}
424            "#,
425        )
426    }
427
428    #[test]
429    fn inferred_type() {
430        check_assist(
431            extract_type_alias,
432            r#"
433struct Wrap<T>(T);
434
435fn main() {
436    let wrap: Wrap<$0_$0> = Wrap::<_>(3i32);
437}
438            "#,
439            r#"
440struct Wrap<T>(T);
441
442type $0Type = i32;
443
444fn main() {
445    let wrap: Wrap<Type> = Wrap::<_>(3i32);
446}
447            "#,
448        )
449    }
450}