1use crate::{
4 AstToken, Direction, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
5 algo::neighbor,
6 ast::{
7 self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make,
8 syntax_factory::SyntaxFactory,
9 },
10 syntax_editor::{Position, SyntaxEditor},
11};
12
13impl SyntaxEditor {
14 pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {
16 match function.generic_param_list() {
17 Some(generic_param_list) => match generic_param_list.generic_params().last() {
18 Some(last_param) => {
19 let position = generic_param_list.r_angle_token().map_or_else(
21 || Position::last_child_of(function.syntax()),
22 Position::before,
23 );
24
25 if last_param
26 .syntax()
27 .next_sibling_or_token()
28 .is_some_and(|it| it.kind() == SyntaxKind::COMMA)
29 {
30 self.insert(
31 Position::after(last_param.syntax()),
32 new_param.syntax().clone(),
33 );
34 self.insert(
35 Position::after(last_param.syntax()),
36 make::token(SyntaxKind::WHITESPACE),
37 );
38 self.insert(
39 Position::after(last_param.syntax()),
40 make::token(SyntaxKind::COMMA),
41 );
42 } else {
43 let elements = vec![
44 make::token(SyntaxKind::COMMA).into(),
45 make::token(SyntaxKind::WHITESPACE).into(),
46 new_param.syntax().clone().into(),
47 ];
48 self.insert_all(position, elements);
49 }
50 }
51 None => {
52 let position = Position::after(generic_param_list.l_angle_token().unwrap());
54 self.insert(position, new_param.syntax());
55 }
56 },
57 None => {
58 let position = if let Some(name) = function.name() {
60 Position::after(name.syntax)
61 } else if let Some(fn_token) = function.fn_token() {
62 Position::after(fn_token)
63 } else if let Some(param_list) = function.param_list() {
64 Position::before(param_list.syntax)
65 } else {
66 Position::last_child_of(function.syntax())
67 };
68 let elements = vec![
69 make::token(SyntaxKind::L_ANGLE).into(),
70 new_param.syntax().clone().into(),
71 make::token(SyntaxKind::R_ANGLE).into(),
72 ];
73 self.insert_all(position, elements);
74 }
75 }
76 }
77}
78
79fn get_or_insert_comma_after(editor: &mut SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken {
80 let make = SyntaxFactory::without_mappings();
81 match syntax
82 .siblings_with_tokens(Direction::Next)
83 .filter_map(|it| it.into_token())
84 .find(|it| it.kind() == T![,])
85 {
86 Some(it) => it,
87 None => {
88 let comma = make.token(T![,]);
89 editor.insert(Position::after(syntax), &comma);
90 comma
91 }
92 }
93}
94
95impl ast::AssocItemList {
96 pub fn add_items(&self, editor: &mut SyntaxEditor, items: Vec<ast::AssocItem>) {
101 let (indent, position, whitespace) = match self.assoc_items().last() {
102 Some(last_item) => (
103 IndentLevel::from_node(last_item.syntax()),
104 Position::after(last_item.syntax()),
105 "\n\n",
106 ),
107 None => match self.l_curly_token() {
108 Some(l_curly) => {
109 normalize_ws_between_braces(editor, self.syntax());
110 (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n")
111 }
112 None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
113 },
114 };
115
116 let elements: Vec<SyntaxElement> = items
117 .into_iter()
118 .enumerate()
119 .flat_map(|(i, item)| {
120 let whitespace = if i != 0 { "\n\n" } else { whitespace };
121 vec![
122 make::tokens::whitespace(&format!("{whitespace}{indent}")).into(),
123 item.syntax().clone().into(),
124 ]
125 })
126 .collect();
127 editor.insert_all(position, elements);
128 }
129}
130
131impl ast::VariantList {
132 pub fn add_variant(&self, editor: &mut SyntaxEditor, variant: &ast::Variant) {
133 let make = SyntaxFactory::without_mappings();
134 let (indent, position) = match self.variants().last() {
135 Some(last_item) => (
136 IndentLevel::from_node(last_item.syntax()),
137 Position::after(get_or_insert_comma_after(editor, last_item.syntax())),
138 ),
139 None => match self.l_curly_token() {
140 Some(l_curly) => {
141 normalize_ws_between_braces(editor, self.syntax());
142 (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
143 }
144 None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
145 },
146 };
147 let elements: Vec<SyntaxElement> = vec![
148 make.whitespace(&format!("{}{indent}", "\n")).into(),
149 variant.syntax().clone().into(),
150 make.token(T![,]).into(),
151 ];
152 editor.insert_all(position, elements);
153 }
154}
155
156impl ast::Fn {
157 pub fn replace_or_insert_body(&self, editor: &mut SyntaxEditor, body: ast::BlockExpr) {
158 if let Some(old_body) = self.body() {
159 editor.replace(old_body.syntax(), body.syntax());
160 } else {
161 let single_space = make::tokens::single_space();
162 let elements = vec![single_space.into(), body.syntax().clone().into()];
163
164 if let Some(semicolon) = self.semicolon_token() {
165 editor.replace_with_many(semicolon, elements);
166 } else {
167 editor.insert_all(Position::last_child_of(self.syntax()), elements);
168 }
169 }
170 }
171}
172
173fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> Option<()> {
174 let make = SyntaxFactory::without_mappings();
175 let l = node
176 .children_with_tokens()
177 .filter_map(|it| it.into_token())
178 .find(|it| it.kind() == T!['{'])?;
179 let r = node
180 .children_with_tokens()
181 .filter_map(|it| it.into_token())
182 .find(|it| it.kind() == T!['}'])?;
183
184 let indent = IndentLevel::from_node(node);
185
186 match l.next_sibling_or_token() {
187 Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
188 if ws.next_sibling_or_token()?.into_token()? == r {
189 editor.replace(ws, make.whitespace(&format!("\n{indent}")));
190 }
191 }
192 Some(ws) if ws.kind() == T!['}'] => {
193 editor.insert(Position::after(l), make.whitespace(&format!("\n{indent}")));
194 }
195 _ => (),
196 }
197 Some(())
198}
199
200pub trait Removable: AstNode {
201 fn remove(&self, editor: &mut SyntaxEditor);
202}
203
204impl Removable for ast::TypeBoundList {
205 fn remove(&self, editor: &mut SyntaxEditor) {
206 match self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:]) {
207 Some(colon) => editor.delete_all(colon..=self.syntax().clone().into()),
208 None => editor.delete(self.syntax()),
209 }
210 }
211}
212
213impl Removable for ast::Use {
214 fn remove(&self, editor: &mut SyntaxEditor) {
215 let make = SyntaxFactory::without_mappings();
216
217 let next_ws = self
218 .syntax()
219 .next_sibling_or_token()
220 .and_then(|it| it.into_token())
221 .and_then(ast::Whitespace::cast);
222 if let Some(next_ws) = next_ws {
223 let ws_text = next_ws.syntax().text();
224 if let Some(rest) = ws_text.strip_prefix('\n') {
225 if rest.is_empty() {
226 editor.delete(next_ws.syntax());
227 } else {
228 editor.replace(next_ws.syntax(), make.whitespace(rest));
229 }
230 }
231 }
232
233 editor.delete(self.syntax());
234 }
235}
236
237impl Removable for ast::UseTree {
238 fn remove(&self, editor: &mut SyntaxEditor) {
239 for dir in [Direction::Next, Direction::Prev] {
240 if let Some(next_use_tree) = neighbor(self, dir) {
241 let separators = self
242 .syntax()
243 .siblings_with_tokens(dir)
244 .skip(1)
245 .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
246 for sep in separators {
247 editor.delete(sep);
248 }
249 break;
250 }
251 }
252 editor.delete(self.syntax());
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use parser::Edition;
259 use stdx::trim_indent;
260 use test_utils::assert_eq_text;
261
262 use crate::SourceFile;
263
264 use super::*;
265
266 fn ast_from_text<N: AstNode>(text: &str) -> N {
267 let parse = SourceFile::parse(text, Edition::CURRENT);
268 let node = match parse.tree().syntax().descendants().find_map(N::cast) {
269 Some(it) => it,
270 None => {
271 let node = std::any::type_name::<N>();
272 panic!("Failed to make ast node `{node}` from text {text}")
273 }
274 };
275 let node = node.clone_subtree();
276 assert_eq!(node.syntax().text_range().start(), 0.into());
277 node
278 }
279
280 #[test]
281 fn add_variant_to_empty_enum() {
282 let make = SyntaxFactory::without_mappings();
283 let variant = make.variant(None, make.name("Bar"), None, None);
284
285 check_add_variant(
286 r#"
287enum Foo {}
288"#,
289 r#"
290enum Foo {
291 Bar,
292}
293"#,
294 variant,
295 );
296 }
297
298 #[test]
299 fn add_variant_to_non_empty_enum() {
300 let make = SyntaxFactory::without_mappings();
301 let variant = make.variant(None, make.name("Baz"), None, None);
302
303 check_add_variant(
304 r#"
305enum Foo {
306 Bar,
307}
308"#,
309 r#"
310enum Foo {
311 Bar,
312 Baz,
313}
314"#,
315 variant,
316 );
317 }
318
319 #[test]
320 fn add_variant_with_tuple_field_list() {
321 let make = SyntaxFactory::without_mappings();
322 let variant = make.variant(
323 None,
324 make.name("Baz"),
325 Some(make.tuple_field_list([make.tuple_field(None, make.ty("bool"))]).into()),
326 None,
327 );
328
329 check_add_variant(
330 r#"
331enum Foo {
332 Bar,
333}
334"#,
335 r#"
336enum Foo {
337 Bar,
338 Baz(bool),
339}
340"#,
341 variant,
342 );
343 }
344
345 #[test]
346 fn add_variant_with_record_field_list() {
347 let make = SyntaxFactory::without_mappings();
348 let variant = make.variant(
349 None,
350 make.name("Baz"),
351 Some(
352 make.record_field_list([make.record_field(None, make.name("x"), make.ty("bool"))])
353 .into(),
354 ),
355 None,
356 );
357
358 check_add_variant(
359 r#"
360enum Foo {
361 Bar,
362}
363"#,
364 r#"
365enum Foo {
366 Bar,
367 Baz { x: bool },
368}
369"#,
370 variant,
371 );
372 }
373
374 fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
375 let enum_ = ast_from_text::<ast::Enum>(before);
376 let mut editor = SyntaxEditor::new(enum_.syntax().clone());
377 if let Some(it) = enum_.variant_list() {
378 it.add_variant(&mut editor, &variant)
379 }
380 let edit = editor.finish();
381 let after = edit.new_root.to_string();
382 assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(after.trim()));
383 }
384}