hir_ty/next_solver/
generic_arg.rs

1//! Things related to generic args in the next-trait-solver.
2
3use hir_def::{GenericDefId, GenericParamId};
4use macros::{TypeFoldable, TypeVisitable};
5use rustc_type_ir::{
6    ClosureArgs, CollectAndApply, ConstVid, CoroutineArgs, CoroutineClosureArgs, FnSigTys,
7    GenericArgKind, Interner, TermKind, TyKind, TyVid, Variance,
8    inherent::{GenericArg as _, GenericsOf, IntoKind, SliceLike, Term as _, Ty as _},
9    relate::{Relate, VarianceDiagInfo},
10    walk::TypeWalker,
11};
12use smallvec::SmallVec;
13
14use crate::next_solver::{PolyFnSig, interned_vec_db};
15
16use super::{
17    Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, Tys,
18    generics::Generics,
19};
20
21#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable, salsa::Supertype)]
22pub enum GenericArg<'db> {
23    Ty(Ty<'db>),
24    Lifetime(Region<'db>),
25    Const(Const<'db>),
26}
27
28impl<'db> std::fmt::Debug for GenericArg<'db> {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            Self::Ty(t) => std::fmt::Debug::fmt(t, f),
32            Self::Lifetime(r) => std::fmt::Debug::fmt(r, f),
33            Self::Const(c) => std::fmt::Debug::fmt(c, f),
34        }
35    }
36}
37
38impl<'db> GenericArg<'db> {
39    pub fn ty(self) -> Option<Ty<'db>> {
40        match self.kind() {
41            GenericArgKind::Type(ty) => Some(ty),
42            _ => None,
43        }
44    }
45
46    pub fn expect_ty(self) -> Ty<'db> {
47        match self.kind() {
48            GenericArgKind::Type(ty) => ty,
49            _ => panic!("Expected ty, got {self:?}"),
50        }
51    }
52
53    pub fn konst(self) -> Option<Const<'db>> {
54        match self.kind() {
55            GenericArgKind::Const(konst) => Some(konst),
56            _ => None,
57        }
58    }
59
60    pub fn region(self) -> Option<Region<'db>> {
61        match self.kind() {
62            GenericArgKind::Lifetime(r) => Some(r),
63            _ => None,
64        }
65    }
66
67    #[inline]
68    pub(crate) fn expect_region(self) -> Region<'db> {
69        match self {
70            GenericArg::Lifetime(region) => region,
71            _ => panic!("expected a region, got {self:?}"),
72        }
73    }
74
75    pub fn error_from_id(interner: DbInterner<'db>, id: GenericParamId) -> GenericArg<'db> {
76        match id {
77            GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(),
78            GenericParamId::ConstParamId(_) => Const::error(interner).into(),
79            GenericParamId::LifetimeParamId(_) => Region::error(interner).into(),
80        }
81    }
82
83    #[inline]
84    pub fn walk(self) -> TypeWalker<DbInterner<'db>> {
85        TypeWalker::new(self)
86    }
87}
88
89impl<'db> From<Term<'db>> for GenericArg<'db> {
90    fn from(value: Term<'db>) -> Self {
91        match value {
92            Term::Ty(ty) => GenericArg::Ty(ty),
93            Term::Const(c) => GenericArg::Const(c),
94        }
95    }
96}
97
98#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)]
99pub enum Term<'db> {
100    Ty(Ty<'db>),
101    Const(Const<'db>),
102}
103
104impl<'db> std::fmt::Debug for Term<'db> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            Self::Ty(t) => std::fmt::Debug::fmt(t, f),
108            Self::Const(c) => std::fmt::Debug::fmt(c, f),
109        }
110    }
111}
112
113impl<'db> Term<'db> {
114    pub fn expect_type(&self) -> Ty<'db> {
115        self.as_type().expect("expected a type, but found a const")
116    }
117
118    pub fn is_trivially_wf(&self, tcx: DbInterner<'db>) -> bool {
119        match self.kind() {
120            TermKind::Ty(ty) => ty.is_trivially_wf(tcx),
121            TermKind::Const(ct) => ct.is_trivially_wf(),
122        }
123    }
124}
125
126impl<'db> From<Ty<'db>> for GenericArg<'db> {
127    fn from(value: Ty<'db>) -> Self {
128        Self::Ty(value)
129    }
130}
131
132impl<'db> From<Region<'db>> for GenericArg<'db> {
133    fn from(value: Region<'db>) -> Self {
134        Self::Lifetime(value)
135    }
136}
137
138impl<'db> From<Const<'db>> for GenericArg<'db> {
139    fn from(value: Const<'db>) -> Self {
140        Self::Const(value)
141    }
142}
143
144impl<'db> IntoKind for GenericArg<'db> {
145    type Kind = GenericArgKind<DbInterner<'db>>;
146
147    fn kind(self) -> Self::Kind {
148        match self {
149            GenericArg::Ty(ty) => GenericArgKind::Type(ty),
150            GenericArg::Lifetime(region) => GenericArgKind::Lifetime(region),
151            GenericArg::Const(c) => GenericArgKind::Const(c),
152        }
153    }
154}
155
156impl<'db> Relate<DbInterner<'db>> for GenericArg<'db> {
157    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
158        relation: &mut R,
159        a: Self,
160        b: Self,
161    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
162        match (a.kind(), b.kind()) {
163            (GenericArgKind::Lifetime(a_lt), GenericArgKind::Lifetime(b_lt)) => {
164                Ok(relation.relate(a_lt, b_lt)?.into())
165            }
166            (GenericArgKind::Type(a_ty), GenericArgKind::Type(b_ty)) => {
167                Ok(relation.relate(a_ty, b_ty)?.into())
168            }
169            (GenericArgKind::Const(a_ct), GenericArgKind::Const(b_ct)) => {
170                Ok(relation.relate(a_ct, b_ct)?.into())
171            }
172            (GenericArgKind::Lifetime(unpacked), x) => {
173                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
174            }
175            (GenericArgKind::Type(unpacked), x) => {
176                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
177            }
178            (GenericArgKind::Const(unpacked), x) => {
179                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
180            }
181        }
182    }
183}
184
185interned_vec_db!(GenericArgs, GenericArg);
186
187impl<'db> rustc_type_ir::inherent::GenericArg<DbInterner<'db>> for GenericArg<'db> {}
188
189impl<'db> GenericArgs<'db> {
190    /// Creates an `GenericArgs` for generic parameter definitions,
191    /// by calling closures to obtain each kind.
192    /// The closures get to observe the `GenericArgs` as they're
193    /// being built, which can be used to correctly
194    /// replace defaults of generic parameters.
195    pub fn for_item<F>(
196        interner: DbInterner<'db>,
197        def_id: SolverDefId,
198        mut mk_kind: F,
199    ) -> GenericArgs<'db>
200    where
201        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
202    {
203        let defs = interner.generics_of(def_id);
204        let count = defs.count();
205
206        if count == 0 {
207            return Default::default();
208        }
209
210        let mut args = SmallVec::with_capacity(count);
211        Self::fill_item(&mut args, interner, defs, &mut mk_kind);
212        interner.mk_args(&args)
213    }
214
215    /// Creates an all-error `GenericArgs`.
216    pub fn error_for_item(interner: DbInterner<'db>, def_id: SolverDefId) -> GenericArgs<'db> {
217        GenericArgs::for_item(interner, def_id, |_, id, _| GenericArg::error_from_id(interner, id))
218    }
219
220    /// Like `for_item`, but prefers the default of a parameter if it has any.
221    pub fn for_item_with_defaults<F>(
222        interner: DbInterner<'db>,
223        def_id: GenericDefId,
224        mut fallback: F,
225    ) -> GenericArgs<'db>
226    where
227        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
228    {
229        let defaults = interner.db.generic_defaults(def_id);
230        Self::for_item(interner, def_id.into(), |idx, id, prev| match defaults.get(idx as usize) {
231            Some(default) => default.instantiate(interner, prev),
232            None => fallback(idx, id, prev),
233        })
234    }
235
236    /// Like `for_item()`, but calls first uses the args from `first`.
237    pub fn fill_rest<F>(
238        interner: DbInterner<'db>,
239        def_id: SolverDefId,
240        first: impl IntoIterator<Item = GenericArg<'db>>,
241        mut fallback: F,
242    ) -> GenericArgs<'db>
243    where
244        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
245    {
246        let mut iter = first.into_iter();
247        Self::for_item(interner, def_id, |idx, id, prev| {
248            iter.next().unwrap_or_else(|| fallback(idx, id, prev))
249        })
250    }
251
252    /// Appends default param values to `first` if needed. Params without default will call `fallback()`.
253    pub fn fill_with_defaults<F>(
254        interner: DbInterner<'db>,
255        def_id: GenericDefId,
256        first: impl IntoIterator<Item = GenericArg<'db>>,
257        mut fallback: F,
258    ) -> GenericArgs<'db>
259    where
260        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
261    {
262        let defaults = interner.db.generic_defaults(def_id);
263        Self::fill_rest(interner, def_id.into(), first, |idx, id, prev| {
264            defaults
265                .get(idx as usize)
266                .map(|default| default.instantiate(interner, prev))
267                .unwrap_or_else(|| fallback(idx, id, prev))
268        })
269    }
270
271    fn fill_item<F>(
272        args: &mut SmallVec<[GenericArg<'db>; 8]>,
273        interner: DbInterner<'_>,
274        defs: Generics,
275        mk_kind: &mut F,
276    ) where
277        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
278    {
279        if let Some(def_id) = defs.parent {
280            let parent_defs = interner.generics_of(def_id.into());
281            Self::fill_item(args, interner, parent_defs, mk_kind);
282        }
283        Self::fill_single(args, &defs, mk_kind);
284    }
285
286    fn fill_single<F>(args: &mut SmallVec<[GenericArg<'db>; 8]>, defs: &Generics, mk_kind: &mut F)
287    where
288        F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
289    {
290        args.reserve(defs.own_params.len());
291        for param in &defs.own_params {
292            let kind = mk_kind(args.len() as u32, param.id, args);
293            args.push(kind);
294        }
295    }
296
297    pub fn closure_sig_untupled(self) -> PolyFnSig<'db> {
298        let TyKind::FnPtr(inputs_and_output, hdr) =
299            self.split_closure_args_untupled().closure_sig_as_fn_ptr_ty.kind()
300        else {
301            unreachable!("not a function pointer")
302        };
303        inputs_and_output.with(hdr)
304    }
305
306    /// A "sensible" `.split_closure_args()`, where the arguments are not in a tuple.
307    pub fn split_closure_args_untupled(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> {
308        // FIXME: should use `ClosureSubst` when possible
309        match self.inner().as_slice() {
310            [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => {
311                let interner = DbInterner::conjure();
312                rustc_type_ir::ClosureArgsParts {
313                    parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
314                    closure_sig_as_fn_ptr_ty: sig_ty.expect_ty(),
315                    closure_kind_ty: closure_kind_ty.expect_ty(),
316                    tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
317                }
318            }
319            _ => {
320                unreachable!("unexpected closure sig");
321            }
322        }
323    }
324
325    pub fn types(self) -> impl Iterator<Item = Ty<'db>> {
326        self.iter().filter_map(|it| it.as_type())
327    }
328
329    pub fn consts(self) -> impl Iterator<Item = Const<'db>> {
330        self.iter().filter_map(|it| it.as_const())
331    }
332
333    pub fn regions(self) -> impl Iterator<Item = Region<'db>> {
334        self.iter().filter_map(|it| it.as_region())
335    }
336}
337
338impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> {
339    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
340        relation: &mut R,
341        a: Self,
342        b: Self,
343    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
344        let interner = relation.cx();
345        CollectAndApply::collect_and_apply(
346            std::iter::zip(a.iter(), b.iter()).map(|(a, b)| {
347                relation.relate_with_variance(
348                    Variance::Invariant,
349                    VarianceDiagInfo::default(),
350                    a,
351                    b,
352                )
353            }),
354            |g| GenericArgs::new_from_iter(interner, g.iter().cloned()),
355        )
356    }
357}
358
359impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs<'db> {
360    fn as_closure(self) -> ClosureArgs<DbInterner<'db>> {
361        ClosureArgs { args: self }
362    }
363    fn as_coroutine(self) -> CoroutineArgs<DbInterner<'db>> {
364        CoroutineArgs { args: self }
365    }
366    fn as_coroutine_closure(self) -> CoroutineClosureArgs<DbInterner<'db>> {
367        CoroutineClosureArgs { args: self }
368    }
369    fn rebase_onto(
370        self,
371        interner: DbInterner<'db>,
372        source_def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
373        target: <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs,
374    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
375        let defs = interner.generics_of(source_def_id);
376        interner.mk_args_from_iter(target.iter().chain(self.iter().skip(defs.count())))
377    }
378
379    fn identity_for_item(
380        interner: DbInterner<'db>,
381        def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
382    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
383        Self::for_item(interner, def_id, |index, kind, _| mk_param(interner, index, kind))
384    }
385
386    fn extend_with_error(
387        interner: DbInterner<'db>,
388        def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
389        original_args: &[<DbInterner<'db> as rustc_type_ir::Interner>::GenericArg],
390    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
391        Self::for_item(interner, def_id, |index, kind, _| {
392            if let Some(arg) = original_args.get(index as usize) {
393                *arg
394            } else {
395                error_for_param_kind(kind, interner)
396            }
397        })
398    }
399    fn type_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Ty {
400        self.inner()
401            .get(i)
402            .and_then(|g| g.as_type())
403            .unwrap_or_else(|| Ty::new_error(DbInterner::conjure(), ErrorGuaranteed))
404    }
405
406    fn region_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Region {
407        self.inner()
408            .get(i)
409            .and_then(|g| g.as_region())
410            .unwrap_or_else(|| Region::error(DbInterner::conjure()))
411    }
412
413    fn const_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Const {
414        self.inner()
415            .get(i)
416            .and_then(|g| g.as_const())
417            .unwrap_or_else(|| Const::error(DbInterner::conjure()))
418    }
419
420    fn split_closure_args(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> {
421        // FIXME: should use `ClosureSubst` when possible
422        match self.inner().as_slice() {
423            [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => {
424                let interner = DbInterner::conjure();
425                // This is stupid, but the next solver expects the first input to actually be a tuple
426                let sig_ty = match sig_ty.expect_ty().kind() {
427                    TyKind::FnPtr(sig_tys, header) => Ty::new(
428                        interner,
429                        TyKind::FnPtr(
430                            sig_tys.map_bound(|s| {
431                                let inputs = Ty::new_tup_from_iter(interner, s.inputs().iter());
432                                let output = s.output();
433                                FnSigTys {
434                                    inputs_and_output: Tys::new_from_iter(
435                                        interner,
436                                        [inputs, output],
437                                    ),
438                                }
439                            }),
440                            header,
441                        ),
442                    ),
443                    _ => unreachable!("sig_ty should be last"),
444                };
445                rustc_type_ir::ClosureArgsParts {
446                    parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
447                    closure_sig_as_fn_ptr_ty: sig_ty,
448                    closure_kind_ty: closure_kind_ty.expect_ty(),
449                    tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
450                }
451            }
452            _ => {
453                unreachable!("unexpected closure sig");
454            }
455        }
456    }
457
458    fn split_coroutine_closure_args(
459        self,
460    ) -> rustc_type_ir::CoroutineClosureArgsParts<DbInterner<'db>> {
461        match self.inner().as_slice() {
462            [
463                parent_args @ ..,
464                closure_kind_ty,
465                signature_parts_ty,
466                tupled_upvars_ty,
467                coroutine_captures_by_ref_ty,
468            ] => rustc_type_ir::CoroutineClosureArgsParts {
469                parent_args: GenericArgs::new_from_iter(
470                    DbInterner::conjure(),
471                    parent_args.iter().cloned(),
472                ),
473                closure_kind_ty: closure_kind_ty.expect_ty(),
474                signature_parts_ty: signature_parts_ty.expect_ty(),
475                tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
476                coroutine_captures_by_ref_ty: coroutine_captures_by_ref_ty.expect_ty(),
477            },
478            _ => panic!("GenericArgs were likely not for a CoroutineClosure."),
479        }
480    }
481
482    fn split_coroutine_args(self) -> rustc_type_ir::CoroutineArgsParts<DbInterner<'db>> {
483        let interner = DbInterner::conjure();
484        match self.inner().as_slice() {
485            [parent_args @ .., kind_ty, resume_ty, yield_ty, return_ty, tupled_upvars_ty] => {
486                rustc_type_ir::CoroutineArgsParts {
487                    parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
488                    kind_ty: kind_ty.expect_ty(),
489                    resume_ty: resume_ty.expect_ty(),
490                    yield_ty: yield_ty.expect_ty(),
491                    return_ty: return_ty.expect_ty(),
492                    tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
493                }
494            }
495            _ => panic!("GenericArgs were likely not for a Coroutine."),
496        }
497    }
498}
499
500pub fn mk_param<'db>(interner: DbInterner<'db>, index: u32, id: GenericParamId) -> GenericArg<'db> {
501    match id {
502        GenericParamId::LifetimeParamId(id) => {
503            Region::new_early_param(interner, EarlyParamRegion { index, id }).into()
504        }
505        GenericParamId::TypeParamId(id) => Ty::new_param(interner, id, index).into(),
506        GenericParamId::ConstParamId(id) => {
507            Const::new_param(interner, ParamConst { index, id }).into()
508        }
509    }
510}
511
512pub fn error_for_param_kind<'db>(id: GenericParamId, interner: DbInterner<'db>) -> GenericArg<'db> {
513    match id {
514        GenericParamId::LifetimeParamId(_) => Region::error(interner).into(),
515        GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(),
516        GenericParamId::ConstParamId(_) => Const::error(interner).into(),
517    }
518}
519
520impl<'db> IntoKind for Term<'db> {
521    type Kind = TermKind<DbInterner<'db>>;
522
523    fn kind(self) -> Self::Kind {
524        match self {
525            Term::Ty(ty) => TermKind::Ty(ty),
526            Term::Const(c) => TermKind::Const(c),
527        }
528    }
529}
530
531impl<'db> From<Ty<'db>> for Term<'db> {
532    fn from(value: Ty<'db>) -> Self {
533        Self::Ty(value)
534    }
535}
536
537impl<'db> From<Const<'db>> for Term<'db> {
538    fn from(value: Const<'db>) -> Self {
539        Self::Const(value)
540    }
541}
542
543impl<'db> Relate<DbInterner<'db>> for Term<'db> {
544    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
545        relation: &mut R,
546        a: Self,
547        b: Self,
548    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
549        match (a.kind(), b.kind()) {
550            (TermKind::Ty(a_ty), TermKind::Ty(b_ty)) => Ok(relation.relate(a_ty, b_ty)?.into()),
551            (TermKind::Const(a_ct), TermKind::Const(b_ct)) => {
552                Ok(relation.relate(a_ct, b_ct)?.into())
553            }
554            (TermKind::Ty(unpacked), x) => {
555                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
556            }
557            (TermKind::Const(unpacked), x) => {
558                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
559            }
560        }
561    }
562}
563
564impl<'db> rustc_type_ir::inherent::Term<DbInterner<'db>> for Term<'db> {}
565
566#[derive(Clone, Eq, PartialEq, Debug)]
567pub enum TermVid {
568    Ty(TyVid),
569    Const(ConstVid),
570}
571
572impl From<TyVid> for TermVid {
573    fn from(value: TyVid) -> Self {
574        TermVid::Ty(value)
575    }
576}
577
578impl From<ConstVid> for TermVid {
579    fn from(value: ConstVid) -> Self {
580        TermVid::Const(value)
581    }
582}
583
584impl<'db> DbInterner<'db> {
585    pub(super) fn mk_args(self, args: &[GenericArg<'db>]) -> GenericArgs<'db> {
586        GenericArgs::new_from_iter(self, args.iter().cloned())
587    }
588
589    pub(super) fn mk_args_from_iter<I, T>(self, iter: I) -> T::Output
590    where
591        I: Iterator<Item = T>,
592        T: rustc_type_ir::CollectAndApply<GenericArg<'db>, GenericArgs<'db>>,
593    {
594        T::collect_and_apply(iter, |xs| self.mk_args(xs))
595    }
596}