hir/term_search/
expr.rs

1//! Type tree for term search
2
3use hir_def::ImportPathConfig;
4use hir_expand::mod_path::ModPath;
5use hir_ty::{
6    db::HirDatabase,
7    display::{DisplaySourceCodeError, DisplayTarget, HirDisplay},
8};
9use itertools::Itertools;
10use span::Edition;
11
12use crate::{
13    Adt, AsAssocItem, AssocItemContainer, Const, ConstParam, Field, Function, Local, ModuleDef,
14    SemanticsScope, Static, Struct, StructKind, Trait, Type, Variant,
15};
16
17/// Helper function to get path to `ModuleDef`
18fn mod_item_path(
19    sema_scope: &SemanticsScope<'_>,
20    def: &ModuleDef,
21    cfg: ImportPathConfig,
22) -> Option<ModPath> {
23    let db = sema_scope.db;
24    let m = sema_scope.module();
25    m.find_path(db, *def, cfg)
26}
27
28/// Helper function to get path to `ModuleDef` as string
29fn mod_item_path_str(
30    sema_scope: &SemanticsScope<'_>,
31    def: &ModuleDef,
32    cfg: ImportPathConfig,
33    edition: Edition,
34) -> Result<String, DisplaySourceCodeError> {
35    let path = mod_item_path(sema_scope, def, cfg);
36    path.map(|it| it.display(sema_scope.db, edition).to_string())
37        .ok_or(DisplaySourceCodeError::PathNotFound)
38}
39
40/// Type tree shows how can we get from set of types to some type.
41///
42/// Consider the following code as an example
43/// ```ignore
44/// fn foo(x: i32, y: bool) -> Option<i32> { None }
45/// fn bar() {
46///    let a = 1;
47///    let b = true;
48///    let c: Option<i32> = _;
49/// }
50/// ```
51/// If we generate type tree in the place of `_` we get
52/// ```txt
53///       Option<i32>
54///           |
55///     foo(i32, bool)
56///      /        \
57///  a: i32      b: bool
58/// ```
59/// So in short it pretty much gives us a way to get type `Option<i32>` using the items we have in
60/// scope.
61#[derive(Debug, Clone, Eq, Hash, PartialEq)]
62pub enum Expr<'db> {
63    /// Constant
64    Const(Const),
65    /// Static variable
66    Static(Static),
67    /// Local variable
68    Local(Local),
69    /// Constant generic parameter
70    ConstParam(ConstParam),
71    /// Well known type (such as `true` for bool)
72    FamousType { ty: Type<'db>, value: &'static str },
73    /// Function call (does not take self param)
74    Function { func: Function, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> },
75    /// Method call (has self param)
76    Method {
77        func: Function,
78        generics: Vec<Type<'db>>,
79        target: Box<Expr<'db>>,
80        params: Vec<Expr<'db>>,
81    },
82    /// Enum variant construction
83    Variant { variant: Variant, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> },
84    /// Struct construction
85    Struct { strukt: Struct, generics: Vec<Type<'db>>, params: Vec<Expr<'db>> },
86    /// Tuple construction
87    Tuple { ty: Type<'db>, params: Vec<Expr<'db>> },
88    /// Struct field access
89    Field { expr: Box<Expr<'db>>, field: Field },
90    /// Passing type as reference (with `&`)
91    Reference(Box<Expr<'db>>),
92    /// Indicates possibility of many different options that all evaluate to `ty`
93    Many(Type<'db>),
94}
95
96impl<'db> Expr<'db> {
97    /// Generate source code for type tree.
98    ///
99    /// Note that trait imports are not added to generated code.
100    /// To make sure that the code is valid, callee has to also ensure that all the traits listed
101    /// by `traits_used` method are also imported.
102    pub fn gen_source_code(
103        &self,
104        sema_scope: &SemanticsScope<'db>,
105        many_formatter: &mut dyn FnMut(&Type<'db>) -> String,
106        cfg: ImportPathConfig,
107        display_target: DisplayTarget,
108    ) -> Result<String, DisplaySourceCodeError> {
109        let db = sema_scope.db;
110        let edition = display_target.edition;
111        let mod_item_path_str = |s, def| mod_item_path_str(s, def, cfg, edition);
112        match self {
113            Expr::Const(it) => match it.as_assoc_item(db).map(|it| it.container(db)) {
114                Some(container) => {
115                    let container_name =
116                        container_name(container, sema_scope, cfg, edition, display_target)?;
117                    let const_name = it
118                        .name(db)
119                        .map(|c| c.display(db, edition).to_string())
120                        .unwrap_or(String::new());
121                    Ok(format!("{container_name}::{const_name}"))
122                }
123                None => mod_item_path_str(sema_scope, &ModuleDef::Const(*it)),
124            },
125            Expr::Static(it) => mod_item_path_str(sema_scope, &ModuleDef::Static(*it)),
126            Expr::Local(it) => Ok(it.name(db).display(db, edition).to_string()),
127            Expr::ConstParam(it) => Ok(it.name(db).display(db, edition).to_string()),
128            Expr::FamousType { value, .. } => Ok(value.to_string()),
129            Expr::Function { func, params, .. } => {
130                let args = params
131                    .iter()
132                    .map(|f| f.gen_source_code(sema_scope, many_formatter, cfg, display_target))
133                    .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
134                    .into_iter()
135                    .join(", ");
136
137                match func.as_assoc_item(db).map(|it| it.container(db)) {
138                    Some(container) => {
139                        let container_name =
140                            container_name(container, sema_scope, cfg, edition, display_target)?;
141                        let fn_name = func.name(db).display(db, edition).to_string();
142                        Ok(format!("{container_name}::{fn_name}({args})"))
143                    }
144                    None => {
145                        let fn_name = mod_item_path_str(sema_scope, &ModuleDef::Function(*func))?;
146                        Ok(format!("{fn_name}({args})"))
147                    }
148                }
149            }
150            Expr::Method { func, target, params, .. } => {
151                if self.contains_many_in_illegal_pos(db) {
152                    return Ok(many_formatter(&target.ty(db)));
153                }
154
155                let func_name = func.name(db).display(db, edition).to_string();
156                let self_param = func.self_param(db).unwrap();
157                let target_str =
158                    target.gen_source_code(sema_scope, many_formatter, cfg, display_target)?;
159                let args = params
160                    .iter()
161                    .map(|f| f.gen_source_code(sema_scope, many_formatter, cfg, display_target))
162                    .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
163                    .into_iter()
164                    .join(", ");
165
166                match func.as_assoc_item(db).and_then(|it| it.container_or_implemented_trait(db)) {
167                    Some(trait_) => {
168                        let trait_name = mod_item_path_str(sema_scope, &ModuleDef::Trait(trait_))?;
169                        let target = match self_param.access(db) {
170                            crate::Access::Shared if !target.is_many() => format!("&{target_str}"),
171                            crate::Access::Exclusive if !target.is_many() => {
172                                format!("&mut {target_str}")
173                            }
174                            crate::Access::Owned => target_str,
175                            _ => many_formatter(&target.ty(db)),
176                        };
177                        let res = match args.is_empty() {
178                            true => format!("{trait_name}::{func_name}({target})",),
179                            false => format!("{trait_name}::{func_name}({target}, {args})",),
180                        };
181                        Ok(res)
182                    }
183                    None => Ok(format!("{target_str}.{func_name}({args})")),
184                }
185            }
186            Expr::Variant { variant, params, .. } => {
187                let inner = match variant.kind(db) {
188                    StructKind::Tuple => {
189                        let args = params
190                            .iter()
191                            .map(|f| {
192                                f.gen_source_code(sema_scope, many_formatter, cfg, display_target)
193                            })
194                            .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
195                            .into_iter()
196                            .join(", ");
197                        format!("({args})")
198                    }
199                    StructKind::Record => {
200                        let fields = variant.fields(db);
201                        let args = params
202                            .iter()
203                            .zip(fields.iter())
204                            .map(|(a, f)| {
205                                let tmp = format!(
206                                    "{}: {}",
207                                    f.name(db).display(db, edition),
208                                    a.gen_source_code(
209                                        sema_scope,
210                                        many_formatter,
211                                        cfg,
212                                        display_target
213                                    )?
214                                );
215                                Ok(tmp)
216                            })
217                            .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
218                            .into_iter()
219                            .join(", ");
220                        format!("{{ {args} }}")
221                    }
222                    StructKind::Unit => String::new(),
223                };
224
225                let prefix = mod_item_path_str(sema_scope, &ModuleDef::Variant(*variant))?;
226                Ok(format!("{prefix}{inner}"))
227            }
228            Expr::Struct { strukt, params, .. } => {
229                let inner = match strukt.kind(db) {
230                    StructKind::Tuple => {
231                        let args = params
232                            .iter()
233                            .map(|a| {
234                                a.gen_source_code(sema_scope, many_formatter, cfg, display_target)
235                            })
236                            .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
237                            .into_iter()
238                            .join(", ");
239                        format!("({args})")
240                    }
241                    StructKind::Record => {
242                        let fields = strukt.fields(db);
243                        let args = params
244                            .iter()
245                            .zip(fields.iter())
246                            .map(|(a, f)| {
247                                let tmp = format!(
248                                    "{}: {}",
249                                    f.name(db).display(db, edition),
250                                    a.gen_source_code(
251                                        sema_scope,
252                                        many_formatter,
253                                        cfg,
254                                        display_target
255                                    )?
256                                );
257                                Ok(tmp)
258                            })
259                            .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
260                            .into_iter()
261                            .join(", ");
262                        format!(" {{ {args} }}")
263                    }
264                    StructKind::Unit => String::new(),
265                };
266
267                let prefix = mod_item_path_str(sema_scope, &ModuleDef::Adt(Adt::Struct(*strukt)))?;
268                Ok(format!("{prefix}{inner}"))
269            }
270            Expr::Tuple { params, .. } => {
271                let args = params
272                    .iter()
273                    .map(|a| a.gen_source_code(sema_scope, many_formatter, cfg, display_target))
274                    .collect::<Result<Vec<String>, DisplaySourceCodeError>>()?
275                    .into_iter()
276                    .join(", ");
277                let res = format!("({args})");
278                Ok(res)
279            }
280            Expr::Field { expr, field } => {
281                if expr.contains_many_in_illegal_pos(db) {
282                    return Ok(many_formatter(&expr.ty(db)));
283                }
284
285                let strukt =
286                    expr.gen_source_code(sema_scope, many_formatter, cfg, display_target)?;
287                let field = field.name(db).display(db, edition).to_string();
288                Ok(format!("{strukt}.{field}"))
289            }
290            Expr::Reference(expr) => {
291                if expr.contains_many_in_illegal_pos(db) {
292                    return Ok(many_formatter(&expr.ty(db)));
293                }
294
295                let inner =
296                    expr.gen_source_code(sema_scope, many_formatter, cfg, display_target)?;
297                Ok(format!("&{inner}"))
298            }
299            Expr::Many(ty) => Ok(many_formatter(ty)),
300        }
301    }
302
303    /// Get type of the type tree.
304    ///
305    /// Same as getting the type of root node
306    pub fn ty(&self, db: &'db dyn HirDatabase) -> Type<'db> {
307        match self {
308            Expr::Const(it) => it.ty(db),
309            Expr::Static(it) => it.ty(db),
310            Expr::Local(it) => it.ty(db),
311            Expr::ConstParam(it) => it.ty(db),
312            Expr::FamousType { ty, .. } => ty.clone(),
313            Expr::Function { func, generics, .. } => {
314                func.ret_type_with_args(db, generics.iter().cloned())
315            }
316            Expr::Method { func, generics, target, .. } => func.ret_type_with_args(
317                db,
318                target.ty(db).type_arguments().chain(generics.iter().cloned()),
319            ),
320            Expr::Variant { variant, generics, .. } => {
321                Adt::from(variant.parent_enum(db)).ty_with_args(db, generics.iter().cloned())
322            }
323            Expr::Struct { strukt, generics, .. } => {
324                Adt::from(*strukt).ty_with_args(db, generics.iter().cloned())
325            }
326            Expr::Tuple { ty, .. } => ty.clone(),
327            Expr::Field { expr, field } => field.ty_with_args(db, expr.ty(db).type_arguments()),
328            Expr::Reference(it) => it.ty(db),
329            Expr::Many(ty) => ty.clone(),
330        }
331    }
332
333    /// List the traits used in type tree
334    pub fn traits_used(&self, db: &dyn HirDatabase) -> Vec<Trait> {
335        let mut res = Vec::new();
336
337        if let Expr::Method { func, params, .. } = self {
338            res.extend(params.iter().flat_map(|it| it.traits_used(db)));
339            if let Some(it) = func.as_assoc_item(db)
340                && let Some(it) = it.container_or_implemented_trait(db)
341            {
342                res.push(it);
343            }
344        }
345
346        res
347    }
348
349    /// Check in the tree contains `Expr::Many` variant in illegal place to insert `todo`,
350    /// `unimplemented` or similar macro
351    ///
352    /// Some examples are following
353    /// ```no_compile
354    /// macro!().foo
355    /// macro!().bar()
356    /// &macro!()
357    /// ```
358    fn contains_many_in_illegal_pos(&self, db: &dyn HirDatabase) -> bool {
359        match self {
360            Expr::Method { target, func, .. } => {
361                match func.as_assoc_item(db).and_then(|it| it.container_or_implemented_trait(db)) {
362                    Some(_) => false,
363                    None => target.is_many(),
364                }
365            }
366            Expr::Field { expr, .. } => expr.contains_many_in_illegal_pos(db),
367            Expr::Reference(target) => target.is_many(),
368            Expr::Many(_) => true,
369            _ => false,
370        }
371    }
372
373    /// Helper function to check if outermost type tree is `Expr::Many` variant
374    pub fn is_many(&self) -> bool {
375        matches!(self, Expr::Many(_))
376    }
377}
378
379/// Helper function to find name of container
380fn container_name(
381    container: AssocItemContainer,
382    sema_scope: &SemanticsScope<'_>,
383    cfg: ImportPathConfig,
384    edition: Edition,
385    display_target: DisplayTarget,
386) -> Result<String, DisplaySourceCodeError> {
387    let container_name = match container {
388        crate::AssocItemContainer::Trait(trait_) => {
389            mod_item_path_str(sema_scope, &ModuleDef::Trait(trait_), cfg, edition)?
390        }
391        crate::AssocItemContainer::Impl(imp) => {
392            let self_ty = imp.self_ty(sema_scope.db);
393            // Should it be guaranteed that `mod_item_path` always exists?
394            match self_ty.as_adt().and_then(|adt| mod_item_path(sema_scope, &adt.into(), cfg)) {
395                Some(path) => path.display(sema_scope.db, edition).to_string(),
396                None => self_ty.display(sema_scope.db, display_target).to_string(),
397            }
398        }
399    };
400    Ok(container_name)
401}