1use crate::helpers::mod_path_to_ast_with_factory;
4use either::Either;
5use hir::{
6 AsAssocItem, FindPathConfig, HirDisplay, HirFileId, ModuleDef, SemanticsScope,
7 prettify_macro_expansion,
8};
9use itertools::Itertools;
10use rustc_hash::FxHashMap;
11use span::Edition;
12use syntax::{
13 NodeOrToken, SyntaxNode,
14 ast::{self, AstNode, HasGenericArgs, HasName, make},
15 syntax_editor::{self, SyntaxEditor},
16};
17
18#[derive(Default, Debug)]
19struct AstSubsts {
20 types_and_consts: Vec<TypeOrConst>,
21 lifetimes: Vec<ast::LifetimeArg>,
22}
23
24#[derive(Debug)]
25enum TypeOrConst {
26 Either(ast::TypeArg), Const(ast::ConstArg),
28}
29
30type LifetimeName = String;
31type DefaultedParam = Either<hir::TypeParam, hir::ConstParam>;
32
33pub struct PathTransform<'a> {
57 generic_def: Option<hir::GenericDef>,
58 substs: AstSubsts,
59 target_scope: &'a SemanticsScope<'a>,
60 source_scope: &'a SemanticsScope<'a>,
61}
62
63impl<'a> PathTransform<'a> {
64 pub fn trait_impl(
65 target_scope: &'a SemanticsScope<'a>,
66 source_scope: &'a SemanticsScope<'a>,
67 trait_: hir::Trait,
68 impl_: ast::Impl,
69 ) -> PathTransform<'a> {
70 PathTransform {
71 source_scope,
72 target_scope,
73 generic_def: Some(trait_.into()),
74 substs: get_syntactic_substs(impl_).unwrap_or_default(),
75 }
76 }
77
78 pub fn function_call(
79 target_scope: &'a SemanticsScope<'a>,
80 source_scope: &'a SemanticsScope<'a>,
81 function: hir::Function,
82 generic_arg_list: ast::GenericArgList,
83 ) -> PathTransform<'a> {
84 PathTransform {
85 source_scope,
86 target_scope,
87 generic_def: Some(function.into()),
88 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
89 }
90 }
91
92 pub fn impl_transformation(
93 target_scope: &'a SemanticsScope<'a>,
94 source_scope: &'a SemanticsScope<'a>,
95 impl_: hir::Impl,
96 generic_arg_list: ast::GenericArgList,
97 ) -> PathTransform<'a> {
98 PathTransform {
99 source_scope,
100 target_scope,
101 generic_def: Some(impl_.into()),
102 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
103 }
104 }
105
106 pub fn adt_transformation(
107 target_scope: &'a SemanticsScope<'a>,
108 source_scope: &'a SemanticsScope<'a>,
109 adt: hir::Adt,
110 generic_arg_list: ast::GenericArgList,
111 ) -> PathTransform<'a> {
112 PathTransform {
113 source_scope,
114 target_scope,
115 generic_def: Some(adt.into()),
116 substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
117 }
118 }
119
120 pub fn generic_transformation(
121 target_scope: &'a SemanticsScope<'a>,
122 source_scope: &'a SemanticsScope<'a>,
123 ) -> PathTransform<'a> {
124 PathTransform {
125 source_scope,
126 target_scope,
127 generic_def: None,
128 substs: AstSubsts::default(),
129 }
130 }
131
132 #[must_use]
133 pub fn apply(&self, syntax: &SyntaxNode) -> SyntaxNode {
134 self.build_ctx().apply(syntax)
135 }
136
137 #[must_use]
138 pub fn apply_all<'b>(
139 &self,
140 nodes: impl IntoIterator<Item = &'b SyntaxNode>,
141 ) -> Vec<SyntaxNode> {
142 let ctx = self.build_ctx();
143 nodes.into_iter().map(|node| ctx.apply(&node.clone())).collect()
144 }
145
146 fn prettify_target_node(&self, node: SyntaxNode) -> SyntaxNode {
147 match self.target_scope.file_id() {
148 HirFileId::FileId(_) => node,
149 HirFileId::MacroFile(file_id) => {
150 let db = self.target_scope.db;
151 prettify_macro_expansion(
152 db,
153 node,
154 db.expansion_span_map(file_id),
155 self.target_scope.module().krate(db).into(),
156 )
157 }
158 }
159 }
160
161 fn prettify_target_ast<N: AstNode>(&self, node: N) -> N {
162 N::cast(self.prettify_target_node(node.syntax().clone())).unwrap()
163 }
164
165 fn build_ctx(&self) -> Ctx<'a> {
166 let db = self.source_scope.db;
167 let target_module = self.target_scope.module();
168 let source_module = self.source_scope.module();
169 let skip = match self.generic_def {
170 Some(hir::GenericDef::Trait(_)) => 1,
172 _ => 0,
173 };
174 let mut type_substs: FxHashMap<hir::TypeParam, ast::Type> = Default::default();
175 let mut const_substs: FxHashMap<hir::ConstParam, SyntaxNode> = Default::default();
176 let mut defaulted_params: Vec<DefaultedParam> = Default::default();
177 let target_edition = target_module.krate(db).edition(self.source_scope.db);
178 self.generic_def
179 .into_iter()
180 .flat_map(|it| it.type_or_const_params(db))
181 .skip(skip)
182 .zip(self.substs.types_and_consts.iter().map(Some).chain(std::iter::repeat(None)))
189 .for_each(|(k, v)| match (k.split(db), v) {
190 (Either::Right(k), Some(TypeOrConst::Either(v))) => {
191 if let Some(ty) = v.ty() {
192 type_substs.insert(k, self.prettify_target_ast(ty));
193 }
194 }
195 (Either::Right(k), None) => {
196 if let Some(default) = k.default(db)
197 && let Some(default) =
198 &default.display_source_code(db, source_module.into(), false).ok()
199 {
200 type_substs.insert(k, make::ty(default));
201 defaulted_params.push(Either::Left(k));
202 }
203 }
204 (Either::Left(k), Some(TypeOrConst::Either(v))) => {
205 if let Some(ty) = v.ty() {
206 const_substs.insert(k, self.prettify_target_node(ty.syntax().clone()));
207 }
208 }
209 (Either::Left(k), Some(TypeOrConst::Const(v))) => {
210 if let Some(expr) = v.expr() {
211 const_substs.insert(k, self.prettify_target_node(expr.syntax().clone()));
218 }
219 }
220 (Either::Left(k), None) => {
221 if let Some(default) = k.default_source_code(db, target_module)
222 && let Some(default) = default.expr()
223 {
224 const_substs.insert(k, default.syntax().clone());
225 defaulted_params.push(Either::Right(k));
226 }
227 }
228 _ => (), });
230 let lifetime_substs: FxHashMap<_, _> = self
232 .generic_def
233 .into_iter()
234 .flat_map(|it| it.lifetime_params(db))
235 .zip(self.substs.lifetimes.clone())
236 .filter_map(|(k, v)| {
237 Some((k.name(db).display(db, target_edition).to_string(), v.lifetime()?))
238 })
239 .collect();
240 let mut ctx = Ctx {
241 type_substs,
242 const_substs,
243 lifetime_substs,
244 target_module,
245 source_scope: self.source_scope,
246 same_self_type: self.target_scope.has_same_self_type(self.source_scope),
247 target_edition,
248 };
249 ctx.transform_default_values(defaulted_params);
250 ctx
251 }
252}
253
254struct Ctx<'a> {
255 type_substs: FxHashMap<hir::TypeParam, ast::Type>,
256 const_substs: FxHashMap<hir::ConstParam, SyntaxNode>,
257 lifetime_substs: FxHashMap<LifetimeName, ast::Lifetime>,
258 target_module: hir::Module,
259 source_scope: &'a SemanticsScope<'a>,
260 same_self_type: bool,
261 target_edition: Edition,
262}
263
264fn preorder_rev(item: &SyntaxNode) -> impl Iterator<Item = SyntaxNode> {
265 let x = item
266 .preorder()
267 .filter_map(|event| match event {
268 syntax::WalkEvent::Enter(node) => Some(node),
269 syntax::WalkEvent::Leave(_) => None,
270 })
271 .collect_vec();
272 x.into_iter().rev()
273}
274
275impl Ctx<'_> {
276 fn apply(&self, item: &SyntaxNode) -> SyntaxNode {
277 let (editor, item) = SyntaxEditor::new(self.transform_path(item));
281 preorder_rev(&item).filter_map(ast::Lifetime::cast).for_each(|lifetime| {
282 if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) {
283 editor.replace(lifetime.syntax(), subst.clone().syntax());
284 }
285 });
286
287 editor.finish().new_root().clone()
288 }
289
290 fn transform_default_values(&mut self, defaulted_params: Vec<DefaultedParam>) {
291 for param in defaulted_params {
295 let value = match ¶m {
296 Either::Left(k) => self.type_substs.get(k).unwrap().syntax(),
297 Either::Right(k) => self.const_substs.get(k).unwrap(),
298 };
299 let new_value = self.transform_path(value);
303 match param {
304 Either::Left(k) => {
305 self.type_substs.insert(k, ast::Type::cast(new_value.clone()).unwrap());
306 }
307 Either::Right(k) => {
308 self.const_substs.insert(k, new_value.clone());
309 }
310 }
311 }
312 }
313
314 fn transform_path(&self, path: &SyntaxNode) -> SyntaxNode {
315 fn find_child_paths_and_ident_pats(
316 root_path: &SyntaxNode,
317 ) -> Vec<Either<ast::Path, ast::IdentPat>> {
318 let mut result: Vec<Either<ast::Path, ast::IdentPat>> = Vec::new();
319 for child in root_path.children() {
320 if let Some(child_path) = ast::Path::cast(child.clone()) {
321 result.push(either::Left(child_path));
322 } else if let Some(child_ident_pat) = ast::IdentPat::cast(child.clone()) {
323 result.push(either::Right(child_ident_pat));
324 } else {
325 result.extend(find_child_paths_and_ident_pats(&child));
326 }
327 }
328 result
329 }
330
331 let (editor, root_path) = SyntaxEditor::new(path.clone());
332 let result = find_child_paths_and_ident_pats(&root_path);
333 for sub_path in result {
334 let new = self.transform_path(sub_path.syntax());
335 editor.replace(sub_path.syntax(), new);
336 }
337 let (editor, update_sub_item) = SyntaxEditor::new(editor.finish().new_root().clone());
338 let item = find_child_paths_and_ident_pats(&update_sub_item);
339 for sub_path in item {
340 self.transform_path_or_ident_pat(&editor, &sub_path);
341 }
342 editor.finish().new_root().clone()
343 }
344 fn transform_path_or_ident_pat(
345 &self,
346 editor: &SyntaxEditor,
347 item: &Either<ast::Path, ast::IdentPat>,
348 ) -> Option<()> {
349 match item {
350 Either::Left(path) => self.transform_path_(editor, path),
351 Either::Right(ident_pat) => self.transform_ident_pat(editor, ident_pat),
352 }
353 }
354
355 fn transform_path_(&self, editor: &SyntaxEditor, path: &ast::Path) -> Option<()> {
356 let make = editor.make();
357 if path.qualifier().is_some() {
358 return None;
359 }
360 if path.segment().is_some_and(|s| {
361 s.parenthesized_arg_list().is_some()
362 || (s.self_token().is_some() && path.parent_path().is_none())
363 }) {
364 return None;
367 }
368 let resolution = self.source_scope.speculative_resolve(path)?;
369
370 match resolution {
371 hir::PathResolution::TypeParam(tp) => {
372 if let Some(subst) = self.type_substs.get(&tp) {
373 let parent = path.syntax().parent()?;
374 if let Some(parent) = ast::Path::cast(parent.clone()) {
375 let trait_ref = find_trait_for_assoc_item(
384 self.source_scope,
385 tp,
386 parent.segment()?.name_ref()?,
387 )
388 .and_then(|trait_ref| {
389 let cfg = FindPathConfig {
390 prefer_no_std: false,
391 prefer_prelude: true,
392 prefer_absolute: false,
393 allow_unstable: true,
394 };
395 let found_path = self.target_module.find_path(
396 self.source_scope.db,
397 hir::ModuleDef::Trait(trait_ref),
398 cfg,
399 )?;
400 match make
401 .ty_path(mod_path_to_ast_with_factory(
402 make,
403 &found_path,
404 self.target_edition,
405 ))
406 .into()
407 {
408 ast::Type::PathType(path_ty) => Some(path_ty),
409 _ => None,
410 }
411 });
412
413 let segment = make::path_segment_ty(subst.clone(), trait_ref);
414 let qualified = make::path_from_segments(std::iter::once(segment), false);
415 editor.replace(path.syntax(), qualified.clone().syntax());
416 } else if let Some(path_ty) = ast::PathType::cast(parent) {
417 let old = path_ty.syntax();
418
419 if old.parent().is_some() {
420 editor.replace(old, subst.clone().syntax());
421 } else {
422 let start = path_ty.syntax().first_child().map(NodeOrToken::Node)?;
423 let end = path_ty.syntax().last_child().map(NodeOrToken::Node)?;
424 editor.replace_all(
425 start..=end,
426 subst
427 .clone()
428 .syntax()
429 .children()
430 .map(NodeOrToken::Node)
431 .collect::<Vec<_>>(),
432 );
433 }
434 } else {
435 editor.replace(path.syntax(), subst.clone().syntax());
436 }
437 }
438 }
439 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
440 if let hir::ModuleDef::Trait(_) = def
441 && matches!(path.segment()?.kind()?, ast::PathSegmentKind::Type { .. })
442 {
443 return None;
448 }
449
450 let cfg = FindPathConfig {
451 prefer_no_std: false,
452 prefer_prelude: true,
453 prefer_absolute: false,
454 allow_unstable: true,
455 };
456 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
457 let res = mod_path_to_ast_with_factory(make, &found_path, self.target_edition);
458 let (res_editor, res) = SyntaxEditor::with_ast_node(&res);
459 if let Some(args) = path.segment().and_then(|it| it.generic_arg_list())
460 && let Some(segment) = res.segment()
461 {
462 if let Some(old) = segment.generic_arg_list() {
463 res_editor.replace(old.syntax(), args.syntax().clone())
464 } else {
465 res_editor.insert(
466 syntax_editor::Position::last_child_of(segment.syntax()),
467 args.syntax().clone(),
468 );
469 }
470 }
471 let res = res_editor.finish().new_root().clone();
472 editor.replace(path.syntax().clone(), res);
473 }
474 hir::PathResolution::ConstParam(cp) => {
475 if let Some(subst) = self.const_substs.get(&cp) {
476 editor.replace(path.syntax(), subst.clone());
477 }
478 }
479 hir::PathResolution::SelfType(imp) => {
480 if self.same_self_type {
482 return None;
483 }
484
485 let ty = imp.self_ty(self.source_scope.db);
486 let ty_str = &ty
487 .display_source_code(
488 self.source_scope.db,
489 self.source_scope.module().into(),
490 true,
491 )
492 .ok()?;
493 let ast_ty = make::ty(ty_str);
494
495 if let Some(adt) = ty.as_adt()
496 && let ast::Type::PathType(path_ty) = &ast_ty
497 {
498 let cfg = FindPathConfig {
499 prefer_no_std: false,
500 prefer_prelude: true,
501 prefer_absolute: false,
502 allow_unstable: true,
503 };
504 let found_path = self.target_module.find_path(
505 self.source_scope.db,
506 ModuleDef::from(adt),
507 cfg,
508 )?;
509
510 if let Some(qual) =
511 mod_path_to_ast_with_factory(make, &found_path, self.target_edition)
512 .qualifier()
513 {
514 editor.replace(
515 path.syntax(),
516 make::path_concat(qual, path_ty.path()?).syntax(),
517 );
518 return Some(());
519 }
520 }
521
522 editor.replace(path.syntax(), ast_ty.syntax());
523 }
524 hir::PathResolution::Local(_)
525 | hir::PathResolution::Def(_)
526 | hir::PathResolution::BuiltinAttr(_)
527 | hir::PathResolution::ToolModule(_)
528 | hir::PathResolution::DeriveHelper(_) => (),
529 }
530 Some(())
531 }
532
533 fn transform_ident_pat(&self, editor: &SyntaxEditor, ident_pat: &ast::IdentPat) -> Option<()> {
534 let name = ident_pat.name()?;
535 let make = editor.make();
536
537 let temp_path = make.path_from_text(&name.text());
538
539 let resolution = self.source_scope.speculative_resolve(&temp_path)?;
540
541 match resolution {
542 hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
543 if matches!(def, hir::ModuleDef::Macro(_)) {
547 return None;
548 }
549
550 if matches!(def, hir::ModuleDef::Module(_)) {
552 return None;
553 }
554
555 if matches!(
556 def,
557 hir::ModuleDef::Function(_)
558 | hir::ModuleDef::Trait(_)
559 | hir::ModuleDef::TypeAlias(_)
560 ) {
561 return None;
562 }
563
564 if let hir::ModuleDef::Adt(adt) = def {
565 match adt {
566 hir::Adt::Struct(s)
567 if s.kind(self.source_scope.db) != hir::StructKind::Unit =>
568 {
569 return None;
570 }
571 hir::Adt::Union(_) => return None,
572 hir::Adt::Enum(_) => return None,
573 _ => (),
574 }
575 }
576
577 if let hir::ModuleDef::EnumVariant(v) = def
578 && v.kind(self.source_scope.db) != hir::StructKind::Unit
579 {
580 return None;
581 }
582
583 let cfg = FindPathConfig {
584 prefer_no_std: false,
585 prefer_prelude: true,
586 prefer_absolute: false,
587 allow_unstable: true,
588 };
589 let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?;
590 editor.replace(
591 ident_pat.syntax(),
592 mod_path_to_ast_with_factory(make, &found_path, self.target_edition).syntax(),
593 );
594 Some(())
595 }
596 _ => None,
597 }
598 }
599}
600
601fn get_syntactic_substs(impl_def: ast::Impl) -> Option<AstSubsts> {
604 let target_trait = impl_def.trait_()?;
605 let path_type = match target_trait {
606 ast::Type::PathType(path) => path,
607 _ => return None,
608 };
609 let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
610
611 get_type_args_from_arg_list(generic_arg_list)
612}
613
614fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<AstSubsts> {
615 let mut result = AstSubsts::default();
616 generic_arg_list.generic_args().for_each(|generic_arg| match generic_arg {
617 ast::GenericArg::TypeArg(type_arg) => {
621 result.types_and_consts.push(TypeOrConst::Either(type_arg))
622 }
623 ast::GenericArg::ConstArg(const_arg) => {
625 result.types_and_consts.push(TypeOrConst::Const(const_arg));
626 }
627 ast::GenericArg::LifetimeArg(l_arg) => result.lifetimes.push(l_arg),
628 _ => (),
629 });
630
631 Some(result)
632}
633
634fn find_trait_for_assoc_item(
635 scope: &SemanticsScope<'_>,
636 type_param: hir::TypeParam,
637 assoc_item: ast::NameRef,
638) -> Option<hir::Trait> {
639 let db = scope.db;
640 let trait_bounds = type_param.trait_bounds(db);
641
642 let assoc_item_name = assoc_item.text();
643
644 for trait_ in trait_bounds {
645 let names = trait_.items(db).into_iter().filter_map(|item| match item {
646 hir::AssocItem::TypeAlias(ta) => Some(ta.name(db)),
647 hir::AssocItem::Const(cst) => cst.name(db),
648 _ => None,
649 });
650
651 for name in names {
652 if assoc_item_name.as_str() == name.as_str() {
653 return Some(trait_);
658 }
659 }
660 }
661
662 None
663}
664
665#[cfg(test)]
666mod tests {
667 use crate::RootDatabase;
668 use crate::path_transform::PathTransform;
669 use hir::Semantics;
670 use syntax::{AstNode, ast::HasName};
671 use test_fixture::WithFixture;
672 use test_utils::assert_eq_text;
673
674 #[test]
675 fn test_transform_ident_pat() {
676 let (db, file_id) = RootDatabase::with_single_file(
677 r#"
678mod foo {
679 pub struct UnitStruct;
680 pub struct RecordStruct {}
681 pub enum Enum { UnitVariant, RecordVariant {} }
682 pub fn function() {}
683 pub const CONST: i32 = 0;
684 pub static STATIC: i32 = 0;
685 pub type Alias = i32;
686 pub union Union { f: i32 }
687}
688
689mod bar {
690 fn anchor() {}
691}
692
693fn main() {
694 use foo::*;
695 use foo::Enum::*;
696 let UnitStruct = ();
697 let RecordStruct = ();
698 let Enum = ();
699 let UnitVariant = ();
700 let RecordVariant = ();
701 let function = ();
702 let CONST = ();
703 let STATIC = ();
704 let Alias = ();
705 let Union = ();
706}
707"#,
708 );
709 let sema = Semantics::new(&db);
710 let source_file = sema.parse(file_id);
711
712 let function = source_file
713 .syntax()
714 .descendants()
715 .filter_map(syntax::ast::Fn::cast)
716 .find(|it| it.name().unwrap().text() == "main")
717 .unwrap();
718 let source_scope = sema.scope(function.body().unwrap().syntax()).unwrap();
719
720 let anchor = source_file
721 .syntax()
722 .descendants()
723 .filter_map(syntax::ast::Fn::cast)
724 .find(|it| it.name().unwrap().text() == "anchor")
725 .unwrap();
726 let target_scope = sema.scope(anchor.body().unwrap().syntax()).unwrap();
727
728 let transform = PathTransform::generic_transformation(&target_scope, &source_scope);
729 let transformed = transform.apply(function.body().unwrap().syntax());
730
731 let expected = r#"{
732 use crate::foo::*;
733 use crate::foo::Enum::*;
734 let crate::foo::UnitStruct = ();
735 let RecordStruct = ();
736 let Enum = ();
737 let crate::foo::Enum::UnitVariant = ();
738 let RecordVariant = ();
739 let function = ();
740 let crate::foo::CONST = ();
741 let crate::foo::STATIC = ();
742 let Alias = ();
743 let Union = ();
744}"#;
745 assert_eq_text!(expected, &transformed.to_string());
746 }
747}