hir_ty/
utils.rs

1//! Helper functions for working with def, which don't need to be a separate
2//! query, but can't be computed directly from `*Data` (ie, which need a `db`).
3
4use std::{cell::LazyCell, iter};
5
6use base_db::Crate;
7use chalk_ir::{DebruijnIndex, fold::FallibleTypeFolder};
8use hir_def::{
9    EnumId, EnumVariantId, FunctionId, Lookup, TraitId, TypeAliasId, TypeOrConstParamId,
10    db::DefDatabase,
11    hir::generics::WherePredicate,
12    lang_item::LangItem,
13    resolver::{HasResolver, TypeNs},
14    type_ref::{TraitBoundModifier, TypeRef},
15};
16use hir_expand::name::Name;
17use intern::sym;
18use rustc_abi::TargetDataLayout;
19use rustc_hash::FxHashSet;
20use rustc_type_ir::inherent::{IntoKind, SliceLike};
21use smallvec::{SmallVec, smallvec};
22use span::Edition;
23use stdx::never;
24
25use crate::{
26    ChalkTraitId, Const, ConstScalar, GenericArg, Interner, Substitution, TargetFeatures, TraitRef,
27    TraitRefExt, Ty, WhereClause,
28    consteval::unknown_const,
29    db::HirDatabase,
30    layout::{Layout, TagEncoding},
31    mir::pad16,
32    next_solver::{
33        DbInterner,
34        mapping::{ChalkToNextSolver, convert_args_for_result},
35    },
36    to_chalk_trait_id,
37};
38
39pub(crate) fn fn_traits(db: &dyn DefDatabase, krate: Crate) -> impl Iterator<Item = TraitId> + '_ {
40    [LangItem::Fn, LangItem::FnMut, LangItem::FnOnce]
41        .into_iter()
42        .filter_map(move |lang| lang.resolve_trait(db, krate))
43}
44
45/// Returns an iterator over the direct super traits (including the trait itself).
46pub fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> {
47    let mut result = smallvec![trait_];
48    direct_super_traits_cb(db, trait_, |tt| {
49        if !result.contains(&tt) {
50            result.push(tt);
51        }
52    });
53    result
54}
55
56/// Returns an iterator over the whole super trait hierarchy (including the
57/// trait itself).
58pub fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> {
59    // we need to take care a bit here to avoid infinite loops in case of cycles
60    // (i.e. if we have `trait A: B; trait B: A;`)
61
62    let mut result = smallvec![trait_];
63    let mut i = 0;
64    while let Some(&t) = result.get(i) {
65        // yeah this is quadratic, but trait hierarchies should be flat
66        // enough that this doesn't matter
67        direct_super_traits_cb(db, t, |tt| {
68            if !result.contains(&tt) {
69                result.push(tt);
70            }
71        });
72        i += 1;
73    }
74    result
75}
76
77/// Given a trait ref (`Self: Trait`), builds all the implied trait refs for
78/// super traits. The original trait ref will be included. So the difference to
79/// `all_super_traits` is that we keep track of type parameters; for example if
80/// we have `Self: Trait<u32, i32>` and `Trait<T, U>: OtherTrait<U>` we'll get
81/// `Self: OtherTrait<i32>`.
82pub(super) fn all_super_trait_refs<T>(
83    db: &dyn HirDatabase,
84    trait_ref: TraitRef,
85    cb: impl FnMut(TraitRef) -> Option<T>,
86) -> Option<T> {
87    let seen = iter::once(trait_ref.trait_id).collect();
88    SuperTraits { db, seen, stack: vec![trait_ref] }.find_map(cb)
89}
90
91struct SuperTraits<'a> {
92    db: &'a dyn HirDatabase,
93    stack: Vec<TraitRef>,
94    seen: FxHashSet<ChalkTraitId>,
95}
96
97impl SuperTraits<'_> {
98    fn elaborate(&mut self, trait_ref: &TraitRef) {
99        direct_super_trait_refs(self.db, trait_ref, |trait_ref| {
100            if !self.seen.contains(&trait_ref.trait_id) {
101                self.stack.push(trait_ref);
102            }
103        });
104    }
105}
106
107impl Iterator for SuperTraits<'_> {
108    type Item = TraitRef;
109
110    fn next(&mut self) -> Option<Self::Item> {
111        if let Some(next) = self.stack.pop() {
112            self.elaborate(&next);
113            Some(next)
114        } else {
115            None
116        }
117    }
118}
119
120pub(super) fn elaborate_clause_supertraits(
121    db: &dyn HirDatabase,
122    clauses: impl Iterator<Item = WhereClause>,
123) -> ClauseElaborator<'_> {
124    let mut elaborator = ClauseElaborator { db, stack: Vec::new(), seen: FxHashSet::default() };
125    elaborator.extend_deduped(clauses);
126
127    elaborator
128}
129
130pub(super) struct ClauseElaborator<'a> {
131    db: &'a dyn HirDatabase,
132    stack: Vec<WhereClause>,
133    seen: FxHashSet<WhereClause>,
134}
135
136impl ClauseElaborator<'_> {
137    fn extend_deduped(&mut self, clauses: impl IntoIterator<Item = WhereClause>) {
138        self.stack.extend(clauses.into_iter().filter(|c| self.seen.insert(c.clone())))
139    }
140
141    fn elaborate_supertrait(&mut self, clause: &WhereClause) {
142        if let WhereClause::Implemented(trait_ref) = clause {
143            direct_super_trait_refs(self.db, trait_ref, |t| {
144                let clause = WhereClause::Implemented(t);
145                if self.seen.insert(clause.clone()) {
146                    self.stack.push(clause);
147                }
148            });
149        }
150    }
151}
152
153impl Iterator for ClauseElaborator<'_> {
154    type Item = WhereClause;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if let Some(next) = self.stack.pop() {
158            self.elaborate_supertrait(&next);
159            Some(next)
160        } else {
161            None
162        }
163    }
164}
165
166fn direct_super_traits_cb(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) {
167    let resolver = LazyCell::new(|| trait_.resolver(db));
168    let (generic_params, store) = db.generic_params_and_store(trait_.into());
169    let trait_self = generic_params.trait_self_param();
170    generic_params
171        .where_predicates()
172        .iter()
173        .filter_map(|pred| match pred {
174            WherePredicate::ForLifetime { target, bound, .. }
175            | WherePredicate::TypeBound { target, bound } => {
176                let is_trait = match &store[*target] {
177                    TypeRef::Path(p) => p.is_self_type(),
178                    TypeRef::TypeParam(p) => Some(p.local_id()) == trait_self,
179                    _ => false,
180                };
181                match is_trait {
182                    true => bound.as_path(&store),
183                    false => None,
184                }
185            }
186            WherePredicate::Lifetime { .. } => None,
187        })
188        .filter(|(_, bound_modifier)| matches!(bound_modifier, TraitBoundModifier::None))
189        .filter_map(|(path, _)| match resolver.resolve_path_in_type_ns_fully(db, path) {
190            Some(TypeNs::TraitId(t)) => Some(t),
191            _ => None,
192        })
193        .for_each(cb);
194}
195
196fn direct_super_trait_refs(db: &dyn HirDatabase, trait_ref: &TraitRef, cb: impl FnMut(TraitRef)) {
197    let interner = DbInterner::new_with(db, None, None);
198    let generic_params = db.generic_params(trait_ref.hir_trait_id().into());
199    let trait_self = match generic_params.trait_self_param() {
200        Some(p) => TypeOrConstParamId { parent: trait_ref.hir_trait_id().into(), local_id: p },
201        None => return,
202    };
203    let trait_ref_args: crate::next_solver::GenericArgs<'_> =
204        trait_ref.substitution.to_nextsolver(interner);
205    db.generic_predicates_for_param_ns(trait_self.parent, trait_self, None)
206        .iter()
207        .filter_map(|pred| {
208            let pred = pred.kind();
209            // FIXME: how to correctly handle higher-ranked bounds here?
210            let pred = pred.no_bound_vars().expect("FIXME unexpected higher-ranked trait bound");
211            match pred {
212                rustc_type_ir::ClauseKind::Trait(t) => {
213                    let t =
214                        rustc_type_ir::EarlyBinder::bind(t).instantiate(interner, trait_ref_args);
215                    let trait_id = to_chalk_trait_id(t.def_id().0);
216
217                    let substitution =
218                        convert_args_for_result(interner, t.trait_ref.args.as_slice());
219                    let tr = chalk_ir::TraitRef { trait_id, substitution };
220                    Some(tr)
221                }
222                _ => None,
223            }
224        })
225        .for_each(cb);
226}
227
228pub(super) fn associated_type_by_name_including_super_traits(
229    db: &dyn HirDatabase,
230    trait_ref: TraitRef,
231    name: &Name,
232) -> Option<(TraitRef, TypeAliasId)> {
233    all_super_trait_refs(db, trait_ref, |t| {
234        let assoc_type = t.hir_trait_id().trait_items(db).associated_type_by_name(name)?;
235        Some((t, assoc_type))
236    })
237}
238
239/// It is a bit different from the rustc equivalent. Currently it stores:
240/// - 0..n-1: generics of the parent
241/// - n: the function signature, encoded as a function pointer type
242///
243/// and it doesn't store the closure types and fields.
244///
245/// Codes should not assume this ordering, and should always use methods available
246/// on this struct for retrieving, and `TyBuilder::substs_for_closure` for creating.
247pub(crate) struct ClosureSubst<'a>(pub(crate) &'a Substitution);
248
249impl<'a> ClosureSubst<'a> {
250    pub(crate) fn parent_subst(&self) -> &'a [GenericArg] {
251        match self.0.as_slice(Interner) {
252            [x @ .., _] => x,
253            _ => {
254                never!("Closure missing parameter");
255                &[]
256            }
257        }
258    }
259
260    pub(crate) fn sig_ty(&self) -> &'a Ty {
261        match self.0.as_slice(Interner) {
262            [.., x] => x.assert_ty_ref(Interner),
263            _ => {
264                unreachable!("Closure missing sig_ty parameter");
265            }
266        }
267    }
268}
269
270#[derive(Debug, Clone, Copy, PartialEq, Eq)]
271pub enum Unsafety {
272    Safe,
273    Unsafe,
274    /// A lint.
275    DeprecatedSafe2024,
276}
277
278pub fn is_fn_unsafe_to_call(
279    db: &dyn HirDatabase,
280    func: FunctionId,
281    caller_target_features: &TargetFeatures,
282    call_edition: Edition,
283) -> Unsafety {
284    let data = db.function_signature(func);
285    if data.is_unsafe() {
286        return Unsafety::Unsafe;
287    }
288
289    if data.has_target_feature() {
290        // RFC 2396 <https://rust-lang.github.io/rfcs/2396-target-feature-1.1.html>.
291        let callee_target_features =
292            TargetFeatures::from_attrs_no_implications(&db.attrs(func.into()));
293        if !caller_target_features.enabled.is_superset(&callee_target_features.enabled) {
294            return Unsafety::Unsafe;
295        }
296    }
297
298    if data.is_deprecated_safe_2024() {
299        if call_edition.at_least_2024() {
300            return Unsafety::Unsafe;
301        } else {
302            return Unsafety::DeprecatedSafe2024;
303        }
304    }
305
306    let loc = func.lookup(db);
307    match loc.container {
308        hir_def::ItemContainerId::ExternBlockId(block) => {
309            let is_intrinsic_block = block.abi(db) == Some(sym::rust_dash_intrinsic);
310            if is_intrinsic_block {
311                // legacy intrinsics
312                // extern "rust-intrinsic" intrinsics are unsafe unless they have the rustc_safe_intrinsic attribute
313                if db.attrs(func.into()).by_key(sym::rustc_safe_intrinsic).exists() {
314                    Unsafety::Safe
315                } else {
316                    Unsafety::Unsafe
317                }
318            } else {
319                // Function in an `extern` block are always unsafe to call, except when
320                // it is marked as `safe`.
321                if data.is_safe() { Unsafety::Safe } else { Unsafety::Unsafe }
322            }
323        }
324        _ => Unsafety::Safe,
325    }
326}
327
328pub(crate) struct UnevaluatedConstEvaluatorFolder<'a> {
329    pub(crate) db: &'a dyn HirDatabase,
330}
331
332impl FallibleTypeFolder<Interner> for UnevaluatedConstEvaluatorFolder<'_> {
333    type Error = ();
334
335    fn as_dyn(&mut self) -> &mut dyn FallibleTypeFolder<Interner, Error = ()> {
336        self
337    }
338
339    fn interner(&self) -> Interner {
340        Interner
341    }
342
343    fn try_fold_const(
344        &mut self,
345        constant: Const,
346        _outer_binder: DebruijnIndex,
347    ) -> Result<Const, Self::Error> {
348        if let chalk_ir::ConstValue::Concrete(c) = &constant.data(Interner).value
349            && let ConstScalar::UnevaluatedConst(id, subst) = &c.interned
350        {
351            if let Ok(eval) = self.db.const_eval(*id, subst.clone(), None) {
352                return Ok(eval);
353            } else {
354                return Ok(unknown_const(constant.data(Interner).ty.clone()));
355            }
356        }
357        Ok(constant)
358    }
359}
360
361pub(crate) fn detect_variant_from_bytes<'a>(
362    layout: &'a Layout,
363    db: &dyn HirDatabase,
364    target_data_layout: &TargetDataLayout,
365    b: &[u8],
366    e: EnumId,
367) -> Option<(EnumVariantId, &'a Layout)> {
368    let (var_id, var_layout) = match &layout.variants {
369        hir_def::layout::Variants::Empty => unreachable!(),
370        hir_def::layout::Variants::Single { index } => {
371            (e.enum_variants(db).variants[index.0].0, layout)
372        }
373        hir_def::layout::Variants::Multiple { tag, tag_encoding, variants, .. } => {
374            let size = tag.size(target_data_layout).bytes_usize();
375            let offset = layout.fields.offset(0).bytes_usize(); // The only field on enum variants is the tag field
376            let tag = i128::from_le_bytes(pad16(&b[offset..offset + size], false));
377            match tag_encoding {
378                TagEncoding::Direct => {
379                    let (var_idx, layout) =
380                        variants.iter_enumerated().find_map(|(var_idx, v)| {
381                            let def = e.enum_variants(db).variants[var_idx.0].0;
382                            (db.const_eval_discriminant(def) == Ok(tag)).then_some((def, v))
383                        })?;
384                    (var_idx, layout)
385                }
386                TagEncoding::Niche { untagged_variant, niche_start, .. } => {
387                    let candidate_tag = tag.wrapping_sub(*niche_start as i128) as usize;
388                    let variant = variants
389                        .iter_enumerated()
390                        .map(|(x, _)| x)
391                        .filter(|x| x != untagged_variant)
392                        .nth(candidate_tag)
393                        .unwrap_or(*untagged_variant);
394                    (e.enum_variants(db).variants[variant.0].0, &variants[variant])
395                }
396            }
397        }
398    };
399    Some((var_id, var_layout))
400}