1use ide_db::{FxHashMap, FxHashSet};
8use std::{fmt::Display, str::FromStr};
9use syntax::{SmolStr, SyntaxKind, SyntaxNode, T};
10
11use crate::errors::bail;
12use crate::{SsrError, SsrPattern, SsrRule, fragments};
13
14#[derive(Debug)]
15pub(crate) struct ParsedRule {
16 pub(crate) placeholders_by_stand_in: FxHashMap<SmolStr, Placeholder>,
17 pub(crate) pattern: SyntaxNode,
18 pub(crate) template: Option<SyntaxNode>,
19}
20
21#[derive(Debug)]
22pub(crate) struct RawPattern {
23 tokens: Vec<PatternElement>,
24}
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub(crate) enum PatternElement {
29 Token(Token),
30 Placeholder(Placeholder),
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub(crate) struct Placeholder {
35 pub(crate) ident: Var,
37 stand_in_name: String,
39 pub(crate) constraints: Vec<Constraint>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub(crate) struct Var(pub(crate) String);
45
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub(crate) enum Constraint {
48 Kind(NodeKind),
49 Not(Box<Constraint>),
50}
51
52#[derive(Clone, Debug, PartialEq, Eq)]
53pub(crate) enum NodeKind {
54 Literal,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub(crate) struct Token {
59 kind: SyntaxKind,
60 pub(crate) text: SmolStr,
61}
62
63impl ParsedRule {
64 fn new(
65 pattern: &RawPattern,
66 template: Option<&RawPattern>,
67 ) -> Result<Vec<ParsedRule>, SsrError> {
68 let raw_pattern = pattern.as_rust_code();
69 let raw_template = template.map(|t| t.as_rust_code());
70 let raw_template = raw_template.as_deref();
71 let mut builder = RuleBuilder {
72 placeholders_by_stand_in: pattern.placeholders_by_stand_in(),
73 rules: Vec::new(),
74 };
75
76 let raw_template_stmt = raw_template.map(fragments::stmt);
77 if let raw_template_expr @ Some(Ok(_)) = raw_template.map(fragments::expr) {
78 builder.try_add(fragments::expr(&raw_pattern), raw_template_expr);
79 } else {
80 builder.try_add(fragments::expr(&raw_pattern), raw_template_stmt.clone());
81 }
82 builder.try_add(fragments::ty(&raw_pattern), raw_template.map(fragments::ty));
83 builder.try_add(fragments::item(&raw_pattern), raw_template.map(fragments::item));
84 builder.try_add(fragments::pat(&raw_pattern), raw_template.map(fragments::pat));
85 builder.try_add(fragments::stmt(&raw_pattern), raw_template_stmt);
86 builder.build()
87 }
88}
89
90struct RuleBuilder {
91 placeholders_by_stand_in: FxHashMap<SmolStr, Placeholder>,
92 rules: Vec<ParsedRule>,
93}
94
95impl RuleBuilder {
96 fn try_add(
97 &mut self,
98 pattern: Result<SyntaxNode, ()>,
99 template: Option<Result<SyntaxNode, ()>>,
100 ) {
101 match (pattern, template) {
102 (Ok(pattern), Some(Ok(template))) => self.rules.push(ParsedRule {
103 placeholders_by_stand_in: self.placeholders_by_stand_in.clone(),
104 pattern,
105 template: Some(template),
106 }),
107 (Ok(pattern), None) => self.rules.push(ParsedRule {
108 placeholders_by_stand_in: self.placeholders_by_stand_in.clone(),
109 pattern,
110 template: None,
111 }),
112 _ => {}
113 }
114 }
115
116 fn build(mut self) -> Result<Vec<ParsedRule>, SsrError> {
117 if self.rules.is_empty() {
118 bail!("Not a valid Rust expression, type, item, path or pattern");
119 }
120 if self.rules.iter().any(|rule| contains_path(&rule.pattern)) {
129 let old_len = self.rules.len();
130 self.rules.retain(|rule| contains_path(&rule.pattern));
131 if self.rules.len() < old_len {
132 cov_mark::hit!(pattern_is_a_single_segment_path);
133 }
134 }
135 Ok(self.rules)
136 }
137}
138
139fn contains_path(node: &SyntaxNode) -> bool {
141 node.kind() == SyntaxKind::PATH
142 || node.descendants().any(|node| node.kind() == SyntaxKind::PATH)
143}
144
145impl FromStr for SsrRule {
146 type Err = SsrError;
147
148 fn from_str(query: &str) -> Result<SsrRule, SsrError> {
149 let mut it = query.split("==>>");
150 let pattern = it.next().expect("at least empty string").trim();
151 let template = it
152 .next()
153 .ok_or_else(|| SsrError("Cannot find delimiter `==>>`".into()))?
154 .trim()
155 .to_owned();
156 if it.next().is_some() {
157 return Err(SsrError("More than one delimiter found".into()));
158 }
159 let raw_pattern = pattern.parse()?;
160 let raw_template = template.parse()?;
161 let parsed_rules = ParsedRule::new(&raw_pattern, Some(&raw_template))?;
162 let rule = SsrRule { pattern: raw_pattern, template: raw_template, parsed_rules };
163 validate_rule(&rule)?;
164 Ok(rule)
165 }
166}
167
168impl FromStr for RawPattern {
169 type Err = SsrError;
170
171 fn from_str(pattern_str: &str) -> Result<RawPattern, SsrError> {
172 Ok(RawPattern { tokens: parse_pattern(pattern_str)? })
173 }
174}
175
176impl RawPattern {
177 fn as_rust_code(&self) -> String {
179 let mut res = String::new();
180 for t in &self.tokens {
181 res.push_str(match t {
182 PatternElement::Token(token) => token.text.as_str(),
183 PatternElement::Placeholder(placeholder) => placeholder.stand_in_name.as_str(),
184 });
185 }
186 res
187 }
188
189 pub(crate) fn placeholders_by_stand_in(&self) -> FxHashMap<SmolStr, Placeholder> {
190 let mut res = FxHashMap::default();
191 for t in &self.tokens {
192 if let PatternElement::Placeholder(placeholder) = t {
193 res.insert(SmolStr::new(&placeholder.stand_in_name), placeholder.clone());
194 }
195 }
196 res
197 }
198}
199
200impl FromStr for SsrPattern {
201 type Err = SsrError;
202
203 fn from_str(pattern_str: &str) -> Result<SsrPattern, SsrError> {
204 let raw_pattern = pattern_str.parse()?;
205 let parsed_rules = ParsedRule::new(&raw_pattern, None)?;
206 Ok(SsrPattern { parsed_rules })
207 }
208}
209
210fn parse_pattern(pattern_str: &str) -> Result<Vec<PatternElement>, SsrError> {
214 let mut res = Vec::new();
215 let mut placeholder_names = FxHashSet::default();
216 let mut tokens = tokenize(pattern_str)?.into_iter();
217 while let Some(token) = tokens.next() {
218 if token.kind == T![$] {
219 let placeholder = parse_placeholder(&mut tokens)?;
220 if !placeholder_names.insert(placeholder.ident.clone()) {
221 bail!("Placeholder `{}` repeats more than once", placeholder.ident);
222 }
223 res.push(PatternElement::Placeholder(placeholder));
224 } else {
225 res.push(PatternElement::Token(token));
226 }
227 }
228 Ok(res)
229}
230
231fn validate_rule(rule: &SsrRule) -> Result<(), SsrError> {
234 let mut defined_placeholders = FxHashSet::default();
235 for p in &rule.pattern.tokens {
236 if let PatternElement::Placeholder(placeholder) = p {
237 defined_placeholders.insert(&placeholder.ident);
238 }
239 }
240 let mut undefined = Vec::new();
241 for p in &rule.template.tokens {
242 if let PatternElement::Placeholder(placeholder) = p {
243 if !defined_placeholders.contains(&placeholder.ident) {
244 undefined.push(placeholder.ident.to_string());
245 }
246 if !placeholder.constraints.is_empty() {
247 bail!("Replacement placeholders cannot have constraints");
248 }
249 }
250 }
251 if !undefined.is_empty() {
252 bail!("Replacement contains undefined placeholders: {}", undefined.join(", "));
253 }
254 Ok(())
255}
256
257fn tokenize(source: &str) -> Result<Vec<Token>, SsrError> {
258 let lexed = parser::LexedStr::new(parser::Edition::CURRENT, source);
259 if let Some((_, first_error)) = lexed.errors().next() {
260 bail!("Failed to parse pattern: {}", first_error);
261 }
262 let mut tokens: Vec<Token> = Vec::new();
263 for i in 0..lexed.len() {
264 tokens.push(Token { kind: lexed.kind(i), text: lexed.text(i).into() });
265 }
266 Ok(tokens)
267}
268
269fn parse_placeholder(tokens: &mut std::vec::IntoIter<Token>) -> Result<Placeholder, SsrError> {
270 let mut name = None;
271 let mut constraints = Vec::new();
272 if let Some(token) = tokens.next() {
273 match token.kind {
274 SyntaxKind::IDENT => {
275 name = Some(token.text);
276 }
277 T!['{'] => {
278 let token =
279 tokens.next().ok_or_else(|| SsrError::new("Unexpected end of placeholder"))?;
280 if token.kind == SyntaxKind::IDENT {
281 name = Some(token.text);
282 }
283 loop {
284 let token = tokens
285 .next()
286 .ok_or_else(|| SsrError::new("Placeholder is missing closing brace '}'"))?;
287 match token.kind {
288 T![:] => {
289 constraints.push(parse_constraint(tokens)?);
290 }
291 T!['}'] => break,
292 _ => bail!("Unexpected token while parsing placeholder: '{}'", token.text),
293 }
294 }
295 }
296 _ => {
297 bail!("Placeholders should either be $name or ${{name:constraints}}");
298 }
299 }
300 }
301 let name = name.ok_or_else(|| SsrError::new("Placeholder ($) with no name"))?;
302 Ok(Placeholder::new(name, constraints))
303}
304
305fn parse_constraint(tokens: &mut std::vec::IntoIter<Token>) -> Result<Constraint, SsrError> {
306 let constraint_type = tokens
307 .next()
308 .ok_or_else(|| SsrError::new("Found end of placeholder while looking for a constraint"))?
309 .text
310 .to_string();
311 match constraint_type.as_str() {
312 "kind" => {
313 expect_token(tokens, "(")?;
314 let t = tokens.next().ok_or_else(|| {
315 SsrError::new("Unexpected end of constraint while looking for kind")
316 })?;
317 if t.kind != SyntaxKind::IDENT {
318 bail!("Expected ident, found {:?} while parsing kind constraint", t.kind);
319 }
320 expect_token(tokens, ")")?;
321 Ok(Constraint::Kind(NodeKind::from(&t.text)?))
322 }
323 "not" => {
324 expect_token(tokens, "(")?;
325 let sub = parse_constraint(tokens)?;
326 expect_token(tokens, ")")?;
327 Ok(Constraint::Not(Box::new(sub)))
328 }
329 x => bail!("Unsupported constraint type '{}'", x),
330 }
331}
332
333fn expect_token(tokens: &mut std::vec::IntoIter<Token>, expected: &str) -> Result<(), SsrError> {
334 if let Some(t) = tokens.next() {
335 if t.text == expected {
336 return Ok(());
337 }
338 bail!("Expected {} found {}", expected, t.text);
339 }
340 bail!("Expected {} found end of stream", expected);
341}
342
343impl NodeKind {
344 fn from(name: &SmolStr) -> Result<NodeKind, SsrError> {
345 Ok(match name.as_str() {
346 "literal" => NodeKind::Literal,
347 _ => bail!("Unknown node kind '{}'", name),
348 })
349 }
350}
351
352impl Placeholder {
353 fn new(name: SmolStr, constraints: Vec<Constraint>) -> Self {
354 Self {
355 stand_in_name: format!("__placeholder_{name}"),
356 constraints,
357 ident: Var(name.to_string()),
358 }
359 }
360}
361
362impl Display for Var {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(f, "${}", self.0)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn parser_happy_case() {
374 fn token(kind: SyntaxKind, text: &str) -> PatternElement {
375 PatternElement::Token(Token { kind, text: SmolStr::new(text) })
376 }
377 fn placeholder(name: &str) -> PatternElement {
378 PatternElement::Placeholder(Placeholder::new(SmolStr::new(name), Vec::new()))
379 }
380 let result: SsrRule = "foo($a, $b) ==>> bar($b, $a)".parse().unwrap();
381 assert_eq!(
382 result.pattern.tokens,
383 vec![
384 token(SyntaxKind::IDENT, "foo"),
385 token(T!['('], "("),
386 placeholder("a"),
387 token(T![,], ","),
388 token(SyntaxKind::WHITESPACE, " "),
389 placeholder("b"),
390 token(T![')'], ")"),
391 ]
392 );
393 assert_eq!(
394 result.template.tokens,
395 vec![
396 token(SyntaxKind::IDENT, "bar"),
397 token(T!['('], "("),
398 placeholder("b"),
399 token(T![,], ","),
400 token(SyntaxKind::WHITESPACE, " "),
401 placeholder("a"),
402 token(T![')'], ")"),
403 ]
404 );
405 }
406}