rust_analyzer/handlers/
dispatch.rs

1//! See [RequestDispatcher].
2use std::{
3    fmt::{self, Debug},
4    panic, thread,
5};
6
7use ide_db::base_db::{
8    DbPanicContext,
9    salsa::{self, Cancelled},
10};
11use lsp_server::{ExtractError, Response, ResponseError};
12use serde::{Serialize, de::DeserializeOwned};
13use stdx::thread::ThreadIntent;
14
15use crate::{
16    global_state::{GlobalState, GlobalStateSnapshot},
17    lsp::LspError,
18    main_loop::Task,
19    version::version,
20};
21
22/// A visitor for routing a raw JSON request to an appropriate handler function.
23///
24/// Most requests are read-only and async and are handled on the threadpool
25/// (`on` method).
26///
27/// Some read-only requests are latency sensitive, and are immediately handled
28/// on the main loop thread (`on_sync`). These are typically typing-related
29/// requests.
30///
31/// Some requests modify the state, and are run on the main thread to get
32/// `&mut` (`on_sync_mut`).
33///
34/// Read-only requests are wrapped into `catch_unwind` -- they don't modify the
35/// state, so it's OK to recover from their failures.
36pub(crate) struct RequestDispatcher<'a> {
37    pub(crate) req: Option<lsp_server::Request>,
38    pub(crate) global_state: &'a mut GlobalState,
39}
40
41impl RequestDispatcher<'_> {
42    /// Dispatches the request onto the current thread, given full access to
43    /// mutable global state. Unlike all other methods here, this one isn't
44    /// guarded by `catch_unwind`, so, please, don't make bugs :-)
45    pub(crate) fn on_sync_mut<R>(
46        &mut self,
47        f: fn(&mut GlobalState, R::Params) -> anyhow::Result<R::Result>,
48    ) -> &mut Self
49    where
50        R: lsp_types::request::Request,
51        R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
52        R::Result: Serialize,
53    {
54        let (req, params, panic_context) = match self.parse::<R>() {
55            Some(it) => it,
56            None => return self,
57        };
58        let _guard =
59            tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
60        tracing::debug!(?params);
61        let result = {
62            let _pctx = DbPanicContext::enter(panic_context);
63            f(self.global_state, params)
64        };
65        if let Ok(response) = result_to_response::<R>(req.id, result) {
66            self.global_state.respond(response);
67        }
68
69        self
70    }
71
72    /// Dispatches the request onto the current thread.
73    pub(crate) fn on_sync<R>(
74        &mut self,
75        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
76    ) -> &mut Self
77    where
78        R: lsp_types::request::Request,
79        R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
80        R::Result: Serialize,
81    {
82        let (req, params, panic_context) = match self.parse::<R>() {
83            Some(it) => it,
84            None => return self,
85        };
86        let _guard =
87            tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
88        tracing::debug!(?params);
89        let global_state_snapshot = self.global_state.snapshot();
90
91        let result = panic::catch_unwind(move || {
92            let _pctx = DbPanicContext::enter(panic_context);
93            f(global_state_snapshot, params)
94        });
95
96        if let Ok(response) = thread_result_to_response::<R>(req.id, result) {
97            self.global_state.respond(response);
98        }
99
100        self
101    }
102
103    /// Dispatches a non-latency-sensitive request onto the thread pool. When the VFS is marked not
104    /// ready this will return a default constructed [`R::Result`].
105    pub(crate) fn on<const ALLOW_RETRYING: bool, R>(
106        &mut self,
107        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
108    ) -> &mut Self
109    where
110        R: lsp_types::request::Request<
111                Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
112                Result: Serialize + Default,
113            > + 'static,
114    {
115        if !self.global_state.vfs_done {
116            if let Some(lsp_server::Request { id, .. }) =
117                self.req.take_if(|it| it.method == R::METHOD)
118            {
119                self.global_state.respond(lsp_server::Response::new_ok(id, R::Result::default()));
120            }
121            return self;
122        }
123        self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
124            ThreadIntent::Worker,
125            f,
126            Self::content_modified_error,
127        )
128    }
129
130    /// Dispatches a non-latency-sensitive request onto the thread pool. When the VFS is marked not
131    /// ready this will return a `default` constructed [`R::Result`].
132    pub(crate) fn on_with_vfs_default<R>(
133        &mut self,
134        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
135        default: impl FnOnce() -> R::Result,
136        on_cancelled: fn() -> ResponseError,
137    ) -> &mut Self
138    where
139        R: lsp_types::request::Request<
140                Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
141                Result: Serialize,
142            > + 'static,
143    {
144        if !self.global_state.vfs_done || self.global_state.incomplete_crate_graph {
145            if let Some(lsp_server::Request { id, .. }) =
146                self.req.take_if(|it| it.method == R::METHOD)
147            {
148                self.global_state.respond(lsp_server::Response::new_ok(id, default()));
149            }
150            return self;
151        }
152        self.on_with_thread_intent::<false, false, R>(ThreadIntent::Worker, f, on_cancelled)
153    }
154
155    /// Dispatches a non-latency-sensitive request onto the thread pool. When the VFS is marked not
156    /// ready this will return the parameter as is.
157    pub(crate) fn on_identity<const ALLOW_RETRYING: bool, R, Params>(
158        &mut self,
159        f: fn(GlobalStateSnapshot, Params) -> anyhow::Result<R::Result>,
160    ) -> &mut Self
161    where
162        R: lsp_types::request::Request<Params = Params, Result = Params> + 'static,
163        Params: Serialize + DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
164    {
165        if !self.global_state.vfs_done {
166            if let Some((request, params, _)) = self.parse::<R>() {
167                self.global_state.respond(lsp_server::Response::new_ok(request.id, &params))
168            }
169            return self;
170        }
171        self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
172            ThreadIntent::Worker,
173            f,
174            Self::content_modified_error,
175        )
176    }
177
178    /// Dispatches a latency-sensitive request onto the thread pool. When the VFS is marked not
179    /// ready this will return a default constructed [`R::Result`].
180    pub(crate) fn on_latency_sensitive<const ALLOW_RETRYING: bool, R>(
181        &mut self,
182        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
183    ) -> &mut Self
184    where
185        R: lsp_types::request::Request<
186                Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
187                Result: Serialize + Default,
188            > + 'static,
189    {
190        if !self.global_state.vfs_done {
191            if let Some(lsp_server::Request { id, .. }) =
192                self.req.take_if(|it| it.method == R::METHOD)
193            {
194                self.global_state.respond(lsp_server::Response::new_ok(id, R::Result::default()));
195            }
196            return self;
197        }
198        self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
199            ThreadIntent::LatencySensitive,
200            f,
201            Self::content_modified_error,
202        )
203    }
204
205    /// Formatting requests should never block on waiting a for task thread to open up, editors will wait
206    /// on the response and a late formatting update might mess with the document and user.
207    /// We can't run this on the main thread though as we invoke rustfmt which may take arbitrary time to complete!
208    pub(crate) fn on_fmt_thread<R>(
209        &mut self,
210        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
211    ) -> &mut Self
212    where
213        R: lsp_types::request::Request + 'static,
214        R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
215        R::Result: Serialize,
216    {
217        self.on_with_thread_intent::<true, false, R>(
218            ThreadIntent::LatencySensitive,
219            f,
220            Self::content_modified_error,
221        )
222    }
223
224    pub(crate) fn finish(&mut self) {
225        if let Some(req) = self.req.take() {
226            tracing::error!("unknown request: {:?}", req);
227            let response = lsp_server::Response::new_err(
228                req.id,
229                lsp_server::ErrorCode::MethodNotFound as i32,
230                "unknown request".to_owned(),
231            );
232            self.global_state.respond(response);
233        }
234    }
235
236    fn on_with_thread_intent<const RUSTFMT: bool, const ALLOW_RETRYING: bool, R>(
237        &mut self,
238        intent: ThreadIntent,
239        f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
240        on_cancelled: fn() -> ResponseError,
241    ) -> &mut Self
242    where
243        R: lsp_types::request::Request + 'static,
244        R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
245        R::Result: Serialize,
246    {
247        let (req, params, panic_context) = match self.parse::<R>() {
248            Some(it) => it,
249            None => return self,
250        };
251        let _guard =
252            tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
253        tracing::debug!(?params);
254
255        let world = self.global_state.snapshot();
256        if RUSTFMT {
257            &mut self.global_state.fmt_pool.handle
258        } else {
259            &mut self.global_state.task_pool.handle
260        }
261        .spawn(intent, move || {
262            let result = panic::catch_unwind(move || {
263                let _pctx = DbPanicContext::enter(panic_context);
264                f(world, params)
265            });
266            match thread_result_to_response::<R>(req.id.clone(), result) {
267                Ok(response) => Task::Response(response),
268                Err(_cancelled) if ALLOW_RETRYING => Task::Retry(req),
269                Err(_cancelled) => {
270                    let error = on_cancelled();
271                    Task::Response(Response { id: req.id, result: None, error: Some(error) })
272                }
273            }
274        });
275
276        self
277    }
278
279    fn parse<R>(&mut self) -> Option<(lsp_server::Request, R::Params, String)>
280    where
281        R: lsp_types::request::Request,
282        R::Params: DeserializeOwned + fmt::Debug,
283    {
284        let req = self.req.take_if(|it| it.method == R::METHOD)?;
285        let res = crate::from_json(R::METHOD, &req.params);
286        match res {
287            Ok(params) => {
288                let panic_context =
289                    format!("\nversion: {}\nrequest: {} {params:#?}", version(), R::METHOD);
290                Some((req, params, panic_context))
291            }
292            Err(err) => {
293                let response = lsp_server::Response::new_err(
294                    req.id,
295                    lsp_server::ErrorCode::InvalidParams as i32,
296                    err.to_string(),
297                );
298                self.global_state.respond(response);
299                None
300            }
301        }
302    }
303
304    fn content_modified_error() -> ResponseError {
305        ResponseError {
306            code: lsp_server::ErrorCode::ContentModified as i32,
307            message: "content modified".to_owned(),
308            data: None,
309        }
310    }
311}
312
313#[derive(Debug)]
314enum HandlerCancelledError {
315    Inner(salsa::Cancelled),
316}
317
318impl std::error::Error for HandlerCancelledError {
319    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320        match self {
321            HandlerCancelledError::Inner(cancelled) => Some(cancelled),
322        }
323    }
324}
325
326impl fmt::Display for HandlerCancelledError {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        write!(f, "Cancelled")
329    }
330}
331
332fn thread_result_to_response<R>(
333    id: lsp_server::RequestId,
334    result: thread::Result<anyhow::Result<R::Result>>,
335) -> Result<lsp_server::Response, HandlerCancelledError>
336where
337    R: lsp_types::request::Request,
338    R::Params: DeserializeOwned,
339    R::Result: Serialize,
340{
341    match result {
342        Ok(result) => result_to_response::<R>(id, result),
343        Err(panic) => {
344            let panic_message = panic
345                .downcast_ref::<String>()
346                .map(String::as_str)
347                .or_else(|| panic.downcast_ref::<&str>().copied());
348
349            let mut message = "request handler panicked".to_owned();
350            if let Some(panic_message) = panic_message {
351                message.push_str(": ");
352                message.push_str(panic_message);
353            } else if let Ok(cancelled) = panic.downcast::<Cancelled>() {
354                tracing::error!("Cancellation propagated out of salsa! This is a bug");
355                return Err(HandlerCancelledError::Inner(*cancelled));
356            };
357
358            Ok(lsp_server::Response::new_err(
359                id,
360                lsp_server::ErrorCode::InternalError as i32,
361                message,
362            ))
363        }
364    }
365}
366
367fn result_to_response<R>(
368    id: lsp_server::RequestId,
369    result: anyhow::Result<R::Result>,
370) -> Result<lsp_server::Response, HandlerCancelledError>
371where
372    R: lsp_types::request::Request,
373    R::Params: DeserializeOwned,
374    R::Result: Serialize,
375{
376    let res = match result {
377        Ok(resp) => lsp_server::Response::new_ok(id, &resp),
378        Err(e) => match e.downcast::<LspError>() {
379            Ok(lsp_error) => lsp_server::Response::new_err(id, lsp_error.code, lsp_error.message),
380            Err(e) => match e.downcast::<Cancelled>() {
381                Ok(cancelled) => return Err(HandlerCancelledError::Inner(cancelled)),
382                Err(e) => lsp_server::Response::new_err(
383                    id,
384                    lsp_server::ErrorCode::InternalError as i32,
385                    e.to_string(),
386                ),
387            },
388        },
389    };
390    Ok(res)
391}
392
393pub(crate) struct NotificationDispatcher<'a> {
394    pub(crate) not: Option<lsp_server::Notification>,
395    pub(crate) global_state: &'a mut GlobalState,
396}
397
398impl NotificationDispatcher<'_> {
399    pub(crate) fn on_sync_mut<N>(
400        &mut self,
401        f: fn(&mut GlobalState, N::Params) -> anyhow::Result<()>,
402    ) -> &mut Self
403    where
404        N: lsp_types::notification::Notification,
405        N::Params: DeserializeOwned + Send + Debug,
406    {
407        let not = match self.not.take() {
408            Some(it) => it,
409            None => return self,
410        };
411
412        let _guard = tracing::info_span!("notification", method = ?not.method).entered();
413
414        let params = match not.extract::<N::Params>(N::METHOD) {
415            Ok(it) => it,
416            Err(ExtractError::JsonError { method, error }) => {
417                panic!("Invalid request\nMethod: {method}\n error: {error}",)
418            }
419            Err(ExtractError::MethodMismatch(not)) => {
420                self.not = Some(not);
421                return self;
422            }
423        };
424
425        tracing::debug!(?params);
426
427        let _pctx =
428            DbPanicContext::enter(format!("\nversion: {}\nnotification: {}", version(), N::METHOD));
429        if let Err(e) = f(self.global_state, params) {
430            tracing::error!(handler = %N::METHOD, error = %e, "notification handler failed");
431        }
432        self
433    }
434
435    pub(crate) fn finish(&mut self) {
436        if let Some(not) = &self.not
437            && !not.method.starts_with("$/")
438        {
439            tracing::error!("unhandled notification: {:?}", not);
440        }
441    }
442}