mbe/
lib.rs

1//! `mbe` (short for Macro By Example) crate contains code for handling
2//! `macro_rules` macros. It uses `TokenTree` (from `tt` package) as the
3//! interface, although it contains some code to bridge `SyntaxNode`s and
4//! `TokenTree`s as well!
5//!
6//! The tests for this functionality live in another crate:
7//! `hir_def::macro_expansion_tests::mbe`.
8
9#![cfg_attr(feature = "in-rust-tree", feature(rustc_private))]
10
11#[cfg(not(feature = "in-rust-tree"))]
12extern crate ra_ap_rustc_lexer as rustc_lexer;
13#[cfg(feature = "in-rust-tree")]
14extern crate rustc_lexer;
15
16mod expander;
17mod macro_call_style;
18mod parser;
19
20#[cfg(test)]
21mod benchmark;
22#[cfg(test)]
23mod tests;
24
25use span::{Edition, Span, SyntaxContext};
26use syntax_bridge::to_parser_input;
27use tt::DelimSpan;
28use tt::iter::TtIter;
29
30use std::fmt;
31use std::sync::Arc;
32
33pub use crate::macro_call_style::{MacroCallStyle, MacroCallStyles};
34use crate::parser::{MetaTemplate, MetaVarKind, Op};
35
36pub use tt::{Delimiter, DelimiterKind, Punct};
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39pub enum ParseError {
40    UnexpectedToken(Box<str>),
41    Expected(Box<str>),
42    InvalidRepeat,
43    RepetitionEmptyTokenTree,
44}
45
46impl ParseError {
47    fn expected(e: &str) -> ParseError {
48        ParseError::Expected(e.into())
49    }
50
51    fn unexpected(e: &str) -> ParseError {
52        ParseError::UnexpectedToken(e.into())
53    }
54}
55
56impl fmt::Display for ParseError {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            ParseError::UnexpectedToken(it) => f.write_str(it),
60            ParseError::Expected(it) => f.write_str(it),
61            ParseError::InvalidRepeat => f.write_str("invalid repeat"),
62            ParseError::RepetitionEmptyTokenTree => f.write_str("empty token tree in repetition"),
63        }
64    }
65}
66
67#[derive(Debug, PartialEq, Eq, Clone, Hash)]
68pub struct ExpandError {
69    pub inner: Arc<(Span, ExpandErrorKind)>,
70}
71#[derive(Debug, PartialEq, Eq, Clone, Hash)]
72pub enum ExpandErrorKind {
73    BindingError(Box<Box<str>>),
74    UnresolvedBinding(Box<Box<str>>),
75    LeftoverTokens,
76    LimitExceeded,
77    NoMatchingRule,
78    UnexpectedToken,
79}
80
81impl ExpandError {
82    fn new(span: Span, kind: ExpandErrorKind) -> ExpandError {
83        ExpandError { inner: Arc::new((span, kind)) }
84    }
85    fn binding_error(span: Span, e: impl Into<Box<str>>) -> ExpandError {
86        ExpandError { inner: Arc::new((span, ExpandErrorKind::BindingError(Box::new(e.into())))) }
87    }
88}
89impl fmt::Display for ExpandError {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        self.inner.1.fmt(f)
92    }
93}
94
95impl fmt::Display for ExpandErrorKind {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            ExpandErrorKind::NoMatchingRule => f.write_str("no rule matches input tokens"),
99            ExpandErrorKind::UnexpectedToken => f.write_str("unexpected token in input"),
100            ExpandErrorKind::BindingError(e) => f.write_str(e),
101            ExpandErrorKind::UnresolvedBinding(binding) => {
102                f.write_str("could not find binding ")?;
103                f.write_str(binding)
104            }
105            ExpandErrorKind::LimitExceeded => f.write_str("Expand exceed limit"),
106            ExpandErrorKind::LeftoverTokens => f.write_str("leftover tokens"),
107        }
108    }
109}
110
111// FIXME: Showing these errors could be nicer.
112#[derive(Debug, PartialEq, Eq, Clone, Hash)]
113pub enum CountError {
114    OutOfBounds,
115    Misplaced,
116}
117
118impl fmt::Display for CountError {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            CountError::OutOfBounds => f.write_str("${count} out of bounds"),
122            CountError::Misplaced => f.write_str("${count} misplaced"),
123        }
124    }
125}
126
127/// Index of the matched macro arm on successful expansion.
128pub type MatchedArmIndex = Option<u32>;
129
130/// This struct contains AST for a single `macro_rules` definition. What might
131/// be very confusing is that AST has almost exactly the same shape as
132/// `tt::TokenTree`, but there's a crucial difference: in macro rules, `$ident`
133/// and `$()*` have special meaning (see `Var` and `Repeat` data structures)
134#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct DeclarativeMacro {
136    rules: Box<[Rule]>,
137    err: Option<Box<ParseError>>,
138}
139
140#[derive(Clone, Debug, PartialEq, Eq)]
141struct Rule {
142    /// Is this a normal fn-like rule, an `attr()` rule, or a `derive()` rule?
143    style: MacroCallStyle,
144    lhs: MetaTemplate,
145    rhs: MetaTemplate,
146}
147
148impl DeclarativeMacro {
149    pub fn from_err(err: ParseError) -> DeclarativeMacro {
150        DeclarativeMacro { rules: Box::default(), err: Some(Box::new(err)) }
151    }
152
153    /// The old, `macro_rules! m {}` flavor.
154    pub fn parse_macro_rules(
155        tt: &tt::TopSubtree<Span>,
156        ctx_edition: impl Copy + Fn(SyntaxContext) -> Edition,
157    ) -> DeclarativeMacro {
158        // Note: this parsing can be implemented using mbe machinery itself, by
159        // matching against `$($lhs:tt => $rhs:tt);*` pattern, but implementing
160        // manually seems easier.
161        let mut src = tt.iter();
162        let mut rules = Vec::new();
163        let mut err = None;
164
165        while !src.is_empty() {
166            let rule = match Rule::parse(ctx_edition, &mut src) {
167                Ok(it) => it,
168                Err(e) => {
169                    err = Some(Box::new(e));
170                    break;
171                }
172            };
173            rules.push(rule);
174            if let Err(()) = src.expect_char(';') {
175                if !src.is_empty() {
176                    err = Some(Box::new(ParseError::expected("expected `;`")));
177                }
178                break;
179            }
180        }
181
182        for Rule { lhs, .. } in &rules {
183            if let Err(e) = validate(lhs) {
184                err = Some(Box::new(e));
185                break;
186            }
187        }
188
189        DeclarativeMacro { rules: rules.into_boxed_slice(), err }
190    }
191
192    /// The new, unstable `macro m {}` flavor.
193    pub fn parse_macro2(
194        args: Option<&tt::TopSubtree<Span>>,
195        body: &tt::TopSubtree<Span>,
196        ctx_edition: impl Copy + Fn(SyntaxContext) -> Edition,
197    ) -> DeclarativeMacro {
198        let mut rules = Vec::new();
199        let mut err = None;
200
201        if let Some(args) = args {
202            // The presence of an argument list means that this macro uses the
203            // "simple" syntax, where the body is the RHS of a single rule.
204            cov_mark::hit!(parse_macro_def_simple);
205
206            let rule = (|| {
207                let lhs = MetaTemplate::parse_pattern(ctx_edition, args.iter())?;
208                let rhs = MetaTemplate::parse_template(ctx_edition, body.iter())?;
209
210                // In the "simple" syntax, there is apparently no way to specify
211                // that the single rule is an attribute or derive rule, so it
212                // must be a function-like rule.
213                Ok(crate::Rule { style: MacroCallStyle::FnLike, lhs, rhs })
214            })();
215
216            match rule {
217                Ok(rule) => rules.push(rule),
218                Err(e) => err = Some(Box::new(e)),
219            }
220        } else {
221            // There was no top-level argument list, so this macro uses the
222            // list-of-rules syntax, similar to `macro_rules!`.
223            cov_mark::hit!(parse_macro_def_rules);
224            let mut src = body.iter();
225            while !src.is_empty() {
226                let rule = match Rule::parse(ctx_edition, &mut src) {
227                    Ok(it) => it,
228                    Err(e) => {
229                        err = Some(Box::new(e));
230                        break;
231                    }
232                };
233                rules.push(rule);
234                if let Err(()) = src.expect_any_char(&[';', ',']) {
235                    if !src.is_empty() {
236                        err = Some(Box::new(ParseError::expected(
237                            "expected `;` or `,` to delimit rules",
238                        )));
239                    }
240                    break;
241                }
242            }
243        }
244
245        for Rule { lhs, .. } in &rules {
246            if let Err(e) = validate(lhs) {
247                err = Some(Box::new(e));
248                break;
249            }
250        }
251
252        DeclarativeMacro { rules: rules.into_boxed_slice(), err }
253    }
254
255    pub fn err(&self) -> Option<&ParseError> {
256        self.err.as_deref()
257    }
258
259    pub fn num_rules(&self) -> usize {
260        self.rules.len()
261    }
262
263    pub fn rule_styles(&self) -> MacroCallStyles {
264        if self.rules.is_empty() {
265            // No rules could be parsed, so fall back to assuming that this
266            // is intended to be a function-like macro.
267            MacroCallStyles::FN_LIKE
268        } else {
269            self.rules
270                .iter()
271                .map(|rule| MacroCallStyles::from(rule.style))
272                .fold(MacroCallStyles::empty(), |a, b| a | b)
273        }
274    }
275
276    pub fn expand(
277        &self,
278        db: &dyn salsa::Database,
279        tt: &tt::TopSubtree<Span>,
280        marker: impl Fn(&mut Span) + Copy,
281        call_style: MacroCallStyle,
282        call_site: Span,
283    ) -> ExpandResult<(tt::TopSubtree<Span>, MatchedArmIndex)> {
284        expander::expand_rules(db, &self.rules, tt, marker, call_style, call_site)
285    }
286}
287
288impl Rule {
289    fn parse(
290        edition: impl Copy + Fn(SyntaxContext) -> Edition,
291        src: &mut TtIter<'_, Span>,
292    ) -> Result<Self, ParseError> {
293        // Parse an optional `attr()` or `derive()` prefix before the LHS pattern.
294        let style = parser::parse_rule_style(src)?;
295
296        let (_, lhs) =
297            src.expect_subtree().map_err(|()| ParseError::expected("expected subtree"))?;
298        src.expect_char('=').map_err(|()| ParseError::expected("expected `=`"))?;
299        src.expect_char('>').map_err(|()| ParseError::expected("expected `>`"))?;
300        let (_, rhs) =
301            src.expect_subtree().map_err(|()| ParseError::expected("expected subtree"))?;
302
303        let lhs = MetaTemplate::parse_pattern(edition, lhs)?;
304        let rhs = MetaTemplate::parse_template(edition, rhs)?;
305
306        Ok(crate::Rule { style, lhs, rhs })
307    }
308}
309
310fn validate(pattern: &MetaTemplate) -> Result<(), ParseError> {
311    for op in pattern.iter() {
312        match op {
313            Op::Subtree { tokens, .. } => validate(tokens)?,
314            Op::Repeat { tokens: subtree, separator, .. } => {
315                // Checks that no repetition which could match an empty token
316                // https://github.com/rust-lang/rust/blob/a58b1ed44f5e06976de2bdc4d7dc81c36a96934f/src/librustc_expand/mbe/macro_rules.rs#L558
317                let lsh_is_empty_seq = separator.is_none() && subtree.iter().all(|child_op| {
318                    match *child_op {
319                        // vis is optional
320                        Op::Var { kind: Some(kind), .. } => kind == MetaVarKind::Vis,
321                        Op::Repeat {
322                            kind: parser::RepeatKind::ZeroOrMore | parser::RepeatKind::ZeroOrOne,
323                            ..
324                        } => true,
325                        _ => false,
326                    }
327                });
328                if lsh_is_empty_seq {
329                    return Err(ParseError::RepetitionEmptyTokenTree);
330                }
331                validate(subtree)?
332            }
333            _ => (),
334        }
335    }
336    Ok(())
337}
338
339pub type ExpandResult<T> = ValueResult<T, ExpandError>;
340
341#[derive(Debug, Clone, Eq, PartialEq)]
342pub struct ValueResult<T, E> {
343    pub value: T,
344    pub err: Option<E>,
345}
346
347impl<T: Default, E> Default for ValueResult<T, E> {
348    fn default() -> Self {
349        Self { value: Default::default(), err: Default::default() }
350    }
351}
352
353impl<T, E> ValueResult<T, E> {
354    pub fn new(value: T, err: E) -> Self {
355        Self { value, err: Some(err) }
356    }
357
358    pub fn ok(value: T) -> Self {
359        Self { value, err: None }
360    }
361
362    pub fn only_err(err: E) -> Self
363    where
364        T: Default,
365    {
366        Self { value: Default::default(), err: Some(err) }
367    }
368
369    pub fn zip_val<U>(self, other: U) -> ValueResult<(T, U), E> {
370        ValueResult { value: (self.value, other), err: self.err }
371    }
372
373    pub fn map<U>(self, f: impl FnOnce(T) -> U) -> ValueResult<U, E> {
374        ValueResult { value: f(self.value), err: self.err }
375    }
376
377    pub fn map_err<E2>(self, f: impl FnOnce(E) -> E2) -> ValueResult<T, E2> {
378        ValueResult { value: self.value, err: self.err.map(f) }
379    }
380
381    pub fn result(self) -> Result<T, E> {
382        self.err.map_or(Ok(self.value), Err)
383    }
384}
385
386impl<T: Default, E> From<Result<T, E>> for ValueResult<T, E> {
387    fn from(result: Result<T, E>) -> Self {
388        result.map_or_else(Self::only_err, Self::ok)
389    }
390}
391
392pub fn expect_fragment<'t>(
393    db: &dyn salsa::Database,
394    tt_iter: &mut TtIter<'t, Span>,
395    entry_point: ::parser::PrefixEntryPoint,
396    delim_span: DelimSpan<Span>,
397) -> ExpandResult<tt::TokenTreesView<'t, Span>> {
398    use ::parser;
399    let buffer = tt_iter.remaining();
400    let parser_input = to_parser_input(buffer, &mut |ctx| ctx.edition(db));
401    let tree_traversal = entry_point.parse(&parser_input);
402    let mut cursor = buffer.cursor();
403    let mut error = false;
404    for step in tree_traversal.iter() {
405        match step {
406            parser::Step::Token { kind, mut n_input_tokens } => {
407                if kind == ::parser::SyntaxKind::LIFETIME_IDENT {
408                    n_input_tokens = 2;
409                }
410                for _ in 0..n_input_tokens {
411                    cursor.bump_or_end();
412                }
413            }
414            parser::Step::FloatSplit { .. } => {
415                // FIXME: We need to split the tree properly here, but mutating the token trees
416                // in the buffer is somewhat tricky to pull off.
417                cursor.bump_or_end();
418            }
419            parser::Step::Enter { .. } | parser::Step::Exit => (),
420            parser::Step::Error { .. } => error = true,
421        }
422    }
423
424    let err = if error || !cursor.is_root() {
425        Some(ExpandError::binding_error(
426            buffer.cursor().token_tree().map_or(delim_span.close, |tt| tt.first_span()),
427            format!("expected {entry_point:?}"),
428        ))
429    } else {
430        None
431    };
432
433    while !cursor.is_root() {
434        cursor.bump_or_end();
435    }
436
437    let res = cursor.crossed();
438    tt_iter.flat_advance(res.len());
439
440    ExpandResult { value: res, err }
441}