1use std::{
3 fmt::{self, Debug},
4 panic, thread,
5};
6
7use ide_db::base_db::{
8 DbPanicContext,
9 salsa::{self, Cancelled},
10};
11use lsp_server::{ExtractError, Response, ResponseError};
12use serde::{Serialize, de::DeserializeOwned};
13use stdx::thread::ThreadIntent;
14
15use crate::{
16 global_state::{GlobalState, GlobalStateSnapshot},
17 lsp::LspError,
18 main_loop::Task,
19 version::version,
20};
21
22pub(crate) struct RequestDispatcher<'a> {
37 pub(crate) req: Option<lsp_server::Request>,
38 pub(crate) global_state: &'a mut GlobalState,
39}
40
41impl RequestDispatcher<'_> {
42 pub(crate) fn on_sync_mut<R>(
46 &mut self,
47 f: fn(&mut GlobalState, R::Params) -> anyhow::Result<R::Result>,
48 ) -> &mut Self
49 where
50 R: lsp_types::request::Request,
51 R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
52 R::Result: Serialize,
53 {
54 let (req, params, panic_context) = match self.parse::<R>() {
55 Some(it) => it,
56 None => return self,
57 };
58 let _guard =
59 tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
60 tracing::debug!(?params);
61 let result = {
62 let _pctx = DbPanicContext::enter(panic_context);
63 f(self.global_state, params)
64 };
65 if let Ok(response) = result_to_response::<R>(req.id, result) {
66 self.global_state.respond(response);
67 }
68
69 self
70 }
71
72 pub(crate) fn on_sync<R>(
74 &mut self,
75 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
76 ) -> &mut Self
77 where
78 R: lsp_types::request::Request,
79 R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug,
80 R::Result: Serialize,
81 {
82 let (req, params, panic_context) = match self.parse::<R>() {
83 Some(it) => it,
84 None => return self,
85 };
86 let _guard =
87 tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
88 tracing::debug!(?params);
89 let global_state_snapshot = self.global_state.snapshot();
90
91 let result = panic::catch_unwind(move || {
92 let _pctx = DbPanicContext::enter(panic_context);
93 f(global_state_snapshot, params)
94 });
95
96 if let Ok(response) = thread_result_to_response::<R>(req.id, result) {
97 self.global_state.respond(response);
98 }
99
100 self
101 }
102
103 pub(crate) fn on<const ALLOW_RETRYING: bool, R>(
106 &mut self,
107 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
108 ) -> &mut Self
109 where
110 R: lsp_types::request::Request<
111 Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
112 Result: Serialize + Default,
113 > + 'static,
114 {
115 if !self.global_state.vfs_done {
116 if let Some(lsp_server::Request { id, .. }) =
117 self.req.take_if(|it| it.method == R::METHOD)
118 {
119 self.global_state.respond(lsp_server::Response::new_ok(id, R::Result::default()));
120 }
121 return self;
122 }
123 self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
124 ThreadIntent::Worker,
125 f,
126 Self::content_modified_error,
127 )
128 }
129
130 pub(crate) fn on_with_vfs_default<R>(
133 &mut self,
134 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
135 default: impl FnOnce() -> R::Result,
136 on_cancelled: fn() -> ResponseError,
137 ) -> &mut Self
138 where
139 R: lsp_types::request::Request<
140 Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
141 Result: Serialize,
142 > + 'static,
143 {
144 if !self.global_state.vfs_done || self.global_state.incomplete_crate_graph {
145 if let Some(lsp_server::Request { id, .. }) =
146 self.req.take_if(|it| it.method == R::METHOD)
147 {
148 self.global_state.respond(lsp_server::Response::new_ok(id, default()));
149 }
150 return self;
151 }
152 self.on_with_thread_intent::<false, false, R>(ThreadIntent::Worker, f, on_cancelled)
153 }
154
155 pub(crate) fn on_identity<const ALLOW_RETRYING: bool, R, Params>(
158 &mut self,
159 f: fn(GlobalStateSnapshot, Params) -> anyhow::Result<R::Result>,
160 ) -> &mut Self
161 where
162 R: lsp_types::request::Request<Params = Params, Result = Params> + 'static,
163 Params: Serialize + DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
164 {
165 if !self.global_state.vfs_done {
166 if let Some((request, params, _)) = self.parse::<R>() {
167 self.global_state.respond(lsp_server::Response::new_ok(request.id, ¶ms))
168 }
169 return self;
170 }
171 self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
172 ThreadIntent::Worker,
173 f,
174 Self::content_modified_error,
175 )
176 }
177
178 pub(crate) fn on_latency_sensitive<const ALLOW_RETRYING: bool, R>(
181 &mut self,
182 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
183 ) -> &mut Self
184 where
185 R: lsp_types::request::Request<
186 Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
187 Result: Serialize + Default,
188 > + 'static,
189 {
190 if !self.global_state.vfs_done {
191 if let Some(lsp_server::Request { id, .. }) =
192 self.req.take_if(|it| it.method == R::METHOD)
193 {
194 self.global_state.respond(lsp_server::Response::new_ok(id, R::Result::default()));
195 }
196 return self;
197 }
198 self.on_with_thread_intent::<false, ALLOW_RETRYING, R>(
199 ThreadIntent::LatencySensitive,
200 f,
201 Self::content_modified_error,
202 )
203 }
204
205 pub(crate) fn on_fmt_thread<R>(
209 &mut self,
210 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
211 ) -> &mut Self
212 where
213 R: lsp_types::request::Request + 'static,
214 R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
215 R::Result: Serialize,
216 {
217 self.on_with_thread_intent::<true, false, R>(
218 ThreadIntent::LatencySensitive,
219 f,
220 Self::content_modified_error,
221 )
222 }
223
224 pub(crate) fn finish(&mut self) {
225 if let Some(req) = self.req.take() {
226 tracing::error!("unknown request: {:?}", req);
227 let response = lsp_server::Response::new_err(
228 req.id,
229 lsp_server::ErrorCode::MethodNotFound as i32,
230 "unknown request".to_owned(),
231 );
232 self.global_state.respond(response);
233 }
234 }
235
236 fn on_with_thread_intent<const RUSTFMT: bool, const ALLOW_RETRYING: bool, R>(
237 &mut self,
238 intent: ThreadIntent,
239 f: fn(GlobalStateSnapshot, R::Params) -> anyhow::Result<R::Result>,
240 on_cancelled: fn() -> ResponseError,
241 ) -> &mut Self
242 where
243 R: lsp_types::request::Request + 'static,
244 R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
245 R::Result: Serialize,
246 {
247 let (req, params, panic_context) = match self.parse::<R>() {
248 Some(it) => it,
249 None => return self,
250 };
251 let _guard =
252 tracing::info_span!("request", method = ?req.method, "request_id" = ?req.id).entered();
253 tracing::debug!(?params);
254
255 let world = self.global_state.snapshot();
256 if RUSTFMT {
257 &mut self.global_state.fmt_pool.handle
258 } else {
259 &mut self.global_state.task_pool.handle
260 }
261 .spawn(intent, move || {
262 let result = panic::catch_unwind(move || {
263 let _pctx = DbPanicContext::enter(panic_context);
264 f(world, params)
265 });
266 match thread_result_to_response::<R>(req.id.clone(), result) {
267 Ok(response) => Task::Response(response),
268 Err(_cancelled) if ALLOW_RETRYING => Task::Retry(req),
269 Err(_cancelled) => {
270 let error = on_cancelled();
271 Task::Response(Response { id: req.id, result: None, error: Some(error) })
272 }
273 }
274 });
275
276 self
277 }
278
279 fn parse<R>(&mut self) -> Option<(lsp_server::Request, R::Params, String)>
280 where
281 R: lsp_types::request::Request,
282 R::Params: DeserializeOwned + fmt::Debug,
283 {
284 let req = self.req.take_if(|it| it.method == R::METHOD)?;
285 let res = crate::from_json(R::METHOD, &req.params);
286 match res {
287 Ok(params) => {
288 let panic_context =
289 format!("\nversion: {}\nrequest: {} {params:#?}", version(), R::METHOD);
290 Some((req, params, panic_context))
291 }
292 Err(err) => {
293 let response = lsp_server::Response::new_err(
294 req.id,
295 lsp_server::ErrorCode::InvalidParams as i32,
296 err.to_string(),
297 );
298 self.global_state.respond(response);
299 None
300 }
301 }
302 }
303
304 fn content_modified_error() -> ResponseError {
305 ResponseError {
306 code: lsp_server::ErrorCode::ContentModified as i32,
307 message: "content modified".to_owned(),
308 data: None,
309 }
310 }
311}
312
313#[derive(Debug)]
314enum HandlerCancelledError {
315 Inner(salsa::Cancelled),
316}
317
318impl std::error::Error for HandlerCancelledError {
319 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320 match self {
321 HandlerCancelledError::Inner(cancelled) => Some(cancelled),
322 }
323 }
324}
325
326impl fmt::Display for HandlerCancelledError {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 write!(f, "Cancelled")
329 }
330}
331
332fn thread_result_to_response<R>(
333 id: lsp_server::RequestId,
334 result: thread::Result<anyhow::Result<R::Result>>,
335) -> Result<lsp_server::Response, HandlerCancelledError>
336where
337 R: lsp_types::request::Request,
338 R::Params: DeserializeOwned,
339 R::Result: Serialize,
340{
341 match result {
342 Ok(result) => result_to_response::<R>(id, result),
343 Err(panic) => {
344 let panic_message = panic
345 .downcast_ref::<String>()
346 .map(String::as_str)
347 .or_else(|| panic.downcast_ref::<&str>().copied());
348
349 let mut message = "request handler panicked".to_owned();
350 if let Some(panic_message) = panic_message {
351 message.push_str(": ");
352 message.push_str(panic_message);
353 } else if let Ok(cancelled) = panic.downcast::<Cancelled>() {
354 tracing::error!("Cancellation propagated out of salsa! This is a bug");
355 return Err(HandlerCancelledError::Inner(*cancelled));
356 };
357
358 Ok(lsp_server::Response::new_err(
359 id,
360 lsp_server::ErrorCode::InternalError as i32,
361 message,
362 ))
363 }
364 }
365}
366
367fn result_to_response<R>(
368 id: lsp_server::RequestId,
369 result: anyhow::Result<R::Result>,
370) -> Result<lsp_server::Response, HandlerCancelledError>
371where
372 R: lsp_types::request::Request,
373 R::Params: DeserializeOwned,
374 R::Result: Serialize,
375{
376 let res = match result {
377 Ok(resp) => lsp_server::Response::new_ok(id, &resp),
378 Err(e) => match e.downcast::<LspError>() {
379 Ok(lsp_error) => lsp_server::Response::new_err(id, lsp_error.code, lsp_error.message),
380 Err(e) => match e.downcast::<Cancelled>() {
381 Ok(cancelled) => return Err(HandlerCancelledError::Inner(cancelled)),
382 Err(e) => lsp_server::Response::new_err(
383 id,
384 lsp_server::ErrorCode::InternalError as i32,
385 e.to_string(),
386 ),
387 },
388 },
389 };
390 Ok(res)
391}
392
393pub(crate) struct NotificationDispatcher<'a> {
394 pub(crate) not: Option<lsp_server::Notification>,
395 pub(crate) global_state: &'a mut GlobalState,
396}
397
398impl NotificationDispatcher<'_> {
399 pub(crate) fn on_sync_mut<N>(
400 &mut self,
401 f: fn(&mut GlobalState, N::Params) -> anyhow::Result<()>,
402 ) -> &mut Self
403 where
404 N: lsp_types::notification::Notification,
405 N::Params: DeserializeOwned + Send + Debug,
406 {
407 let not = match self.not.take() {
408 Some(it) => it,
409 None => return self,
410 };
411
412 let _guard = tracing::info_span!("notification", method = ?not.method).entered();
413
414 let params = match not.extract::<N::Params>(N::METHOD) {
415 Ok(it) => it,
416 Err(ExtractError::JsonError { method, error }) => {
417 panic!("Invalid request\nMethod: {method}\n error: {error}",)
418 }
419 Err(ExtractError::MethodMismatch(not)) => {
420 self.not = Some(not);
421 return self;
422 }
423 };
424
425 tracing::debug!(?params);
426
427 let _pctx =
428 DbPanicContext::enter(format!("\nversion: {}\nnotification: {}", version(), N::METHOD));
429 if let Err(e) = f(self.global_state, params) {
430 tracing::error!(handler = %N::METHOD, error = %e, "notification handler failed");
431 }
432 self
433 }
434
435 pub(crate) fn finish(&mut self) {
436 if let Some(not) = &self.not
437 && !not.method.starts_with("$/")
438 {
439 tracing::error!("unhandled notification: {:?}", not);
440 }
441 }
442}