Skip to main content

ide_assists/handlers/
move_bounds.rs

1use either::Either;
2use syntax::{
3    ast::{self, AstNode, HasName, HasTypeBounds, syntax_factory::SyntaxFactory},
4    match_ast,
5    syntax_editor::{GetOrCreateWhereClause, Removable},
6};
7
8use crate::{AssistContext, AssistId, Assists};
9
10// Assist: move_bounds_to_where_clause
11//
12// Moves inline type bounds to a where clause.
13//
14// ```
15// fn apply<T, U, $0F: FnOnce(T) -> U>(f: F, x: T) -> U {
16//     f(x)
17// }
18// ```
19// ->
20// ```
21// fn apply<T, U, F>(f: F, x: T) -> U where F: FnOnce(T) -> U {
22//     f(x)
23// }
24// ```
25pub(crate) fn move_bounds_to_where_clause(
26    acc: &mut Assists,
27    ctx: &AssistContext<'_, '_>,
28) -> Option<()> {
29    let type_param_list = ctx.find_node_at_offset::<ast::GenericParamList>()?;
30
31    let mut type_params = type_param_list.generic_params();
32    if type_params.all(|p| match p {
33        ast::GenericParam::TypeParam(t) => t.type_bound_list().is_none(),
34        ast::GenericParam::LifetimeParam(l) => l.type_bound_list().is_none(),
35        ast::GenericParam::ConstParam(_) => true,
36    }) {
37        return None;
38    }
39
40    let parent = type_param_list.syntax().parent()?;
41
42    let target = type_param_list.syntax().text_range();
43    acc.add(
44        AssistId::refactor_rewrite("move_bounds_to_where_clause"),
45        "Move to where clause",
46        target,
47        |builder| {
48            let editor = builder.make_editor(&parent);
49
50            let new_preds: Vec<ast::WherePred> = type_param_list
51                .generic_params()
52                .filter_map(|param| build_predicate(param, editor.make()))
53                .collect();
54
55            match_ast! {
56                match (&parent) {
57                    ast::Fn(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
58                    ast::Trait(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
59                    ast::Impl(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
60                    ast::Enum(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
61                    ast::Struct(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
62                    ast::TypeAlias(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()),
63                    _ => return,
64                }
65            };
66
67            for generic_param in type_param_list.generic_params() {
68                let param: &dyn HasTypeBounds = match &generic_param {
69                    ast::GenericParam::TypeParam(t) => t,
70                    ast::GenericParam::LifetimeParam(l) => l,
71                    ast::GenericParam::ConstParam(_) => continue,
72                };
73                if let Some(tbl) = param.type_bound_list() {
74                    tbl.remove(&editor);
75                }
76            }
77
78            builder.add_file_edits(ctx.vfs_file_id(), editor);
79        },
80    )
81}
82
83fn build_predicate(param: ast::GenericParam, make: &SyntaxFactory) -> Option<ast::WherePred> {
84    let target = match &param {
85        ast::GenericParam::TypeParam(t) => Either::Right(make.ty(&t.name()?.to_string())),
86        ast::GenericParam::LifetimeParam(l) => Either::Left(l.lifetime()?),
87        ast::GenericParam::ConstParam(_) => return None,
88    };
89    let predicate = make.where_pred(
90        target,
91        match param {
92            ast::GenericParam::TypeParam(t) => t.type_bound_list()?,
93            ast::GenericParam::LifetimeParam(l) => l.type_bound_list()?,
94            ast::GenericParam::ConstParam(_) => return None,
95        }
96        .bounds(),
97    );
98    Some(predicate)
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    use crate::tests::check_assist;
106
107    #[test]
108    fn move_bounds_to_where_clause_fn() {
109        check_assist(
110            move_bounds_to_where_clause,
111            r#"fn foo<T: u32, $0F: FnOnce(T) -> T>() {}"#,
112            r#"fn foo<T, F>() where T: u32, F: FnOnce(T) -> T {}"#,
113        );
114    }
115
116    #[test]
117    fn move_bounds_to_where_clause_impl() {
118        check_assist(
119            move_bounds_to_where_clause,
120            r#"impl<U: u32, $0T> A<U, T> {}"#,
121            r#"impl<U, T> A<U, T> where U: u32 {}"#,
122        );
123    }
124
125    #[test]
126    fn move_bounds_to_where_clause_struct() {
127        check_assist(
128            move_bounds_to_where_clause,
129            r#"struct A<$0T: Iterator<Item = u32>> {}"#,
130            r#"struct A<T> where T: Iterator<Item = u32> {}"#,
131        );
132    }
133
134    #[test]
135    fn move_bounds_to_where_clause_tuple_struct() {
136        check_assist(
137            move_bounds_to_where_clause,
138            r#"struct Pair<$0T: u32>(T, T);"#,
139            r#"struct Pair<T>(T, T) where T: u32;"#,
140        );
141    }
142
143    #[test]
144    fn move_bounds_to_where_clause_trait() {
145        check_assist(
146            move_bounds_to_where_clause,
147            r#"trait T<'a: 'static, $0T: u32> {}"#,
148            r#"trait T<'a, T> where 'a: 'static, T: u32 {}"#,
149        );
150    }
151}