zink_abi_macro/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use std::fs;
6use std::path::Path;
7use syn::{parse_macro_input, Error};
8use zint::Contract;
9
10/// A struct to represent the function in an ERC ABI
11#[derive(serde::Deserialize, Debug)]
12struct AbiFunction {
13    name: String,
14    #[serde(default)]
15    inputs: Vec<AbiParameter>,
16    #[serde(default)]
17    outputs: Vec<AbiParameter>,
18    #[serde(default)]
19    state_mutability: String,
20    #[serde(default)]
21    constant: Option<bool>,
22    #[serde(rename = "type")]
23    fn_type: String,
24}
25
26/// A struct to represent a parameter in an ERC ABI
27#[derive(serde::Deserialize, Debug)]
28struct AbiParameter {
29    #[serde(default)]
30    name: String,
31    #[serde(rename = "type")]
32    param_type: String,
33    #[serde(default)]
34    _components: Option<Vec<AbiParameter>>,
35    #[serde(default)]
36    _indexed: Option<bool>,
37}
38
39/// Represents an Ethereum ABI
40#[derive(serde::Deserialize, Debug)]
41struct EthereumAbi {
42    #[serde(default)]
43    abi: Vec<AbiFunction>,
44}
45
46/// Maps Solidity types to Rust types and handles encoding/decoding
47fn map_type_to_rust_and_encode(solidity_type: &str) -> proc_macro2::TokenStream {
48    match solidity_type {
49        "uint256" | "int256" => quote! { ::zink::primitives::u256::U256 },
50        "uint8" | "int8" => quote! { u8 },
51        "uint16" | "int16" => quote! { u16 },
52        "uint32" | "int32" => quote! { u32 },
53        "uint64" | "int64" => quote! { u64 },
54        "uint128" | "int128" => quote! { u128 },
55        "bool" => quote! { bool },
56        "address" => quote! { ::zink::primitives::address::Address },
57        "string" => quote! { String },
58        "bytes" => quote! { Vec<u8> },
59        // Handle arrays, e.g., uint256[]
60        t if t.ends_with("[]") => {
61            let inner_type = &t[..t.len() - 2];
62            let rust_inner_type = map_type_to_rust_and_encode(inner_type);
63            quote! { Vec<#rust_inner_type> }
64        }
65        // Handle fixed size arrays, e.g., uint256[10]
66        t if t.contains('[') && t.ends_with(']') => {
67            let bracket_pos = t.find('[').unwrap();
68            let inner_type = &t[..bracket_pos];
69            let rust_inner_type = map_type_to_rust_and_encode(inner_type);
70            quote! { Vec<#rust_inner_type> }
71        }
72        // Default to bytes for any other type
73        _ => quote! { Vec<u8> },
74    }
75}
76
77/// Generate a function signature for an ABI function
78fn generate_function_signature(func: &AbiFunction) -> proc_macro2::TokenStream {
79    let fn_name = format_ident!("{}", func.name.to_case(Case::Snake));
80
81    // Generate function parameters
82    let mut params = quote! { &self };
83    for input in &func.inputs {
84        let param_name = if input.name.is_empty() {
85            format_ident!("arg{}", input.name.len())
86        } else {
87            format_ident!("{}", input.name.to_case(Case::Snake))
88        };
89
90        let param_type = map_type_to_rust_and_encode(&input.param_type);
91        params = quote! { #params, #param_name: #param_type };
92    }
93
94    // Generate function return type
95    let return_type = if func.outputs.is_empty() {
96        quote! { () }
97    } else if func.outputs.len() == 1 {
98        let output_type = map_type_to_rust_and_encode(&func.outputs[0].param_type);
99        quote! { #output_type }
100    } else {
101        let output_types = func
102            .outputs
103            .iter()
104            .map(|output| map_type_to_rust_and_encode(&output.param_type))
105            .collect::<Vec<_>>();
106        quote! { (#(#output_types),*) }
107    };
108
109    quote! {
110        pub fn #fn_name(#params) -> ::std::result::Result<#return_type, &'static str>
111    }
112}
113
114/// Generate the implementation for a contract function
115fn generate_function_implementation(func: &AbiFunction) -> proc_macro2::TokenStream {
116    let fn_signature = generate_function_signature(func);
117    let fn_name = &func.name;
118    let is_view = func.state_mutability == "view"
119        || func.state_mutability == "pure"
120        || func.constant.unwrap_or(false);
121
122    // Generate parameter names for encoding
123    let param_names = func
124        .inputs
125        .iter()
126        .enumerate()
127        .map(|(i, input)| {
128            if input.name.is_empty() {
129                format_ident!("arg{}", i)
130            } else {
131                format_ident!("{}", input.name.to_case(Case::Snake))
132            }
133        })
134        .collect::<Vec<_>>();
135
136    // Generate function selector calculation
137    let selector_str = format!(
138        "{}({})",
139        fn_name,
140        func.inputs
141            .iter()
142            .map(|i| i.param_type.clone())
143            .collect::<Vec<_>>()
144            .join(",")
145    );
146
147    // Determine which method to call (view_call or call)
148    let call_method = if is_view {
149        format_ident!("view_call")
150    } else {
151        format_ident!("call")
152    };
153
154    // Generate parameter encoding for each input
155    let param_encoding = if param_names.is_empty() {
156        quote! {
157            // No parameters to encode
158        }
159    } else {
160        let encoding_statements = param_names.iter().map(|param_name| {
161            let param_type = func
162                .inputs
163                .iter()
164                .find(|input| {
165                    if input.name.is_empty() {
166                        format_ident!("arg{}", input.name.len()) == *param_name
167                    } else {
168                        format_ident!("{}", input.name.to_case(Case::Snake)) == *param_name
169                    }
170                })
171                .map(|input| input.param_type.as_str())
172                .unwrap_or("unknown");
173
174            match param_type {
175                "address" => quote! {
176                    call_data.extend_from_slice(&zabi::encode_address(#param_name.as_bytes()));
177                },
178                "uint256" | "int256" => quote! {
179                    call_data.extend_from_slice(&zabi::encode_u256(#param_name.as_bytes()));
180                },
181                _ => quote! {
182                    call_data.extend_from_slice(&zabi::encode(#param_name));
183                },
184            }
185        });
186
187        quote! {
188            #(#encoding_statements)*
189        }
190    };
191
192    // Generate result decoding based on outputs
193    let result_decoding = if func.outputs.is_empty() {
194        quote! {
195            Ok(())
196        }
197    } else if func.outputs.len() == 1 {
198        let output_type = &func.outputs[0].param_type;
199        match output_type.as_str() {
200            "uint8" => quote! {
201                let decoded = zabi::decode::<u8>(&result)?;
202                Ok(decoded)
203            },
204            "uint256" | "int256" => {
205                quote! {
206                    let decoded_bytes = zabi::decode_u256(&result)?;
207                    Ok(::zink::primitives::u256::U256::from_be_bytes(decoded_bytes))
208                }
209            }
210            "bool" => quote! {
211                let decoded = zabi::decode::<bool>(&root)?;
212                Ok(decoded)
213            },
214            "string" => quote! {
215                let decoded = zabi::decode::<String>(&result)?;
216                Ok(decoded)
217            },
218            "address" => quote! {
219                let decoded_bytes = zabi::decode_address(&result)?;
220                Ok(::zink::primitives::address::Address::from(decoded_bytes))
221            },
222            _ => quote! {
223                // Default fallback for unknown types
224                Err("Unsupported return type")
225            },
226        }
227    } else {
228        quote! {
229            Err("Multiple return values not yet supported")
230        }
231    };
232
233    // Calculate the function selector using tiny-keccak directly
234    quote! {
235        #fn_signature {
236            let mut hasher = tiny_keccak::Keccak::v256();
237            let mut selector = [0u8; 4];
238            let signature = #selector_str;
239            hasher.update(signature.as_bytes());
240            let mut hash = [0u8; 32];
241            hasher.finalize(&mut hash);
242            selector.copy_from_slice(&hash[0..4]);
243
244            // Encode function parameters
245            let mut call_data = selector.to_vec();
246
247            #param_encoding
248
249            // Execute the call
250            let result = self.#call_method(&call_data)?;
251
252            // Decode the result
253            #result_decoding
254        }
255    }
256}
257
258/// The `import!` macro generates a Rust struct and implementation for interacting with an Ethereum
259/// smart contract based on its ABI (Application Binary Interface) and deploys the corresponding
260/// contract.
261///
262/// # Parameters
263/// - `abi_path`: A string literal specifying the path to the ABI JSON file (e.g., `"examples/ERC20.json"`).
264/// - `contract_name` (optional): A string literal specifying the name of the contract source file (e.g., `"my_erc20"`)
265///   without the `.rs` extension. If omitted, defaults to the base name of the ABI file (e.g., `"ERC20"` for `"ERC20.json"`).
266///   The file must be located in the `examples` directory or a configured search path.
267///
268/// # Generated Code
269/// The macro generates a struct named after the ABI file's base name (e.g., `ERC20` for `"ERC20.json"`) with:
270/// - An `address` field of type `::zink::primitives::address::Address` to hold the contract address.
271/// - An `evm` field of type `::zint::revm::EVM<'static>` to manage the EVM state.
272/// - A `new` method that deploys the specified contract and initializes the EVM.
273/// - Methods for each function in the ABI, which encode parameters, call the contract, and decode the results.
274///
275/// # Example
276/// ```rust
277/// #[cfg(feature = "abi-import")]
278/// use zink::import;
279///
280/// #[cfg(test)]
281/// mod tests {
282///     use zink::primitives::address::Address;
283///     use zint::revm;
284///
285///     #[test]
286///     fn test_contract() -> anyhow::Result<()> {
287///         #[cfg(feature = "abi-import")]
288///         {
289///             // Single argument: uses default contract name "ERC20"
290///             import!("examples/ERC20.json");
291///             let contract_address = Address::from(revm::CONTRACT);
292///             let token = ERC20::new(contract_address);
293///             let decimals = token.decimals()?;
294///             assert_eq!(decimals, 18);
295///
296///             // Two arguments: specifies custom contract name "my_erc20"
297///             import!("examples/ERC20.json", "my_erc20");
298///             let token = MyERC20::new(contract_address);
299///             let decimals = token.decimals()?;
300///             assert_eq!(decimals, 8);
301///         }
302///         Ok(())
303///     }
304/// }
305/// ```
306///
307/// # Requirements
308/// - The `abi-import` feature must be enabled (`--features abi-import`).
309/// - For `wasm32` targets, the `wasm-alloc` feature must be enabled (`--features wasm-alloc`) to provide a global allocator (`dlmalloc`).
310///
311/// # Notes
312/// - The contract file (defaulting to the ABI base name or specified by `contract_name`) must exist and be compilable by `zint::Contract::search`.
313/// - The EVM state is initialized with a default account (`ALICE`) and deploys the contract on `new`.
314#[proc_macro]
315pub fn import(input: TokenStream) -> TokenStream {
316    // Parse the input as a tuple of (abi_path) or (abi_path, contract_name)
317    let input = parse_macro_input!(input as syn::ExprTuple);
318    let (abi_path, contract_name) = match input.elems.len() {
319        1 => {
320            let abi_path = if let syn::Expr::Lit(syn::ExprLit {
321                lit: syn::Lit::Str(lit_str),
322                ..
323            }) = &input.elems[0]
324            {
325                lit_str.value()
326            } else {
327                return Error::new(
328                    Span::call_site(),
329                    "First argument must be a string literal for ABI path",
330                )
331                .to_compile_error()
332                .into();
333            };
334            let file_name = Path::new(&abi_path)
335                .file_stem()
336                .and_then(|s| s.to_str())
337                .unwrap_or("Contract")
338                .to_string();
339            (abi_path, file_name)
340        }
341        2 => {
342            let abi_path = if let syn::Expr::Lit(syn::ExprLit {
343                lit: syn::Lit::Str(lit_str),
344                ..
345            }) = &input.elems[0]
346            {
347                lit_str.value()
348            } else {
349                return Error::new(
350                    Span::call_site(),
351                    "First argument must be a string literal for ABI path",
352                )
353                .to_compile_error()
354                .into();
355            };
356            let contract_name = if let syn::Expr::Lit(syn::ExprLit {
357                lit: syn::Lit::Str(lit_str),
358                ..
359            }) = &input.elems[1]
360            {
361                lit_str.value()
362            } else {
363                return Error::new(
364                    Span::call_site(),
365                    "Second argument must be a string literal for contract name",
366                )
367                .to_compile_error()
368                .into();
369            };
370            (abi_path, contract_name)
371        }
372        _ => {
373            return Error::new(Span::call_site(), "import! macro expects one or two arguments: (abi_path) or (abi_path, contract_name)")
374                .to_compile_error()
375                .into();
376        }
377    };
378
379    // Attempt to locate the contract file using zint::Contract::search
380    let _contract = match Contract::search(&contract_name) {
381        Ok(contract) => contract,
382        Err(e) => {
383            return Error::new(
384                Span::call_site(),
385                format!(
386                    "Failed to find or compile contract '{}': {}",
387                    contract_name, e
388                ),
389            )
390            .to_compile_error()
391            .into();
392        }
393    };
394
395    let abi_content = match fs::read_to_string(&abi_path) {
396        Ok(content) => content,
397        Err(e) => {
398            return Error::new(Span::call_site(), format!("Failed to read ABI file: {}", e))
399                .to_compile_error()
400                .into()
401        }
402    };
403
404    // Parse the ABI JSON
405    let abi: EthereumAbi = match serde_json::from_str(&abi_content) {
406        Ok(abi) => abi,
407        Err(e) => {
408            return Error::new(
409                Span::call_site(),
410                format!("Failed to parse ABI JSON: {}", e),
411            )
412            .to_compile_error()
413            .into()
414        }
415    };
416
417    let file_name = std::path::Path::new(&abi_path)
418        .file_stem()
419        .and_then(|s| s.to_str())
420        .unwrap_or("Contract");
421
422    let struct_name = format_ident!("{}", file_name);
423
424    // Generate function implementations
425    let function_impls = abi
426        .abi
427        .iter()
428        .filter(|func| func.fn_type == "function")
429        .map(generate_function_implementation)
430        .collect::<Vec<_>>();
431
432    let expanded = quote! {
433        pub struct #struct_name {
434            address: ::zink::primitives::address::Address,
435            evm: ::zint::revm::EVM<'static>,
436        }
437
438        impl #struct_name {
439            pub fn new(address: ::zink::primitives::address::Address) -> Self {
440                use ::zint::revm;
441                use ::zink::primitives::address::Address;
442                use ::zink::primitives::u256::U256;
443                use ::zint::Contract;
444
445                let mut evm = revm::EVM::default();
446                // Initialize ALICE account with maximum balance
447                evm.db_mut().insert_account_info(
448                    revm::primitives::Address::from(Address::from(revm::ALICE)),
449                    revm::primitives::AccountInfo::from_balance(U256::MAX),
450                );
451                // Compile and deploy the specified contract
452                let contract = Contract::search(#contract_name).expect("Contract not found");
453                let bytecode = contract.compile().expect("Compilation failed").bytecode().expect("No bytecode").to_vec();
454                let deployed = evm.contract(&bytecode).deploy(&bytecode).expect("Deploy failed");
455                evm = deployed.evm;
456                evm.commit(true); // Commit the deployment
457
458                // Runtime check to ensure the contract is valid
459                if bytecode.is_empty() {
460                        panic!("Contract deployment failed: no bytecode generated");
461                    }
462
463                // Initialize ALICE's balance
464                let mut evm = evm.caller(revm::ALICE);
465                let storage_key = ::zink::storage::Mapping::<Address, U256>::storage_key(Address::from(revm::ALICE));
466                let initial_balance = U256::from(1000); // Set ALICE's balance to 1000 tokens
467                evm.db_mut().insert_storage(
468                    *address.as_bytes(),
469                    storage_key,
470                    initial_balance,
471                );
472                Self { address, evm }
473            }
474
475            fn view_call(&self, data: &[u8]) -> ::std::result::Result<Vec<u8>, &'static str> {
476                self.evm
477                    .clone()
478                    .calldata(data)
479                    .call(*self.address.as_bytes())
480                    .map(|info| info.ret)
481                    .map_err(|_| "View call failed")
482            }
483
484            fn call(&self, data: &[u8]) -> ::std::result::Result<Vec<u8>, &'static str> {
485                self.evm
486                    .clone()
487                    .calldata(data)
488                    .call(*self.address.as_bytes())
489                    .map(|info| info.ret)
490                    .map_err(|_| "Call failed")
491            }
492
493            #(#function_impls)*
494        }
495    };
496
497    expanded.into()
498}