ide_db/syntax_helpers/
tree_diff.rs

1//! Basic tree diffing functionality.
2use rustc_hash::FxHashMap;
3use syntax::{NodeOrToken, SyntaxElement, SyntaxNode};
4
5use crate::{FxIndexMap, text_edit::TextEditBuilder};
6
7#[derive(Debug, Hash, PartialEq, Eq)]
8enum TreeDiffInsertPos {
9    After(SyntaxElement),
10    AsFirstChild(SyntaxElement),
11}
12
13#[derive(Debug)]
14pub struct TreeDiff {
15    replacements: FxHashMap<SyntaxElement, SyntaxElement>,
16    deletions: Vec<SyntaxElement>,
17    // the vec as well as the indexmap are both here to preserve order
18    insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
19}
20
21impl TreeDiff {
22    pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
23        let _p = tracing::info_span!("into_text_edit").entered();
24
25        for (anchor, to) in &self.insertions {
26            let offset = match anchor {
27                TreeDiffInsertPos::After(it) => it.text_range().end(),
28                TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
29            };
30            to.iter().for_each(|to| builder.insert(offset, to.to_string()));
31        }
32        for (from, to) in &self.replacements {
33            builder.replace(from.text_range(), to.to_string());
34        }
35        for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
36            builder.delete(text_range);
37        }
38    }
39
40    pub fn is_empty(&self) -> bool {
41        self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
42    }
43}
44
45/// Finds a (potentially minimal) diff, which, applied to `from`, will result in `to`.
46///
47/// Specifically, returns a structure that consists of a replacements, insertions and deletions
48/// such that applying this map on `from` will result in `to`.
49///
50/// This function tries to find a fine-grained diff.
51pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
52    let _p = tracing::info_span!("diff").entered();
53
54    let mut diff = TreeDiff {
55        replacements: FxHashMap::default(),
56        insertions: FxIndexMap::default(),
57        deletions: Vec::new(),
58    };
59    let (from, to) = (from.clone().into(), to.clone().into());
60
61    if !syntax_element_eq(&from, &to) {
62        go(&mut diff, from, to);
63    }
64    return diff;
65
66    fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
67        lhs.kind() == rhs.kind()
68            && lhs.text_range().len() == rhs.text_range().len()
69            && match (&lhs, &rhs) {
70                (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
71                    lhs == rhs || lhs.text() == rhs.text()
72                }
73                (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
74                _ => false,
75            }
76    }
77
78    // FIXME: this is horribly inefficient. I bet there's a cool algorithm to diff trees properly.
79    fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
80        let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
81            Some((lhs, rhs)) => (lhs, rhs),
82            _ => {
83                cov_mark::hit!(diff_node_token_replace);
84                diff.replacements.insert(lhs, rhs);
85                return;
86            }
87        };
88
89        let mut look_ahead_scratch = Vec::default();
90
91        let mut rhs_children = rhs.children_with_tokens();
92        let mut lhs_children = lhs.children_with_tokens();
93        let mut last_lhs = None;
94        loop {
95            let lhs_child = lhs_children.next();
96            match (lhs_child.clone(), rhs_children.next()) {
97                (None, None) => break,
98                (None, Some(element)) => {
99                    let insert_pos = match last_lhs.clone() {
100                        Some(prev) => {
101                            cov_mark::hit!(diff_insert);
102                            TreeDiffInsertPos::After(prev)
103                        }
104                        // first iteration, insert into out parent as the first child
105                        None => {
106                            cov_mark::hit!(diff_insert_as_first_child);
107                            TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
108                        }
109                    };
110                    diff.insertions.entry(insert_pos).or_default().push(element);
111                }
112                (Some(element), None) => {
113                    cov_mark::hit!(diff_delete);
114                    diff.deletions.push(element);
115                }
116                (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
117                (Some(lhs_ele), Some(rhs_ele)) => {
118                    // nodes differ, look for lhs_ele in rhs, if its found we can mark everything up
119                    // until that element as insertions. This is important to keep the diff minimal
120                    // in regards to insertions that have been actually done, this is important for
121                    // use insertions as we do not want to replace the entire module node.
122                    look_ahead_scratch.push(rhs_ele.clone());
123                    let mut rhs_children_clone = rhs_children.clone();
124                    let mut insert = false;
125                    for rhs_child in &mut rhs_children_clone {
126                        if syntax_element_eq(&lhs_ele, &rhs_child) {
127                            cov_mark::hit!(diff_insertions);
128                            insert = true;
129                            break;
130                        }
131                        look_ahead_scratch.push(rhs_child);
132                    }
133                    let drain = look_ahead_scratch.drain(..);
134                    if insert {
135                        let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
136                            TreeDiffInsertPos::After(prev)
137                        } else {
138                            cov_mark::hit!(insert_first_child);
139                            TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
140                        };
141
142                        diff.insertions.entry(insert_pos).or_default().extend(drain);
143                        rhs_children = rhs_children_clone;
144                    } else {
145                        go(diff, lhs_ele, rhs_ele);
146                    }
147                }
148            }
149            last_lhs = lhs_child.or(last_lhs);
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use expect_test::{Expect, expect};
157    use itertools::Itertools;
158    use parser::{Edition, SyntaxKind};
159    use syntax::{AstNode, SourceFile, SyntaxElement};
160
161    use crate::text_edit::TextEdit;
162
163    #[test]
164    fn replace_node_token() {
165        cov_mark::check!(diff_node_token_replace);
166        check_diff(
167            r#"use node;"#,
168            r#"ident"#,
169            expect![[r#"
170                insertions:
171
172
173
174                replacements:
175
176                Line 0: Token(USE_KW@0..3 "use") -> ident
177
178                deletions:
179
180                Line 1: " "
181                Line 1: node
182                Line 1: ;
183            "#]],
184        );
185    }
186
187    #[test]
188    fn replace_parent() {
189        cov_mark::check!(diff_insert_as_first_child);
190        check_diff(
191            r#""#,
192            r#"use foo::bar;"#,
193            expect![[r#"
194                insertions:
195
196                Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
197                -> use foo::bar;
198
199                replacements:
200
201
202
203                deletions:
204
205
206            "#]],
207        );
208    }
209
210    #[test]
211    fn insert_last() {
212        cov_mark::check!(diff_insert);
213        check_diff(
214            r#"
215use foo;
216use bar;"#,
217            r#"
218use foo;
219use bar;
220use baz;"#,
221            expect![[r#"
222                insertions:
223
224                Line 2: After(Node(USE@10..18))
225                -> "\n"
226                -> use baz;
227
228                replacements:
229
230
231
232                deletions:
233
234
235            "#]],
236        );
237    }
238
239    #[test]
240    fn insert_middle() {
241        check_diff(
242            r#"
243use foo;
244use baz;"#,
245            r#"
246use foo;
247use bar;
248use baz;"#,
249            expect![[r#"
250                insertions:
251
252                Line 2: After(Token(WHITESPACE@9..10 "\n"))
253                -> use bar;
254                -> "\n"
255
256                replacements:
257
258
259
260                deletions:
261
262
263            "#]],
264        )
265    }
266
267    #[test]
268    fn insert_first() {
269        check_diff(
270            r#"
271use bar;
272use baz;"#,
273            r#"
274use foo;
275use bar;
276use baz;"#,
277            expect![[r#"
278                insertions:
279
280                Line 0: After(Token(WHITESPACE@0..1 "\n"))
281                -> use foo;
282                -> "\n"
283
284                replacements:
285
286
287
288                deletions:
289
290
291            "#]],
292        )
293    }
294
295    #[test]
296    fn first_child_insertion() {
297        cov_mark::check!(insert_first_child);
298        check_diff(
299            r#"fn main() {
300        stdi
301    }"#,
302            r#"use foo::bar;
303
304    fn main() {
305        stdi
306    }"#,
307            expect![[r#"
308                insertions:
309
310                Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
311                -> use foo::bar;
312                -> "\n\n    "
313
314                replacements:
315
316
317
318                deletions:
319
320
321            "#]],
322        );
323    }
324
325    #[test]
326    fn delete_last() {
327        cov_mark::check!(diff_delete);
328        check_diff(
329            r#"use foo;
330            use bar;"#,
331            r#"use foo;"#,
332            expect![[r#"
333                insertions:
334
335
336
337                replacements:
338
339
340
341                deletions:
342
343                Line 1: "\n            "
344                Line 2: use bar;
345            "#]],
346        );
347    }
348
349    #[test]
350    fn delete_middle() {
351        cov_mark::check!(diff_insertions);
352        check_diff(
353            r#"
354use expect_test::{expect, Expect};
355use text_edit::TextEdit;
356
357use crate::AstNode;
358"#,
359            r#"
360use expect_test::{expect, Expect};
361
362use crate::AstNode;
363"#,
364            expect![[r#"
365                insertions:
366
367                Line 1: After(Node(USE@1..35))
368                -> "\n\n"
369                -> use crate::AstNode;
370
371                replacements:
372
373
374
375                deletions:
376
377                Line 2: use text_edit::TextEdit;
378                Line 3: "\n\n"
379                Line 4: use crate::AstNode;
380                Line 5: "\n"
381            "#]],
382        )
383    }
384
385    #[test]
386    fn delete_first() {
387        check_diff(
388            r#"
389use text_edit::TextEdit;
390
391use crate::AstNode;
392"#,
393            r#"
394use crate::AstNode;
395"#,
396            expect![[r#"
397                insertions:
398
399
400
401                replacements:
402
403                Line 2: Token(IDENT@5..14 "text_edit") -> crate
404                Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
405                Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
406
407                deletions:
408
409                Line 3: use crate::AstNode;
410                Line 4: "\n"
411            "#]],
412        )
413    }
414
415    #[test]
416    fn merge_use() {
417        check_diff(
418            r#"
419use std::{
420    fmt,
421    hash::BuildHasherDefault,
422    ops::{self, RangeInclusive},
423};
424"#,
425            r#"
426use std::fmt;
427use std::hash::BuildHasherDefault;
428use std::ops::{self, RangeInclusive};
429"#,
430            expect![[r#"
431                insertions:
432
433                Line 2: After(Node(PATH_SEGMENT@5..8))
434                -> ::
435                -> fmt
436                Line 6: After(Token(WHITESPACE@86..87 "\n"))
437                -> use std::hash::BuildHasherDefault;
438                -> "\n"
439                -> use std::ops::{self, RangeInclusive};
440                -> "\n"
441
442                replacements:
443
444                Line 2: Token(IDENT@5..8 "std") -> std
445
446                deletions:
447
448                Line 2: ::
449                Line 2: {
450                    fmt,
451                    hash::BuildHasherDefault,
452                    ops::{self, RangeInclusive},
453                }
454            "#]],
455        )
456    }
457
458    #[test]
459    fn early_return_assist() {
460        check_diff(
461            r#"
462fn main() {
463    if let Ok(x) = Err(92) {
464        foo(x);
465    }
466}
467            "#,
468            r#"
469fn main() {
470    let x = match Err(92) {
471        Ok(it) => it,
472        _ => return,
473    };
474    foo(x);
475}
476            "#,
477            expect![[r#"
478                insertions:
479
480                Line 3: After(Node(BLOCK_EXPR@40..63))
481                -> " "
482                -> match Err(92) {
483                        Ok(it) => it,
484                        _ => return,
485                    }
486                -> ;
487                Line 3: After(Node(IF_EXPR@17..63))
488                -> "\n    "
489                -> foo(x);
490
491                replacements:
492
493                Line 3: Token(IF_KW@17..19 "if") -> let
494                Line 3: Token(LET_KW@20..23 "let") -> x
495                Line 3: Node(BLOCK_EXPR@40..63) -> =
496
497                deletions:
498
499                Line 3: " "
500                Line 3: Ok(x)
501                Line 3: " "
502                Line 3: =
503                Line 3: " "
504                Line 3: Err(92)
505            "#]],
506        )
507    }
508
509    fn check_diff(from: &str, to: &str, expected_diff: Expect) {
510        let from_node = SourceFile::parse(from, Edition::CURRENT).tree().syntax().clone();
511        let to_node = SourceFile::parse(to, Edition::CURRENT).tree().syntax().clone();
512        let diff = super::diff(&from_node, &to_node);
513
514        let line_number =
515            |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
516
517        let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
518            SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
519            _ => format!("{syn}"),
520        };
521
522        let insertions =
523            diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
524                f(&format!(
525                    "Line {}: {:?}\n-> {}",
526                    line_number(match k {
527                        super::TreeDiffInsertPos::After(syn) => syn,
528                        super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
529                    }),
530                    k,
531                    v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
532                ))
533            });
534
535        let replacements = diff
536            .replacements
537            .iter()
538            .sorted_by_key(|(syntax, _)| syntax.text_range().start())
539            .format_with("\n", |(k, v), f| {
540                f(&format!("Line {}: {k:?} -> {}", line_number(k), fmt_syntax(v)))
541            });
542
543        let deletions = diff
544            .deletions
545            .iter()
546            .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), fmt_syntax(v))));
547
548        let actual = format!(
549            "insertions:\n\n{insertions}\n\nreplacements:\n\n{replacements}\n\ndeletions:\n\n{deletions}\n"
550        );
551        expected_diff.assert_eq(&actual);
552
553        let mut from = from.to_owned();
554        let mut text_edit = TextEdit::builder();
555        diff.into_text_edit(&mut text_edit);
556        text_edit.finish().apply(&mut from);
557        assert_eq!(&*from, to, "diff did not turn `from` to `to`");
558    }
559}