Skip to main content

ide_assists/handlers/
convert_closure_to_fn.rs

1use either::Either;
2use hir::{CaptureKind, ClosureCapture, FileRangeWrapper, HirDisplay};
3use ide_db::{
4    FxHashSet, assists::AssistId, base_db::SourceDatabase, defs::Definition,
5    search::FileReferenceNode, source_change::SourceChangeBuilder,
6};
7use stdx::format_to;
8use syntax::{
9    AstNode, Direction, SyntaxKind, SyntaxNode, T, TextSize, ToSmolStr,
10    algo::{skip_trivia_token, skip_whitespace_token},
11    ast::{
12        self, HasArgList, HasGenericParams, HasName,
13        edit::{AstNodeEdit, IndentLevel},
14        syntax_factory::SyntaxFactory,
15    },
16    hacks::parse_expr_from_str,
17    syntax_editor::SyntaxEditor,
18};
19
20use crate::assist_context::{AssistContext, Assists};
21
22// Assist: convert_closure_to_fn
23//
24// This converts a closure to a freestanding function, changing all captures to parameters.
25//
26// ```
27// # //- minicore: copy, fn
28// # struct String;
29// # impl String {
30// #     fn new() -> Self {}
31// #     fn push_str(&mut self, s: &str) {}
32// # }
33// fn main() {
34//     let mut s = String::new();
35//     let closure = |$0a| s.push_str(a);
36//     closure("abc");
37// }
38// ```
39// ->
40// ```
41// # struct String;
42// # impl String {
43// #     fn new() -> Self {}
44// #     fn push_str(&mut self, s: &str) {}
45// # }
46// fn main() {
47//     let mut s = String::new();
48//     fn closure(a: &str, s: &mut String) {
49//         s.push_str(a)
50//     }
51//     closure("abc", &mut s);
52// }
53// ```
54pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
55    let closure = ctx.find_node_at_offset::<ast::ClosureExpr>()?;
56    if ctx.find_node_at_offset::<ast::Expr>() != Some(ast::Expr::ClosureExpr(closure.clone())) {
57        // Not inside the parameter list.
58        return None;
59    }
60    let closure_name = closure.syntax().parent().and_then(|parent| {
61        let closure_decl = ast::LetStmt::cast(parent)?;
62        match closure_decl.pat()? {
63            ast::Pat::IdentPat(pat) => Some((closure_decl, pat.clone(), pat.name()?)),
64            _ => None,
65        }
66    });
67
68    let (editor, source_root) = SyntaxEditor::new(ctx.source_file().syntax().clone());
69    let make = editor.make();
70
71    let module = ctx.sema.scope(closure.syntax())?.module();
72    let closure_ty = ctx.sema.type_of_expr(&closure.clone().into())?;
73    let callable = closure_ty.original.as_callable(ctx.db())?;
74    let closure_ty = closure_ty.original.as_closure()?;
75
76    let mut ret_ty = callable.return_type();
77    let mut closure_mentioned_generic_params = ret_ty.generic_params(ctx.db());
78
79    let mut params = callable
80        .params()
81        .into_iter()
82        .map(|param| {
83            let node = ctx.sema.source(param.clone())?.value.right()?;
84            let param_ty = param.ty();
85            closure_mentioned_generic_params.extend(param_ty.generic_params(ctx.db()));
86            match node.ty() {
87                Some(_) => Some(node),
88                None => {
89                    let ty = param_ty
90                        .display_source_code(ctx.db(), module.into(), true)
91                        .unwrap_or_else(|_| "_".to_owned());
92                    Some(make.param(node.pat()?, make.ty(&ty)))
93                }
94            }
95        })
96        .collect::<Option<Vec<_>>>()?;
97    let capture_params_start = params.len();
98
99    let closure_param_list = syntax::AstPtr::new(&closure.param_list()?);
100    let body = closure.body()?;
101    let mut is_gen = false;
102    let mut is_async = closure.async_token().is_some();
103    if is_async {
104        ret_ty = ret_ty.future_output(ctx.db())?;
105    }
106    // We defer the wrapping of the body in the block, because `make::block()` will generate a new node,
107    // but we need to locate `AstPtr`s inside the body.
108    let mut wrap_body_in_block = true;
109    if let ast::Expr::BlockExpr(block) = &body {
110        if let Some(async_token) = block.async_token()
111            && !is_async
112        {
113            is_async = true;
114            ret_ty = ret_ty.future_output(ctx.db())?;
115            let end = async_token
116                .siblings_with_tokens(Direction::Next)
117                .skip(1)
118                .take_while(|it| it.kind() == SyntaxKind::WHITESPACE)
119                .last()
120                .unwrap_or_else(|| async_token.clone().into());
121            editor.delete_all(async_token.into()..=end);
122        }
123        if let Some(gen_token) = block.gen_token() {
124            is_gen = true;
125            ret_ty = ret_ty.iterator_item(ctx.db())?;
126            let end = gen_token
127                .siblings_with_tokens(Direction::Next)
128                .skip(1)
129                .take_while(|it| it.kind() == SyntaxKind::WHITESPACE)
130                .last()
131                .unwrap_or_else(|| gen_token.clone().into());
132            editor.delete_all(gen_token.into()..=end);
133        }
134
135        if block.try_block_modifier().is_none()
136            && block.unsafe_token().is_none()
137            && block.label().is_none()
138            && block.const_token().is_none()
139        {
140            wrap_body_in_block = false;
141        }
142    };
143
144    acc.add(
145        AssistId::refactor_rewrite("convert_closure_to_fn"),
146        "Convert closure to fn",
147        closure.param_list()?.syntax().text_range(),
148        |builder| {
149            let make = editor.make();
150            let closure_name_or_default = closure_name
151                .as_ref()
152                .map(|(_, _, it)| it.clone())
153                .unwrap_or_else(|| make.name("fun_name"));
154            let captures = closure_ty.captured_items(ctx.db());
155            let capture_tys =
156                captures.iter().map(|capture| capture.captured_ty(ctx.db())).collect::<Vec<_>>();
157
158            let mut captures_as_args = Vec::with_capacity(captures.len());
159
160            // We need to defer this work because otherwise the text range of elements is being messed up, and
161            // replacements for the next captures won't work.
162            let mut capture_usages_replacement_map = Vec::with_capacity(captures.len());
163
164            for (capture, capture_ty) in std::iter::zip(&captures, &capture_tys) {
165                // FIXME: Allow configuring the replacement of `self`.
166                let is_self = capture.local().is_self(ctx.db()) && !capture.has_field_projections();
167                let capture_name = if is_self {
168                    make.name("this")
169                } else {
170                    make.name(&capture.place_to_name(ctx.db(), ctx.edition()))
171                };
172
173                closure_mentioned_generic_params.extend(capture_ty.generic_params(ctx.db()));
174
175                let capture_ty = capture_ty
176                    .display_source_code(ctx.db(), module.into(), true)
177                    .unwrap_or_else(|_| "_".to_owned());
178                let param = make.param(
179                    ast::Pat::IdentPat(make.ident_pat(false, false, capture_name.clone())),
180                    make.ty(&capture_ty),
181                );
182                if is_self {
183                    // Always put `this` first.
184                    params.insert(capture_params_start, param);
185                } else {
186                    params.push(param);
187                }
188
189                for capture_usage in capture.usages().sources(ctx.db()) {
190                    if capture_usage.file_id() != ctx.file_id() {
191                        // This is from a macro, don't change it.
192                        continue;
193                    }
194
195                    let capture_usage_source = capture_usage.source();
196                    let capture_usage_source = capture_usage_source.to_node(&source_root);
197                    let mut expr = match capture_usage_source {
198                        Either::Left(expr) => expr,
199                        Either::Right(pat) => {
200                            let Some(expr) = expr_of_pat(pat) else { continue };
201                            expr
202                        }
203                    };
204                    if !capture_usage.is_ref() {
205                        expr = peel_ref(expr);
206                    }
207                    let replacement = wrap_capture_in_deref_if_needed(
208                        make,
209                        &expr,
210                        &capture_name,
211                        capture.kind(),
212                        matches!(expr, ast::Expr::RefExpr(_)) || capture_usage.is_ref(),
213                    );
214                    capture_usages_replacement_map.push((expr, replacement));
215                }
216
217                let capture_as_arg = capture_as_arg(make, ctx, capture);
218                if is_self {
219                    captures_as_args.insert(0, capture_as_arg);
220                } else {
221                    captures_as_args.push(capture_as_arg);
222                }
223            }
224
225            let (closure_type_params, closure_where_clause) =
226                compute_closure_type_params(make, ctx, closure_mentioned_generic_params, &closure);
227
228            for (old, new) in capture_usages_replacement_map {
229                editor.replace(old.syntax(), new.syntax());
230            }
231
232            let body = closure_param_list
233                .to_node(editor.finish().new_root())
234                .syntax()
235                .parent()
236                .and_then(ast::ClosureExpr::cast)
237                .and_then(|closure| closure.body())
238                .unwrap();
239
240            let make = SyntaxFactory::without_mappings();
241
242            let body = if wrap_body_in_block {
243                make.block_expr([], Some(body.reset_indent().indent(1.into())))
244            } else {
245                ast::BlockExpr::cast(body.syntax().clone()).unwrap()
246            };
247
248            let params = make.param_list(None, params);
249            let ret_ty = if ret_ty.is_unit() {
250                None
251            } else {
252                let ret_ty = ret_ty
253                    .display_source_code(ctx.db(), module.into(), true)
254                    .unwrap_or_else(|_| "_".to_owned());
255                Some(make.ret_type(make.ty(&ret_ty)))
256            };
257            let mut fn_ = make.fn_(
258                None,
259                None,
260                closure_name_or_default.clone(),
261                closure_type_params,
262                closure_where_clause,
263                params,
264                body,
265                ret_ty,
266                is_async,
267                false,
268                false,
269                is_gen,
270            );
271            fn_ = fn_.dedent(IndentLevel::from_token(&fn_.syntax().last_token().unwrap()));
272
273            match &closure_name {
274                Some((closure_decl, _, _)) => {
275                    fn_ = fn_.indent(closure_decl.indent_level());
276                    builder.replace(closure_decl.syntax().text_range(), fn_.to_string());
277                }
278                None => {
279                    let Some(top_stmt) =
280                        closure.syntax().ancestors().skip(1).find_map(|ancestor| {
281                            ast::Stmt::cast(ancestor.clone()).map(Either::Left).or_else(|| {
282                                ast::ClosureExpr::cast(ancestor.clone())
283                                    .map(Either::Left)
284                                    .or_else(|| ast::BlockExpr::cast(ancestor).map(Either::Right))
285                                    .map(Either::Right)
286                            })
287                        })
288                    else {
289                        return;
290                    };
291                    builder.replace(
292                        closure.syntax().text_range(),
293                        closure_name_or_default.to_string(),
294                    );
295                    match top_stmt {
296                        Either::Left(stmt) => {
297                            let indent = stmt.indent_level();
298                            fn_ = fn_.indent(indent);
299                            let range = stmt
300                                .syntax()
301                                .first_token()
302                                .and_then(|token| {
303                                    skip_whitespace_token(token.prev_token()?, Direction::Prev)
304                                })
305                                .map(|it| it.text_range().end())
306                                .unwrap_or_else(|| stmt.syntax().text_range().start());
307                            builder.insert(range, format!("\n{indent}{fn_}"));
308                        }
309                        Either::Right(Either::Left(closure_inside_closure)) => {
310                            let Some(closure_body) = closure_inside_closure.body() else { return };
311                            // FIXME: Maybe we can indent this properly, adding newlines and all, but this is hard.
312                            builder.insert(
313                                closure_body.syntax().text_range().start(),
314                                format!("{{ {fn_} "),
315                            );
316                            builder
317                                .insert(closure_body.syntax().text_range().end(), " }".to_owned());
318                        }
319                        Either::Right(Either::Right(block_expr)) => {
320                            let Some(tail_expr) = block_expr.tail_expr() else { return };
321                            let Some(insert_in) =
322                                tail_expr.syntax().first_token().and_then(|token| {
323                                    skip_whitespace_token(token.prev_token()?, Direction::Prev)
324                                })
325                            else {
326                                return;
327                            };
328                            let indent = tail_expr.indent_level();
329                            fn_ = fn_.indent(indent);
330                            builder
331                                .insert(insert_in.text_range().end(), format!("\n{indent}{fn_}"));
332                        }
333                    }
334                }
335            }
336
337            handle_calls(
338                builder,
339                ctx,
340                closure_name.as_ref().map(|(_, it, _)| it),
341                &captures_as_args,
342                &closure,
343            );
344
345            // FIXME: Place the cursor at `fun_name`, like rename does.
346        },
347    )?;
348    Some(())
349}
350
351fn compute_closure_type_params(
352    make: &SyntaxFactory,
353    ctx: &AssistContext<'_, '_>,
354    mentioned_generic_params: FxHashSet<hir::GenericParam>,
355    closure: &ast::ClosureExpr,
356) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
357    if mentioned_generic_params.is_empty() {
358        return (None, None);
359    }
360
361    let mut mentioned_names = mentioned_generic_params
362        .iter()
363        .filter_map(|param| match param {
364            hir::GenericParam::TypeParam(param) => Some(param.name(ctx.db()).as_str().to_smolstr()),
365            hir::GenericParam::ConstParam(param) => {
366                Some(param.name(ctx.db()).as_str().to_smolstr())
367            }
368            hir::GenericParam::LifetimeParam(_) => None,
369        })
370        .collect::<FxHashSet<_>>();
371
372    let Some((container_params, container_where, container)) =
373        closure.syntax().ancestors().find_map(ast::AnyHasGenericParams::cast).and_then(
374            |container| {
375                Some((container.generic_param_list()?, container.where_clause(), container))
376            },
377        )
378    else {
379        return (None, None);
380    };
381    let containing_impl = if ast::AssocItem::can_cast(container.syntax().kind()) {
382        container
383            .syntax()
384            .ancestors()
385            .find_map(ast::Impl::cast)
386            .and_then(|impl_| Some((impl_.generic_param_list()?, impl_.where_clause())))
387    } else {
388        None
389    };
390
391    let all_params = container_params
392        .type_or_const_params()
393        .chain(containing_impl.iter().flat_map(|(param_list, _)| param_list.type_or_const_params()))
394        .filter_map(|param| Some(param.name()?.text().to_smolstr()))
395        .collect::<FxHashSet<_>>();
396
397    // A fixpoint algorithm to detect (very roughly) if we need to include a generic parameter
398    // by checking if it is mentioned by another parameter we need to include.
399    let mut reached_fixpoint = false;
400    let mut container_where_bounds_indices = Vec::new();
401    let mut impl_where_bounds_indices = Vec::new();
402    while !reached_fixpoint {
403        reached_fixpoint = true;
404
405        let mut insert_name = |syntax: &SyntaxNode| {
406            let has_name = syntax
407                .descendants()
408                .filter_map(ast::NameOrNameRef::cast)
409                .any(|name| mentioned_names.contains(name.text().trim_start_matches("r#")));
410            let mut has_new_params = false;
411            if has_name {
412                syntax
413                    .descendants()
414                    .filter_map(ast::NameOrNameRef::cast)
415                    .filter(|name| all_params.contains(name.text().trim_start_matches("r#")))
416                    .for_each(|name| {
417                        if mentioned_names.insert(name.text().trim_start_matches("r#").to_smolstr())
418                        {
419                            // We do this here so we don't do it if there are only matches that are not in `all_params`.
420                            has_new_params = true;
421                            reached_fixpoint = false;
422                        }
423                    });
424            }
425            has_new_params
426        };
427
428        for param in container_params.type_or_const_params() {
429            insert_name(param.syntax());
430        }
431        for (pred_index, pred) in container_where.iter().flat_map(|it| it.predicates()).enumerate()
432        {
433            if insert_name(pred.syntax()) {
434                container_where_bounds_indices.push(pred_index);
435            }
436        }
437        if let Some((impl_params, impl_where)) = &containing_impl {
438            for param in impl_params.type_or_const_params() {
439                insert_name(param.syntax());
440            }
441            for (pred_index, pred) in impl_where.iter().flat_map(|it| it.predicates()).enumerate() {
442                if insert_name(pred.syntax()) {
443                    impl_where_bounds_indices.push(pred_index);
444                }
445            }
446        }
447    }
448
449    // Order matters here (for beauty). First the outer impl parameters, then the direct container's.
450    let include_params = containing_impl
451        .iter()
452        .flat_map(|(impl_params, _)| {
453            impl_params.type_or_const_params().filter(|param| {
454                param.name().is_some_and(|name| {
455                    mentioned_names.contains(name.text().trim_start_matches("r#"))
456                })
457            })
458        })
459        .chain(container_params.type_or_const_params().filter(|param| {
460            param
461                .name()
462                .is_some_and(|name| mentioned_names.contains(name.text().trim_start_matches("r#")))
463        }))
464        .map(ast::TypeOrConstParam::into);
465    let include_where_bounds = containing_impl
466        .as_ref()
467        .and_then(|(_, it)| it.as_ref())
468        .into_iter()
469        .flat_map(|where_| {
470            impl_where_bounds_indices.iter().filter_map(|&index| where_.predicates().nth(index))
471        })
472        .chain(container_where.iter().flat_map(|where_| {
473            container_where_bounds_indices
474                .iter()
475                .filter_map(|&index| where_.predicates().nth(index))
476        }))
477        .collect::<Vec<_>>();
478    let where_clause =
479        (!include_where_bounds.is_empty()).then(|| make.where_clause(include_where_bounds));
480
481    // FIXME: Consider generic parameters that do not appear in params/return type/captures but
482    // written explicitly inside the closure.
483    (Some(make.generic_param_list(include_params)), where_clause)
484}
485
486fn peel_parens(mut expr: ast::Expr) -> ast::Expr {
487    loop {
488        if ast::ParenExpr::can_cast(expr.syntax().kind()) {
489            let Some(parent) = expr.syntax().parent().and_then(ast::Expr::cast) else { break };
490            expr = parent;
491        } else {
492            break;
493        }
494    }
495    expr
496}
497
498fn peel_ref(mut expr: ast::Expr) -> ast::Expr {
499    expr = peel_parens(expr);
500    expr.syntax().parent().and_then(ast::RefExpr::cast).map(Into::into).unwrap_or(expr)
501}
502
503fn wrap_capture_in_deref_if_needed(
504    make: &SyntaxFactory,
505    expr: &ast::Expr,
506    capture_name: &ast::Name,
507    capture_kind: CaptureKind,
508    is_ref: bool,
509) -> ast::Expr {
510    let capture_name = make.expr_path(make.path_from_text(&capture_name.text()));
511    if capture_kind == CaptureKind::Move || is_ref {
512        return capture_name;
513    }
514    let expr_parent = expr.syntax().parent().and_then(ast::Expr::cast);
515    let expr_parent_peeled_parens = expr_parent.map(peel_parens);
516    let does_autoderef = match expr_parent_peeled_parens {
517        Some(
518            ast::Expr::AwaitExpr(_)
519            | ast::Expr::CallExpr(_)
520            | ast::Expr::FieldExpr(_)
521            | ast::Expr::FormatArgsExpr(_)
522            | ast::Expr::MethodCallExpr(_),
523        ) => true,
524        Some(ast::Expr::IndexExpr(parent_expr)) if parent_expr.base().as_ref() == Some(expr) => {
525            true
526        }
527        _ => false,
528    };
529    if does_autoderef {
530        return capture_name;
531    }
532    make.expr_prefix(T![*], capture_name).into()
533}
534
535fn capture_as_arg(
536    make: &SyntaxFactory,
537    ctx: &AssistContext<'_, '_>,
538    capture: &ClosureCapture<'_>,
539) -> ast::Expr {
540    let place = parse_expr_from_str(
541        &capture.display_place_source_code(ctx.db(), ctx.edition()),
542        ctx.edition(),
543    )
544    .expect("`display_place_source_code()` produced an invalid expr");
545    let needs_mut = match capture.kind() {
546        CaptureKind::SharedRef => false,
547        CaptureKind::MutableRef | CaptureKind::UniqueSharedRef => true,
548        CaptureKind::Move => return place,
549    };
550    if let ast::Expr::PrefixExpr(expr) = &place
551        && expr.op_kind() == Some(ast::UnaryOp::Deref)
552    {
553        return expr.expr().expect("`display_place_source_code()` produced an invalid expr");
554    }
555    make.expr_ref(place, needs_mut)
556}
557
558fn handle_calls(
559    builder: &mut SourceChangeBuilder,
560    ctx: &AssistContext<'_, '_>,
561    closure_name: Option<&ast::IdentPat>,
562    captures_as_args: &[ast::Expr],
563    closure: &ast::ClosureExpr,
564) {
565    if captures_as_args.is_empty() {
566        return;
567    }
568
569    match closure_name {
570        Some(closure_name) => {
571            let Some(closure_def) = ctx.sema.to_def(closure_name) else { return };
572            let closure_usages = Definition::from(closure_def).usages(&ctx.sema).all();
573            for (_, usages) in closure_usages {
574                for usage in usages {
575                    let name = match usage.name {
576                        FileReferenceNode::Name(name) => name.syntax().clone(),
577                        FileReferenceNode::NameRef(name_ref) => name_ref.syntax().clone(),
578                        FileReferenceNode::FormatStringEntry(..) => continue,
579                        FileReferenceNode::Lifetime(_) => {
580                            unreachable!("impossible usage")
581                        }
582                    };
583                    let Some(expr) = name.parent().and_then(|it| {
584                        ast::Expr::cast(
585                            ast::PathSegment::cast(it)?.parent_path().syntax().parent()?,
586                        )
587                    }) else {
588                        continue;
589                    };
590                    handle_call(builder, ctx, expr, captures_as_args);
591                }
592            }
593        }
594        None => {
595            handle_call(builder, ctx, ast::Expr::ClosureExpr(closure.clone()), captures_as_args);
596        }
597    }
598}
599
600fn handle_call(
601    builder: &mut SourceChangeBuilder,
602    ctx: &AssistContext<'_, '_>,
603    closure_ref: ast::Expr,
604    captures_as_args: &[ast::Expr],
605) -> Option<()> {
606    let call =
607        ast::CallExpr::cast(peel_blocks_and_refs_and_parens(closure_ref).syntax().parent()?)?;
608    let args = call.arg_list()?;
609    // The really last token is `)`; we need one before that.
610    let has_trailing_comma = args.syntax().last_token()?.prev_token().is_some_and(|token| {
611        skip_trivia_token(token, Direction::Prev).is_some_and(|token| token.kind() == T![,])
612    });
613    let has_existing_args = args.args().next().is_some();
614
615    let FileRangeWrapper { file_id, range } = ctx.sema.original_range_opt(args.syntax())?;
616    let first_arg_indent = args.args().next().map(|it| it.indent_level());
617    let arg_list_indent = args.indent_level();
618    let insert_newlines =
619        first_arg_indent.is_some_and(|first_arg_indent| first_arg_indent != arg_list_indent);
620    let indent =
621        if insert_newlines { first_arg_indent.unwrap().to_string() } else { String::new() };
622    // FIXME: This text manipulation seems risky.
623    let text = ctx.db().file_text(file_id.file_id(ctx.db())).text(ctx.db());
624    let mut text = text[..u32::from(range.end()).try_into().unwrap()].trim_end();
625    if !text.ends_with(')') {
626        return None;
627    }
628    text = text[..text.len() - 1].trim_end();
629    let offset = TextSize::new(text.len().try_into().unwrap());
630
631    let mut to_insert = String::new();
632    if has_existing_args && !has_trailing_comma {
633        to_insert.push(',');
634    }
635    if insert_newlines {
636        to_insert.push('\n');
637    }
638    let (last_arg, rest_args) =
639        captures_as_args.split_last().expect("already checked has captures");
640    if !insert_newlines && has_existing_args {
641        to_insert.push(' ');
642    }
643    if let Some((first_arg, rest_args)) = rest_args.split_first() {
644        format_to!(to_insert, "{indent}{first_arg},",);
645        if insert_newlines {
646            to_insert.push('\n');
647        }
648        for new_arg in rest_args {
649            if !insert_newlines {
650                to_insert.push(' ');
651            }
652            format_to!(to_insert, "{indent}{new_arg},",);
653            if insert_newlines {
654                to_insert.push('\n');
655            }
656        }
657        if !insert_newlines {
658            to_insert.push(' ');
659        }
660    }
661    format_to!(to_insert, "{indent}{last_arg}");
662    if has_trailing_comma {
663        to_insert.push(',');
664    }
665
666    builder.edit_file(file_id.file_id(ctx.db()));
667    builder.insert(offset, to_insert);
668
669    Some(())
670}
671
672fn peel_blocks_and_refs_and_parens(mut expr: ast::Expr) -> ast::Expr {
673    loop {
674        let Some(parent) = expr.syntax().parent() else { break };
675        if matches!(parent.kind(), SyntaxKind::PAREN_EXPR | SyntaxKind::REF_EXPR) {
676            expr = ast::Expr::cast(parent).unwrap();
677            continue;
678        }
679        if let Some(stmt_list) = ast::StmtList::cast(parent)
680            && let Some(block) = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)
681        {
682            expr = ast::Expr::BlockExpr(block);
683            continue;
684        }
685        break;
686    }
687    expr
688}
689
690// FIXME:
691// Somehow handle the case of `let Struct { field, .. } = capture`.
692// Replacing `capture` with `capture_field` won't work.
693fn expr_of_pat(pat: ast::Pat) -> Option<ast::Expr> {
694    'find_expr: {
695        for ancestor in pat.syntax().ancestors() {
696            if let Some(let_stmt) = ast::LetStmt::cast(ancestor.clone()) {
697                break 'find_expr let_stmt.initializer();
698            }
699            if ast::MatchArm::can_cast(ancestor.kind())
700                && let Some(match_) =
701                    ancestor.parent().and_then(|it| it.parent()).and_then(ast::MatchExpr::cast)
702            {
703                break 'find_expr match_.expr();
704            }
705            if ast::ExprStmt::can_cast(ancestor.kind()) {
706                break;
707            }
708        }
709        None
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use crate::tests::{check_assist, check_assist_not_applicable};
716
717    use super::*;
718
719    #[test]
720    fn handles_unique_captures() {
721        check_assist(
722            convert_closure_to_fn,
723            r#"
724//- minicore: copy, fn
725fn main() {
726    let s = &mut true;
727    let closure = |$0| { *s = false; };
728    closure();
729}
730"#,
731            r#"
732fn main() {
733    let s = &mut true;
734    fn closure(s: &mut bool) { *s = false; }
735    closure(s);
736}
737"#,
738        );
739    }
740
741    #[test]
742    fn multiple_capture_usages() {
743        check_assist(
744            convert_closure_to_fn,
745            r#"
746//- minicore: copy, fn
747struct A { a: i32, b: bool }
748fn main() {
749    let mut a = A { a: 123, b: false };
750    let closure = |$0| {
751        let b = a.b;
752        a = A { a: 456, b: true };
753    };
754    closure();
755}
756"#,
757            r#"
758struct A { a: i32, b: bool }
759fn main() {
760    let mut a = A { a: 123, b: false };
761    fn closure(a: &mut A) {
762        let b = a.b;
763        *a = A { a: 456, b: true };
764    }
765    closure(&mut a);
766}
767"#,
768        );
769    }
770
771    #[test]
772    fn changes_names_of_place() {
773        check_assist(
774            convert_closure_to_fn,
775            r#"
776//- minicore: copy, fn
777struct A { b: &'static mut B, c: i32 }
778struct B(bool, i32);
779struct C;
780impl C {
781    fn foo(&self) {
782        let a = A { b: &B(false, 0), c: 123 };
783        let closure = |$0| {
784            let b = a.b.1;
785            let c = &*self;
786        };
787        closure();
788    }
789}
790"#,
791            r#"
792struct A { b: &'static mut B, c: i32 }
793struct B(bool, i32);
794struct C;
795impl C {
796    fn foo(&self) {
797        let a = A { b: &B(false, 0), c: 123 };
798        fn closure(this: &C, a_b_1: &i32) {
799            let b = *a_b_1;
800            let c = this;
801        }
802        closure(self, &a.b.1);
803    }
804}
805"#,
806        );
807    }
808
809    #[test]
810    fn self_with_fields_does_not_change_to_this() {
811        check_assist(
812            convert_closure_to_fn,
813            r#"
814//- minicore: copy, fn
815struct A { b: &'static B, c: i32 }
816struct B(bool, i32);
817impl A {
818    fn foo(&self) {
819        let closure = |$0| {
820            let b = self.b.1;
821        };
822        closure();
823    }
824}
825"#,
826            r#"
827struct A { b: &'static B, c: i32 }
828struct B(bool, i32);
829impl A {
830    fn foo(&self) {
831        fn closure(self_b: &B) {
832            let b = self_b.1;
833        }
834        closure(self.b);
835    }
836}
837"#,
838        );
839    }
840
841    #[test]
842    fn replaces_async_closure_with_async_fn() {
843        check_assist(
844            convert_closure_to_fn,
845            r#"
846//- minicore: copy, future, async_fn
847fn foo(&self) {
848    let closure = async |$0| 1;
849    closure();
850}
851"#,
852            r#"
853fn foo(&self) {
854    async fn closure() -> i32 {
855        1
856    }
857    closure();
858}
859"#,
860        );
861    }
862
863    #[test]
864    fn replaces_async_block_with_async_fn() {
865        check_assist(
866            convert_closure_to_fn,
867            r#"
868//- minicore: copy, future, fn
869fn foo() {
870    let closure = |$0| async { 1 };
871    closure();
872}
873"#,
874            r#"
875fn foo() {
876    async fn closure() -> i32 { 1 }
877    closure();
878}
879"#,
880        );
881    }
882
883    #[test]
884    #[ignore = "FIXME: we do not do type inference for gen blocks yet"]
885    fn replaces_gen_block_with_gen_fn() {
886        check_assist(
887            convert_closure_to_fn,
888            r#"
889//- minicore: copy, iterator
890//- /lib.rs edition:2024
891fn foo() {
892    let closure = |$0| gen {
893        yield 1;
894    };
895    closure();
896}
897"#,
898            r#"
899fn foo() {
900    gen fn closure() -> i32 {
901        yield 1;
902    }
903    closure();
904}
905"#,
906        );
907    }
908
909    #[test]
910    fn leaves_block_in_place() {
911        check_assist(
912            convert_closure_to_fn,
913            r#"
914//- minicore: copy, fn
915fn foo() {
916    let closure = |$0| {};
917    closure();
918}
919"#,
920            r#"
921fn foo() {
922    fn closure() {}
923    closure();
924}
925"#,
926        );
927    }
928
929    #[test]
930    fn wraps_in_block_if_needed() {
931        check_assist(
932            convert_closure_to_fn,
933            r#"
934//- minicore: copy, fn
935fn foo() {
936    let a = 1;
937    let closure = |$0| a;
938    closure();
939}
940"#,
941            r#"
942fn foo() {
943    let a = 1;
944    fn closure(a: &i32) -> i32 {
945        *a
946    }
947    closure(&a);
948}
949"#,
950        );
951        check_assist(
952            convert_closure_to_fn,
953            r#"
954//- minicore: copy, fn
955fn foo() {
956    let closure = |$0| 'label: {};
957    closure();
958}
959"#,
960            r#"
961fn foo() {
962    fn closure() {
963        'label: {}
964    }
965    closure();
966}
967"#,
968        );
969        check_assist(
970            convert_closure_to_fn,
971            r#"
972//- minicore: copy, fn
973fn foo() {
974    let closure = |$0| {
975        const { () }
976    };
977    closure();
978}
979"#,
980            r#"
981fn foo() {
982    fn closure() {
983        const { () }
984    }
985    closure();
986}
987"#,
988        );
989        check_assist(
990            convert_closure_to_fn,
991            r#"
992//- minicore: copy, fn
993fn foo() {
994    let closure = |$0| unsafe { };
995    closure();
996}
997"#,
998            r#"
999fn foo() {
1000    fn closure() {
1001        unsafe { }
1002    }
1003    closure();
1004}
1005"#,
1006        );
1007        check_assist(
1008            convert_closure_to_fn,
1009            r#"
1010//- minicore: copy, fn
1011fn foo() {
1012    {
1013        let closure = |$0| match () {
1014            () => {},
1015        };
1016        closure();
1017    }
1018}
1019"#,
1020            r#"
1021fn foo() {
1022    {
1023        fn closure() {
1024            match () {
1025                () => {},
1026            }
1027        }
1028        closure();
1029    }
1030}
1031"#,
1032        );
1033    }
1034
1035    #[test]
1036    fn closure_in_closure() {
1037        check_assist(
1038            convert_closure_to_fn,
1039            r#"
1040//- minicore: copy
1041fn foo() {
1042    let a = 1;
1043    || |$0| { let b = &a; };
1044}
1045"#,
1046            r#"
1047fn foo() {
1048    let a = 1;
1049    || { fn fun_name(a: &i32) { let b = a; } fun_name };
1050}
1051"#,
1052        );
1053    }
1054
1055    #[test]
1056    fn closure_in_block() {
1057        check_assist(
1058            convert_closure_to_fn,
1059            r#"
1060//- minicore: copy
1061fn foo() {
1062    {
1063        let a = 1;
1064        |$0| { let b = &a; }
1065    };
1066}
1067"#,
1068            r#"
1069fn foo() {
1070    {
1071        let a = 1;
1072        fn fun_name(a: &i32) { let b = a; }
1073        fun_name
1074    };
1075}
1076"#,
1077        );
1078    }
1079
1080    #[test]
1081    fn finds_pat_for_expr() {
1082        check_assist(
1083            convert_closure_to_fn,
1084            r#"
1085//- minicore: copy, fn
1086struct A { b: B }
1087struct B(bool, i32);
1088fn foo() {
1089    let mut a = A { b: B(true, 0) };
1090    let closure = |$0| {
1091        let A { b: B(_, ref mut c) } = a;
1092    };
1093    closure();
1094}
1095"#,
1096            r#"
1097struct A { b: B }
1098struct B(bool, i32);
1099fn foo() {
1100    let mut a = A { b: B(true, 0) };
1101    fn closure(a_b_1: &mut i32) {
1102        let A { b: B(_, ref mut c) } = *a_b_1;
1103    }
1104    closure(&mut a.b.1);
1105}
1106"#,
1107        );
1108    }
1109
1110    #[test]
1111    fn with_existing_params() {
1112        check_assist(
1113            convert_closure_to_fn,
1114            r#"
1115//- minicore: copy, fn
1116fn foo() {
1117    let (mut a, b) = (0.1, "abc");
1118    let closure = |$0p1: i32, p2: &mut bool| {
1119        a = 1.2;
1120        let c = b;
1121    };
1122    closure(0, &mut false);
1123}
1124"#,
1125            r#"
1126fn foo() {
1127    let (mut a, b) = (0.1, "abc");
1128    fn closure(p1: i32, p2: &mut bool, a: &mut f64, b: &&str) {
1129        *a = 1.2;
1130        let c = *b;
1131    }
1132    closure(0, &mut false, &mut a, &b);
1133}
1134"#,
1135        );
1136    }
1137
1138    #[test]
1139    fn with_existing_params_newlines() {
1140        check_assist(
1141            convert_closure_to_fn,
1142            r#"
1143//- minicore: copy, fn
1144fn foo() {
1145    let (mut a, b) = (0.1, "abc");
1146    let closure = |$0p1: i32, p2| {
1147        let _: &mut bool = p2;
1148        a = 1.2;
1149        let c = b;
1150    };
1151    closure(
1152        0,
1153        &mut false
1154    );
1155}
1156"#,
1157            r#"
1158fn foo() {
1159    let (mut a, b) = (0.1, "abc");
1160    fn closure(p1: i32, p2: &mut bool, a: &mut f64, b: &&str) {
1161        let _: &mut bool = p2;
1162        *a = 1.2;
1163        let c = *b;
1164    }
1165    closure(
1166        0,
1167        &mut false,
1168        &mut a,
1169        &b
1170    );
1171}
1172"#,
1173        );
1174    }
1175
1176    #[test]
1177    fn with_existing_params_trailing_comma() {
1178        check_assist(
1179            convert_closure_to_fn,
1180            r#"
1181//- minicore: copy, fn
1182fn foo() {
1183    let (mut a, b) = (0.1, "abc");
1184    let closure = |$0p1: i32, p2| {
1185        let _: &mut bool = p2;
1186        a = 1.2;
1187        let c = b;
1188    };
1189    closure(
1190        0,
1191        &mut false,
1192    );
1193}
1194"#,
1195            r#"
1196fn foo() {
1197    let (mut a, b) = (0.1, "abc");
1198    fn closure(p1: i32, p2: &mut bool, a: &mut f64, b: &&str) {
1199        let _: &mut bool = p2;
1200        *a = 1.2;
1201        let c = *b;
1202    }
1203    closure(
1204        0,
1205        &mut false,
1206        &mut a,
1207        &b,
1208    );
1209}
1210"#,
1211        );
1212    }
1213
1214    #[test]
1215    fn closure_using_generic_params() {
1216        check_assist(
1217            convert_closure_to_fn,
1218            r#"
1219//- minicore: copy, from
1220struct Foo<A, B, const C: usize>(A, B);
1221impl<A, B: From<A>, const C: usize> Foo<A, B, C> {
1222    fn foo<D, E, F, G>(a: A, b: D)
1223    where
1224        E: From<D>,
1225    {
1226        let closure = |$0c: F| {
1227            let a = B::from(a);
1228            let b = E::from(b);
1229        };
1230    }
1231}
1232"#,
1233            r#"
1234struct Foo<A, B, const C: usize>(A, B);
1235impl<A, B: From<A>, const C: usize> Foo<A, B, C> {
1236    fn foo<D, E, F, G>(a: A, b: D)
1237    where
1238        E: From<D>,
1239    {
1240        fn closure<A, B: From<A>, D, E, F>(c: F, a: A, b: D) where E: From<D> {
1241            let a = B::from(a);
1242            let b = E::from(b);
1243        }
1244    }
1245}
1246"#,
1247        );
1248    }
1249
1250    #[test]
1251    fn closure_in_stmt() {
1252        check_assist(
1253            convert_closure_to_fn,
1254            r#"
1255//- minicore: copy
1256fn bar(_: impl FnOnce() -> i32) {}
1257fn foo() {
1258    let a = 123;
1259    bar(|$0| a);
1260}
1261"#,
1262            r#"
1263fn bar(_: impl FnOnce() -> i32) {}
1264fn foo() {
1265    let a = 123;
1266    fn fun_name(a: &i32) -> i32 {
1267        *a
1268    }
1269    bar(fun_name);
1270}
1271"#,
1272        );
1273    }
1274
1275    #[test]
1276    fn unique_and_imm() {
1277        check_assist(
1278            convert_closure_to_fn,
1279            r#"
1280//- minicore: copy, fn
1281fn main() {
1282    let a = &mut true;
1283    let closure = |$0| {
1284        let b = &a;
1285        *a = false;
1286    };
1287    closure();
1288}
1289"#,
1290            r#"
1291fn main() {
1292    let a = &mut true;
1293    fn closure(a: &mut &mut bool) {
1294        let b = a;
1295        **a = false;
1296    }
1297    closure(&mut a);
1298}
1299"#,
1300        );
1301    }
1302
1303    #[test]
1304    fn only_applicable_in_param_list() {
1305        check_assist_not_applicable(
1306            convert_closure_to_fn,
1307            r#"
1308//- minicore:copy
1309fn main() {
1310    let closure = || { $0 };
1311}
1312"#,
1313        );
1314        check_assist_not_applicable(
1315            convert_closure_to_fn,
1316            r#"
1317//- minicore:copy
1318fn main() {
1319    let $0closure = || { };
1320}
1321"#,
1322        );
1323    }
1324}