hir_def/expr_store/
scope.rs

1//! Name resolution for expressions.
2use hir_expand::{MacroDefId, name::Name};
3use la_arena::{Arena, ArenaMap, Idx, IdxRange, RawIdx};
4use triomphe::Arc;
5
6use crate::{
7    BlockId, DefWithBodyId,
8    db::DefDatabase,
9    expr_store::{Body, ExpressionStore, HygieneId},
10    hir::{Binding, BindingId, Expr, ExprId, Item, LabelId, Pat, PatId, Statement},
11};
12
13pub type ScopeId = Idx<ScopeData>;
14
15#[derive(Debug, PartialEq, Eq)]
16pub struct ExprScopes {
17    scopes: Arena<ScopeData>,
18    scope_entries: Arena<ScopeEntry>,
19    scope_by_expr: ArenaMap<ExprId, ScopeId>,
20}
21
22#[derive(Debug, PartialEq, Eq)]
23pub struct ScopeEntry {
24    name: Name,
25    hygiene: HygieneId,
26    binding: BindingId,
27}
28
29impl ScopeEntry {
30    pub fn name(&self) -> &Name {
31        &self.name
32    }
33
34    pub(crate) fn hygiene(&self) -> HygieneId {
35        self.hygiene
36    }
37
38    pub fn binding(&self) -> BindingId {
39        self.binding
40    }
41}
42
43#[derive(Debug, PartialEq, Eq)]
44pub struct ScopeData {
45    parent: Option<ScopeId>,
46    block: Option<BlockId>,
47    label: Option<(LabelId, Name)>,
48    // FIXME: We can compress this with an enum for this and `label`/`block` if memory usage matters.
49    macro_def: Option<Box<MacroDefId>>,
50    entries: IdxRange<ScopeEntry>,
51}
52
53impl ExprScopes {
54    pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
55        let body = db.body(def);
56        let mut scopes = ExprScopes::new_body(&body);
57        scopes.shrink_to_fit();
58        Arc::new(scopes)
59    }
60
61    pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
62        &self.scope_entries[self.scopes[scope].entries.clone()]
63    }
64
65    /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
66    pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
67        self.scopes[scope].block
68    }
69
70    /// If `scope` refers to a macro def scope, returns the corresponding `MacroId`.
71    #[allow(clippy::borrowed_box)] // If we return `&MacroDefId` we need to move it, this way we just clone the `Box`.
72    pub fn macro_def(&self, scope: ScopeId) -> Option<&Box<MacroDefId>> {
73        self.scopes[scope].macro_def.as_ref()
74    }
75
76    /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
77    pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
78        self.scopes[scope].label.clone()
79    }
80
81    /// Returns the scopes in ascending order.
82    pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
83        std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
84    }
85
86    pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
87        self.scope_chain(Some(scope))
88            .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
89    }
90
91    pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
92        self.scope_by_expr.get(expr).copied()
93    }
94
95    pub fn scope_by_expr(&self) -> &ArenaMap<ExprId, ScopeId> {
96        &self.scope_by_expr
97    }
98}
99
100fn empty_entries(idx: usize) -> IdxRange<ScopeEntry> {
101    IdxRange::new(Idx::from_raw(RawIdx::from(idx as u32))..Idx::from_raw(RawIdx::from(idx as u32)))
102}
103
104impl ExprScopes {
105    fn new_body(body: &Body) -> ExprScopes {
106        let mut scopes = ExprScopes {
107            scopes: Arena::default(),
108            scope_entries: Arena::default(),
109            scope_by_expr: ArenaMap::with_capacity(
110                body.expr_only.as_ref().map_or(0, |it| it.exprs.len()),
111            ),
112        };
113        let mut root = scopes.root_scope();
114        if let Some(self_param) = body.self_param {
115            scopes.add_bindings(body, root, self_param, body.binding_hygiene(self_param));
116        }
117        scopes.add_params_bindings(body, root, &body.params);
118        compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root);
119        scopes
120    }
121
122    fn root_scope(&mut self) -> ScopeId {
123        self.scopes.alloc(ScopeData {
124            parent: None,
125            block: None,
126            label: None,
127            macro_def: None,
128            entries: empty_entries(self.scope_entries.len()),
129        })
130    }
131
132    fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
133        self.scopes.alloc(ScopeData {
134            parent: Some(parent),
135            block: None,
136            label: None,
137            macro_def: None,
138            entries: empty_entries(self.scope_entries.len()),
139        })
140    }
141
142    fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
143        self.scopes.alloc(ScopeData {
144            parent: Some(parent),
145            block: None,
146            label,
147            macro_def: None,
148            entries: empty_entries(self.scope_entries.len()),
149        })
150    }
151
152    fn new_block_scope(
153        &mut self,
154        parent: ScopeId,
155        block: Option<BlockId>,
156        label: Option<(LabelId, Name)>,
157    ) -> ScopeId {
158        self.scopes.alloc(ScopeData {
159            parent: Some(parent),
160            block,
161            label,
162            macro_def: None,
163            entries: empty_entries(self.scope_entries.len()),
164        })
165    }
166
167    fn new_macro_def_scope(&mut self, parent: ScopeId, macro_id: Box<MacroDefId>) -> ScopeId {
168        self.scopes.alloc(ScopeData {
169            parent: Some(parent),
170            block: None,
171            label: None,
172            macro_def: Some(macro_id),
173            entries: empty_entries(self.scope_entries.len()),
174        })
175    }
176
177    fn add_bindings(
178        &mut self,
179        store: &ExpressionStore,
180        scope: ScopeId,
181        binding: BindingId,
182        hygiene: HygieneId,
183    ) {
184        let Binding { name, .. } = &store[binding];
185        let entry = self.scope_entries.alloc(ScopeEntry { name: name.clone(), binding, hygiene });
186        self.scopes[scope].entries =
187            IdxRange::new_inclusive(self.scopes[scope].entries.start()..=entry);
188    }
189
190    fn add_pat_bindings(&mut self, store: &ExpressionStore, scope: ScopeId, pat: PatId) {
191        let pattern = &store[pat];
192        if let Pat::Bind { id, .. } = *pattern {
193            self.add_bindings(store, scope, id, store.binding_hygiene(id));
194        }
195
196        pattern.walk_child_pats(|pat| self.add_pat_bindings(store, scope, pat));
197    }
198
199    fn add_params_bindings(&mut self, store: &ExpressionStore, scope: ScopeId, params: &[PatId]) {
200        params.iter().for_each(|pat| self.add_pat_bindings(store, scope, *pat));
201    }
202
203    fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
204        self.scope_by_expr.insert(node, scope);
205    }
206
207    fn shrink_to_fit(&mut self) {
208        let ExprScopes { scopes, scope_entries, scope_by_expr } = self;
209        scopes.shrink_to_fit();
210        scope_entries.shrink_to_fit();
211        scope_by_expr.shrink_to_fit();
212    }
213}
214
215fn compute_block_scopes(
216    statements: &[Statement],
217    tail: Option<ExprId>,
218    store: &ExpressionStore,
219    scopes: &mut ExprScopes,
220    scope: &mut ScopeId,
221) {
222    for stmt in statements {
223        match stmt {
224            Statement::Let { pat, initializer, else_branch, .. } => {
225                if let Some(expr) = initializer {
226                    compute_expr_scopes(*expr, store, scopes, scope);
227                }
228                if let Some(expr) = else_branch {
229                    compute_expr_scopes(*expr, store, scopes, scope);
230                }
231
232                *scope = scopes.new_scope(*scope);
233                scopes.add_pat_bindings(store, *scope, *pat);
234            }
235            Statement::Expr { expr, .. } => {
236                compute_expr_scopes(*expr, store, scopes, scope);
237            }
238            Statement::Item(Item::MacroDef(macro_id)) => {
239                *scope = scopes.new_macro_def_scope(*scope, macro_id.clone());
240            }
241            Statement::Item(Item::Other) => (),
242        }
243    }
244    if let Some(expr) = tail {
245        compute_expr_scopes(expr, store, scopes, scope);
246    }
247}
248
249fn compute_expr_scopes(
250    expr: ExprId,
251    store: &ExpressionStore,
252    scopes: &mut ExprScopes,
253    scope: &mut ScopeId,
254) {
255    let make_label =
256        |label: &Option<LabelId>| label.map(|label| (label, store[label].name.clone()));
257
258    let compute_expr_scopes = |scopes: &mut ExprScopes, expr: ExprId, scope: &mut ScopeId| {
259        compute_expr_scopes(expr, store, scopes, scope)
260    };
261
262    scopes.set_scope(expr, *scope);
263    match &store[expr] {
264        Expr::Block { statements, tail, id, label } => {
265            let mut scope = scopes.new_block_scope(*scope, *id, make_label(label));
266            // Overwrite the old scope for the block expr, so that every block scope can be found
267            // via the block itself (important for blocks that only contain items, no expressions).
268            scopes.set_scope(expr, scope);
269            compute_block_scopes(statements, *tail, store, scopes, &mut scope);
270        }
271        Expr::Const(id) => {
272            let mut scope = scopes.root_scope();
273            compute_expr_scopes(scopes, *id, &mut scope);
274        }
275        Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } => {
276            let mut scope = scopes.new_block_scope(*scope, *id, None);
277            // Overwrite the old scope for the block expr, so that every block scope can be found
278            // via the block itself (important for blocks that only contain items, no expressions).
279            scopes.set_scope(expr, scope);
280            compute_block_scopes(statements, *tail, store, scopes, &mut scope);
281        }
282        Expr::Loop { body: body_expr, label } => {
283            let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
284            compute_expr_scopes(scopes, *body_expr, &mut scope);
285        }
286        Expr::Closure { args, body: body_expr, .. } => {
287            let mut scope = scopes.new_scope(*scope);
288            scopes.add_params_bindings(store, scope, args);
289            compute_expr_scopes(scopes, *body_expr, &mut scope);
290        }
291        Expr::Match { expr, arms } => {
292            compute_expr_scopes(scopes, *expr, scope);
293            for arm in arms.iter() {
294                let mut scope = scopes.new_scope(*scope);
295                scopes.add_pat_bindings(store, scope, arm.pat);
296                if let Some(guard) = arm.guard {
297                    scope = scopes.new_scope(scope);
298                    compute_expr_scopes(scopes, guard, &mut scope);
299                }
300                compute_expr_scopes(scopes, arm.expr, &mut scope);
301            }
302        }
303        &Expr::If { condition, then_branch, else_branch } => {
304            let mut then_branch_scope = scopes.new_scope(*scope);
305            compute_expr_scopes(scopes, condition, &mut then_branch_scope);
306            compute_expr_scopes(scopes, then_branch, &mut then_branch_scope);
307            if let Some(else_branch) = else_branch {
308                compute_expr_scopes(scopes, else_branch, scope);
309            }
310        }
311        &Expr::Let { pat, expr } => {
312            compute_expr_scopes(scopes, expr, scope);
313            *scope = scopes.new_scope(*scope);
314            scopes.add_pat_bindings(store, *scope, pat);
315        }
316        _ => store.walk_child_exprs(expr, |e| compute_expr_scopes(scopes, e, scope)),
317    };
318}
319
320#[cfg(test)]
321mod tests {
322    use base_db::RootQueryDb;
323    use hir_expand::{InFile, name::AsName};
324    use span::FileId;
325    use syntax::{AstNode, algo::find_node_at_offset, ast};
326    use test_fixture::WithFixture;
327    use test_utils::{assert_eq_text, extract_offset};
328
329    use crate::{
330        FunctionId, ModuleDefId, db::DefDatabase, nameres::crate_def_map, test_db::TestDB,
331    };
332
333    fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
334        let krate = db.test_crate();
335        let crate_def_map = crate_def_map(db, krate);
336
337        let module = crate_def_map.modules_for_file(db, file_id).next().unwrap();
338        let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
339        match def.take_values().unwrap() {
340            ModuleDefId::FunctionId(it) => it,
341            _ => panic!(),
342        }
343    }
344
345    fn do_check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected: &[&str]) {
346        let (offset, code) = extract_offset(ra_fixture);
347        let code = {
348            let mut buf = String::new();
349            let off: usize = offset.into();
350            buf.push_str(&code[..off]);
351            buf.push_str("$0marker");
352            buf.push_str(&code[off..]);
353            buf
354        };
355
356        let (db, position) = TestDB::with_position(&code);
357        let editioned_file_id = position.file_id;
358        let offset = position.offset;
359
360        let (file_id, _) = editioned_file_id.unpack(&db);
361
362        let file_syntax = db.parse(editioned_file_id).syntax_node();
363        let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
364        let function = find_function(&db, file_id);
365
366        let scopes = db.expr_scopes(function.into());
367        let (_body, source_map) = db.body_with_source_map(function.into());
368
369        let expr_id = source_map
370            .node_expr(InFile { file_id: editioned_file_id.into(), value: &marker.into() })
371            .unwrap()
372            .as_expr()
373            .unwrap();
374        let scope = scopes.scope_for(expr_id);
375
376        let actual = scopes
377            .scope_chain(scope)
378            .flat_map(|scope| scopes.entries(scope))
379            .map(|it| it.name().as_str())
380            .collect::<Vec<_>>()
381            .join("\n");
382        let expected = expected.join("\n");
383        assert_eq_text!(&expected, &actual);
384    }
385
386    #[test]
387    fn test_lambda_scope() {
388        do_check(
389            r"
390            fn quux(foo: i32) {
391                let f = |bar, baz: i32| {
392                    $0
393                };
394            }",
395            &["bar", "baz", "foo"],
396        );
397    }
398
399    #[test]
400    fn test_call_scope() {
401        do_check(
402            r"
403            fn quux() {
404                f(|x| $0 );
405            }",
406            &["x"],
407        );
408    }
409
410    #[test]
411    fn test_method_call_scope() {
412        do_check(
413            r"
414            fn quux() {
415                z.f(|x| $0 );
416            }",
417            &["x"],
418        );
419    }
420
421    #[test]
422    fn test_loop_scope() {
423        do_check(
424            r"
425            fn quux() {
426                loop {
427                    let x = ();
428                    $0
429                };
430            }",
431            &["x"],
432        );
433    }
434
435    #[test]
436    fn test_match() {
437        do_check(
438            r"
439            fn quux() {
440                match () {
441                    Some(x) => {
442                        $0
443                    }
444                };
445            }",
446            &["x"],
447        );
448    }
449
450    #[test]
451    fn test_shadow_variable() {
452        do_check(
453            r"
454            fn foo(x: String) {
455                let x : &str = &x$0;
456            }",
457            &["x"],
458        );
459    }
460
461    #[test]
462    fn test_bindings_after_at() {
463        do_check(
464            r"
465fn foo() {
466    match Some(()) {
467        opt @ Some(unit) => {
468            $0
469        }
470        _ => {}
471    }
472}
473",
474            &["opt", "unit"],
475        );
476    }
477
478    #[test]
479    fn macro_inner_item() {
480        do_check(
481            r"
482            macro_rules! mac {
483                () => {{
484                    fn inner() {}
485                    inner();
486                }};
487            }
488
489            fn foo() {
490                mac!();
491                $0
492            }
493        ",
494            &[],
495        );
496    }
497
498    #[test]
499    fn broken_inner_item() {
500        do_check(
501            r"
502            fn foo() {
503                trait {}
504                $0
505            }
506        ",
507            &[],
508        );
509    }
510
511    fn do_check_local_name(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected_offset: u32) {
512        let (db, position) = TestDB::with_position(ra_fixture);
513        let editioned_file_id = position.file_id;
514        let offset = position.offset;
515
516        let (file_id, _) = editioned_file_id.unpack(&db);
517
518        let file = db.parse(editioned_file_id).ok().unwrap();
519        let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
520            .expect("failed to find a name at the target offset");
521        let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
522
523        let function = find_function(&db, file_id);
524
525        let scopes = db.expr_scopes(function.into());
526        let (_, source_map) = db.body_with_source_map(function.into());
527
528        let expr_scope = {
529            let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
530            let expr_id = source_map
531                .node_expr(InFile { file_id: editioned_file_id.into(), value: &expr_ast })
532                .unwrap()
533                .as_expr()
534                .unwrap();
535            scopes.scope_for(expr_id).unwrap()
536        };
537
538        let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
539        let pat_src =
540            source_map.pat_syntax(source_map.patterns_for_binding(resolved.binding())[0]).unwrap();
541
542        let local_name = pat_src.value.syntax_node_ptr().to_node(file.syntax());
543        assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
544    }
545
546    #[test]
547    fn test_resolve_local_name() {
548        do_check_local_name(
549            r#"
550fn foo(x: i32, y: u32) {
551    {
552        let z = x * 2;
553    }
554    {
555        let t = x$0 * 3;
556    }
557}
558"#,
559            7,
560        );
561    }
562
563    #[test]
564    fn test_resolve_local_name_declaration() {
565        do_check_local_name(
566            r#"
567fn foo(x: String) {
568    let x : &str = &x$0;
569}
570"#,
571            7,
572        );
573    }
574
575    #[test]
576    fn test_resolve_local_name_shadow() {
577        do_check_local_name(
578            r"
579fn foo(x: String) {
580    let x : &str = &x;
581    x$0
582}
583",
584            28,
585        );
586    }
587
588    #[test]
589    fn ref_patterns_contribute_bindings() {
590        do_check_local_name(
591            r"
592fn foo() {
593    if let Some(&from) = bar() {
594        from$0;
595    }
596}
597",
598            28,
599        );
600    }
601
602    #[test]
603    fn while_let_adds_binding() {
604        do_check_local_name(
605            r#"
606fn test() {
607    let foo: Option<f32> = None;
608    while let Option::Some(spam) = foo {
609        spam$0
610    }
611}
612"#,
613            75,
614        );
615        do_check_local_name(
616            r#"
617fn test() {
618    let foo: Option<f32> = None;
619    while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo {
620        spam$0
621    }
622}
623"#,
624            107,
625        );
626    }
627
628    #[test]
629    fn match_guard_if_let() {
630        do_check_local_name(
631            r#"
632fn test() {
633    let foo: Option<f32> = None;
634    match foo {
635        _ if let Option::Some(spam) = foo => spam$0,
636    }
637}
638"#,
639            93,
640        );
641    }
642
643    #[test]
644    fn let_chains_can_reference_previous_lets() {
645        do_check_local_name(
646            r#"
647fn test() {
648    let foo: Option<i32> = None;
649    if let Some(spam) = foo && spa$0m > 1 && let Some(spam) = foo && spam > 1 {}
650}
651"#,
652            61,
653        );
654        do_check_local_name(
655            r#"
656fn test() {
657    let foo: Option<i32> = None;
658    if let Some(spam) = foo && spam > 1 && let Some(spam) = foo && sp$0am > 1 {}
659}
660"#,
661            100,
662        );
663    }
664}