hir_ty/diagnostics/
unsafe_check.rs

1//! Provides validations for unsafe code. Currently checks if unsafe functions are missing
2//! unsafe blocks.
3
4use std::mem;
5
6use either::Either;
7use hir_def::{
8    AdtId, CallableDefId, DefWithBodyId, FieldId, FunctionId, VariantId,
9    expr_store::{Body, path::Path},
10    hir::{AsmOperand, Expr, ExprId, ExprOrPatId, InlineAsmKind, Pat, PatId, Statement, UnaryOp},
11    resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs},
12    signatures::StaticFlags,
13    type_ref::Rawness,
14};
15use rustc_type_ir::inherent::IntoKind;
16use span::Edition;
17
18use crate::{
19    InferenceResult, TargetFeatures,
20    db::HirDatabase,
21    next_solver::{CallableIdWrapper, TyKind, abi::Safety},
22    utils::{TargetFeatureIsSafeInTarget, is_fn_unsafe_to_call, target_feature_is_safe_in_target},
23};
24
25#[derive(Debug, Default)]
26pub struct MissingUnsafeResult {
27    pub unsafe_exprs: Vec<(ExprOrPatId, UnsafetyReason)>,
28    /// If `fn_is_unsafe` is false, `unsafe_exprs` are hard errors. If true, they're `unsafe_op_in_unsafe_fn`.
29    pub fn_is_unsafe: bool,
30    pub deprecated_safe_calls: Vec<ExprId>,
31}
32
33pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> MissingUnsafeResult {
34    let _p = tracing::info_span!("missing_unsafe").entered();
35
36    let is_unsafe = match def {
37        DefWithBodyId::FunctionId(it) => db.function_signature(it).is_unsafe(),
38        DefWithBodyId::StaticId(_) | DefWithBodyId::ConstId(_) | DefWithBodyId::VariantId(_) => {
39            false
40        }
41    };
42
43    let mut res = MissingUnsafeResult { fn_is_unsafe: is_unsafe, ..MissingUnsafeResult::default() };
44    let body = db.body(def);
45    let infer = InferenceResult::for_body(db, def);
46    let mut callback = |diag| match diag {
47        UnsafeDiagnostic::UnsafeOperation { node, inside_unsafe_block, reason } => {
48            if inside_unsafe_block == InsideUnsafeBlock::No {
49                res.unsafe_exprs.push((node, reason));
50            }
51        }
52        UnsafeDiagnostic::DeprecatedSafe2024 { node, inside_unsafe_block } => {
53            if inside_unsafe_block == InsideUnsafeBlock::No {
54                res.deprecated_safe_calls.push(node)
55            }
56        }
57    };
58    let mut visitor = UnsafeVisitor::new(db, infer, &body, def, &mut callback);
59    visitor.walk_expr(body.body_expr);
60
61    if !is_unsafe {
62        // Unsafety in function parameter patterns (that can only be union destructuring)
63        // cannot be inserted into an unsafe block, so even with `unsafe_op_in_unsafe_fn`
64        // it is turned off for unsafe functions.
65        for &param in &body.params {
66            visitor.walk_pat(param);
67        }
68    }
69
70    res
71}
72
73#[derive(Debug, Clone, Copy)]
74pub enum UnsafetyReason {
75    UnionField,
76    UnsafeFnCall,
77    InlineAsm,
78    RawPtrDeref,
79    MutableStatic,
80    ExternStatic,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum InsideUnsafeBlock {
85    No,
86    Yes,
87}
88
89#[derive(Debug)]
90enum UnsafeDiagnostic {
91    UnsafeOperation {
92        node: ExprOrPatId,
93        inside_unsafe_block: InsideUnsafeBlock,
94        reason: UnsafetyReason,
95    },
96    /// A lint.
97    DeprecatedSafe2024 { node: ExprId, inside_unsafe_block: InsideUnsafeBlock },
98}
99
100pub fn unsafe_operations_for_body<'db>(
101    db: &'db dyn HirDatabase,
102    infer: &InferenceResult<'db>,
103    def: DefWithBodyId,
104    body: &Body,
105    callback: &mut dyn FnMut(ExprOrPatId),
106) {
107    let mut visitor_callback = |diag| {
108        if let UnsafeDiagnostic::UnsafeOperation { node, .. } = diag {
109            callback(node);
110        }
111    };
112    let mut visitor = UnsafeVisitor::new(db, infer, body, def, &mut visitor_callback);
113    visitor.walk_expr(body.body_expr);
114    for &param in &body.params {
115        visitor.walk_pat(param);
116    }
117}
118
119pub fn unsafe_operations<'db>(
120    db: &'db dyn HirDatabase,
121    infer: &InferenceResult<'db>,
122    def: DefWithBodyId,
123    body: &Body,
124    current: ExprId,
125    callback: &mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock),
126) {
127    let mut visitor_callback = |diag| {
128        if let UnsafeDiagnostic::UnsafeOperation { inside_unsafe_block, node, .. } = diag {
129            callback(node, inside_unsafe_block);
130        }
131    };
132    let mut visitor = UnsafeVisitor::new(db, infer, body, def, &mut visitor_callback);
133    _ = visitor.resolver.update_to_inner_scope(db, def, current);
134    visitor.walk_expr(current);
135}
136
137struct UnsafeVisitor<'db> {
138    db: &'db dyn HirDatabase,
139    infer: &'db InferenceResult<'db>,
140    body: &'db Body,
141    resolver: Resolver<'db>,
142    def: DefWithBodyId,
143    inside_unsafe_block: InsideUnsafeBlock,
144    inside_assignment: bool,
145    inside_union_destructure: bool,
146    callback: &'db mut dyn FnMut(UnsafeDiagnostic),
147    def_target_features: TargetFeatures<'db>,
148    // FIXME: This needs to be the edition of the span of each call.
149    edition: Edition,
150    /// On some targets (WASM), calling safe functions with `#[target_feature]` is always safe, even when
151    /// the target feature is not enabled. This flag encodes that.
152    target_feature_is_safe: TargetFeatureIsSafeInTarget,
153}
154
155impl<'db> UnsafeVisitor<'db> {
156    fn new(
157        db: &'db dyn HirDatabase,
158        infer: &'db InferenceResult<'db>,
159        body: &'db Body,
160        def: DefWithBodyId,
161        unsafe_expr_cb: &'db mut dyn FnMut(UnsafeDiagnostic),
162    ) -> Self {
163        let resolver = def.resolver(db);
164        let def_target_features = match def {
165            DefWithBodyId::FunctionId(func) => TargetFeatures::from_fn(db, func),
166            _ => TargetFeatures::default(),
167        };
168        let krate = resolver.krate();
169        let edition = krate.data(db).edition;
170        let target_feature_is_safe = match &krate.workspace_data(db).target {
171            Ok(target) => target_feature_is_safe_in_target(target),
172            Err(_) => TargetFeatureIsSafeInTarget::No,
173        };
174        Self {
175            db,
176            infer,
177            body,
178            resolver,
179            def,
180            inside_unsafe_block: InsideUnsafeBlock::No,
181            inside_assignment: false,
182            inside_union_destructure: false,
183            callback: unsafe_expr_cb,
184            def_target_features,
185            edition,
186            target_feature_is_safe,
187        }
188    }
189
190    fn on_unsafe_op(&mut self, node: ExprOrPatId, reason: UnsafetyReason) {
191        (self.callback)(UnsafeDiagnostic::UnsafeOperation {
192            node,
193            inside_unsafe_block: self.inside_unsafe_block,
194            reason,
195        });
196    }
197
198    fn check_call(&mut self, node: ExprId, func: FunctionId) {
199        let unsafety = is_fn_unsafe_to_call(
200            self.db,
201            func,
202            &self.def_target_features,
203            self.edition,
204            self.target_feature_is_safe,
205        );
206        match unsafety {
207            crate::utils::Unsafety::Safe => {}
208            crate::utils::Unsafety::Unsafe => {
209                self.on_unsafe_op(node.into(), UnsafetyReason::UnsafeFnCall)
210            }
211            crate::utils::Unsafety::DeprecatedSafe2024 => {
212                (self.callback)(UnsafeDiagnostic::DeprecatedSafe2024 {
213                    node,
214                    inside_unsafe_block: self.inside_unsafe_block,
215                })
216            }
217        }
218    }
219
220    fn with_inside_unsafe_block<R>(
221        &mut self,
222        inside_unsafe_block: InsideUnsafeBlock,
223        f: impl FnOnce(&mut Self) -> R,
224    ) -> R {
225        let old = mem::replace(&mut self.inside_unsafe_block, inside_unsafe_block);
226        let result = f(self);
227        self.inside_unsafe_block = old;
228        result
229    }
230
231    fn walk_pats_top(&mut self, pats: impl Iterator<Item = PatId>, parent_expr: ExprId) {
232        let guard = self.resolver.update_to_inner_scope(self.db, self.def, parent_expr);
233        pats.for_each(|pat| self.walk_pat(pat));
234        self.resolver.reset_to_guard(guard);
235    }
236
237    fn walk_pat(&mut self, current: PatId) {
238        let pat = &self.body[current];
239
240        if self.inside_union_destructure {
241            match pat {
242                Pat::Tuple { .. }
243                | Pat::Record { .. }
244                | Pat::Range { .. }
245                | Pat::Slice { .. }
246                | Pat::Path(..)
247                | Pat::Lit(..)
248                | Pat::Bind { .. }
249                | Pat::TupleStruct { .. }
250                | Pat::Ref { .. }
251                | Pat::Box { .. }
252                | Pat::Expr(..)
253                | Pat::ConstBlock(..) => {
254                    self.on_unsafe_op(current.into(), UnsafetyReason::UnionField)
255                }
256                // `Or` only wraps other patterns, and `Missing`/`Wild` do not constitute a read.
257                Pat::Missing | Pat::Wild | Pat::Or(_) => {}
258            }
259        }
260
261        match pat {
262            Pat::Record { .. } => {
263                if let Some((AdtId::UnionId(_), _)) = self.infer[current].as_adt() {
264                    let old_inside_union_destructure =
265                        mem::replace(&mut self.inside_union_destructure, true);
266                    self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat));
267                    self.inside_union_destructure = old_inside_union_destructure;
268                    return;
269                }
270            }
271            Pat::Path(path) => self.mark_unsafe_path(current.into(), path),
272            &Pat::ConstBlock(expr) => {
273                let old_inside_assignment = mem::replace(&mut self.inside_assignment, false);
274                self.walk_expr(expr);
275                self.inside_assignment = old_inside_assignment;
276            }
277            &Pat::Expr(expr) => self.walk_expr(expr),
278            _ => {}
279        }
280
281        self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat));
282    }
283
284    fn walk_expr(&mut self, current: ExprId) {
285        let expr = &self.body[current];
286        let inside_assignment = mem::replace(&mut self.inside_assignment, false);
287        match expr {
288            &Expr::Call { callee, .. } => {
289                let callee = self.infer[callee];
290                if let TyKind::FnDef(CallableIdWrapper(CallableDefId::FunctionId(func)), _) =
291                    callee.kind()
292                {
293                    self.check_call(current, func);
294                }
295                if let TyKind::FnPtr(_, hdr) = callee.kind()
296                    && hdr.safety == Safety::Unsafe
297                {
298                    self.on_unsafe_op(current.into(), UnsafetyReason::UnsafeFnCall);
299                }
300            }
301            Expr::Path(path) => {
302                let guard = self.resolver.update_to_inner_scope(self.db, self.def, current);
303                self.mark_unsafe_path(current.into(), path);
304                self.resolver.reset_to_guard(guard);
305            }
306            Expr::Ref { expr, rawness: Rawness::RawPtr, mutability: _ } => {
307                match self.body[*expr] {
308                    // Do not report unsafe for `addr_of[_mut]!(EXTERN_OR_MUT_STATIC)`,
309                    // see https://github.com/rust-lang/rust/pull/125834.
310                    Expr::Path(_) => return,
311                    // https://github.com/rust-lang/rust/pull/129248
312                    // Taking a raw ref to a deref place expr is always safe.
313                    Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
314                        self.body
315                            .walk_child_exprs_without_pats(expr, |child| self.walk_expr(child));
316
317                        return;
318                    }
319                    _ => (),
320                }
321
322                let mut peeled = *expr;
323                while let Expr::Field { expr: lhs, .. } = &self.body[peeled] {
324                    if let Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. })) =
325                        self.infer.field_resolution(peeled)
326                    {
327                        peeled = *lhs;
328                    } else {
329                        break;
330                    }
331                }
332
333                // Walk the peeled expression (the LHS of the union field chain)
334                self.walk_expr(peeled);
335                // Return so we don't recurse directly onto the union field access(es)
336                return;
337            }
338            Expr::MethodCall { .. } => {
339                if let Some((func, _)) = self.infer.method_resolution(current) {
340                    self.check_call(current, func);
341                }
342            }
343            Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
344                if let TyKind::RawPtr(..) = self.infer[*expr].kind() {
345                    self.on_unsafe_op(current.into(), UnsafetyReason::RawPtrDeref);
346                }
347            }
348            &Expr::Assignment { target, value: _ } => {
349                let old_inside_assignment = mem::replace(&mut self.inside_assignment, true);
350                self.walk_pats_top(std::iter::once(target), current);
351                self.inside_assignment = old_inside_assignment;
352            }
353            Expr::InlineAsm(asm) => {
354                if asm.kind == InlineAsmKind::Asm {
355                    // `naked_asm!()` requires `unsafe` on the attribute (`#[unsafe(naked)]`),
356                    // and `global_asm!()` doesn't require it at all.
357                    self.on_unsafe_op(current.into(), UnsafetyReason::InlineAsm);
358                }
359
360                asm.operands.iter().for_each(|(_, op)| match op {
361                    AsmOperand::In { expr, .. }
362                    | AsmOperand::Out { expr: Some(expr), .. }
363                    | AsmOperand::InOut { expr, .. }
364                    | AsmOperand::Const(expr) => self.walk_expr(*expr),
365                    AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
366                        self.walk_expr(*in_expr);
367                        if let Some(out_expr) = out_expr {
368                            self.walk_expr(*out_expr);
369                        }
370                    }
371                    AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
372                    AsmOperand::Label(expr) => {
373                        // Inline asm labels are considered safe even when inside unsafe blocks.
374                        self.with_inside_unsafe_block(InsideUnsafeBlock::No, |this| {
375                            this.walk_expr(*expr)
376                        });
377                    }
378                });
379                return;
380            }
381            // rustc allows union assignment to propagate through field accesses and casts.
382            Expr::Cast { .. } => self.inside_assignment = inside_assignment,
383            Expr::Field { .. } => {
384                self.inside_assignment = inside_assignment;
385                if !inside_assignment
386                    && let Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. })) =
387                        self.infer.field_resolution(current)
388                {
389                    self.on_unsafe_op(current.into(), UnsafetyReason::UnionField);
390                }
391            }
392            Expr::Unsafe { statements, .. } => {
393                self.with_inside_unsafe_block(InsideUnsafeBlock::Yes, |this| {
394                    this.walk_pats_top(
395                        statements.iter().filter_map(|statement| match statement {
396                            &Statement::Let { pat, .. } => Some(pat),
397                            _ => None,
398                        }),
399                        current,
400                    );
401                    this.body.walk_child_exprs_without_pats(current, |child| this.walk_expr(child));
402                });
403                return;
404            }
405            Expr::Block { statements, .. } | Expr::Async { statements, .. } => {
406                self.walk_pats_top(
407                    statements.iter().filter_map(|statement| match statement {
408                        &Statement::Let { pat, .. } => Some(pat),
409                        _ => None,
410                    }),
411                    current,
412                );
413            }
414            Expr::Match { arms, .. } => {
415                self.walk_pats_top(arms.iter().map(|arm| arm.pat), current);
416            }
417            &Expr::Let { pat, .. } => {
418                self.walk_pats_top(std::iter::once(pat), current);
419            }
420            Expr::Closure { args, .. } => {
421                self.walk_pats_top(args.iter().copied(), current);
422            }
423            Expr::Const(e) => self.walk_expr(*e),
424            _ => {}
425        }
426
427        self.body.walk_child_exprs_without_pats(current, |child| self.walk_expr(child));
428    }
429
430    fn mark_unsafe_path(&mut self, node: ExprOrPatId, path: &Path) {
431        let hygiene = self.body.expr_or_pat_path_hygiene(node);
432        let value_or_partial = self.resolver.resolve_path_in_value_ns(self.db, path, hygiene);
433        if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial {
434            let static_data = self.db.static_signature(id);
435            if static_data.flags.contains(StaticFlags::MUTABLE) {
436                self.on_unsafe_op(node, UnsafetyReason::MutableStatic);
437            } else if static_data.flags.contains(StaticFlags::EXTERN)
438                && !static_data.flags.contains(StaticFlags::EXPLICIT_SAFE)
439            {
440                self.on_unsafe_op(node, UnsafetyReason::ExternStatic);
441            }
442        }
443    }
444}