Skip to main content

ide_assists/
utils.rs

1//! Assorted functions shared by several assists.
2
3use std::slice;
4
5pub(crate) use gen_trait_fn_body::gen_trait_fn_body;
6use hir::{
7    HasAttrs as HirHasAttrs, HirDisplay, InFile, ModuleDef, PathResolution, Semantics,
8    db::{ExpandDatabase, HirDatabase},
9};
10use ide_db::{
11    RootDatabase,
12    assists::ExprFillDefaultMode,
13    famous_defs::FamousDefs,
14    path_transform::PathTransform,
15    syntax_helpers::{node_ext::preorder_expr, prettify_macro_expansion},
16};
17use itertools::Itertools;
18use syntax::{
19    AstNode, AstToken, Direction, NodeOrToken, SourceFile,
20    SyntaxKind::*,
21    SyntaxNode, SyntaxToken, T, TextRange, TextSize, WalkEvent,
22    ast::{
23        self, HasArgList, HasAttrs, HasGenericParams, HasName, HasTypeBounds, Whitespace,
24        edit::{AstNodeEdit, AttrsOwnerEdit, IndentLevel},
25        make,
26        prec::ExprPrecedence,
27        syntax_factory::SyntaxFactory,
28    },
29    syntax_editor::{Element, Removable, SyntaxEditor},
30};
31
32use crate::{
33    AssistConfig,
34    assist_context::{AssistContext, SourceChangeBuilder},
35};
36
37mod gen_trait_fn_body;
38pub(crate) mod ref_field_expr;
39
40pub(crate) fn unwrap_trivial_block(block_expr: ast::BlockExpr) -> ast::Expr {
41    extract_trivial_expression(&block_expr)
42        .filter(|expr| !expr.syntax().text().contains_char('\n'))
43        .unwrap_or_else(|| block_expr.into())
44}
45
46pub fn extract_trivial_expression(block_expr: &ast::BlockExpr) -> Option<ast::Expr> {
47    if block_expr.modifier().is_some() {
48        return None;
49    }
50    let stmt_list = block_expr.stmt_list()?;
51    let has_anything_else = |thing: &SyntaxNode| -> bool {
52        let mut non_trivial_children =
53            stmt_list.syntax().children_with_tokens().filter(|it| match it.kind() {
54                WHITESPACE | T!['{'] | T!['}'] => false,
55                _ => it.as_node() != Some(thing),
56            });
57        non_trivial_children.next().is_some()
58    };
59    if stmt_list
60        .syntax()
61        .children_with_tokens()
62        .filter_map(NodeOrToken::into_token)
63        .any(|token| token.kind() == syntax::SyntaxKind::COMMENT)
64    {
65        return None;
66    }
67
68    if let Some(expr) = stmt_list.tail_expr() {
69        if has_anything_else(expr.syntax()) {
70            return None;
71        }
72        return Some(expr);
73    }
74    // Unwrap `{ continue; }`
75    let stmt = stmt_list.statements().next()?;
76    if let ast::Stmt::ExprStmt(expr_stmt) = stmt {
77        if has_anything_else(expr_stmt.syntax()) {
78            return None;
79        }
80        let expr = expr_stmt.expr()?;
81        if matches!(expr.syntax().kind(), CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR) {
82            return Some(expr);
83        }
84    }
85    None
86}
87
88pub(crate) fn wrap_block(expr: &ast::Expr, make: &SyntaxFactory) -> ast::BlockExpr {
89    if let ast::Expr::BlockExpr(block) = expr
90        && let Some(first) = block.syntax().first_token()
91        && first.kind() == T!['{']
92    {
93        block.reset_indent()
94    } else {
95        make.block_expr(None, Some(expr.reset_indent().indent(1.into())))
96    }
97}
98
99pub(crate) fn wrap_paren(expr: ast::Expr, make: &SyntaxFactory, prec: ExprPrecedence) -> ast::Expr {
100    if expr.precedence().needs_parentheses_in(prec) { make.expr_paren(expr).into() } else { expr }
101}
102
103pub(crate) fn wrap_paren_in_call(expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr {
104    if needs_parens_in_call(make, &expr) { make.expr_paren(expr).into() } else { expr }
105}
106
107fn needs_parens_in_call(make: &SyntaxFactory, param: &ast::Expr) -> bool {
108    let call = make.expr_call(make.expr_unit(), make.arg_list(Vec::new()));
109    let callable = call.expr().expect("invalid make call");
110    param.needs_parens_in_place_of(call.syntax(), callable.syntax())
111}
112
113/// This is a method with a heuristics to support test methods annotated with custom test annotations, such as
114/// `#[test_case(...)]`, `#[tokio::test]` and similar.
115/// Also a regular `#[test]` annotation is supported.
116///
117/// It may produce false positives, for example, `#[wasm_bindgen_test]` requires a different command to run the test,
118/// but it's better than not to have the runnables for the tests at all.
119pub fn test_related_attribute_syn(fn_def: &ast::Fn) -> Option<ast::Attr> {
120    fn_def.attrs().find_map(|attr| {
121        let path = attr.path()?;
122        let text = path.syntax().text().to_string();
123        if text.starts_with("test") || text.ends_with("test") { Some(attr) } else { None }
124    })
125}
126
127pub fn has_test_related_attribute(attrs: &hir::AttrsWithOwner) -> bool {
128    attrs.is_test()
129}
130
131#[derive(Clone, Copy, PartialEq)]
132pub enum IgnoreAssocItems {
133    DocHiddenAttrPresent,
134    No,
135}
136
137#[derive(Copy, Clone, PartialEq)]
138pub enum DefaultMethods {
139    Only,
140    No,
141}
142
143pub fn filter_assoc_items(
144    sema: &Semantics<'_, RootDatabase>,
145    items: &[hir::AssocItem],
146    default_methods: DefaultMethods,
147    ignore_items: IgnoreAssocItems,
148) -> Vec<InFile<ast::AssocItem>> {
149    return items
150        .iter()
151        .copied()
152        .filter(|assoc_item| {
153            if ignore_items == IgnoreAssocItems::DocHiddenAttrPresent
154                && assoc_item.attrs(sema.db).is_doc_hidden()
155            {
156                if let hir::AssocItem::Function(f) = assoc_item
157                    && !f.has_body(sema.db)
158                {
159                    return true;
160                }
161                return false;
162            }
163
164            true
165        })
166        // Note: This throws away items with no source.
167        .filter_map(|assoc_item| {
168            let item = match assoc_item {
169                hir::AssocItem::Function(it) => sema.source(it)?.map(ast::AssocItem::Fn),
170                hir::AssocItem::TypeAlias(it) => sema.source(it)?.map(ast::AssocItem::TypeAlias),
171                hir::AssocItem::Const(it) => sema.source(it)?.map(ast::AssocItem::Const),
172            };
173            Some(item)
174        })
175        .filter(has_def_name)
176        .filter(|it| match &it.value {
177            ast::AssocItem::Fn(def) => matches!(
178                (default_methods, def.body()),
179                (DefaultMethods::Only, Some(_)) | (DefaultMethods::No, None)
180            ),
181            ast::AssocItem::Const(def) => matches!(
182                (default_methods, def.body()),
183                (DefaultMethods::Only, Some(_)) | (DefaultMethods::No, None)
184            ),
185            ast::AssocItem::TypeAlias(def) => matches!(
186                (default_methods, def.ty()),
187                (DefaultMethods::Only, Some(_)) | (DefaultMethods::No, None)
188            ),
189            ast::AssocItem::MacroCall(_) => unreachable!(),
190        })
191        .collect();
192
193    fn has_def_name(item: &InFile<ast::AssocItem>) -> bool {
194        match &item.value {
195            ast::AssocItem::Fn(def) => def.name(),
196            ast::AssocItem::TypeAlias(def) => def.name(),
197            ast::AssocItem::Const(def) => def.name(),
198            ast::AssocItem::MacroCall(_) => None,
199        }
200        .is_some()
201    }
202}
203
204/// Given `original_items` retrieved from the trait definition (usually by
205/// [`filter_assoc_items()`]), clones each item for update and applies path transformation to it,
206/// then inserts into `impl_`. Returns the modified `impl_` and the first associated item that got
207/// inserted.
208#[must_use]
209pub fn add_trait_assoc_items_to_impl(
210    make: &SyntaxFactory,
211    sema: &Semantics<'_, RootDatabase>,
212    config: &AssistConfig,
213    original_items: &[InFile<ast::AssocItem>],
214    trait_: hir::Trait,
215    impl_: &ast::Impl,
216    target_scope: &hir::SemanticsScope<'_>,
217) -> Vec<ast::AssocItem> {
218    let new_indent_level = IndentLevel::from_node(impl_.syntax()) + 1;
219    original_items
220        .iter()
221        .map(|InFile { file_id, value: original_item }| {
222            let mut cloned_item = {
223                if let Some(macro_file) = file_id.macro_file() {
224                    let span_map = sema.db.expansion_span_map(macro_file);
225                    let item_prettified = prettify_macro_expansion(
226                        sema.db,
227                        original_item.syntax().clone(),
228                        span_map,
229                        target_scope.krate().into(),
230                    );
231                    if let Some(formatted) = ast::AssocItem::cast(item_prettified) {
232                        return formatted;
233                    } else {
234                        stdx::never!("formatted `AssocItem` could not be cast back to `AssocItem`");
235                    }
236                }
237                original_item
238            }
239            .reset_indent();
240
241            if let Some(source_scope) = sema.scope(original_item.syntax()) {
242                // FIXME: Paths in nested macros are not handled well. See
243                // `add_missing_impl_members::paths_in_nested_macro_should_get_transformed` test.
244                let transform =
245                    PathTransform::trait_impl(target_scope, &source_scope, trait_, impl_.clone());
246                cloned_item = ast::AssocItem::cast(transform.apply(cloned_item.syntax())).unwrap();
247            }
248            let (editor, cloned_item) = SyntaxEditor::with_ast_node(&cloned_item);
249            cloned_item.remove_attrs_and_docs(&editor);
250            ast::AssocItem::cast(editor.finish().new_root().clone()).unwrap()
251        })
252        .filter_map(|item| match item {
253            ast::AssocItem::Fn(fn_) if fn_.body().is_none() => {
254                let (fn_editor, fn_) = SyntaxEditor::with_ast_node(&fn_);
255                let fill_expr: ast::Expr = match config.expr_fill_default {
256                    ExprFillDefaultMode::Todo | ExprFillDefaultMode::Default => make.expr_todo(),
257                    ExprFillDefaultMode::Underscore => make.expr_underscore().into(),
258                };
259                let new_body = make.block_expr(None::<ast::Stmt>, Some(fill_expr));
260                fn_.replace_or_insert_body(&fn_editor, new_body);
261                let new_fn_ = fn_editor.finish().new_root().clone();
262                ast::AssocItem::cast(new_fn_)
263            }
264            ast::AssocItem::TypeAlias(type_alias) => {
265                let (type_alias_editor, type_alias) = SyntaxEditor::with_ast_node(&type_alias);
266                if let Some(type_bound_list) = type_alias.type_bound_list() {
267                    type_bound_list.remove(&type_alias_editor);
268                };
269                let type_alias = type_alias_editor.finish().new_root().clone();
270                ast::AssocItem::cast(type_alias)
271            }
272            item => Some(item),
273        })
274        .map(|item| AstNodeEdit::indent(&item, new_indent_level))
275        .collect()
276}
277
278pub(crate) fn vis_offset(node: &SyntaxNode) -> TextSize {
279    node.children_with_tokens()
280        .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR))
281        .map(|it| it.text_range().start())
282        .unwrap_or_else(|| node.text_range().start())
283}
284
285pub(crate) fn invert_boolean_expression(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr {
286    invert_special_case(make, &expr).unwrap_or_else(|| make.expr_prefix(T![!], expr).into())
287}
288
289fn invert_special_case(make: &SyntaxFactory, expr: &ast::Expr) -> Option<ast::Expr> {
290    match expr {
291        ast::Expr::BinExpr(bin) => {
292            let op_kind = bin.op_kind()?;
293            let rev_kind = match op_kind {
294                ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated }) => {
295                    ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: !negated })
296                }
297                ast::BinaryOp::CmpOp(ast::CmpOp::Ord { ordering: ast::Ordering::Less, strict }) => {
298                    ast::BinaryOp::CmpOp(ast::CmpOp::Ord {
299                        ordering: ast::Ordering::Greater,
300                        strict: !strict,
301                    })
302                }
303                ast::BinaryOp::CmpOp(ast::CmpOp::Ord {
304                    ordering: ast::Ordering::Greater,
305                    strict,
306                }) => ast::BinaryOp::CmpOp(ast::CmpOp::Ord {
307                    ordering: ast::Ordering::Less,
308                    strict: !strict,
309                }),
310                // Parenthesize other expressions before prefixing `!`
311                _ => {
312                    return Some(
313                        make.expr_prefix(T![!], make.expr_paren(expr.clone()).into()).into(),
314                    );
315                }
316            };
317
318            Some(make.expr_bin(bin.lhs()?, rev_kind, bin.rhs()?).into())
319        }
320        ast::Expr::MethodCallExpr(mce) => {
321            let receiver = mce.receiver()?;
322            let method = mce.name_ref()?;
323            let arg_list = mce.arg_list()?;
324
325            let method = match method.text().as_str() {
326                "is_some" => "is_none",
327                "is_none" => "is_some",
328                "is_ok" => "is_err",
329                "is_err" => "is_ok",
330                _ => return None,
331            };
332
333            Some(make.expr_method_call(receiver, make.name_ref(method), arg_list).into())
334        }
335        ast::Expr::PrefixExpr(pe) if pe.op_kind()? == ast::UnaryOp::Not => match pe.expr()? {
336            ast::Expr::ParenExpr(parexpr) => parexpr.expr(),
337            _ => pe.expr(),
338        },
339        ast::Expr::Literal(lit) => match lit.kind() {
340            ast::LiteralKind::Bool(b) => match b {
341                true => Some(ast::Expr::Literal(make.expr_literal("false"))),
342                false => Some(ast::Expr::Literal(make.expr_literal("true"))),
343            },
344            _ => None,
345        },
346        _ => None,
347    }
348}
349
350pub(crate) fn insert_attributes(
351    before: impl Element,
352    editor: &SyntaxEditor,
353    attrs: impl IntoIterator<Item = ast::Attr>,
354) {
355    let make = editor.make();
356    let mut attrs = attrs.into_iter().peekable();
357    if attrs.peek().is_none() {
358        return;
359    }
360    let elem = before.syntax_element();
361    let indent = IndentLevel::from_element(&elem);
362    let whitespace = format!("\n{indent}");
363    let elements: Vec<syntax::SyntaxElement> = attrs
364        .flat_map(|attr| [attr.syntax().clone().into(), make.whitespace(&whitespace).into()])
365        .collect();
366    editor.insert_all(syntax::syntax_editor::Position::before(elem), elements);
367}
368
369pub(crate) fn next_prev() -> impl Iterator<Item = Direction> {
370    [Direction::Next, Direction::Prev].into_iter()
371}
372
373pub(crate) fn does_pat_match_variant(pat: &ast::Pat, var: &ast::Pat) -> bool {
374    let first_node_text = |pat: &ast::Pat| pat.syntax().first_child().map(|node| node.text());
375
376    let pat_head = match pat {
377        ast::Pat::IdentPat(bind_pat) => match bind_pat.pat() {
378            Some(p) => first_node_text(&p),
379            None => return pat.syntax().text() == var.syntax().text(),
380        },
381        pat => first_node_text(pat),
382    };
383
384    let var_head = first_node_text(var);
385
386    pat_head == var_head
387}
388
389pub(crate) fn does_pat_variant_nested_or_literal(
390    ctx: &AssistContext<'_, '_>,
391    pat: &ast::Pat,
392) -> bool {
393    check_pat_variant_nested_or_literal_with_depth(ctx, pat, 0)
394}
395
396fn check_pat_variant_from_enum(ctx: &AssistContext<'_, '_>, pat: &ast::Pat) -> bool {
397    ctx.sema.type_of_pat(pat).is_none_or(|ty: hir::TypeInfo<'_>| {
398        ty.adjusted().as_adt().is_some_and(|adt| matches!(adt, hir::Adt::Enum(_)))
399    })
400}
401
402fn check_pat_variant_nested_or_literal_with_depth(
403    ctx: &AssistContext<'_, '_>,
404    pat: &ast::Pat,
405    depth_after_refutable: usize,
406) -> bool {
407    if depth_after_refutable > 1 {
408        return true;
409    }
410
411    match pat {
412        ast::Pat::RestPat(_) | ast::Pat::WildcardPat(_) | ast::Pat::RefPat(_) => false,
413
414        ast::Pat::LiteralPat(_)
415        | ast::Pat::RangePat(_)
416        | ast::Pat::MacroPat(_)
417        | ast::Pat::PathPat(_)
418        | ast::Pat::BoxPat(_)
419        | ast::Pat::DerefPat(_)
420        | ast::Pat::NotNull(_)
421        | ast::Pat::ConstBlockPat(_) => true,
422
423        ast::Pat::IdentPat(ident_pat) => ident_pat.pat().is_some_and(|pat| {
424            check_pat_variant_nested_or_literal_with_depth(ctx, &pat, depth_after_refutable)
425        }),
426        ast::Pat::ParenPat(paren_pat) => paren_pat.pat().is_none_or(|pat| {
427            check_pat_variant_nested_or_literal_with_depth(ctx, &pat, depth_after_refutable)
428        }),
429        ast::Pat::TuplePat(tuple_pat) => tuple_pat.fields().any(|pat| {
430            check_pat_variant_nested_or_literal_with_depth(ctx, &pat, depth_after_refutable)
431        }),
432        ast::Pat::RecordPat(record_pat) => {
433            let adjusted_next_depth =
434                depth_after_refutable + if check_pat_variant_from_enum(ctx, pat) { 1 } else { 0 };
435            record_pat.record_pat_field_list().is_none_or(|pat| {
436                pat.fields().any(|pat| {
437                    pat.pat().is_none_or(|pat| {
438                        check_pat_variant_nested_or_literal_with_depth(
439                            ctx,
440                            &pat,
441                            adjusted_next_depth,
442                        )
443                    })
444                })
445            })
446        }
447        ast::Pat::OrPat(or_pat) => or_pat.pats().any(|pat| {
448            check_pat_variant_nested_or_literal_with_depth(ctx, &pat, depth_after_refutable)
449        }),
450        ast::Pat::TupleStructPat(tuple_struct_pat) => {
451            let adjusted_next_depth =
452                depth_after_refutable + if check_pat_variant_from_enum(ctx, pat) { 1 } else { 0 };
453            tuple_struct_pat.fields().any(|pat| {
454                check_pat_variant_nested_or_literal_with_depth(ctx, &pat, adjusted_next_depth)
455            })
456        }
457        ast::Pat::SlicePat(slice_pat) => {
458            let mut pats = slice_pat.pats();
459            pats.next()
460                .is_none_or(|pat| !matches!(pat, ast::Pat::RestPat(_)) || pats.next().is_some())
461        }
462    }
463}
464
465pub(crate) fn expr_fill_default(config: &AssistConfig) -> ast::Expr {
466    let make = SyntaxFactory::without_mappings();
467    match config.expr_fill_default {
468        ExprFillDefaultMode::Todo => make.expr_todo(),
469        ExprFillDefaultMode::Underscore => make.expr_underscore().into(),
470        ExprFillDefaultMode::Default => make.expr_todo(),
471    }
472}
473
474// Uses a syntax-driven approach to find any impl blocks for the struct that
475// exist within the module/file
476//
477// Returns `None` if we've found an existing fn
478//
479// FIXME: change the new fn checking to a more semantic approach when that's more
480// viable (e.g. we process proc macros, etc)
481// FIXME: this partially overlaps with `find_impl_block_*`
482
483/// `find_struct_impl` looks for impl of a struct, but this also has additional feature
484/// where it takes a list of function names and check if they exist inside impl_, if
485/// even one match is found, it returns None.
486///
487/// That means this function can have 3 potential return values:
488///  - `None`: an impl exists, but one of the function names within the impl matches one of the provided names.
489///  - `Some(None)`: no impl exists.
490///  - `Some(Some(_))`: an impl exists, with no matching function names.
491pub(crate) fn find_struct_impl(
492    ctx: &AssistContext<'_, '_>,
493    adt: &ast::Adt,
494    names: &[String],
495) -> Option<Option<ast::Impl>> {
496    let db = ctx.db();
497    let module = adt.syntax().parent()?;
498
499    let struct_def = ctx.sema.to_def(adt)?;
500
501    let block = module.descendants().filter_map(ast::Impl::cast).find_map(|impl_blk| {
502        let blk = ctx.sema.to_def(&impl_blk)?;
503
504        // FIXME: handle e.g. `struct S<T>; impl<U> S<U> {}`
505        // (we currently use the wrong type parameter)
506        // also we wouldn't want to use e.g. `impl S<u32>`
507
508        let same_ty = match blk.self_ty(db).as_adt() {
509            Some(def) => def == struct_def,
510            None => false,
511        };
512        let not_trait_impl = blk.trait_(db).is_none();
513
514        if !(same_ty && not_trait_impl) { None } else { Some(impl_blk) }
515    });
516
517    if let Some(ref impl_blk) = block
518        && has_any_fn(impl_blk, names)
519    {
520        return None;
521    }
522
523    Some(block)
524}
525
526fn has_any_fn(imp: &ast::Impl, names: &[String]) -> bool {
527    if let Some(il) = imp.assoc_item_list() {
528        for item in il.assoc_items() {
529            if let ast::AssocItem::Fn(f) = item
530                && let Some(name) = f.name()
531                && names.iter().any(|n| n.eq_ignore_ascii_case(&name.text()))
532            {
533                return true;
534            }
535        }
536    }
537
538    false
539}
540
541/// Generates the corresponding `impl Type {}` including type and lifetime
542/// parameters.
543pub(crate) fn generate_impl_with_item(
544    make: &SyntaxFactory,
545    adt: &ast::Adt,
546    body: Option<ast::AssocItemList>,
547) -> ast::Impl {
548    generate_impl_inner(make, false, adt, None, true, body)
549}
550
551pub(crate) fn generate_impl(make: &SyntaxFactory, adt: &ast::Adt) -> ast::Impl {
552    generate_impl_inner(make, false, adt, None, true, None)
553}
554
555/// Generates the corresponding `impl <trait> for Type {}` including type
556/// and lifetime parameters, with `<trait>` appended to `impl`'s generic parameters' bounds.
557///
558/// This is useful for traits like `PartialEq`, since `impl<T> PartialEq for U<T>` often requires `T: PartialEq`.
559pub(crate) fn generate_trait_impl(
560    make: &SyntaxFactory,
561    is_unsafe: bool,
562    adt: &ast::Adt,
563    trait_: ast::Type,
564) -> ast::Impl {
565    generate_impl_inner(make, is_unsafe, adt, Some(trait_), true, None)
566}
567
568/// Generates the corresponding `impl <trait> for Type {}` including type
569/// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is.
570///
571/// This is useful for traits like `From<T>`, since `impl<T> From<T> for U<T>` doesn't require `T: From<T>`.
572pub(crate) fn generate_trait_impl_intransitive(
573    make: &SyntaxFactory,
574    adt: &ast::Adt,
575    trait_: ast::Type,
576) -> ast::Impl {
577    generate_impl_inner(make, false, adt, Some(trait_), false, None)
578}
579
580pub(crate) fn generate_trait_impl_intransitive_with_item(
581    make: &SyntaxFactory,
582    adt: &ast::Adt,
583    trait_: ast::Type,
584    body: ast::AssocItemList,
585) -> ast::Impl {
586    generate_impl_inner(make, false, adt, Some(trait_), false, Some(body))
587}
588
589pub(crate) fn generate_trait_impl_with_item(
590    make: &SyntaxFactory,
591    is_unsafe: bool,
592    adt: &ast::Adt,
593    trait_: ast::Type,
594    body: ast::AssocItemList,
595) -> ast::Impl {
596    generate_impl_inner(make, is_unsafe, adt, Some(trait_), true, Some(body))
597}
598
599fn generate_impl_inner(
600    make: &SyntaxFactory,
601    is_unsafe: bool,
602    adt: &ast::Adt,
603    trait_: Option<ast::Type>,
604    trait_is_transitive: bool,
605    body: Option<ast::AssocItemList>,
606) -> ast::Impl {
607    // Ensure lifetime params are before type & const params
608    let generic_params = adt.generic_param_list().map(|generic_params| {
609        let lifetime_params =
610            generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam);
611        let ty_or_const_params = generic_params.type_or_const_params().filter_map(|param| {
612            let param = match param {
613                ast::TypeOrConstParam::Type(param) => {
614                    // remove defaults since they can't be specified in impls
615                    let mut bounds =
616                        param.type_bound_list().map_or_else(Vec::new, |it| it.bounds().collect());
617                    if let Some(trait_) = &trait_ {
618                        // Add the current trait to `bounds` if the trait is transitive,
619                        // meaning `impl<T> Trait for U<T>` requires `T: Trait`.
620                        if trait_is_transitive {
621                            bounds.push(make.type_bound(trait_.clone()));
622                        }
623                    };
624                    // `{ty_param}: {bounds}`
625                    let param = make.type_param(param.name()?, make.type_bound_list(bounds));
626                    ast::GenericParam::TypeParam(param)
627                }
628                ast::TypeOrConstParam::Const(param) => {
629                    // remove defaults since they can't be specified in impls
630                    let param = make.const_param(param.name()?, param.ty()?);
631                    ast::GenericParam::ConstParam(param)
632                }
633            };
634            Some(param)
635        });
636
637        make.generic_param_list(itertools::chain(lifetime_params, ty_or_const_params))
638    });
639    let generic_args = generic_params.as_ref().map(|params| params.to_generic_args(make));
640    let adt_assoc_bounds = trait_
641        .as_ref()
642        .zip(generic_params.as_ref())
643        .and_then(|(trait_, params)| generic_param_associated_bounds(make, adt, trait_, params));
644
645    let ty: ast::Type = make.ty_path(make.ident_path(&adt.name().unwrap().text())).into();
646
647    let cfg_attrs = adt.attrs().filter(|attr| matches!(attr.meta(), Some(ast::Meta::CfgMeta(_))));
648    match trait_ {
649        Some(trait_) => make.impl_trait(
650            cfg_attrs,
651            is_unsafe,
652            None,
653            None,
654            generic_params,
655            generic_args,
656            false,
657            trait_,
658            ty,
659            adt_assoc_bounds,
660            adt.where_clause(),
661            body,
662        ),
663        None => make.impl_(cfg_attrs, generic_params, generic_args, ty, adt.where_clause(), body),
664    }
665}
666
667fn generic_param_associated_bounds(
668    make: &SyntaxFactory,
669    adt: &ast::Adt,
670    trait_: &ast::Type,
671    generic_params: &ast::GenericParamList,
672) -> Option<ast::WhereClause> {
673    let in_type_params = |name: &ast::NameRef| {
674        generic_params
675            .generic_params()
676            .filter_map(|param| match param {
677                ast::GenericParam::TypeParam(type_param) => type_param.name(),
678                _ => None,
679            })
680            .any(|param| param.text() == name.text())
681    };
682    let adt_body = match adt {
683        ast::Adt::Enum(e) => e.variant_list().map(|it| it.syntax().clone()),
684        ast::Adt::Struct(s) => s.field_list().map(|it| it.syntax().clone()),
685        ast::Adt::Union(u) => u.record_field_list().map(|it| it.syntax().clone()),
686    };
687    let mut trait_where_clause = adt_body
688        .into_iter()
689        .flat_map(|it| it.descendants())
690        .filter_map(ast::Path::cast)
691        .filter_map(|path| {
692            let qualifier = path.qualifier()?.as_single_segment()?;
693            let qualifier = qualifier
694                .name_ref()
695                .or_else(|| match qualifier.type_anchor()?.ty()? {
696                    ast::Type::PathType(path_type) => path_type.path()?.as_single_name_ref(),
697                    _ => None,
698                })
699                .filter(in_type_params)?;
700            Some((qualifier, path.segment()?.name_ref()?))
701        })
702        .map(|(qualifier, assoc_name)| {
703            let segments = [qualifier, assoc_name].map(|nr| make.path_segment(nr));
704            let path = make.path_from_segments(segments, false);
705            let bounds = [make.type_bound(trait_.clone())];
706            make.where_pred(either::Either::Right(make.ty_path(path).into()), bounds)
707        })
708        .unique_by(|it| it.syntax().to_string())
709        .peekable();
710    trait_where_clause.peek().is_some().then(|| make.where_clause(trait_where_clause))
711}
712
713#[derive(Debug)]
714pub(crate) struct ReferenceConversion<'db> {
715    conversion: ReferenceConversionType,
716    ty: hir::Type<'db>,
717    impls_deref: bool,
718}
719
720#[derive(Debug)]
721enum ReferenceConversionType {
722    // reference can be stripped if the type is Copy
723    Copy,
724    // &String -> &str
725    AsRefStr,
726    // &Vec<T> -> &[T]
727    AsRefSlice,
728    // &Box<T> -> &T
729    Dereferenced,
730    // &Option<T> -> Option<&T>
731    Option,
732    // &Result<T, E> -> Result<&T, &E>
733    Result,
734}
735
736impl<'db> ReferenceConversion<'db> {
737    fn type_to_string(&self, db: &'db dyn HirDatabase, module: hir::Module) -> String {
738        match self.conversion {
739            ReferenceConversionType::Copy => self
740                .ty
741                .display_source_code(db, module.into(), true)
742                .unwrap_or_else(|_| "_".to_owned()),
743            ReferenceConversionType::AsRefStr => "&str".to_owned(),
744            ReferenceConversionType::AsRefSlice => {
745                let type_argument_name = self
746                    .ty
747                    .type_arguments()
748                    .next()
749                    .unwrap()
750                    .display_source_code(db, module.into(), true)
751                    .unwrap_or_else(|_| "_".to_owned());
752                format!("&[{type_argument_name}]")
753            }
754            ReferenceConversionType::Dereferenced => {
755                let type_argument_name = self
756                    .ty
757                    .type_arguments()
758                    .next()
759                    .unwrap()
760                    .display_source_code(db, module.into(), true)
761                    .unwrap_or_else(|_| "_".to_owned());
762                format!("&{type_argument_name}")
763            }
764            ReferenceConversionType::Option => {
765                let type_argument_name = self
766                    .ty
767                    .type_arguments()
768                    .next()
769                    .unwrap()
770                    .display_source_code(db, module.into(), true)
771                    .unwrap_or_else(|_| "_".to_owned());
772                format!("Option<&{type_argument_name}>")
773            }
774            ReferenceConversionType::Result => {
775                let mut type_arguments = self.ty.type_arguments();
776                let first_type_argument_name = type_arguments
777                    .next()
778                    .unwrap()
779                    .display_source_code(db, module.into(), true)
780                    .unwrap_or_else(|_| "_".to_owned());
781                let second_type_argument_name = type_arguments
782                    .next()
783                    .unwrap()
784                    .display_source_code(db, module.into(), true)
785                    .unwrap_or_else(|_| "_".to_owned());
786                format!("Result<&{first_type_argument_name}, &{second_type_argument_name}>")
787            }
788        }
789    }
790
791    pub(crate) fn convert_type(&self, db: &'db dyn HirDatabase, module: hir::Module) -> ast::Type {
792        let ty = self.type_to_string(db, module);
793        make::ty(&ty)
794    }
795
796    pub(crate) fn convert_type_with_factory(
797        &self,
798        make: &SyntaxFactory,
799        db: &'db dyn HirDatabase,
800        module: hir::Module,
801    ) -> ast::Type {
802        let ty = self.type_to_string(db, module);
803        make.ty(&ty)
804    }
805
806    pub(crate) fn getter(&self, make: &SyntaxFactory, field_name: String) -> ast::Expr {
807        let expr = make.expr_field(make.expr_self(), &field_name);
808
809        match self.conversion {
810            ReferenceConversionType::Copy => expr.into(),
811            ReferenceConversionType::AsRefStr
812            | ReferenceConversionType::AsRefSlice
813            | ReferenceConversionType::Dereferenced
814            | ReferenceConversionType::Option
815            | ReferenceConversionType::Result => {
816                if self.impls_deref {
817                    make.expr_ref(expr.into(), false)
818                } else {
819                    make.expr_method_call(expr.into(), make.name_ref("as_ref"), make.arg_list([]))
820                        .into()
821                }
822            }
823        }
824    }
825}
826
827// FIXME: It should return a new hir::Type, but currently constructing new types is too cumbersome
828//        and all users of this function operate on string type names, so they can do the conversion
829//        itself themselves.
830pub(crate) fn convert_reference_type<'db>(
831    ty: hir::Type<'db>,
832    db: &'db RootDatabase,
833    famous_defs: &FamousDefs<'_, 'db>,
834) -> Option<ReferenceConversion<'db>> {
835    handle_copy(&ty, db)
836        .or_else(|| handle_as_ref_str(&ty, db, famous_defs))
837        .or_else(|| handle_as_ref_slice(&ty, db, famous_defs))
838        .or_else(|| handle_dereferenced(&ty, db, famous_defs))
839        .or_else(|| handle_option_as_ref(&ty, db, famous_defs))
840        .or_else(|| handle_result_as_ref(&ty, db, famous_defs))
841        .map(|(conversion, impls_deref)| ReferenceConversion { ty, conversion, impls_deref })
842}
843
844fn could_deref_to_target(ty: &hir::Type<'_>, target: &hir::Type<'_>, db: &dyn HirDatabase) -> bool {
845    let ty_ref = ty.add_reference(db, hir::Mutability::Shared);
846    let target_ref = target.add_reference(db, hir::Mutability::Shared);
847    ty_ref.could_coerce_to(db, &target_ref)
848}
849
850fn handle_copy(
851    ty: &hir::Type<'_>,
852    db: &dyn HirDatabase,
853) -> Option<(ReferenceConversionType, bool)> {
854    ty.is_copy(db).then_some((ReferenceConversionType::Copy, true))
855}
856
857fn handle_as_ref_str(
858    ty: &hir::Type<'_>,
859    db: &dyn HirDatabase,
860    famous_defs: &FamousDefs<'_, '_>,
861) -> Option<(ReferenceConversionType, bool)> {
862    let str_type = hir::BuiltinType::str().ty(db);
863
864    ty.impls_trait(db, famous_defs.core_convert_AsRef()?, slice::from_ref(&str_type))
865        .then_some((ReferenceConversionType::AsRefStr, could_deref_to_target(ty, &str_type, db)))
866}
867
868fn handle_as_ref_slice(
869    ty: &hir::Type<'_>,
870    db: &dyn HirDatabase,
871    famous_defs: &FamousDefs<'_, '_>,
872) -> Option<(ReferenceConversionType, bool)> {
873    let type_argument = ty.type_arguments().next()?;
874    let slice_type = hir::Type::new_slice(db, type_argument);
875
876    ty.impls_trait(db, famous_defs.core_convert_AsRef()?, slice::from_ref(&slice_type)).then_some((
877        ReferenceConversionType::AsRefSlice,
878        could_deref_to_target(ty, &slice_type, db),
879    ))
880}
881
882fn handle_dereferenced(
883    ty: &hir::Type<'_>,
884    db: &dyn HirDatabase,
885    famous_defs: &FamousDefs<'_, '_>,
886) -> Option<(ReferenceConversionType, bool)> {
887    let type_argument = ty.type_arguments().next()?;
888
889    ty.impls_trait(db, famous_defs.core_convert_AsRef()?, slice::from_ref(&type_argument))
890        .then_some((
891            ReferenceConversionType::Dereferenced,
892            could_deref_to_target(ty, &type_argument, db),
893        ))
894}
895
896fn handle_option_as_ref(
897    ty: &hir::Type<'_>,
898    db: &dyn HirDatabase,
899    famous_defs: &FamousDefs<'_, '_>,
900) -> Option<(ReferenceConversionType, bool)> {
901    if ty.as_adt() == famous_defs.core_option_Option()?.ty(db).as_adt() {
902        Some((ReferenceConversionType::Option, false))
903    } else {
904        None
905    }
906}
907
908fn handle_result_as_ref(
909    ty: &hir::Type<'_>,
910    db: &dyn HirDatabase,
911    famous_defs: &FamousDefs<'_, '_>,
912) -> Option<(ReferenceConversionType, bool)> {
913    if ty.as_adt() == famous_defs.core_result_Result()?.ty(db).as_adt() {
914        Some((ReferenceConversionType::Result, false))
915    } else {
916        None
917    }
918}
919
920pub(crate) fn get_methods(items: &ast::AssocItemList) -> Vec<ast::Fn> {
921    items
922        .assoc_items()
923        .flat_map(|i| match i {
924            ast::AssocItem::Fn(f) => Some(f),
925            _ => None,
926        })
927        .filter(|f| f.name().is_some())
928        .collect()
929}
930
931/// Trim(remove leading and trailing whitespace) `initial_range` in `source_file`, return the trimmed range.
932pub(crate) fn trimmed_text_range(source_file: &SourceFile, initial_range: TextRange) -> TextRange {
933    let mut trimmed_range = initial_range;
934    while source_file
935        .syntax()
936        .token_at_offset(trimmed_range.start())
937        .find_map(Whitespace::cast)
938        .is_some()
939        && trimmed_range.start() < trimmed_range.end()
940    {
941        let start = trimmed_range.start() + TextSize::from(1);
942        trimmed_range = TextRange::new(start, trimmed_range.end());
943    }
944    while source_file
945        .syntax()
946        .token_at_offset(trimmed_range.end())
947        .find_map(Whitespace::cast)
948        .is_some()
949        && trimmed_range.start() < trimmed_range.end()
950    {
951        let end = trimmed_range.end() - TextSize::from(1);
952        trimmed_range = TextRange::new(trimmed_range.start(), end);
953    }
954    trimmed_range
955}
956
957/// Convert a list of function params to a list of arguments that can be passed
958/// into a function call.
959pub(crate) fn convert_param_list_to_arg_list(
960    list: ast::ParamList,
961    make: &SyntaxFactory,
962) -> ast::ArgList {
963    let mut args = vec![];
964    for param in list.params() {
965        if let Some(ast::Pat::IdentPat(pat)) = param.pat()
966            && let Some(name) = pat.name()
967        {
968            let name = name.to_string();
969            let expr = make.expr_path(make.ident_path(&name));
970            args.push(expr);
971        }
972    }
973    make.arg_list(args)
974}
975
976/// Calculate the number of hashes required for a raw string containing `s`
977pub(crate) fn required_hashes(s: &str) -> usize {
978    let mut res = 0usize;
979    for idx in s.match_indices('"').map(|(i, _)| i) {
980        let (_, sub) = s.split_at(idx + 1);
981        let n_hashes = sub.chars().take_while(|c| *c == '#').count();
982        res = res.max(n_hashes + 1)
983    }
984    res
985}
986#[test]
987fn test_required_hashes() {
988    assert_eq!(0, required_hashes("abc"));
989    assert_eq!(0, required_hashes("###"));
990    assert_eq!(1, required_hashes("\""));
991    assert_eq!(2, required_hashes("\"#abc"));
992    assert_eq!(0, required_hashes("#abc"));
993    assert_eq!(3, required_hashes("#ab\"##c"));
994    assert_eq!(5, required_hashes("#ab\"##\"####c"));
995}
996
997/// Calculate the string literal suffix length
998pub(crate) fn string_suffix(s: &str) -> Option<&str> {
999    s.rfind(['"', '\'', '#']).map(|i| &s[i + 1..])
1000}
1001#[test]
1002fn test_string_suffix() {
1003    assert_eq!(Some(""), string_suffix(r#""abc""#));
1004    assert_eq!(Some(""), string_suffix(r#""""#));
1005    assert_eq!(Some("a"), string_suffix(r#"""a"#));
1006    assert_eq!(Some("i32"), string_suffix(r#"""i32"#));
1007    assert_eq!(Some("i32"), string_suffix(r#"r""i32"#));
1008    assert_eq!(Some("i32"), string_suffix(r##"r#""#i32"##));
1009}
1010
1011/// Calculate the string literal prefix length
1012pub(crate) fn string_prefix(s: &str) -> Option<&str> {
1013    s.split_once(['"', '\'', '#']).map(|(prefix, _)| prefix)
1014}
1015#[test]
1016fn test_string_prefix() {
1017    assert_eq!(Some(""), string_prefix(r#""abc""#));
1018    assert_eq!(Some(""), string_prefix(r#""""#));
1019    assert_eq!(Some(""), string_prefix(r#"""suffix"#));
1020    assert_eq!(Some("c"), string_prefix(r#"c"""#));
1021    assert_eq!(Some("r"), string_prefix(r#"r"""#));
1022    assert_eq!(Some("cr"), string_prefix(r#"cr"""#));
1023    assert_eq!(Some("r"), string_prefix(r##"r#""#"##));
1024}
1025
1026pub(crate) fn add_group_separators(s: &str, group_size: usize) -> String {
1027    let mut chars = Vec::new();
1028    for (i, ch) in s.chars().filter(|&ch| ch != '_').rev().enumerate() {
1029        if i > 0 && i % group_size == 0 && ch != '-' {
1030            chars.push('_');
1031        }
1032        chars.push(ch);
1033    }
1034
1035    chars.into_iter().rev().collect()
1036}
1037
1038/// Replaces the record expression, handling field shorthands including inside macros.
1039pub(crate) fn replace_record_field_expr(
1040    ctx: &AssistContext<'_, '_>,
1041    edit: &mut SourceChangeBuilder,
1042    record_field: ast::RecordExprField,
1043    initializer: ast::Expr,
1044) {
1045    if let Some(ast::Expr::PathExpr(path_expr)) = record_field.expr() {
1046        // replace field shorthand
1047        let file_range = ctx.sema.original_range(path_expr.syntax());
1048        edit.insert(file_range.range.end(), format!(": {}", initializer.syntax().text()))
1049    } else if let Some(expr) = record_field.expr() {
1050        // just replace expr
1051        let file_range = ctx.sema.original_range(expr.syntax());
1052        edit.replace(file_range.range, initializer.syntax().text());
1053    }
1054}
1055
1056/// Creates a token tree list from a syntax node, creating the needed delimited sub token trees.
1057/// Assumes that the input syntax node is a valid syntax tree.
1058pub(crate) fn tt_from_syntax(
1059    node: SyntaxNode,
1060    make: &SyntaxFactory,
1061) -> Vec<NodeOrToken<ast::TokenTree, SyntaxToken>> {
1062    let mut tt_stack = vec![(None, vec![])];
1063
1064    for element in node.descendants_with_tokens() {
1065        let NodeOrToken::Token(token) = element else { continue };
1066
1067        match token.kind() {
1068            T!['('] | T!['{'] | T!['['] => {
1069                // Found an opening delimiter, start a new sub token tree
1070                tt_stack.push((Some(token.kind()), vec![]));
1071            }
1072            T![')'] | T!['}'] | T![']'] => {
1073                // Closing a subtree
1074                let (delimiter, tt) = tt_stack.pop().expect("unbalanced delimiters");
1075                let (_, parent_tt) = tt_stack
1076                    .last_mut()
1077                    .expect("parent token tree was closed before it was completed");
1078                let closing_delimiter = delimiter.map(|it| match it {
1079                    T!['('] => T![')'],
1080                    T!['{'] => T!['}'],
1081                    T!['['] => T![']'],
1082                    _ => unreachable!(),
1083                });
1084                stdx::always!(
1085                    closing_delimiter == Some(token.kind()),
1086                    "mismatched opening and closing delimiters"
1087                );
1088
1089                let sub_tt = make.token_tree(delimiter.expect("unbalanced delimiters"), tt);
1090                parent_tt.push(NodeOrToken::Node(sub_tt));
1091            }
1092            _ => {
1093                let (_, current_tt) = tt_stack.last_mut().expect("unmatched delimiters");
1094                current_tt.push(NodeOrToken::Token(token))
1095            }
1096        }
1097    }
1098
1099    tt_stack.pop().expect("parent token tree was closed before it was completed").1
1100}
1101
1102pub(crate) fn cover_let_chain(mut expr: ast::Expr, range: TextRange) -> Option<ast::Expr> {
1103    if !expr.syntax().text_range().contains_range(range) {
1104        return None;
1105    }
1106    loop {
1107        let (chain_expr, rest) = if let ast::Expr::BinExpr(bin_expr) = &expr
1108            && bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
1109        {
1110            (bin_expr.rhs(), bin_expr.lhs())
1111        } else {
1112            (Some(expr), None)
1113        };
1114
1115        if let Some(chain_expr) = chain_expr
1116            && chain_expr.syntax().text_range().contains_range(range)
1117        {
1118            break Some(chain_expr);
1119        }
1120        expr = rest?;
1121    }
1122}
1123
1124pub(crate) fn cover_edit_range(
1125    source: &SyntaxNode,
1126    range: TextRange,
1127) -> std::ops::RangeInclusive<syntax::SyntaxElement> {
1128    let node = match source.covering_element(range) {
1129        NodeOrToken::Node(node) => node,
1130        NodeOrToken::Token(t) => t.parent().unwrap(),
1131    };
1132    let mut iter = node.children_with_tokens().filter(|it| range.contains_range(it.text_range()));
1133    let first = iter.next().unwrap_or(node.into());
1134    let last = iter.last().unwrap_or_else(|| first.clone());
1135    first..=last
1136}
1137
1138pub(crate) fn is_selected(
1139    it: &impl AstNode,
1140    selection: syntax::TextRange,
1141    allow_empty: bool,
1142) -> bool {
1143    selection.intersect(it.syntax().text_range()).is_some_and(|it| !it.is_empty())
1144        || allow_empty && it.syntax().text_range().contains_range(selection)
1145}
1146
1147pub fn is_body_const(sema: &Semantics<'_, RootDatabase>, expr: &ast::Expr) -> bool {
1148    let mut is_const = true;
1149    preorder_expr(expr, &mut |ev| {
1150        let expr = match ev {
1151            WalkEvent::Enter(_) if !is_const => return true,
1152            WalkEvent::Enter(expr) => expr,
1153            WalkEvent::Leave(_) => return false,
1154        };
1155        match expr {
1156            ast::Expr::CallExpr(call) => {
1157                if let Some(ast::Expr::PathExpr(path_expr)) = call.expr()
1158                    && let Some(PathResolution::Def(ModuleDef::Function(func))) =
1159                        path_expr.path().and_then(|path| sema.resolve_path(&path))
1160                {
1161                    is_const &= func.is_const(sema.db);
1162                }
1163            }
1164            ast::Expr::MethodCallExpr(call) => {
1165                is_const &=
1166                    sema.resolve_method_call(&call).map(|it| it.is_const(sema.db)).unwrap_or(true)
1167            }
1168            ast::Expr::ForExpr(_)
1169            | ast::Expr::ReturnExpr(_)
1170            | ast::Expr::TryExpr(_)
1171            | ast::Expr::YieldExpr(_)
1172            | ast::Expr::AwaitExpr(_) => is_const = false,
1173            _ => (),
1174        }
1175        !is_const
1176    });
1177    is_const
1178}
1179
1180// FIXME: #20460 When hir-ty can analyze the `never` statement at the end of block, remove it
1181pub(crate) fn is_never_block(
1182    sema: &Semantics<'_, RootDatabase>,
1183    block_expr: &ast::BlockExpr,
1184) -> bool {
1185    if let Some(tail_expr) = block_expr.tail_expr() {
1186        sema.type_of_expr(&tail_expr).is_some_and(|ty| ty.original.is_never())
1187    } else if let Some(ast::Stmt::ExprStmt(expr_stmt)) = block_expr.statements().last()
1188        && let Some(expr) = expr_stmt.expr()
1189    {
1190        sema.type_of_expr(&expr).is_some_and(|ty| ty.original.is_never())
1191    } else {
1192        false
1193    }
1194}