1pub(crate) mod analysis;
4
5use std::{iter, mem, ops::ControlFlow};
6
7use hir_def::{
8 TraitId,
9 hir::{ClosureKind, ExprId, PatId},
10 type_ref::TypeRefId,
11};
12use rustc_type_ir::{
13 ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs,
14 CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
15 TypeVisitor,
16 inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Ty as _},
17};
18use tracing::debug;
19
20use crate::{
21 FnAbi,
22 db::{InternedClosure, InternedCoroutine},
23 infer::{BreakableKind, Diverges, coerce::CoerceMany},
24 next_solver::{
25 AliasTy, Binder, BoundRegionKind, BoundVarKind, BoundVarKinds, ClauseKind, DbInterner,
26 ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, PolyProjectionPredicate, Predicate,
27 PredicateKind, SolverDefId, Ty, TyKind,
28 abi::Safety,
29 infer::{
30 BoundRegionConversionTime, InferOk, InferResult,
31 traits::{ObligationCause, PredicateObligations},
32 },
33 },
34 traits::FnTrait,
35};
36
37use super::{Expectation, InferenceContext};
38
39#[derive(Debug)]
40struct ClosureSignatures<'db> {
41 bound_sig: PolyFnSig<'db>,
43 liberated_sig: FnSig<'db>,
48}
49
50impl<'db> InferenceContext<'_, 'db> {
51 pub(super) fn infer_closure(
52 &mut self,
53 body: ExprId,
54 args: &[PatId],
55 ret_type: Option<TypeRefId>,
56 arg_types: &[Option<TypeRefId>],
57 closure_kind: ClosureKind,
58 tgt_expr: ExprId,
59 expected: &Expectation<'db>,
60 ) -> Ty<'db> {
61 assert_eq!(args.len(), arg_types.len());
62
63 let interner = self.interner();
64 let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) {
65 Some(expected_ty) => self.deduce_closure_signature(expected_ty, closure_kind),
66 None => (None, None),
67 };
68
69 let ClosureSignatures { bound_sig, liberated_sig } =
70 self.sig_of_closure(arg_types, ret_type, expected_sig);
71 let body_ret_ty = bound_sig.output().skip_binder();
72 let sig_ty = Ty::new_fn_ptr(interner, bound_sig);
73
74 let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into());
75 let tupled_upvars_ty = self.types.unit;
77 let (id, ty, resume_yield_tys) = match closure_kind {
78 ClosureKind::Coroutine(_) => {
79 let yield_ty = self.table.next_ty_var();
80 let resume_ty = liberated_sig.inputs().get(0).unwrap_or(self.types.unit);
81
82 let parts = CoroutineArgsParts {
84 parent_args,
85 kind_ty: self.types.unit,
86 resume_ty,
87 yield_ty,
88 return_ty: body_ret_ty,
89 tupled_upvars_ty,
90 };
91
92 let coroutine_id =
93 self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into();
94 let coroutine_ty = Ty::new_coroutine(
95 interner,
96 coroutine_id,
97 CoroutineArgs::new(interner, parts).args,
98 );
99
100 (None, coroutine_ty, Some((resume_ty, yield_ty)))
101 }
102 ClosureKind::Closure => {
103 let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr));
104 match expected_kind {
105 Some(kind) => {
106 self.result.closure_info.insert(
107 closure_id,
108 (
109 Vec::new(),
110 match kind {
111 rustc_type_ir::ClosureKind::Fn => FnTrait::Fn,
112 rustc_type_ir::ClosureKind::FnMut => FnTrait::FnMut,
113 rustc_type_ir::ClosureKind::FnOnce => FnTrait::FnOnce,
114 },
115 ),
116 );
117 }
118 None => {}
119 };
120 let parts = ClosureArgsParts {
122 parent_args,
123 closure_kind_ty: Ty::from_closure_kind(
124 interner,
125 expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn),
126 ),
127 closure_sig_as_fn_ptr_ty: sig_ty,
128 tupled_upvars_ty,
129 };
130 let closure_ty = Ty::new_closure(
131 interner,
132 closure_id.into(),
133 ClosureArgs::new(interner, parts).args,
134 );
135 self.deferred_closures.entry(closure_id).or_default();
136 self.add_current_closure_dependency(closure_id);
137 (Some(closure_id), closure_ty, None)
138 }
139 ClosureKind::Async => {
140 let bound_return_ty = bound_sig.skip_binder().output();
143 let bound_yield_ty = self.types.unit;
144 let resume_ty = self.types.unit;
146
147 let closure_kind_ty = Ty::from_closure_kind(
149 interner,
150 expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn),
151 );
152
153 let coroutine_captures_by_ref_ty = Ty::new_fn_ptr(
156 interner,
157 Binder::bind_with_vars(
158 interner.mk_fn_sig([], self.types.unit, false, Safety::Safe, FnAbi::Rust),
159 BoundVarKinds::new_from_iter(
160 interner,
161 [BoundVarKind::Region(BoundRegionKind::ClosureEnv)],
162 ),
163 ),
164 );
165 let closure_args = CoroutineClosureArgs::new(
166 interner,
167 CoroutineClosureArgsParts {
168 parent_args,
169 closure_kind_ty,
170 signature_parts_ty: Ty::new_fn_ptr(
171 interner,
172 bound_sig.map_bound(|sig| {
173 interner.mk_fn_sig(
174 [
175 resume_ty,
176 Ty::new_tup_from_iter(interner, sig.inputs().iter()),
177 ],
178 Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]),
179 sig.c_variadic,
180 sig.safety,
181 sig.abi,
182 )
183 }),
184 ),
185 tupled_upvars_ty,
186 coroutine_captures_by_ref_ty,
187 },
188 );
189
190 let coroutine_id =
191 self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into();
192 (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None)
193 }
194 };
195
196 for (arg_pat, arg_ty) in args.iter().zip(bound_sig.skip_binder().inputs()) {
198 self.infer_top_pat(*arg_pat, arg_ty, None);
199 }
200
201 let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
203 let prev_closure = mem::replace(&mut self.current_closure, id);
204 let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty);
205 let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty));
206 let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys);
207
208 self.with_breakable_ctx(BreakableKind::Border, None, None, |this| {
209 this.infer_return(body);
210 });
211
212 self.diverges = prev_diverges;
213 self.return_ty = prev_ret_ty;
214 self.return_coercion = prev_ret_coercion;
215 self.current_closure = prev_closure;
216 self.resume_yield_tys = prev_resume_yield_tys;
217
218 ty
219 }
220
221 fn fn_trait_kind_from_def_id(&self, trait_id: TraitId) -> Option<rustc_type_ir::ClosureKind> {
222 match trait_id {
223 _ if self.lang_items.Fn == Some(trait_id) => Some(rustc_type_ir::ClosureKind::Fn),
224 _ if self.lang_items.FnMut == Some(trait_id) => Some(rustc_type_ir::ClosureKind::FnMut),
225 _ if self.lang_items.FnOnce == Some(trait_id) => {
226 Some(rustc_type_ir::ClosureKind::FnOnce)
227 }
228 _ => None,
229 }
230 }
231
232 fn async_fn_trait_kind_from_def_id(
233 &self,
234 trait_id: TraitId,
235 ) -> Option<rustc_type_ir::ClosureKind> {
236 match trait_id {
237 _ if self.lang_items.AsyncFn == Some(trait_id) => Some(rustc_type_ir::ClosureKind::Fn),
238 _ if self.lang_items.AsyncFnMut == Some(trait_id) => {
239 Some(rustc_type_ir::ClosureKind::FnMut)
240 }
241 _ if self.lang_items.AsyncFnOnce == Some(trait_id) => {
242 Some(rustc_type_ir::ClosureKind::FnOnce)
243 }
244 _ => None,
245 }
246 }
247
248 fn deduce_closure_signature(
251 &mut self,
252 expected_ty: Ty<'db>,
253 closure_kind: ClosureKind,
254 ) -> (Option<PolyFnSig<'db>>, Option<rustc_type_ir::ClosureKind>) {
255 match expected_ty.kind() {
256 TyKind::Alias(rustc_type_ir::Opaque, AliasTy { def_id, args, .. }) => self
257 .deduce_closure_signature_from_predicates(
258 expected_ty,
259 closure_kind,
260 def_id
261 .expect_opaque_ty()
262 .predicates(self.db)
263 .iter_instantiated_copied(self.interner(), args.as_slice())
264 .map(|clause| clause.as_predicate()),
265 ),
266 TyKind::Dynamic(object_type, ..) => {
267 let sig = object_type.projection_bounds().into_iter().find_map(|pb| {
268 let pb = pb.with_self_ty(self.interner(), Ty::new_unit(self.interner()));
269 self.deduce_sig_from_projection(closure_kind, pb)
270 });
271 let kind = object_type
272 .principal_def_id()
273 .and_then(|did| self.fn_trait_kind_from_def_id(did.0));
274 (sig, kind)
275 }
276 TyKind::Infer(rustc_type_ir::TyVar(vid)) => self
277 .deduce_closure_signature_from_predicates(
278 Ty::new_var(self.interner(), self.table.infer_ctxt.root_var(vid)),
279 closure_kind,
280 self.table.obligations_for_self_ty(vid).into_iter().map(|obl| obl.predicate),
281 ),
282 TyKind::FnPtr(sig_tys, hdr) => match closure_kind {
283 ClosureKind::Closure => {
284 let expected_sig = sig_tys.with(hdr);
285 (Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn))
286 }
287 ClosureKind::Coroutine(_) | ClosureKind::Async => (None, None),
288 },
289 _ => (None, None),
290 }
291 }
292
293 fn deduce_closure_signature_from_predicates(
294 &mut self,
295 expected_ty: Ty<'db>,
296 closure_kind: ClosureKind,
297 predicates: impl DoubleEndedIterator<Item = Predicate<'db>>,
298 ) -> (Option<PolyFnSig<'db>>, Option<rustc_type_ir::ClosureKind>) {
299 let mut expected_sig = None;
300 let mut expected_kind = None;
301
302 for pred in rustc_type_ir::elaborate::elaborate(
303 self.interner(),
304 predicates.rev(),
308 )
309 .filter_only_self()
311 {
312 debug!(?pred);
313 let bound_predicate = pred.kind();
314
315 if expected_sig.is_none()
318 && let PredicateKind::Clause(ClauseKind::Projection(proj_predicate)) =
319 bound_predicate.skip_binder()
320 {
321 let inferred_sig = self.deduce_sig_from_projection(
322 closure_kind,
323 bound_predicate.rebind(proj_predicate),
324 );
325
326 struct MentionsTy<'db> {
330 expected_ty: Ty<'db>,
331 }
332 impl<'db> TypeVisitor<DbInterner<'db>> for MentionsTy<'db> {
333 type Result = ControlFlow<()>;
334
335 fn visit_ty(&mut self, t: Ty<'db>) -> Self::Result {
336 if t == self.expected_ty {
337 ControlFlow::Break(())
338 } else {
339 t.super_visit_with(self)
340 }
341 }
342 }
343
344 if let Some(inferred_sig) = inferred_sig {
347 let generalized_fnptr_sig = self.table.next_ty_var();
365 let inferred_fnptr_sig = Ty::new_fn_ptr(self.interner(), inferred_sig);
366 _ = self
368 .table
369 .infer_ctxt
370 .at(&ObligationCause::new(), self.table.param_env)
371 .eq(inferred_fnptr_sig, generalized_fnptr_sig)
372 .map(|infer_ok| self.table.register_infer_ok(infer_ok));
373
374 let resolved_sig =
375 self.table.infer_ctxt.resolve_vars_if_possible(generalized_fnptr_sig);
376
377 if resolved_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() {
378 expected_sig = Some(resolved_sig.fn_sig(self.interner()));
379 }
380 } else if inferred_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() {
381 expected_sig = inferred_sig;
382 }
383 }
384
385 let trait_def_id = match bound_predicate.skip_binder() {
390 PredicateKind::Clause(ClauseKind::Projection(data)) => {
391 Some(data.projection_term.trait_def_id(self.interner()).0)
392 }
393 PredicateKind::Clause(ClauseKind::Trait(data)) => Some(data.def_id().0),
394 _ => None,
395 };
396
397 if let Some(trait_def_id) = trait_def_id {
398 let found_kind = match closure_kind {
399 ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id),
400 ClosureKind::Async => self
401 .async_fn_trait_kind_from_def_id(trait_def_id)
402 .or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)),
403 _ => None,
404 };
405
406 if let Some(found_kind) = found_kind {
407 match (expected_kind, found_kind) {
409 (None, _) => expected_kind = Some(found_kind),
410 (
411 Some(rustc_type_ir::ClosureKind::FnMut),
412 rustc_type_ir::ClosureKind::Fn,
413 ) => expected_kind = Some(rustc_type_ir::ClosureKind::Fn),
414 (
415 Some(rustc_type_ir::ClosureKind::FnOnce),
416 rustc_type_ir::ClosureKind::Fn | rustc_type_ir::ClosureKind::FnMut,
417 ) => expected_kind = Some(found_kind),
418 _ => {}
419 }
420 }
421 }
422 }
423
424 (expected_sig, expected_kind)
425 }
426
427 fn deduce_sig_from_projection(
434 &mut self,
435 closure_kind: ClosureKind,
436 projection: PolyProjectionPredicate<'db>,
437 ) -> Option<PolyFnSig<'db>> {
438 let SolverDefId::TypeAliasId(def_id) = projection.item_def_id() else { unreachable!() };
439
440 match closure_kind {
443 ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => {
444 self.extract_sig_from_projection(projection)
445 }
446 ClosureKind::Async if Some(def_id) == self.lang_items.AsyncFnOnceOutput => {
447 self.extract_sig_from_projection(projection)
448 }
449 ClosureKind::Async if Some(def_id) == self.lang_items.FnOnceOutput => {
453 self.extract_sig_from_projection_and_future_bound(projection)
454 }
455 _ => None,
456 }
457 }
458
459 fn extract_sig_from_projection(
462 &self,
463 projection: PolyProjectionPredicate<'db>,
464 ) -> Option<PolyFnSig<'db>> {
465 let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection);
466
467 let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
468 debug!(?arg_param_ty);
469
470 let TyKind::Tuple(input_tys) = arg_param_ty.kind() else {
471 return None;
472 };
473
474 let ret_param_ty = projection.skip_binder().term.expect_type();
476 debug!(?ret_param_ty);
477
478 let sig = projection.rebind(self.interner().mk_fn_sig(
479 input_tys,
480 ret_param_ty,
481 false,
482 Safety::Safe,
483 FnAbi::Rust,
484 ));
485
486 Some(sig)
487 }
488
489 fn extract_sig_from_projection_and_future_bound(
512 &mut self,
513 projection: PolyProjectionPredicate<'db>,
514 ) -> Option<PolyFnSig<'db>> {
515 let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection);
516
517 let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
518 debug!(?arg_param_ty);
519
520 let TyKind::Tuple(input_tys) = arg_param_ty.kind() else {
521 return None;
522 };
523
524 let TyKind::Infer(rustc_type_ir::TyVar(return_vid)) =
530 projection.skip_binder().term.expect_type().kind()
531 else {
532 return None;
533 };
534
535 let mut return_ty = None;
537 for bound in self.table.obligations_for_self_ty(return_vid) {
538 if let PredicateKind::Clause(ClauseKind::Projection(ret_projection)) =
539 bound.predicate.kind().skip_binder()
540 && let ret_projection = bound.predicate.kind().rebind(ret_projection)
541 && let Some(ret_projection) = ret_projection.no_bound_vars()
542 && let SolverDefId::TypeAliasId(assoc_type) = ret_projection.def_id()
543 && Some(assoc_type) == self.lang_items.FutureOutput
544 {
545 return_ty = Some(ret_projection.term.expect_type());
546 break;
547 }
548 }
549
550 let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var());
565
566 let sig = projection.rebind(self.interner().mk_fn_sig(
567 input_tys,
568 return_ty,
569 false,
570 Safety::Safe,
571 FnAbi::Rust,
572 ));
573
574 Some(sig)
575 }
576
577 fn sig_of_closure(
578 &mut self,
579 decl_inputs: &[Option<TypeRefId>],
580 decl_output: Option<TypeRefId>,
581 expected_sig: Option<PolyFnSig<'db>>,
582 ) -> ClosureSignatures<'db> {
583 if let Some(e) = expected_sig {
584 self.sig_of_closure_with_expectation(decl_inputs, decl_output, e)
585 } else {
586 self.sig_of_closure_no_expectation(decl_inputs, decl_output)
587 }
588 }
589
590 fn sig_of_closure_no_expectation(
593 &mut self,
594 decl_inputs: &[Option<TypeRefId>],
595 decl_output: Option<TypeRefId>,
596 ) -> ClosureSignatures<'db> {
597 let bound_sig = self.supplied_sig_of_closure(decl_inputs, decl_output);
598
599 self.closure_sigs(bound_sig)
600 }
601
602 fn sig_of_closure_with_expectation(
650 &mut self,
651 decl_inputs: &[Option<TypeRefId>],
652 decl_output: Option<TypeRefId>,
653 expected_sig: PolyFnSig<'db>,
654 ) -> ClosureSignatures<'db> {
655 if expected_sig.c_variadic() {
659 return self.sig_of_closure_no_expectation(decl_inputs, decl_output);
660 } else if expected_sig.skip_binder().inputs_and_output.len() != decl_inputs.len() + 1 {
661 return self
662 .sig_of_closure_with_mismatched_number_of_arguments(decl_inputs, decl_output);
663 }
664
665 assert!(!expected_sig.skip_binder().has_vars_bound_above(rustc_type_ir::INNERMOST));
669 let bound_sig = expected_sig.map_bound(|sig| {
670 self.interner().mk_fn_sig(
671 sig.inputs(),
672 sig.output(),
673 sig.c_variadic,
674 Safety::Safe,
675 FnAbi::RustCall,
676 )
677 });
678
679 let bound_sig = self.interner().anonymize_bound_vars(bound_sig);
683
684 let closure_sigs = self.closure_sigs(bound_sig);
685
686 match self.merge_supplied_sig_with_expectation(decl_inputs, decl_output, closure_sigs) {
692 Ok(infer_ok) => self.table.register_infer_ok(infer_ok),
693 Err(_) => self.sig_of_closure_no_expectation(decl_inputs, decl_output),
694 }
695 }
696
697 fn sig_of_closure_with_mismatched_number_of_arguments(
698 &mut self,
699 decl_inputs: &[Option<TypeRefId>],
700 decl_output: Option<TypeRefId>,
701 ) -> ClosureSignatures<'db> {
702 let error_sig = self.error_sig_of_closure(decl_inputs, decl_output);
703
704 self.closure_sigs(error_sig)
705 }
706
707 fn merge_supplied_sig_with_expectation(
711 &mut self,
712 decl_inputs: &[Option<TypeRefId>],
713 decl_output: Option<TypeRefId>,
714 mut expected_sigs: ClosureSignatures<'db>,
715 ) -> InferResult<'db, ClosureSignatures<'db>> {
716 let supplied_sig = self.supplied_sig_of_closure(decl_inputs, decl_output);
721
722 debug!(?supplied_sig);
723
724 self.table.commit_if_ok(|table| {
739 let mut all_obligations = PredicateObligations::new();
740 let supplied_sig = table.infer_ctxt.instantiate_binder_with_fresh_vars(
741 BoundRegionConversionTime::FnCall,
742 supplied_sig,
743 );
744
745 for (supplied_ty, expected_ty) in
748 iter::zip(supplied_sig.inputs(), expected_sigs.liberated_sig.inputs())
749 {
750 let cause = ObligationCause::new();
752 let InferOk { value: (), obligations } =
753 table.infer_ctxt.at(&cause, table.param_env).eq(expected_ty, supplied_ty)?;
754 all_obligations.extend(obligations);
755 }
756
757 let supplied_output_ty = supplied_sig.output();
758 let cause = ObligationCause::new();
759 let InferOk { value: (), obligations } =
760 table
761 .infer_ctxt
762 .at(&cause, table.param_env)
763 .eq(expected_sigs.liberated_sig.output(), supplied_output_ty)?;
764 all_obligations.extend(obligations);
765
766 let inputs = supplied_sig
767 .inputs()
768 .into_iter()
769 .map(|ty| table.infer_ctxt.resolve_vars_if_possible(ty));
770
771 expected_sigs.liberated_sig = table.interner().mk_fn_sig(
772 inputs,
773 supplied_output_ty,
774 expected_sigs.liberated_sig.c_variadic,
775 Safety::Safe,
776 FnAbi::RustCall,
777 );
778
779 Ok(InferOk { value: expected_sigs, obligations: all_obligations })
780 })
781 }
782
783 fn supplied_sig_of_closure(
788 &mut self,
789 decl_inputs: &[Option<TypeRefId>],
790 decl_output: Option<TypeRefId>,
791 ) -> PolyFnSig<'db> {
792 let interner = self.interner();
793
794 let supplied_return = match decl_output {
795 Some(output) => {
796 let output = self.make_body_ty(output);
797 self.process_user_written_ty(output)
798 }
799 None => self.table.next_ty_var(),
800 };
801 let supplied_arguments = decl_inputs.iter().map(|&input| match input {
803 Some(input) => {
804 let input = self.make_body_ty(input);
805 self.process_user_written_ty(input)
806 }
807 None => self.table.next_ty_var(),
808 });
809
810 Binder::dummy(interner.mk_fn_sig(
811 supplied_arguments,
812 supplied_return,
813 false,
814 Safety::Safe,
815 FnAbi::RustCall,
816 ))
817 }
818
819 fn error_sig_of_closure(
823 &mut self,
824 decl_inputs: &[Option<TypeRefId>],
825 decl_output: Option<TypeRefId>,
826 ) -> PolyFnSig<'db> {
827 let interner = self.interner();
828 let err_ty = Ty::new_error(interner, ErrorGuaranteed);
829
830 if let Some(output) = decl_output {
831 self.make_body_ty(output);
832 }
833 let supplied_arguments = decl_inputs.iter().map(|&input| match input {
834 Some(input) => {
835 self.make_body_ty(input);
836 err_ty
837 }
838 None => err_ty,
839 });
840
841 let result = Binder::dummy(interner.mk_fn_sig(
842 supplied_arguments,
843 err_ty,
844 false,
845 Safety::Safe,
846 FnAbi::RustCall,
847 ));
848
849 debug!("supplied_sig_of_closure: result={:?}", result);
850
851 result
852 }
853
854 fn closure_sigs(&self, bound_sig: PolyFnSig<'db>) -> ClosureSignatures<'db> {
855 let liberated_sig = bound_sig.skip_binder();
856 ClosureSignatures { bound_sig, liberated_sig }
858 }
859}