ide_assists/handlers/
extract_expressions_from_format_string.rs

1use crate::{AssistContext, Assists, utils};
2use ide_db::{
3    assists::{AssistId, AssistKind},
4    syntax_helpers::format_string_exprs::{Arg, parse_format_exprs},
5};
6use itertools::Itertools;
7use syntax::{
8    AstNode, AstToken, NodeOrToken,
9    SyntaxKind::WHITESPACE,
10    SyntaxToken, T,
11    ast::{self, TokenTree, make, syntax_factory::SyntaxFactory},
12};
13
14// Assist: extract_expressions_from_format_string
15//
16// Move an expression out of a format string.
17//
18// ```
19// # //- minicore: fmt
20// fn main() {
21//     print!("{var} {x + 1}$0");
22// }
23// ```
24// ->
25// ```
26// fn main() {
27//     print!("{var} {}"$0, x + 1);
28// }
29// ```
30
31pub(crate) fn extract_expressions_from_format_string(
32    acc: &mut Assists,
33    ctx: &AssistContext<'_>,
34) -> Option<()> {
35    let fmt_string = ctx.find_token_at_offset::<ast::String>()?;
36    let tt = fmt_string.syntax().parent().and_then(ast::TokenTree::cast)?;
37    let tt_delimiter = tt.left_delimiter_token()?.kind();
38
39    let _ = ctx.sema.as_format_args_parts(&fmt_string)?;
40
41    let (new_fmt, extracted_args) = parse_format_exprs(fmt_string.text()).ok()?;
42    if extracted_args.is_empty() {
43        return None;
44    }
45
46    acc.add(
47        AssistId(
48            "extract_expressions_from_format_string",
49            // if there aren't any expressions, then make the assist a RefactorExtract
50            if extracted_args.iter().filter(|f| matches!(f, Arg::Expr(_))).count() == 0 {
51                AssistKind::RefactorExtract
52            } else {
53                AssistKind::QuickFix
54            },
55            None,
56        ),
57        "Extract format expressions",
58        tt.syntax().text_range(),
59        |edit| {
60            // Extract existing arguments in macro
61            let mut raw_tokens = tt.token_trees_and_tokens().skip(1).collect_vec();
62            let format_string_index = format_str_index(&raw_tokens, &fmt_string);
63            let tokens = raw_tokens.split_off(format_string_index);
64
65            let existing_args = if let [
66                NodeOrToken::Token(_format_string),
67                _args_start_comma,
68                tokens @ ..,
69                NodeOrToken::Token(_end_bracket),
70            ] = tokens.as_slice()
71            {
72                let args = tokens
73                    .split(|it| matches!(it, NodeOrToken::Token(t) if t.kind() == T![,]))
74                    .map(|arg| {
75                        // Strip off leading and trailing whitespace tokens
76                        let arg = match arg.split_first() {
77                            Some((NodeOrToken::Token(t), rest)) if t.kind() == WHITESPACE => rest,
78                            _ => arg,
79                        };
80
81                        match arg.split_last() {
82                            Some((NodeOrToken::Token(t), rest)) if t.kind() == WHITESPACE => rest,
83                            _ => arg,
84                        }
85                    });
86
87                args.collect()
88            } else {
89                vec![]
90            };
91
92            // Start building the new args
93            let mut existing_args = existing_args.into_iter();
94            let mut new_tt_bits = raw_tokens;
95            let mut placeholder_indexes = vec![];
96
97            new_tt_bits.push(NodeOrToken::Token(make::tokens::literal(&new_fmt)));
98
99            for arg in extracted_args {
100                if matches!(arg, Arg::Expr(_) | Arg::Placeholder) {
101                    // insert ", " before each arg
102                    new_tt_bits.extend_from_slice(&[
103                        NodeOrToken::Token(make::token(T![,])),
104                        NodeOrToken::Token(make::tokens::single_space()),
105                    ]);
106                }
107
108                match arg {
109                    Arg::Expr(s) => {
110                        // insert arg
111                        let expr = ast::Expr::parse(&s, ctx.edition()).syntax_node();
112                        let mut expr_tt = utils::tt_from_syntax(expr);
113                        new_tt_bits.append(&mut expr_tt);
114                    }
115                    Arg::Placeholder => {
116                        // try matching with existing argument
117                        match existing_args.next() {
118                            Some(arg) => {
119                                new_tt_bits.extend_from_slice(arg);
120                            }
121                            None => {
122                                placeholder_indexes.push(new_tt_bits.len());
123                                new_tt_bits.push(NodeOrToken::Token(make::token(T![_])));
124                            }
125                        }
126                    }
127                    Arg::Ident(_s) => (),
128                }
129            }
130
131            // Insert new args
132            let make = SyntaxFactory::with_mappings();
133            let new_tt = make.token_tree(tt_delimiter, new_tt_bits);
134            let mut editor = edit.make_editor(tt.syntax());
135            editor.replace(tt.syntax(), new_tt.syntax());
136
137            if let Some(cap) = ctx.config.snippet_cap {
138                // Add placeholder snippets over placeholder args
139                for pos in placeholder_indexes {
140                    // Skip the opening delimiter
141                    let Some(NodeOrToken::Token(placeholder)) =
142                        new_tt.token_trees_and_tokens().skip(1).nth(pos)
143                    else {
144                        continue;
145                    };
146
147                    if stdx::always!(placeholder.kind() == T![_]) {
148                        let annotation = edit.make_placeholder_snippet(cap);
149                        editor.add_annotation(placeholder, annotation);
150                    }
151                }
152
153                // Add the final tabstop after the format literal
154                if let Some(NodeOrToken::Token(literal)) =
155                    new_tt.token_trees_and_tokens().nth(1 + format_string_index)
156                {
157                    let annotation = edit.make_tabstop_after(cap);
158                    editor.add_annotation(literal, annotation);
159                }
160            }
161            editor.add_mappings(make.finish_with_mappings());
162            edit.add_file_edits(ctx.vfs_file_id(), editor);
163        },
164    );
165
166    Some(())
167}
168
169fn format_str_index(
170    raw_tokens: &[NodeOrToken<TokenTree, SyntaxToken>],
171    fmt_string: &ast::String,
172) -> usize {
173    let fmt_string = fmt_string.syntax();
174    raw_tokens
175        .iter()
176        .position(|tt| tt.as_token().is_some_and(|tt| tt == fmt_string))
177        .unwrap_or_default()
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::tests::{check_assist, check_assist_no_snippet_cap};
184
185    #[test]
186    fn multiple_middle_arg() {
187        check_assist(
188            extract_expressions_from_format_string,
189            r#"
190//- minicore: fmt
191fn main() {
192    print!("{} {x + 1:b} {}$0", y + 2, 2);
193}
194"#,
195            r#"
196fn main() {
197    print!("{} {:b} {}"$0, y + 2, x + 1, 2);
198}
199"#,
200        );
201    }
202
203    #[test]
204    fn multiple_middle_arg_on_write() {
205        check_assist(
206            extract_expressions_from_format_string,
207            r#"
208//- minicore: write
209fn main() {
210    write!(writer(), "{} {x + 1:b} {}$0", y + 2, 2);
211}
212"#,
213            r#"
214fn main() {
215    write!(writer(), "{} {:b} {}"$0, y + 2, x + 1, 2);
216}
217"#,
218        );
219    }
220
221    #[test]
222    fn single_arg() {
223        check_assist(
224            extract_expressions_from_format_string,
225            r#"
226//- minicore: fmt
227fn main() {
228    print!("{obj.value:b}$0",);
229}
230"#,
231            r#"
232fn main() {
233    print!("{:b}"$0, obj.value);
234}
235"#,
236        );
237    }
238
239    #[test]
240    fn multiple_middle_placeholders_arg() {
241        check_assist(
242            extract_expressions_from_format_string,
243            r#"
244//- minicore: fmt
245fn main() {
246    print!("{} {x + 1:b} {} {}$0", y + 2, 2);
247}
248"#,
249            r#"
250fn main() {
251    print!("{} {:b} {} {}"$0, y + 2, x + 1, 2, ${1:_});
252}
253"#,
254        );
255    }
256
257    #[test]
258    fn multiple_trailing_args() {
259        check_assist(
260            extract_expressions_from_format_string,
261            r#"
262//- minicore: fmt
263fn main() {
264    print!("{:b} {x + 1:b} {Struct(1, 2)}$0", 1);
265}
266"#,
267            r#"
268fn main() {
269    print!("{:b} {:b} {}"$0, 1, x + 1, Struct(1, 2));
270}
271"#,
272        );
273    }
274
275    #[test]
276    fn improper_commas() {
277        check_assist(
278            extract_expressions_from_format_string,
279            r#"
280//- minicore: fmt
281fn main() {
282    print!("{} {x + 1:b} {Struct(1, 2)}$0", 1,);
283}
284"#,
285            r#"
286fn main() {
287    print!("{} {:b} {}"$0, 1, x + 1, Struct(1, 2));
288}
289"#,
290        );
291    }
292
293    #[test]
294    fn nested_tt() {
295        check_assist(
296            extract_expressions_from_format_string,
297            r#"
298//- minicore: fmt
299fn main() {
300    print!("My name is {} {x$0 + x}", stringify!(Paperino))
301}
302"#,
303            r#"
304fn main() {
305    print!("My name is {} {}"$0, stringify!(Paperino), x + x)
306}
307"#,
308        );
309    }
310
311    #[test]
312    fn extract_only_expressions() {
313        check_assist(
314            extract_expressions_from_format_string,
315            r#"
316//- minicore: fmt
317fn main() {
318    let var = 1 + 1;
319    print!("foobar {var} {var:?} {x$0 + x}")
320}
321"#,
322            r#"
323fn main() {
324    let var = 1 + 1;
325    print!("foobar {var} {var:?} {}"$0, x + x)
326}
327"#,
328        );
329    }
330
331    #[test]
332    fn escaped_literals() {
333        check_assist(
334            extract_expressions_from_format_string,
335            r#"
336//- minicore: fmt
337fn main() {
338    print!("\n$ {x + 1}$0");
339}
340            "#,
341            r#"
342fn main() {
343    print!("\n$ {}"$0, x + 1);
344}
345            "#,
346        );
347    }
348
349    #[test]
350    fn without_snippets() {
351        check_assist_no_snippet_cap(
352            extract_expressions_from_format_string,
353            r#"
354//- minicore: fmt
355fn main() {
356    print!("{} {x + 1:b} {} {}$0", y + 2, 2);
357}
358"#,
359            r#"
360fn main() {
361    print!("{} {:b} {} {}", y + 2, x + 1, 2, _);
362}
363"#,
364        );
365    }
366}