ide_assists/handlers/
number_representation.rs

1use syntax::{AstToken, ast, ast::Radix};
2
3use crate::{AssistContext, AssistId, Assists, GroupLabel};
4
5const MIN_NUMBER_OF_DIGITS_TO_FORMAT: usize = 5;
6
7// Assist: reformat_number_literal
8//
9// Adds or removes separators from integer literal.
10//
11// ```
12// const _: i32 = 1012345$0;
13// ```
14// ->
15// ```
16// const _: i32 = 1_012_345;
17// ```
18pub(crate) fn reformat_number_literal(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
19    let literal = ctx.find_node_at_offset::<ast::Literal>()?;
20    let literal = match literal.kind() {
21        ast::LiteralKind::IntNumber(it) => it,
22        _ => return None,
23    };
24
25    let text = literal.text();
26    if text.contains('_') {
27        return remove_separators(acc, literal);
28    }
29
30    let (prefix, value, suffix) = literal.split_into_parts();
31    if value.len() < MIN_NUMBER_OF_DIGITS_TO_FORMAT {
32        return None;
33    }
34
35    let radix = literal.radix();
36    let mut converted = prefix.to_owned();
37    converted.push_str(&add_group_separators(value, group_size(radix)));
38    converted.push_str(suffix);
39
40    let group_id = GroupLabel("Reformat number literal".into());
41    let label = format!("Convert {literal} to {converted}");
42    let range = literal.syntax().text_range();
43    acc.add_group(
44        &group_id,
45        AssistId::refactor_inline("reformat_number_literal"),
46        label,
47        range,
48        |builder| builder.replace(range, converted),
49    )
50}
51
52fn remove_separators(acc: &mut Assists, literal: ast::IntNumber) -> Option<()> {
53    let group_id = GroupLabel("Reformat number literal".into());
54    let range = literal.syntax().text_range();
55    acc.add_group(
56        &group_id,
57        AssistId::refactor_inline("reformat_number_literal"),
58        "Remove digit separators",
59        range,
60        |builder| builder.replace(range, literal.text().replace('_', "")),
61    )
62}
63
64const fn group_size(r: Radix) -> usize {
65    match r {
66        Radix::Binary => 4,
67        Radix::Octal => 3,
68        Radix::Decimal => 3,
69        Radix::Hexadecimal => 4,
70    }
71}
72
73fn add_group_separators(s: &str, group_size: usize) -> String {
74    let mut chars = Vec::new();
75    for (i, ch) in s.chars().filter(|&ch| ch != '_').rev().enumerate() {
76        if i > 0 && i % group_size == 0 {
77            chars.push('_');
78        }
79        chars.push(ch);
80    }
81
82    chars.into_iter().rev().collect()
83}
84
85#[cfg(test)]
86mod tests {
87    use crate::tests::{check_assist_by_label, check_assist_not_applicable, check_assist_target};
88
89    use super::*;
90
91    #[test]
92    fn group_separators() {
93        let cases = vec![
94            ("", 4, ""),
95            ("1", 4, "1"),
96            ("12", 4, "12"),
97            ("123", 4, "123"),
98            ("1234", 4, "1234"),
99            ("12345", 4, "1_2345"),
100            ("123456", 4, "12_3456"),
101            ("1234567", 4, "123_4567"),
102            ("12345678", 4, "1234_5678"),
103            ("123456789", 4, "1_2345_6789"),
104            ("1234567890", 4, "12_3456_7890"),
105            ("1_2_3_4_5_6_7_8_9_0_", 4, "12_3456_7890"),
106            ("1234567890", 3, "1_234_567_890"),
107            ("1234567890", 2, "12_34_56_78_90"),
108            ("1234567890", 1, "1_2_3_4_5_6_7_8_9_0"),
109        ];
110
111        for case in cases {
112            let (input, group_size, expected) = case;
113            assert_eq!(add_group_separators(input, group_size), expected)
114        }
115    }
116
117    #[test]
118    fn good_targets() {
119        let cases = vec![
120            ("const _: i32 = 0b11111$0", "0b11111"),
121            ("const _: i32 = 0o77777$0;", "0o77777"),
122            ("const _: i32 = 10000$0;", "10000"),
123            ("const _: i32 = 0xFFFFF$0;", "0xFFFFF"),
124            ("const _: i32 = 10000i32$0;", "10000i32"),
125            ("const _: i32 = 0b_10_0i32$0;", "0b_10_0i32"),
126        ];
127
128        for case in cases {
129            check_assist_target(reformat_number_literal, case.0, case.1);
130        }
131    }
132
133    #[test]
134    fn bad_targets() {
135        let cases = vec![
136            "const _: i32 = 0b111$0",
137            "const _: i32 = 0b1111$0",
138            "const _: i32 = 0o77$0;",
139            "const _: i32 = 0o777$0;",
140            "const _: i32 = 10$0;",
141            "const _: i32 = 999$0;",
142            "const _: i32 = 0xFF$0;",
143            "const _: i32 = 0xFFFF$0;",
144        ];
145
146        for case in cases {
147            check_assist_not_applicable(reformat_number_literal, case);
148        }
149    }
150
151    #[test]
152    fn labels() {
153        let cases = vec![
154            ("const _: i32 = 10000$0", "const _: i32 = 10_000", "Convert 10000 to 10_000"),
155            (
156                "const _: i32 = 0xFF0000$0;",
157                "const _: i32 = 0xFF_0000;",
158                "Convert 0xFF0000 to 0xFF_0000",
159            ),
160            (
161                "const _: i32 = 0b11111111$0;",
162                "const _: i32 = 0b1111_1111;",
163                "Convert 0b11111111 to 0b1111_1111",
164            ),
165            (
166                "const _: i32 = 0o377211$0;",
167                "const _: i32 = 0o377_211;",
168                "Convert 0o377211 to 0o377_211",
169            ),
170            (
171                "const _: i32 = 10000i32$0;",
172                "const _: i32 = 10_000i32;",
173                "Convert 10000i32 to 10_000i32",
174            ),
175            ("const _: i32 = 1_0_0_0_i32$0;", "const _: i32 = 1000i32;", "Remove digit separators"),
176        ];
177
178        for case in cases {
179            let (before, after, label) = case;
180            check_assist_by_label(reformat_number_literal, before, after, label);
181        }
182    }
183}