1use crate::db::HirDatabase;
17use crate::generics::{Generics, generics};
18use crate::{
19 AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime,
20 LifetimeData, Ty, TyKind,
21};
22use chalk_ir::Mutability;
23use hir_def::signatures::StructFlags;
24use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId};
25use std::fmt;
26use std::ops::Not;
27use stdx::never;
28use triomphe::Arc;
29
30pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> {
31 tracing::debug!("variances_of(def={:?})", def);
32 match def {
33 GenericDefId::FunctionId(_) => (),
34 GenericDefId::AdtId(adt) => {
35 if let AdtId::StructId(id) = adt {
36 let flags = &db.struct_signature(id).flags;
37 if flags.contains(StructFlags::IS_UNSAFE_CELL) {
38 return Some(Arc::from_iter(vec![Variance::Invariant; 1]));
39 } else if flags.contains(StructFlags::IS_PHANTOM_DATA) {
40 return Some(Arc::from_iter(vec![Variance::Covariant; 1]));
41 }
42 }
43 }
44 _ => return None,
45 }
46
47 let generics = generics(db, def);
48 let count = generics.len();
49 if count == 0 {
50 return None;
51 }
52 let variances = Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();
53
54 variances.is_empty().not().then(|| Arc::from_iter(variances))
55}
56
57pub(crate) fn variances_of_cycle_initial(
67 db: &dyn HirDatabase,
68 def: GenericDefId,
69) -> Option<Arc<[Variance]>> {
70 let generics = generics(db, def);
71 let count = generics.len();
72
73 if count == 0 {
74 return None;
75 }
76 Some(Arc::from(vec![Variance::Bivariant; count]))
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
80pub enum Variance {
81 Covariant, Invariant, Contravariant, Bivariant, }
86
87impl fmt::Display for Variance {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Variance::Covariant => write!(f, "covariant"),
91 Variance::Invariant => write!(f, "invariant"),
92 Variance::Contravariant => write!(f, "contravariant"),
93 Variance::Bivariant => write!(f, "bivariant"),
94 }
95 }
96}
97
98impl Variance {
99 fn xform(self, v: Variance) -> Variance {
136 match (self, v) {
137 (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
139 (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
140 (Variance::Covariant, Variance::Invariant) => Variance::Invariant,
141 (Variance::Covariant, Variance::Bivariant) => Variance::Bivariant,
142
143 (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
145 (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
146 (Variance::Contravariant, Variance::Invariant) => Variance::Invariant,
147 (Variance::Contravariant, Variance::Bivariant) => Variance::Bivariant,
148
149 (Variance::Invariant, _) => Variance::Invariant,
151
152 (Variance::Bivariant, _) => Variance::Bivariant,
154 }
155 }
156
157 fn glb(self, v: Variance) -> Variance {
158 match (self, v) {
165 (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
166
167 (Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
168 (Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
169
170 (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
171
172 (Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
173
174 (x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
175 }
176 }
177
178 pub fn invariant(self) -> Self {
179 self.xform(Variance::Invariant)
180 }
181
182 pub fn covariant(self) -> Self {
183 self.xform(Variance::Covariant)
184 }
185
186 pub fn contravariant(self) -> Self {
187 self.xform(Variance::Contravariant)
188 }
189}
190
191struct Context<'db> {
192 db: &'db dyn HirDatabase,
193 generics: Generics,
194 variances: Vec<Variance>,
195}
196
197impl Context<'_> {
198 fn solve(mut self) -> Vec<Variance> {
199 tracing::debug!("solve(generics={:?})", self.generics);
200 match self.generics.def() {
201 GenericDefId::AdtId(adt) => {
202 let db = self.db;
203 let mut add_constraints_from_variant = |variant| {
204 let subst = self.generics.placeholder_subst(db);
205 for (_, field) in db.field_types(variant).iter() {
206 self.add_constraints_from_ty(
207 &field.clone().substitute(Interner, &subst),
208 Variance::Covariant,
209 );
210 }
211 };
212 match adt {
213 AdtId::StructId(s) => add_constraints_from_variant(VariantId::StructId(s)),
214 AdtId::UnionId(u) => add_constraints_from_variant(VariantId::UnionId(u)),
215 AdtId::EnumId(e) => {
216 e.enum_variants(db).variants.iter().for_each(|&(variant, _, _)| {
217 add_constraints_from_variant(VariantId::EnumVariantId(variant))
218 });
219 }
220 }
221 }
222 GenericDefId::FunctionId(f) => {
223 let subst = self.generics.placeholder_subst(self.db);
224 self.add_constraints_from_sig(
225 self.db
226 .callable_item_signature(f.into())
227 .substitute(Interner, &subst)
228 .params_and_return
229 .iter(),
230 Variance::Covariant,
231 );
232 }
233 _ => {}
234 }
235 let mut variances = self.variances;
236
237 for (idx, param) in self.generics.iter_id().enumerate() {
240 if let GenericParamId::ConstParamId(_) = param {
241 variances[idx] = Variance::Invariant;
242 }
243 }
244
245 if let GenericDefId::FunctionId(_) = self.generics.def() {
247 variances
248 .iter_mut()
249 .filter(|&&mut v| v == Variance::Bivariant)
250 .for_each(|v| *v = Variance::Invariant);
251 }
252
253 variances
254 }
255
256 fn add_constraints_from_ty(&mut self, ty: &Ty, variance: Variance) {
260 tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance);
261 match ty.kind(Interner) {
262 TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => {
263 }
265 TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => {
266 never!("Unexpected unnameable type in variance computation: {:?}", ty);
267 }
268 TyKind::Ref(mutbl, lifetime, ty) => {
269 self.add_constraints_from_region(lifetime, variance);
270 self.add_constraints_from_mt(ty, *mutbl, variance);
271 }
272 TyKind::Array(typ, len) => {
273 self.add_constraints_from_const(len, variance);
274 self.add_constraints_from_ty(typ, variance);
275 }
276 TyKind::Slice(typ) => {
277 self.add_constraints_from_ty(typ, variance);
278 }
279 TyKind::Raw(mutbl, ty) => {
280 self.add_constraints_from_mt(ty, *mutbl, variance);
281 }
282 TyKind::Tuple(_, subtys) => {
283 for subty in subtys.type_parameters(Interner) {
284 self.add_constraints_from_ty(&subty, variance);
285 }
286 }
287 TyKind::Adt(def, args) => {
288 self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance);
289 }
290 TyKind::Alias(AliasTy::Opaque(opaque)) => {
291 self.add_constraints_from_invariant_args(
292 opaque.substitution.as_slice(Interner),
293 variance,
294 );
295 }
296 TyKind::Alias(AliasTy::Projection(proj)) => {
297 self.add_constraints_from_invariant_args(
298 proj.substitution.as_slice(Interner),
299 variance,
300 );
301 }
302 TyKind::AssociatedType(_, subst) => {
304 self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
305 }
306 TyKind::OpaqueType(_, subst) => {
308 self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
309 }
310 TyKind::Dyn(it) => {
311 self.add_constraints_from_region(&it.lifetime, variance);
313
314 if let Some(trait_ref) = it.principal() {
315 self.add_constraints_from_invariant_args(
317 trait_ref
318 .map(|it| it.map(|it| it.substitution.clone()))
319 .substitute(
320 Interner,
321 &[GenericArg::new(
322 Interner,
323 chalk_ir::GenericArgData::Ty(TyKind::Error.intern(Interner)),
324 )],
325 )
326 .skip_binders()
327 .as_slice(Interner),
328 variance,
329 );
330 }
331
332 }
344
345 TyKind::Placeholder(index) => {
347 let idx = crate::from_placeholder_idx(self.db, *index).0;
348 let index = self.generics.type_or_const_param_idx(idx).unwrap();
349 self.constrain(index, variance);
350 }
351 TyKind::Function(f) => {
352 self.add_constraints_from_sig(
353 f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)),
354 variance,
355 );
356 }
357 TyKind::Error => {
358 }
361 TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => {
362 never!("unexpected type encountered in variance inference: {:?}", ty)
363 }
364 }
365 }
366
367 fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) {
368 let variance_i = variance.invariant();
369
370 for k in args {
371 match k.data(Interner) {
372 GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i),
373 GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
374 GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i),
375 }
376 }
377 }
378
379 fn add_constraints_from_args(
382 &mut self,
383 def_id: GenericDefId,
384 args: &[GenericArg],
385 variance: Variance,
386 ) {
387 if args.is_empty() {
389 return;
390 }
391 let Some(variances) = self.db.variances_of(def_id) else {
392 return;
393 };
394
395 for (i, k) in args.iter().enumerate() {
396 match k.data(Interner) {
397 GenericArgData::Lifetime(lt) => {
398 self.add_constraints_from_region(lt, variance.xform(variances[i]))
399 }
400 GenericArgData::Ty(ty) => {
401 self.add_constraints_from_ty(ty, variance.xform(variances[i]))
402 }
403 GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
404 }
405 }
406 }
407
408 fn add_constraints_from_const(&mut self, c: &Const, variance: Variance) {
411 match &c.data(Interner).value {
412 chalk_ir::ConstValue::Concrete(c) => {
413 if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned {
414 self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
415 }
416 }
417 _ => {}
418 }
419 }
420
421 fn add_constraints_from_sig<'a>(
424 &mut self,
425 mut sig_tys: impl DoubleEndedIterator<Item = &'a Ty>,
426 variance: Variance,
427 ) {
428 let contra = variance.contravariant();
429 let Some(output) = sig_tys.next_back() else {
430 return never!("function signature has no return type");
431 };
432 self.add_constraints_from_ty(output, variance);
433 for input in sig_tys {
434 self.add_constraints_from_ty(input, contra);
435 }
436 }
437
438 fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) {
441 tracing::debug!(
442 "add_constraints_from_region(region={:?}, variance={:?})",
443 region,
444 variance
445 );
446 match region.data(Interner) {
447 LifetimeData::Placeholder(index) => {
448 let idx = crate::lt_from_placeholder_idx(self.db, *index).0;
449 let inferred = self.generics.lifetime_idx(idx).unwrap();
450 self.constrain(inferred, variance);
451 }
452 LifetimeData::Static => {}
453 LifetimeData::BoundVar(..) => {
454 }
459 LifetimeData::Error => {}
460 LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => {
461 never!(
464 "unexpected region encountered in variance \
465 inference: {:?}",
466 region
467 );
468 }
469 }
470 }
471
472 fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) {
475 self.add_constraints_from_ty(
476 ty,
477 match mt {
478 Mutability::Mut => variance.invariant(),
479 Mutability::Not => variance,
480 },
481 );
482 }
483
484 fn constrain(&mut self, index: usize, variance: Variance) {
485 tracing::debug!(
486 "constrain(index={:?}, variance={:?}, to={:?})",
487 index,
488 self.variances[index],
489 variance
490 );
491 self.variances[index] = self.variances[index].glb(variance);
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use expect_test::{Expect, expect};
498 use hir_def::{
499 AdtId, GenericDefId, ModuleDefId, hir::generics::GenericParamDataRef, src::HasSource,
500 };
501 use itertools::Itertools;
502 use stdx::format_to;
503 use syntax::{AstNode, ast::HasName};
504 use test_fixture::WithFixture;
505
506 use hir_def::Lookup;
507
508 use crate::{db::HirDatabase, test_db::TestDB, variance::generics};
509
510 #[test]
511 fn phantom_data() {
512 check(
513 r#"
514//- minicore: phantom_data
515
516struct Covariant<A> {
517 t: core::marker::PhantomData<A>
518}
519"#,
520 expect![[r#"
521 Covariant[A: covariant]
522 "#]],
523 );
524 }
525
526 #[test]
527 fn rustc_test_variance_types() {
528 check(
529 r#"
530//- minicore: cell
531
532use core::cell::UnsafeCell;
533
534struct InvariantMut<'a,A:'a,B:'a> { //~ ERROR ['a: +, A: o, B: o]
535 t: &'a mut (A,B)
536}
537
538struct InvariantCell<A> { //~ ERROR [A: o]
539 t: UnsafeCell<A>
540}
541
542struct InvariantIndirect<A> { //~ ERROR [A: o]
543 t: InvariantCell<A>
544}
545
546struct Covariant<A> { //~ ERROR [A: +]
547 t: A, u: fn() -> A
548}
549
550struct Contravariant<A> { //~ ERROR [A: -]
551 t: fn(A)
552}
553
554enum Enum<A,B,C> { //~ ERROR [A: +, B: -, C: o]
555 Foo(Covariant<A>),
556 Bar(Contravariant<B>),`
557 Zed(Covariant<C>,Contravariant<C>)
558}
559"#,
560 expect![[r#"
561 InvariantMut['a: covariant, A: invariant, B: invariant]
562 InvariantCell[A: invariant]
563 InvariantIndirect[A: invariant]
564 Covariant[A: covariant]
565 Contravariant[A: contravariant]
566 Enum[A: covariant, B: contravariant, C: invariant]
567 "#]],
568 );
569 }
570
571 #[test]
572 fn type_resolve_error_two_structs_deep() {
573 check(
574 r#"
575struct Hello<'a> {
576 missing: Missing<'a>,
577}
578
579struct Other<'a> {
580 hello: Hello<'a>,
581}
582"#,
583 expect![[r#"
584 Hello['a: bivariant]
585 Other['a: bivariant]
586 "#]],
587 );
588 }
589
590 #[test]
591 fn rustc_test_variance_associated_consts() {
592 check(
594 r#"
595trait Trait {
596 const Const: usize;
597}
598
599struct Foo<T: Trait> { //~ ERROR [T: o]
600 field: [u8; <T as Trait>::Const]
601}
602"#,
603 expect![[r#"
604 Foo[T: bivariant]
605 "#]],
606 );
607 }
608
609 #[test]
610 fn rustc_test_variance_associated_types() {
611 check(
612 r#"
613trait Trait<'a> {
614 type Type;
615
616 fn method(&'a self) { }
617}
618
619struct Foo<'a, T : Trait<'a>> { //~ ERROR ['a: +, T: +]
620 field: (T, &'a ())
621}
622
623struct Bar<'a, T : Trait<'a>> { //~ ERROR ['a: o, T: o]
624 field: <T as Trait<'a>>::Type
625}
626
627"#,
628 expect![[r#"
629 method[Self: contravariant, 'a: contravariant]
630 Foo['a: covariant, T: covariant]
631 Bar['a: invariant, T: invariant]
632 "#]],
633 );
634 }
635
636 #[test]
637 fn rustc_test_variance_associated_types2() {
638 check(
640 r#"
641trait Foo {
642 type Bar;
643}
644
645fn make() -> *const dyn Foo<Bar = &'static u32> {}
646"#,
647 expect![""],
648 );
649 }
650
651 #[test]
652 fn rustc_test_variance_trait_bounds() {
653 check(
654 r#"
655trait Getter<T> {
656 fn get(&self) -> T;
657}
658
659trait Setter<T> {
660 fn get(&self, _: T);
661}
662
663struct TestStruct<U,T:Setter<U>> { //~ ERROR [U: +, T: +]
664 t: T, u: U
665}
666
667enum TestEnum<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
668 //~^ ERROR: `U` is never used
669 Foo(T)
670}
671
672struct TestContraStruct<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
673 //~^ ERROR: `U` is never used
674 t: T
675}
676
677struct TestBox<U,T:Getter<U>+Setter<U>> { //~ ERROR [U: *, T: +]
678 //~^ ERROR: `U` is never used
679 t: T
680}
681"#,
682 expect![[r#"
683 get[Self: contravariant, T: covariant]
684 get[Self: contravariant, T: contravariant]
685 TestStruct[U: covariant, T: covariant]
686 TestEnum[U: bivariant, T: covariant]
687 TestContraStruct[U: bivariant, T: covariant]
688 TestBox[U: bivariant, T: covariant]
689 "#]],
690 );
691 }
692
693 #[test]
694 fn rustc_test_variance_trait_matching() {
695 check(
696 r#"
697
698trait Get<T> {
699 fn get(&self) -> T;
700}
701
702struct Cloner<T:Clone> {
703 t: T
704}
705
706impl<T:Clone> Get<T> for Cloner<T> {
707 fn get(&self) -> T {}
708}
709
710fn get<'a, G>(get: &G) -> i32
711 where G : Get<&'a i32>
712{}
713
714fn pick<'b, G>(get: &'b G, if_odd: &'b i32) -> i32
715 where G : Get<&'b i32>
716{}
717"#,
718 expect![[r#"
719 get[Self: contravariant, T: covariant]
720 Cloner[T: covariant]
721 get[T: invariant]
722 get['a: invariant, G: contravariant]
723 pick['b: contravariant, G: contravariant]
724 "#]],
725 );
726 }
727
728 #[test]
729 fn rustc_test_variance_trait_object_bound() {
730 check(
731 r#"
732enum Option<T> {
733 Some(T),
734 None
735}
736trait T { fn foo(&self); }
737
738struct TOption<'a> { //~ ERROR ['a: +]
739 v: Option<*const (dyn T + 'a)>,
740}
741"#,
742 expect![[r#"
743 Option[T: covariant]
744 foo[Self: contravariant]
745 TOption['a: covariant]
746 "#]],
747 );
748 }
749
750 #[test]
751 fn rustc_test_variance_types_bounds() {
752 check(
753 r#"
754//- minicore: send
755struct TestImm<A, B> { //~ ERROR [A: +, B: +]
756 x: A,
757 y: B,
758}
759
760struct TestMut<A, B:'static> { //~ ERROR [A: +, B: o]
761 x: A,
762 y: &'static mut B,
763}
764
765struct TestIndirect<A:'static, B:'static> { //~ ERROR [A: +, B: o]
766 m: TestMut<A, B>
767}
768
769struct TestIndirect2<A:'static, B:'static> { //~ ERROR [A: o, B: o]
770 n: TestMut<A, B>,
771 m: TestMut<B, A>
772}
773
774trait Getter<A> {
775 fn get(&self) -> A;
776}
777
778trait Setter<A> {
779 fn set(&mut self, a: A);
780}
781
782struct TestObject<A, R> { //~ ERROR [A: o, R: o]
783 n: *const (dyn Setter<A> + Send),
784 m: *const (dyn Getter<R> + Send),
785}
786"#,
787 expect![[r#"
788 TestImm[A: covariant, B: covariant]
789 TestMut[A: covariant, B: invariant]
790 TestIndirect[A: covariant, B: invariant]
791 TestIndirect2[A: invariant, B: invariant]
792 get[Self: contravariant, A: covariant]
793 set[Self: invariant, A: contravariant]
794 TestObject[A: invariant, R: invariant]
795 "#]],
796 );
797 }
798
799 #[test]
800 fn rustc_test_variance_unused_region_param() {
801 check(
802 r#"
803struct SomeStruct<'a> { x: u32 } //~ ERROR parameter `'a` is never used
804enum SomeEnum<'a> { Nothing } //~ ERROR parameter `'a` is never used
805trait SomeTrait<'a> { fn foo(&self); } // OK on traits.
806"#,
807 expect![[r#"
808 SomeStruct['a: bivariant]
809 SomeEnum['a: bivariant]
810 foo[Self: contravariant, 'a: invariant]
811 "#]],
812 );
813 }
814
815 #[test]
816 fn rustc_test_variance_unused_type_param() {
817 check(
818 r#"
819//- minicore: sized
820struct SomeStruct<A> { x: u32 }
821enum SomeEnum<A> { Nothing }
822enum ListCell<T> {
823 Cons(*const ListCell<T>),
824 Nil
825}
826
827struct SelfTyAlias<T>(*const Self);
828struct WithBounds<T: Sized> {}
829struct WithWhereBounds<T> where T: Sized {}
830struct WithOutlivesBounds<T: 'static> {}
831struct DoubleNothing<T> {
832 s: SomeStruct<T>,
833}
834
835"#,
836 expect![[r#"
837 SomeStruct[A: bivariant]
838 SomeEnum[A: bivariant]
839 ListCell[T: bivariant]
840 SelfTyAlias[T: bivariant]
841 WithBounds[T: bivariant]
842 WithWhereBounds[T: bivariant]
843 WithOutlivesBounds[T: bivariant]
844 DoubleNothing[T: bivariant]
845 "#]],
846 );
847 }
848
849 #[test]
850 fn rustc_test_variance_use_contravariant_struct1() {
851 check(
852 r#"
853struct SomeStruct<T>(fn(T));
854
855fn foo<'min,'max>(v: SomeStruct<&'max ()>)
856 -> SomeStruct<&'min ()>
857 where 'max : 'min
858{}
859"#,
860 expect![[r#"
861 SomeStruct[T: contravariant]
862 foo['min: contravariant, 'max: covariant]
863 "#]],
864 );
865 }
866
867 #[test]
868 fn rustc_test_variance_use_contravariant_struct2() {
869 check(
870 r#"
871struct SomeStruct<T>(fn(T));
872
873fn bar<'min,'max>(v: SomeStruct<&'min ()>)
874 -> SomeStruct<&'max ()>
875 where 'max : 'min
876{}
877"#,
878 expect![[r#"
879 SomeStruct[T: contravariant]
880 bar['min: covariant, 'max: contravariant]
881 "#]],
882 );
883 }
884
885 #[test]
886 fn rustc_test_variance_use_covariant_struct1() {
887 check(
888 r#"
889struct SomeStruct<T>(T);
890
891fn foo<'min,'max>(v: SomeStruct<&'min ()>)
892 -> SomeStruct<&'max ()>
893 where 'max : 'min
894{}
895"#,
896 expect![[r#"
897 SomeStruct[T: covariant]
898 foo['min: contravariant, 'max: covariant]
899 "#]],
900 );
901 }
902
903 #[test]
904 fn rustc_test_variance_use_covariant_struct2() {
905 check(
906 r#"
907struct SomeStruct<T>(T);
908
909fn foo<'min,'max>(v: SomeStruct<&'max ()>)
910 -> SomeStruct<&'min ()>
911 where 'max : 'min
912{}
913"#,
914 expect![[r#"
915 SomeStruct[T: covariant]
916 foo['min: covariant, 'max: contravariant]
917 "#]],
918 );
919 }
920
921 #[test]
922 fn rustc_test_variance_use_invariant_struct1() {
923 check(
924 r#"
925struct SomeStruct<T>(*mut T);
926
927fn foo<'min,'max>(v: SomeStruct<&'max ()>)
928 -> SomeStruct<&'min ()>
929 where 'max : 'min
930{}
931
932fn bar<'min,'max>(v: SomeStruct<&'min ()>)
933 -> SomeStruct<&'max ()>
934 where 'max : 'min
935{}
936"#,
937 expect![[r#"
938 SomeStruct[T: invariant]
939 foo['min: invariant, 'max: invariant]
940 bar['min: invariant, 'max: invariant]
941 "#]],
942 );
943 }
944
945 #[test]
946 fn invalid_arg_counts() {
947 check(
948 r#"
949struct S<T>(T);
950struct S2<T>(S<>);
951struct S3<T>(S<T, T>);
952"#,
953 expect![[r#"
954 S[T: covariant]
955 S2[T: bivariant]
956 S3[T: covariant]
957 "#]],
958 );
959 }
960
961 #[test]
962 fn prove_fixedpoint() {
963 check(
964 r#"
965struct FixedPoint<T, U, V>(&'static FixedPoint<(), T, U>, V);
966"#,
967 expect![[r#"
968 FixedPoint[T: bivariant, U: bivariant, V: bivariant]
969 "#]],
970 );
971 }
972
973 #[track_caller]
974 fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected: Expect) {
975 let (db, file_id) = TestDB::with_single_file(ra_fixture);
983
984 let mut defs: Vec<GenericDefId> = Vec::new();
985 let module = db.module_for_file_opt(file_id.file_id(&db)).unwrap();
986 let def_map = module.def_map(&db);
987 crate::tests::visit_module(&db, def_map, module.local_id, &mut |it| {
988 defs.push(match it {
989 ModuleDefId::FunctionId(it) => it.into(),
990 ModuleDefId::AdtId(it) => it.into(),
991 ModuleDefId::ConstId(it) => it.into(),
992 ModuleDefId::TraitId(it) => it.into(),
993 ModuleDefId::TypeAliasId(it) => it.into(),
994 _ => return,
995 })
996 });
997 let defs = defs
998 .into_iter()
999 .filter_map(|def| {
1000 Some((
1001 def,
1002 match def {
1003 GenericDefId::FunctionId(it) => {
1004 let loc = it.lookup(&db);
1005 loc.source(&db).value.name().unwrap()
1006 }
1007 GenericDefId::AdtId(AdtId::EnumId(it)) => {
1008 let loc = it.lookup(&db);
1009 loc.source(&db).value.name().unwrap()
1010 }
1011 GenericDefId::AdtId(AdtId::StructId(it)) => {
1012 let loc = it.lookup(&db);
1013 loc.source(&db).value.name().unwrap()
1014 }
1015 GenericDefId::AdtId(AdtId::UnionId(it)) => {
1016 let loc = it.lookup(&db);
1017 loc.source(&db).value.name().unwrap()
1018 }
1019 GenericDefId::TraitId(it) => {
1020 let loc = it.lookup(&db);
1021 loc.source(&db).value.name().unwrap()
1022 }
1023 GenericDefId::TypeAliasId(it) => {
1024 let loc = it.lookup(&db);
1025 loc.source(&db).value.name().unwrap()
1026 }
1027 GenericDefId::ImplId(_) => return None,
1028 GenericDefId::ConstId(_) => return None,
1029 GenericDefId::StaticId(_) => return None,
1030 },
1031 ))
1032 })
1033 .sorted_by_key(|(_, n)| n.syntax().text_range().start());
1034 let mut res = String::new();
1035 for (def, name) in defs {
1036 let Some(variances) = db.variances_of(def) else {
1037 continue;
1038 };
1039 format_to!(
1040 res,
1041 "{name}[{}]\n",
1042 generics(&db, def)
1043 .iter()
1044 .map(|(_, param)| match param {
1045 GenericParamDataRef::TypeParamData(type_param_data) => {
1046 type_param_data.name.as_ref().unwrap()
1047 }
1048 GenericParamDataRef::ConstParamData(const_param_data) =>
1049 &const_param_data.name,
1050 GenericParamDataRef::LifetimeParamData(lifetime_param_data) => {
1051 &lifetime_param_data.name
1052 }
1053 })
1054 .zip_eq(&*variances)
1055 .format_with(", ", |(name, var), f| f(&format_args!(
1056 "{}: {var}",
1057 name.as_str()
1058 )))
1059 );
1060 }
1061
1062 expected.assert_eq(&res);
1063 }
1064}