ide_db/syntax_helpers/
node_ext.rs

1//! Various helper functions to work with SyntaxNodes.
2use std::ops::ControlFlow;
3
4use itertools::Itertools;
5use parser::T;
6use span::Edition;
7use syntax::{
8    AstNode, AstToken, Preorder, RustLanguage, WalkEvent,
9    ast::{self, HasLoopBody, MacroCall, PathSegmentKind, VisibilityKind},
10};
11
12pub fn expr_as_name_ref(expr: &ast::Expr) -> Option<ast::NameRef> {
13    if let ast::Expr::PathExpr(expr) = expr {
14        let path = expr.path()?;
15        path.as_single_name_ref()
16    } else {
17        None
18    }
19}
20
21pub fn full_path_of_name_ref(name_ref: &ast::NameRef) -> Option<ast::Path> {
22    let mut ancestors = name_ref.syntax().ancestors();
23    let _ = ancestors.next()?; // skip self
24    let _ = ancestors.next().filter(|it| ast::PathSegment::can_cast(it.kind()))?; // skip self
25    ancestors.take_while(|it| ast::Path::can_cast(it.kind())).last().and_then(ast::Path::cast)
26}
27
28pub fn block_as_lone_tail(block: &ast::BlockExpr) -> Option<ast::Expr> {
29    block.statements().next().is_none().then(|| block.tail_expr()).flatten()
30}
31
32/// Preorder walk all the expression's child expressions.
33pub fn walk_expr(expr: &ast::Expr, cb: &mut dyn FnMut(ast::Expr)) {
34    preorder_expr(expr, &mut |ev| {
35        if let WalkEvent::Enter(expr) = ev {
36            cb(expr);
37        }
38        false
39    })
40}
41
42pub fn is_closure_or_blk_with_modif(expr: &ast::Expr) -> bool {
43    match expr {
44        ast::Expr::BlockExpr(block_expr) => {
45            matches!(
46                block_expr.modifier(),
47                Some(
48                    ast::BlockModifier::Async(_)
49                        | ast::BlockModifier::Try(_)
50                        | ast::BlockModifier::Const(_)
51                )
52            )
53        }
54        ast::Expr::ClosureExpr(_) => true,
55        _ => false,
56    }
57}
58
59/// Preorder walk all the expression's child expressions preserving events.
60/// If the callback returns true on an [`WalkEvent::Enter`], the subtree of the expression will be skipped.
61/// Note that the subtree may already be skipped due to the context analysis this function does.
62pub fn preorder_expr(start: &ast::Expr, cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool) {
63    preorder_expr_with_ctx_checker(start, &is_closure_or_blk_with_modif, cb);
64}
65
66pub fn preorder_expr_with_ctx_checker(
67    start: &ast::Expr,
68    check_ctx: &dyn Fn(&ast::Expr) -> bool,
69    cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool,
70) {
71    let mut preorder = start.syntax().preorder();
72    while let Some(event) = preorder.next() {
73        let node = match event {
74            WalkEvent::Enter(node) => node,
75            WalkEvent::Leave(node) => {
76                if let Some(expr) = ast::Expr::cast(node) {
77                    cb(WalkEvent::Leave(expr));
78                }
79                continue;
80            }
81        };
82        if let Some(let_stmt) = node.parent().and_then(ast::LetStmt::cast)
83            && let_stmt.initializer().map(|it| it.syntax() != &node).unwrap_or(true)
84            && let_stmt.let_else().map(|it| it.syntax() != &node).unwrap_or(true)
85        {
86            // skipping potential const pat expressions in  let statements
87            preorder.skip_subtree();
88            continue;
89        }
90
91        match ast::Stmt::cast(node.clone()) {
92            // Don't skip subtree since we want to process the expression child next
93            Some(ast::Stmt::ExprStmt(_)) | Some(ast::Stmt::LetStmt(_)) => (),
94            // skip inner items which might have their own expressions
95            Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
96            None => {
97                // skip const args, those expressions are a different context
98                if ast::GenericArg::can_cast(node.kind()) {
99                    preorder.skip_subtree();
100                } else if let Some(expr) = ast::Expr::cast(node) {
101                    let is_different_context = check_ctx(&expr) && expr.syntax() != start.syntax();
102                    let skip = cb(WalkEvent::Enter(expr));
103                    if skip || is_different_context {
104                        preorder.skip_subtree();
105                    }
106                }
107            }
108        }
109    }
110}
111
112/// Preorder walk all the expression's child patterns.
113pub fn walk_patterns_in_expr(start: &ast::Expr, cb: &mut dyn FnMut(ast::Pat)) {
114    let mut preorder = start.syntax().preorder();
115    while let Some(event) = preorder.next() {
116        let node = match event {
117            WalkEvent::Enter(node) => node,
118            WalkEvent::Leave(_) => continue,
119        };
120        match ast::Stmt::cast(node.clone()) {
121            Some(ast::Stmt::LetStmt(l)) => {
122                if let Some(pat) = l.pat() {
123                    _ = walk_pat(&pat, &mut |pat| {
124                        cb(pat);
125                        ControlFlow::<(), ()>::Continue(())
126                    });
127                }
128                if let Some(expr) = l.initializer() {
129                    walk_patterns_in_expr(&expr, cb);
130                }
131                preorder.skip_subtree();
132            }
133            // Don't skip subtree since we want to process the expression child next
134            Some(ast::Stmt::ExprStmt(_)) => (),
135            // skip inner items which might have their own patterns
136            Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
137            None => {
138                // skip const args, those are a different context
139                if ast::GenericArg::can_cast(node.kind()) {
140                    preorder.skip_subtree();
141                } else if let Some(expr) = ast::Expr::cast(node.clone()) {
142                    let is_different_context = match &expr {
143                        ast::Expr::BlockExpr(block_expr) => {
144                            matches!(
145                                block_expr.modifier(),
146                                Some(
147                                    ast::BlockModifier::Async(_)
148                                        | ast::BlockModifier::Try(_)
149                                        | ast::BlockModifier::Const(_)
150                                )
151                            )
152                        }
153                        ast::Expr::ClosureExpr(_) => true,
154                        _ => false,
155                    } && expr.syntax() != start.syntax();
156                    if is_different_context {
157                        preorder.skip_subtree();
158                    }
159                } else if let Some(pat) = ast::Pat::cast(node) {
160                    preorder.skip_subtree();
161                    _ = walk_pat(&pat, &mut |pat| {
162                        cb(pat);
163                        ControlFlow::<(), ()>::Continue(())
164                    });
165                }
166            }
167        }
168    }
169}
170
171/// Preorder walk all the pattern's sub patterns.
172pub fn walk_pat<T>(
173    pat: &ast::Pat,
174    cb: &mut dyn FnMut(ast::Pat) -> ControlFlow<T>,
175) -> ControlFlow<T> {
176    let mut preorder = pat.syntax().preorder();
177    while let Some(event) = preorder.next() {
178        let node = match event {
179            WalkEvent::Enter(node) => node,
180            WalkEvent::Leave(_) => continue,
181        };
182        let kind = node.kind();
183        match ast::Pat::cast(node) {
184            Some(pat @ ast::Pat::ConstBlockPat(_)) => {
185                preorder.skip_subtree();
186                cb(pat)?;
187            }
188            Some(pat) => {
189                cb(pat)?;
190            }
191            // skip const args
192            None if ast::GenericArg::can_cast(kind) => {
193                preorder.skip_subtree();
194            }
195            None => (),
196        }
197    }
198    ControlFlow::Continue(())
199}
200
201/// Preorder walk all the type's sub types.
202// FIXME: Make the control flow more proper
203pub fn walk_ty(ty: &ast::Type, cb: &mut dyn FnMut(ast::Type) -> bool) {
204    let mut preorder = ty.syntax().preorder();
205    while let Some(event) = preorder.next() {
206        let node = match event {
207            WalkEvent::Enter(node) => node,
208            WalkEvent::Leave(_) => continue,
209        };
210        let kind = node.kind();
211        match ast::Type::cast(node) {
212            Some(ty @ ast::Type::MacroType(_)) => {
213                preorder.skip_subtree();
214                cb(ty);
215            }
216            Some(ty) => {
217                if cb(ty) {
218                    preorder.skip_subtree();
219                }
220            }
221            // skip const args
222            None if ast::ConstArg::can_cast(kind) => {
223                preorder.skip_subtree();
224            }
225            None => (),
226        }
227    }
228}
229
230pub fn vis_eq(this: &ast::Visibility, other: &ast::Visibility) -> bool {
231    match (this.kind(), other.kind()) {
232        (VisibilityKind::In(this), VisibilityKind::In(other)) => {
233            stdx::iter_eq_by(this.segments(), other.segments(), |lhs, rhs| {
234                lhs.kind().zip(rhs.kind()).is_some_and(|it| match it {
235                    (PathSegmentKind::CrateKw, PathSegmentKind::CrateKw)
236                    | (PathSegmentKind::SelfKw, PathSegmentKind::SelfKw)
237                    | (PathSegmentKind::SuperKw, PathSegmentKind::SuperKw) => true,
238                    (PathSegmentKind::Name(lhs), PathSegmentKind::Name(rhs)) => {
239                        lhs.text() == rhs.text()
240                    }
241                    _ => false,
242                })
243            })
244        }
245        (VisibilityKind::PubSelf, VisibilityKind::PubSelf)
246        | (VisibilityKind::PubSuper, VisibilityKind::PubSuper)
247        | (VisibilityKind::PubCrate, VisibilityKind::PubCrate)
248        | (VisibilityKind::Pub, VisibilityKind::Pub) => true,
249        _ => false,
250    }
251}
252
253/// Returns the `let` only if there is exactly one (that is, `let pat = expr`
254/// or `((let pat = expr))`, but not `let pat = expr && expr` or `non_let_expr`).
255pub fn single_let(expr: ast::Expr) -> Option<ast::LetExpr> {
256    match expr {
257        ast::Expr::ParenExpr(expr) => expr.expr().and_then(single_let),
258        ast::Expr::LetExpr(expr) => Some(expr),
259        _ => None,
260    }
261}
262
263pub fn is_pattern_cond(expr: ast::Expr) -> bool {
264    match expr {
265        ast::Expr::BinExpr(expr)
266            if expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And)) =>
267        {
268            expr.lhs()
269                .map(is_pattern_cond)
270                .or_else(|| expr.rhs().map(is_pattern_cond))
271                .unwrap_or(false)
272        }
273        ast::Expr::ParenExpr(expr) => expr.expr().is_some_and(is_pattern_cond),
274        ast::Expr::LetExpr(_) => true,
275        _ => false,
276    }
277}
278
279/// Calls `cb` on each expression inside `expr` that is at "tail position".
280/// Does not walk into `break` or `return` expressions.
281/// Note that modifying the tree while iterating it will cause undefined iteration which might
282/// potentially results in an out of bounds panic.
283pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) {
284    let walk_loop = |cb: &mut dyn FnMut(&ast::Expr), label, body: Option<ast::BlockExpr>| {
285        for_each_break_expr(label, body.and_then(|it| it.stmt_list()), &mut |b| {
286            cb(&ast::Expr::BreakExpr(b))
287        })
288    };
289    match expr {
290        ast::Expr::BlockExpr(b) => {
291            match b.modifier() {
292                Some(
293                    ast::BlockModifier::Async(_)
294                    | ast::BlockModifier::Try(_)
295                    | ast::BlockModifier::Const(_),
296                ) => return cb(expr),
297
298                Some(ast::BlockModifier::Label(label)) => {
299                    for_each_break_expr(Some(label), b.stmt_list(), &mut |b| {
300                        cb(&ast::Expr::BreakExpr(b))
301                    });
302                }
303                Some(ast::BlockModifier::Unsafe(_)) => (),
304                Some(ast::BlockModifier::Gen(_)) => (),
305                Some(ast::BlockModifier::AsyncGen(_)) => (),
306                None => (),
307            }
308            if let Some(stmt_list) = b.stmt_list()
309                && let Some(e) = stmt_list.tail_expr()
310            {
311                for_each_tail_expr(&e, cb);
312            }
313        }
314        ast::Expr::IfExpr(if_) => {
315            let mut if_ = if_.clone();
316            loop {
317                if let Some(block) = if_.then_branch() {
318                    for_each_tail_expr(&ast::Expr::BlockExpr(block), cb);
319                }
320                match if_.else_branch() {
321                    Some(ast::ElseBranch::IfExpr(it)) => if_ = it,
322                    Some(ast::ElseBranch::Block(block)) => {
323                        for_each_tail_expr(&ast::Expr::BlockExpr(block), cb);
324                        break;
325                    }
326                    None => break,
327                }
328            }
329        }
330        ast::Expr::LoopExpr(l) => walk_loop(cb, l.label(), l.loop_body()),
331        ast::Expr::WhileExpr(w) => walk_loop(cb, w.label(), w.loop_body()),
332        ast::Expr::ForExpr(f) => walk_loop(cb, f.label(), f.loop_body()),
333        ast::Expr::MatchExpr(m) => {
334            if let Some(arms) = m.match_arm_list() {
335                arms.arms().filter_map(|arm| arm.expr()).for_each(|e| for_each_tail_expr(&e, cb));
336            }
337        }
338        ast::Expr::ArrayExpr(_)
339        | ast::Expr::AwaitExpr(_)
340        | ast::Expr::BinExpr(_)
341        | ast::Expr::BreakExpr(_)
342        | ast::Expr::CallExpr(_)
343        | ast::Expr::CastExpr(_)
344        | ast::Expr::ClosureExpr(_)
345        | ast::Expr::ContinueExpr(_)
346        | ast::Expr::FieldExpr(_)
347        | ast::Expr::IndexExpr(_)
348        | ast::Expr::Literal(_)
349        | ast::Expr::MacroExpr(_)
350        | ast::Expr::MethodCallExpr(_)
351        | ast::Expr::ParenExpr(_)
352        | ast::Expr::PathExpr(_)
353        | ast::Expr::PrefixExpr(_)
354        | ast::Expr::RangeExpr(_)
355        | ast::Expr::RecordExpr(_)
356        | ast::Expr::RefExpr(_)
357        | ast::Expr::ReturnExpr(_)
358        | ast::Expr::BecomeExpr(_)
359        | ast::Expr::TryExpr(_)
360        | ast::Expr::TupleExpr(_)
361        | ast::Expr::LetExpr(_)
362        | ast::Expr::UnderscoreExpr(_)
363        | ast::Expr::YieldExpr(_)
364        | ast::Expr::YeetExpr(_)
365        | ast::Expr::OffsetOfExpr(_)
366        | ast::Expr::FormatArgsExpr(_)
367        | ast::Expr::AsmExpr(_) => cb(expr),
368    }
369}
370
371pub fn for_each_break_and_continue_expr(
372    label: Option<ast::Label>,
373    body: Option<ast::StmtList>,
374    cb: &mut dyn FnMut(ast::Expr),
375) {
376    let label = label.and_then(|lbl| lbl.lifetime());
377    if let Some(b) = body {
378        let tree_depth_iterator = TreeWithDepthIterator::new(b);
379        for (expr, depth) in tree_depth_iterator {
380            match expr {
381                ast::Expr::BreakExpr(b)
382                    if (depth == 0 && b.lifetime().is_none())
383                        || eq_label_lt(&label, &b.lifetime()) =>
384                {
385                    cb(ast::Expr::BreakExpr(b));
386                }
387                ast::Expr::ContinueExpr(c)
388                    if (depth == 0 && c.lifetime().is_none())
389                        || eq_label_lt(&label, &c.lifetime()) =>
390                {
391                    cb(ast::Expr::ContinueExpr(c));
392                }
393                _ => (),
394            }
395        }
396    }
397}
398
399fn for_each_break_expr(
400    label: Option<ast::Label>,
401    body: Option<ast::StmtList>,
402    cb: &mut dyn FnMut(ast::BreakExpr),
403) {
404    let label = label.and_then(|lbl| lbl.lifetime());
405    if let Some(b) = body {
406        let tree_depth_iterator = TreeWithDepthIterator::new(b);
407        for (expr, depth) in tree_depth_iterator {
408            match expr {
409                ast::Expr::BreakExpr(b)
410                    if (depth == 0 && b.lifetime().is_none())
411                        || eq_label_lt(&label, &b.lifetime()) =>
412                {
413                    cb(b);
414                }
415                _ => (),
416            }
417        }
418    }
419}
420
421pub fn eq_label_lt(lt1: &Option<ast::Lifetime>, lt2: &Option<ast::Lifetime>) -> bool {
422    lt1.as_ref().zip(lt2.as_ref()).is_some_and(|(lt, lbl)| lt.text() == lbl.text())
423}
424
425struct TreeWithDepthIterator {
426    preorder: Preorder<RustLanguage>,
427    depth: u32,
428}
429
430impl TreeWithDepthIterator {
431    fn new(body: ast::StmtList) -> Self {
432        let preorder = body.syntax().preorder();
433        Self { preorder, depth: 0 }
434    }
435}
436
437impl Iterator for TreeWithDepthIterator {
438    type Item = (ast::Expr, u32);
439
440    fn next(&mut self) -> Option<Self::Item> {
441        while let Some(event) = self.preorder.find_map(|ev| match ev {
442            WalkEvent::Enter(it) => ast::Expr::cast(it).map(WalkEvent::Enter),
443            WalkEvent::Leave(it) => ast::Expr::cast(it).map(WalkEvent::Leave),
444        }) {
445            match event {
446                WalkEvent::Enter(
447                    ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_),
448                ) => {
449                    self.depth += 1;
450                }
451                WalkEvent::Leave(
452                    ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_),
453                ) => {
454                    self.depth -= 1;
455                }
456                WalkEvent::Enter(ast::Expr::BlockExpr(e)) if e.label().is_some() => {
457                    self.depth += 1;
458                }
459                WalkEvent::Leave(ast::Expr::BlockExpr(e)) if e.label().is_some() => {
460                    self.depth -= 1;
461                }
462                WalkEvent::Enter(expr) => return Some((expr, self.depth)),
463                _ => (),
464            }
465        }
466        None
467    }
468}
469
470/// Parses the input token tree as comma separated plain paths.
471pub fn parse_tt_as_comma_sep_paths(
472    input: ast::TokenTree,
473    edition: Edition,
474) -> Option<Vec<ast::Path>> {
475    let r_paren = input.r_paren_token();
476    let tokens =
477        input.syntax().children_with_tokens().skip(1).map_while(|it| match it.into_token() {
478            // seeing a keyword means the attribute is unclosed so stop parsing here
479            Some(tok) if tok.kind().is_keyword(edition) => None,
480            // don't include the right token tree parenthesis if it exists
481            tok @ Some(_) if tok == r_paren => None,
482            // only nodes that we can find are other TokenTrees, those are unexpected in this parse though
483            None => None,
484            Some(tok) => Some(tok),
485        });
486    let input_expressions = tokens.chunk_by(|tok| tok.kind() == T![,]);
487    let paths = input_expressions
488        .into_iter()
489        .filter_map(|(is_sep, group)| (!is_sep).then_some(group))
490        .filter_map(|mut tokens| {
491            syntax::hacks::parse_expr_from_str(&tokens.join(""), Edition::CURRENT).and_then(
492                |expr| match expr {
493                    ast::Expr::PathExpr(it) => it.path(),
494                    _ => None,
495                },
496            )
497        })
498        .collect();
499    Some(paths)
500}
501
502pub fn macro_call_for_string_token(string: &ast::String) -> Option<MacroCall> {
503    let macro_call = string.syntax().parent_ancestors().find_map(ast::MacroCall::cast)?;
504    Some(macro_call)
505}