1use crate::utils::Bytes32;
2use heck::AsSnakeCase;
3use proc_macro::TokenStream;
4use proc_macro2::{Literal, Span, TokenTree};
5use quote::quote;
6use std::{cell::RefCell, collections::HashSet};
7use syn::{
8 meta::{self, ParseNestedMeta},
9 parse::{Parse, ParseStream, Result},
10 parse_quote, Attribute, Ident, ItemFn, ItemStruct, Visibility,
11};
12
13thread_local! {
14 static STORAGE_REGISTRY: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
15 static TRANSIENT_STORAGE_REGISTRY: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
16}
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum StorageKind {
21 Persistent,
22 Transient,
23}
24
25pub struct Storage {
27 kind: StorageKind,
29 ty: StorageType,
31 target: ItemStruct,
33 getter: Option<Ident>,
35}
36
37impl Storage {
38 pub fn parse(ty: StorageType, target: ItemStruct) -> TokenStream {
40 let storage = Self::from_parts(StorageKind::Persistent, ty, target);
41 storage.expand()
42 }
43
44 pub fn parse_transient(ty: StorageType, target: ItemStruct) -> TokenStream {
46 let storage = Self::from_parts(StorageKind::Transient, ty, target);
47 storage.expand()
48 }
49
50 fn from_parts(kind: StorageKind, ty: StorageType, target: ItemStruct) -> Self {
51 let mut this = Self {
52 kind,
53 ty,
54 target,
55 getter: None,
56 };
57
58 let mut attrs: Vec<Attribute> = Default::default();
59 for attr in this.target.attrs.iter().cloned() {
60 if !attr.path().is_ident("getter") {
61 attrs.push(attr);
62 continue;
63 }
64
65 let Ok(list) = attr.meta.require_list().clone() else {
66 panic!("Invalid getter arguments");
67 };
68
69 let Some(TokenTree::Ident(getter)) = list.tokens.clone().into_iter().nth(0) else {
70 panic!("Invalid getter function name");
71 };
72
73 this.getter = Some(getter);
74 }
75
76 this.target.attrs = attrs;
77 this
78 }
79
80 fn expand(mut self) -> TokenStream {
81 match &self.ty {
82 StorageType::Value(value) => self.expand_value(value.clone()),
83 StorageType::Mapping { key, value } => self.expand_mapping(key.clone(), value.clone()),
84 StorageType::DoubleKeyMapping { key1, key2, value } => {
85 self.expand_dk_mapping(key1.clone(), key2.clone(), value.clone())
86 }
87 StorageType::Invalid => panic!("Invalid storage type"),
88 }
89 }
90
91 fn expand_value(&mut self, value: Ident) -> TokenStream {
92 let is = &self.target;
93 let name = self.target.ident.clone();
94 let slot = self.get_storage_slot(name.to_string());
95 let key = slot.to_bytes32();
96
97 let keyl = Literal::byte_string(&key);
98 let trait_path = match self.kind {
99 StorageKind::Persistent => quote!(zink::storage::Storage),
100 StorageKind::Transient => quote!(zink::storage::TransientStorage),
101 };
102
103 let mut expanded = quote! {
104 #is
105
106 impl #trait_path for #name {
107 #[cfg(not(target_family = "wasm"))]
108 const STORAGE_KEY: [u8; 32] = *#keyl;
109 const STORAGE_SLOT: i32 = #slot;
110
111 type Value = #value;
112 }
113 };
114
115 if let Some(getter) = self.getter() {
116 let gs: proc_macro2::TokenStream = parse_quote! {
117 #[allow(missing_docs)]
118 #[zink::external]
119 pub fn #getter() -> #value {
120 #name::get()
121 }
122 };
123 expanded.extend(gs);
124 }
125
126 expanded.into()
127 }
128
129 fn expand_mapping(&mut self, key: Ident, value: Ident) -> TokenStream {
130 let is = &self.target;
131 let name = self.target.ident.clone();
132 let slot = self.get_storage_slot(name.to_string());
133
134 let trait_path = match self.kind {
135 StorageKind::Persistent => quote!(zink::storage::Mapping),
136 StorageKind::Transient => quote!(zink::transient_storage::TransientMapping),
137 };
138
139 let mut expanded = quote! {
140 #is
141
142 impl #trait_path for #name {
143 const STORAGE_SLOT: i32 = #slot;
144
145 type Key = #key;
146 type Value = #value;
147
148 #[cfg(not(target_family = "wasm"))]
149 fn storage_key(key: Self::Key) -> [u8; 32] {
150 use zink::Value;
151
152 let mut seed = [0; 64];
153 seed[..32].copy_from_slice(&key.bytes32());
154 seed[32..].copy_from_slice(&Self::STORAGE_SLOT.bytes32());
155 zink::keccak256(&seed)
156 }
157 }
158 };
159
160 if let Some(getter) = self.getter() {
161 let gs: proc_macro2::TokenStream = parse_quote! {
162 #[allow(missing_docs)]
163 #[zink::external]
164 pub fn #getter(key: #key) -> #value {
165 #name::get(key)
166 }
167 };
168 expanded.extend(gs);
169 }
170
171 expanded.into()
172 }
173
174 fn expand_dk_mapping(&mut self, key1: Ident, key2: Ident, value: Ident) -> TokenStream {
175 let is = &self.target;
176 let name = self.target.ident.clone();
177 let slot = self.get_storage_slot(name.to_string());
178
179 let trait_path = match self.kind {
180 StorageKind::Persistent => quote!(zink::storage::DoubleKeyMapping),
181 StorageKind::Transient => quote!(zink::transient_storage::DoubleKeyTransientMapping),
182 };
183
184 let mut expanded = quote! {
185 #is
186
187 impl #trait_path for #name {
188 const STORAGE_SLOT: i32 = #slot;
189
190 type Key1 = #key1;
191 type Key2 = #key2;
192 type Value = #value;
193
194 #[cfg(not(target_family = "wasm"))]
195 fn storage_key(key1: Self::Key1, key2: Self::Key2) -> [u8; 32] {
196 use zink::Value;
197
198 let mut seed = [0; 64];
199 seed[..32].copy_from_slice(&key1.bytes32());
200 seed[32..].copy_from_slice(&Self::STORAGE_SLOT.bytes32());
201 let skey1 = zink::keccak256(&seed);
202 seed[..32].copy_from_slice(&skey1);
203 seed[32..].copy_from_slice(&key2.bytes32());
204 zink::keccak256(&seed)
205 }
206 }
207 };
208
209 if let Some(getter) = self.getter() {
210 let gs: proc_macro2::TokenStream = parse_quote! {
211 #[allow(missing_docs)]
212 #[zink::external]
213 pub fn #getter(key1: #key1, key2: #key2) -> #value {
214 #name::get(key1, key2)
215 }
216 };
217 expanded.extend(gs);
218 }
219
220 expanded.into()
221 }
222
223 fn get_storage_slot(&self, name: String) -> i32 {
224 match self.kind {
225 StorageKind::Persistent => STORAGE_REGISTRY.with_borrow_mut(|r| {
226 let key = r.len();
227 if !r.insert(name.clone()) {
228 panic!("Storage {name} has already been declared");
229 }
230 key
231 }) as i32,
232 StorageKind::Transient => TRANSIENT_STORAGE_REGISTRY.with_borrow_mut(|r| {
233 let key = r.len();
234 if !r.insert(name.clone()) {
235 panic!("Transient storage {name} has already been declared");
236 }
237 key
238 }) as i32,
239 }
240 }
241
242 fn getter(&mut self) -> Option<Ident> {
244 let mut getter = if matches!(self.target.vis, Visibility::Public(_)) {
245 let fname = Ident::new(
246 &AsSnakeCase(self.target.ident.to_string()).to_string(),
247 Span::call_site(),
248 );
249 Some(fname)
250 } else {
251 None
252 };
253
254 self.getter.take().or(getter)
255 }
256}
257
258#[derive(Default, Debug)]
260pub enum StorageType {
261 Value(Ident),
263 Mapping { key: Ident, value: Ident },
265 DoubleKeyMapping {
267 key1: Ident,
268 key2: Ident,
269 value: Ident,
270 },
271 #[default]
273 Invalid,
274}
275
276impl From<TokenStream> for StorageType {
277 fn from(input: TokenStream) -> Self {
278 let tokens = input.to_string();
279 let types: Vec<_> = tokens.split(',').collect();
280 match types.len() {
281 1 => StorageType::Value(Ident::new(types[0].trim(), Span::call_site())),
282 2 => StorageType::Mapping {
283 key: Ident::new(types[0].trim(), Span::call_site()),
284 value: Ident::new(types[1].trim(), Span::call_site()),
285 },
286 3 => StorageType::DoubleKeyMapping {
287 key1: Ident::new(types[0].trim(), Span::call_site()),
288 key2: Ident::new(types[1].trim(), Span::call_site()),
289 value: Ident::new(types[2].trim(), Span::call_site()),
290 },
291 _ => panic!("Invalid storage attributes"),
292 }
293 }
294}