syntax/syntax_editor/
edits.rs

1//! Structural editing for ast using `SyntaxEditor`
2
3use crate::{
4    AstToken, Direction, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
5    algo::neighbor,
6    ast::{
7        self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make,
8        syntax_factory::SyntaxFactory,
9    },
10    syntax_editor::{Position, SyntaxEditor},
11};
12
13impl SyntaxEditor {
14    /// Adds a new generic param to the function using `SyntaxEditor`
15    pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {
16        match function.generic_param_list() {
17            Some(generic_param_list) => match generic_param_list.generic_params().last() {
18                Some(last_param) => {
19                    // There exists a generic param list and it's not empty
20                    let position = generic_param_list.r_angle_token().map_or_else(
21                        || Position::last_child_of(function.syntax()),
22                        Position::before,
23                    );
24
25                    if last_param
26                        .syntax()
27                        .next_sibling_or_token()
28                        .is_some_and(|it| it.kind() == SyntaxKind::COMMA)
29                    {
30                        self.insert(
31                            Position::after(last_param.syntax()),
32                            new_param.syntax().clone(),
33                        );
34                        self.insert(
35                            Position::after(last_param.syntax()),
36                            make::token(SyntaxKind::WHITESPACE),
37                        );
38                        self.insert(
39                            Position::after(last_param.syntax()),
40                            make::token(SyntaxKind::COMMA),
41                        );
42                    } else {
43                        let elements = vec![
44                            make::token(SyntaxKind::COMMA).into(),
45                            make::token(SyntaxKind::WHITESPACE).into(),
46                            new_param.syntax().clone().into(),
47                        ];
48                        self.insert_all(position, elements);
49                    }
50                }
51                None => {
52                    // There exists a generic param list but it's empty
53                    let position = Position::after(generic_param_list.l_angle_token().unwrap());
54                    self.insert(position, new_param.syntax());
55                }
56            },
57            None => {
58                // There was no generic param list
59                let position = if let Some(name) = function.name() {
60                    Position::after(name.syntax)
61                } else if let Some(fn_token) = function.fn_token() {
62                    Position::after(fn_token)
63                } else if let Some(param_list) = function.param_list() {
64                    Position::before(param_list.syntax)
65                } else {
66                    Position::last_child_of(function.syntax())
67                };
68                let elements = vec![
69                    make::token(SyntaxKind::L_ANGLE).into(),
70                    new_param.syntax().clone().into(),
71                    make::token(SyntaxKind::R_ANGLE).into(),
72                ];
73                self.insert_all(position, elements);
74            }
75        }
76    }
77}
78
79fn get_or_insert_comma_after(editor: &mut SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken {
80    let make = SyntaxFactory::without_mappings();
81    match syntax
82        .siblings_with_tokens(Direction::Next)
83        .filter_map(|it| it.into_token())
84        .find(|it| it.kind() == T![,])
85    {
86        Some(it) => it,
87        None => {
88            let comma = make.token(T![,]);
89            editor.insert(Position::after(syntax), &comma);
90            comma
91        }
92    }
93}
94
95impl ast::AssocItemList {
96    /// Adds a new associated item after all of the existing associated items.
97    ///
98    /// Attention! This function does align the first line of `item` with respect to `self`,
99    /// but it does _not_ change indentation of other lines (if any).
100    pub fn add_items(&self, editor: &mut SyntaxEditor, items: Vec<ast::AssocItem>) {
101        let (indent, position, whitespace) = match self.assoc_items().last() {
102            Some(last_item) => (
103                IndentLevel::from_node(last_item.syntax()),
104                Position::after(last_item.syntax()),
105                "\n\n",
106            ),
107            None => match self.l_curly_token() {
108                Some(l_curly) => {
109                    normalize_ws_between_braces(editor, self.syntax());
110                    (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n")
111                }
112                None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
113            },
114        };
115
116        let elements: Vec<SyntaxElement> = items
117            .into_iter()
118            .enumerate()
119            .flat_map(|(i, item)| {
120                let whitespace = if i != 0 { "\n\n" } else { whitespace };
121                vec![
122                    make::tokens::whitespace(&format!("{whitespace}{indent}")).into(),
123                    item.syntax().clone().into(),
124                ]
125            })
126            .collect();
127        editor.insert_all(position, elements);
128    }
129}
130
131impl ast::VariantList {
132    pub fn add_variant(&self, editor: &mut SyntaxEditor, variant: &ast::Variant) {
133        let make = SyntaxFactory::without_mappings();
134        let (indent, position) = match self.variants().last() {
135            Some(last_item) => (
136                IndentLevel::from_node(last_item.syntax()),
137                Position::after(get_or_insert_comma_after(editor, last_item.syntax())),
138            ),
139            None => match self.l_curly_token() {
140                Some(l_curly) => {
141                    normalize_ws_between_braces(editor, self.syntax());
142                    (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
143                }
144                None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
145            },
146        };
147        let elements: Vec<SyntaxElement> = vec![
148            make.whitespace(&format!("{}{indent}", "\n")).into(),
149            variant.syntax().clone().into(),
150            make.token(T![,]).into(),
151        ];
152        editor.insert_all(position, elements);
153    }
154}
155
156impl ast::Fn {
157    pub fn replace_or_insert_body(&self, editor: &mut SyntaxEditor, body: ast::BlockExpr) {
158        if let Some(old_body) = self.body() {
159            editor.replace(old_body.syntax(), body.syntax());
160        } else {
161            let single_space = make::tokens::single_space();
162            let elements = vec![single_space.into(), body.syntax().clone().into()];
163
164            if let Some(semicolon) = self.semicolon_token() {
165                editor.replace_with_many(semicolon, elements);
166            } else {
167                editor.insert_all(Position::last_child_of(self.syntax()), elements);
168            }
169        }
170    }
171}
172
173fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> Option<()> {
174    let make = SyntaxFactory::without_mappings();
175    let l = node
176        .children_with_tokens()
177        .filter_map(|it| it.into_token())
178        .find(|it| it.kind() == T!['{'])?;
179    let r = node
180        .children_with_tokens()
181        .filter_map(|it| it.into_token())
182        .find(|it| it.kind() == T!['}'])?;
183
184    let indent = IndentLevel::from_node(node);
185
186    match l.next_sibling_or_token() {
187        Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
188            if ws.next_sibling_or_token()?.into_token()? == r {
189                editor.replace(ws, make.whitespace(&format!("\n{indent}")));
190            }
191        }
192        Some(ws) if ws.kind() == T!['}'] => {
193            editor.insert(Position::after(l), make.whitespace(&format!("\n{indent}")));
194        }
195        _ => (),
196    }
197    Some(())
198}
199
200pub trait Removable: AstNode {
201    fn remove(&self, editor: &mut SyntaxEditor);
202}
203
204impl Removable for ast::TypeBoundList {
205    fn remove(&self, editor: &mut SyntaxEditor) {
206        match self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:]) {
207            Some(colon) => editor.delete_all(colon..=self.syntax().clone().into()),
208            None => editor.delete(self.syntax()),
209        }
210    }
211}
212
213impl Removable for ast::Use {
214    fn remove(&self, editor: &mut SyntaxEditor) {
215        let make = SyntaxFactory::without_mappings();
216
217        let next_ws = self
218            .syntax()
219            .next_sibling_or_token()
220            .and_then(|it| it.into_token())
221            .and_then(ast::Whitespace::cast);
222        if let Some(next_ws) = next_ws {
223            let ws_text = next_ws.syntax().text();
224            if let Some(rest) = ws_text.strip_prefix('\n') {
225                if rest.is_empty() {
226                    editor.delete(next_ws.syntax());
227                } else {
228                    editor.replace(next_ws.syntax(), make.whitespace(rest));
229                }
230            }
231        }
232
233        editor.delete(self.syntax());
234    }
235}
236
237impl Removable for ast::UseTree {
238    fn remove(&self, editor: &mut SyntaxEditor) {
239        for dir in [Direction::Next, Direction::Prev] {
240            if let Some(next_use_tree) = neighbor(self, dir) {
241                let separators = self
242                    .syntax()
243                    .siblings_with_tokens(dir)
244                    .skip(1)
245                    .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
246                for sep in separators {
247                    editor.delete(sep);
248                }
249                break;
250            }
251        }
252        editor.delete(self.syntax());
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use parser::Edition;
259    use stdx::trim_indent;
260    use test_utils::assert_eq_text;
261
262    use crate::SourceFile;
263
264    use super::*;
265
266    fn ast_from_text<N: AstNode>(text: &str) -> N {
267        let parse = SourceFile::parse(text, Edition::CURRENT);
268        let node = match parse.tree().syntax().descendants().find_map(N::cast) {
269            Some(it) => it,
270            None => {
271                let node = std::any::type_name::<N>();
272                panic!("Failed to make ast node `{node}` from text {text}")
273            }
274        };
275        let node = node.clone_subtree();
276        assert_eq!(node.syntax().text_range().start(), 0.into());
277        node
278    }
279
280    #[test]
281    fn add_variant_to_empty_enum() {
282        let make = SyntaxFactory::without_mappings();
283        let variant = make.variant(None, make.name("Bar"), None, None);
284
285        check_add_variant(
286            r#"
287enum Foo {}
288"#,
289            r#"
290enum Foo {
291    Bar,
292}
293"#,
294            variant,
295        );
296    }
297
298    #[test]
299    fn add_variant_to_non_empty_enum() {
300        let make = SyntaxFactory::without_mappings();
301        let variant = make.variant(None, make.name("Baz"), None, None);
302
303        check_add_variant(
304            r#"
305enum Foo {
306    Bar,
307}
308"#,
309            r#"
310enum Foo {
311    Bar,
312    Baz,
313}
314"#,
315            variant,
316        );
317    }
318
319    #[test]
320    fn add_variant_with_tuple_field_list() {
321        let make = SyntaxFactory::without_mappings();
322        let variant = make.variant(
323            None,
324            make.name("Baz"),
325            Some(make.tuple_field_list([make.tuple_field(None, make.ty("bool"))]).into()),
326            None,
327        );
328
329        check_add_variant(
330            r#"
331enum Foo {
332    Bar,
333}
334"#,
335            r#"
336enum Foo {
337    Bar,
338    Baz(bool),
339}
340"#,
341            variant,
342        );
343    }
344
345    #[test]
346    fn add_variant_with_record_field_list() {
347        let make = SyntaxFactory::without_mappings();
348        let variant = make.variant(
349            None,
350            make.name("Baz"),
351            Some(
352                make.record_field_list([make.record_field(None, make.name("x"), make.ty("bool"))])
353                    .into(),
354            ),
355            None,
356        );
357
358        check_add_variant(
359            r#"
360enum Foo {
361    Bar,
362}
363"#,
364            r#"
365enum Foo {
366    Bar,
367    Baz { x: bool },
368}
369"#,
370            variant,
371        );
372    }
373
374    fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
375        let enum_ = ast_from_text::<ast::Enum>(before);
376        let mut editor = SyntaxEditor::new(enum_.syntax().clone());
377        if let Some(it) = enum_.variant_list() {
378            it.add_variant(&mut editor, &variant)
379        }
380        let edit = editor.finish();
381        let after = edit.new_root.to_string();
382        assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(after.trim()));
383    }
384}