ide_assists/handlers/
reorder_fields.rs

1use either::Either;
2use ide_db::FxHashMap;
3use itertools::Itertools;
4use syntax::{AstNode, SmolStr, SyntaxElement, ToSmolStr, ast, syntax_editor::SyntaxEditor};
5
6use crate::{AssistContext, AssistId, Assists};
7
8// Assist: reorder_fields
9//
10// Reorder the fields of record literals and record patterns in the same order as in
11// the definition.
12//
13// ```
14// struct Foo {foo: i32, bar: i32};
15// const test: Foo = $0Foo {bar: 0, foo: 1}
16// ```
17// ->
18// ```
19// struct Foo {foo: i32, bar: i32};
20// const test: Foo = Foo {foo: 1, bar: 0}
21// ```
22pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
23    let path = ctx.find_node_at_offset::<ast::Path>()?;
24    let record =
25        path.syntax().parent().and_then(<Either<ast::RecordExpr, ast::RecordPat>>::cast)?;
26
27    let parent_node = match ctx.covering_element() {
28        SyntaxElement::Node(n) => n,
29        SyntaxElement::Token(t) => t.parent()?,
30    };
31
32    let ranks = compute_fields_ranks(&path, ctx)?;
33    let get_rank_of_field = |of: Option<SmolStr>| {
34        *ranks.get(of.unwrap_or_default().trim_start_matches("r#")).unwrap_or(&usize::MAX)
35    };
36
37    let field_list = match &record {
38        Either::Left(it) => Either::Left(it.record_expr_field_list()?),
39        Either::Right(it) => Either::Right(it.record_pat_field_list()?),
40    };
41    let fields = match field_list {
42        Either::Left(it) => Either::Left((
43            it.fields()
44                .sorted_unstable_by_key(|field| {
45                    get_rank_of_field(field.field_name().map(|it| it.to_smolstr()))
46                })
47                .collect::<Vec<_>>(),
48            it,
49        )),
50        Either::Right(it) => Either::Right((
51            it.fields()
52                .sorted_unstable_by_key(|field| {
53                    get_rank_of_field(field.field_name().map(|it| it.to_smolstr()))
54                })
55                .collect::<Vec<_>>(),
56            it,
57        )),
58    };
59
60    let is_sorted = fields.as_ref().either(
61        |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
62        |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
63    );
64    if is_sorted {
65        cov_mark::hit!(reorder_sorted_fields);
66        return None;
67    }
68    let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range();
69    acc.add(
70        AssistId::refactor_rewrite("reorder_fields"),
71        "Reorder record fields",
72        target,
73        |builder| {
74            let mut editor = builder.make_editor(&parent_node);
75
76            match fields {
77                Either::Left((sorted, field_list)) => {
78                    replace(&mut editor, field_list.fields(), sorted)
79                }
80                Either::Right((sorted, field_list)) => {
81                    replace(&mut editor, field_list.fields(), sorted)
82                }
83            }
84
85            builder.add_file_edits(ctx.vfs_file_id(), editor);
86        },
87    )
88}
89
90fn replace<T: AstNode + PartialEq>(
91    editor: &mut SyntaxEditor,
92    fields: impl Iterator<Item = T>,
93    sorted_fields: impl IntoIterator<Item = T>,
94) {
95    fields
96        .zip(sorted_fields)
97        .for_each(|(field, sorted_field)| editor.replace(field.syntax(), sorted_field.syntax()));
98}
99
100fn compute_fields_ranks(
101    path: &ast::Path,
102    ctx: &AssistContext<'_>,
103) -> Option<FxHashMap<String, usize>> {
104    let strukt = match ctx.sema.resolve_path(path) {
105        Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Struct(it)))) => it,
106        _ => return None,
107    };
108
109    let res = strukt
110        .fields(ctx.db())
111        .into_iter()
112        .enumerate()
113        .map(|(idx, field)| (field.name(ctx.db()).as_str().to_owned(), idx))
114        .collect();
115
116    Some(res)
117}
118
119#[cfg(test)]
120mod tests {
121    use crate::tests::{check_assist, check_assist_not_applicable};
122
123    use super::*;
124
125    #[test]
126    fn reorder_sorted_fields() {
127        cov_mark::check!(reorder_sorted_fields);
128        check_assist_not_applicable(
129            reorder_fields,
130            r#"
131struct Foo { foo: i32, bar: i32 }
132const test: Foo = $0Foo { foo: 0, bar: 0 };
133"#,
134        )
135    }
136
137    #[test]
138    fn trivial_empty_fields() {
139        check_assist_not_applicable(
140            reorder_fields,
141            r#"
142struct Foo {}
143const test: Foo = $0Foo {};
144"#,
145        )
146    }
147
148    #[test]
149    fn reorder_struct_fields() {
150        check_assist(
151            reorder_fields,
152            r#"
153struct Foo { foo: i32, bar: i32 }
154const test: Foo = $0Foo { bar: 0, foo: 1 };
155"#,
156            r#"
157struct Foo { foo: i32, bar: i32 }
158const test: Foo = Foo { foo: 1, bar: 0 };
159"#,
160        )
161    }
162    #[test]
163    fn reorder_struct_pattern() {
164        check_assist(
165            reorder_fields,
166            r#"
167struct Foo { foo: i64, bar: i64, baz: i64 }
168
169fn f(f: Foo) -> {
170    match f {
171        $0Foo { baz: 0, ref mut bar, .. } => (),
172        _ => ()
173    }
174}
175"#,
176            r#"
177struct Foo { foo: i64, bar: i64, baz: i64 }
178
179fn f(f: Foo) -> {
180    match f {
181        Foo { ref mut bar, baz: 0, .. } => (),
182        _ => ()
183    }
184}
185"#,
186        )
187    }
188
189    #[test]
190    fn reorder_with_extra_field() {
191        check_assist(
192            reorder_fields,
193            r#"
194struct Foo { foo: String, bar: String }
195
196impl Foo {
197    fn new() -> Foo {
198        let foo = String::new();
199        $0Foo {
200            bar: foo.clone(),
201            extra: "Extra field",
202            foo,
203        }
204    }
205}
206"#,
207            r#"
208struct Foo { foo: String, bar: String }
209
210impl Foo {
211    fn new() -> Foo {
212        let foo = String::new();
213        Foo {
214            foo,
215            bar: foo.clone(),
216            extra: "Extra field",
217        }
218    }
219}
220"#,
221        )
222    }
223}