zink/primitives/
numeric.rs

1use crate::{ffi, primitives::U256};
2
3/// A trait for modular arithmetic operations on numeric types.
4pub trait Numeric: Copy {
5    fn addmod(self, other: Self, n: Self) -> Self;
6    fn mulmod(self, other: Self, n: Self) -> Self;
7}
8
9/// A trait for safe arithmetic operations with bound checks.
10pub trait SafeNumeric: Copy + PartialOrd + Sized {
11    fn max() -> Self;
12    fn min() -> Self;
13
14    fn safe_add(self, rhs: Self) -> Self;
15    fn safe_sub(self, rhs: Self) -> Self;
16    fn safe_mul(self, rhs: Self) -> Self;
17    fn safe_div(self, rhs: Self) -> Self;
18}
19
20macro_rules! local_revert {
21    ($msg:expr) => {
22        #[cfg(target_arch = "wasm32")]
23        unsafe {
24            crate::ffi::asm::revert1($msg)
25        }
26        #[cfg(not(target_arch = "wasm32"))]
27        crate::ffi::asm::asm::revert1($msg)
28    };
29}
30
31macro_rules! impl_numeric {
32    ($($t:ty, $addmod_fn:ident, $mulmod_fn:ident);* $(;)?) => {
33        $(
34            impl Numeric for $t {
35                #[inline(always)]
36                fn addmod(self, other: Self, n: Self) -> Self {
37                    #[cfg(target_arch = "wasm32")]
38                    unsafe { ffi::asm::$addmod_fn(n, other, self) }
39                    #[cfg(not(target_arch = "wasm32"))]
40                    ffi::asm::asm::$addmod_fn(n, other, self)
41                }
42                #[inline(always)]
43                fn mulmod(self, other: Self, n: Self) -> Self {
44                    #[cfg(target_arch = "wasm32")]
45                    unsafe { ffi::asm::$mulmod_fn(n, other, self) }
46                    #[cfg(not(target_arch = "wasm32"))]
47                    ffi::asm::asm::$mulmod_fn(n, other, self)
48                }
49            }
50        )*
51    };
52    // Special case for U256
53    (U256, $addmod_fn:ident, $mulmod_fn:ident) => {
54        impl Numeric for U256 {
55            #[inline(always)]
56            fn addmod(self, other: Self, n: Self) -> Self {
57                unsafe { ffi::$addmod_fn(n, other, self) }
58            }
59            #[inline(always)]
60            fn mulmod(self, other: Self, n: Self) -> Self {
61                unsafe { ffi::$mulmod_fn(n, other, self) }
62            }
63        }
64    };
65}
66
67// Signed types (i8, i16, i32, i64)
68macro_rules! impl_safe_numeric_signed {
69    ($($t:ty);* $(;)?) => {
70        $(
71            impl SafeNumeric for $t {
72                #[inline(always)]
73                fn max() -> Self { <$t>::MAX }
74                #[inline(always)]
75                fn min() -> Self { <$t>::MIN }
76
77                #[inline(always)]
78                fn safe_add(self, rhs: Self) -> Self {
79                    let result = self.wrapping_add(rhs);
80                    if (self > 0 && rhs > 0 && result < self) ||
81                       (self < 0 && rhs < 0 && result > self) {
82                        local_revert!("addition overflow");
83                    }
84                    result
85                }
86
87                #[inline(always)]
88                fn safe_sub(self, rhs: Self) -> Self {
89                    let result = self.wrapping_sub(rhs);
90                    if rhs < 0 && self > result {
91                        local_revert!("subtraction overflow");
92                    }
93                    result
94                }
95
96                #[inline(always)]
97                fn safe_mul(self, rhs: Self) -> Self {
98                    let result = self.wrapping_mul(rhs);
99                    if rhs != 0 && result / rhs != self {
100                        local_revert!("multiplication overflow");
101                    }
102                    result
103                }
104
105                #[inline(always)]
106                fn safe_div(self, rhs: Self) -> Self {
107                    if rhs == 0 {
108                        local_revert!("division by zero");
109                    }
110                    let result = self.wrapping_div(rhs);
111                    if self == <Self as SafeNumeric>::min() && rhs == -1 {
112                        local_revert!("division overflow");
113                    }
114                    result
115                }
116            }
117        )*
118    };
119}
120
121// Unsigned types (u8, u16, u32, u64)
122macro_rules! impl_safe_numeric_unsigned {
123    ($($t:ty);* $(;)?) => {
124        $(
125            impl SafeNumeric for $t {
126                #[inline(always)]
127                fn max() -> Self { <$t>::MAX }
128                #[inline(always)]
129                fn min() -> Self { <$t>::MIN }
130
131                #[inline(always)]
132                fn safe_add(self, rhs: Self) -> Self {
133                    let result = self.wrapping_add(rhs);
134                    if result < self {
135                        local_revert!("addition overflow");
136                    }
137                    result
138                }
139
140                #[inline(always)]
141                fn safe_sub(self, rhs: Self) -> Self {
142                    let result = self.wrapping_sub(rhs);
143                    if result > self {
144                        local_revert!("subtraction overflow");
145                    }
146                    result
147                }
148
149                #[inline(always)]
150                fn safe_mul(self, rhs: Self) -> Self {
151                    let result = self.wrapping_mul(rhs);
152                    if rhs != 0 && result / rhs != self {
153                        local_revert!("multiplication overflow");
154                    }
155                    result
156                }
157
158                #[inline(always)]
159                fn safe_div(self, rhs: Self) -> Self {
160                    if rhs == 0 {
161                        local_revert!("division by zero");
162                    }
163                    self / rhs
164                }
165            }
166        )*
167    };
168}
169
170// U256 special case
171impl SafeNumeric for U256 {
172    #[inline(always)]
173    fn max() -> Self {
174        unsafe { ffi::u256_max() }
175    }
176    #[inline(always)]
177    fn min() -> Self {
178        U256::empty()
179    }
180
181    #[inline(always)]
182    fn safe_add(self, rhs: Self) -> Self {
183        let result = unsafe { ffi::u256_add(self, rhs) };
184        if result < self {
185            local_revert!("addition overflow");
186        }
187        result
188    }
189
190    #[inline(always)]
191    fn safe_sub(self, rhs: Self) -> Self {
192        let result = unsafe { ffi::u256_sub(self, rhs) };
193        if result > self {
194            local_revert!("subtraction overflow");
195        }
196        result
197    }
198
199    #[inline(always)]
200    fn safe_mul(self, rhs: Self) -> Self {
201        let max = Self::max();
202        let result = unsafe { ffi::u256_mulmod(self, rhs, max) };
203        // Check if result exceeds max when rhs > 1
204        if rhs > Self::min() && result > self && result > rhs && result > max - self {
205            local_revert!("multiplication overflow");
206        }
207        result
208    }
209
210    #[inline(always)]
211    fn safe_div(self, rhs: Self) -> Self {
212        if rhs == Self::min() {
213            local_revert!("division by zero");
214        }
215        unsafe { ffi::u256_div(self, rhs) }
216    }
217}
218
219impl_numeric! {
220    i8, addmod_i8, mulmod_i8;
221    u8, addmod_u8, mulmod_u8;
222    i16, addmod_i16, mulmod_i16;
223    u16, addmod_u16, mulmod_u16;
224    i32, addmod_i32, mulmod_i32;
225    u32, addmod_u32, mulmod_u32;
226    i64, addmod_i64, mulmod_i64;
227    u64, addmod_u64, mulmod_u64;
228}
229
230impl_safe_numeric_signed! {
231    i8;
232    i16;
233    i32;
234    i64;
235}
236
237impl_safe_numeric_unsigned! {
238    u8;
239    u16;
240    u32;
241    u64;
242}