1use crate::helpers::mod_path_to_ast;
4use either::Either;
5use hir::{
6 AsAssocItem, HirDisplay, HirFileId, ImportPathConfig, ModuleDef, SemanticsScope,
7 prettify_macro_expansion,
8};
9use itertools::Itertools;
10use rustc_hash::FxHashMap;
11use span::Edition;
12use syntax::{
13 NodeOrToken, SyntaxNode,
14 ast::{self, AstNode, HasGenericArgs, HasName, make},
15 syntax_editor::{self, SyntaxEditor},
16};
17
18#[derive(Default, Debug)]
19struct AstSubsts {
20 types_and_consts: Vec<TypeOrConst>,
21 lifetimes: Vec<ast::LifetimeArg>,
22}
23
24#[derive(Debug)]
25enum TypeOrConst {
26 Either(ast::TypeArg), Const(ast::ConstArg),
28}
29
30type LifetimeName = String;
31type DefaultedParam = Either<hir::TypeParam, hir::ConstParam>;
32
33pub struct PathTransform<'a> {
57 generic_def: Option<hir::GenericDef>,
58 substs: AstSubsts,
59 target_scope: &'a SemanticsScope<'a>,
60 source_scope: &'a SemanticsScope<'a>,
61}
62
63impl<'a> PathTransform<'a> {
64 pub fn trait_impl(
65 target_scope: &'a SemanticsScope<'a>,
66 source_scope: &'a SemanticsScope<'a>,
67 trait_: hir::Trait,
68 impl_: ast::Impl,
69 ) -> PathTransform<'a> {
70 PathTransform {
71 source_scope,
72 target_scope,
73 generic_def: Some(trait_.into()),
74 substs: get_syntactic_substs(impl_).unwrap_or_default(),
75 }
76 }
77
78 pub fn function_call(
79 target_scope: &'a SemanticsScope<'a>,
80 source_scope: &'a SemanticsScope<'a>,
81 function: hir::Function,
82 generic_arg_list: ast::GenericArgList,
83 ) -> PathTransform<'a> {
84 PathTransform {
85 source_scope,
86 target_scope,
87 generic_def: Some(function.into()),
88 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
89 }
90 }
91
92 pub fn impl_transformation(
93 target_scope: &'a SemanticsScope<'a>,
94 source_scope: &'a SemanticsScope<'a>,
95 impl_: hir::Impl,
96 generic_arg_list: ast::GenericArgList,
97 ) -> PathTransform<'a> {
98 PathTransform {
99 source_scope,
100 target_scope,
101 generic_def: Some(impl_.into()),
102 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
103 }
104 }
105
106 pub fn adt_transformation(
107 target_scope: &'a SemanticsScope<'a>,
108 source_scope: &'a SemanticsScope<'a>,
109 adt: hir::Adt,
110 generic_arg_list: ast::GenericArgList,
111 ) -> PathTransform<'a> {
112 PathTransform {
113 source_scope,
114 target_scope,
115 generic_def: Some(adt.into()),
116 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
117 }
118 }
119
120 pub fn generic_transformation(
121 target_scope: &'a SemanticsScope<'a>,
122 source_scope: &'a SemanticsScope<'a>,
123 ) -> PathTransform<'a> {
124 PathTransform {
125 source_scope,
126 target_scope,
127 generic_def: None,
128 substs: AstSubsts::default(),
129 }
130 }
131
132 #[must_use]
133 pub fn apply(&self, syntax: &SyntaxNode) -> SyntaxNode {
134 self.build_ctx().apply(syntax)
135 }
136
137 #[must_use]
138 pub fn apply_all<'b>(
139 &self,
140 nodes: impl IntoIterator<Item = &'b SyntaxNode>,
141 ) -> Vec<SyntaxNode> {
142 let ctx = self.build_ctx();
143 nodes.into_iter().map(|node| ctx.apply(&node.clone())).collect()
144 }
145
146 fn prettify_target_node(&self, node: SyntaxNode) -> SyntaxNode {
147 match self.target_scope.file_id() {
148 HirFileId::FileId(_) => node,
149 HirFileId::MacroFile(file_id) => {
150 let db = self.target_scope.db;
151 prettify_macro_expansion(
152 db,
153 node,
154 &db.expansion_span_map(file_id),
155 self.target_scope.module().krate().into(),
156 )
157 }
158 }
159 }
160
161 fn prettify_target_ast<N: AstNode>(&self, node: N) -> N {
162 N::cast(self.prettify_target_node(node.syntax().clone())).unwrap()
163 }
164
165 fn build_ctx(&self) -> Ctx<'a> {
166 let db = self.source_scope.db;
167 let target_module = self.target_scope.module();
168 let source_module = self.source_scope.module();
169 let skip = match self.generic_def {
170 Some(hir::GenericDef::Trait(_)) => 1,
172 _ => 0,
173 };
174 let mut type_substs: FxHashMap<hir::TypeParam, ast::Type> = Default::default();
175 let mut const_substs: FxHashMap<hir::ConstParam, SyntaxNode> = Default::default();
176 let mut defaulted_params: Vec<DefaultedParam> = Default::default();
177 let target_edition = target_module.krate().edition(self.source_scope.db);
178 self.generic_def
179 .into_iter()
180 .flat_map(|it| it.type_or_const_params(db))
181 .skip(skip)
182 .zip(self.substs.types_and_consts.iter().map(Some).chain(std::iter::repeat(None)))
189 .for_each(|(k, v)| match (k.split(db), v) {
190 (Either::Right(k), Some(TypeOrConst::Either(v))) => {
191 if let Some(ty) = v.ty() {
192 type_substs.insert(k, self.prettify_target_ast(ty));
193 }
194 }
195 (Either::Right(k), None) => {
196 if let Some(default) = k.default(db)
197 && let Some(default) =
198 &default.display_source_code(db, source_module.into(), false).ok()
199 {
200 type_substs.insert(k, make::ty(default).clone_for_update());
201 defaulted_params.push(Either::Left(k));
202 }
203 }
204 (Either::Left(k), Some(TypeOrConst::Either(v))) => {
205 if let Some(ty) = v.ty() {
206 const_substs.insert(k, self.prettify_target_node(ty.syntax().clone()));
207 }
208 }
209 (Either::Left(k), Some(TypeOrConst::Const(v))) => {
210 if let Some(expr) = v.expr() {
211 const_substs.insert(k, self.prettify_target_node(expr.syntax().clone()));
218 }
219 }
220 (Either::Left(k), None) => {
221 if let Some(default) =
222 k.default(db, target_module.krate().to_display_target(db))
223 && let Some(default) = default.expr()
224 {
225 const_substs.insert(k, default.syntax().clone_for_update());
226 defaulted_params.push(Either::Right(k));
227 }
228 }
229 _ => (), });
231 let lifetime_substs: FxHashMap<_, _> = self
233 .generic_def
234 .into_iter()
235 .flat_map(|it| it.lifetime_params(db))
236 .zip(self.substs.lifetimes.clone())
237 .filter_map(|(k, v)| {
238 Some((k.name(db).display(db, target_edition).to_string(), v.lifetime()?))
239 })
240 .collect();
241 let mut ctx = Ctx {
242 type_substs,
243 const_substs,
244 lifetime_substs,
245 target_module,
246 source_scope: self.source_scope,
247 same_self_type: self.target_scope.has_same_self_type(self.source_scope),
248 target_edition,
249 };
250 ctx.transform_default_values(defaulted_params);
251 ctx
252 }
253}
254
255struct Ctx<'a> {
256 type_substs: FxHashMap<hir::TypeParam, ast::Type>,
257 const_substs: FxHashMap<hir::ConstParam, SyntaxNode>,
258 lifetime_substs: FxHashMap<LifetimeName, ast::Lifetime>,
259 target_module: hir::Module,
260 source_scope: &'a SemanticsScope<'a>,
261 same_self_type: bool,
262 target_edition: Edition,
263}
264
265fn preorder_rev(item: &SyntaxNode) -> impl Iterator<Item = SyntaxNode> {
266 let x = item
267 .preorder()
268 .filter_map(|event| match event {
269 syntax::WalkEvent::Enter(node) => Some(node),
270 syntax::WalkEvent::Leave(_) => None,
271 })
272 .collect_vec();
273 x.into_iter().rev()
274}
275
276impl Ctx<'_> {
277 fn apply(&self, item: &SyntaxNode) -> SyntaxNode {
278 let item = self.transform_path(item).clone_subtree();
282 let mut editor = SyntaxEditor::new(item.clone());
283 preorder_rev(&item).filter_map(ast::Lifetime::cast).for_each(|lifetime| {
284 if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) {
285 editor
286 .replace(lifetime.syntax(), subst.clone_subtree().clone_for_update().syntax());
287 }
288 });
289
290 editor.finish().new_root().clone()
291 }
292
293 fn transform_default_values(&mut self, defaulted_params: Vec<DefaultedParam>) {
294 for param in defaulted_params {
298 let value = match ¶m {
299 Either::Left(k) => self.type_substs.get(k).unwrap().syntax(),
300 Either::Right(k) => self.const_substs.get(k).unwrap(),
301 };
302 let new_value = self.transform_path(value);
306 match param {
307 Either::Left(k) => {
308 self.type_substs.insert(k, ast::Type::cast(new_value.clone()).unwrap());
309 }
310 Either::Right(k) => {
311 self.const_substs.insert(k, new_value.clone());
312 }
313 }
314 }
315 }
316
317 fn transform_path(&self, path: &SyntaxNode) -> SyntaxNode {
318 fn find_child_paths_and_ident_pats(
319 root_path: &SyntaxNode,
320 ) -> Vec<Either<ast::Path, ast::IdentPat>> {
321 let mut result: Vec<Either<ast::Path, ast::IdentPat>> = Vec::new();
322 for child in root_path.children() {
323 if let Some(child_path) = ast::Path::cast(child.clone()) {
324 result.push(either::Left(child_path));
325 } else if let Some(child_ident_pat) = ast::IdentPat::cast(child.clone()) {
326 result.push(either::Right(child_ident_pat));
327 } else {
328 result.extend(find_child_paths_and_ident_pats(&child));
329 }
330 }
331 result
332 }
333
334 let root_path = path.clone_subtree();
335
336 let result = find_child_paths_and_ident_pats(&root_path);
337 let mut editor = SyntaxEditor::new(root_path.clone());
338 for sub_path in result {
339 let new = self.transform_path(sub_path.syntax());
340 editor.replace(sub_path.syntax(), new);
341 }
342
343 let update_sub_item = editor.finish().new_root().clone().clone_subtree();
344 let item = find_child_paths_and_ident_pats(&update_sub_item);
345 let mut editor = SyntaxEditor::new(update_sub_item);
346 for sub_path in item {
347 self.transform_path_or_ident_pat(&mut editor, &sub_path);
348 }
349 editor.finish().new_root().clone()
350 }
351 fn transform_path_or_ident_pat(
352 &self,
353 editor: &mut SyntaxEditor,
354 item: &Either<ast::Path, ast::IdentPat>,
355 ) -> Option<()> {
356 match item {
357 Either::Left(path) => self.transform_path_(editor, path),
358 Either::Right(ident_pat) => self.transform_ident_pat(editor, ident_pat),
359 }
360 }
361
362 fn transform_path_(&self, editor: &mut SyntaxEditor, path: &ast::Path) -> Option<()> {
363 if path.qualifier().is_some() {
364 return None;
365 }
366 if path.segment().is_some_and(|s| {
367 s.parenthesized_arg_list().is_some()
368 || (s.self_token().is_some() && path.parent_path().is_none())
369 }) {
370 return None;
373 }
374 let resolution = self.source_scope.speculative_resolve(path)?;
375
376 match resolution {
377 hir::PathResolution::TypeParam(tp) => {
378 if let Some(subst) = self.type_substs.get(&tp) {
379 let parent = path.syntax().parent()?;
380 if let Some(parent) = ast::Path::cast(parent.clone()) {
381 let trait_ref = find_trait_for_assoc_item(
390 self.source_scope,
391 tp,
392 parent.segment()?.name_ref()?,
393 )
394 .and_then(|trait_ref| {
395 let cfg = ImportPathConfig {
396 prefer_no_std: false,
397 prefer_prelude: true,
398 prefer_absolute: false,
399 allow_unstable: true,
400 };
401 let found_path = self.target_module.find_path(
402 self.source_scope.db,
403 hir::ModuleDef::Trait(trait_ref),
404 cfg,
405 )?;
406 match make::ty_path(mod_path_to_ast(&found_path, self.target_edition)) {
407 ast::Type::PathType(path_ty) => Some(path_ty),
408 _ => None,
409 }
410 });
411
412 let segment = make::path_segment_ty(subst.clone(), trait_ref);
413 let qualified = make::path_from_segments(std::iter::once(segment), false);
414 editor.replace(path.syntax(), qualified.clone_for_update().syntax());
415 } else if let Some(path_ty) = ast::PathType::cast(parent) {
416 let old = path_ty.syntax();
417
418 if old.parent().is_some() {
419 editor.replace(old, subst.clone_subtree().clone_for_update().syntax());
420 } else {
421 let new = subst.clone_subtree().clone_for_update();
426 if !matches!(new, ast::Type::PathType(..)) {
427 return None;
428 }
429 let start = path_ty.syntax().first_child().map(NodeOrToken::Node)?;
430 let end = path_ty.syntax().last_child().map(NodeOrToken::Node)?;
431 editor.replace_all(
432 start..=end,
433 new.syntax().children().map(NodeOrToken::Node).collect::<Vec<_>>(),
434 );
435 }
436 } else {
437 editor.replace(
438 path.syntax(),
439 subst.clone_subtree().clone_for_update().syntax(),
440 );
441 }
442 }
443 }
444 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
445 if let hir::ModuleDef::Trait(_) = def
446 && matches!(path.segment()?.kind()?, ast::PathSegmentKind::Type { .. })
447 {
448 return None;
453 }
454
455 let cfg = ImportPathConfig {
456 prefer_no_std: false,
457 prefer_prelude: true,
458 prefer_absolute: false,
459 allow_unstable: true,
460 };
461 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
462 let res = mod_path_to_ast(&found_path, self.target_edition).clone_for_update();
463 let mut res_editor = SyntaxEditor::new(res.syntax().clone_subtree());
464 if let Some(args) = path.segment().and_then(|it| it.generic_arg_list())
465 && let Some(segment) = res.segment()
466 {
467 if let Some(old) = segment.generic_arg_list() {
468 res_editor
469 .replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
470 } else {
471 res_editor.insert(
472 syntax_editor::Position::last_child_of(segment.syntax()),
473 args.clone_subtree().syntax().clone_for_update(),
474 );
475 }
476 }
477 let res = res_editor.finish().new_root().clone();
478 editor.replace(path.syntax().clone(), res);
479 }
480 hir::PathResolution::ConstParam(cp) => {
481 if let Some(subst) = self.const_substs.get(&cp) {
482 editor.replace(path.syntax(), subst.clone_subtree().clone_for_update());
483 }
484 }
485 hir::PathResolution::SelfType(imp) => {
486 if self.same_self_type {
488 return None;
489 }
490
491 let ty = imp.self_ty(self.source_scope.db);
492 let ty_str = &ty
493 .display_source_code(
494 self.source_scope.db,
495 self.source_scope.module().into(),
496 true,
497 )
498 .ok()?;
499 let ast_ty = make::ty(ty_str).clone_for_update();
500
501 if let Some(adt) = ty.as_adt()
502 && let ast::Type::PathType(path_ty) = &ast_ty
503 {
504 let cfg = ImportPathConfig {
505 prefer_no_std: false,
506 prefer_prelude: true,
507 prefer_absolute: false,
508 allow_unstable: true,
509 };
510 let found_path = self.target_module.find_path(
511 self.source_scope.db,
512 ModuleDef::from(adt),
513 cfg,
514 )?;
515
516 if let Some(qual) =
517 mod_path_to_ast(&found_path, self.target_edition).qualifier()
518 {
519 let res = make::path_concat(qual, path_ty.path()?).clone_for_update();
520 editor.replace(path.syntax(), res.syntax());
521 return Some(());
522 }
523 }
524
525 editor.replace(path.syntax(), ast_ty.syntax());
526 }
527 hir::PathResolution::Local(_)
528 | hir::PathResolution::Def(_)
529 | hir::PathResolution::BuiltinAttr(_)
530 | hir::PathResolution::ToolModule(_)
531 | hir::PathResolution::DeriveHelper(_) => (),
532 }
533 Some(())
534 }
535
536 fn transform_ident_pat(
537 &self,
538 editor: &mut SyntaxEditor,
539 ident_pat: &ast::IdentPat,
540 ) -> Option<()> {
541 let name = ident_pat.name()?;
542
543 let temp_path = make::path_from_text(&name.text());
544
545 let resolution = self.source_scope.speculative_resolve(&temp_path)?;
546
547 match resolution {
548 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
549 let cfg = ImportPathConfig {
550 prefer_no_std: false,
551 prefer_prelude: true,
552 prefer_absolute: false,
553 allow_unstable: true,
554 };
555 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
556 let res = mod_path_to_ast(&found_path, self.target_edition).clone_for_update();
557 editor.replace(ident_pat.syntax(), res.syntax());
558 Some(())
559 }
560 _ => None,
561 }
562 }
563}
564
565fn get_syntactic_substs(impl_def: ast::Impl) -> Option<AstSubsts> {
568 let target_trait = impl_def.trait_()?;
569 let path_type = match target_trait {
570 ast::Type::PathType(path) => path,
571 _ => return None,
572 };
573 let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
574
575 get_type_args_from_arg_list(generic_arg_list)
576}
577
578fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<AstSubsts> {
579 let mut result = AstSubsts::default();
580 generic_arg_list.generic_args().for_each(|generic_arg| match generic_arg {
581 ast::GenericArg::TypeArg(type_arg) => {
585 result.types_and_consts.push(TypeOrConst::Either(type_arg))
586 }
587 ast::GenericArg::ConstArg(const_arg) => {
589 result.types_and_consts.push(TypeOrConst::Const(const_arg));
590 }
591 ast::GenericArg::LifetimeArg(l_arg) => result.lifetimes.push(l_arg),
592 _ => (),
593 });
594
595 Some(result)
596}
597
598fn find_trait_for_assoc_item(
599 scope: &SemanticsScope<'_>,
600 type_param: hir::TypeParam,
601 assoc_item: ast::NameRef,
602) -> Option<hir::Trait> {
603 let db = scope.db;
604 let trait_bounds = type_param.trait_bounds(db);
605
606 let assoc_item_name = assoc_item.text();
607
608 for trait_ in trait_bounds {
609 let names = trait_.items(db).into_iter().filter_map(|item| match item {
610 hir::AssocItem::TypeAlias(ta) => Some(ta.name(db)),
611 hir::AssocItem::Const(cst) => cst.name(db),
612 _ => None,
613 });
614
615 for name in names {
616 if assoc_item_name.as_str() == name.as_str() {
617 return Some(trait_);
622 }
623 }
624 }
625
626 None
627}