zink_codegen/
contract.rs

1//! Derive macro for contract storage
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, Fields, Ident, ItemStruct, Type};
6
7// Represents the contract storage derivation
8pub struct ContractStorage {
9    target: ItemStruct,
10}
11
12impl ContractStorage {
13    /// Create a new ContractStorage from an input struct
14    pub fn new(input: ItemStruct) -> Self {
15        Self { target: input }
16    }
17
18    /// Parse and validate the input, returning a TokenStream
19    pub fn parse(input: TokenStream) -> TokenStream {
20        let input = parse_macro_input!(input as ItemStruct);
21        let storage = Self::new(input);
22        storage.expand()
23    }
24
25    /// Generate the expanded TokenStream
26    fn expand(&self) -> TokenStream {
27        let Fields::Named(fields) = &self.target.fields else {
28            return syn::Error::new(
29                Span::call_site(),
30                "Storage derive only supports structs with named fields",
31            )
32            .to_compile_error()
33            .into();
34        };
35
36        let struct_name = &self.target.ident;
37        let generics = &self.target.generics;
38        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40        let mut slot_counter = 0;
41        let field_structs: Vec<_> = fields.named.iter().map(|field| {
42            let field_name = field.ident.as_ref().unwrap();
43            let field_ty = &field.ty;
44            let slot = slot_counter;
45            slot_counter += 1;
46            let struct_name = format_ident!("{}{}", struct_name, field_name.to_upper_camel_case());
47
48            match classify_field_type(field_ty) {
49                FieldType::Simple => {
50                    quote! {
51                        pub struct #struct_name;
52                        impl #impl_generics zink::storage::Storage for #struct_name #ty_generics #where_clause {
53                            const STORAGE_SLOT: i32 = #slot;
54                            type Value = #field_ty;
55
56                            #[cfg(not(target_family = "wasm"))]
57                            const STORAGE_KEY: [u8; 32] = [0u8; 32];
58
59                            fn get() -> Self::Value {
60                                zink::Asm::push(Self::STORAGE_SLOT);
61                                <Self::Value as zink::storage::StorageValue>::sload()
62                            }
63
64                            fn set(value: Self::Value) {
65                                value.push();
66                                zink::Asm::push(Self::STORAGE_SLOT);
67                                unsafe { zink::ffi::evm::sstore(); }
68                            }
69                        }
70                    }
71                }
72                FieldType::Mapping => {
73                    let (key_ty, value_ty) = extract_mapping_types(field_ty).unwrap_or_else(|| {
74                        panic!("Mapping type must be of form Mapping<K, V>");
75                    });
76                    quote! {
77                        pub struct #struct_name;
78                        impl #impl_generics zink::storage::Mapping for #struct_name #ty_generics #where_clause {
79                            const STORAGE_SLOT: i32 = #slot;
80                            type Key = #key_ty;
81                            type Value = #value_ty;
82
83                            #[cfg(not(target_family = "wasm"))]
84                            fn storage_key(key: Self::Key) -> [u8; 32] {
85                                [0u8; 32]
86                            }
87
88                            fn get(key: Self::Key) -> Self::Value {
89                                zink::storage::mapping::load_key(key, Self::STORAGE_SLOT);
90                                <Self::Value as zink::storage::StorageValue>::sload()
91                            }
92
93                            fn set(key: Self::Key, value: Self::Value) {
94                                value.push();
95                                zink::storage::mapping::load_key(key, Self::STORAGE_SLOT);
96                                unsafe { zink::ffi::evm::sstore(); }
97                            }
98                        }
99                    }
100                }
101                FieldType::DoubleKeyMapping => {
102                    let (key1_ty, key2_ty, value_ty) = extract_double_key_mapping_types(field_ty).unwrap_or_else(|| {
103                        panic!("DoubleKeyMapping type must be of form DoubleKeyMapping<K1, K2, V>");
104                    });
105                    quote! {
106                        pub struct #struct_name;
107                        impl #impl_generics zink::storage::DoubleKeyMapping for #struct_name #ty_generics #where_clause {
108                            const STORAGE_SLOT: i32 = #slot;
109                            type Key1 = #key1_ty;
110                            type Key2 = #key2_ty;
111                            type Value = #value_ty;
112
113                            #[cfg(not(target_family = "wasm"))]
114                            fn storage_key(key1: Self::Key1, key2: Self::Key2) -> [u8; 32] {
115                                [0u8; 32]
116                            }
117
118                            fn get(key1: Self::Key1, key2: Self::Key2) -> Self::Value {
119                                zink::storage::dkmapping::load_double_key(key1, key2, Self::STORAGE_SLOT);
120                                <Self::Value as zink::storage::StorageValue>::sload()
121                            }
122
123                            fn set(key1: Self::Key1, key2: Self::Key2, value: Self::Value) {
124                                value.push();
125                                zink::storage::dkmapping::load_double_key(key1, key2, Self::STORAGE_SLOT);
126                                unsafe { zink::ffi::evm::sstore(); }
127                            }
128                        }
129                    }
130                }
131                FieldType::Unknown => {
132                    syn::Error::new_spanned(field_ty, "Unsupported storage type").to_compile_error()
133                }
134            }
135        }).collect();
136
137        let method_impls: Vec<_> = fields.named.iter().map(|field| {
138            let field_name = field.ident.as_ref().unwrap();
139            let field_ty = &field.ty;
140            let setter_name = format_ident!("set_{}", field_name);
141            let field_struct = format_ident!("{}{}", struct_name, field_name.to_upper_camel_case());
142
143            match classify_field_type(field_ty) {
144                FieldType::Simple => {
145                    quote! {
146                        pub fn #field_name(&self) -> #field_ty {
147                            #field_struct::get()
148                        }
149
150                        pub fn #setter_name(&self, value: #field_ty) {
151                            #field_struct::set(value);
152                        }
153                    }
154                }
155                FieldType::Mapping => {
156                    let (key_ty, value_ty) = extract_mapping_types(field_ty).unwrap();
157                    quote! {
158                        pub fn #field_name(&self, key: #key_ty) -> #value_ty {
159                            #field_struct::get(key)
160                        }
161
162                        pub fn #setter_name(&self, key: #key_ty, value: #value_ty) {
163                            #field_struct::set(key, value);
164                        }
165                    }
166                }
167                FieldType::DoubleKeyMapping => {
168                    let (key1_ty, key2_ty, value_ty) = extract_double_key_mapping_types(field_ty).unwrap();
169                    quote! {
170                        pub fn #field_name(&self, key1: #key1_ty, key2: #key2_ty) -> #value_ty {
171                            #field_struct::get(key1, key2)
172                        }
173
174                        pub fn #setter_name(&self, key1: #key1_ty, key2: #key2_ty, value: #value_ty) {
175                            #field_struct::set(key1, key2, value);
176                        }
177                    }
178                }
179                FieldType::Unknown => {
180                    syn::Error::new_spanned(field_ty, "Unsupported storage type").to_compile_error()
181                }
182            }
183        }).collect();
184
185        let expanded = quote! {
186            use zink::Asm;
187            #(#field_structs)*
188            impl #impl_generics #struct_name #ty_generics #where_clause {
189                #(#method_impls)*
190            }
191        };
192
193        TokenStream::from(expanded)
194    }
195}
196
197trait ToUpperCamelCase {
198    fn to_upper_camel_case(&self) -> String;
199}
200
201impl ToUpperCamelCase for Ident {
202    fn to_upper_camel_case(&self) -> String {
203        let s = self.to_string();
204        let mut result = String::new();
205        let mut capitalize_next = true;
206
207        for c in s.chars() {
208            if c == '_' {
209                capitalize_next = true;
210            } else if capitalize_next {
211                result.push(c.to_ascii_uppercase());
212                capitalize_next = false;
213            } else {
214                result.push(c);
215            }
216        }
217        result
218    }
219}
220
221enum FieldType {
222    Simple,
223    Mapping,
224    DoubleKeyMapping,
225    Unknown,
226}
227
228fn classify_field_type(ty: &Type) -> FieldType {
229    if let Type::Path(type_path) = ty {
230        let path = &type_path.path;
231        if let Some(segment) = path.segments.last() {
232            match segment.ident.to_string().as_str() {
233                "Mapping" => FieldType::Mapping,
234                "DoubleKeyMapping" => FieldType::DoubleKeyMapping,
235                _ => FieldType::Simple,
236            }
237        } else {
238            FieldType::Unknown
239        }
240    } else {
241        FieldType::Unknown
242    }
243}
244
245/// Extract generic types from Mapping<K, V>
246fn extract_mapping_types(ty: &Type) -> Option<(Type, Type)> {
247    if let Type::Path(type_path) = ty {
248        if let Some(segment) = type_path.path.segments.last() {
249            if segment.ident == "Mapping" {
250                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
251                    let args: Vec<_> = args.args.iter().collect();
252                    if args.len() == 2 {
253                        if let (
254                            syn::GenericArgument::Type(key_ty),
255                            syn::GenericArgument::Type(value_ty),
256                        ) = (&args[0], &args[1])
257                        {
258                            return Some((key_ty.clone(), value_ty.clone()));
259                        }
260                    }
261                }
262            }
263        }
264    }
265    None
266}
267
268/// Extract generic types from DoubleKeyMapping<K1, K2, V>
269fn extract_double_key_mapping_types(ty: &Type) -> Option<(Type, Type, Type)> {
270    if let Type::Path(type_path) = ty {
271        if let Some(segment) = type_path.path.segments.last() {
272            if segment.ident == "DoubleKeyMapping" {
273                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
274                    let args: Vec<_> = args.args.iter().collect();
275                    if args.len() == 3 {
276                        if let (
277                            syn::GenericArgument::Type(key1_ty),
278                            syn::GenericArgument::Type(key2_ty),
279                            syn::GenericArgument::Type(value_ty),
280                        ) = (&args[0], &args[1], &args[2])
281                        {
282                            return Some((key1_ty.clone(), key2_ty.clone(), value_ty.clone()));
283                        }
284                    }
285                }
286            }
287        }
288    }
289    None
290}