hir_ty/
variance.rs

1//! Module for inferring the variance of type and lifetime parameters. See the [rustc dev guide]
2//! chapter for more info.
3//!
4//! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/variance.html
5//!
6//! The implementation here differs from rustc. Rustc does a crate wide fixpoint resolution
7//! as the algorithm for determining variance is a fixpoint computation with potential cycles that
8//! need to be resolved. rust-analyzer does not want a crate-wide analysis though as that would hurt
9//! incrementality too much and as such our query is based on a per item basis.
10//!
11//! This does unfortunately run into the issue that we can run into query cycles which salsa
12//! currently does not allow to be resolved via a fixpoint computation. This will likely be resolved
13//! by the next salsa version. If not, we will likely have to adapt and go with the rustc approach
14//! while installing firewall per item queries to prevent invalidation issues.
15
16use crate::db::HirDatabase;
17use crate::generics::{Generics, generics};
18use crate::{
19    AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime,
20    LifetimeData, Ty, TyKind,
21};
22use chalk_ir::Mutability;
23use hir_def::signatures::StructFlags;
24use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId};
25use std::fmt;
26use std::ops::Not;
27use stdx::never;
28use triomphe::Arc;
29
30pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> {
31    tracing::debug!("variances_of(def={:?})", def);
32    match def {
33        GenericDefId::FunctionId(_) => (),
34        GenericDefId::AdtId(adt) => {
35            if let AdtId::StructId(id) = adt {
36                let flags = &db.struct_signature(id).flags;
37                if flags.contains(StructFlags::IS_UNSAFE_CELL) {
38                    return Some(Arc::from_iter(vec![Variance::Invariant; 1]));
39                } else if flags.contains(StructFlags::IS_PHANTOM_DATA) {
40                    return Some(Arc::from_iter(vec![Variance::Covariant; 1]));
41                }
42            }
43        }
44        _ => return None,
45    }
46
47    let generics = generics(db, def);
48    let count = generics.len();
49    if count == 0 {
50        return None;
51    }
52    let variances = Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();
53
54    variances.is_empty().not().then(|| Arc::from_iter(variances))
55}
56
57// pub(crate) fn variances_of_cycle_fn(
58//     _db: &dyn HirDatabase,
59//     _result: &Option<Arc<[Variance]>>,
60//     _count: u32,
61//     _def: GenericDefId,
62// ) -> salsa::CycleRecoveryAction<Option<Arc<[Variance]>>> {
63//     salsa::CycleRecoveryAction::Iterate
64// }
65
66pub(crate) fn variances_of_cycle_initial(
67    db: &dyn HirDatabase,
68    def: GenericDefId,
69) -> Option<Arc<[Variance]>> {
70    let generics = generics(db, def);
71    let count = generics.len();
72
73    if count == 0 {
74        return None;
75    }
76    Some(Arc::from(vec![Variance::Bivariant; count]))
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
80pub enum Variance {
81    Covariant,     // T<A> <: T<B> iff A <: B -- e.g., function return type
82    Invariant,     // T<A> <: T<B> iff B == A -- e.g., type of mutable cell
83    Contravariant, // T<A> <: T<B> iff B <: A -- e.g., function param type
84    Bivariant,     // T<A> <: T<B>            -- e.g., unused type parameter
85}
86
87impl fmt::Display for Variance {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Variance::Covariant => write!(f, "covariant"),
91            Variance::Invariant => write!(f, "invariant"),
92            Variance::Contravariant => write!(f, "contravariant"),
93            Variance::Bivariant => write!(f, "bivariant"),
94        }
95    }
96}
97
98impl Variance {
99    /// `a.xform(b)` combines the variance of a context with the
100    /// variance of a type with the following meaning. If we are in a
101    /// context with variance `a`, and we encounter a type argument in
102    /// a position with variance `b`, then `a.xform(b)` is the new
103    /// variance with which the argument appears.
104    ///
105    /// Example 1:
106    /// ```ignore (illustrative)
107    /// *mut Vec<i32>
108    /// ```
109    /// Here, the "ambient" variance starts as covariant. `*mut T` is
110    /// invariant with respect to `T`, so the variance in which the
111    /// `Vec<i32>` appears is `Covariant.xform(Invariant)`, which
112    /// yields `Invariant`. Now, the type `Vec<T>` is covariant with
113    /// respect to its type argument `T`, and hence the variance of
114    /// the `i32` here is `Invariant.xform(Covariant)`, which results
115    /// (again) in `Invariant`.
116    ///
117    /// Example 2:
118    /// ```ignore (illustrative)
119    /// fn(*const Vec<i32>, *mut Vec<i32)
120    /// ```
121    /// The ambient variance is covariant. A `fn` type is
122    /// contravariant with respect to its parameters, so the variance
123    /// within which both pointer types appear is
124    /// `Covariant.xform(Contravariant)`, or `Contravariant`. `*const
125    /// T` is covariant with respect to `T`, so the variance within
126    /// which the first `Vec<i32>` appears is
127    /// `Contravariant.xform(Covariant)` or `Contravariant`. The same
128    /// is true for its `i32` argument. In the `*mut T` case, the
129    /// variance of `Vec<i32>` is `Contravariant.xform(Invariant)`,
130    /// and hence the outermost type is `Invariant` with respect to
131    /// `Vec<i32>` (and its `i32` argument).
132    ///
133    /// Source: Figure 1 of "Taming the Wildcards:
134    /// Combining Definition- and Use-Site Variance" published in PLDI'11.
135    fn xform(self, v: Variance) -> Variance {
136        match (self, v) {
137            // Figure 1, column 1.
138            (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
139            (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
140            (Variance::Covariant, Variance::Invariant) => Variance::Invariant,
141            (Variance::Covariant, Variance::Bivariant) => Variance::Bivariant,
142
143            // Figure 1, column 2.
144            (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
145            (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
146            (Variance::Contravariant, Variance::Invariant) => Variance::Invariant,
147            (Variance::Contravariant, Variance::Bivariant) => Variance::Bivariant,
148
149            // Figure 1, column 3.
150            (Variance::Invariant, _) => Variance::Invariant,
151
152            // Figure 1, column 4.
153            (Variance::Bivariant, _) => Variance::Bivariant,
154        }
155    }
156
157    fn glb(self, v: Variance) -> Variance {
158        // Greatest lower bound of the variance lattice as
159        // defined in The Paper:
160        //
161        //       *
162        //    -     +
163        //       o
164        match (self, v) {
165            (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
166
167            (Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
168            (Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
169
170            (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
171
172            (Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
173
174            (x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
175        }
176    }
177
178    pub fn invariant(self) -> Self {
179        self.xform(Variance::Invariant)
180    }
181
182    pub fn covariant(self) -> Self {
183        self.xform(Variance::Covariant)
184    }
185
186    pub fn contravariant(self) -> Self {
187        self.xform(Variance::Contravariant)
188    }
189}
190
191struct Context<'db> {
192    db: &'db dyn HirDatabase,
193    generics: Generics,
194    variances: Vec<Variance>,
195}
196
197impl Context<'_> {
198    fn solve(mut self) -> Vec<Variance> {
199        tracing::debug!("solve(generics={:?})", self.generics);
200        match self.generics.def() {
201            GenericDefId::AdtId(adt) => {
202                let db = self.db;
203                let mut add_constraints_from_variant = |variant| {
204                    let subst = self.generics.placeholder_subst(db);
205                    for (_, field) in db.field_types(variant).iter() {
206                        self.add_constraints_from_ty(
207                            &field.clone().substitute(Interner, &subst),
208                            Variance::Covariant,
209                        );
210                    }
211                };
212                match adt {
213                    AdtId::StructId(s) => add_constraints_from_variant(VariantId::StructId(s)),
214                    AdtId::UnionId(u) => add_constraints_from_variant(VariantId::UnionId(u)),
215                    AdtId::EnumId(e) => {
216                        e.enum_variants(db).variants.iter().for_each(|&(variant, _, _)| {
217                            add_constraints_from_variant(VariantId::EnumVariantId(variant))
218                        });
219                    }
220                }
221            }
222            GenericDefId::FunctionId(f) => {
223                let subst = self.generics.placeholder_subst(self.db);
224                self.add_constraints_from_sig(
225                    self.db
226                        .callable_item_signature(f.into())
227                        .substitute(Interner, &subst)
228                        .params_and_return
229                        .iter(),
230                    Variance::Covariant,
231                );
232            }
233            _ => {}
234        }
235        let mut variances = self.variances;
236
237        // Const parameters are always invariant.
238        // Make all const parameters invariant.
239        for (idx, param) in self.generics.iter_id().enumerate() {
240            if let GenericParamId::ConstParamId(_) = param {
241                variances[idx] = Variance::Invariant;
242            }
243        }
244
245        // Functions are permitted to have unused generic parameters: make those invariant.
246        if let GenericDefId::FunctionId(_) = self.generics.def() {
247            variances
248                .iter_mut()
249                .filter(|&&mut v| v == Variance::Bivariant)
250                .for_each(|v| *v = Variance::Invariant);
251        }
252
253        variances
254    }
255
256    /// Adds constraints appropriate for an instance of `ty` appearing
257    /// in a context with the generics defined in `generics` and
258    /// ambient variance `variance`
259    fn add_constraints_from_ty(&mut self, ty: &Ty, variance: Variance) {
260        tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance);
261        match ty.kind(Interner) {
262            TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => {
263                // leaf type -- noop
264            }
265            TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => {
266                never!("Unexpected unnameable type in variance computation: {:?}", ty);
267            }
268            TyKind::Ref(mutbl, lifetime, ty) => {
269                self.add_constraints_from_region(lifetime, variance);
270                self.add_constraints_from_mt(ty, *mutbl, variance);
271            }
272            TyKind::Array(typ, len) => {
273                self.add_constraints_from_const(len, variance);
274                self.add_constraints_from_ty(typ, variance);
275            }
276            TyKind::Slice(typ) => {
277                self.add_constraints_from_ty(typ, variance);
278            }
279            TyKind::Raw(mutbl, ty) => {
280                self.add_constraints_from_mt(ty, *mutbl, variance);
281            }
282            TyKind::Tuple(_, subtys) => {
283                for subty in subtys.type_parameters(Interner) {
284                    self.add_constraints_from_ty(&subty, variance);
285                }
286            }
287            TyKind::Adt(def, args) => {
288                self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance);
289            }
290            TyKind::Alias(AliasTy::Opaque(opaque)) => {
291                self.add_constraints_from_invariant_args(
292                    opaque.substitution.as_slice(Interner),
293                    variance,
294                );
295            }
296            TyKind::Alias(AliasTy::Projection(proj)) => {
297                self.add_constraints_from_invariant_args(
298                    proj.substitution.as_slice(Interner),
299                    variance,
300                );
301            }
302            // FIXME: check this
303            TyKind::AssociatedType(_, subst) => {
304                self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
305            }
306            // FIXME: check this
307            TyKind::OpaqueType(_, subst) => {
308                self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
309            }
310            TyKind::Dyn(it) => {
311                // The type `dyn Trait<T> +'a` is covariant w/r/t `'a`:
312                self.add_constraints_from_region(&it.lifetime, variance);
313
314                if let Some(trait_ref) = it.principal() {
315                    // Trait are always invariant so we can take advantage of that.
316                    self.add_constraints_from_invariant_args(
317                        trait_ref
318                            .map(|it| it.map(|it| it.substitution.clone()))
319                            .substitute(
320                                Interner,
321                                &[GenericArg::new(
322                                    Interner,
323                                    chalk_ir::GenericArgData::Ty(TyKind::Error.intern(Interner)),
324                                )],
325                            )
326                            .skip_binders()
327                            .as_slice(Interner),
328                        variance,
329                    );
330                }
331
332                // FIXME
333                // for projection in data.projection_bounds() {
334                //     match projection.skip_binder().term.unpack() {
335                //         TyKind::TermKind::Ty(ty) => {
336                //             self.add_constraints_from_ty( ty, self.invariant);
337                //         }
338                //         TyKind::TermKind::Const(c) => {
339                //             self.add_constraints_from_const( c, self.invariant)
340                //         }
341                //     }
342                // }
343            }
344
345            // Chalk has no params, so use placeholders for now?
346            TyKind::Placeholder(index) => {
347                let idx = crate::from_placeholder_idx(self.db, *index).0;
348                let index = self.generics.type_or_const_param_idx(idx).unwrap();
349                self.constrain(index, variance);
350            }
351            TyKind::Function(f) => {
352                self.add_constraints_from_sig(
353                    f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)),
354                    variance,
355                );
356            }
357            TyKind::Error => {
358                // we encounter this when walking the trait references for object
359                // types, where we use Error as the Self type
360            }
361            TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => {
362                never!("unexpected type encountered in variance inference: {:?}", ty)
363            }
364        }
365    }
366
367    fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) {
368        let variance_i = variance.invariant();
369
370        for k in args {
371            match k.data(Interner) {
372                GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i),
373                GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
374                GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i),
375            }
376        }
377    }
378
379    /// Adds constraints appropriate for a nominal type (enum, struct,
380    /// object, etc) appearing in a context with ambient variance `variance`
381    fn add_constraints_from_args(
382        &mut self,
383        def_id: GenericDefId,
384        args: &[GenericArg],
385        variance: Variance,
386    ) {
387        // We don't record `inferred_starts` entries for empty generics.
388        if args.is_empty() {
389            return;
390        }
391        let Some(variances) = self.db.variances_of(def_id) else {
392            return;
393        };
394
395        for (i, k) in args.iter().enumerate() {
396            match k.data(Interner) {
397                GenericArgData::Lifetime(lt) => {
398                    self.add_constraints_from_region(lt, variance.xform(variances[i]))
399                }
400                GenericArgData::Ty(ty) => {
401                    self.add_constraints_from_ty(ty, variance.xform(variances[i]))
402                }
403                GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
404            }
405        }
406    }
407
408    /// Adds constraints appropriate for a const expression `val`
409    /// in a context with ambient variance `variance`
410    fn add_constraints_from_const(&mut self, c: &Const, variance: Variance) {
411        match &c.data(Interner).value {
412            chalk_ir::ConstValue::Concrete(c) => {
413                if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned {
414                    self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
415                }
416            }
417            _ => {}
418        }
419    }
420
421    /// Adds constraints appropriate for a function with signature
422    /// `sig` appearing in a context with ambient variance `variance`
423    fn add_constraints_from_sig<'a>(
424        &mut self,
425        mut sig_tys: impl DoubleEndedIterator<Item = &'a Ty>,
426        variance: Variance,
427    ) {
428        let contra = variance.contravariant();
429        let Some(output) = sig_tys.next_back() else {
430            return never!("function signature has no return type");
431        };
432        self.add_constraints_from_ty(output, variance);
433        for input in sig_tys {
434            self.add_constraints_from_ty(input, contra);
435        }
436    }
437
438    /// Adds constraints appropriate for a region appearing in a
439    /// context with ambient variance `variance`
440    fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) {
441        tracing::debug!(
442            "add_constraints_from_region(region={:?}, variance={:?})",
443            region,
444            variance
445        );
446        match region.data(Interner) {
447            LifetimeData::Placeholder(index) => {
448                let idx = crate::lt_from_placeholder_idx(self.db, *index).0;
449                let inferred = self.generics.lifetime_idx(idx).unwrap();
450                self.constrain(inferred, variance);
451            }
452            LifetimeData::Static => {}
453            LifetimeData::BoundVar(..) => {
454                // Either a higher-ranked region inside of a type or a
455                // late-bound function parameter.
456                //
457                // We do not compute constraints for either of these.
458            }
459            LifetimeData::Error => {}
460            LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => {
461                // We don't expect to see anything but 'static or bound
462                // regions when visiting member types or method types.
463                never!(
464                    "unexpected region encountered in variance \
465                      inference: {:?}",
466                    region
467                );
468            }
469        }
470    }
471
472    /// Adds constraints appropriate for a mutability-type pair
473    /// appearing in a context with ambient variance `variance`
474    fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) {
475        self.add_constraints_from_ty(
476            ty,
477            match mt {
478                Mutability::Mut => variance.invariant(),
479                Mutability::Not => variance,
480            },
481        );
482    }
483
484    fn constrain(&mut self, index: usize, variance: Variance) {
485        tracing::debug!(
486            "constrain(index={:?}, variance={:?}, to={:?})",
487            index,
488            self.variances[index],
489            variance
490        );
491        self.variances[index] = self.variances[index].glb(variance);
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use expect_test::{Expect, expect};
498    use hir_def::{
499        AdtId, GenericDefId, ModuleDefId, hir::generics::GenericParamDataRef, src::HasSource,
500    };
501    use itertools::Itertools;
502    use stdx::format_to;
503    use syntax::{AstNode, ast::HasName};
504    use test_fixture::WithFixture;
505
506    use hir_def::Lookup;
507
508    use crate::{db::HirDatabase, test_db::TestDB, variance::generics};
509
510    #[test]
511    fn phantom_data() {
512        check(
513            r#"
514//- minicore: phantom_data
515
516struct Covariant<A> {
517    t: core::marker::PhantomData<A>
518}
519"#,
520            expect![[r#"
521                Covariant[A: covariant]
522            "#]],
523        );
524    }
525
526    #[test]
527    fn rustc_test_variance_types() {
528        check(
529            r#"
530//- minicore: cell
531
532use core::cell::UnsafeCell;
533
534struct InvariantMut<'a,A:'a,B:'a> { //~ ERROR ['a: +, A: o, B: o]
535    t: &'a mut (A,B)
536}
537
538struct InvariantCell<A> { //~ ERROR [A: o]
539    t: UnsafeCell<A>
540}
541
542struct InvariantIndirect<A> { //~ ERROR [A: o]
543    t: InvariantCell<A>
544}
545
546struct Covariant<A> { //~ ERROR [A: +]
547    t: A, u: fn() -> A
548}
549
550struct Contravariant<A> { //~ ERROR [A: -]
551    t: fn(A)
552}
553
554enum Enum<A,B,C> { //~ ERROR [A: +, B: -, C: o]
555    Foo(Covariant<A>),
556    Bar(Contravariant<B>),`
557    Zed(Covariant<C>,Contravariant<C>)
558}
559"#,
560            expect![[r#"
561                InvariantMut['a: covariant, A: invariant, B: invariant]
562                InvariantCell[A: invariant]
563                InvariantIndirect[A: invariant]
564                Covariant[A: covariant]
565                Contravariant[A: contravariant]
566                Enum[A: covariant, B: contravariant, C: invariant]
567            "#]],
568        );
569    }
570
571    #[test]
572    fn type_resolve_error_two_structs_deep() {
573        check(
574            r#"
575struct Hello<'a> {
576    missing: Missing<'a>,
577}
578
579struct Other<'a> {
580    hello: Hello<'a>,
581}
582"#,
583            expect![[r#"
584                Hello['a: bivariant]
585                Other['a: bivariant]
586            "#]],
587        );
588    }
589
590    #[test]
591    fn rustc_test_variance_associated_consts() {
592        // FIXME: Should be invariant
593        check(
594            r#"
595trait Trait {
596    const Const: usize;
597}
598
599struct Foo<T: Trait> { //~ ERROR [T: o]
600    field: [u8; <T as Trait>::Const]
601}
602"#,
603            expect![[r#"
604                Foo[T: bivariant]
605            "#]],
606        );
607    }
608
609    #[test]
610    fn rustc_test_variance_associated_types() {
611        check(
612            r#"
613trait Trait<'a> {
614    type Type;
615
616    fn method(&'a self) { }
617}
618
619struct Foo<'a, T : Trait<'a>> { //~ ERROR ['a: +, T: +]
620    field: (T, &'a ())
621}
622
623struct Bar<'a, T : Trait<'a>> { //~ ERROR ['a: o, T: o]
624    field: <T as Trait<'a>>::Type
625}
626
627"#,
628            expect![[r#"
629                method[Self: contravariant, 'a: contravariant]
630                Foo['a: covariant, T: covariant]
631                Bar['a: invariant, T: invariant]
632            "#]],
633        );
634    }
635
636    #[test]
637    fn rustc_test_variance_associated_types2() {
638        // FIXME: RPITs have variance, but we can't treat them as their own thing right now
639        check(
640            r#"
641trait Foo {
642    type Bar;
643}
644
645fn make() -> *const dyn Foo<Bar = &'static u32> {}
646"#,
647            expect![""],
648        );
649    }
650
651    #[test]
652    fn rustc_test_variance_trait_bounds() {
653        check(
654            r#"
655trait Getter<T> {
656    fn get(&self) -> T;
657}
658
659trait Setter<T> {
660    fn get(&self, _: T);
661}
662
663struct TestStruct<U,T:Setter<U>> { //~ ERROR [U: +, T: +]
664    t: T, u: U
665}
666
667enum TestEnum<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
668    //~^ ERROR: `U` is never used
669    Foo(T)
670}
671
672struct TestContraStruct<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
673    //~^ ERROR: `U` is never used
674    t: T
675}
676
677struct TestBox<U,T:Getter<U>+Setter<U>> { //~ ERROR [U: *, T: +]
678    //~^ ERROR: `U` is never used
679    t: T
680}
681"#,
682            expect![[r#"
683                get[Self: contravariant, T: covariant]
684                get[Self: contravariant, T: contravariant]
685                TestStruct[U: covariant, T: covariant]
686                TestEnum[U: bivariant, T: covariant]
687                TestContraStruct[U: bivariant, T: covariant]
688                TestBox[U: bivariant, T: covariant]
689            "#]],
690        );
691    }
692
693    #[test]
694    fn rustc_test_variance_trait_matching() {
695        check(
696            r#"
697
698trait Get<T> {
699    fn get(&self) -> T;
700}
701
702struct Cloner<T:Clone> {
703    t: T
704}
705
706impl<T:Clone> Get<T> for Cloner<T> {
707    fn get(&self) -> T {}
708}
709
710fn get<'a, G>(get: &G) -> i32
711    where G : Get<&'a i32>
712{}
713
714fn pick<'b, G>(get: &'b G, if_odd: &'b i32) -> i32
715    where G : Get<&'b i32>
716{}
717"#,
718            expect![[r#"
719                get[Self: contravariant, T: covariant]
720                Cloner[T: covariant]
721                get[T: invariant]
722                get['a: invariant, G: contravariant]
723                pick['b: contravariant, G: contravariant]
724            "#]],
725        );
726    }
727
728    #[test]
729    fn rustc_test_variance_trait_object_bound() {
730        check(
731            r#"
732enum Option<T> {
733    Some(T),
734    None
735}
736trait T { fn foo(&self); }
737
738struct TOption<'a> { //~ ERROR ['a: +]
739    v: Option<*const (dyn T + 'a)>,
740}
741"#,
742            expect![[r#"
743                Option[T: covariant]
744                foo[Self: contravariant]
745                TOption['a: covariant]
746            "#]],
747        );
748    }
749
750    #[test]
751    fn rustc_test_variance_types_bounds() {
752        check(
753            r#"
754//- minicore: send
755struct TestImm<A, B> { //~ ERROR [A: +, B: +]
756    x: A,
757    y: B,
758}
759
760struct TestMut<A, B:'static> { //~ ERROR [A: +, B: o]
761    x: A,
762    y: &'static mut B,
763}
764
765struct TestIndirect<A:'static, B:'static> { //~ ERROR [A: +, B: o]
766    m: TestMut<A, B>
767}
768
769struct TestIndirect2<A:'static, B:'static> { //~ ERROR [A: o, B: o]
770    n: TestMut<A, B>,
771    m: TestMut<B, A>
772}
773
774trait Getter<A> {
775    fn get(&self) -> A;
776}
777
778trait Setter<A> {
779    fn set(&mut self, a: A);
780}
781
782struct TestObject<A, R> { //~ ERROR [A: o, R: o]
783    n: *const (dyn Setter<A> + Send),
784    m: *const (dyn Getter<R> + Send),
785}
786"#,
787            expect![[r#"
788                TestImm[A: covariant, B: covariant]
789                TestMut[A: covariant, B: invariant]
790                TestIndirect[A: covariant, B: invariant]
791                TestIndirect2[A: invariant, B: invariant]
792                get[Self: contravariant, A: covariant]
793                set[Self: invariant, A: contravariant]
794                TestObject[A: invariant, R: invariant]
795            "#]],
796        );
797    }
798
799    #[test]
800    fn rustc_test_variance_unused_region_param() {
801        check(
802            r#"
803struct SomeStruct<'a> { x: u32 } //~ ERROR parameter `'a` is never used
804enum SomeEnum<'a> { Nothing } //~ ERROR parameter `'a` is never used
805trait SomeTrait<'a> { fn foo(&self); } // OK on traits.
806"#,
807            expect![[r#"
808                SomeStruct['a: bivariant]
809                SomeEnum['a: bivariant]
810                foo[Self: contravariant, 'a: invariant]
811            "#]],
812        );
813    }
814
815    #[test]
816    fn rustc_test_variance_unused_type_param() {
817        check(
818            r#"
819//- minicore: sized
820struct SomeStruct<A> { x: u32 }
821enum SomeEnum<A> { Nothing }
822enum ListCell<T> {
823    Cons(*const ListCell<T>),
824    Nil
825}
826
827struct SelfTyAlias<T>(*const Self);
828struct WithBounds<T: Sized> {}
829struct WithWhereBounds<T> where T: Sized {}
830struct WithOutlivesBounds<T: 'static> {}
831struct DoubleNothing<T> {
832    s: SomeStruct<T>,
833}
834
835"#,
836            expect![[r#"
837                SomeStruct[A: bivariant]
838                SomeEnum[A: bivariant]
839                ListCell[T: bivariant]
840                SelfTyAlias[T: bivariant]
841                WithBounds[T: bivariant]
842                WithWhereBounds[T: bivariant]
843                WithOutlivesBounds[T: bivariant]
844                DoubleNothing[T: bivariant]
845            "#]],
846        );
847    }
848
849    #[test]
850    fn rustc_test_variance_use_contravariant_struct1() {
851        check(
852            r#"
853struct SomeStruct<T>(fn(T));
854
855fn foo<'min,'max>(v: SomeStruct<&'max ()>)
856                  -> SomeStruct<&'min ()>
857    where 'max : 'min
858{}
859"#,
860            expect![[r#"
861                SomeStruct[T: contravariant]
862                foo['min: contravariant, 'max: covariant]
863            "#]],
864        );
865    }
866
867    #[test]
868    fn rustc_test_variance_use_contravariant_struct2() {
869        check(
870            r#"
871struct SomeStruct<T>(fn(T));
872
873fn bar<'min,'max>(v: SomeStruct<&'min ()>)
874                  -> SomeStruct<&'max ()>
875    where 'max : 'min
876{}
877"#,
878            expect![[r#"
879                SomeStruct[T: contravariant]
880                bar['min: covariant, 'max: contravariant]
881            "#]],
882        );
883    }
884
885    #[test]
886    fn rustc_test_variance_use_covariant_struct1() {
887        check(
888            r#"
889struct SomeStruct<T>(T);
890
891fn foo<'min,'max>(v: SomeStruct<&'min ()>)
892                  -> SomeStruct<&'max ()>
893    where 'max : 'min
894{}
895"#,
896            expect![[r#"
897                SomeStruct[T: covariant]
898                foo['min: contravariant, 'max: covariant]
899            "#]],
900        );
901    }
902
903    #[test]
904    fn rustc_test_variance_use_covariant_struct2() {
905        check(
906            r#"
907struct SomeStruct<T>(T);
908
909fn foo<'min,'max>(v: SomeStruct<&'max ()>)
910                  -> SomeStruct<&'min ()>
911    where 'max : 'min
912{}
913"#,
914            expect![[r#"
915                SomeStruct[T: covariant]
916                foo['min: covariant, 'max: contravariant]
917            "#]],
918        );
919    }
920
921    #[test]
922    fn rustc_test_variance_use_invariant_struct1() {
923        check(
924            r#"
925struct SomeStruct<T>(*mut T);
926
927fn foo<'min,'max>(v: SomeStruct<&'max ()>)
928                  -> SomeStruct<&'min ()>
929    where 'max : 'min
930{}
931
932fn bar<'min,'max>(v: SomeStruct<&'min ()>)
933                  -> SomeStruct<&'max ()>
934    where 'max : 'min
935{}
936"#,
937            expect![[r#"
938                SomeStruct[T: invariant]
939                foo['min: invariant, 'max: invariant]
940                bar['min: invariant, 'max: invariant]
941            "#]],
942        );
943    }
944
945    #[test]
946    fn invalid_arg_counts() {
947        check(
948            r#"
949struct S<T>(T);
950struct S2<T>(S<>);
951struct S3<T>(S<T, T>);
952"#,
953            expect![[r#"
954                S[T: covariant]
955                S2[T: bivariant]
956                S3[T: covariant]
957            "#]],
958        );
959    }
960
961    #[test]
962    fn prove_fixedpoint() {
963        check(
964            r#"
965struct FixedPoint<T, U, V>(&'static FixedPoint<(), T, U>, V);
966"#,
967            expect![[r#"
968                FixedPoint[T: bivariant, U: bivariant, V: bivariant]
969            "#]],
970        );
971    }
972
973    #[track_caller]
974    fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expected: Expect) {
975        // use tracing_subscriber::{layer::SubscriberExt, Layer};
976        // let my_layer = tracing_subscriber::fmt::layer();
977        // let _g = tracing::subscriber::set_default(tracing_subscriber::registry().with(
978        //     my_layer.with_filter(tracing_subscriber::filter::filter_fn(|metadata| {
979        //         metadata.target().starts_with("hir_ty::variance")
980        //     })),
981        // ));
982        let (db, file_id) = TestDB::with_single_file(ra_fixture);
983
984        let mut defs: Vec<GenericDefId> = Vec::new();
985        let module = db.module_for_file_opt(file_id.file_id(&db)).unwrap();
986        let def_map = module.def_map(&db);
987        crate::tests::visit_module(&db, def_map, module.local_id, &mut |it| {
988            defs.push(match it {
989                ModuleDefId::FunctionId(it) => it.into(),
990                ModuleDefId::AdtId(it) => it.into(),
991                ModuleDefId::ConstId(it) => it.into(),
992                ModuleDefId::TraitId(it) => it.into(),
993                ModuleDefId::TypeAliasId(it) => it.into(),
994                _ => return,
995            })
996        });
997        let defs = defs
998            .into_iter()
999            .filter_map(|def| {
1000                Some((
1001                    def,
1002                    match def {
1003                        GenericDefId::FunctionId(it) => {
1004                            let loc = it.lookup(&db);
1005                            loc.source(&db).value.name().unwrap()
1006                        }
1007                        GenericDefId::AdtId(AdtId::EnumId(it)) => {
1008                            let loc = it.lookup(&db);
1009                            loc.source(&db).value.name().unwrap()
1010                        }
1011                        GenericDefId::AdtId(AdtId::StructId(it)) => {
1012                            let loc = it.lookup(&db);
1013                            loc.source(&db).value.name().unwrap()
1014                        }
1015                        GenericDefId::AdtId(AdtId::UnionId(it)) => {
1016                            let loc = it.lookup(&db);
1017                            loc.source(&db).value.name().unwrap()
1018                        }
1019                        GenericDefId::TraitId(it) => {
1020                            let loc = it.lookup(&db);
1021                            loc.source(&db).value.name().unwrap()
1022                        }
1023                        GenericDefId::TypeAliasId(it) => {
1024                            let loc = it.lookup(&db);
1025                            loc.source(&db).value.name().unwrap()
1026                        }
1027                        GenericDefId::ImplId(_) => return None,
1028                        GenericDefId::ConstId(_) => return None,
1029                        GenericDefId::StaticId(_) => return None,
1030                    },
1031                ))
1032            })
1033            .sorted_by_key(|(_, n)| n.syntax().text_range().start());
1034        let mut res = String::new();
1035        for (def, name) in defs {
1036            let Some(variances) = db.variances_of(def) else {
1037                continue;
1038            };
1039            format_to!(
1040                res,
1041                "{name}[{}]\n",
1042                generics(&db, def)
1043                    .iter()
1044                    .map(|(_, param)| match param {
1045                        GenericParamDataRef::TypeParamData(type_param_data) => {
1046                            type_param_data.name.as_ref().unwrap()
1047                        }
1048                        GenericParamDataRef::ConstParamData(const_param_data) =>
1049                            &const_param_data.name,
1050                        GenericParamDataRef::LifetimeParamData(lifetime_param_data) => {
1051                            &lifetime_param_data.name
1052                        }
1053                    })
1054                    .zip_eq(&*variances)
1055                    .format_with(", ", |(name, var), f| f(&format_args!(
1056                        "{}: {var}",
1057                        name.as_str()
1058                    )))
1059            );
1060        }
1061
1062        expected.assert_eq(&res);
1063    }
1064}