1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
use super::RecordedItemId;
use crate::RustIrDatabase;
use chalk_ir::{
    interner::Interner,
    visit::TypeVisitor,
    visit::{TypeSuperVisitable, TypeVisitable},
    AliasTy, DebruijnIndex, TyKind, WhereClause,
};
use std::ops::ControlFlow;

use indexmap::IndexSet;

/// Collects the identifiers needed to resolve all the names for a given
/// set of identifers, excluding identifiers we already have.
///
/// When recording identifiers to print, the `LoggingRustIrDatabase` only
/// records identifiers the solver uses. But the solver assumes well-formedness,
/// and thus skips over many names referenced in the definitions.
///
/// For instance, if we have:
///
/// ```rust,ignore
/// struct S {}
///
/// trait Parent {}
/// trait Child where Self: Parent {}
/// impl Parent for S {}
/// impl Child for S {}
/// ```
///
/// And our goal is `S: Child`, we will only render `S`, `impl Child for S`, and
/// `trait Child`. This will not parse because the `Child` trait's definition
/// references parent. IdCollector solves this by collecting all of the directly
/// related identifiers, allowing those to be rendered as well, ensuring name
/// resolution is successful.
pub fn collect_unrecorded_ids<I: Interner, DB: RustIrDatabase<I>>(
    db: &DB,
    identifiers: &'_ IndexSet<RecordedItemId<I>>,
) -> IndexSet<RecordedItemId<I>> {
    let mut collector = IdCollector {
        db,
        found_identifiers: IndexSet::new(),
    };
    for id in identifiers {
        match *id {
            RecordedItemId::Adt(adt_id) => {
                collector
                    .db
                    .adt_datum(adt_id)
                    .visit_with(&mut collector, DebruijnIndex::INNERMOST);
            }
            RecordedItemId::FnDef(fn_def) => {
                collector
                    .db
                    .fn_def_datum(fn_def)
                    .visit_with(&mut collector, DebruijnIndex::INNERMOST);
            }
            RecordedItemId::Coroutine(_coroutine_id) => unimplemented!(),
            RecordedItemId::Trait(trait_id) => {
                let trait_datum = collector.db.trait_datum(trait_id);

                trait_datum.visit_with(&mut collector, DebruijnIndex::INNERMOST);
                for assoc_ty_id in &trait_datum.associated_ty_ids {
                    let assoc_ty_datum = collector.db.associated_ty_data(*assoc_ty_id);
                    assoc_ty_datum
                        .bounds_on_self(collector.db.interner())
                        .visit_with(&mut collector, DebruijnIndex::INNERMOST);
                    assoc_ty_datum.visit_with(&mut collector, DebruijnIndex::INNERMOST);
                }
            }
            RecordedItemId::OpaqueTy(opaque_id) => {
                collector
                    .db
                    .opaque_ty_data(opaque_id)
                    .visit_with(&mut collector, DebruijnIndex::INNERMOST);
                collector
                    .db
                    .hidden_opaque_type(opaque_id)
                    .visit_with(&mut collector, DebruijnIndex::INNERMOST);
            }
            RecordedItemId::Impl(impl_id) => {
                let impl_datum = collector.db.impl_datum(impl_id);
                for id in &impl_datum.associated_ty_value_ids {
                    let assoc_ty_value = collector.db.associated_ty_value(*id);
                    assoc_ty_value.visit_with(&mut collector, DebruijnIndex::INNERMOST);
                }
                impl_datum.visit_with(&mut collector, DebruijnIndex::INNERMOST);
            }
        }
    }
    collector
        .found_identifiers
        .difference(identifiers)
        .copied()
        .collect()
}

struct IdCollector<'i, I: Interner, DB: RustIrDatabase<I>> {
    db: &'i DB,
    found_identifiers: IndexSet<RecordedItemId<I>>,
}

impl<'i, I: Interner, DB: RustIrDatabase<I>> IdCollector<'i, I, DB> {
    fn record(&mut self, id: impl Into<RecordedItemId<I>>) {
        self.found_identifiers.insert(id.into());
    }

    fn visit_alias(&mut self, alias: &AliasTy<I>) {
        match alias {
            AliasTy::Projection(projection_ty) => {
                let assoc_ty_datum = self.db.associated_ty_data(projection_ty.associated_ty_id);
                self.record(assoc_ty_datum.trait_id)
            }
            AliasTy::Opaque(opaque_ty) => self.record(opaque_ty.opaque_ty_id),
        }
    }
}

impl<'i, I: Interner, DB: RustIrDatabase<I>> TypeVisitor<I> for IdCollector<'i, I, DB> {
    type BreakTy = ();

    fn as_dyn(&mut self) -> &mut dyn TypeVisitor<I, BreakTy = Self::BreakTy> {
        self
    }
    fn interner(&self) -> I {
        self.db.interner()
    }

    fn visit_ty(
        &mut self,
        ty: &chalk_ir::Ty<I>,
        outer_binder: chalk_ir::DebruijnIndex,
    ) -> ControlFlow<()> {
        match ty.kind(self.db.interner()) {
            TyKind::Adt(adt, _) => self.record(*adt),
            TyKind::FnDef(fn_def, _) => self.record(*fn_def),
            TyKind::OpaqueType(opaque, _) => self.record(*opaque),
            TyKind::Alias(alias) => self.visit_alias(alias),
            TyKind::BoundVar(..) => (),
            TyKind::Dyn(..) => (),
            TyKind::Function(..) => (),
            TyKind::InferenceVar(..) => (),
            TyKind::Placeholder(..) => (),
            _ => {}
        }
        ty.super_visit_with(self, outer_binder)
    }

    fn visit_where_clause(
        &mut self,
        where_clause: &WhereClause<I>,
        outer_binder: DebruijnIndex,
    ) -> ControlFlow<()> {
        match where_clause {
            WhereClause::Implemented(trait_ref) => self.record(trait_ref.trait_id),
            WhereClause::AliasEq(alias_eq) => self.visit_alias(&alias_eq.alias),
            WhereClause::LifetimeOutlives(_lifetime_outlives) => (),
            WhereClause::TypeOutlives(_type_outlives) => (),
        }
        where_clause.super_visit_with(self.as_dyn(), outer_binder)
    }
}