ide_assists/handlers/
move_bounds.rs

1use either::Either;
2use syntax::{
3    ast::{
4        self, AstNode, HasName, HasTypeBounds,
5        edit_in_place::{GenericParamsOwnerEdit, Removable},
6        make,
7    },
8    match_ast,
9};
10
11use crate::{AssistContext, AssistId, Assists};
12
13// Assist: move_bounds_to_where_clause
14//
15// Moves inline type bounds to a where clause.
16//
17// ```
18// fn apply<T, U, $0F: FnOnce(T) -> U>(f: F, x: T) -> U {
19//     f(x)
20// }
21// ```
22// ->
23// ```
24// fn apply<T, U, F>(f: F, x: T) -> U where F: FnOnce(T) -> U {
25//     f(x)
26// }
27// ```
28pub(crate) fn move_bounds_to_where_clause(
29    acc: &mut Assists,
30    ctx: &AssistContext<'_>,
31) -> Option<()> {
32    let type_param_list = ctx.find_node_at_offset::<ast::GenericParamList>()?;
33
34    let mut type_params = type_param_list.generic_params();
35    if type_params.all(|p| match p {
36        ast::GenericParam::TypeParam(t) => t.type_bound_list().is_none(),
37        ast::GenericParam::LifetimeParam(l) => l.type_bound_list().is_none(),
38        ast::GenericParam::ConstParam(_) => true,
39    }) {
40        return None;
41    }
42
43    let parent = type_param_list.syntax().parent()?;
44
45    let target = type_param_list.syntax().text_range();
46    acc.add(
47        AssistId::refactor_rewrite("move_bounds_to_where_clause"),
48        "Move to where clause",
49        target,
50        |edit| {
51            let type_param_list = edit.make_mut(type_param_list);
52            let parent = edit.make_syntax_mut(parent);
53
54            let where_clause: ast::WhereClause = match_ast! {
55                match parent {
56                    ast::Fn(it) => it.get_or_create_where_clause(),
57                    ast::Trait(it) => it.get_or_create_where_clause(),
58                    ast::Impl(it) => it.get_or_create_where_clause(),
59                    ast::Enum(it) => it.get_or_create_where_clause(),
60                    ast::Struct(it) => it.get_or_create_where_clause(),
61                    ast::TypeAlias(it) => it.get_or_create_where_clause(),
62                    _ => return,
63                }
64            };
65
66            for generic_param in type_param_list.generic_params() {
67                let param: &dyn HasTypeBounds = match &generic_param {
68                    ast::GenericParam::TypeParam(t) => t,
69                    ast::GenericParam::LifetimeParam(l) => l,
70                    ast::GenericParam::ConstParam(_) => continue,
71                };
72                if let Some(tbl) = param.type_bound_list() {
73                    if let Some(predicate) = build_predicate(generic_param) {
74                        where_clause.add_predicate(predicate)
75                    }
76                    tbl.remove()
77                }
78            }
79        },
80    )
81}
82
83fn build_predicate(param: ast::GenericParam) -> Option<ast::WherePred> {
84    let target = match &param {
85        ast::GenericParam::TypeParam(t) => {
86            Either::Right(make::ty_path(make::ext::ident_path(&t.name()?.to_string())))
87        }
88        ast::GenericParam::LifetimeParam(l) => Either::Left(l.lifetime()?),
89        ast::GenericParam::ConstParam(_) => return None,
90    };
91    let predicate = make::where_pred(
92        target,
93        match param {
94            ast::GenericParam::TypeParam(t) => t.type_bound_list()?,
95            ast::GenericParam::LifetimeParam(l) => l.type_bound_list()?,
96            ast::GenericParam::ConstParam(_) => return None,
97        }
98        .bounds(),
99    );
100    Some(predicate.clone_for_update())
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    use crate::tests::check_assist;
108
109    #[test]
110    fn move_bounds_to_where_clause_fn() {
111        check_assist(
112            move_bounds_to_where_clause,
113            r#"fn foo<T: u32, $0F: FnOnce(T) -> T>() {}"#,
114            r#"fn foo<T, F>() where T: u32, F: FnOnce(T) -> T {}"#,
115        );
116    }
117
118    #[test]
119    fn move_bounds_to_where_clause_impl() {
120        check_assist(
121            move_bounds_to_where_clause,
122            r#"impl<U: u32, $0T> A<U, T> {}"#,
123            r#"impl<U, T> A<U, T> where U: u32 {}"#,
124        );
125    }
126
127    #[test]
128    fn move_bounds_to_where_clause_struct() {
129        check_assist(
130            move_bounds_to_where_clause,
131            r#"struct A<$0T: Iterator<Item = u32>> {}"#,
132            r#"struct A<T> where T: Iterator<Item = u32> {}"#,
133        );
134    }
135
136    #[test]
137    fn move_bounds_to_where_clause_tuple_struct() {
138        check_assist(
139            move_bounds_to_where_clause,
140            r#"struct Pair<$0T: u32>(T, T);"#,
141            r#"struct Pair<T>(T, T) where T: u32;"#,
142        );
143    }
144
145    #[test]
146    fn move_bounds_to_where_clause_trait() {
147        check_assist(
148            move_bounds_to_where_clause,
149            r#"trait T<'a: 'static, $0T: u32> {}"#,
150            r#"trait T<'a, T> where 'a: 'static, T: u32 {}"#,
151        );
152    }
153}