1use hir_def::{GenericDefId, GenericParamId};
4use macros::{TypeFoldable, TypeVisitable};
5use rustc_type_ir::{
6 ClosureArgs, CollectAndApply, ConstVid, CoroutineArgs, CoroutineClosureArgs, FnSigTys,
7 GenericArgKind, Interner, TermKind, TyKind, TyVid, Variance,
8 inherent::{GenericArg as _, GenericsOf, IntoKind, SliceLike, Term as _, Ty as _},
9 relate::{Relate, VarianceDiagInfo},
10 walk::TypeWalker,
11};
12use smallvec::SmallVec;
13
14use crate::next_solver::{PolyFnSig, interned_vec_db};
15
16use super::{
17 Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, Tys,
18 generics::Generics,
19};
20
21#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable, salsa::Supertype)]
22pub enum GenericArg<'db> {
23 Ty(Ty<'db>),
24 Lifetime(Region<'db>),
25 Const(Const<'db>),
26}
27
28impl<'db> std::fmt::Debug for GenericArg<'db> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 Self::Ty(t) => std::fmt::Debug::fmt(t, f),
32 Self::Lifetime(r) => std::fmt::Debug::fmt(r, f),
33 Self::Const(c) => std::fmt::Debug::fmt(c, f),
34 }
35 }
36}
37
38impl<'db> GenericArg<'db> {
39 pub fn ty(self) -> Option<Ty<'db>> {
40 match self.kind() {
41 GenericArgKind::Type(ty) => Some(ty),
42 _ => None,
43 }
44 }
45
46 pub fn expect_ty(self) -> Ty<'db> {
47 match self.kind() {
48 GenericArgKind::Type(ty) => ty,
49 _ => panic!("Expected ty, got {self:?}"),
50 }
51 }
52
53 pub fn konst(self) -> Option<Const<'db>> {
54 match self.kind() {
55 GenericArgKind::Const(konst) => Some(konst),
56 _ => None,
57 }
58 }
59
60 pub fn region(self) -> Option<Region<'db>> {
61 match self.kind() {
62 GenericArgKind::Lifetime(r) => Some(r),
63 _ => None,
64 }
65 }
66
67 #[inline]
68 pub(crate) fn expect_region(self) -> Region<'db> {
69 match self {
70 GenericArg::Lifetime(region) => region,
71 _ => panic!("expected a region, got {self:?}"),
72 }
73 }
74
75 pub fn error_from_id(interner: DbInterner<'db>, id: GenericParamId) -> GenericArg<'db> {
76 match id {
77 GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(),
78 GenericParamId::ConstParamId(_) => Const::error(interner).into(),
79 GenericParamId::LifetimeParamId(_) => Region::error(interner).into(),
80 }
81 }
82
83 #[inline]
84 pub fn walk(self) -> TypeWalker<DbInterner<'db>> {
85 TypeWalker::new(self)
86 }
87}
88
89impl<'db> From<Term<'db>> for GenericArg<'db> {
90 fn from(value: Term<'db>) -> Self {
91 match value {
92 Term::Ty(ty) => GenericArg::Ty(ty),
93 Term::Const(c) => GenericArg::Const(c),
94 }
95 }
96}
97
98#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)]
99pub enum Term<'db> {
100 Ty(Ty<'db>),
101 Const(Const<'db>),
102}
103
104impl<'db> std::fmt::Debug for Term<'db> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 Self::Ty(t) => std::fmt::Debug::fmt(t, f),
108 Self::Const(c) => std::fmt::Debug::fmt(c, f),
109 }
110 }
111}
112
113impl<'db> Term<'db> {
114 pub fn expect_type(&self) -> Ty<'db> {
115 self.as_type().expect("expected a type, but found a const")
116 }
117
118 pub fn is_trivially_wf(&self, tcx: DbInterner<'db>) -> bool {
119 match self.kind() {
120 TermKind::Ty(ty) => ty.is_trivially_wf(tcx),
121 TermKind::Const(ct) => ct.is_trivially_wf(),
122 }
123 }
124}
125
126impl<'db> From<Ty<'db>> for GenericArg<'db> {
127 fn from(value: Ty<'db>) -> Self {
128 Self::Ty(value)
129 }
130}
131
132impl<'db> From<Region<'db>> for GenericArg<'db> {
133 fn from(value: Region<'db>) -> Self {
134 Self::Lifetime(value)
135 }
136}
137
138impl<'db> From<Const<'db>> for GenericArg<'db> {
139 fn from(value: Const<'db>) -> Self {
140 Self::Const(value)
141 }
142}
143
144impl<'db> IntoKind for GenericArg<'db> {
145 type Kind = GenericArgKind<DbInterner<'db>>;
146
147 fn kind(self) -> Self::Kind {
148 match self {
149 GenericArg::Ty(ty) => GenericArgKind::Type(ty),
150 GenericArg::Lifetime(region) => GenericArgKind::Lifetime(region),
151 GenericArg::Const(c) => GenericArgKind::Const(c),
152 }
153 }
154}
155
156impl<'db> Relate<DbInterner<'db>> for GenericArg<'db> {
157 fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
158 relation: &mut R,
159 a: Self,
160 b: Self,
161 ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
162 match (a.kind(), b.kind()) {
163 (GenericArgKind::Lifetime(a_lt), GenericArgKind::Lifetime(b_lt)) => {
164 Ok(relation.relate(a_lt, b_lt)?.into())
165 }
166 (GenericArgKind::Type(a_ty), GenericArgKind::Type(b_ty)) => {
167 Ok(relation.relate(a_ty, b_ty)?.into())
168 }
169 (GenericArgKind::Const(a_ct), GenericArgKind::Const(b_ct)) => {
170 Ok(relation.relate(a_ct, b_ct)?.into())
171 }
172 (GenericArgKind::Lifetime(unpacked), x) => {
173 unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
174 }
175 (GenericArgKind::Type(unpacked), x) => {
176 unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
177 }
178 (GenericArgKind::Const(unpacked), x) => {
179 unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
180 }
181 }
182 }
183}
184
185interned_vec_db!(GenericArgs, GenericArg);
186
187impl<'db> rustc_type_ir::inherent::GenericArg<DbInterner<'db>> for GenericArg<'db> {}
188
189impl<'db> GenericArgs<'db> {
190 pub fn for_item<F>(
196 interner: DbInterner<'db>,
197 def_id: SolverDefId,
198 mut mk_kind: F,
199 ) -> GenericArgs<'db>
200 where
201 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
202 {
203 let defs = interner.generics_of(def_id);
204 let count = defs.count();
205
206 if count == 0 {
207 return Default::default();
208 }
209
210 let mut args = SmallVec::with_capacity(count);
211 Self::fill_item(&mut args, interner, defs, &mut mk_kind);
212 interner.mk_args(&args)
213 }
214
215 pub fn error_for_item(interner: DbInterner<'db>, def_id: SolverDefId) -> GenericArgs<'db> {
217 GenericArgs::for_item(interner, def_id, |_, id, _| GenericArg::error_from_id(interner, id))
218 }
219
220 pub fn for_item_with_defaults<F>(
222 interner: DbInterner<'db>,
223 def_id: GenericDefId,
224 mut fallback: F,
225 ) -> GenericArgs<'db>
226 where
227 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
228 {
229 let defaults = interner.db.generic_defaults(def_id);
230 Self::for_item(interner, def_id.into(), |idx, id, prev| match defaults.get(idx as usize) {
231 Some(default) => default.instantiate(interner, prev),
232 None => fallback(idx, id, prev),
233 })
234 }
235
236 pub fn fill_rest<F>(
238 interner: DbInterner<'db>,
239 def_id: SolverDefId,
240 first: impl IntoIterator<Item = GenericArg<'db>>,
241 mut fallback: F,
242 ) -> GenericArgs<'db>
243 where
244 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
245 {
246 let mut iter = first.into_iter();
247 Self::for_item(interner, def_id, |idx, id, prev| {
248 iter.next().unwrap_or_else(|| fallback(idx, id, prev))
249 })
250 }
251
252 pub fn fill_with_defaults<F>(
254 interner: DbInterner<'db>,
255 def_id: GenericDefId,
256 first: impl IntoIterator<Item = GenericArg<'db>>,
257 mut fallback: F,
258 ) -> GenericArgs<'db>
259 where
260 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
261 {
262 let defaults = interner.db.generic_defaults(def_id);
263 Self::fill_rest(interner, def_id.into(), first, |idx, id, prev| {
264 defaults
265 .get(idx as usize)
266 .map(|default| default.instantiate(interner, prev))
267 .unwrap_or_else(|| fallback(idx, id, prev))
268 })
269 }
270
271 fn fill_item<F>(
272 args: &mut SmallVec<[GenericArg<'db>; 8]>,
273 interner: DbInterner<'_>,
274 defs: Generics,
275 mk_kind: &mut F,
276 ) where
277 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
278 {
279 if let Some(def_id) = defs.parent {
280 let parent_defs = interner.generics_of(def_id.into());
281 Self::fill_item(args, interner, parent_defs, mk_kind);
282 }
283 Self::fill_single(args, &defs, mk_kind);
284 }
285
286 fn fill_single<F>(args: &mut SmallVec<[GenericArg<'db>; 8]>, defs: &Generics, mk_kind: &mut F)
287 where
288 F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>,
289 {
290 args.reserve(defs.own_params.len());
291 for param in &defs.own_params {
292 let kind = mk_kind(args.len() as u32, param.id, args);
293 args.push(kind);
294 }
295 }
296
297 pub fn closure_sig_untupled(self) -> PolyFnSig<'db> {
298 let TyKind::FnPtr(inputs_and_output, hdr) =
299 self.split_closure_args_untupled().closure_sig_as_fn_ptr_ty.kind()
300 else {
301 unreachable!("not a function pointer")
302 };
303 inputs_and_output.with(hdr)
304 }
305
306 pub fn split_closure_args_untupled(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> {
308 match self.inner().as_slice() {
310 [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => {
311 let interner = DbInterner::conjure();
312 rustc_type_ir::ClosureArgsParts {
313 parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
314 closure_sig_as_fn_ptr_ty: sig_ty.expect_ty(),
315 closure_kind_ty: closure_kind_ty.expect_ty(),
316 tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
317 }
318 }
319 _ => {
320 unreachable!("unexpected closure sig");
321 }
322 }
323 }
324
325 pub fn types(self) -> impl Iterator<Item = Ty<'db>> {
326 self.iter().filter_map(|it| it.as_type())
327 }
328
329 pub fn consts(self) -> impl Iterator<Item = Const<'db>> {
330 self.iter().filter_map(|it| it.as_const())
331 }
332
333 pub fn regions(self) -> impl Iterator<Item = Region<'db>> {
334 self.iter().filter_map(|it| it.as_region())
335 }
336}
337
338impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> {
339 fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
340 relation: &mut R,
341 a: Self,
342 b: Self,
343 ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
344 let interner = relation.cx();
345 CollectAndApply::collect_and_apply(
346 std::iter::zip(a.iter(), b.iter()).map(|(a, b)| {
347 relation.relate_with_variance(
348 Variance::Invariant,
349 VarianceDiagInfo::default(),
350 a,
351 b,
352 )
353 }),
354 |g| GenericArgs::new_from_iter(interner, g.iter().cloned()),
355 )
356 }
357}
358
359impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs<'db> {
360 fn as_closure(self) -> ClosureArgs<DbInterner<'db>> {
361 ClosureArgs { args: self }
362 }
363 fn as_coroutine(self) -> CoroutineArgs<DbInterner<'db>> {
364 CoroutineArgs { args: self }
365 }
366 fn as_coroutine_closure(self) -> CoroutineClosureArgs<DbInterner<'db>> {
367 CoroutineClosureArgs { args: self }
368 }
369 fn rebase_onto(
370 self,
371 interner: DbInterner<'db>,
372 source_def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
373 target: <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs,
374 ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
375 let defs = interner.generics_of(source_def_id);
376 interner.mk_args_from_iter(target.iter().chain(self.iter().skip(defs.count())))
377 }
378
379 fn identity_for_item(
380 interner: DbInterner<'db>,
381 def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
382 ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
383 Self::for_item(interner, def_id, |index, kind, _| mk_param(interner, index, kind))
384 }
385
386 fn extend_with_error(
387 interner: DbInterner<'db>,
388 def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId,
389 original_args: &[<DbInterner<'db> as rustc_type_ir::Interner>::GenericArg],
390 ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs {
391 Self::for_item(interner, def_id, |index, kind, _| {
392 if let Some(arg) = original_args.get(index as usize) {
393 *arg
394 } else {
395 error_for_param_kind(kind, interner)
396 }
397 })
398 }
399 fn type_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Ty {
400 self.inner()
401 .get(i)
402 .and_then(|g| g.as_type())
403 .unwrap_or_else(|| Ty::new_error(DbInterner::conjure(), ErrorGuaranteed))
404 }
405
406 fn region_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Region {
407 self.inner()
408 .get(i)
409 .and_then(|g| g.as_region())
410 .unwrap_or_else(|| Region::error(DbInterner::conjure()))
411 }
412
413 fn const_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Const {
414 self.inner()
415 .get(i)
416 .and_then(|g| g.as_const())
417 .unwrap_or_else(|| Const::error(DbInterner::conjure()))
418 }
419
420 fn split_closure_args(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> {
421 match self.inner().as_slice() {
423 [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => {
424 let interner = DbInterner::conjure();
425 let sig_ty = match sig_ty.expect_ty().kind() {
427 TyKind::FnPtr(sig_tys, header) => Ty::new(
428 interner,
429 TyKind::FnPtr(
430 sig_tys.map_bound(|s| {
431 let inputs = Ty::new_tup_from_iter(interner, s.inputs().iter());
432 let output = s.output();
433 FnSigTys {
434 inputs_and_output: Tys::new_from_iter(
435 interner,
436 [inputs, output],
437 ),
438 }
439 }),
440 header,
441 ),
442 ),
443 _ => unreachable!("sig_ty should be last"),
444 };
445 rustc_type_ir::ClosureArgsParts {
446 parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
447 closure_sig_as_fn_ptr_ty: sig_ty,
448 closure_kind_ty: closure_kind_ty.expect_ty(),
449 tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
450 }
451 }
452 _ => {
453 unreachable!("unexpected closure sig");
454 }
455 }
456 }
457
458 fn split_coroutine_closure_args(
459 self,
460 ) -> rustc_type_ir::CoroutineClosureArgsParts<DbInterner<'db>> {
461 match self.inner().as_slice() {
462 [
463 parent_args @ ..,
464 closure_kind_ty,
465 signature_parts_ty,
466 tupled_upvars_ty,
467 coroutine_captures_by_ref_ty,
468 ] => rustc_type_ir::CoroutineClosureArgsParts {
469 parent_args: GenericArgs::new_from_iter(
470 DbInterner::conjure(),
471 parent_args.iter().cloned(),
472 ),
473 closure_kind_ty: closure_kind_ty.expect_ty(),
474 signature_parts_ty: signature_parts_ty.expect_ty(),
475 tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
476 coroutine_captures_by_ref_ty: coroutine_captures_by_ref_ty.expect_ty(),
477 },
478 _ => panic!("GenericArgs were likely not for a CoroutineClosure."),
479 }
480 }
481
482 fn split_coroutine_args(self) -> rustc_type_ir::CoroutineArgsParts<DbInterner<'db>> {
483 let interner = DbInterner::conjure();
484 match self.inner().as_slice() {
485 [parent_args @ .., kind_ty, resume_ty, yield_ty, return_ty, tupled_upvars_ty] => {
486 rustc_type_ir::CoroutineArgsParts {
487 parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()),
488 kind_ty: kind_ty.expect_ty(),
489 resume_ty: resume_ty.expect_ty(),
490 yield_ty: yield_ty.expect_ty(),
491 return_ty: return_ty.expect_ty(),
492 tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
493 }
494 }
495 _ => panic!("GenericArgs were likely not for a Coroutine."),
496 }
497 }
498}
499
500pub fn mk_param<'db>(interner: DbInterner<'db>, index: u32, id: GenericParamId) -> GenericArg<'db> {
501 match id {
502 GenericParamId::LifetimeParamId(id) => {
503 Region::new_early_param(interner, EarlyParamRegion { index, id }).into()
504 }
505 GenericParamId::TypeParamId(id) => Ty::new_param(interner, id, index).into(),
506 GenericParamId::ConstParamId(id) => {
507 Const::new_param(interner, ParamConst { index, id }).into()
508 }
509 }
510}
511
512pub fn error_for_param_kind<'db>(id: GenericParamId, interner: DbInterner<'db>) -> GenericArg<'db> {
513 match id {
514 GenericParamId::LifetimeParamId(_) => Region::error(interner).into(),
515 GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(),
516 GenericParamId::ConstParamId(_) => Const::error(interner).into(),
517 }
518}
519
520impl<'db> IntoKind for Term<'db> {
521 type Kind = TermKind<DbInterner<'db>>;
522
523 fn kind(self) -> Self::Kind {
524 match self {
525 Term::Ty(ty) => TermKind::Ty(ty),
526 Term::Const(c) => TermKind::Const(c),
527 }
528 }
529}
530
531impl<'db> From<Ty<'db>> for Term<'db> {
532 fn from(value: Ty<'db>) -> Self {
533 Self::Ty(value)
534 }
535}
536
537impl<'db> From<Const<'db>> for Term<'db> {
538 fn from(value: Const<'db>) -> Self {
539 Self::Const(value)
540 }
541}
542
543impl<'db> Relate<DbInterner<'db>> for Term<'db> {
544 fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>(
545 relation: &mut R,
546 a: Self,
547 b: Self,
548 ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> {
549 match (a.kind(), b.kind()) {
550 (TermKind::Ty(a_ty), TermKind::Ty(b_ty)) => Ok(relation.relate(a_ty, b_ty)?.into()),
551 (TermKind::Const(a_ct), TermKind::Const(b_ct)) => {
552 Ok(relation.relate(a_ct, b_ct)?.into())
553 }
554 (TermKind::Ty(unpacked), x) => {
555 unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
556 }
557 (TermKind::Const(unpacked), x) => {
558 unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x)
559 }
560 }
561 }
562}
563
564impl<'db> rustc_type_ir::inherent::Term<DbInterner<'db>> for Term<'db> {}
565
566#[derive(Clone, Eq, PartialEq, Debug)]
567pub enum TermVid {
568 Ty(TyVid),
569 Const(ConstVid),
570}
571
572impl From<TyVid> for TermVid {
573 fn from(value: TyVid) -> Self {
574 TermVid::Ty(value)
575 }
576}
577
578impl From<ConstVid> for TermVid {
579 fn from(value: ConstVid) -> Self {
580 TermVid::Const(value)
581 }
582}
583
584impl<'db> DbInterner<'db> {
585 pub(super) fn mk_args(self, args: &[GenericArg<'db>]) -> GenericArgs<'db> {
586 GenericArgs::new_from_iter(self, args.iter().cloned())
587 }
588
589 pub(super) fn mk_args_from_iter<I, T>(self, iter: I) -> T::Output
590 where
591 I: Iterator<Item = T>,
592 T: rustc_type_ir::CollectAndApply<GenericArg<'db>, GenericArgs<'db>>,
593 {
594 T::collect_and_apply(iter, |xs| self.mk_args(xs))
595 }
596}