1use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, Fields, Ident, ItemStruct, Type};
6
7pub struct ContractStorage {
9 target: ItemStruct,
10}
11
12impl ContractStorage {
13 pub fn new(input: ItemStruct) -> Self {
15 Self { target: input }
16 }
17
18 pub fn parse(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as ItemStruct);
21 let storage = Self::new(input);
22 storage.expand()
23 }
24
25 fn expand(&self) -> TokenStream {
27 let Fields::Named(fields) = &self.target.fields else {
28 return syn::Error::new(
29 Span::call_site(),
30 "Storage derive only supports structs with named fields",
31 )
32 .to_compile_error()
33 .into();
34 };
35
36 let struct_name = &self.target.ident;
37 let generics = &self.target.generics;
38 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40 let mut slot_counter = 0;
41 let field_structs: Vec<_> = fields.named.iter().map(|field| {
42 let field_name = field.ident.as_ref().unwrap();
43 let field_ty = &field.ty;
44 let slot = slot_counter;
45 slot_counter += 1;
46 let struct_name = format_ident!("{}{}", struct_name, field_name.to_upper_camel_case());
47
48 match classify_field_type(field_ty) {
49 FieldType::Simple => {
50 quote! {
51 pub struct #struct_name;
52 impl #impl_generics zink::storage::Storage for #struct_name #ty_generics #where_clause {
53 const STORAGE_SLOT: i32 = #slot;
54 type Value = #field_ty;
55
56 #[cfg(not(target_family = "wasm"))]
57 const STORAGE_KEY: [u8; 32] = [0u8; 32];
58
59 fn get() -> Self::Value {
60 zink::Asm::push(Self::STORAGE_SLOT);
61 <Self::Value as zink::storage::StorageValue>::sload()
62 }
63
64 fn set(value: Self::Value) {
65 value.push();
66 zink::Asm::push(Self::STORAGE_SLOT);
67 unsafe { zink::ffi::evm::sstore(); }
68 }
69 }
70 }
71 }
72 FieldType::Mapping => {
73 let (key_ty, value_ty) = extract_mapping_types(field_ty).unwrap_or_else(|| {
74 panic!("Mapping type must be of form Mapping<K, V>");
75 });
76 quote! {
77 pub struct #struct_name;
78 impl #impl_generics zink::storage::Mapping for #struct_name #ty_generics #where_clause {
79 const STORAGE_SLOT: i32 = #slot;
80 type Key = #key_ty;
81 type Value = #value_ty;
82
83 #[cfg(not(target_family = "wasm"))]
84 fn storage_key(key: Self::Key) -> [u8; 32] {
85 [0u8; 32]
86 }
87
88 fn get(key: Self::Key) -> Self::Value {
89 zink::storage::mapping::load_key(key, Self::STORAGE_SLOT);
90 <Self::Value as zink::storage::StorageValue>::sload()
91 }
92
93 fn set(key: Self::Key, value: Self::Value) {
94 value.push();
95 zink::storage::mapping::load_key(key, Self::STORAGE_SLOT);
96 unsafe { zink::ffi::evm::sstore(); }
97 }
98 }
99 }
100 }
101 FieldType::DoubleKeyMapping => {
102 let (key1_ty, key2_ty, value_ty) = extract_double_key_mapping_types(field_ty).unwrap_or_else(|| {
103 panic!("DoubleKeyMapping type must be of form DoubleKeyMapping<K1, K2, V>");
104 });
105 quote! {
106 pub struct #struct_name;
107 impl #impl_generics zink::storage::DoubleKeyMapping for #struct_name #ty_generics #where_clause {
108 const STORAGE_SLOT: i32 = #slot;
109 type Key1 = #key1_ty;
110 type Key2 = #key2_ty;
111 type Value = #value_ty;
112
113 #[cfg(not(target_family = "wasm"))]
114 fn storage_key(key1: Self::Key1, key2: Self::Key2) -> [u8; 32] {
115 [0u8; 32]
116 }
117
118 fn get(key1: Self::Key1, key2: Self::Key2) -> Self::Value {
119 zink::storage::dkmapping::load_double_key(key1, key2, Self::STORAGE_SLOT);
120 <Self::Value as zink::storage::StorageValue>::sload()
121 }
122
123 fn set(key1: Self::Key1, key2: Self::Key2, value: Self::Value) {
124 value.push();
125 zink::storage::dkmapping::load_double_key(key1, key2, Self::STORAGE_SLOT);
126 unsafe { zink::ffi::evm::sstore(); }
127 }
128 }
129 }
130 }
131 FieldType::Unknown => {
132 syn::Error::new_spanned(field_ty, "Unsupported storage type").to_compile_error()
133 }
134 }
135 }).collect();
136
137 let method_impls: Vec<_> = fields.named.iter().map(|field| {
138 let field_name = field.ident.as_ref().unwrap();
139 let field_ty = &field.ty;
140 let setter_name = format_ident!("set_{}", field_name);
141 let field_struct = format_ident!("{}{}", struct_name, field_name.to_upper_camel_case());
142
143 match classify_field_type(field_ty) {
144 FieldType::Simple => {
145 quote! {
146 pub fn #field_name(&self) -> #field_ty {
147 #field_struct::get()
148 }
149
150 pub fn #setter_name(&self, value: #field_ty) {
151 #field_struct::set(value);
152 }
153 }
154 }
155 FieldType::Mapping => {
156 let (key_ty, value_ty) = extract_mapping_types(field_ty).unwrap();
157 quote! {
158 pub fn #field_name(&self, key: #key_ty) -> #value_ty {
159 #field_struct::get(key)
160 }
161
162 pub fn #setter_name(&self, key: #key_ty, value: #value_ty) {
163 #field_struct::set(key, value);
164 }
165 }
166 }
167 FieldType::DoubleKeyMapping => {
168 let (key1_ty, key2_ty, value_ty) = extract_double_key_mapping_types(field_ty).unwrap();
169 quote! {
170 pub fn #field_name(&self, key1: #key1_ty, key2: #key2_ty) -> #value_ty {
171 #field_struct::get(key1, key2)
172 }
173
174 pub fn #setter_name(&self, key1: #key1_ty, key2: #key2_ty, value: #value_ty) {
175 #field_struct::set(key1, key2, value);
176 }
177 }
178 }
179 FieldType::Unknown => {
180 syn::Error::new_spanned(field_ty, "Unsupported storage type").to_compile_error()
181 }
182 }
183 }).collect();
184
185 let expanded = quote! {
186 use zink::Asm;
187 #(#field_structs)*
188 impl #impl_generics #struct_name #ty_generics #where_clause {
189 #(#method_impls)*
190 }
191 };
192
193 TokenStream::from(expanded)
194 }
195}
196
197trait ToUpperCamelCase {
198 fn to_upper_camel_case(&self) -> String;
199}
200
201impl ToUpperCamelCase for Ident {
202 fn to_upper_camel_case(&self) -> String {
203 let s = self.to_string();
204 let mut result = String::new();
205 let mut capitalize_next = true;
206
207 for c in s.chars() {
208 if c == '_' {
209 capitalize_next = true;
210 } else if capitalize_next {
211 result.push(c.to_ascii_uppercase());
212 capitalize_next = false;
213 } else {
214 result.push(c);
215 }
216 }
217 result
218 }
219}
220
221enum FieldType {
222 Simple,
223 Mapping,
224 DoubleKeyMapping,
225 Unknown,
226}
227
228fn classify_field_type(ty: &Type) -> FieldType {
229 if let Type::Path(type_path) = ty {
230 let path = &type_path.path;
231 if let Some(segment) = path.segments.last() {
232 match segment.ident.to_string().as_str() {
233 "Mapping" => FieldType::Mapping,
234 "DoubleKeyMapping" => FieldType::DoubleKeyMapping,
235 _ => FieldType::Simple,
236 }
237 } else {
238 FieldType::Unknown
239 }
240 } else {
241 FieldType::Unknown
242 }
243}
244
245fn extract_mapping_types(ty: &Type) -> Option<(Type, Type)> {
247 if let Type::Path(type_path) = ty {
248 if let Some(segment) = type_path.path.segments.last() {
249 if segment.ident == "Mapping" {
250 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
251 let args: Vec<_> = args.args.iter().collect();
252 if args.len() == 2 {
253 if let (
254 syn::GenericArgument::Type(key_ty),
255 syn::GenericArgument::Type(value_ty),
256 ) = (&args[0], &args[1])
257 {
258 return Some((key_ty.clone(), value_ty.clone()));
259 }
260 }
261 }
262 }
263 }
264 }
265 None
266}
267
268fn extract_double_key_mapping_types(ty: &Type) -> Option<(Type, Type, Type)> {
270 if let Type::Path(type_path) = ty {
271 if let Some(segment) = type_path.path.segments.last() {
272 if segment.ident == "DoubleKeyMapping" {
273 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
274 let args: Vec<_> = args.args.iter().collect();
275 if args.len() == 3 {
276 if let (
277 syn::GenericArgument::Type(key1_ty),
278 syn::GenericArgument::Type(key2_ty),
279 syn::GenericArgument::Type(value_ty),
280 ) = (&args[0], &args[1], &args[2])
281 {
282 return Some((key1_ty.clone(), key2_ty.clone(), value_ty.clone()));
283 }
284 }
285 }
286 }
287 }
288 }
289 None
290}