ra_salsa_macros/
database_storage.rsuse heck::ToSnakeCase;
use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Ident, ItemStruct, Path, Token};
type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>;
pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(args as QueryGroupList);
let input = syn::parse_macro_input!(input as ItemStruct);
let query_groups = &args.query_groups;
let database_name = &input.ident;
let visibility = &input.vis;
let db_storage_field = quote! { storage };
let mut output = proc_macro2::TokenStream::new();
output.extend(quote! { #input });
let query_group_names_snake: Vec<_> = query_groups
.iter()
.map(|query_group| {
let group_name = query_group.name();
Ident::new(&group_name.to_string().to_snake_case(), group_name.span())
})
.collect();
let query_group_storage_names: Vec<_> = query_groups
.iter()
.map(|QueryGroup { group_path }| {
quote! {
<#group_path as ra_salsa::plumbing::QueryGroup>::GroupStorage
}
})
.collect();
let mut storage_fields = proc_macro2::TokenStream::new();
let mut storage_initializers = proc_macro2::TokenStream::new();
let mut has_group_impls = proc_macro2::TokenStream::new();
for (((query_group, group_name_snake), group_storage), group_index) in query_groups
.iter()
.zip(&query_group_names_snake)
.zip(&query_group_storage_names)
.zip(0_u16..)
{
let group_path = &query_group.group_path;
storage_fields.extend(quote! {
#group_name_snake: #group_storage,
});
storage_initializers.extend(quote! {
#group_name_snake: #group_storage::new(#group_index),
});
has_group_impls.extend(quote! {
impl ra_salsa::plumbing::HasQueryGroup<#group_path> for #database_name {
fn group_storage(&self) -> &#group_storage {
&self.#db_storage_field.query_store().#group_name_snake
}
fn group_storage_mut(&mut self) -> (&#group_storage, &mut ra_salsa::Runtime) {
let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut();
(&query_store_mut.#group_name_snake, runtime)
}
}
});
}
output.extend(quote! {
#[doc(hidden)]
#visibility struct __SalsaDatabaseStorage {
#storage_fields
}
impl Default for __SalsaDatabaseStorage {
fn default() -> Self {
Self {
#storage_initializers
}
}
}
});
let mut database_data = vec![];
for QueryGroup { group_path } in query_groups {
database_data.push(quote! {
<#group_path as ra_salsa::plumbing::QueryGroup>::GroupData
});
}
output.extend(quote! {
impl ra_salsa::plumbing::DatabaseStorageTypes for #database_name {
type DatabaseStorage = __SalsaDatabaseStorage;
}
});
let mut fmt_ops = proc_macro2::TokenStream::new();
let mut maybe_changed_ops = proc_macro2::TokenStream::new();
let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
let mut for_each_ops = proc_macro2::TokenStream::new();
for ((QueryGroup { group_path }, group_storage), group_index) in
query_groups.iter().zip(&query_group_storage_names).zip(0_u16..)
{
fmt_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.fmt_index(self, input, fmt)
}
});
maybe_changed_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.maybe_changed_after(self, input, revision)
}
});
cycle_recovery_strategy_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.cycle_recovery_strategy(self, input)
}
});
for_each_ops.extend(quote! {
let storage: &#group_storage =
<Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.for_each_query(runtime, &mut op);
});
}
output.extend(quote! {
impl ra_salsa::plumbing::DatabaseOps for #database_name {
fn ops_database(&self) -> &dyn ra_salsa::Database {
self
}
fn ops_salsa_runtime(&self) -> &ra_salsa::Runtime {
self.#db_storage_field.salsa_runtime()
}
fn synthetic_write(&mut self, durability: ra_salsa::Durability) {
self.#db_storage_field.salsa_runtime_mut().synthetic_write(durability)
}
fn fmt_index(
&self,
input: ra_salsa::DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match input.group_index() {
#fmt_ops
i => panic!("ra_salsa: invalid group index {}", i)
}
}
fn maybe_changed_after(
&self,
input: ra_salsa::DatabaseKeyIndex,
revision: ra_salsa::Revision
) -> bool {
match input.group_index() {
#maybe_changed_ops
i => panic!("ra_salsa: invalid group index {}", i)
}
}
fn cycle_recovery_strategy(
&self,
input: ra_salsa::DatabaseKeyIndex,
) -> ra_salsa::plumbing::CycleRecoveryStrategy {
match input.group_index() {
#cycle_recovery_strategy_ops
i => panic!("ra_salsa: invalid group index {}", i)
}
}
fn for_each_query(
&self,
mut op: &mut dyn FnMut(&dyn ra_salsa::plumbing::QueryStorageMassOps),
) {
let runtime = ra_salsa::Database::salsa_runtime(self);
#for_each_ops
}
}
});
output.extend(has_group_impls);
output.into()
}
#[derive(Clone, Debug)]
struct QueryGroupList {
query_groups: PunctuatedQueryGroups,
}
impl Parse for QueryGroupList {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let query_groups: PunctuatedQueryGroups =
input.parse_terminated(QueryGroup::parse, Token![,])?;
Ok(QueryGroupList { query_groups })
}
}
#[derive(Clone, Debug)]
struct QueryGroup {
group_path: Path,
}
impl QueryGroup {
fn name(&self) -> Ident {
self.group_path.segments.last().unwrap().ident.clone()
}
}
impl Parse for QueryGroup {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let group_path: Path = input.parse()?;
Ok(QueryGroup { group_path })
}
}