ide_assists/handlers/
convert_nested_function_to_closure.rs

1use ide_db::assists::AssistId;
2use syntax::ast::{self, HasGenericParams, HasName};
3use syntax::{AstNode, SyntaxKind};
4
5use crate::assist_context::{AssistContext, Assists};
6
7// Assist: convert_nested_function_to_closure
8//
9// Converts a function that is defined within the body of another function into a closure.
10//
11// ```
12// fn main() {
13//     fn fo$0o(label: &str, number: u64) {
14//         println!("{}: {}", label, number);
15//     }
16//
17//     foo("Bar", 100);
18// }
19// ```
20// ->
21// ```
22// fn main() {
23//     let foo = |label: &str, number: u64| {
24//         println!("{}: {}", label, number);
25//     };
26//
27//     foo("Bar", 100);
28// }
29// ```
30pub(crate) fn convert_nested_function_to_closure(
31    acc: &mut Assists,
32    ctx: &AssistContext<'_>,
33) -> Option<()> {
34    let name = ctx.find_node_at_offset::<ast::Name>()?;
35    let function = name.syntax().parent().and_then(ast::Fn::cast)?;
36
37    if !is_nested_function(&function) || is_generic(&function) || has_modifiers(&function) {
38        return None;
39    }
40
41    let target = function.syntax().text_range();
42    let body = function.body()?;
43    let name = function.name()?;
44    let param_list = function.param_list()?;
45
46    acc.add(
47        AssistId::refactor_rewrite("convert_nested_function_to_closure"),
48        "Convert nested function to closure",
49        target,
50        |edit| {
51            let params = &param_list.syntax().text().to_string();
52            let params = params.strip_prefix('(').unwrap_or(params);
53            let params = params.strip_suffix(')').unwrap_or(params);
54
55            let mut body = body.to_string();
56            if !has_semicolon(&function) {
57                body.push(';');
58            }
59            edit.replace(target, format!("let {name} = |{params}| {body}"));
60        },
61    )
62}
63
64/// Returns whether the given function is nested within the body of another function.
65fn is_nested_function(function: &ast::Fn) -> bool {
66    function.syntax().ancestors().skip(1).find_map(ast::Item::cast).is_some_and(|it| {
67        matches!(it, ast::Item::Fn(_) | ast::Item::Static(_) | ast::Item::Const(_))
68    })
69}
70
71/// Returns whether the given nested function has generic parameters.
72fn is_generic(function: &ast::Fn) -> bool {
73    function.generic_param_list().is_some()
74}
75
76/// Returns whether the given nested function has any modifiers:
77///
78/// - `async`,
79/// - `const` or
80/// - `unsafe`
81fn has_modifiers(function: &ast::Fn) -> bool {
82    function.async_token().is_some()
83        || function.const_token().is_some()
84        || function.unsafe_token().is_some()
85}
86
87/// Returns whether the given nested function has a trailing semicolon.
88fn has_semicolon(function: &ast::Fn) -> bool {
89    function
90        .syntax()
91        .next_sibling_or_token()
92        .map(|t| t.kind() == SyntaxKind::SEMICOLON)
93        .unwrap_or(false)
94}
95
96#[cfg(test)]
97mod tests {
98    use crate::tests::{check_assist, check_assist_not_applicable};
99
100    use super::convert_nested_function_to_closure;
101
102    #[test]
103    fn convert_nested_function_to_closure_works() {
104        check_assist(
105            convert_nested_function_to_closure,
106            r#"
107fn main() {
108    fn $0foo(a: u64, b: u64) -> u64 {
109        2 * (a + b)
110    }
111
112    _ = foo(3, 4);
113}
114            "#,
115            r#"
116fn main() {
117    let foo = |a: u64, b: u64| {
118        2 * (a + b)
119    };
120
121    _ = foo(3, 4);
122}
123            "#,
124        );
125    }
126
127    #[test]
128    fn convert_nested_function_to_closure_works_with_existing_semicolon() {
129        check_assist(
130            convert_nested_function_to_closure,
131            r#"
132fn main() {
133    fn foo$0(a: u64, b: u64) -> u64 {
134        2 * (a + b)
135    };
136
137    _ = foo(3, 4);
138}
139            "#,
140            r#"
141fn main() {
142    let foo = |a: u64, b: u64| {
143        2 * (a + b)
144    };
145
146    _ = foo(3, 4);
147}
148            "#,
149        );
150    }
151
152    #[test]
153    fn convert_nested_function_to_closure_is_not_suggested_on_top_level_function() {
154        check_assist_not_applicable(
155            convert_nested_function_to_closure,
156            r#"
157fn ma$0in() {}
158            "#,
159        );
160    }
161
162    #[test]
163    fn convert_nested_function_to_closure_is_not_suggested_when_cursor_off_name() {
164        check_assist_not_applicable(
165            convert_nested_function_to_closure,
166            r#"
167fn main() {
168    fn foo(a: u64, $0b: u64) -> u64 {
169        2 * (a + b)
170    }
171
172    _ = foo(3, 4);
173}
174            "#,
175        );
176    }
177
178    #[test]
179    fn convert_nested_function_to_closure_is_not_suggested_if_function_has_generic_params() {
180        check_assist_not_applicable(
181            convert_nested_function_to_closure,
182            r#"
183fn main() {
184    fn fo$0o<S: Into<String>>(s: S) -> String {
185        s.into()
186    }
187
188    _ = foo("hello");
189}
190            "#,
191        );
192    }
193
194    #[test]
195    fn convert_nested_function_to_closure_is_not_suggested_if_function_has_modifier() {
196        check_assist_not_applicable(
197            convert_nested_function_to_closure,
198            r#"
199fn main() {
200    const fn fo$0o(s: String) -> String {
201        s
202    }
203
204    _ = foo("hello");
205}
206            "#,
207        );
208    }
209}