1use crate::{asm, primitives::U256};
4
5pub trait SafeNumeric: Copy + PartialOrd + Sized {
7 fn max() -> Self;
8 fn min() -> Self;
9
10 fn safe_add(self, rhs: Self) -> Self;
11 fn safe_sub(self, rhs: Self) -> Self;
12 fn safe_mul(self, rhs: Self) -> Self;
13 fn safe_div(self, rhs: Self) -> Self;
14}
15
16macro_rules! local_revert {
17 ($msg:expr) => {
18 unsafe {
19 crate::asm::ext::revert1($msg);
20 }
21 };
22}
23
24macro_rules! impl_safe_numeric_signed {
26 ($($t:ty);* $(;)?) => {
27 $(
28 impl SafeNumeric for $t {
29 #[inline(always)]
30 fn max() -> Self { <$t>::MAX }
31 #[inline(always)]
32 fn min() -> Self { <$t>::MIN }
33
34 #[inline(always)]
35 fn safe_add(self, rhs: Self) -> Self {
36 let result = self.wrapping_add(rhs);
37 if (self > 0 && rhs > 0 && result < self) ||
38 (self < 0 && rhs < 0 && result > self) {
39 local_revert!("addition overflow");
40 }
41 result
42 }
43
44 #[inline(always)]
45 fn safe_sub(self, rhs: Self) -> Self {
46 let result = self.wrapping_sub(rhs);
47 if rhs < 0 && self > result {
48 local_revert!("subtraction overflow");
49 }
50 result
51 }
52
53 #[inline(always)]
54 fn safe_mul(self, rhs: Self) -> Self {
55 let result = self.wrapping_mul(rhs);
56 if rhs != 0 && result / rhs != self {
57 local_revert!("multiplication overflow");
58 }
59 result
60 }
61
62 #[inline(always)]
63 fn safe_div(self, rhs: Self) -> Self {
64 if rhs == 0 {
65 local_revert!("division by zero");
66 }
67 let result = self.wrapping_div(rhs);
68 if self == <Self as SafeNumeric>::min() && rhs == -1 {
69 local_revert!("division overflow");
70 }
71 result
72 }
73 }
74 )*
75 };
76}
77
78macro_rules! impl_safe_numeric_unsigned {
80 ($($t:ty);* $(;)?) => {
81 $(
82 impl SafeNumeric for $t {
83 #[inline(always)]
84 fn max() -> Self { <$t>::MAX }
85 #[inline(always)]
86 fn min() -> Self { <$t>::MIN }
87
88 #[inline(always)]
89 fn safe_add(self, rhs: Self) -> Self {
90 let result = self.wrapping_add(rhs);
91 if result < self {
92 local_revert!("addition overflow");
93 }
94 result
95 }
96
97 #[inline(always)]
98 fn safe_sub(self, rhs: Self) -> Self {
99 let result = self.wrapping_sub(rhs);
100 if result > self {
101 local_revert!("subtraction overflow");
102 }
103 result
104 }
105
106 #[inline(always)]
107 fn safe_mul(self, rhs: Self) -> Self {
108 let result = self.wrapping_mul(rhs);
109 if rhs != 0 && result / rhs != self {
110 local_revert!("multiplication overflow");
111 }
112 result
113 }
114
115 #[inline(always)]
116 fn safe_div(self, rhs: Self) -> Self {
117 if rhs == 0 {
118 local_revert!("division by zero");
119 }
120 self / rhs
121 }
122 }
123 )*
124 };
125}
126
127impl SafeNumeric for U256 {
129 #[inline(always)]
130 fn max() -> Self {
131 unsafe { asm::ext::u256_max() }
132 }
133 #[inline(always)]
134 fn min() -> Self {
135 U256::empty()
136 }
137
138 #[inline(always)]
139 fn safe_add(self, rhs: Self) -> Self {
140 let result = unsafe { asm::ext::u256_add(self, rhs) };
141 if result < self {
142 local_revert!("addition overflow");
143 }
144 result
145 }
146
147 #[inline(always)]
148 fn safe_sub(self, rhs: Self) -> Self {
149 let result = unsafe { asm::ext::u256_sub(self, rhs) };
150 if result > self {
151 local_revert!("subtraction overflow");
152 }
153 result
154 }
155
156 #[inline(always)]
157 fn safe_mul(self, rhs: Self) -> Self {
158 let max = Self::max();
159 let result = unsafe { asm::ext::u256_mulmod(self, rhs, max) };
160 if rhs > Self::min() && result > self && result > rhs && result > max - self {
162 local_revert!("multiplication overflow");
163 }
164 result
165 }
166
167 #[inline(always)]
168 fn safe_div(self, rhs: Self) -> Self {
169 if rhs == Self::min() {
170 local_revert!("division by zero");
171 }
172 unsafe { asm::ext::u256_div(self, rhs) }
173 }
174}
175
176impl_safe_numeric_signed! {
177 i8;
178 i16;
179 i32;
180 i64;
181}
182
183impl_safe_numeric_unsigned! {
184 u8;
185 u16;
186 u32;
187 u64;
188}