ide_assists/handlers/
wrap_unwrap_cfg_attr.rs

1use ide_db::source_change::SourceChangeBuilder;
2use itertools::Itertools;
3use syntax::{
4    NodeOrToken, SyntaxToken, T, TextRange, algo,
5    ast::{self, AstNode, make, syntax_factory::SyntaxFactory},
6};
7
8use crate::{AssistContext, AssistId, Assists};
9
10// Assist: wrap_unwrap_cfg_attr
11//
12// Wraps an attribute to a cfg_attr attribute or unwraps a cfg_attr attribute to the inner attributes.
13//
14// ```
15// #[derive$0(Debug)]
16// struct S {
17//    field: i32
18// }
19// ```
20// ->
21// ```
22// #[cfg_attr($0, derive(Debug))]
23// struct S {
24//    field: i32
25// }
26// ```
27
28enum WrapUnwrapOption {
29    WrapDerive { derive: TextRange, attr: ast::Attr },
30    WrapAttr(ast::Attr),
31}
32
33/// Attempts to get the derive attribute from a derive attribute list
34///
35/// This will collect all the tokens in the "path" within the derive attribute list
36/// But a derive attribute list doesn't have paths. So we need to collect all the tokens before and after the ident
37///
38/// If this functions return None just map to WrapAttr
39fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption {
40    let attempt_attr = || {
41        {
42            let mut derive = ident.text_range();
43            // TokenTree is all the tokens between the `(` and `)`. They do not have paths. So a path `serde::Serialize` would be [Ident Colon Colon Ident]
44            // So lets say we have derive(Debug, serde::Serialize, Copy) ident would be on Serialize
45            // We need to grab all previous tokens until we find a `,` or `(` and all following tokens until we find a `,` or `)`
46            // We also want to consume the following comma if it exists
47
48            let mut prev = algo::skip_trivia_token(
49                ident.prev_sibling_or_token()?.into_token()?,
50                syntax::Direction::Prev,
51            )?;
52            let mut following = algo::skip_trivia_token(
53                ident.next_sibling_or_token()?.into_token()?,
54                syntax::Direction::Next,
55            )?;
56            if (prev.kind() == T![,] || prev.kind() == T!['('])
57                && (following.kind() == T![,] || following.kind() == T![')'])
58            {
59                // This would be a single ident such as Debug. As no path is present
60                if following.kind() == T![,] {
61                    derive = derive.cover(following.text_range());
62                } else if following.kind() == T![')'] && prev.kind() == T![,] {
63                    derive = derive.cover(prev.text_range());
64                }
65
66                Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
67            } else {
68                let mut consumed_comma = false;
69                // Collect the path
70                while let Some(prev_token) = algo::skip_trivia_token(prev, syntax::Direction::Prev)
71                {
72                    let kind = prev_token.kind();
73                    if kind == T![,] {
74                        consumed_comma = true;
75                        derive = derive.cover(prev_token.text_range());
76                        break;
77                    } else if kind == T!['('] {
78                        break;
79                    } else {
80                        derive = derive.cover(prev_token.text_range());
81                    }
82                    prev = prev_token.prev_sibling_or_token()?.into_token()?;
83                }
84                while let Some(next_token) =
85                    algo::skip_trivia_token(following.clone(), syntax::Direction::Next)
86                {
87                    let kind = next_token.kind();
88                    match kind {
89                        T![,] if !consumed_comma => {
90                            derive = derive.cover(next_token.text_range());
91                            break;
92                        }
93                        T![')'] | T![,] => break,
94                        _ => derive = derive.cover(next_token.text_range()),
95                    }
96                    following = next_token.next_sibling_or_token()?.into_token()?;
97                }
98                Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
99            }
100        }
101    };
102    if ident.parent().and_then(ast::TokenTree::cast).is_none()
103        || !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default()
104    {
105        WrapUnwrapOption::WrapAttr(attr)
106    } else {
107        attempt_attr().unwrap_or(WrapUnwrapOption::WrapAttr(attr))
108    }
109}
110pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
111    let option = if ctx.has_empty_selection() {
112        let ident = ctx.find_token_syntax_at_offset(T![ident]);
113        let attr = ctx.find_node_at_offset::<ast::Attr>();
114        match (attr, ident) {
115            (Some(attr), Some(ident))
116                if attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() =>
117            {
118                Some(attempt_get_derive(attr, ident))
119            }
120
121            (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(attr)),
122            _ => None,
123        }
124    } else {
125        let covering_element = ctx.covering_element();
126        match covering_element {
127            NodeOrToken::Node(node) => ast::Attr::cast(node).map(WrapUnwrapOption::WrapAttr),
128            NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => {
129                let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?;
130                Some(attempt_get_derive(attr, ident))
131            }
132            _ => None,
133        }
134    }?;
135    match option {
136        WrapUnwrapOption::WrapAttr(attr) if attr.simple_name().as_deref() == Some("cfg_attr") => {
137            unwrap_cfg_attr(acc, attr)
138        }
139        WrapUnwrapOption::WrapAttr(attr) => wrap_cfg_attr(acc, ctx, attr),
140        WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive),
141    }
142}
143
144fn wrap_derive(
145    acc: &mut Assists,
146    ctx: &AssistContext<'_>,
147    attr: ast::Attr,
148    derive_element: TextRange,
149) -> Option<()> {
150    let range = attr.syntax().text_range();
151    let token_tree = attr.token_tree()?;
152    let mut path_text = String::new();
153
154    let mut cfg_derive_tokens = Vec::new();
155    let mut new_derive = Vec::new();
156
157    for tt in token_tree.token_trees_and_tokens() {
158        let NodeOrToken::Token(token) = tt else {
159            continue;
160        };
161        if token.kind() == T!['('] || token.kind() == T![')'] {
162            continue;
163        }
164
165        if derive_element.contains_range(token.text_range()) {
166            if token.kind() != T![,] && token.kind() != syntax::SyntaxKind::WHITESPACE {
167                path_text.push_str(token.text());
168                cfg_derive_tokens.push(NodeOrToken::Token(token));
169            }
170        } else {
171            new_derive.push(NodeOrToken::Token(token));
172        }
173    }
174    let handle_source_change = |edit: &mut SourceChangeBuilder| {
175        let make = SyntaxFactory::with_mappings();
176        let mut editor = edit.make_editor(attr.syntax());
177        let new_derive = make.attr_outer(
178            make.meta_token_tree(make.ident_path("derive"), make.token_tree(T!['('], new_derive)),
179        );
180        let meta = make.meta_token_tree(
181            make.ident_path("cfg_attr"),
182            make.token_tree(
183                T!['('],
184                vec![
185                    NodeOrToken::Token(make.token(T![,])),
186                    NodeOrToken::Token(make.whitespace(" ")),
187                    NodeOrToken::Token(make.ident("derive")),
188                    NodeOrToken::Node(make.token_tree(T!['('], cfg_derive_tokens)),
189                ],
190            ),
191        );
192
193        let cfg_attr = make.attr_outer(meta);
194        editor.replace_with_many(
195            attr.syntax(),
196            vec![
197                new_derive.syntax().clone().into(),
198                make.whitespace("\n").into(),
199                cfg_attr.syntax().clone().into(),
200            ],
201        );
202
203        if let Some(snippet_cap) = ctx.config.snippet_cap
204            && let Some(first_meta) =
205                cfg_attr.meta().and_then(|meta| meta.token_tree()).and_then(|tt| tt.l_paren_token())
206        {
207            let tabstop = edit.make_tabstop_after(snippet_cap);
208            editor.add_annotation(first_meta, tabstop);
209        }
210
211        editor.add_mappings(make.finish_with_mappings());
212        edit.add_file_edits(ctx.vfs_file_id(), editor);
213    };
214
215    acc.add(
216        AssistId::refactor("wrap_unwrap_cfg_attr"),
217        format!("Wrap #[derive({path_text})] in `cfg_attr`",),
218        range,
219        handle_source_change,
220    );
221    Some(())
222}
223fn wrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>, attr: ast::Attr) -> Option<()> {
224    let range = attr.syntax().text_range();
225    let path = attr.path()?;
226    let handle_source_change = |edit: &mut SourceChangeBuilder| {
227        let make = SyntaxFactory::with_mappings();
228        let mut editor = edit.make_editor(attr.syntax());
229        let mut raw_tokens =
230            vec![NodeOrToken::Token(make.token(T![,])), NodeOrToken::Token(make.whitespace(" "))];
231        path.syntax().descendants_with_tokens().for_each(|it| {
232            if let NodeOrToken::Token(token) = it {
233                raw_tokens.push(NodeOrToken::Token(token));
234            }
235        });
236        if let Some(meta) = attr.meta() {
237            if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) {
238                raw_tokens.push(NodeOrToken::Token(make.whitespace(" ")));
239                raw_tokens.push(NodeOrToken::Token(eq));
240                raw_tokens.push(NodeOrToken::Token(make.whitespace(" ")));
241
242                expr.syntax().descendants_with_tokens().for_each(|it| {
243                    if let NodeOrToken::Token(token) = it {
244                        raw_tokens.push(NodeOrToken::Token(token));
245                    }
246                });
247            } else if let Some(tt) = meta.token_tree() {
248                raw_tokens.extend(tt.token_trees_and_tokens());
249            }
250        }
251        let meta =
252            make.meta_token_tree(make.ident_path("cfg_attr"), make.token_tree(T!['('], raw_tokens));
253        let cfg_attr =
254            if attr.excl_token().is_some() { make.attr_inner(meta) } else { make.attr_outer(meta) };
255
256        editor.replace(attr.syntax(), cfg_attr.syntax());
257
258        if let Some(snippet_cap) = ctx.config.snippet_cap
259            && let Some(first_meta) =
260                cfg_attr.meta().and_then(|meta| meta.token_tree()).and_then(|tt| tt.l_paren_token())
261        {
262            let tabstop = edit.make_tabstop_after(snippet_cap);
263            editor.add_annotation(first_meta, tabstop);
264        }
265
266        editor.add_mappings(make.finish_with_mappings());
267        edit.add_file_edits(ctx.vfs_file_id(), editor);
268    };
269    acc.add(
270        AssistId::refactor("wrap_unwrap_cfg_attr"),
271        "Convert to `cfg_attr`",
272        range,
273        handle_source_change,
274    );
275    Some(())
276}
277fn unwrap_cfg_attr(acc: &mut Assists, attr: ast::Attr) -> Option<()> {
278    let range = attr.syntax().text_range();
279    let meta = attr.meta()?;
280    let meta_tt = meta.token_tree()?;
281    let mut inner_attrs = Vec::with_capacity(1);
282    let mut found_comma = false;
283    let mut iter = meta_tt.token_trees_and_tokens().skip(1).peekable();
284    while let Some(tt) = iter.next() {
285        if let NodeOrToken::Token(token) = &tt {
286            if token.kind() == T![')'] {
287                break;
288            }
289            if token.kind() == T![,] {
290                found_comma = true;
291                continue;
292            }
293        }
294        if !found_comma {
295            continue;
296        }
297        let Some(attr_name) = tt.into_token().and_then(|token| {
298            if token.kind() == T![ident] { Some(make::ext::ident_path(token.text())) } else { None }
299        }) else {
300            continue;
301        };
302        let next_tt = iter.next()?;
303        let meta = match next_tt {
304            NodeOrToken::Node(tt) => make::meta_token_tree(attr_name, tt),
305            NodeOrToken::Token(token) if token.kind() == T![,] || token.kind() == T![')'] => {
306                make::meta_path(attr_name)
307            }
308            NodeOrToken::Token(token) => {
309                let equals = algo::skip_trivia_token(token, syntax::Direction::Next)?;
310                if equals.kind() != T![=] {
311                    return None;
312                }
313                let expr_token =
314                    algo::skip_trivia_token(equals.next_token()?, syntax::Direction::Next)
315                        .and_then(|it| {
316                            if it.kind().is_literal() {
317                                Some(make::expr_literal(it.text()))
318                            } else {
319                                None
320                            }
321                        })?;
322                make::meta_expr(attr_name, ast::Expr::Literal(expr_token))
323            }
324        };
325        if attr.excl_token().is_some() {
326            inner_attrs.push(make::attr_inner(meta));
327        } else {
328            inner_attrs.push(make::attr_outer(meta));
329        }
330    }
331    if inner_attrs.is_empty() {
332        return None;
333    }
334    let handle_source_change = |f: &mut SourceChangeBuilder| {
335        let inner_attrs = inner_attrs.iter().map(|it| it.to_string()).join("\n");
336        f.replace(range, inner_attrs);
337    };
338    acc.add(
339        AssistId::refactor("wrap_unwrap_cfg_attr"),
340        "Extract Inner Attributes from `cfg_attr`",
341        range,
342        handle_source_change,
343    );
344    Some(())
345}
346#[cfg(test)]
347mod tests {
348    use crate::tests::check_assist;
349
350    use super::*;
351
352    #[test]
353    fn test_basic_to_from_cfg_attr() {
354        check_assist(
355            wrap_unwrap_cfg_attr,
356            r#"
357            #[derive$0(Debug)]
358            pub struct Test {
359                test: u32,
360            }
361            "#,
362            r#"
363            #[cfg_attr($0, derive(Debug))]
364            pub struct Test {
365                test: u32,
366            }
367            "#,
368        );
369        check_assist(
370            wrap_unwrap_cfg_attr,
371            r#"
372            #[cfg_attr(debug_assertions, $0 derive(Debug))]
373            pub struct Test {
374                test: u32,
375            }
376            "#,
377            r#"
378            #[derive(Debug)]
379            pub struct Test {
380                test: u32,
381            }
382            "#,
383        );
384    }
385    #[test]
386    fn to_from_path_attr() {
387        check_assist(
388            wrap_unwrap_cfg_attr,
389            r#"
390            pub struct Test {
391                #[foo$0]
392                test: u32,
393            }
394            "#,
395            r#"
396            pub struct Test {
397                #[cfg_attr($0, foo)]
398                test: u32,
399            }
400            "#,
401        );
402        check_assist(
403            wrap_unwrap_cfg_attr,
404            r#"
405            pub struct Test {
406                #[cfg_attr(debug_assertions$0, foo)]
407                test: u32,
408            }
409            "#,
410            r#"
411            pub struct Test {
412                #[foo]
413                test: u32,
414            }
415            "#,
416        );
417    }
418    #[test]
419    fn to_from_eq_attr() {
420        check_assist(
421            wrap_unwrap_cfg_attr,
422            r#"
423            pub struct Test {
424                #[foo = "bar"$0]
425                test: u32,
426            }
427            "#,
428            r#"
429            pub struct Test {
430                #[cfg_attr($0, foo = "bar")]
431                test: u32,
432            }
433            "#,
434        );
435        check_assist(
436            wrap_unwrap_cfg_attr,
437            r#"
438            pub struct Test {
439                #[cfg_attr(debug_assertions$0, foo = "bar")]
440                test: u32,
441            }
442            "#,
443            r#"
444            pub struct Test {
445                #[foo = "bar"]
446                test: u32,
447            }
448            "#,
449        );
450    }
451    #[test]
452    fn inner_attrs() {
453        check_assist(
454            wrap_unwrap_cfg_attr,
455            r#"
456            #![no_std$0]
457            "#,
458            r#"
459            #![cfg_attr($0, no_std)]
460            "#,
461        );
462        check_assist(
463            wrap_unwrap_cfg_attr,
464            r#"
465            #![cfg_attr(not(feature = "std")$0, no_std)]
466            "#,
467            r#"
468            #![no_std]
469            "#,
470        );
471    }
472    #[test]
473    fn test_derive_wrap() {
474        check_assist(
475            wrap_unwrap_cfg_attr,
476            r#"
477            #[derive(Debug$0, Clone, Copy)]
478            pub struct Test {
479                test: u32,
480            }
481            "#,
482            r#"
483            #[derive( Clone, Copy)]
484            #[cfg_attr($0, derive(Debug))]
485            pub struct Test {
486                test: u32,
487            }
488            "#,
489        );
490        check_assist(
491            wrap_unwrap_cfg_attr,
492            r#"
493            #[derive(Clone, Debug$0, Copy)]
494            pub struct Test {
495                test: u32,
496            }
497            "#,
498            r#"
499            #[derive(Clone,  Copy)]
500            #[cfg_attr($0, derive(Debug))]
501            pub struct Test {
502                test: u32,
503            }
504            "#,
505        );
506    }
507    #[test]
508    fn test_derive_wrap_with_path() {
509        check_assist(
510            wrap_unwrap_cfg_attr,
511            r#"
512            #[derive(std::fmt::Debug$0, Clone, Copy)]
513            pub struct Test {
514                test: u32,
515            }
516            "#,
517            r#"
518            #[derive( Clone, Copy)]
519            #[cfg_attr($0, derive(std::fmt::Debug))]
520            pub struct Test {
521                test: u32,
522            }
523            "#,
524        );
525        check_assist(
526            wrap_unwrap_cfg_attr,
527            r#"
528            #[derive(Clone, std::fmt::Debug$0, Copy)]
529            pub struct Test {
530                test: u32,
531            }
532            "#,
533            r#"
534            #[derive(Clone, Copy)]
535            #[cfg_attr($0, derive(std::fmt::Debug))]
536            pub struct Test {
537                test: u32,
538            }
539            "#,
540        );
541    }
542    #[test]
543    fn test_derive_wrap_at_end() {
544        check_assist(
545            wrap_unwrap_cfg_attr,
546            r#"
547            #[derive(std::fmt::Debug, Clone, Cop$0y)]
548            pub struct Test {
549                test: u32,
550            }
551            "#,
552            r#"
553            #[derive(std::fmt::Debug, Clone)]
554            #[cfg_attr($0, derive(Copy))]
555            pub struct Test {
556                test: u32,
557            }
558            "#,
559        );
560        check_assist(
561            wrap_unwrap_cfg_attr,
562            r#"
563            #[derive(Clone, Copy, std::fmt::D$0ebug)]
564            pub struct Test {
565                test: u32,
566            }
567            "#,
568            r#"
569            #[derive(Clone, Copy)]
570            #[cfg_attr($0, derive(std::fmt::Debug))]
571            pub struct Test {
572                test: u32,
573            }
574            "#,
575        );
576    }
577}