hir_ty/next_solver/
generic_arg.rs

1//! Things related to generic args in the next-trait-solver.
2
3use hir_def::GenericParamId;
4use intern::{Interned, Symbol};
5use rustc_type_ir::{
6    ClosureArgs, CollectAndApply, ConstVid, CoroutineArgs, CoroutineClosureArgs, FnSig, FnSigTys,
7    GenericArgKind, IntTy, Interner, TermKind, TyKind, TyVid, TypeFoldable, TypeVisitable,
8    Variance,
9    inherent::{GenericArg as _, GenericsOf, IntoKind, SliceLike, Term as _, Ty as _},
10    relate::{Relate, VarianceDiagInfo},
11};
12use smallvec::SmallVec;
13
14use crate::db::HirDatabase;
15
16use super::{
17    Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, Tys,
18    generics::{GenericParamDef, Generics},
19    interned_vec_db,
20};
21
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub enum GenericArg<'db> {
24    Ty(Ty<'db>),
25    Lifetime(Region<'db>),
26    Const(Const<'db>),
27}
28
29impl<'db> std::fmt::Debug for GenericArg<'db> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Self::Ty(t) => std::fmt::Debug::fmt(t, f),
33            Self::Lifetime(r) => std::fmt::Debug::fmt(r, f),
34            Self::Const(c) => std::fmt::Debug::fmt(c, f),
35        }
36    }
37}
38
39impl<'db> GenericArg<'db> {
40    pub fn ty(self) -> Option<Ty<'db>> {
41        match self.kind() {
42            GenericArgKind::Type(ty) => Some(ty),
43            _ => None,
44        }
45    }
46
47    pub fn expect_ty(self) -> Ty<'db> {
48        match self.kind() {
49            GenericArgKind::Type(ty) => ty,
50            _ => panic!("Expected ty, got {self:?}"),
51        }
52    }
53
54    pub fn region(self) -> Option<Region<'db>> {
55        match self.kind() {
56            GenericArgKind::Lifetime(r) => Some(r),
57            _ => None,
58        }
59    }
60}
61
62impl<'db> From<Term<'db>> for GenericArg<'db> {
63    fn from(value: Term<'db>) -> Self {
64        match value {
65            Term::Ty(ty) => GenericArg::Ty(ty),
66            Term::Const(c) => GenericArg::Const(c),
67        }
68    }
69}
70
71#[derive(Copy, Clone, PartialEq, Eq, Hash)]
72pub enum Term<'db> {
73    Ty(Ty<'db>),
74    Const(Const<'db>),
75}
76
77impl<'db> std::fmt::Debug for Term<'db> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            Self::Ty(t) => std::fmt::Debug::fmt(t, f),
81            Self::Const(c) => std::fmt::Debug::fmt(c, f),
82        }
83    }
84}
85
86impl<'db> Term<'db> {
87    pub fn expect_type(&self) -> Ty<'db> {
88        self.as_type().expect("expected a type, but found a const")
89    }
90
91    pub fn is_trivially_wf(&self, tcx: DbInterner<'db>) -> bool {
92        match self.kind() {
93            TermKind::Ty(ty) => ty.is_trivially_wf(tcx),
94            TermKind::Const(ct) => ct.is_trivially_wf(),
95        }
96    }
97}
98
99impl<'db> From<Ty<'db>> for GenericArg<'db> {
100    fn from(value: Ty<'db>) -> Self {
101        Self::Ty(value)
102    }
103}
104
105impl<'db> From<Region<'db>> for GenericArg<'db> {
106    fn from(value: Region<'db>) -> Self {
107        Self::Lifetime(value)
108    }
109}
110
111impl<'db> From<Const<'db>> for GenericArg<'db> {
112    fn from(value: Const<'db>) -> Self {
113        Self::Const(value)
114    }
115}
116
117impl<'db> IntoKind for GenericArg<'db> {
118    type Kind = GenericArgKind<DbInterner<'db>>;
119
120    fn kind(self) -> Self::Kind {
121        match self {
122            GenericArg::Ty(ty) => GenericArgKind::Type(ty),
123            GenericArg::Lifetime(region) => GenericArgKind::Lifetime(region),
124            GenericArg::Const(c) => GenericArgKind::Const(c),
125        }
126    }
127}
128
129impl<'db> TypeVisitable<DbInterner<'db>> for GenericArg<'db> {
130    fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
131        &self,
132        visitor: &mut V,
133    ) -> V::Result {
134        match self {
135            GenericArg::Lifetime(lt) => lt.visit_with(visitor),
136            GenericArg::Ty(ty) => ty.visit_with(visitor),
137            GenericArg::Const(ct) => ct.visit_with(visitor),
138        }
139    }
140}
141
142impl<'db> TypeFoldable<DbInterner<'db>> for GenericArg<'db> {
143    fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
144        self,
145        folder: &mut F,
146    ) -> Result<Self, F::Error> {
147        match self.kind() {
148            GenericArgKind::Lifetime(lt) => lt.try_fold_with(folder).map(Into::into),
149            GenericArgKind::Type(ty) => ty.try_fold_with(folder).map(Into::into),
150            GenericArgKind::Const(ct) => ct.try_fold_with(folder).map(Into::into),
151        }
152    }
153    fn fold_with<F: rustc_type_ir::TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self {
154        match self.kind() {
155            GenericArgKind::Lifetime(lt) => lt.fold_with(folder).into(),
156            GenericArgKind::Type(ty) => ty.fold_with(folder).into(),
157            GenericArgKind::Const(ct) => ct.fold_with(folder).into(),
158        }
159    }
160}
161
162impl<'db> Relate<DbInterner<'db>> for GenericArg<'db> {
163    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
164        relation: &mut R,
165        a: Self,
166        b: Self,
167    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
168        match (a.kind(), b.kind()) {
169            (GenericArgKind::Lifetime(a_lt), GenericArgKind::Lifetime(b_lt)) => {
170                Ok(relation.relate(a_lt, b_lt)?.into())
171            }
172            (GenericArgKind::Type(a_ty), GenericArgKind::Type(b_ty)) => {
173                Ok(relation.relate(a_ty, b_ty)?.into())
174            }
175            (GenericArgKind::Const(a_ct), GenericArgKind::Const(b_ct)) => {
176                Ok(relation.relate(a_ct, b_ct)?.into())
177            }
178            (GenericArgKind::Lifetime(unpacked), x) => {
179                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
180            }
181            (GenericArgKind::Type(unpacked), x) => {
182                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
183            }
184            (GenericArgKind::Const(unpacked), x) => {
185                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
186            }
187        }
188    }
189}
190
191interned_vec_db!(GenericArgs, GenericArg);
192
193impl<'db> rustc_type_ir::inherent::GenericArg<DbInterner<'db>> for GenericArg<'db> {}
194
195impl<'db> GenericArgs<'db> {
196    /// Creates an `GenericArgs` for generic parameter definitions,
197    /// by calling closures to obtain each kind.
198    /// The closures get to observe the `GenericArgs` as they're
199    /// being built, which can be used to correctly
200    /// replace defaults of generic parameters.
201    pub fn for_item<F>(
202        interner: DbInterner<'db>,
203        def_id: SolverDefId,
204        mut mk_kind: F,
205    ) -> GenericArgs<'db>
206    where
207        F: FnMut(&Symbol, u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
208    {
209        let defs = interner.generics_of(def_id);
210        let count = defs.count();
211        let mut args = SmallVec::with_capacity(count);
212        Self::fill_item(&mut args, interner, defs, &mut mk_kind);
213        interner.mk_args(&args)
214    }
215
216    fn fill_item<F>(
217        args: &mut SmallVec<[GenericArg<'db>; 8]>,
218        interner: DbInterner<'_>,
219        defs: Generics,
220        mk_kind: &mut F,
221    ) where
222        F: FnMut(&Symbol, u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
223    {
224        let self_len = defs.own_params.len() as u32;
225        if let Some(def_id) = defs.parent {
226            let parent_defs = interner.generics_of(def_id.into());
227            Self::fill_item(args, interner, parent_defs, mk_kind);
228        }
229        Self::fill_single(args, &defs, mk_kind);
230    }
231
232    fn fill_single<F>(args: &mut SmallVec<[GenericArg<'db>; 8]>, defs: &Generics, mk_kind: &mut F)
233    where
234        F: FnMut(&Symbol, u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
235    {
236        let start_len = args.len();
237        args.reserve(defs.own_params.len());
238        for param in &defs.own_params {
239            let kind = mk_kind(&param.name, args.len() as u32, param.id, args);
240            args.push(kind);
241        }
242    }
243}
244
245impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> {
246    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
247        relation: &mut R,
248        a: Self,
249        b: Self,
250    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
251        let interner = relation.cx();
252        CollectAndApply::collect_and_apply(
253            std::iter::zip(a.iter(), b.iter()).map(|(a, b)| {
254                relation.relate_with_variance(
255                    Variance::Invariant,
256                    VarianceDiagInfo::default(),
257                    a,
258                    b,
259                )
260            }),
261            |g| GenericArgs::new_from_iter(interner, g.iter().cloned()),
262        )
263    }
264}
265
266impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs<'db> {
267    fn as_closure(self) -> ClosureArgs<DbInterner<'db>> {
268        ClosureArgs { args: self }
269    }
270    fn as_coroutine(self) -> CoroutineArgs<DbInterner<'db>> {
271        CoroutineArgs { args: self }
272    }
273    fn as_coroutine_closure(self) -> CoroutineClosureArgs<DbInterner<'db>> {
274        CoroutineClosureArgs { args: self }
275    }
276    fn rebase_onto(
277        self,
278        interner: DbInterner<'db>,
279        source_def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
280        target: <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs,
281    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
282        let defs = interner.generics_of(source_def_id);
283        interner.mk_args_from_iter(target.iter().chain(self.iter().skip(defs.count())))
284    }
285
286    fn identity_for_item(
287        interner: DbInterner<'db>,
288        def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
289    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
290        Self::for_item(interner, def_id, |name, index, kind, _| {
291            mk_param(interner, index, name, kind)
292        })
293    }
294
295    fn extend_with_error(
296        interner: DbInterner<'db>,
297        def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
298        original_args: &[<DbInterner<'db> as rustc_type_ir::Interner>::GenericArg],
299    ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
300        Self::for_item(interner, def_id, |name, index, kind, _| {
301            if let Some(arg) = original_args.get(index as usize) {
302                *arg
303            } else {
304                error_for_param_kind(kind, interner)
305            }
306        })
307    }
308    fn type_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Ty {
309        self.inner()
310            .get(i)
311            .and_then(|g| g.as_type())
312            .unwrap_or_else(|| Ty::new_error(DbInterner::conjure(), ErrorGuaranteed))
313    }
314
315    fn region_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Region {
316        self.inner()
317            .get(i)
318            .and_then(|g| g.as_region())
319            .unwrap_or_else(|| Region::error(DbInterner::conjure()))
320    }
321
322    fn const_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Const {
323        self.inner()
324            .get(i)
325            .and_then(|g| g.as_const())
326            .unwrap_or_else(|| Const::error(DbInterner::conjure()))
327    }
328
329    fn split_closure_args(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> {
330        // FIXME: should use `ClosureSubst` when possible
331        match self.inner().as_slice() {
332            [parent_args @ .., sig_ty] => {
333                let interner = DbInterner::conjure();
334                // This is stupid, but the next solver expects the first input to actually be a tuple
335                let sig_ty = match sig_ty.expect_ty().kind() {
336                    TyKind::FnPtr(sig_tys, header) => Ty::new(
337                        interner,
338                        TyKind::FnPtr(
339                            sig_tys.map_bound(|s| {
340                                let inputs = Ty::new_tup_from_iter(interner, s.inputs().iter());
341                                let output = s.output();
342                                FnSigTys {
343                                    inputs_and_output: Tys::new_from_iter(
344                                        interner,
345                                        [inputs, output],
346                                    ),
347                                }
348                            }),
349                            header,
350                        ),
351                    ),
352                    _ => unreachable!("sig_ty should be last"),
353                };
354                rustc_type_ir::ClosureArgsParts {
355                    parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
356                    closure_sig_as_fn_ptr_ty: sig_ty,
357                    closure_kind_ty: Ty::new(interner, TyKind::Int(IntTy::I8)),
358                    tupled_upvars_ty: Ty::new_unit(interner),
359                }
360            }
361            _ => {
362                unreachable!("unexpected closure sig");
363            }
364        }
365    }
366
367    fn split_coroutine_closure_args(
368        self,
369    ) -> rustc_type_ir::CoroutineClosureArgsParts<DbInterner<'db>> {
370        match self.inner().as_slice() {
371            [
372                parent_args @ ..,
373                closure_kind_ty,
374                signature_parts_ty,
375                tupled_upvars_ty,
376                coroutine_captures_by_ref_ty,
377                coroutine_witness_ty,
378            ] => rustc_type_ir::CoroutineClosureArgsParts {
379                parent_args: GenericArgs::new_from_iter(
380                    DbInterner::conjure(),
381                    parent_args.iter().cloned(),
382                ),
383                closure_kind_ty: closure_kind_ty.expect_ty(),
384                signature_parts_ty: signature_parts_ty.expect_ty(),
385                tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
386                coroutine_captures_by_ref_ty: coroutine_captures_by_ref_ty.expect_ty(),
387                coroutine_witness_ty: coroutine_witness_ty.expect_ty(),
388            },
389            _ => panic!("GenericArgs were likely not for a CoroutineClosure."),
390        }
391    }
392
393    fn split_coroutine_args(self) -> rustc_type_ir::CoroutineArgsParts<DbInterner<'db>> {
394        let interner = DbInterner::conjure();
395        match self.inner().as_slice() {
396            [parent_args @ .., resume_ty, yield_ty, return_ty] => {
397                rustc_type_ir::CoroutineArgsParts {
398                    parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
399                    kind_ty: Ty::new_unit(interner),
400                    resume_ty: resume_ty.expect_ty(),
401                    yield_ty: yield_ty.expect_ty(),
402                    return_ty: return_ty.expect_ty(),
403                    witness: Ty::new_unit(interner),
404                    tupled_upvars_ty: Ty::new_unit(interner),
405                }
406            }
407            _ => panic!("GenericArgs were likely not for a Coroutine."),
408        }
409    }
410}
411
412pub fn mk_param<'db>(
413    interner: DbInterner<'db>,
414    index: u32,
415    name: &Symbol,
416    id: GenericParamId,
417) -> GenericArg<'db> {
418    let name = name.clone();
419    match id {
420        GenericParamId::LifetimeParamId(id) => {
421            Region::new_early_param(interner, EarlyParamRegion { index, id }).into()
422        }
423        GenericParamId::TypeParamId(id) => Ty::new_param(interner, id, index, name).into(),
424        GenericParamId::ConstParamId(id) => {
425            Const::new_param(interner, ParamConst { index, id }).into()
426        }
427    }
428}
429
430pub fn error_for_param_kind<'db>(id: GenericParamId, interner: DbInterner<'db>) -> GenericArg<'db> {
431    match id {
432        GenericParamId::LifetimeParamId(_) => Region::error(interner).into(),
433        GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(),
434        GenericParamId::ConstParamId(_) => Const::error(interner).into(),
435    }
436}
437
438impl<'db> IntoKind for Term<'db> {
439    type Kind = TermKind<DbInterner<'db>>;
440
441    fn kind(self) -> Self::Kind {
442        match self {
443            Term::Ty(ty) => TermKind::Ty(ty),
444            Term::Const(c) => TermKind::Const(c),
445        }
446    }
447}
448
449impl<'db> From<Ty<'db>> for Term<'db> {
450    fn from(value: Ty<'db>) -> Self {
451        Self::Ty(value)
452    }
453}
454
455impl<'db> From<Const<'db>> for Term<'db> {
456    fn from(value: Const<'db>) -> Self {
457        Self::Const(value)
458    }
459}
460
461impl<'db> TypeVisitable<DbInterner<'db>> for Term<'db> {
462    fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
463        &self,
464        visitor: &mut V,
465    ) -> V::Result {
466        match self {
467            Term::Ty(ty) => ty.visit_with(visitor),
468            Term::Const(ct) => ct.visit_with(visitor),
469        }
470    }
471}
472
473impl<'db> TypeFoldable<DbInterner<'db>> for Term<'db> {
474    fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
475        self,
476        folder: &mut F,
477    ) -> Result<Self, F::Error> {
478        match self.kind() {
479            TermKind::Ty(ty) => ty.try_fold_with(folder).map(Into::into),
480            TermKind::Const(ct) => ct.try_fold_with(folder).map(Into::into),
481        }
482    }
483    fn fold_with<F: rustc_type_ir::TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self {
484        match self.kind() {
485            TermKind::Ty(ty) => ty.fold_with(folder).into(),
486            TermKind::Const(ct) => ct.fold_with(folder).into(),
487        }
488    }
489}
490
491impl<'db> Relate<DbInterner<'db>> for Term<'db> {
492    fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
493        relation: &mut R,
494        a: Self,
495        b: Self,
496    ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
497        match (a.kind(), b.kind()) {
498            (TermKind::Ty(a_ty), TermKind::Ty(b_ty)) => Ok(relation.relate(a_ty, b_ty)?.into()),
499            (TermKind::Const(a_ct), TermKind::Const(b_ct)) => {
500                Ok(relation.relate(a_ct, b_ct)?.into())
501            }
502            (TermKind::Ty(unpacked), x) => {
503                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
504            }
505            (TermKind::Const(unpacked), x) => {
506                unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
507            }
508        }
509    }
510}
511
512impl<'db> rustc_type_ir::inherent::Term<DbInterner<'db>> for Term<'db> {}
513
514#[derive(Clone, Eq, PartialEq, Debug)]
515pub enum TermVid {
516    Ty(TyVid),
517    Const(ConstVid),
518}
519
520impl From<TyVid> for TermVid {
521    fn from(value: TyVid) -> Self {
522        TermVid::Ty(value)
523    }
524}
525
526impl From<ConstVid> for TermVid {
527    fn from(value: ConstVid) -> Self {
528        TermVid::Const(value)
529    }
530}
531
532impl<'db> DbInterner<'db> {
533    pub(super) fn mk_args(self, args: &[GenericArg<'db>]) -> GenericArgs<'db> {
534        GenericArgs::new_from_iter(self, args.iter().cloned())
535    }
536
537    pub(super) fn mk_args_from_iter<I, T>(self, iter: I) -> T::Output
538    where
539        I: Iterator<Item = T>,
540        T: rustc_type_ir::CollectAndApply<GenericArg<'db>, GenericArgs<'db>>,
541    {
542        T::collect_and_apply(iter, |xs| self.mk_args(xs))
543    }
544
545    pub(super) fn check_args_compatible(self, def_id: SolverDefId, args: GenericArgs<'db>) -> bool {
546        // TODO
547        true
548    }
549
550    pub(super) fn debug_assert_args_compatible(self, def_id: SolverDefId, args: GenericArgs<'db>) {
551        // TODO
552    }
553}