Skip to main content

ide_assists/handlers/
extract_expressions_from_format_string.rs

1use crate::{AssistContext, Assists, utils};
2use ide_db::{
3    assists::{AssistId, AssistKind},
4    syntax_helpers::format_string_exprs::{Arg, parse_format_exprs},
5};
6use itertools::Itertools;
7use syntax::{
8    AstNode, AstToken, NodeOrToken,
9    SyntaxKind::WHITESPACE,
10    SyntaxToken, T,
11    ast::{self, TokenTree},
12};
13
14// Assist: extract_expressions_from_format_string
15//
16// Move an expression out of a format string.
17//
18// ```
19// # //- minicore: fmt
20// fn main() {
21//     print!("{var} {x + 1}$0");
22// }
23// ```
24// ->
25// ```
26// fn main() {
27//     print!("{var} {}"$0, x + 1);
28// }
29// ```
30
31pub(crate) fn extract_expressions_from_format_string(
32    acc: &mut Assists,
33    ctx: &AssistContext<'_, '_>,
34) -> Option<()> {
35    let fmt_string = ctx.find_token_at_offset::<ast::String>()?;
36    let tt = fmt_string.syntax().parent().and_then(ast::TokenTree::cast)?;
37    let tt_delimiter = tt.left_delimiter_token()?.kind();
38
39    let _ = ctx.sema.as_format_args_parts(&fmt_string)?;
40
41    let (new_fmt, extracted_args) = parse_format_exprs(fmt_string.text()).ok()?;
42    if extracted_args.is_empty() {
43        return None;
44    }
45
46    acc.add(
47        AssistId(
48            "extract_expressions_from_format_string",
49            // if there aren't any expressions, then make the assist a RefactorExtract
50            if extracted_args.iter().filter(|f| matches!(f, Arg::Expr(_))).count() == 0 {
51                AssistKind::RefactorExtract
52            } else {
53                AssistKind::QuickFix
54            },
55            None,
56        ),
57        "Extract format expressions",
58        tt.syntax().text_range(),
59        |edit| {
60            let editor = edit.make_editor(tt.syntax());
61            let make = editor.make();
62            // Extract existing arguments in macro
63            let mut raw_tokens = tt.token_trees_and_tokens().skip(1).collect_vec();
64            let format_string_index = format_str_index(&raw_tokens, &fmt_string);
65            let tokens = raw_tokens.split_off(format_string_index);
66
67            let existing_args = if let [
68                NodeOrToken::Token(_format_string),
69                _args_start_comma,
70                tokens @ ..,
71                NodeOrToken::Token(_end_bracket),
72            ] = tokens.as_slice()
73            {
74                let args = tokens
75                    .split(|it| matches!(it, NodeOrToken::Token(t) if t.kind() == T![,]))
76                    .map(|arg| {
77                        // Strip off leading and trailing whitespace tokens
78                        let arg = match arg.split_first() {
79                            Some((NodeOrToken::Token(t), rest)) if t.kind() == WHITESPACE => rest,
80                            _ => arg,
81                        };
82
83                        match arg.split_last() {
84                            Some((NodeOrToken::Token(t), rest)) if t.kind() == WHITESPACE => rest,
85                            _ => arg,
86                        }
87                    });
88
89                args.collect()
90            } else {
91                vec![]
92            };
93
94            // Start building the new args
95            let mut existing_args = existing_args.into_iter();
96            let mut new_tt_bits = raw_tokens;
97            let mut placeholder_indexes = vec![];
98
99            new_tt_bits.push(NodeOrToken::Token(make.expr_literal(&new_fmt).token().clone()));
100
101            for arg in extracted_args {
102                if matches!(arg, Arg::Expr(_) | Arg::Placeholder) {
103                    // insert ", " before each arg
104                    new_tt_bits.extend_from_slice(&[
105                        NodeOrToken::Token(make.token(T![,])),
106                        NodeOrToken::Token(make.whitespace(" ")),
107                    ]);
108                }
109
110                match arg {
111                    Arg::Expr(s) => {
112                        // insert arg
113                        let expr = ast::Expr::parse(&s, ctx.edition()).syntax_node();
114                        let mut expr_tt = utils::tt_from_syntax(expr, make);
115                        new_tt_bits.append(&mut expr_tt);
116                    }
117                    Arg::Placeholder => {
118                        // try matching with existing argument
119                        match existing_args.next() {
120                            Some(arg) => {
121                                new_tt_bits.extend_from_slice(arg);
122                            }
123                            None => {
124                                placeholder_indexes.push(new_tt_bits.len());
125                                new_tt_bits.push(NodeOrToken::Token(make.token(T![_])));
126                            }
127                        }
128                    }
129                    Arg::Ident(_s) => (),
130                }
131            }
132
133            // Insert new args
134            let new_tt = make.token_tree(tt_delimiter, new_tt_bits);
135            editor.replace(tt.syntax(), new_tt.syntax());
136
137            if let Some(cap) = ctx.config.snippet_cap {
138                // Add placeholder snippets over placeholder args
139                for pos in placeholder_indexes {
140                    // Skip the opening delimiter
141                    let Some(NodeOrToken::Token(placeholder)) =
142                        new_tt.token_trees_and_tokens().skip(1).nth(pos)
143                    else {
144                        continue;
145                    };
146
147                    if stdx::always!(placeholder.kind() == T![_]) {
148                        let annotation = edit.make_placeholder_snippet(cap);
149                        editor.add_annotation(placeholder, annotation);
150                    }
151                }
152
153                // Add the final tabstop after the format literal
154                if let Some(NodeOrToken::Token(literal)) =
155                    new_tt.token_trees_and_tokens().nth(1 + format_string_index)
156                {
157                    let annotation = edit.make_tabstop_after(cap);
158                    editor.add_annotation(literal, annotation);
159                }
160            }
161            edit.add_file_edits(ctx.vfs_file_id(), editor);
162        },
163    );
164
165    Some(())
166}
167
168fn format_str_index(
169    raw_tokens: &[NodeOrToken<TokenTree, SyntaxToken>],
170    fmt_string: &ast::String,
171) -> usize {
172    let fmt_string = fmt_string.syntax();
173    raw_tokens
174        .iter()
175        .position(|tt| tt.as_token().is_some_and(|tt| tt == fmt_string))
176        .unwrap_or_default()
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::tests::{check_assist, check_assist_no_snippet_cap};
183
184    #[test]
185    fn multiple_middle_arg() {
186        check_assist(
187            extract_expressions_from_format_string,
188            r#"
189//- minicore: fmt
190fn main() {
191    print!("{} {x + 1:b} {}$0", y + 2, 2);
192}
193"#,
194            r#"
195fn main() {
196    print!("{} {:b} {}"$0, y + 2, x + 1, 2);
197}
198"#,
199        );
200    }
201
202    #[test]
203    fn multiple_middle_arg_on_write() {
204        check_assist(
205            extract_expressions_from_format_string,
206            r#"
207//- minicore: write
208fn main() {
209    write!(writer(), "{} {x + 1:b} {}$0", y + 2, 2);
210}
211"#,
212            r#"
213fn main() {
214    write!(writer(), "{} {:b} {}"$0, y + 2, x + 1, 2);
215}
216"#,
217        );
218    }
219
220    #[test]
221    fn single_arg() {
222        check_assist(
223            extract_expressions_from_format_string,
224            r#"
225//- minicore: fmt
226fn main() {
227    print!("{obj.value:b}$0",);
228}
229"#,
230            r#"
231fn main() {
232    print!("{:b}"$0, obj.value);
233}
234"#,
235        );
236    }
237
238    #[test]
239    fn multiple_middle_placeholders_arg() {
240        check_assist(
241            extract_expressions_from_format_string,
242            r#"
243//- minicore: fmt
244fn main() {
245    print!("{} {x + 1:b} {} {}$0", y + 2, 2);
246}
247"#,
248            r#"
249fn main() {
250    print!("{} {:b} {} {}"$0, y + 2, x + 1, 2, ${1:_});
251}
252"#,
253        );
254    }
255
256    #[test]
257    fn multiple_trailing_args() {
258        check_assist(
259            extract_expressions_from_format_string,
260            r#"
261//- minicore: fmt
262fn main() {
263    print!("{:b} {x + 1:b} {Struct(1, 2)}$0", 1);
264}
265"#,
266            r#"
267fn main() {
268    print!("{:b} {:b} {}"$0, 1, x + 1, Struct(1, 2));
269}
270"#,
271        );
272    }
273
274    #[test]
275    fn improper_commas() {
276        check_assist(
277            extract_expressions_from_format_string,
278            r#"
279//- minicore: fmt
280fn main() {
281    print!("{} {x + 1:b} {Struct(1, 2)}$0", 1,);
282}
283"#,
284            r#"
285fn main() {
286    print!("{} {:b} {}"$0, 1, x + 1, Struct(1, 2));
287}
288"#,
289        );
290    }
291
292    #[test]
293    fn nested_tt() {
294        check_assist(
295            extract_expressions_from_format_string,
296            r#"
297//- minicore: fmt
298fn main() {
299    print!("My name is {} {x$0 + x}", stringify!(Paperino))
300}
301"#,
302            r#"
303fn main() {
304    print!("My name is {} {}"$0, stringify!(Paperino), x + x)
305}
306"#,
307        );
308    }
309
310    #[test]
311    fn extract_only_expressions() {
312        check_assist(
313            extract_expressions_from_format_string,
314            r#"
315//- minicore: fmt
316fn main() {
317    let var = 1 + 1;
318    print!("foobar {var} {var:?} {x$0 + x}")
319}
320"#,
321            r#"
322fn main() {
323    let var = 1 + 1;
324    print!("foobar {var} {var:?} {}"$0, x + x)
325}
326"#,
327        );
328    }
329
330    #[test]
331    fn escaped_literals() {
332        check_assist(
333            extract_expressions_from_format_string,
334            r#"
335//- minicore: fmt
336fn main() {
337    print!("\n$ {x + 1}$0");
338}
339            "#,
340            r#"
341fn main() {
342    print!("\n$ {}"$0, x + 1);
343}
344            "#,
345        );
346    }
347
348    #[test]
349    fn without_snippets() {
350        check_assist_no_snippet_cap(
351            extract_expressions_from_format_string,
352            r#"
353//- minicore: fmt
354fn main() {
355    print!("{} {x + 1:b} {} {}$0", y + 2, 2);
356}
357"#,
358            r#"
359fn main() {
360    print!("{} {:b} {} {}", y + 2, x + 1, 2, _);
361}
362"#,
363        );
364    }
365}