Skip to main content

ide_assists/handlers/
wrap_unwrap_cfg_attr.rs

1use ide_db::source_change::SourceChangeBuilder;
2use syntax::{
3    NodeOrToken, SyntaxToken, T, TextRange, algo,
4    ast::{self, AstNode, edit::AstNodeEdit},
5};
6
7use crate::{AssistContext, AssistId, Assists};
8
9// Assist: wrap_unwrap_cfg_attr
10//
11// Wraps an attribute to a cfg_attr attribute or unwraps a cfg_attr attribute to the inner attributes.
12//
13// ```
14// #[derive$0(Debug)]
15// struct S {
16//    field: i32
17// }
18// ```
19// ->
20// ```
21// #[cfg_attr(${0:cfg}, derive(Debug))]
22// struct S {
23//    field: i32
24// }
25// ```
26
27enum WrapUnwrapOption {
28    WrapDerive { derive: TextRange, attr: ast::Attr },
29    WrapAttr(Vec<ast::Attr>),
30}
31
32/// Attempts to get the derive attribute from a derive attribute list
33///
34/// This will collect all the tokens in the "path" within the derive attribute list
35/// But a derive attribute list doesn't have paths. So we need to collect all the tokens before and after the ident
36///
37/// If this functions return None just map to WrapAttr
38fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption {
39    let attempt_attr = || {
40        {
41            let mut derive = ident.text_range();
42            // TokenTree is all the tokens between the `(` and `)`. They do not have paths. So a path `serde::Serialize` would be [Ident Colon Colon Ident]
43            // So lets say we have derive(Debug, serde::Serialize, Copy) ident would be on Serialize
44            // We need to grab all previous tokens until we find a `,` or `(` and all following tokens until we find a `,` or `)`
45            // We also want to consume the following comma if it exists
46
47            let mut prev = algo::skip_trivia_token(
48                ident.prev_sibling_or_token()?.into_token()?,
49                syntax::Direction::Prev,
50            )?;
51            let mut following = algo::skip_trivia_token(
52                ident.next_sibling_or_token()?.into_token()?,
53                syntax::Direction::Next,
54            )?;
55            if (prev.kind() == T![,] || prev.kind() == T!['('])
56                && (following.kind() == T![,] || following.kind() == T![')'])
57            {
58                // This would be a single ident such as Debug. As no path is present
59                if following.kind() == T![,] {
60                    derive = derive.cover(following.text_range());
61                } else if following.kind() == T![')'] && prev.kind() == T![,] {
62                    derive = derive.cover(prev.text_range());
63                }
64
65                Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
66            } else {
67                let mut consumed_comma = false;
68                // Collect the path
69                while let Some(prev_token) = algo::skip_trivia_token(prev, syntax::Direction::Prev)
70                {
71                    let kind = prev_token.kind();
72                    if kind == T![,] {
73                        consumed_comma = true;
74                        derive = derive.cover(prev_token.text_range());
75                        break;
76                    } else if kind == T!['('] {
77                        break;
78                    } else {
79                        derive = derive.cover(prev_token.text_range());
80                    }
81                    prev = prev_token.prev_sibling_or_token()?.into_token()?;
82                }
83                while let Some(next_token) =
84                    algo::skip_trivia_token(following.clone(), syntax::Direction::Next)
85                {
86                    let kind = next_token.kind();
87                    match kind {
88                        T![,] if !consumed_comma => {
89                            derive = derive.cover(next_token.text_range());
90                            break;
91                        }
92                        T![')'] | T![,] => break,
93                        _ => derive = derive.cover(next_token.text_range()),
94                    }
95                    following = next_token.next_sibling_or_token()?.into_token()?;
96                }
97                Some(WrapUnwrapOption::WrapDerive { derive, attr: attr.clone() })
98            }
99        }
100    };
101    if ident.parent().and_then(ast::TokenTree::cast).is_none()
102        || !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default()
103    {
104        WrapUnwrapOption::WrapAttr(vec![attr])
105    } else {
106        attempt_attr().unwrap_or_else(|| WrapUnwrapOption::WrapAttr(vec![attr]))
107    }
108}
109pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
110    let option = if ctx.has_empty_selection() {
111        let ident = ctx.find_token_syntax_at_offset(T![ident]);
112        let attr = ctx.find_node_at_offset::<ast::Attr>();
113        match (attr, ident) {
114            (Some(attr), Some(ident))
115                if attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() =>
116            {
117                Some(attempt_get_derive(attr, ident))
118            }
119
120            (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(vec![attr])),
121            _ => None,
122        }
123    } else {
124        let covering_element = ctx.covering_element();
125        match covering_element {
126            NodeOrToken::Node(node) => {
127                if let Some(attr) = ast::Attr::cast(node.clone()) {
128                    Some(WrapUnwrapOption::WrapAttr(vec![attr]))
129                } else {
130                    let attrs = node
131                        .children()
132                        .filter(|it| it.text_range().intersect(ctx.selection_trimmed()).is_some())
133                        .map(ast::Attr::cast)
134                        .collect::<Option<Vec<_>>>()?;
135                    if attrs.is_empty() {
136                        return None;
137                    }
138                    Some(WrapUnwrapOption::WrapAttr(attrs))
139                }
140            }
141            NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => {
142                let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?;
143                Some(attempt_get_derive(attr, ident))
144            }
145            _ => None,
146        }
147    }?;
148    match option {
149        WrapUnwrapOption::WrapAttr(attrs) => {
150            if let [attr] = &attrs[..]
151                && let Some(ast::Meta::CfgAttrMeta(meta)) = attr.meta()
152            {
153                unwrap_cfg_attr(acc, ctx, meta)
154            } else {
155                wrap_cfg_attrs(acc, ctx, attrs)
156            }
157        }
158        WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive),
159    }
160}
161
162fn wrap_derive(
163    acc: &mut Assists,
164    ctx: &AssistContext<'_, '_>,
165    attr: ast::Attr,
166    derive_element: TextRange,
167) -> Option<()> {
168    let range = attr.syntax().text_range();
169    let ast::Meta::TokenTreeMeta(meta) = attr.meta()? else { return None };
170    let token_tree = meta.token_tree()?;
171    let mut path_text = String::new();
172
173    let mut cfg_derive_tokens = Vec::new();
174    let mut new_derive = Vec::new();
175
176    for tt in token_tree.token_trees_and_tokens() {
177        let NodeOrToken::Token(token) = tt else {
178            continue;
179        };
180        if token.kind() == T!['('] || token.kind() == T![')'] {
181            continue;
182        }
183
184        if derive_element.contains_range(token.text_range()) {
185            if token.kind() != T![,] && token.kind() != syntax::SyntaxKind::WHITESPACE {
186                path_text.push_str(token.text());
187                cfg_derive_tokens.push(NodeOrToken::Token(token));
188            }
189        } else {
190            new_derive.push(NodeOrToken::Token(token));
191        }
192    }
193    let handle_source_change = |edit: &mut SourceChangeBuilder| {
194        let editor = edit.make_editor(attr.syntax());
195        let make = editor.make();
196        let new_derive = make.attr_outer(
197            make.meta_token_tree(make.ident_path("derive"), make.token_tree(T!['('], new_derive)),
198        );
199        let meta = make.cfg_attr_meta(
200            make.cfg_flag("cfg"),
201            [make.meta_token_tree(
202                make.ident_path("derive"),
203                make.token_tree(T!['('], cfg_derive_tokens),
204            )],
205        );
206
207        let cfg_attr = make.attr_outer(meta.clone().into());
208        editor.replace_with_many(
209            attr.syntax(),
210            vec![
211                new_derive.syntax().clone().into(),
212                make.whitespace("\n").into(),
213                cfg_attr.syntax().clone().into(),
214            ],
215        );
216
217        if let Some(snippet_cap) = ctx.config.snippet_cap
218            && let Some(cfg_predicate) = meta.cfg_predicate()
219        {
220            let tabstop = edit.make_placeholder_snippet(snippet_cap);
221            editor.add_annotation(cfg_predicate.syntax(), tabstop);
222        }
223        edit.add_file_edits(ctx.vfs_file_id(), editor);
224    };
225
226    acc.add(
227        AssistId::refactor("wrap_unwrap_cfg_attr"),
228        format!("Wrap #[derive({path_text})] in `cfg_attr`",),
229        range,
230        handle_source_change,
231    );
232    Some(())
233}
234
235fn wrap_cfg_attrs(
236    acc: &mut Assists,
237    ctx: &AssistContext<'_, '_>,
238    attrs: Vec<ast::Attr>,
239) -> Option<()> {
240    let (first_attr, last_attr) = (attrs.first()?, attrs.last()?);
241    let range = first_attr.syntax().text_range().cover(last_attr.syntax().text_range());
242    let handle_source_change = |edit: &mut SourceChangeBuilder| {
243        let editor = edit.make_editor(first_attr.syntax());
244        let make = editor.make();
245        let meta =
246            make.cfg_attr_meta(make.cfg_flag("cfg"), attrs.iter().filter_map(|attr| attr.meta()));
247        let cfg_attr = if first_attr.excl_token().is_some() {
248            make.attr_inner(meta.clone().into())
249        } else {
250            make.attr_outer(meta.clone().into())
251        };
252
253        let syntax_range = first_attr.syntax().clone().into()..=last_attr.syntax().clone().into();
254        editor.replace_all(syntax_range, vec![cfg_attr.syntax().clone().into()]);
255
256        if let Some(snippet_cap) = ctx.config.snippet_cap
257            && let Some(cfg_flag) = meta.cfg_predicate()
258        {
259            let tabstop = edit.make_placeholder_snippet(snippet_cap);
260            editor.add_annotation(cfg_flag.syntax(), tabstop);
261        }
262        edit.add_file_edits(ctx.vfs_file_id(), editor);
263    };
264    acc.add(
265        AssistId::refactor("wrap_unwrap_cfg_attr"),
266        "Convert to `cfg_attr`",
267        range,
268        handle_source_change,
269    );
270    Some(())
271}
272
273fn unwrap_cfg_attr(
274    acc: &mut Assists,
275    ctx: &AssistContext<'_, '_>,
276    meta: ast::CfgAttrMeta,
277) -> Option<()> {
278    let top_attr = ast::Meta::from(meta.clone()).parent_attr()?;
279    let range = top_attr.syntax().text_range();
280    let inner_metas: Vec<ast::Meta> = meta.metas().collect();
281    if inner_metas.is_empty() {
282        return None;
283    }
284    let is_inner = top_attr.excl_token().is_some();
285    let indent = top_attr.indent_level();
286    acc.add(
287        AssistId::refactor("wrap_unwrap_cfg_attr"),
288        "Extract Inner Attributes from `cfg_attr`",
289        range,
290        |builder: &mut SourceChangeBuilder| {
291            let editor = builder.make_editor(top_attr.syntax());
292            let make = editor.make();
293            let mut elements = vec![];
294            for (i, meta) in inner_metas.into_iter().enumerate() {
295                if i > 0 {
296                    elements.push(make.whitespace(&format!("\n{indent}")).into());
297                }
298                let attr = if is_inner { make.attr_inner(meta) } else { make.attr_outer(meta) };
299                elements.push(attr.syntax().clone().into());
300            }
301            editor.replace_with_many(top_attr.syntax(), elements);
302            builder.add_file_edits(ctx.vfs_file_id(), editor);
303        },
304    );
305    Some(())
306}
307
308#[cfg(test)]
309mod tests {
310    use crate::tests::check_assist;
311
312    use super::*;
313
314    #[test]
315    fn test_basic_to_from_cfg_attr() {
316        check_assist(
317            wrap_unwrap_cfg_attr,
318            r#"
319            #[derive$0(Debug)]
320            pub struct Test {
321                test: u32,
322            }
323            "#,
324            r#"
325            #[cfg_attr(${0:cfg}, derive(Debug))]
326            pub struct Test {
327                test: u32,
328            }
329            "#,
330        );
331        check_assist(
332            wrap_unwrap_cfg_attr,
333            r#"
334            #[cfg_attr(debug_assertions, $0 derive(Debug))]
335            pub struct Test {
336                test: u32,
337            }
338            "#,
339            r#"
340            #[derive(Debug)]
341            pub struct Test {
342                test: u32,
343            }
344            "#,
345        );
346    }
347    #[test]
348    fn to_from_path_attr() {
349        check_assist(
350            wrap_unwrap_cfg_attr,
351            r#"
352            pub struct Test {
353                #[foo$0]
354                test: u32,
355            }
356            "#,
357            r#"
358            pub struct Test {
359                #[cfg_attr(${0:cfg}, foo)]
360                test: u32,
361            }
362            "#,
363        );
364        check_assist(
365            wrap_unwrap_cfg_attr,
366            r#"
367            pub struct Test {
368                #[cfg_attr(debug_assertions$0, foo)]
369                test: u32,
370            }
371            "#,
372            r#"
373            pub struct Test {
374                #[foo]
375                test: u32,
376            }
377            "#,
378        );
379        check_assist(
380            wrap_unwrap_cfg_attr,
381            r#"
382            pub struct Test {
383                #[other_attr]
384                $0#[foo]
385                #[bar]$0
386                #[other_attr]
387                test: u32,
388            }
389            "#,
390            r#"
391            pub struct Test {
392                #[other_attr]
393                #[cfg_attr(${0:cfg}, foo, bar)]
394                #[other_attr]
395                test: u32,
396            }
397            "#,
398        );
399        check_assist(
400            wrap_unwrap_cfg_attr,
401            r#"
402            pub struct Test {
403                #[cfg_attr(debug_assertions$0, foo, bar)]
404                test: u32,
405            }
406            "#,
407            r#"
408            pub struct Test {
409                #[foo]
410                #[bar]
411                test: u32,
412            }
413            "#,
414        );
415    }
416    #[test]
417    fn to_from_eq_attr() {
418        check_assist(
419            wrap_unwrap_cfg_attr,
420            r#"
421            pub struct Test {
422                #[foo = "bar"$0]
423                test: u32,
424            }
425            "#,
426            r#"
427            pub struct Test {
428                #[cfg_attr(${0:cfg}, foo = "bar")]
429                test: u32,
430            }
431            "#,
432        );
433        check_assist(
434            wrap_unwrap_cfg_attr,
435            r#"
436            pub struct Test {
437                #[cfg_attr(debug_assertions$0, foo = "bar")]
438                test: u32,
439            }
440            "#,
441            r#"
442            pub struct Test {
443                #[foo = "bar"]
444                test: u32,
445            }
446            "#,
447        );
448    }
449    #[test]
450    fn inner_attrs() {
451        check_assist(
452            wrap_unwrap_cfg_attr,
453            r#"
454            #![no_std$0]
455            "#,
456            r#"
457            #![cfg_attr(${0:cfg}, no_std)]
458            "#,
459        );
460        check_assist(
461            wrap_unwrap_cfg_attr,
462            r#"
463            #![cfg_attr(not(feature = "std")$0, no_std)]
464            "#,
465            r#"
466            #![no_std]
467            "#,
468        );
469    }
470    #[test]
471    fn test_derive_wrap() {
472        check_assist(
473            wrap_unwrap_cfg_attr,
474            r#"
475            #[derive(Debug$0, Clone, Copy)]
476            pub struct Test {
477                test: u32,
478            }
479            "#,
480            r#"
481            #[derive( Clone, Copy)]
482            #[cfg_attr(${0:cfg}, derive(Debug))]
483            pub struct Test {
484                test: u32,
485            }
486            "#,
487        );
488        check_assist(
489            wrap_unwrap_cfg_attr,
490            r#"
491            #[derive(Clone, Debug$0, Copy)]
492            pub struct Test {
493                test: u32,
494            }
495            "#,
496            r#"
497            #[derive(Clone,  Copy)]
498            #[cfg_attr(${0:cfg}, derive(Debug))]
499            pub struct Test {
500                test: u32,
501            }
502            "#,
503        );
504    }
505    #[test]
506    fn test_derive_wrap_with_path() {
507        check_assist(
508            wrap_unwrap_cfg_attr,
509            r#"
510            #[derive(std::fmt::Debug$0, Clone, Copy)]
511            pub struct Test {
512                test: u32,
513            }
514            "#,
515            r#"
516            #[derive( Clone, Copy)]
517            #[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
518            pub struct Test {
519                test: u32,
520            }
521            "#,
522        );
523        check_assist(
524            wrap_unwrap_cfg_attr,
525            r#"
526            #[derive(Clone, std::fmt::Debug$0, Copy)]
527            pub struct Test {
528                test: u32,
529            }
530            "#,
531            r#"
532            #[derive(Clone, Copy)]
533            #[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
534            pub struct Test {
535                test: u32,
536            }
537            "#,
538        );
539    }
540    #[test]
541    fn test_derive_wrap_at_end() {
542        check_assist(
543            wrap_unwrap_cfg_attr,
544            r#"
545            #[derive(std::fmt::Debug, Clone, Cop$0y)]
546            pub struct Test {
547                test: u32,
548            }
549            "#,
550            r#"
551            #[derive(std::fmt::Debug, Clone)]
552            #[cfg_attr(${0:cfg}, derive(Copy))]
553            pub struct Test {
554                test: u32,
555            }
556            "#,
557        );
558        check_assist(
559            wrap_unwrap_cfg_attr,
560            r#"
561            #[derive(Clone, Copy, std::fmt::D$0ebug)]
562            pub struct Test {
563                test: u32,
564            }
565            "#,
566            r#"
567            #[derive(Clone, Copy)]
568            #[cfg_attr(${0:cfg}, derive(std::fmt::Debug))]
569            pub struct Test {
570                test: u32,
571            }
572            "#,
573        );
574    }
575}