1use crate::helpers::mod_path_to_ast;
4use either::Either;
5use hir::{
6 AsAssocItem, FindPathConfig, HirDisplay, HirFileId, ModuleDef, SemanticsScope,
7 prettify_macro_expansion,
8};
9use itertools::Itertools;
10use rustc_hash::FxHashMap;
11use span::Edition;
12use syntax::{
13 NodeOrToken, SyntaxNode,
14 ast::{self, AstNode, HasGenericArgs, HasName, make},
15 syntax_editor::{self, SyntaxEditor},
16};
17
18#[derive(Default, Debug)]
19struct AstSubsts {
20 types_and_consts: Vec<TypeOrConst>,
21 lifetimes: Vec<ast::LifetimeArg>,
22}
23
24#[derive(Debug)]
25enum TypeOrConst {
26 Either(ast::TypeArg), Const(ast::ConstArg),
28}
29
30type LifetimeName = String;
31type DefaultedParam = Either<hir::TypeParam, hir::ConstParam>;
32
33pub struct PathTransform<'a> {
57 generic_def: Option<hir::GenericDef>,
58 substs: AstSubsts,
59 target_scope: &'a SemanticsScope<'a>,
60 source_scope: &'a SemanticsScope<'a>,
61}
62
63impl<'a> PathTransform<'a> {
64 pub fn trait_impl(
65 target_scope: &'a SemanticsScope<'a>,
66 source_scope: &'a SemanticsScope<'a>,
67 trait_: hir::Trait,
68 impl_: ast::Impl,
69 ) -> PathTransform<'a> {
70 PathTransform {
71 source_scope,
72 target_scope,
73 generic_def: Some(trait_.into()),
74 substs: get_syntactic_substs(impl_).unwrap_or_default(),
75 }
76 }
77
78 pub fn function_call(
79 target_scope: &'a SemanticsScope<'a>,
80 source_scope: &'a SemanticsScope<'a>,
81 function: hir::Function,
82 generic_arg_list: ast::GenericArgList,
83 ) -> PathTransform<'a> {
84 PathTransform {
85 source_scope,
86 target_scope,
87 generic_def: Some(function.into()),
88 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
89 }
90 }
91
92 pub fn impl_transformation(
93 target_scope: &'a SemanticsScope<'a>,
94 source_scope: &'a SemanticsScope<'a>,
95 impl_: hir::Impl,
96 generic_arg_list: ast::GenericArgList,
97 ) -> PathTransform<'a> {
98 PathTransform {
99 source_scope,
100 target_scope,
101 generic_def: Some(impl_.into()),
102 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
103 }
104 }
105
106 pub fn adt_transformation(
107 target_scope: &'a SemanticsScope<'a>,
108 source_scope: &'a SemanticsScope<'a>,
109 adt: hir::Adt,
110 generic_arg_list: ast::GenericArgList,
111 ) -> PathTransform<'a> {
112 PathTransform {
113 source_scope,
114 target_scope,
115 generic_def: Some(adt.into()),
116 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
117 }
118 }
119
120 pub fn generic_transformation(
121 target_scope: &'a SemanticsScope<'a>,
122 source_scope: &'a SemanticsScope<'a>,
123 ) -> PathTransform<'a> {
124 PathTransform {
125 source_scope,
126 target_scope,
127 generic_def: None,
128 substs: AstSubsts::default(),
129 }
130 }
131
132 #[must_use]
133 pub fn apply(&self, syntax: &SyntaxNode) -> SyntaxNode {
134 self.build_ctx().apply(syntax)
135 }
136
137 #[must_use]
138 pub fn apply_all<'b>(
139 &self,
140 nodes: impl IntoIterator<Item = &'b SyntaxNode>,
141 ) -> Vec<SyntaxNode> {
142 let ctx = self.build_ctx();
143 nodes.into_iter().map(|node| ctx.apply(&node.clone())).collect()
144 }
145
146 fn prettify_target_node(&self, node: SyntaxNode) -> SyntaxNode {
147 match self.target_scope.file_id() {
148 HirFileId::FileId(_) => node,
149 HirFileId::MacroFile(file_id) => {
150 let db = self.target_scope.db;
151 prettify_macro_expansion(
152 db,
153 node,
154 &db.expansion_span_map(file_id),
155 self.target_scope.module().krate(db).into(),
156 )
157 }
158 }
159 }
160
161 fn prettify_target_ast<N: AstNode>(&self, node: N) -> N {
162 N::cast(self.prettify_target_node(node.syntax().clone())).unwrap()
163 }
164
165 fn build_ctx(&self) -> Ctx<'a> {
166 let db = self.source_scope.db;
167 let target_module = self.target_scope.module();
168 let source_module = self.source_scope.module();
169 let skip = match self.generic_def {
170 Some(hir::GenericDef::Trait(_)) => 1,
172 _ => 0,
173 };
174 let mut type_substs: FxHashMap<hir::TypeParam, ast::Type> = Default::default();
175 let mut const_substs: FxHashMap<hir::ConstParam, SyntaxNode> = Default::default();
176 let mut defaulted_params: Vec<DefaultedParam> = Default::default();
177 let target_edition = target_module.krate(db).edition(self.source_scope.db);
178 self.generic_def
179 .into_iter()
180 .flat_map(|it| it.type_or_const_params(db))
181 .skip(skip)
182 .zip(self.substs.types_and_consts.iter().map(Some).chain(std::iter::repeat(None)))
189 .for_each(|(k, v)| match (k.split(db), v) {
190 (Either::Right(k), Some(TypeOrConst::Either(v))) => {
191 if let Some(ty) = v.ty() {
192 type_substs.insert(k, self.prettify_target_ast(ty));
193 }
194 }
195 (Either::Right(k), None) => {
196 if let Some(default) = k.default(db)
197 && let Some(default) =
198 &default.display_source_code(db, source_module.into(), false).ok()
199 {
200 type_substs.insert(k, make::ty(default));
201 defaulted_params.push(Either::Left(k));
202 }
203 }
204 (Either::Left(k), Some(TypeOrConst::Either(v))) => {
205 if let Some(ty) = v.ty() {
206 const_substs.insert(k, self.prettify_target_node(ty.syntax().clone()));
207 }
208 }
209 (Either::Left(k), Some(TypeOrConst::Const(v))) => {
210 if let Some(expr) = v.expr() {
211 const_substs.insert(k, self.prettify_target_node(expr.syntax().clone()));
218 }
219 }
220 (Either::Left(k), None) => {
221 if let Some(default) =
222 k.default(db, target_module.krate(db).to_display_target(db))
223 && let Some(default) = default.expr()
224 {
225 const_substs.insert(k, default.syntax().clone());
226 defaulted_params.push(Either::Right(k));
227 }
228 }
229 _ => (), });
231 let lifetime_substs: FxHashMap<_, _> = self
233 .generic_def
234 .into_iter()
235 .flat_map(|it| it.lifetime_params(db))
236 .zip(self.substs.lifetimes.clone())
237 .filter_map(|(k, v)| {
238 Some((k.name(db).display(db, target_edition).to_string(), v.lifetime()?))
239 })
240 .collect();
241 let mut ctx = Ctx {
242 type_substs,
243 const_substs,
244 lifetime_substs,
245 target_module,
246 source_scope: self.source_scope,
247 same_self_type: self.target_scope.has_same_self_type(self.source_scope),
248 target_edition,
249 };
250 ctx.transform_default_values(defaulted_params);
251 ctx
252 }
253}
254
255struct Ctx<'a> {
256 type_substs: FxHashMap<hir::TypeParam, ast::Type>,
257 const_substs: FxHashMap<hir::ConstParam, SyntaxNode>,
258 lifetime_substs: FxHashMap<LifetimeName, ast::Lifetime>,
259 target_module: hir::Module,
260 source_scope: &'a SemanticsScope<'a>,
261 same_self_type: bool,
262 target_edition: Edition,
263}
264
265fn preorder_rev(item: &SyntaxNode) -> impl Iterator<Item = SyntaxNode> {
266 let x = item
267 .preorder()
268 .filter_map(|event| match event {
269 syntax::WalkEvent::Enter(node) => Some(node),
270 syntax::WalkEvent::Leave(_) => None,
271 })
272 .collect_vec();
273 x.into_iter().rev()
274}
275
276impl Ctx<'_> {
277 fn apply(&self, item: &SyntaxNode) -> SyntaxNode {
278 let (editor, item) = SyntaxEditor::new(self.transform_path(item));
282 preorder_rev(&item).filter_map(ast::Lifetime::cast).for_each(|lifetime| {
283 if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) {
284 editor.replace(lifetime.syntax(), subst.clone().syntax());
285 }
286 });
287
288 editor.finish().new_root().clone()
289 }
290
291 fn transform_default_values(&mut self, defaulted_params: Vec<DefaultedParam>) {
292 for param in defaulted_params {
296 let value = match ¶m {
297 Either::Left(k) => self.type_substs.get(k).unwrap().syntax(),
298 Either::Right(k) => self.const_substs.get(k).unwrap(),
299 };
300 let new_value = self.transform_path(value);
304 match param {
305 Either::Left(k) => {
306 self.type_substs.insert(k, ast::Type::cast(new_value.clone()).unwrap());
307 }
308 Either::Right(k) => {
309 self.const_substs.insert(k, new_value.clone());
310 }
311 }
312 }
313 }
314
315 fn transform_path(&self, path: &SyntaxNode) -> SyntaxNode {
316 fn find_child_paths_and_ident_pats(
317 root_path: &SyntaxNode,
318 ) -> Vec<Either<ast::Path, ast::IdentPat>> {
319 let mut result: Vec<Either<ast::Path, ast::IdentPat>> = Vec::new();
320 for child in root_path.children() {
321 if let Some(child_path) = ast::Path::cast(child.clone()) {
322 result.push(either::Left(child_path));
323 } else if let Some(child_ident_pat) = ast::IdentPat::cast(child.clone()) {
324 result.push(either::Right(child_ident_pat));
325 } else {
326 result.extend(find_child_paths_and_ident_pats(&child));
327 }
328 }
329 result
330 }
331
332 let (editor, root_path) = SyntaxEditor::new(path.clone());
333 let result = find_child_paths_and_ident_pats(&root_path);
334 for sub_path in result {
335 let new = self.transform_path(sub_path.syntax());
336 editor.replace(sub_path.syntax(), new);
337 }
338 let (editor, update_sub_item) = SyntaxEditor::new(editor.finish().new_root().clone());
339 let item = find_child_paths_and_ident_pats(&update_sub_item);
340 for sub_path in item {
341 self.transform_path_or_ident_pat(&editor, &sub_path);
342 }
343 editor.finish().new_root().clone()
344 }
345 fn transform_path_or_ident_pat(
346 &self,
347 editor: &SyntaxEditor,
348 item: &Either<ast::Path, ast::IdentPat>,
349 ) -> Option<()> {
350 match item {
351 Either::Left(path) => self.transform_path_(editor, path),
352 Either::Right(ident_pat) => self.transform_ident_pat(editor, ident_pat),
353 }
354 }
355
356 fn transform_path_(&self, editor: &SyntaxEditor, path: &ast::Path) -> Option<()> {
357 if path.qualifier().is_some() {
358 return None;
359 }
360 if path.segment().is_some_and(|s| {
361 s.parenthesized_arg_list().is_some()
362 || (s.self_token().is_some() && path.parent_path().is_none())
363 }) {
364 return None;
367 }
368 let resolution = self.source_scope.speculative_resolve(path)?;
369
370 match resolution {
371 hir::PathResolution::TypeParam(tp) => {
372 if let Some(subst) = self.type_substs.get(&tp) {
373 let parent = path.syntax().parent()?;
374 if let Some(parent) = ast::Path::cast(parent.clone()) {
375 let trait_ref = find_trait_for_assoc_item(
384 self.source_scope,
385 tp,
386 parent.segment()?.name_ref()?,
387 )
388 .and_then(|trait_ref| {
389 let cfg = FindPathConfig {
390 prefer_no_std: false,
391 prefer_prelude: true,
392 prefer_absolute: false,
393 allow_unstable: true,
394 };
395 let found_path = self.target_module.find_path(
396 self.source_scope.db,
397 hir::ModuleDef::Trait(trait_ref),
398 cfg,
399 )?;
400 match make::ty_path(mod_path_to_ast(&found_path, self.target_edition)) {
401 ast::Type::PathType(path_ty) => Some(path_ty),
402 _ => None,
403 }
404 });
405
406 let segment = make::path_segment_ty(subst.clone(), trait_ref);
407 let qualified = make::path_from_segments(std::iter::once(segment), false);
408 editor.replace(path.syntax(), qualified.clone().syntax());
409 } else if let Some(path_ty) = ast::PathType::cast(parent) {
410 let old = path_ty.syntax();
411
412 if old.parent().is_some() {
413 editor.replace(old, subst.clone().syntax());
414 } else {
415 let start = path_ty.syntax().first_child().map(NodeOrToken::Node)?;
416 let end = path_ty.syntax().last_child().map(NodeOrToken::Node)?;
417 editor.replace_all(
418 start..=end,
419 subst
420 .clone()
421 .syntax()
422 .children()
423 .map(NodeOrToken::Node)
424 .collect::<Vec<_>>(),
425 );
426 }
427 } else {
428 editor.replace(path.syntax(), subst.clone().syntax());
429 }
430 }
431 }
432 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
433 if let hir::ModuleDef::Trait(_) = def
434 && matches!(path.segment()?.kind()?, ast::PathSegmentKind::Type { .. })
435 {
436 return None;
441 }
442
443 let cfg = FindPathConfig {
444 prefer_no_std: false,
445 prefer_prelude: true,
446 prefer_absolute: false,
447 allow_unstable: true,
448 };
449 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
450 let res = mod_path_to_ast(&found_path, self.target_edition);
451 let (res_editor, res) = SyntaxEditor::with_ast_node(&res);
452 if let Some(args) = path.segment().and_then(|it| it.generic_arg_list())
453 && let Some(segment) = res.segment()
454 {
455 if let Some(old) = segment.generic_arg_list() {
456 res_editor.replace(old.syntax(), args.syntax().clone())
457 } else {
458 res_editor.insert(
459 syntax_editor::Position::last_child_of(segment.syntax()),
460 args.syntax().clone(),
461 );
462 }
463 }
464 let res = res_editor.finish().new_root().clone();
465 editor.replace(path.syntax().clone(), res);
466 }
467 hir::PathResolution::ConstParam(cp) => {
468 if let Some(subst) = self.const_substs.get(&cp) {
469 editor.replace(path.syntax(), subst.clone());
470 }
471 }
472 hir::PathResolution::SelfType(imp) => {
473 if self.same_self_type {
475 return None;
476 }
477
478 let ty = imp.self_ty(self.source_scope.db);
479 let ty_str = &ty
480 .display_source_code(
481 self.source_scope.db,
482 self.source_scope.module().into(),
483 true,
484 )
485 .ok()?;
486 let ast_ty = make::ty(ty_str);
487
488 if let Some(adt) = ty.as_adt()
489 && let ast::Type::PathType(path_ty) = &ast_ty
490 {
491 let cfg = FindPathConfig {
492 prefer_no_std: false,
493 prefer_prelude: true,
494 prefer_absolute: false,
495 allow_unstable: true,
496 };
497 let found_path = self.target_module.find_path(
498 self.source_scope.db,
499 ModuleDef::from(adt),
500 cfg,
501 )?;
502
503 if let Some(qual) =
504 mod_path_to_ast(&found_path, self.target_edition).qualifier()
505 {
506 editor.replace(
507 path.syntax(),
508 make::path_concat(qual, path_ty.path()?).syntax(),
509 );
510 return Some(());
511 }
512 }
513
514 editor.replace(path.syntax(), ast_ty.syntax());
515 }
516 hir::PathResolution::Local(_)
517 | hir::PathResolution::Def(_)
518 | hir::PathResolution::BuiltinAttr(_)
519 | hir::PathResolution::ToolModule(_)
520 | hir::PathResolution::DeriveHelper(_) => (),
521 }
522 Some(())
523 }
524
525 fn transform_ident_pat(&self, editor: &SyntaxEditor, ident_pat: &ast::IdentPat) -> Option<()> {
526 let name = ident_pat.name()?;
527
528 let temp_path = make::path_from_text(&name.text());
529
530 let resolution = self.source_scope.speculative_resolve(&temp_path)?;
531
532 match resolution {
533 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
534 if matches!(def, hir::ModuleDef::Macro(_)) {
538 return None;
539 }
540
541 if matches!(def, hir::ModuleDef::Module(_)) {
543 return None;
544 }
545
546 if matches!(
547 def,
548 hir::ModuleDef::Function(_)
549 | hir::ModuleDef::Trait(_)
550 | hir::ModuleDef::TypeAlias(_)
551 ) {
552 return None;
553 }
554
555 if let hir::ModuleDef::Adt(adt) = def {
556 match adt {
557 hir::Adt::Struct(s)
558 if s.kind(self.source_scope.db) != hir::StructKind::Unit =>
559 {
560 return None;
561 }
562 hir::Adt::Union(_) => return None,
563 hir::Adt::Enum(_) => return None,
564 _ => (),
565 }
566 }
567
568 if let hir::ModuleDef::EnumVariant(v) = def
569 && v.kind(self.source_scope.db) != hir::StructKind::Unit
570 {
571 return None;
572 }
573
574 let cfg = FindPathConfig {
575 prefer_no_std: false,
576 prefer_prelude: true,
577 prefer_absolute: false,
578 allow_unstable: true,
579 };
580 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
581 editor.replace(
582 ident_pat.syntax(),
583 mod_path_to_ast(&found_path, self.target_edition).syntax(),
584 );
585 Some(())
586 }
587 _ => None,
588 }
589 }
590}
591
592fn get_syntactic_substs(impl_def: ast::Impl) -> Option<AstSubsts> {
595 let target_trait = impl_def.trait_()?;
596 let path_type = match target_trait {
597 ast::Type::PathType(path) => path,
598 _ => return None,
599 };
600 let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
601
602 get_type_args_from_arg_list(generic_arg_list)
603}
604
605fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<AstSubsts> {
606 let mut result = AstSubsts::default();
607 generic_arg_list.generic_args().for_each(|generic_arg| match generic_arg {
608 ast::GenericArg::TypeArg(type_arg) => {
612 result.types_and_consts.push(TypeOrConst::Either(type_arg))
613 }
614 ast::GenericArg::ConstArg(const_arg) => {
616 result.types_and_consts.push(TypeOrConst::Const(const_arg));
617 }
618 ast::GenericArg::LifetimeArg(l_arg) => result.lifetimes.push(l_arg),
619 _ => (),
620 });
621
622 Some(result)
623}
624
625fn find_trait_for_assoc_item(
626 scope: &SemanticsScope<'_>,
627 type_param: hir::TypeParam,
628 assoc_item: ast::NameRef,
629) -> Option<hir::Trait> {
630 let db = scope.db;
631 let trait_bounds = type_param.trait_bounds(db);
632
633 let assoc_item_name = assoc_item.text();
634
635 for trait_ in trait_bounds {
636 let names = trait_.items(db).into_iter().filter_map(|item| match item {
637 hir::AssocItem::TypeAlias(ta) => Some(ta.name(db)),
638 hir::AssocItem::Const(cst) => cst.name(db),
639 _ => None,
640 });
641
642 for name in names {
643 if assoc_item_name.as_str() == name.as_str() {
644 return Some(trait_);
649 }
650 }
651 }
652
653 None
654}
655
656#[cfg(test)]
657mod tests {
658 use crate::RootDatabase;
659 use crate::path_transform::PathTransform;
660 use hir::Semantics;
661 use syntax::{AstNode, ast::HasName};
662 use test_fixture::WithFixture;
663 use test_utils::assert_eq_text;
664
665 #[test]
666 fn test_transform_ident_pat() {
667 let (db, file_id) = RootDatabase::with_single_file(
668 r#"
669mod foo {
670 pub struct UnitStruct;
671 pub struct RecordStruct {}
672 pub enum Enum { UnitVariant, RecordVariant {} }
673 pub fn function() {}
674 pub const CONST: i32 = 0;
675 pub static STATIC: i32 = 0;
676 pub type Alias = i32;
677 pub union Union { f: i32 }
678}
679
680mod bar {
681 fn anchor() {}
682}
683
684fn main() {
685 use foo::*;
686 use foo::Enum::*;
687 let UnitStruct = ();
688 let RecordStruct = ();
689 let Enum = ();
690 let UnitVariant = ();
691 let RecordVariant = ();
692 let function = ();
693 let CONST = ();
694 let STATIC = ();
695 let Alias = ();
696 let Union = ();
697}
698"#,
699 );
700 let sema = Semantics::new(&db);
701 let source_file = sema.parse(file_id);
702
703 let function = source_file
704 .syntax()
705 .descendants()
706 .filter_map(syntax::ast::Fn::cast)
707 .find(|it| it.name().unwrap().text() == "main")
708 .unwrap();
709 let source_scope = sema.scope(function.body().unwrap().syntax()).unwrap();
710
711 let anchor = source_file
712 .syntax()
713 .descendants()
714 .filter_map(syntax::ast::Fn::cast)
715 .find(|it| it.name().unwrap().text() == "anchor")
716 .unwrap();
717 let target_scope = sema.scope(anchor.body().unwrap().syntax()).unwrap();
718
719 let transform = PathTransform::generic_transformation(&target_scope, &source_scope);
720 let transformed = transform.apply(function.body().unwrap().syntax());
721
722 let expected = r#"{
723 use crate::foo::*;
724 use crate::foo::Enum::*;
725 let crate::foo::UnitStruct = ();
726 let RecordStruct = ();
727 let Enum = ();
728 let crate::foo::Enum::UnitVariant = ();
729 let RecordVariant = ();
730 let function = ();
731 let crate::foo::CONST = ();
732 let crate::foo::STATIC = ();
733 let Alias = ();
734 let Union = ();
735}"#;
736 assert_eq_text!(expected, &transformed.to_string());
737 }
738}