ide_assists/handlers/
replace_named_generic_with_impl.rs

1use hir::{FileRange, Semantics};
2use ide_db::text_edit::TextRange;
3use ide_db::{
4    EditionedFileId, RootDatabase,
5    defs::Definition,
6    search::{SearchScope, UsageSearchResult},
7};
8use syntax::{
9    AstNode,
10    ast::{self, HasGenericParams, HasName, HasTypeBounds, Name, NameLike, PathType, make},
11    match_ast,
12};
13
14use crate::{AssistContext, AssistId, Assists};
15
16// Assist: replace_named_generic_with_impl
17//
18// Replaces named generic with an `impl Trait` in function argument.
19//
20// ```
21// fn new<P$0: AsRef<Path>>(location: P) -> Self {}
22// ```
23// ->
24// ```
25// fn new(location: impl AsRef<Path>) -> Self {}
26// ```
27pub(crate) fn replace_named_generic_with_impl(
28    acc: &mut Assists,
29    ctx: &AssistContext<'_>,
30) -> Option<()> {
31    // finds `<P: AsRef<Path>>`
32    let type_param = ctx.find_node_at_offset::<ast::TypeParam>()?;
33    // returns `P`
34    let type_param_name = type_param.name()?;
35
36    // The list of type bounds / traits: `AsRef<Path>`
37    let type_bound_list = type_param.type_bound_list()?;
38
39    let fn_ = type_param.syntax().ancestors().find_map(ast::Fn::cast)?;
40    let param_list_text_range = fn_.param_list()?.syntax().text_range();
41
42    let type_param_hir_def = ctx.sema.to_def(&type_param)?;
43    let type_param_def = Definition::GenericParam(hir::GenericParam::TypeParam(type_param_hir_def));
44
45    // get all usage references for the type param
46    let usage_refs = find_usages(&ctx.sema, &fn_, type_param_def, ctx.file_id());
47    if usage_refs.is_empty() {
48        return None;
49    }
50
51    // All usage references need to be valid (inside the function param list)
52    if !check_valid_usages(&usage_refs, param_list_text_range) {
53        return None;
54    }
55
56    let mut path_types_to_replace = Vec::new();
57    for (_a, refs) in usage_refs.iter() {
58        for usage_ref in refs {
59            let Some(name_like) = usage_ref.name.clone().into_name_like() else {
60                continue;
61            };
62            let param_node = find_path_type(&ctx.sema, &type_param_name, &name_like)?;
63            path_types_to_replace.push(param_node);
64        }
65    }
66
67    let target = type_param.syntax().text_range();
68
69    acc.add(
70        AssistId::refactor_rewrite("replace_named_generic_with_impl"),
71        "Replace named generic with impl trait",
72        target,
73        |edit| {
74            let mut editor = edit.make_editor(type_param.syntax());
75
76            // remove trait from generic param list
77            if let Some(generic_params) = fn_.generic_param_list() {
78                let params: Vec<ast::GenericParam> = generic_params
79                    .clone()
80                    .generic_params()
81                    .filter(|it| it.syntax() != type_param.syntax())
82                    .collect();
83                if params.is_empty() {
84                    editor.delete(generic_params.syntax());
85                } else {
86                    let new_generic_param_list = make::generic_param_list(params);
87                    editor.replace(
88                        generic_params.syntax(),
89                        new_generic_param_list.syntax().clone_for_update(),
90                    );
91                }
92            }
93
94            let new_bounds = make::impl_trait_type(type_bound_list);
95            for path_type in path_types_to_replace.iter().rev() {
96                editor.replace(path_type.syntax(), new_bounds.clone_for_update().syntax());
97            }
98            edit.add_file_edits(ctx.vfs_file_id(), editor);
99        },
100    )
101}
102
103fn find_path_type(
104    sema: &Semantics<'_, RootDatabase>,
105    type_param_name: &Name,
106    param: &NameLike,
107) -> Option<PathType> {
108    let path_type =
109        sema.ancestors_with_macros(param.syntax().clone()).find_map(ast::PathType::cast)?;
110
111    // Ignore any path types that look like `P::Assoc`
112    if path_type.path()?.as_single_name_ref()?.text() != type_param_name.text() {
113        return None;
114    }
115
116    let ancestors = sema.ancestors_with_macros(path_type.syntax().clone());
117
118    let mut in_generic_arg_list = false;
119    let mut is_associated_type = false;
120
121    // walking the ancestors checks them in a heuristic way until the `Fn` node is reached.
122    for ancestor in ancestors {
123        match_ast! {
124            match ancestor {
125                ast::PathSegment(ps) => {
126                    match ps.kind()? {
127                        ast::PathSegmentKind::Name(_name_ref) => (),
128                        ast::PathSegmentKind::Type { .. } => return None,
129                        _ => return None,
130                    }
131                },
132                ast::GenericArgList(_) => {
133                    in_generic_arg_list = true;
134                },
135                ast::AssocTypeArg(_) => {
136                    is_associated_type = true;
137                },
138                ast::ImplTraitType(_) => {
139                    if in_generic_arg_list && !is_associated_type {
140                        return None;
141                    }
142                },
143                ast::DynTraitType(_) => {
144                    if !is_associated_type {
145                        return None;
146                    }
147                },
148                ast::Fn(_) => return Some(path_type),
149                _ => (),
150            }
151        }
152    }
153
154    None
155}
156
157/// Returns all usage references for the given type parameter definition.
158fn find_usages(
159    sema: &Semantics<'_, RootDatabase>,
160    fn_: &ast::Fn,
161    type_param_def: Definition,
162    file_id: EditionedFileId,
163) -> UsageSearchResult {
164    let file_range = FileRange { file_id, range: fn_.syntax().text_range() };
165    type_param_def.usages(sema).in_scope(&SearchScope::file_range(file_range)).all()
166}
167
168fn check_valid_usages(usages: &UsageSearchResult, param_list_range: TextRange) -> bool {
169    usages
170        .iter()
171        .flat_map(|(_, usage_refs)| usage_refs)
172        .all(|usage_ref| param_list_range.contains_range(usage_ref.range))
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    use crate::tests::{check_assist, check_assist_not_applicable};
180
181    #[test]
182    fn replace_generic_moves_into_function() {
183        check_assist(
184            replace_named_generic_with_impl,
185            r#"fn new<T$0: ToString>(input: T) -> Self {}"#,
186            r#"fn new(input: impl ToString) -> Self {}"#,
187        );
188    }
189
190    #[test]
191    fn replace_generic_with_inner_associated_type() {
192        check_assist(
193            replace_named_generic_with_impl,
194            r#"fn new<P$0: AsRef<Path>>(input: P) -> Self {}"#,
195            r#"fn new(input: impl AsRef<Path>) -> Self {}"#,
196        );
197    }
198
199    #[test]
200    fn replace_generic_trait_applies_to_all_matching_params() {
201        check_assist(
202            replace_named_generic_with_impl,
203            r#"fn new<T$0: ToString>(a: T, b: T) -> Self {}"#,
204            r#"fn new(a: impl ToString, b: impl ToString) -> Self {}"#,
205        );
206    }
207
208    #[test]
209    fn replace_generic_trait_applies_to_generic_arguments_in_params() {
210        check_assist(
211            replace_named_generic_with_impl,
212            r#"
213            fn foo<P$0: Trait>(
214                _: P,
215                _: Option<P>,
216                _: Option<Option<P>>,
217                _: impl Iterator<Item = P>,
218                _: &dyn Iterator<Item = P>,
219            ) {}
220            "#,
221            r#"
222            fn foo(
223                _: impl Trait,
224                _: Option<impl Trait>,
225                _: Option<Option<impl Trait>>,
226                _: impl Iterator<Item = impl Trait>,
227                _: &dyn Iterator<Item = impl Trait>,
228            ) {}
229            "#,
230        );
231    }
232
233    #[test]
234    fn replace_generic_not_applicable_when_one_param_type_is_invalid() {
235        check_assist_not_applicable(
236            replace_named_generic_with_impl,
237            r#"
238            fn foo<P$0: Trait>(
239                _: i32,
240                _: Option<P>,
241                _: Option<Option<P>>,
242                _: impl Iterator<Item = P>,
243                _: &dyn Iterator<Item = P>,
244                _: <P as Trait>::Assoc,
245            ) {}
246            "#,
247        );
248    }
249
250    #[test]
251    fn replace_generic_not_applicable_when_referenced_in_where_clause() {
252        check_assist_not_applicable(
253            replace_named_generic_with_impl,
254            r#"fn foo<P$0: Trait, I>() where I: FromRef<P> {}"#,
255        );
256    }
257
258    #[test]
259    fn replace_generic_not_applicable_when_used_with_type_alias() {
260        check_assist_not_applicable(
261            replace_named_generic_with_impl,
262            r#"fn foo<P$0: Trait>(p: <P as Trait>::Assoc) {}"#,
263        );
264    }
265
266    #[test]
267    fn replace_generic_not_applicable_when_used_as_argument_in_outer_trait_alias() {
268        check_assist_not_applicable(
269            replace_named_generic_with_impl,
270            r#"fn foo<P$0: Trait>(_: <() as OtherTrait<P>>::Assoc) {}"#,
271        );
272    }
273
274    #[test]
275    fn replace_generic_not_applicable_with_inner_associated_type() {
276        check_assist_not_applicable(
277            replace_named_generic_with_impl,
278            r#"fn foo<P$0: Trait>(_: P::Assoc) {}"#,
279        );
280    }
281
282    #[test]
283    fn replace_generic_not_applicable_when_passed_into_outer_impl_trait() {
284        check_assist_not_applicable(
285            replace_named_generic_with_impl,
286            r#"fn foo<P$0: Trait>(_: impl OtherTrait<P>) {}"#,
287        );
288    }
289
290    #[test]
291    fn replace_generic_not_applicable_when_used_in_passed_function_parameter() {
292        check_assist_not_applicable(
293            replace_named_generic_with_impl,
294            r#"fn foo<P$0: Trait>(_: &dyn Fn(P)) {}"#,
295        );
296    }
297
298    #[test]
299    fn replace_generic_with_multiple_generic_params() {
300        check_assist(
301            replace_named_generic_with_impl,
302            r#"fn new<P: AsRef<Path>, T$0: ToString>(t: T, p: P) -> Self {}"#,
303            r#"fn new<P: AsRef<Path>>(t: impl ToString, p: P) -> Self {}"#,
304        );
305        check_assist(
306            replace_named_generic_with_impl,
307            r#"fn new<T$0: ToString, P: AsRef<Path>>(t: T, p: P) -> Self {}"#,
308            r#"fn new<P: AsRef<Path>>(t: impl ToString, p: P) -> Self {}"#,
309        );
310        check_assist(
311            replace_named_generic_with_impl,
312            r#"fn new<A: Send, B$0: ToString, C: Debug>(a: A, b: B, c: C) -> Self {}"#,
313            r#"fn new<A: Send, C: Debug>(a: A, b: impl ToString, c: C) -> Self {}"#,
314        );
315    }
316
317    #[test]
318    fn replace_generic_with_multiple_trait_bounds() {
319        check_assist(
320            replace_named_generic_with_impl,
321            r#"fn new<P$0: Send + Sync>(p: P) -> Self {}"#,
322            r#"fn new(p: impl Send + Sync) -> Self {}"#,
323        );
324    }
325
326    #[test]
327    fn replace_generic_not_applicable_if_param_used_as_return_type() {
328        check_assist_not_applicable(
329            replace_named_generic_with_impl,
330            r#"fn new<P$0: Send + Sync>(p: P) -> P {}"#,
331        );
332    }
333
334    #[test]
335    fn replace_generic_not_applicable_if_param_used_in_fn_body() {
336        check_assist_not_applicable(
337            replace_named_generic_with_impl,
338            r#"fn new<P$0: ToString>(p: P) { let x: &dyn P = &O; }"#,
339        );
340    }
341
342    #[test]
343    fn replace_generic_ignores_another_function_with_same_param_type() {
344        check_assist(
345            replace_named_generic_with_impl,
346            r#"
347            fn new<P$0: Send + Sync>(p: P) {}
348            fn hello<P: Debug>(p: P) { println!("{:?}", p); }
349            "#,
350            r#"
351            fn new(p: impl Send + Sync) {}
352            fn hello<P: Debug>(p: P) { println!("{:?}", p); }
353            "#,
354        );
355    }
356}