zink/num/
safe.rs

1//! Numeric primitives
2
3use crate::{asm, primitives::U256};
4
5/// A trait for safe arithmetic operations with bound checks.
6pub trait SafeNumeric: Copy + PartialOrd + Sized {
7    fn max() -> Self;
8    fn min() -> Self;
9
10    fn safe_add(self, rhs: Self) -> Self;
11    fn safe_sub(self, rhs: Self) -> Self;
12    fn safe_mul(self, rhs: Self) -> Self;
13    fn safe_div(self, rhs: Self) -> Self;
14}
15
16macro_rules! local_revert {
17    ($msg:expr) => {
18        unsafe {
19            crate::asm::ext::revert1($msg);
20        }
21    };
22}
23
24// Signed types (i8, i16, i32, i64)
25macro_rules! impl_safe_numeric_signed {
26    ($($t:ty);* $(;)?) => {
27        $(
28            impl SafeNumeric for $t {
29                #[inline(always)]
30                fn max() -> Self { <$t>::MAX }
31                #[inline(always)]
32                fn min() -> Self { <$t>::MIN }
33
34                #[inline(always)]
35                fn safe_add(self, rhs: Self) -> Self {
36                    let result = self.wrapping_add(rhs);
37                    if (self > 0 && rhs > 0 && result < self) ||
38                       (self < 0 && rhs < 0 && result > self) {
39                        local_revert!("addition overflow");
40                    }
41                    result
42                }
43
44                #[inline(always)]
45                fn safe_sub(self, rhs: Self) -> Self {
46                    let result = self.wrapping_sub(rhs);
47                    if rhs < 0 && self > result {
48                        local_revert!("subtraction overflow");
49                    }
50                    result
51                }
52
53                #[inline(always)]
54                fn safe_mul(self, rhs: Self) -> Self {
55                    let result = self.wrapping_mul(rhs);
56                    if rhs != 0 && result / rhs != self {
57                        local_revert!("multiplication overflow");
58                    }
59                    result
60                }
61
62                #[inline(always)]
63                fn safe_div(self, rhs: Self) -> Self {
64                    if rhs == 0 {
65                        local_revert!("division by zero");
66                    }
67                    let result = self.wrapping_div(rhs);
68                    if self == <Self as SafeNumeric>::min() && rhs == -1 {
69                        local_revert!("division overflow");
70                    }
71                    result
72                }
73            }
74        )*
75    };
76}
77
78// Unsigned types (u8, u16, u32, u64)
79macro_rules! impl_safe_numeric_unsigned {
80    ($($t:ty);* $(;)?) => {
81        $(
82            impl SafeNumeric for $t {
83                #[inline(always)]
84                fn max() -> Self { <$t>::MAX }
85                #[inline(always)]
86                fn min() -> Self { <$t>::MIN }
87
88                #[inline(always)]
89                fn safe_add(self, rhs: Self) -> Self {
90                    let result = self.wrapping_add(rhs);
91                    if result < self {
92                        local_revert!("addition overflow");
93                    }
94                    result
95                }
96
97                #[inline(always)]
98                fn safe_sub(self, rhs: Self) -> Self {
99                    let result = self.wrapping_sub(rhs);
100                    if result > self {
101                        local_revert!("subtraction overflow");
102                    }
103                    result
104                }
105
106                #[inline(always)]
107                fn safe_mul(self, rhs: Self) -> Self {
108                    let result = self.wrapping_mul(rhs);
109                    if rhs != 0 && result / rhs != self {
110                        local_revert!("multiplication overflow");
111                    }
112                    result
113                }
114
115                #[inline(always)]
116                fn safe_div(self, rhs: Self) -> Self {
117                    if rhs == 0 {
118                        local_revert!("division by zero");
119                    }
120                    self / rhs
121                }
122            }
123        )*
124    };
125}
126
127// U256 special case
128impl SafeNumeric for U256 {
129    #[inline(always)]
130    fn max() -> Self {
131        unsafe { asm::ext::u256_max() }
132    }
133    #[inline(always)]
134    fn min() -> Self {
135        U256::empty()
136    }
137
138    #[inline(always)]
139    fn safe_add(self, rhs: Self) -> Self {
140        let result = unsafe { asm::ext::u256_add(self, rhs) };
141        if result < self {
142            local_revert!("addition overflow");
143        }
144        result
145    }
146
147    #[inline(always)]
148    fn safe_sub(self, rhs: Self) -> Self {
149        let result = unsafe { asm::ext::u256_sub(self, rhs) };
150        if result > self {
151            local_revert!("subtraction overflow");
152        }
153        result
154    }
155
156    #[inline(always)]
157    fn safe_mul(self, rhs: Self) -> Self {
158        let max = Self::max();
159        let result = unsafe { asm::ext::u256_mulmod(self, rhs, max) };
160        // Check if result exceeds max when rhs > 1
161        if rhs > Self::min() && result > self && result > rhs && result > max - self {
162            local_revert!("multiplication overflow");
163        }
164        result
165    }
166
167    #[inline(always)]
168    fn safe_div(self, rhs: Self) -> Self {
169        if rhs == Self::min() {
170            local_revert!("division by zero");
171        }
172        unsafe { asm::ext::u256_div(self, rhs) }
173    }
174}
175
176impl_safe_numeric_signed! {
177    i8;
178    i16;
179    i32;
180    i64;
181}
182
183impl_safe_numeric_unsigned! {
184    u8;
185    u16;
186    u32;
187    u64;
188}