zink_codegen/
storage.rs

1use crate::utils::Bytes32;
2use heck::AsSnakeCase;
3use proc_macro::TokenStream;
4use proc_macro2::{Literal, Span, TokenTree};
5use quote::quote;
6use std::{cell::RefCell, collections::HashSet};
7use syn::{
8    meta::{self, ParseNestedMeta},
9    parse::{Parse, ParseStream, Result},
10    parse_quote, Attribute, Ident, ItemFn, ItemStruct, Visibility,
11};
12
13thread_local! {
14   static STORAGE_REGISTRY: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
15   static TRANSIENT_STORAGE_REGISTRY: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
16}
17
18/// Storage type (persistent or transient)
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum StorageKind {
21    Persistent,
22    Transient,
23}
24
25/// Storage attributes parser
26pub struct Storage {
27    /// Storage kind (persistent or transient)
28    kind: StorageKind,
29    /// kind of the storage
30    ty: StorageType,
31    /// The source and the target storage struct
32    target: ItemStruct,
33    /// Getter function of storage
34    getter: Option<Ident>,
35}
36
37impl Storage {
38    /// Parse from proc_macro attribute for persistent storage
39    pub fn parse(ty: StorageType, target: ItemStruct) -> TokenStream {
40        let storage = Self::from_parts(StorageKind::Persistent, ty, target);
41        storage.expand()
42    }
43
44    /// Parse from proc_macro attribute for transient storage
45    pub fn parse_transient(ty: StorageType, target: ItemStruct) -> TokenStream {
46        let storage = Self::from_parts(StorageKind::Transient, ty, target);
47        storage.expand()
48    }
49
50    fn from_parts(kind: StorageKind, ty: StorageType, target: ItemStruct) -> Self {
51        let mut this = Self {
52            kind,
53            ty,
54            target,
55            getter: None,
56        };
57
58        let mut attrs: Vec<Attribute> = Default::default();
59        for attr in this.target.attrs.iter().cloned() {
60            if !attr.path().is_ident("getter") {
61                attrs.push(attr);
62                continue;
63            }
64
65            let Ok(list) = attr.meta.require_list().clone() else {
66                panic!("Invalid getter arguments");
67            };
68
69            let Some(TokenTree::Ident(getter)) = list.tokens.clone().into_iter().nth(0) else {
70                panic!("Invalid getter function name");
71            };
72
73            this.getter = Some(getter);
74        }
75
76        this.target.attrs = attrs;
77        this
78    }
79
80    fn expand(mut self) -> TokenStream {
81        match &self.ty {
82            StorageType::Value(value) => self.expand_value(value.clone()),
83            StorageType::Mapping { key, value } => self.expand_mapping(key.clone(), value.clone()),
84            StorageType::DoubleKeyMapping { key1, key2, value } => {
85                self.expand_dk_mapping(key1.clone(), key2.clone(), value.clone())
86            }
87            StorageType::Invalid => panic!("Invalid storage type"),
88        }
89    }
90
91    fn expand_value(&mut self, value: Ident) -> TokenStream {
92        let is = &self.target;
93        let name = self.target.ident.clone();
94        let slot = self.get_storage_slot(name.to_string());
95        let key = slot.to_bytes32();
96
97        let keyl = Literal::byte_string(&key);
98        let trait_path = match self.kind {
99            StorageKind::Persistent => quote!(zink::storage::Storage),
100            StorageKind::Transient => quote!(zink::storage::TransientStorage),
101        };
102
103        let mut expanded = quote! {
104            #is
105
106            impl #trait_path for #name {
107                #[cfg(not(target_family = "wasm"))]
108                const STORAGE_KEY: [u8; 32] = *#keyl;
109                const STORAGE_SLOT: i32 = #slot;
110
111                type Value = #value;
112            }
113        };
114
115        if let Some(getter) = self.getter() {
116            let gs: proc_macro2::TokenStream = parse_quote! {
117                #[allow(missing_docs)]
118                #[zink::external]
119                pub fn #getter() -> #value {
120                    #name::get()
121                }
122            };
123            expanded.extend(gs);
124        }
125
126        expanded.into()
127    }
128
129    fn expand_mapping(&mut self, key: Ident, value: Ident) -> TokenStream {
130        let is = &self.target;
131        let name = self.target.ident.clone();
132        let slot = self.get_storage_slot(name.to_string());
133
134        let trait_path = match self.kind {
135            StorageKind::Persistent => quote!(zink::storage::Mapping),
136            StorageKind::Transient => quote!(zink::transient_storage::TransientMapping),
137        };
138
139        let mut expanded = quote! {
140            #is
141
142            impl #trait_path for #name {
143                const STORAGE_SLOT: i32 = #slot;
144
145                type Key = #key;
146                type Value = #value;
147
148                #[cfg(not(target_family = "wasm"))]
149                fn storage_key(key: Self::Key) -> [u8; 32] {
150                    use zink::Value;
151
152                    let mut seed = [0; 64];
153                    seed[..32].copy_from_slice(&key.bytes32());
154                    seed[32..].copy_from_slice(&Self::STORAGE_SLOT.bytes32());
155                    zink::keccak256(&seed)
156                }
157            }
158        };
159
160        if let Some(getter) = self.getter() {
161            let gs: proc_macro2::TokenStream = parse_quote! {
162                #[allow(missing_docs)]
163                #[zink::external]
164                pub fn #getter(key: #key) -> #value {
165                    #name::get(key)
166                }
167            };
168            expanded.extend(gs);
169        }
170
171        expanded.into()
172    }
173
174    fn expand_dk_mapping(&mut self, key1: Ident, key2: Ident, value: Ident) -> TokenStream {
175        let is = &self.target;
176        let name = self.target.ident.clone();
177        let slot = self.get_storage_slot(name.to_string());
178
179        let trait_path = match self.kind {
180            StorageKind::Persistent => quote!(zink::storage::DoubleKeyMapping),
181            StorageKind::Transient => quote!(zink::transient_storage::DoubleKeyTransientMapping),
182        };
183
184        let mut expanded = quote! {
185            #is
186
187            impl #trait_path for #name {
188                const STORAGE_SLOT: i32 = #slot;
189
190                type Key1 = #key1;
191                type Key2 = #key2;
192                type Value = #value;
193
194                #[cfg(not(target_family = "wasm"))]
195                fn storage_key(key1: Self::Key1, key2: Self::Key2) -> [u8; 32] {
196                    use zink::Value;
197
198                    let mut seed = [0; 64];
199                    seed[..32].copy_from_slice(&key1.bytes32());
200                    seed[32..].copy_from_slice(&Self::STORAGE_SLOT.bytes32());
201                    let skey1 = zink::keccak256(&seed);
202                    seed[..32].copy_from_slice(&skey1);
203                    seed[32..].copy_from_slice(&key2.bytes32());
204                    zink::keccak256(&seed)
205                }
206            }
207        };
208
209        if let Some(getter) = self.getter() {
210            let gs: proc_macro2::TokenStream = parse_quote! {
211                #[allow(missing_docs)]
212                #[zink::external]
213                pub fn #getter(key1: #key1, key2: #key2) -> #value {
214                    #name::get(key1, key2)
215                }
216            };
217            expanded.extend(gs);
218        }
219
220        expanded.into()
221    }
222
223    fn get_storage_slot(&self, name: String) -> i32 {
224        match self.kind {
225            StorageKind::Persistent => STORAGE_REGISTRY.with_borrow_mut(|r| {
226                let key = r.len();
227                if !r.insert(name.clone()) {
228                    panic!("Storage {name} has already been declared");
229                }
230                key
231            }) as i32,
232            StorageKind::Transient => TRANSIENT_STORAGE_REGISTRY.with_borrow_mut(|r| {
233                let key = r.len();
234                if !r.insert(name.clone()) {
235                    panic!("Transient storage {name} has already been declared");
236                }
237                key
238            }) as i32,
239        }
240    }
241
242    /// Get the getter of this storage
243    fn getter(&mut self) -> Option<Ident> {
244        let mut getter = if matches!(self.target.vis, Visibility::Public(_)) {
245            let fname = Ident::new(
246                &AsSnakeCase(self.target.ident.to_string()).to_string(),
247                Span::call_site(),
248            );
249            Some(fname)
250        } else {
251            None
252        };
253
254        self.getter.take().or(getter)
255    }
256}
257
258/// Zink storage type parser
259#[derive(Default, Debug)]
260pub enum StorageType {
261    /// Single value storage
262    Value(Ident),
263    /// Mapping storage
264    Mapping { key: Ident, value: Ident },
265    /// Double key mapping storage
266    DoubleKeyMapping {
267        key1: Ident,
268        key2: Ident,
269        value: Ident,
270    },
271    /// Invalid storage type
272    #[default]
273    Invalid,
274}
275
276impl From<TokenStream> for StorageType {
277    fn from(input: TokenStream) -> Self {
278        let tokens = input.to_string();
279        let types: Vec<_> = tokens.split(',').collect();
280        match types.len() {
281            1 => StorageType::Value(Ident::new(types[0].trim(), Span::call_site())),
282            2 => StorageType::Mapping {
283                key: Ident::new(types[0].trim(), Span::call_site()),
284                value: Ident::new(types[1].trim(), Span::call_site()),
285            },
286            3 => StorageType::DoubleKeyMapping {
287                key1: Ident::new(types[0].trim(), Span::call_site()),
288                key2: Ident::new(types[1].trim(), Span::call_site()),
289                value: Ident::new(types[2].trim(), Span::call_site()),
290            },
291            _ => panic!("Invalid storage attributes"),
292        }
293    }
294}