hir_ty/
utils.rs

1//! Helper functions for working with def, which don't need to be a separate
2//! query, but can't be computed directly from `*Data` (ie, which need a `db`).
3
4use std::cell::LazyCell;
5
6use base_db::target::{self, TargetData};
7use hir_def::{
8    EnumId, EnumVariantId, FunctionId, Lookup, TraitId,
9    attrs::AttrFlags,
10    db::DefDatabase,
11    hir::generics::WherePredicate,
12    lang_item::LangItems,
13    resolver::{HasResolver, TypeNs},
14    type_ref::{TraitBoundModifier, TypeRef},
15};
16use intern::sym;
17use rustc_abi::TargetDataLayout;
18use smallvec::{SmallVec, smallvec};
19use span::Edition;
20
21use crate::{
22    TargetFeatures,
23    db::HirDatabase,
24    layout::{Layout, TagEncoding},
25    mir::pad16,
26};
27
28/// SAFETY: `old_pointer` must be valid for unique writes
29pub(crate) unsafe fn unsafe_update_eq<T>(old_pointer: *mut T, new_value: T) -> bool
30where
31    T: PartialEq,
32{
33    // SAFETY: Caller obligation
34    let old_ref: &mut T = unsafe { &mut *old_pointer };
35
36    if *old_ref != new_value {
37        *old_ref = new_value;
38        true
39    } else {
40        // Subtle but important: Eq impls can be buggy or define equality
41        // in surprising ways. If it says that the value has not changed,
42        // we do not modify the existing value, and thus do not have to
43        // update the revision, as downstream code will not see the new value.
44        false
45    }
46}
47
48pub(crate) fn fn_traits(lang_items: &LangItems) -> impl Iterator<Item = TraitId> + '_ {
49    [lang_items.Fn, lang_items.FnMut, lang_items.FnOnce].into_iter().flatten()
50}
51
52/// Returns an iterator over the direct super traits (including the trait itself).
53pub fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> {
54    let mut result = smallvec![trait_];
55    direct_super_traits_cb(db, trait_, |tt| {
56        if !result.contains(&tt) {
57            result.push(tt);
58        }
59    });
60    result
61}
62
63/// Returns an iterator over the whole super trait hierarchy (including the
64/// trait itself).
65pub fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> SmallVec<[TraitId; 4]> {
66    // we need to take care a bit here to avoid infinite loops in case of cycles
67    // (i.e. if we have `trait A: B; trait B: A;`)
68
69    let mut result = smallvec![trait_];
70    let mut i = 0;
71    while let Some(&t) = result.get(i) {
72        // yeah this is quadratic, but trait hierarchies should be flat
73        // enough that this doesn't matter
74        direct_super_traits_cb(db, t, |tt| {
75            if !result.contains(&tt) {
76                result.push(tt);
77            }
78        });
79        i += 1;
80    }
81    result
82}
83
84fn direct_super_traits_cb(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) {
85    let resolver = LazyCell::new(|| trait_.resolver(db));
86    let (generic_params, store) = db.generic_params_and_store(trait_.into());
87    let trait_self = generic_params.trait_self_param();
88    generic_params
89        .where_predicates()
90        .iter()
91        .filter_map(|pred| match pred {
92            WherePredicate::ForLifetime { target, bound, .. }
93            | WherePredicate::TypeBound { target, bound } => {
94                let is_trait = match &store[*target] {
95                    TypeRef::Path(p) => p.is_self_type(),
96                    TypeRef::TypeParam(p) => Some(p.local_id()) == trait_self,
97                    _ => false,
98                };
99                match is_trait {
100                    true => bound.as_path(&store),
101                    false => None,
102                }
103            }
104            WherePredicate::Lifetime { .. } => None,
105        })
106        .filter(|(_, bound_modifier)| matches!(bound_modifier, TraitBoundModifier::None))
107        .filter_map(|(path, _)| match resolver.resolve_path_in_type_ns_fully(db, path) {
108            Some(TypeNs::TraitId(t)) => Some(t),
109            _ => None,
110        })
111        .for_each(cb);
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum Unsafety {
116    Safe,
117    Unsafe,
118    /// A lint.
119    DeprecatedSafe2024,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum TargetFeatureIsSafeInTarget {
124    No,
125    Yes,
126}
127
128pub fn target_feature_is_safe_in_target(target: &TargetData) -> TargetFeatureIsSafeInTarget {
129    match target.arch {
130        target::Arch::Wasm32 | target::Arch::Wasm64 => TargetFeatureIsSafeInTarget::Yes,
131        _ => TargetFeatureIsSafeInTarget::No,
132    }
133}
134
135pub fn is_fn_unsafe_to_call(
136    db: &dyn HirDatabase,
137    func: FunctionId,
138    caller_target_features: &TargetFeatures<'_>,
139    call_edition: Edition,
140    target_feature_is_safe: TargetFeatureIsSafeInTarget,
141) -> Unsafety {
142    let data = db.function_signature(func);
143    if data.is_unsafe() {
144        return Unsafety::Unsafe;
145    }
146
147    if data.has_target_feature() && target_feature_is_safe == TargetFeatureIsSafeInTarget::No {
148        // RFC 2396 <https://rust-lang.github.io/rfcs/2396-target-feature-1.1.html>.
149        let callee_target_features = TargetFeatures::from_fn_no_implications(db, func);
150        if !caller_target_features.enabled.is_superset(&callee_target_features.enabled) {
151            return Unsafety::Unsafe;
152        }
153    }
154
155    if data.is_deprecated_safe_2024() {
156        if call_edition.at_least_2024() {
157            return Unsafety::Unsafe;
158        } else {
159            return Unsafety::DeprecatedSafe2024;
160        }
161    }
162
163    let loc = func.lookup(db);
164    match loc.container {
165        hir_def::ItemContainerId::ExternBlockId(block) => {
166            let is_intrinsic_block = block.abi(db) == Some(sym::rust_dash_intrinsic);
167            if is_intrinsic_block {
168                // legacy intrinsics
169                // extern "rust-intrinsic" intrinsics are unsafe unless they have the rustc_safe_intrinsic attribute
170                if AttrFlags::query(db, func.into()).contains(AttrFlags::RUSTC_SAFE_INTRINSIC) {
171                    Unsafety::Safe
172                } else {
173                    Unsafety::Unsafe
174                }
175            } else {
176                // Function in an `extern` block are always unsafe to call, except when
177                // it is marked as `safe`.
178                if data.is_safe() { Unsafety::Safe } else { Unsafety::Unsafe }
179            }
180        }
181        _ => Unsafety::Safe,
182    }
183}
184
185pub(crate) fn detect_variant_from_bytes<'a>(
186    layout: &'a Layout,
187    db: &dyn HirDatabase,
188    target_data_layout: &TargetDataLayout,
189    b: &[u8],
190    e: EnumId,
191) -> Option<(EnumVariantId, &'a Layout)> {
192    let (var_id, var_layout) = match &layout.variants {
193        hir_def::layout::Variants::Empty => unreachable!(),
194        hir_def::layout::Variants::Single { index } => {
195            (e.enum_variants(db).variants[index.0].0, layout)
196        }
197        hir_def::layout::Variants::Multiple { tag, tag_encoding, variants, .. } => {
198            let size = tag.size(target_data_layout).bytes_usize();
199            let offset = layout.fields.offset(0).bytes_usize(); // The only field on enum variants is the tag field
200            let tag = i128::from_le_bytes(pad16(&b[offset..offset + size], false));
201            match tag_encoding {
202                TagEncoding::Direct => {
203                    let (var_idx, layout) =
204                        variants.iter_enumerated().find_map(|(var_idx, v)| {
205                            let def = e.enum_variants(db).variants[var_idx.0].0;
206                            (db.const_eval_discriminant(def) == Ok(tag)).then_some((def, v))
207                        })?;
208                    (var_idx, layout)
209                }
210                TagEncoding::Niche { untagged_variant, niche_start, .. } => {
211                    let candidate_tag = tag.wrapping_sub(*niche_start as i128) as usize;
212                    let variant = variants
213                        .iter_enumerated()
214                        .map(|(x, _)| x)
215                        .filter(|x| x != untagged_variant)
216                        .nth(candidate_tag)
217                        .unwrap_or(*untagged_variant);
218                    (e.enum_variants(db).variants[variant.0].0, &variants[variant])
219                }
220            }
221        }
222    };
223    Some((var_id, var_layout))
224}