hir_ty/diagnostics/
unsafe_check.rs

1//! Provides validations for unsafe code. Currently checks if unsafe functions are missing
2//! unsafe blocks.
3
4use std::mem;
5
6use either::Either;
7use hir_def::{
8    AdtId, DefWithBodyId, FieldId, FunctionId, VariantId,
9    expr_store::{Body, path::Path},
10    hir::{AsmOperand, Expr, ExprId, ExprOrPatId, InlineAsmKind, Pat, PatId, Statement, UnaryOp},
11    resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs},
12    signatures::StaticFlags,
13    type_ref::Rawness,
14};
15use span::Edition;
16
17use crate::{
18    InferenceResult, Interner, TargetFeatures, TyExt, TyKind, db::HirDatabase,
19    utils::is_fn_unsafe_to_call,
20};
21
22#[derive(Debug, Default)]
23pub struct MissingUnsafeResult {
24    pub unsafe_exprs: Vec<(ExprOrPatId, UnsafetyReason)>,
25    /// If `fn_is_unsafe` is false, `unsafe_exprs` are hard errors. If true, they're `unsafe_op_in_unsafe_fn`.
26    pub fn_is_unsafe: bool,
27    pub deprecated_safe_calls: Vec<ExprId>,
28}
29
30pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> MissingUnsafeResult {
31    let _p = tracing::info_span!("missing_unsafe").entered();
32
33    let is_unsafe = match def {
34        DefWithBodyId::FunctionId(it) => db.function_signature(it).is_unsafe(),
35        DefWithBodyId::StaticId(_) | DefWithBodyId::ConstId(_) | DefWithBodyId::VariantId(_) => {
36            false
37        }
38    };
39
40    let mut res = MissingUnsafeResult { fn_is_unsafe: is_unsafe, ..MissingUnsafeResult::default() };
41    let body = db.body(def);
42    let infer = db.infer(def);
43    let mut callback = |diag| match diag {
44        UnsafeDiagnostic::UnsafeOperation { node, inside_unsafe_block, reason } => {
45            if inside_unsafe_block == InsideUnsafeBlock::No {
46                res.unsafe_exprs.push((node, reason));
47            }
48        }
49        UnsafeDiagnostic::DeprecatedSafe2024 { node, inside_unsafe_block } => {
50            if inside_unsafe_block == InsideUnsafeBlock::No {
51                res.deprecated_safe_calls.push(node)
52            }
53        }
54    };
55    let mut visitor = UnsafeVisitor::new(db, &infer, &body, def, &mut callback);
56    visitor.walk_expr(body.body_expr);
57
58    if !is_unsafe {
59        // Unsafety in function parameter patterns (that can only be union destructuring)
60        // cannot be inserted into an unsafe block, so even with `unsafe_op_in_unsafe_fn`
61        // it is turned off for unsafe functions.
62        for &param in &body.params {
63            visitor.walk_pat(param);
64        }
65    }
66
67    res
68}
69
70#[derive(Debug, Clone, Copy)]
71pub enum UnsafetyReason {
72    UnionField,
73    UnsafeFnCall,
74    InlineAsm,
75    RawPtrDeref,
76    MutableStatic,
77    ExternStatic,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum InsideUnsafeBlock {
82    No,
83    Yes,
84}
85
86#[derive(Debug)]
87enum UnsafeDiagnostic {
88    UnsafeOperation {
89        node: ExprOrPatId,
90        inside_unsafe_block: InsideUnsafeBlock,
91        reason: UnsafetyReason,
92    },
93    /// A lint.
94    DeprecatedSafe2024 { node: ExprId, inside_unsafe_block: InsideUnsafeBlock },
95}
96
97pub fn unsafe_operations_for_body(
98    db: &dyn HirDatabase,
99    infer: &InferenceResult,
100    def: DefWithBodyId,
101    body: &Body,
102    callback: &mut dyn FnMut(ExprOrPatId),
103) {
104    let mut visitor_callback = |diag| {
105        if let UnsafeDiagnostic::UnsafeOperation { node, .. } = diag {
106            callback(node);
107        }
108    };
109    let mut visitor = UnsafeVisitor::new(db, infer, body, def, &mut visitor_callback);
110    visitor.walk_expr(body.body_expr);
111    for &param in &body.params {
112        visitor.walk_pat(param);
113    }
114}
115
116pub fn unsafe_operations(
117    db: &dyn HirDatabase,
118    infer: &InferenceResult,
119    def: DefWithBodyId,
120    body: &Body,
121    current: ExprId,
122    callback: &mut dyn FnMut(ExprOrPatId, InsideUnsafeBlock),
123) {
124    let mut visitor_callback = |diag| {
125        if let UnsafeDiagnostic::UnsafeOperation { inside_unsafe_block, node, .. } = diag {
126            callback(node, inside_unsafe_block);
127        }
128    };
129    let mut visitor = UnsafeVisitor::new(db, infer, body, def, &mut visitor_callback);
130    _ = visitor.resolver.update_to_inner_scope(db, def, current);
131    visitor.walk_expr(current);
132}
133
134struct UnsafeVisitor<'db> {
135    db: &'db dyn HirDatabase,
136    infer: &'db InferenceResult,
137    body: &'db Body,
138    resolver: Resolver<'db>,
139    def: DefWithBodyId,
140    inside_unsafe_block: InsideUnsafeBlock,
141    inside_assignment: bool,
142    inside_union_destructure: bool,
143    callback: &'db mut dyn FnMut(UnsafeDiagnostic),
144    def_target_features: TargetFeatures,
145    // FIXME: This needs to be the edition of the span of each call.
146    edition: Edition,
147}
148
149impl<'db> UnsafeVisitor<'db> {
150    fn new(
151        db: &'db dyn HirDatabase,
152        infer: &'db InferenceResult,
153        body: &'db Body,
154        def: DefWithBodyId,
155        unsafe_expr_cb: &'db mut dyn FnMut(UnsafeDiagnostic),
156    ) -> Self {
157        let resolver = def.resolver(db);
158        let def_target_features = match def {
159            DefWithBodyId::FunctionId(func) => TargetFeatures::from_attrs(&db.attrs(func.into())),
160            _ => TargetFeatures::default(),
161        };
162        let edition = resolver.module().krate().data(db).edition;
163        Self {
164            db,
165            infer,
166            body,
167            resolver,
168            def,
169            inside_unsafe_block: InsideUnsafeBlock::No,
170            inside_assignment: false,
171            inside_union_destructure: false,
172            callback: unsafe_expr_cb,
173            def_target_features,
174            edition,
175        }
176    }
177
178    fn on_unsafe_op(&mut self, node: ExprOrPatId, reason: UnsafetyReason) {
179        (self.callback)(UnsafeDiagnostic::UnsafeOperation {
180            node,
181            inside_unsafe_block: self.inside_unsafe_block,
182            reason,
183        });
184    }
185
186    fn check_call(&mut self, node: ExprId, func: FunctionId) {
187        let unsafety = is_fn_unsafe_to_call(self.db, func, &self.def_target_features, self.edition);
188        match unsafety {
189            crate::utils::Unsafety::Safe => {}
190            crate::utils::Unsafety::Unsafe => {
191                self.on_unsafe_op(node.into(), UnsafetyReason::UnsafeFnCall)
192            }
193            crate::utils::Unsafety::DeprecatedSafe2024 => {
194                (self.callback)(UnsafeDiagnostic::DeprecatedSafe2024 {
195                    node,
196                    inside_unsafe_block: self.inside_unsafe_block,
197                })
198            }
199        }
200    }
201
202    fn with_inside_unsafe_block<R>(
203        &mut self,
204        inside_unsafe_block: InsideUnsafeBlock,
205        f: impl FnOnce(&mut Self) -> R,
206    ) -> R {
207        let old = mem::replace(&mut self.inside_unsafe_block, inside_unsafe_block);
208        let result = f(self);
209        self.inside_unsafe_block = old;
210        result
211    }
212
213    fn walk_pats_top(&mut self, pats: impl Iterator<Item = PatId>, parent_expr: ExprId) {
214        let guard = self.resolver.update_to_inner_scope(self.db, self.def, parent_expr);
215        pats.for_each(|pat| self.walk_pat(pat));
216        self.resolver.reset_to_guard(guard);
217    }
218
219    fn walk_pat(&mut self, current: PatId) {
220        let pat = &self.body[current];
221
222        if self.inside_union_destructure {
223            match pat {
224                Pat::Tuple { .. }
225                | Pat::Record { .. }
226                | Pat::Range { .. }
227                | Pat::Slice { .. }
228                | Pat::Path(..)
229                | Pat::Lit(..)
230                | Pat::Bind { .. }
231                | Pat::TupleStruct { .. }
232                | Pat::Ref { .. }
233                | Pat::Box { .. }
234                | Pat::Expr(..)
235                | Pat::ConstBlock(..) => {
236                    self.on_unsafe_op(current.into(), UnsafetyReason::UnionField)
237                }
238                // `Or` only wraps other patterns, and `Missing`/`Wild` do not constitute a read.
239                Pat::Missing | Pat::Wild | Pat::Or(_) => {}
240            }
241        }
242
243        match pat {
244            Pat::Record { .. } => {
245                if let Some((AdtId::UnionId(_), _)) = self.infer[current].as_adt() {
246                    let old_inside_union_destructure =
247                        mem::replace(&mut self.inside_union_destructure, true);
248                    self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat));
249                    self.inside_union_destructure = old_inside_union_destructure;
250                    return;
251                }
252            }
253            Pat::Path(path) => self.mark_unsafe_path(current.into(), path),
254            &Pat::ConstBlock(expr) => {
255                let old_inside_assignment = mem::replace(&mut self.inside_assignment, false);
256                self.walk_expr(expr);
257                self.inside_assignment = old_inside_assignment;
258            }
259            &Pat::Expr(expr) => self.walk_expr(expr),
260            _ => {}
261        }
262
263        self.body.walk_pats_shallow(current, |pat| self.walk_pat(pat));
264    }
265
266    fn walk_expr(&mut self, current: ExprId) {
267        let expr = &self.body[current];
268        let inside_assignment = mem::replace(&mut self.inside_assignment, false);
269        match expr {
270            &Expr::Call { callee, .. } => {
271                let callee = &self.infer[callee];
272                if let Some(func) = callee.as_fn_def(self.db) {
273                    self.check_call(current, func);
274                }
275                if let TyKind::Function(fn_ptr) = callee.kind(Interner)
276                    && fn_ptr.sig.safety == chalk_ir::Safety::Unsafe
277                {
278                    self.on_unsafe_op(current.into(), UnsafetyReason::UnsafeFnCall);
279                }
280            }
281            Expr::Path(path) => {
282                let guard = self.resolver.update_to_inner_scope(self.db, self.def, current);
283                self.mark_unsafe_path(current.into(), path);
284                self.resolver.reset_to_guard(guard);
285            }
286            Expr::Ref { expr, rawness: Rawness::RawPtr, mutability: _ } => {
287                match self.body[*expr] {
288                    // Do not report unsafe for `addr_of[_mut]!(EXTERN_OR_MUT_STATIC)`,
289                    // see https://github.com/rust-lang/rust/pull/125834.
290                    Expr::Path(_) => return,
291                    // https://github.com/rust-lang/rust/pull/129248
292                    // Taking a raw ref to a deref place expr is always safe.
293                    Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
294                        self.body
295                            .walk_child_exprs_without_pats(expr, |child| self.walk_expr(child));
296
297                        return;
298                    }
299                    _ => (),
300                }
301            }
302            Expr::MethodCall { .. } => {
303                if let Some((func, _)) = self.infer.method_resolution(current) {
304                    self.check_call(current, func);
305                }
306            }
307            Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
308                if let TyKind::Raw(..) = &self.infer[*expr].kind(Interner) {
309                    self.on_unsafe_op(current.into(), UnsafetyReason::RawPtrDeref);
310                }
311            }
312            &Expr::Assignment { target, value: _ } => {
313                let old_inside_assignment = mem::replace(&mut self.inside_assignment, true);
314                self.walk_pats_top(std::iter::once(target), current);
315                self.inside_assignment = old_inside_assignment;
316            }
317            Expr::InlineAsm(asm) => {
318                if asm.kind == InlineAsmKind::Asm {
319                    // `naked_asm!()` requires `unsafe` on the attribute (`#[unsafe(naked)]`),
320                    // and `global_asm!()` doesn't require it at all.
321                    self.on_unsafe_op(current.into(), UnsafetyReason::InlineAsm);
322                }
323
324                asm.operands.iter().for_each(|(_, op)| match op {
325                    AsmOperand::In { expr, .. }
326                    | AsmOperand::Out { expr: Some(expr), .. }
327                    | AsmOperand::InOut { expr, .. }
328                    | AsmOperand::Const(expr) => self.walk_expr(*expr),
329                    AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
330                        self.walk_expr(*in_expr);
331                        if let Some(out_expr) = out_expr {
332                            self.walk_expr(*out_expr);
333                        }
334                    }
335                    AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
336                    AsmOperand::Label(expr) => {
337                        // Inline asm labels are considered safe even when inside unsafe blocks.
338                        self.with_inside_unsafe_block(InsideUnsafeBlock::No, |this| {
339                            this.walk_expr(*expr)
340                        });
341                    }
342                });
343                return;
344            }
345            // rustc allows union assignment to propagate through field accesses and casts.
346            Expr::Cast { .. } => self.inside_assignment = inside_assignment,
347            Expr::Field { .. } => {
348                self.inside_assignment = inside_assignment;
349                if !inside_assignment
350                    && let Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. })) =
351                        self.infer.field_resolution(current)
352                {
353                    self.on_unsafe_op(current.into(), UnsafetyReason::UnionField);
354                }
355            }
356            Expr::Unsafe { statements, .. } => {
357                self.with_inside_unsafe_block(InsideUnsafeBlock::Yes, |this| {
358                    this.walk_pats_top(
359                        statements.iter().filter_map(|statement| match statement {
360                            &Statement::Let { pat, .. } => Some(pat),
361                            _ => None,
362                        }),
363                        current,
364                    );
365                    this.body.walk_child_exprs_without_pats(current, |child| this.walk_expr(child));
366                });
367                return;
368            }
369            Expr::Block { statements, .. } | Expr::Async { statements, .. } => {
370                self.walk_pats_top(
371                    statements.iter().filter_map(|statement| match statement {
372                        &Statement::Let { pat, .. } => Some(pat),
373                        _ => None,
374                    }),
375                    current,
376                );
377            }
378            Expr::Match { arms, .. } => {
379                self.walk_pats_top(arms.iter().map(|arm| arm.pat), current);
380            }
381            &Expr::Let { pat, .. } => {
382                self.walk_pats_top(std::iter::once(pat), current);
383            }
384            Expr::Closure { args, .. } => {
385                self.walk_pats_top(args.iter().copied(), current);
386            }
387            Expr::Const(e) => self.walk_expr(*e),
388            _ => {}
389        }
390
391        self.body.walk_child_exprs_without_pats(current, |child| self.walk_expr(child));
392    }
393
394    fn mark_unsafe_path(&mut self, node: ExprOrPatId, path: &Path) {
395        let hygiene = self.body.expr_or_pat_path_hygiene(node);
396        let value_or_partial = self.resolver.resolve_path_in_value_ns(self.db, path, hygiene);
397        if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial {
398            let static_data = self.db.static_signature(id);
399            if static_data.flags.contains(StaticFlags::MUTABLE) {
400                self.on_unsafe_op(node, UnsafetyReason::MutableStatic);
401            } else if static_data.flags.contains(StaticFlags::EXTERN)
402                && !static_data.flags.contains(StaticFlags::EXPLICIT_SAFE)
403            {
404                self.on_unsafe_op(node, UnsafetyReason::ExternStatic);
405            }
406        }
407    }
408}