hir_ty/
variance.rs

1//! Module for inferring the variance of type and lifetime parameters. See the [rustc dev guide]
2//! chapter for more info.
3//!
4//! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/variance.html
5//!
6//! The implementation here differs from rustc. Rustc does a crate wide fixpoint resolution
7//! as the algorithm for determining variance is a fixpoint computation with potential cycles that
8//! need to be resolved. rust-analyzer does not want a crate-wide analysis though as that would hurt
9//! incrementality too much and as such our query is based on a per item basis.
10//!
11//! This does unfortunately run into the issue that we can run into query cycles which salsa
12//! currently does not allow to be resolved via a fixpoint computation. This will likely be resolved
13//! by the next salsa version. If not, we will likely have to adapt and go with the rustc approach
14//! while installing firewall per item queries to prevent invalidation issues.
15
16use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId, signatures::StructFlags};
17use rustc_ast_ir::Mutability;
18use rustc_type_ir::{
19    Variance,
20    inherent::{AdtDef, IntoKind, SliceLike},
21};
22use stdx::never;
23
24use crate::{
25    db::HirDatabase,
26    generics::{Generics, generics},
27    next_solver::{
28        Const, ConstKind, DbInterner, ExistentialPredicate, GenericArg, GenericArgs, Region,
29        RegionKind, Term, Ty, TyKind, VariancesOf,
30    },
31};
32
33pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> VariancesOf<'_> {
34    tracing::debug!("variances_of(def={:?})", def);
35    let interner = DbInterner::new_no_crate(db);
36    match def {
37        GenericDefId::FunctionId(_) => (),
38        GenericDefId::AdtId(adt) => {
39            if let AdtId::StructId(id) = adt {
40                let flags = &db.struct_signature(id).flags;
41                if flags.contains(StructFlags::IS_UNSAFE_CELL) {
42                    return VariancesOf::new_from_iter(interner, [Variance::Invariant]);
43                } else if flags.contains(StructFlags::IS_PHANTOM_DATA) {
44                    return VariancesOf::new_from_iter(interner, [Variance::Covariant]);
45                }
46            }
47        }
48        _ => return VariancesOf::new_from_iter(interner, []),
49    }
50
51    let generics = generics(db, def);
52    let count = generics.len();
53    if count == 0 {
54        return VariancesOf::new_from_iter(interner, []);
55    }
56    let mut variances =
57        Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();
58
59    // FIXME(next-solver): This is *not* the correct behavior. I don't know if it has an actual effect,
60    // since bivariance is prohibited in Rust, but rustc definitely does not fallback bivariance.
61    // So why do we do this? Because, with the new solver, the effects of bivariance are catastrophic:
62    // it leads to not relating types properly, and to very, very hard to debug bugs (speaking from experience).
63    // Furthermore, our variance infra is known to not handle cycles properly. Therefore, at least until we fix
64    // cycles, and perhaps forever at least for out tests, not allowing bivariance makes sense.
65    // Why specifically invariance? I don't have a strong reason, mainly that invariance is a stronger relationship
66    // (therefore, less room for mistakes) and that IMO incorrect covariance can be more problematic that incorrect
67    // bivariance, at least while we don't handle lifetimes anyway.
68    for variance in &mut variances {
69        if *variance == Variance::Bivariant {
70            *variance = Variance::Invariant;
71        }
72    }
73
74    VariancesOf::new_from_iter(interner, variances)
75}
76
77// pub(crate) fn variances_of_cycle_fn(
78//     _db: &dyn HirDatabase,
79//     _result: &Option<Arc<[Variance]>>,
80//     _count: u32,
81//     _def: GenericDefId,
82// ) -> salsa::CycleRecoveryAction<Option<Arc<[Variance]>>> {
83//     salsa::CycleRecoveryAction::Iterate
84// }
85
86fn glb(v1: Variance, v2: Variance) -> Variance {
87    // Greatest lower bound of the variance lattice as defined in The Paper:
88    //
89    //       *
90    //    -     +
91    //       o
92    match (v1, v2) {
93        (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
94
95        (Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
96        (Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
97
98        (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
99
100        (Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
101
102        (x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
103    }
104}
105
106pub(crate) fn variances_of_cycle_initial(
107    db: &dyn HirDatabase,
108    def: GenericDefId,
109) -> VariancesOf<'_> {
110    let interner = DbInterner::new_no_crate(db);
111    let generics = generics(db, def);
112    let count = generics.len();
113
114    // FIXME(next-solver): Returns `Invariance` and not `Bivariance` here, see the comment in the main query.
115    VariancesOf::new_from_iter(interner, std::iter::repeat_n(Variance::Invariant, count))
116}
117
118struct Context<'db> {
119    db: &'db dyn HirDatabase,
120    generics: Generics,
121    variances: Vec<Variance>,
122}
123
124impl<'db> Context<'db> {
125    fn solve(mut self) -> Vec<Variance> {
126        tracing::debug!("solve(generics={:?})", self.generics);
127        match self.generics.def() {
128            GenericDefId::AdtId(adt) => {
129                let db = self.db;
130                let mut add_constraints_from_variant = |variant| {
131                    for (_, field) in db.field_types(variant).iter() {
132                        self.add_constraints_from_ty(
133                            field.instantiate_identity(),
134                            Variance::Covariant,
135                        );
136                    }
137                };
138                match adt {
139                    AdtId::StructId(s) => add_constraints_from_variant(VariantId::StructId(s)),
140                    AdtId::UnionId(u) => add_constraints_from_variant(VariantId::UnionId(u)),
141                    AdtId::EnumId(e) => {
142                        e.enum_variants(db).variants.iter().for_each(|&(variant, _, _)| {
143                            add_constraints_from_variant(VariantId::EnumVariantId(variant))
144                        });
145                    }
146                }
147            }
148            GenericDefId::FunctionId(f) => {
149                let sig =
150                    self.db.callable_item_signature(f.into()).instantiate_identity().skip_binder();
151                self.add_constraints_from_sig(sig.inputs_and_output.iter(), Variance::Covariant);
152            }
153            _ => {}
154        }
155        let mut variances = self.variances;
156
157        // Const parameters are always invariant.
158        // Make all const parameters invariant.
159        for (idx, param) in self.generics.iter_id().enumerate() {
160            if let GenericParamId::ConstParamId(_) = param {
161                variances[idx] = Variance::Invariant;
162            }
163        }
164
165        // Functions are permitted to have unused generic parameters: make those invariant.
166        if let GenericDefId::FunctionId(_) = self.generics.def() {
167            variances
168                .iter_mut()
169                .filter(|&&mut v| v == Variance::Bivariant)
170                .for_each(|v| *v = Variance::Invariant);
171        }
172
173        variances
174    }
175
176    /// Adds constraints appropriate for an instance of `ty` appearing
177    /// in a context with the generics defined in `generics` and
178    /// ambient variance `variance`
179    fn add_constraints_from_ty(&mut self, ty: Ty<'db>, variance: Variance) {
180        tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance);
181        match ty.kind() {
182            TyKind::Int(_)
183            | TyKind::Uint(_)
184            | TyKind::Float(_)
185            | TyKind::Char
186            | TyKind::Bool
187            | TyKind::Never
188            | TyKind::Str
189            | TyKind::Foreign(..) => {
190                // leaf type -- noop
191            }
192            TyKind::FnDef(..)
193            | TyKind::Coroutine(..)
194            | TyKind::CoroutineClosure(..)
195            | TyKind::Closure(..) => {
196                never!("Unexpected unnameable type in variance computation: {:?}", ty);
197            }
198            TyKind::Ref(lifetime, ty, mutbl) => {
199                self.add_constraints_from_region(lifetime, variance);
200                self.add_constraints_from_mt(ty, mutbl, variance);
201            }
202            TyKind::Array(typ, len) => {
203                self.add_constraints_from_const(len);
204                self.add_constraints_from_ty(typ, variance);
205            }
206            TyKind::Slice(typ) => {
207                self.add_constraints_from_ty(typ, variance);
208            }
209            TyKind::RawPtr(ty, mutbl) => {
210                self.add_constraints_from_mt(ty, mutbl, variance);
211            }
212            TyKind::Tuple(subtys) => {
213                for subty in subtys {
214                    self.add_constraints_from_ty(subty, variance);
215                }
216            }
217            TyKind::Adt(def, args) => {
218                self.add_constraints_from_args(def.def_id().0.into(), args, variance);
219            }
220            TyKind::Alias(_, alias) => {
221                // FIXME: Probably not correct wrt. opaques.
222                self.add_constraints_from_invariant_args(alias.args);
223            }
224            TyKind::Dynamic(bounds, region) => {
225                // The type `dyn Trait<T> +'a` is covariant w/r/t `'a`:
226                self.add_constraints_from_region(region, variance);
227
228                for bound in bounds {
229                    match bound.skip_binder() {
230                        ExistentialPredicate::Trait(trait_ref) => {
231                            self.add_constraints_from_invariant_args(trait_ref.args)
232                        }
233                        ExistentialPredicate::Projection(projection) => {
234                            self.add_constraints_from_invariant_args(projection.args);
235                            match projection.term {
236                                Term::Ty(ty) => {
237                                    self.add_constraints_from_ty(ty, Variance::Invariant)
238                                }
239                                Term::Const(konst) => self.add_constraints_from_const(konst),
240                            }
241                        }
242                        ExistentialPredicate::AutoTrait(_) => {}
243                    }
244                }
245            }
246
247            // Chalk has no params, so use placeholders for now?
248            TyKind::Param(param) => self.constrain(param.index as usize, variance),
249            TyKind::FnPtr(sig, _) => {
250                self.add_constraints_from_sig(sig.skip_binder().inputs_and_output.iter(), variance);
251            }
252            TyKind::Error(_) => {
253                // we encounter this when walking the trait references for object
254                // types, where we use Error as the Self type
255            }
256            TyKind::Bound(..) => {}
257            TyKind::CoroutineWitness(..)
258            | TyKind::Placeholder(..)
259            | TyKind::Infer(..)
260            | TyKind::UnsafeBinder(..)
261            | TyKind::Pat(..) => {
262                never!("unexpected type encountered in variance inference: {:?}", ty)
263            }
264        }
265    }
266
267    fn add_constraints_from_invariant_args(&mut self, args: GenericArgs<'db>) {
268        for k in args.iter() {
269            match k {
270                GenericArg::Lifetime(lt) => {
271                    self.add_constraints_from_region(lt, Variance::Invariant)
272                }
273                GenericArg::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Invariant),
274                GenericArg::Const(val) => self.add_constraints_from_const(val),
275            }
276        }
277    }
278
279    /// Adds constraints appropriate for a nominal type (enum, struct,
280    /// object, etc) appearing in a context with ambient variance `variance`
281    fn add_constraints_from_args(
282        &mut self,
283        def_id: GenericDefId,
284        args: GenericArgs<'db>,
285        variance: Variance,
286    ) {
287        if args.is_empty() {
288            return;
289        }
290        let variances = self.db.variances_of(def_id);
291
292        for (k, v) in args.iter().zip(variances) {
293            match k {
294                GenericArg::Lifetime(lt) => self.add_constraints_from_region(lt, variance.xform(v)),
295                GenericArg::Ty(ty) => self.add_constraints_from_ty(ty, variance.xform(v)),
296                GenericArg::Const(val) => self.add_constraints_from_const(val),
297            }
298        }
299    }
300
301    /// Adds constraints appropriate for a const expression `val`
302    /// in a context with ambient variance `variance`
303    fn add_constraints_from_const(&mut self, c: Const<'db>) {
304        match c.kind() {
305            ConstKind::Unevaluated(c) => self.add_constraints_from_invariant_args(c.args),
306            _ => {}
307        }
308    }
309
310    /// Adds constraints appropriate for a function with signature
311    /// `sig` appearing in a context with ambient variance `variance`
312    fn add_constraints_from_sig(
313        &mut self,
314        mut sig_tys: impl DoubleEndedIterator<Item = Ty<'db>>,
315        variance: Variance,
316    ) {
317        let contra = variance.xform(Variance::Contravariant);
318        let Some(output) = sig_tys.next_back() else {
319            return never!("function signature has no return type");
320        };
321        self.add_constraints_from_ty(output, variance);
322        for input in sig_tys {
323            self.add_constraints_from_ty(input, contra);
324        }
325    }
326
327    /// Adds constraints appropriate for a region appearing in a
328    /// context with ambient variance `variance`
329    fn add_constraints_from_region(&mut self, region: Region<'db>, variance: Variance) {
330        tracing::debug!(
331            "add_constraints_from_region(region={:?}, variance={:?})",
332            region,
333            variance
334        );
335        match region.kind() {
336            RegionKind::ReEarlyParam(param) => self.constrain(param.index as usize, variance),
337            RegionKind::ReStatic => {}
338            RegionKind::ReBound(..) => {
339                // Either a higher-ranked region inside of a type or a
340                // late-bound function parameter.
341                //
342                // We do not compute constraints for either of these.
343            }
344            RegionKind::ReError(_) => {}
345            RegionKind::ReLateParam(..)
346            | RegionKind::RePlaceholder(..)
347            | RegionKind::ReVar(..)
348            | RegionKind::ReErased => {
349                // We don't expect to see anything but 'static or bound
350                // regions when visiting member types or method types.
351                never!(
352                    "unexpected region encountered in variance \
353                      inference: {:?}",
354                    region
355                );
356            }
357        }
358    }
359
360    /// Adds constraints appropriate for a mutability-type pair
361    /// appearing in a context with ambient variance `variance`
362    fn add_constraints_from_mt(&mut self, ty: Ty<'db>, mt: Mutability, variance: Variance) {
363        self.add_constraints_from_ty(
364            ty,
365            match mt {
366                Mutability::Mut => Variance::Invariant,
367                Mutability::Not => variance,
368            },
369        );
370    }
371
372    fn constrain(&mut self, index: usize, variance: Variance) {
373        tracing::debug!(
374            "constrain(index={:?}, variance={:?}, to={:?})",
375            index,
376            self.variances[index],
377            variance
378        );
379        self.variances[index] = glb(self.variances[index], variance);
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use expect_test::{Expect, expect};
386    use hir_def::{
387        AdtId, GenericDefId, ModuleDefId, hir::generics::GenericParamDataRef, src::HasSource,
388    };
389    use itertools::Itertools;
390    use rustc_type_ir::{Variance, inherent::SliceLike};
391    use stdx::format_to;
392    use syntax::{AstNode, ast::HasName};
393    use test_fixture::WithFixture;
394
395    use hir_def::Lookup;
396
397    use crate::{db::HirDatabase, test_db::TestDB, variance::generics};
398
399    #[test]
400    fn phantom_data() {
401        check(
402            r#"
403//- minicore: phantom_data
404
405struct Covariant<A> {
406    t: core::marker::PhantomData<A>
407}
408"#,
409            expect![[r#"
410                Covariant[A: covariant]
411            "#]],
412        );
413    }
414
415    #[test]
416    fn rustc_test_variance_types() {
417        check(
418            r#"
419//- minicore: cell
420
421use core::cell::UnsafeCell;
422
423struct InvariantMut<'a,A:'a,B:'a> { //~ ERROR ['a: +, A: o, B: o]
424    t: &'a mut (A,B)
425}
426
427struct InvariantCell<A> { //~ ERROR [A: o]
428    t: UnsafeCell<A>
429}
430
431struct InvariantIndirect<A> { //~ ERROR [A: o]
432    t: InvariantCell<A>
433}
434
435struct Covariant<A> { //~ ERROR [A: +]
436    t: A, u: fn() -> A
437}
438
439struct Contravariant<A> { //~ ERROR [A: -]
440    t: fn(A)
441}
442
443enum Enum<A,B,C> { //~ ERROR [A: +, B: -, C: o]
444    Foo(Covariant<A>),
445    Bar(Contravariant<B>),`
446    Zed(Covariant<C>,Contravariant<C>)
447}
448"#,
449            expect![[r#"
450                InvariantMut['a: covariant, A: invariant, B: invariant]
451                InvariantCell[A: invariant]
452                InvariantIndirect[A: invariant]
453                Covariant[A: covariant]
454                Contravariant[A: contravariant]
455                Enum[A: covariant, B: contravariant, C: invariant]
456            "#]],
457        );
458    }
459
460    #[test]
461    fn type_resolve_error_two_structs_deep() {
462        check(
463            r#"
464struct Hello<'a> {
465    missing: Missing<'a>,
466}
467
468struct Other<'a> {
469    hello: Hello<'a>,
470}
471"#,
472            expect![[r#"
473                Hello['a: invariant]
474                Other['a: invariant]
475            "#]],
476        );
477    }
478
479    #[test]
480    fn rustc_test_variance_associated_consts() {
481        // FIXME: Should be invariant
482        check(
483            r#"
484trait Trait {
485    const Const: usize;
486}
487
488struct Foo<T: Trait> { //~ ERROR [T: o]
489    field: [u8; <T as Trait>::Const]
490}
491"#,
492            expect![[r#"
493                Foo[T: invariant]
494            "#]],
495        );
496    }
497
498    #[test]
499    fn rustc_test_variance_associated_types() {
500        check(
501            r#"
502trait Trait<'a> {
503    type Type;
504
505    fn method(&'a self) { }
506}
507
508struct Foo<'a, T : Trait<'a>> { //~ ERROR ['a: +, T: +]
509    field: (T, &'a ())
510}
511
512struct Bar<'a, T : Trait<'a>> { //~ ERROR ['a: o, T: o]
513    field: <T as Trait<'a>>::Type
514}
515
516"#,
517            expect![[r#"
518                method[Self: contravariant, 'a: contravariant]
519                Foo['a: covariant, T: covariant]
520                Bar['a: invariant, T: invariant]
521            "#]],
522        );
523    }
524
525    #[test]
526    fn rustc_test_variance_associated_types2() {
527        // FIXME: RPITs have variance, but we can't treat them as their own thing right now
528        check(
529            r#"
530trait Foo {
531    type Bar;
532}
533
534fn make() -> *const dyn Foo<Bar = &'static u32> {}
535"#,
536            expect![""],
537        );
538    }
539
540    #[test]
541    fn rustc_test_variance_trait_bounds() {
542        check(
543            r#"
544trait Getter<T> {
545    fn get(&self) -> T;
546}
547
548trait Setter<T> {
549    fn get(&self, _: T);
550}
551
552struct TestStruct<U,T:Setter<U>> { //~ ERROR [U: +, T: +]
553    t: T, u: U
554}
555
556enum TestEnum<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
557    //~^ ERROR: `U` is never used
558    Foo(T)
559}
560
561struct TestContraStruct<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
562    //~^ ERROR: `U` is never used
563    t: T
564}
565
566struct TestBox<U,T:Getter<U>+Setter<U>> { //~ ERROR [U: *, T: +]
567    //~^ ERROR: `U` is never used
568    t: T
569}
570"#,
571            expect![[r#"
572                get[Self: contravariant, T: covariant]
573                get[Self: contravariant, T: contravariant]
574                TestStruct[U: covariant, T: covariant]
575                TestEnum[U: invariant, T: covariant]
576                TestContraStruct[U: invariant, T: covariant]
577                TestBox[U: invariant, T: covariant]
578            "#]],
579        );
580    }
581
582    #[test]
583    fn rustc_test_variance_trait_matching() {
584        check(
585            r#"
586
587trait Get<T> {
588    fn get(&self) -> T;
589}
590
591struct Cloner<T:Clone> {
592    t: T
593}
594
595impl<T:Clone> Get<T> for Cloner<T> {
596    fn get(&self) -> T {}
597}
598
599fn get<'a, G>(get: &G) -> i32
600    where G : Get<&'a i32>
601{}
602
603fn pick<'b, G>(get: &'b G, if_odd: &'b i32) -> i32
604    where G : Get<&'b i32>
605{}
606"#,
607            expect![[r#"
608                get[Self: contravariant, T: covariant]
609                Cloner[T: covariant]
610                get[T: invariant]
611                get['a: invariant, G: contravariant]
612                pick['b: contravariant, G: contravariant]
613            "#]],
614        );
615    }
616
617    #[test]
618    fn rustc_test_variance_trait_object_bound() {
619        check(
620            r#"
621enum Option<T> {
622    Some(T),
623    None
624}
625trait T { fn foo(&self); }
626
627struct TOption<'a> { //~ ERROR ['a: +]
628    v: Option<*const (dyn T + 'a)>,
629}
630"#,
631            expect![[r#"
632                Option[T: covariant]
633                foo[Self: contravariant]
634                TOption['a: covariant]
635            "#]],
636        );
637    }
638
639    #[test]
640    fn rustc_test_variance_types_bounds() {
641        check(
642            r#"
643//- minicore: send
644struct TestImm<A, B> { //~ ERROR [A: +, B: +]
645    x: A,
646    y: B,
647}
648
649struct TestMut<A, B:'static> { //~ ERROR [A: +, B: o]
650    x: A,
651    y: &'static mut B,
652}
653
654struct TestIndirect<A:'static, B:'static> { //~ ERROR [A: +, B: o]
655    m: TestMut<A, B>
656}
657
658struct TestIndirect2<A:'static, B:'static> { //~ ERROR [A: o, B: o]
659    n: TestMut<A, B>,
660    m: TestMut<B, A>
661}
662
663trait Getter<A> {
664    fn get(&self) -> A;
665}
666
667trait Setter<A> {
668    fn set(&mut self, a: A);
669}
670
671struct TestObject<A, R> { //~ ERROR [A: o, R: o]
672    n: *const (dyn Setter<A> + Send),
673    m: *const (dyn Getter<R> + Send),
674}
675"#,
676            expect![[r#"
677                TestImm[A: covariant, B: covariant]
678                TestMut[A: covariant, B: invariant]
679                TestIndirect[A: covariant, B: invariant]
680                TestIndirect2[A: invariant, B: invariant]
681                get[Self: contravariant, A: covariant]
682                set[Self: invariant, A: contravariant]
683                TestObject[A: invariant, R: invariant]
684            "#]],
685        );
686    }
687
688    #[test]
689    fn rustc_test_variance_unused_region_param() {
690        check(
691            r#"
692struct SomeStruct<'a> { x: u32 } //~ ERROR parameter `'a` is never used
693enum SomeEnum<'a> { Nothing } //~ ERROR parameter `'a` is never used
694trait SomeTrait<'a> { fn foo(&self); } // OK on traits.
695"#,
696            expect![[r#"
697                SomeStruct['a: invariant]
698                SomeEnum['a: invariant]
699                foo[Self: contravariant, 'a: invariant]
700            "#]],
701        );
702    }
703
704    #[test]
705    fn rustc_test_variance_unused_type_param() {
706        check(
707            r#"
708//- minicore: sized
709struct SomeStruct<A> { x: u32 }
710enum SomeEnum<A> { Nothing }
711enum ListCell<T> {
712    Cons(*const ListCell<T>),
713    Nil
714}
715
716struct SelfTyAlias<T>(*const Self);
717struct WithBounds<T: Sized> {}
718struct WithWhereBounds<T> where T: Sized {}
719struct WithOutlivesBounds<T: 'static> {}
720struct DoubleNothing<T> {
721    s: SomeStruct<T>,
722}
723
724"#,
725            expect![[r#"
726                SomeStruct[A: invariant]
727                SomeEnum[A: invariant]
728                ListCell[T: invariant]
729                SelfTyAlias[T: invariant]
730                WithBounds[T: invariant]
731                WithWhereBounds[T: invariant]
732                WithOutlivesBounds[T: invariant]
733                DoubleNothing[T: invariant]
734            "#]],
735        );
736    }
737
738    #[test]
739    fn rustc_test_variance_use_contravariant_struct1() {
740        check(
741            r#"
742struct SomeStruct<T>(fn(T));
743
744fn foo<'min,'max>(v: SomeStruct<&'max ()>)
745                  -> SomeStruct<&'min ()>
746    where 'max : 'min
747{}
748"#,
749            expect![[r#"
750                SomeStruct[T: contravariant]
751                foo['min: contravariant, 'max: covariant]
752            "#]],
753        );
754    }
755
756    #[test]
757    fn rustc_test_variance_use_contravariant_struct2() {
758        check(
759            r#"
760struct SomeStruct<T>(fn(T));
761
762fn bar<'min,'max>(v: SomeStruct<&'min ()>)
763                  -> SomeStruct<&'max ()>
764    where 'max : 'min
765{}
766"#,
767            expect![[r#"
768                SomeStruct[T: contravariant]
769                bar['min: covariant, 'max: contravariant]
770            "#]],
771        );
772    }
773
774    #[test]
775    fn rustc_test_variance_use_covariant_struct1() {
776        check(
777            r#"
778struct SomeStruct<T>(T);
779
780fn foo<'min,'max>(v: SomeStruct<&'min ()>)
781                  -> SomeStruct<&'max ()>
782    where 'max : 'min
783{}
784"#,
785            expect![[r#"
786                SomeStruct[T: covariant]
787                foo['min: contravariant, 'max: covariant]
788            "#]],
789        );
790    }
791
792    #[test]
793    fn rustc_test_variance_use_covariant_struct2() {
794        check(
795            r#"
796struct SomeStruct<T>(T);
797
798fn foo<'min,'max>(v: SomeStruct<&'max ()>)
799                  -> SomeStruct<&'min ()>
800    where 'max : 'min
801{}
802"#,
803            expect![[r#"
804                SomeStruct[T: covariant]
805                foo['min: covariant, 'max: contravariant]
806            "#]],
807        );
808    }
809
810    #[test]
811    fn rustc_test_variance_use_invariant_struct1() {
812        check(
813            r#"
814struct SomeStruct<T>(*mut T);
815
816fn foo<'min,'max>(v: SomeStruct<&'max ()>)
817                  -> SomeStruct<&'min ()>
818    where 'max : 'min
819{}
820
821fn bar<'min,'max>(v: SomeStruct<&'min ()>)
822                  -> SomeStruct<&'max ()>
823    where 'max : 'min
824{}
825"#,
826            expect![[r#"
827                SomeStruct[T: invariant]
828                foo['min: invariant, 'max: invariant]
829                bar['min: invariant, 'max: invariant]
830            "#]],
831        );
832    }
833
834    #[test]
835    fn invalid_arg_counts() {
836        check(
837            r#"
838struct S<T>(T);
839struct S2<T>(S<>);
840struct S3<T>(S<T, T>);
841"#,
842            expect![[r#"
843                S[T: covariant]
844                S2[T: invariant]
845                S3[T: covariant]
846            "#]],
847        );
848    }
849
850    #[test]
851    fn prove_fixedpoint() {
852        check(
853            r#"
854struct FixedPoint<T, U, V>(&'static FixedPoint<(), T, U>, V);
855"#,
856            expect![[r#"
857                FixedPoint[T: invariant, U: invariant, V: invariant]
858            "#]],
859        );
860    }
861
862    #[track_caller]
863    fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected: Expect) {
864        // use tracing_subscriber::{layer::SubscriberExt, Layer};
865        // let my_layer = tracing_subscriber::fmt::layer();
866        // let _g = tracing::subscriber::set_default(tracing_subscriber::registry().with(
867        //     my_layer.with_filter(tracing_subscriber::filter::filter_fn(|metadata| {
868        //         metadata.target().starts_with("hir_ty::variance")
869        //     })),
870        // ));
871        let (db, file_id) = TestDB::with_single_file(ra_fixture);
872
873        crate::attach_db(&db, || {
874            let mut defs: Vec<GenericDefId> = Vec::new();
875            let module = db.module_for_file_opt(file_id.file_id(&db)).unwrap();
876            let def_map = module.def_map(&db);
877            crate::tests::visit_module(&db, def_map, module, &mut |it| {
878                defs.push(match it {
879                    ModuleDefId::FunctionId(it) => it.into(),
880                    ModuleDefId::AdtId(it) => it.into(),
881                    ModuleDefId::ConstId(it) => it.into(),
882                    ModuleDefId::TraitId(it) => it.into(),
883                    ModuleDefId::TypeAliasId(it) => it.into(),
884                    _ => return,
885                })
886            });
887            let defs = defs
888                .into_iter()
889                .filter_map(|def| {
890                    Some((
891                        def,
892                        match def {
893                            GenericDefId::FunctionId(it) => {
894                                let loc = it.lookup(&db);
895                                loc.source(&db).value.name().unwrap()
896                            }
897                            GenericDefId::AdtId(AdtId::EnumId(it)) => {
898                                let loc = it.lookup(&db);
899                                loc.source(&db).value.name().unwrap()
900                            }
901                            GenericDefId::AdtId(AdtId::StructId(it)) => {
902                                let loc = it.lookup(&db);
903                                loc.source(&db).value.name().unwrap()
904                            }
905                            GenericDefId::AdtId(AdtId::UnionId(it)) => {
906                                let loc = it.lookup(&db);
907                                loc.source(&db).value.name().unwrap()
908                            }
909                            GenericDefId::TraitId(_)
910                            | GenericDefId::TypeAliasId(_)
911                            | GenericDefId::ImplId(_)
912                            | GenericDefId::ConstId(_)
913                            | GenericDefId::StaticId(_) => return None,
914                        },
915                    ))
916                })
917                .sorted_by_key(|(_, n)| n.syntax().text_range().start());
918            let mut res = String::new();
919            for (def, name) in defs {
920                let variances = db.variances_of(def);
921                if variances.is_empty() {
922                    continue;
923                }
924                format_to!(
925                    res,
926                    "{name}[{}]\n",
927                    generics(&db, def)
928                        .iter()
929                        .map(|(_, param)| match param {
930                            GenericParamDataRef::TypeParamData(type_param_data) => {
931                                type_param_data.name.as_ref().unwrap()
932                            }
933                            GenericParamDataRef::ConstParamData(const_param_data) =>
934                                &const_param_data.name,
935                            GenericParamDataRef::LifetimeParamData(lifetime_param_data) => {
936                                &lifetime_param_data.name
937                            }
938                        })
939                        .zip_eq(variances)
940                        .format_with(", ", |(name, var), f| f(&format_args!(
941                            "{}: {}",
942                            name.as_str(),
943                            match var {
944                                Variance::Covariant => "covariant",
945                                Variance::Invariant => "invariant",
946                                Variance::Contravariant => "contravariant",
947                                Variance::Bivariant => "bivariant",
948                            },
949                        )))
950                );
951            }
952
953            expected.assert_eq(&res);
954        })
955    }
956}