zink/primitives/
numeric.rs1use crate::{ffi, primitives::U256};
2
3pub trait Numeric: Copy {
5 fn addmod(self, other: Self, n: Self) -> Self;
6 fn mulmod(self, other: Self, n: Self) -> Self;
7}
8
9pub trait SafeNumeric: Copy + PartialOrd + Sized {
11 fn max() -> Self;
12 fn min() -> Self;
13
14 fn safe_add(self, rhs: Self) -> Self;
15 fn safe_sub(self, rhs: Self) -> Self;
16 fn safe_mul(self, rhs: Self) -> Self;
17 fn safe_div(self, rhs: Self) -> Self;
18}
19
20macro_rules! local_revert {
21 ($msg:expr) => {
22 #[cfg(target_arch = "wasm32")]
23 unsafe {
24 crate::ffi::asm::revert1($msg)
25 }
26 #[cfg(not(target_arch = "wasm32"))]
27 crate::ffi::asm::asm::revert1($msg)
28 };
29}
30
31macro_rules! impl_numeric {
32 ($($t:ty, $addmod_fn:ident, $mulmod_fn:ident);* $(;)?) => {
33 $(
34 impl Numeric for $t {
35 #[inline(always)]
36 fn addmod(self, other: Self, n: Self) -> Self {
37 #[cfg(target_arch = "wasm32")]
38 unsafe { ffi::asm::$addmod_fn(n, other, self) }
39 #[cfg(not(target_arch = "wasm32"))]
40 ffi::asm::asm::$addmod_fn(n, other, self)
41 }
42 #[inline(always)]
43 fn mulmod(self, other: Self, n: Self) -> Self {
44 #[cfg(target_arch = "wasm32")]
45 unsafe { ffi::asm::$mulmod_fn(n, other, self) }
46 #[cfg(not(target_arch = "wasm32"))]
47 ffi::asm::asm::$mulmod_fn(n, other, self)
48 }
49 }
50 )*
51 };
52 (U256, $addmod_fn:ident, $mulmod_fn:ident) => {
54 impl Numeric for U256 {
55 #[inline(always)]
56 fn addmod(self, other: Self, n: Self) -> Self {
57 unsafe { ffi::$addmod_fn(n, other, self) }
58 }
59 #[inline(always)]
60 fn mulmod(self, other: Self, n: Self) -> Self {
61 unsafe { ffi::$mulmod_fn(n, other, self) }
62 }
63 }
64 };
65}
66
67macro_rules! impl_safe_numeric_signed {
69 ($($t:ty);* $(;)?) => {
70 $(
71 impl SafeNumeric for $t {
72 #[inline(always)]
73 fn max() -> Self { <$t>::MAX }
74 #[inline(always)]
75 fn min() -> Self { <$t>::MIN }
76
77 #[inline(always)]
78 fn safe_add(self, rhs: Self) -> Self {
79 let result = self.wrapping_add(rhs);
80 if (self > 0 && rhs > 0 && result < self) ||
81 (self < 0 && rhs < 0 && result > self) {
82 local_revert!("addition overflow");
83 }
84 result
85 }
86
87 #[inline(always)]
88 fn safe_sub(self, rhs: Self) -> Self {
89 let result = self.wrapping_sub(rhs);
90 if rhs < 0 && self > result {
91 local_revert!("subtraction overflow");
92 }
93 result
94 }
95
96 #[inline(always)]
97 fn safe_mul(self, rhs: Self) -> Self {
98 let result = self.wrapping_mul(rhs);
99 if rhs != 0 && result / rhs != self {
100 local_revert!("multiplication overflow");
101 }
102 result
103 }
104
105 #[inline(always)]
106 fn safe_div(self, rhs: Self) -> Self {
107 if rhs == 0 {
108 local_revert!("division by zero");
109 }
110 let result = self.wrapping_div(rhs);
111 if self == <Self as SafeNumeric>::min() && rhs == -1 {
112 local_revert!("division overflow");
113 }
114 result
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! impl_safe_numeric_unsigned {
123 ($($t:ty);* $(;)?) => {
124 $(
125 impl SafeNumeric for $t {
126 #[inline(always)]
127 fn max() -> Self { <$t>::MAX }
128 #[inline(always)]
129 fn min() -> Self { <$t>::MIN }
130
131 #[inline(always)]
132 fn safe_add(self, rhs: Self) -> Self {
133 let result = self.wrapping_add(rhs);
134 if result < self {
135 local_revert!("addition overflow");
136 }
137 result
138 }
139
140 #[inline(always)]
141 fn safe_sub(self, rhs: Self) -> Self {
142 let result = self.wrapping_sub(rhs);
143 if result > self {
144 local_revert!("subtraction overflow");
145 }
146 result
147 }
148
149 #[inline(always)]
150 fn safe_mul(self, rhs: Self) -> Self {
151 let result = self.wrapping_mul(rhs);
152 if rhs != 0 && result / rhs != self {
153 local_revert!("multiplication overflow");
154 }
155 result
156 }
157
158 #[inline(always)]
159 fn safe_div(self, rhs: Self) -> Self {
160 if rhs == 0 {
161 local_revert!("division by zero");
162 }
163 self / rhs
164 }
165 }
166 )*
167 };
168}
169
170impl SafeNumeric for U256 {
172 #[inline(always)]
173 fn max() -> Self {
174 unsafe { ffi::u256_max() }
175 }
176 #[inline(always)]
177 fn min() -> Self {
178 U256::empty()
179 }
180
181 #[inline(always)]
182 fn safe_add(self, rhs: Self) -> Self {
183 let result = unsafe { ffi::u256_add(self, rhs) };
184 if result < self {
185 local_revert!("addition overflow");
186 }
187 result
188 }
189
190 #[inline(always)]
191 fn safe_sub(self, rhs: Self) -> Self {
192 let result = unsafe { ffi::u256_sub(self, rhs) };
193 if result > self {
194 local_revert!("subtraction overflow");
195 }
196 result
197 }
198
199 #[inline(always)]
200 fn safe_mul(self, rhs: Self) -> Self {
201 let max = Self::max();
202 let result = unsafe { ffi::u256_mulmod(self, rhs, max) };
203 if rhs > Self::min() && result > self && result > rhs && result > max - self {
205 local_revert!("multiplication overflow");
206 }
207 result
208 }
209
210 #[inline(always)]
211 fn safe_div(self, rhs: Self) -> Self {
212 if rhs == Self::min() {
213 local_revert!("division by zero");
214 }
215 unsafe { ffi::u256_div(self, rhs) }
216 }
217}
218
219impl_numeric! {
220 i8, addmod_i8, mulmod_i8;
221 u8, addmod_u8, mulmod_u8;
222 i16, addmod_i16, mulmod_i16;
223 u16, addmod_u16, mulmod_u16;
224 i32, addmod_i32, mulmod_i32;
225 u32, addmod_u32, mulmod_u32;
226 i64, addmod_i64, mulmod_i64;
227 u64, addmod_u64, mulmod_u64;
228}
229
230impl_safe_numeric_signed! {
231 i8;
232 i16;
233 i32;
234 i64;
235}
236
237impl_safe_numeric_unsigned! {
238 u8;
239 u16;
240 u32;
241 u64;
242}