1use rustc_hash::FxHashMap;
3use syntax::{NodeOrToken, SyntaxElement, SyntaxNode};
4
5use crate::{FxIndexMap, text_edit::TextEditBuilder};
6
7#[derive(Debug, Hash, PartialEq, Eq)]
8enum TreeDiffInsertPos {
9 After(SyntaxElement),
10 AsFirstChild(SyntaxElement),
11}
12
13#[derive(Debug)]
14pub struct TreeDiff {
15 replacements: FxHashMap<SyntaxElement, SyntaxElement>,
16 deletions: Vec<SyntaxElement>,
17 insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
19}
20
21impl TreeDiff {
22 pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
23 let _p = tracing::info_span!("into_text_edit").entered();
24
25 for (anchor, to) in &self.insertions {
26 let offset = match anchor {
27 TreeDiffInsertPos::After(it) => it.text_range().end(),
28 TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
29 };
30 to.iter().for_each(|to| builder.insert(offset, to.to_string()));
31 }
32 for (from, to) in &self.replacements {
33 builder.replace(from.text_range(), to.to_string());
34 }
35 for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
36 builder.delete(text_range);
37 }
38 }
39
40 pub fn is_empty(&self) -> bool {
41 self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
42 }
43}
44
45pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
52 let _p = tracing::info_span!("diff").entered();
53
54 let mut diff = TreeDiff {
55 replacements: FxHashMap::default(),
56 insertions: FxIndexMap::default(),
57 deletions: Vec::new(),
58 };
59 let (from, to) = (from.clone().into(), to.clone().into());
60
61 if !syntax_element_eq(&from, &to) {
62 go(&mut diff, from, to);
63 }
64 return diff;
65
66 fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
67 lhs.kind() == rhs.kind()
68 && lhs.text_range().len() == rhs.text_range().len()
69 && match (&lhs, &rhs) {
70 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
71 lhs == rhs || lhs.text() == rhs.text()
72 }
73 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
74 _ => false,
75 }
76 }
77
78 fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
80 let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
81 Some((lhs, rhs)) => (lhs, rhs),
82 _ => {
83 cov_mark::hit!(diff_node_token_replace);
84 diff.replacements.insert(lhs, rhs);
85 return;
86 }
87 };
88
89 let mut look_ahead_scratch = Vec::default();
90
91 let mut rhs_children = rhs.children_with_tokens();
92 let mut lhs_children = lhs.children_with_tokens();
93 let mut last_lhs = None;
94 loop {
95 let lhs_child = lhs_children.next();
96 match (lhs_child.clone(), rhs_children.next()) {
97 (None, None) => break,
98 (None, Some(element)) => {
99 let insert_pos = match last_lhs.clone() {
100 Some(prev) => {
101 cov_mark::hit!(diff_insert);
102 TreeDiffInsertPos::After(prev)
103 }
104 None => {
106 cov_mark::hit!(diff_insert_as_first_child);
107 TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
108 }
109 };
110 diff.insertions.entry(insert_pos).or_default().push(element);
111 }
112 (Some(element), None) => {
113 cov_mark::hit!(diff_delete);
114 diff.deletions.push(element);
115 }
116 (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
117 (Some(lhs_ele), Some(rhs_ele)) => {
118 look_ahead_scratch.push(rhs_ele.clone());
123 let mut rhs_children_clone = rhs_children.clone();
124 let mut insert = false;
125 for rhs_child in &mut rhs_children_clone {
126 if syntax_element_eq(&lhs_ele, &rhs_child) {
127 cov_mark::hit!(diff_insertions);
128 insert = true;
129 break;
130 }
131 look_ahead_scratch.push(rhs_child);
132 }
133 let drain = look_ahead_scratch.drain(..);
134 if insert {
135 let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
136 TreeDiffInsertPos::After(prev)
137 } else {
138 cov_mark::hit!(insert_first_child);
139 TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
140 };
141
142 diff.insertions.entry(insert_pos).or_default().extend(drain);
143 rhs_children = rhs_children_clone;
144 } else {
145 go(diff, lhs_ele, rhs_ele);
146 }
147 }
148 }
149 last_lhs = lhs_child.or(last_lhs);
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use expect_test::{Expect, expect};
157 use itertools::Itertools;
158 use parser::{Edition, SyntaxKind};
159 use syntax::{AstNode, SourceFile, SyntaxElement};
160
161 use crate::text_edit::TextEdit;
162
163 #[test]
164 fn replace_node_token() {
165 cov_mark::check!(diff_node_token_replace);
166 check_diff(
167 r#"use node;"#,
168 r#"ident"#,
169 expect![[r#"
170 insertions:
171
172
173
174 replacements:
175
176 Line 0: Token(USE_KW@0..3 "use") -> ident
177
178 deletions:
179
180 Line 1: " "
181 Line 1: node
182 Line 1: ;
183 "#]],
184 );
185 }
186
187 #[test]
188 fn replace_parent() {
189 cov_mark::check!(diff_insert_as_first_child);
190 check_diff(
191 r#""#,
192 r#"use foo::bar;"#,
193 expect![[r#"
194 insertions:
195
196 Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
197 -> use foo::bar;
198
199 replacements:
200
201
202
203 deletions:
204
205
206 "#]],
207 );
208 }
209
210 #[test]
211 fn insert_last() {
212 cov_mark::check!(diff_insert);
213 check_diff(
214 r#"
215use foo;
216use bar;"#,
217 r#"
218use foo;
219use bar;
220use baz;"#,
221 expect![[r#"
222 insertions:
223
224 Line 2: After(Node(USE@10..18))
225 -> "\n"
226 -> use baz;
227
228 replacements:
229
230
231
232 deletions:
233
234
235 "#]],
236 );
237 }
238
239 #[test]
240 fn insert_middle() {
241 check_diff(
242 r#"
243use foo;
244use baz;"#,
245 r#"
246use foo;
247use bar;
248use baz;"#,
249 expect![[r#"
250 insertions:
251
252 Line 2: After(Token(WHITESPACE@9..10 "\n"))
253 -> use bar;
254 -> "\n"
255
256 replacements:
257
258
259
260 deletions:
261
262
263 "#]],
264 )
265 }
266
267 #[test]
268 fn insert_first() {
269 check_diff(
270 r#"
271use bar;
272use baz;"#,
273 r#"
274use foo;
275use bar;
276use baz;"#,
277 expect![[r#"
278 insertions:
279
280 Line 0: After(Token(WHITESPACE@0..1 "\n"))
281 -> use foo;
282 -> "\n"
283
284 replacements:
285
286
287
288 deletions:
289
290
291 "#]],
292 )
293 }
294
295 #[test]
296 fn first_child_insertion() {
297 cov_mark::check!(insert_first_child);
298 check_diff(
299 r#"fn main() {
300 stdi
301 }"#,
302 r#"use foo::bar;
303
304 fn main() {
305 stdi
306 }"#,
307 expect![[r#"
308 insertions:
309
310 Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
311 -> use foo::bar;
312 -> "\n\n "
313
314 replacements:
315
316
317
318 deletions:
319
320
321 "#]],
322 );
323 }
324
325 #[test]
326 fn delete_last() {
327 cov_mark::check!(diff_delete);
328 check_diff(
329 r#"use foo;
330 use bar;"#,
331 r#"use foo;"#,
332 expect![[r#"
333 insertions:
334
335
336
337 replacements:
338
339
340
341 deletions:
342
343 Line 1: "\n "
344 Line 2: use bar;
345 "#]],
346 );
347 }
348
349 #[test]
350 fn delete_middle() {
351 cov_mark::check!(diff_insertions);
352 check_diff(
353 r#"
354use expect_test::{expect, Expect};
355use text_edit::TextEdit;
356
357use crate::AstNode;
358"#,
359 r#"
360use expect_test::{expect, Expect};
361
362use crate::AstNode;
363"#,
364 expect![[r#"
365 insertions:
366
367 Line 1: After(Node(USE@1..35))
368 -> "\n\n"
369 -> use crate::AstNode;
370
371 replacements:
372
373
374
375 deletions:
376
377 Line 2: use text_edit::TextEdit;
378 Line 3: "\n\n"
379 Line 4: use crate::AstNode;
380 Line 5: "\n"
381 "#]],
382 )
383 }
384
385 #[test]
386 fn delete_first() {
387 check_diff(
388 r#"
389use text_edit::TextEdit;
390
391use crate::AstNode;
392"#,
393 r#"
394use crate::AstNode;
395"#,
396 expect![[r#"
397 insertions:
398
399
400
401 replacements:
402
403 Line 2: Token(IDENT@5..14 "text_edit") -> crate
404 Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
405 Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
406
407 deletions:
408
409 Line 3: use crate::AstNode;
410 Line 4: "\n"
411 "#]],
412 )
413 }
414
415 #[test]
416 fn merge_use() {
417 check_diff(
418 r#"
419use std::{
420 fmt,
421 hash::BuildHasherDefault,
422 ops::{self, RangeInclusive},
423};
424"#,
425 r#"
426use std::fmt;
427use std::hash::BuildHasherDefault;
428use std::ops::{self, RangeInclusive};
429"#,
430 expect![[r#"
431 insertions:
432
433 Line 2: After(Node(PATH_SEGMENT@5..8))
434 -> ::
435 -> fmt
436 Line 6: After(Token(WHITESPACE@86..87 "\n"))
437 -> use std::hash::BuildHasherDefault;
438 -> "\n"
439 -> use std::ops::{self, RangeInclusive};
440 -> "\n"
441
442 replacements:
443
444 Line 2: Token(IDENT@5..8 "std") -> std
445
446 deletions:
447
448 Line 2: ::
449 Line 2: {
450 fmt,
451 hash::BuildHasherDefault,
452 ops::{self, RangeInclusive},
453 }
454 "#]],
455 )
456 }
457
458 #[test]
459 fn early_return_assist() {
460 check_diff(
461 r#"
462fn main() {
463 if let Ok(x) = Err(92) {
464 foo(x);
465 }
466}
467 "#,
468 r#"
469fn main() {
470 let x = match Err(92) {
471 Ok(it) => it,
472 _ => return,
473 };
474 foo(x);
475}
476 "#,
477 expect![[r#"
478 insertions:
479
480 Line 3: After(Node(BLOCK_EXPR@40..63))
481 -> " "
482 -> match Err(92) {
483 Ok(it) => it,
484 _ => return,
485 }
486 -> ;
487 Line 3: After(Node(IF_EXPR@17..63))
488 -> "\n "
489 -> foo(x);
490
491 replacements:
492
493 Line 3: Token(IF_KW@17..19 "if") -> let
494 Line 3: Token(LET_KW@20..23 "let") -> x
495 Line 3: Node(BLOCK_EXPR@40..63) -> =
496
497 deletions:
498
499 Line 3: " "
500 Line 3: Ok(x)
501 Line 3: " "
502 Line 3: =
503 Line 3: " "
504 Line 3: Err(92)
505 "#]],
506 )
507 }
508
509 fn check_diff(from: &str, to: &str, expected_diff: Expect) {
510 let from_node = SourceFile::parse(from, Edition::CURRENT).tree().syntax().clone();
511 let to_node = SourceFile::parse(to, Edition::CURRENT).tree().syntax().clone();
512 let diff = super::diff(&from_node, &to_node);
513
514 let line_number =
515 |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
516
517 let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
518 SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
519 _ => format!("{syn}"),
520 };
521
522 let insertions =
523 diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
524 f(&format!(
525 "Line {}: {:?}\n-> {}",
526 line_number(match k {
527 super::TreeDiffInsertPos::After(syn) => syn,
528 super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
529 }),
530 k,
531 v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
532 ))
533 });
534
535 let replacements = diff
536 .replacements
537 .iter()
538 .sorted_by_key(|(syntax, _)| syntax.text_range().start())
539 .format_with("\n", |(k, v), f| {
540 f(&format!("Line {}: {k:?} -> {}", line_number(k), fmt_syntax(v)))
541 });
542
543 let deletions = diff
544 .deletions
545 .iter()
546 .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), fmt_syntax(v))));
547
548 let actual = format!(
549 "insertions:\n\n{insertions}\n\nreplacements:\n\n{replacements}\n\ndeletions:\n\n{deletions}\n"
550 );
551 expected_diff.assert_eq(&actual);
552
553 let mut from = from.to_owned();
554 let mut text_edit = TextEdit::builder();
555 diff.into_text_edit(&mut text_edit);
556 text_edit.finish().apply(&mut from);
557 assert_eq!(&*from, to, "diff did not turn `from` to `to`");
558 }
559}