zingen/jump/
mod.rs

1//! Jump table implementation.
2//!
3//! This module defines the `Jump` enum and the `JumpTable` struct, which are used to manage
4//! various types of jumps in the program, including offsets, labels, function calls, and
5//! external functions.
6
7use crate::codegen::ExtFunc;
8use core::fmt::Display;
9pub use table::JumpTable;
10
11mod pc;
12mod relocate;
13mod table;
14mod target;
15
16/// Represents the different types of jumps in the program.
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum Jump {
19    /// Jump to a specific label, which corresponds to the original program counter.
20    Label(u16),
21    /// Jump to a function identified by its index.
22    Func(u32),
23    /// Jump to an external function.
24    ExtFunc(ExtFunc),
25}
26
27impl Display for Jump {
28    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29        match self {
30            Jump::Label(offset) => write!(f, "Label(0x{offset:x})"),
31            Jump::Func(index) => write!(f, "Func({index})"),
32            Jump::ExtFunc(_) => write!(f, "ExtFunc"),
33        }
34    }
35}
36
37impl Jump {
38    /// Checks if the target is a label.
39    pub fn is_label(&self) -> bool {
40        matches!(self, Jump::Label { .. })
41    }
42
43    /// Checks if the target is a function call.
44    pub fn is_call(&self) -> bool {
45        !self.is_label()
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use crate::jump::{Jump, JumpTable};
52    use smallvec::smallvec;
53
54    #[allow(unused)]
55    fn init_tracing() {
56        tracing_subscriber::fmt()
57            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
58            .without_time()
59            .compact()
60            .try_init()
61            .ok();
62    }
63
64    fn assert_target_shift_vs_relocation(mut table: JumpTable) -> anyhow::Result<()> {
65        // Calculate expected buffer size based on the maximum target
66        let mut buffer = smallvec![0; table.max_target() as usize];
67
68        // Perform target shifts
69        table.shift_targets()?;
70
71        // Find the maximum target after shifts
72        let max_target = table.max_target();
73
74        // Perform relocation
75        table.relocate(&mut buffer)?;
76
77        assert_eq!(buffer.len(), max_target as usize);
78        Ok(())
79    }
80
81    #[test]
82    fn test_target_shift_vs_relocation() -> anyhow::Result<()> {
83        let mut table = JumpTable::default();
84
85        // Register jumps with known offsets and labels
86        table.register(0x10, Jump::Label(0x20)); // Jump to label at 0x20
87        table.register(0x30, Jump::Label(0x40)); // Jump to label at 0x40
88
89        assert_target_shift_vs_relocation(table)
90    }
91
92    #[test]
93    fn test_multiple_internal_calls() -> anyhow::Result<()> {
94        let mut table = JumpTable::default();
95
96        // Simulate multiple functions calling _approve
97        table.register(0x10, Jump::Label(0x100)); // approve() -> _approve
98        table.register(0x20, Jump::Label(0x100)); // spend_allowance() -> _approve
99
100        assert_target_shift_vs_relocation(table)
101    }
102
103    #[test]
104    fn test_nested_function_calls() -> anyhow::Result<()> {
105        let mut table = JumpTable::default();
106
107        // Simulate ERC20's approve -> _approve call chain
108        table.register(0x100, Jump::Label(0x200)); // approve entry
109        table.register(0x110, Jump::Label(0x300)); // approve -> _approve
110        table.register(0x200, Jump::Label(0x400)); // _approve entry
111
112        let mut buffer = smallvec![0; table.max_target() as usize];
113        table.relocate(&mut buffer)?;
114
115        // Check if all jumps use correct PUSH instructions
116        assert_eq!(buffer[0x100], 0x61); // PUSH2
117        assert_eq!(buffer[0x113], 0x61); // PUSH2
118        assert_eq!(buffer[0x206], 0x61); // PUSH2
119
120        Ok(())
121    }
122
123    #[test]
124    fn test_label_call_interaction() -> anyhow::Result<()> {
125        init_tracing();
126        let mut table = JumpTable::default();
127
128        table.func.insert(1, 0x317);
129        table.label(0x10, 0x12);
130        table.call(0x11, 1);
131
132        let mut buffer = smallvec![0; table.max_target() as usize];
133        table.relocate(&mut buffer)?;
134
135        assert_eq!(buffer[0x11], 0x17, "{buffer:?}");
136        assert_eq!(buffer[0x14], 0x03, "{buffer:?}");
137        assert_eq!(buffer[0x15], 0x1c, "{buffer:?}");
138        Ok(())
139    }
140
141    #[test]
142    fn test_large_target_offset_calculation() -> anyhow::Result<()> {
143        let mut table = JumpTable::default();
144
145        // Register a jump with target < 0xff
146        table.register(0x10, Jump::Label(0x80));
147
148        // Register a jump with target > 0xff
149        table.register(0x20, Jump::Label(0x100));
150
151        // Register a jump with target > 0xfff
152        table.register(0x30, Jump::Label(0x1000));
153
154        let mut buffer = smallvec![0; table.max_target() as usize];
155        table.relocate(&mut buffer)?;
156
157        // Check if offsets are correctly calculated
158        // For target 0x80: PUSH1 (1 byte) + target (1 byte)
159        // For target 0x100: PUSH2 (1 byte) + target (2 bytes)
160        // For target 0x1000: PUSH2 (1 byte) + target (2 bytes)
161        assert_eq!(buffer[0x11], 0x88); // Small target
162        assert_eq!(buffer[0x23], 0x01); // First byte of large target
163        assert_eq!(buffer[0x24], 0x08); // Second byte of large target
164        assert_eq!(buffer[0x36], 0x10); // First byte of large target
165        assert_eq!(buffer[0x37], 0x08); // Second byte of large target
166
167        Ok(())
168    }
169
170    #[test]
171    fn test_sequential_large_jumps() -> anyhow::Result<()> {
172        let mut table = JumpTable::default();
173
174        // Register multiple sequential jumps with increasing targets
175        // This mirrors the ERC20 pattern where we have many functions
176        for i in 0..20 {
177            let target = 0x100 + (i * 0x20);
178            table.register(0x10 + i, Jump::Label(target));
179        }
180
181        let mut buffer = smallvec![0; table.max_target() as usize];
182        table.relocate(&mut buffer)?;
183
184        // Check first jump (should use PUSH2)
185        assert_eq!(buffer[0x10], 0x61); // PUSH2
186        assert_eq!(buffer[0x11], 0x01); // First byte
187        assert_eq!(buffer[0x12], 0x3c); // Second byte
188        assert_eq!(0x013c, 0x100 + 20 * 3);
189
190        // Check last jump (should still use PUSH2 but with adjusted offset)
191        let last_idx = 0x10 + 19 + 19 * 3;
192        assert_eq!(buffer[last_idx], 0x61); // PUSH2
193        assert_eq!(buffer[last_idx + 1], 0x03); // First byte should be larger
194        assert_eq!(buffer[last_idx + 2], 0x9c); // Second byte accounts for all previous jumps
195        assert_eq!(0x039c, 0x100 + 0x20 * 19 + 20 * 3);
196
197        Ok(())
198    }
199
200    #[test]
201    fn test_dispatcher_jump_targets() -> anyhow::Result<()> {
202        let mut table = JumpTable::default();
203        let selectors = 5;
204
205        // Register jumps for each selector check
206        for i in 0..selectors {
207            let i = i as u16;
208            let check_pc = 0x10 + i * 0x20;
209            let target_pc = 0x100 + i * 0x40;
210
211            // Register both the comparison jump and function jump
212            table.register(check_pc, Jump::Label(check_pc + 0x10));
213            table.register(check_pc + 0x10, Jump::Label(target_pc));
214        }
215
216        let mut buffer = smallvec![0; table.max_target() as usize];
217        table.relocate(&mut buffer)?;
218
219        // Verify each selector's jump chain
220        let mut total_offset = 0;
221        for i in 0..selectors {
222            let check_pc = 0x10 + i * 0x20 + total_offset;
223            let check_pc_offset = if check_pc + 0x10 > 0xff { 3 } else { 2 };
224
225            let func_pc = check_pc + 0x10 + check_pc_offset;
226
227            let check_jump = buffer[check_pc];
228            let func_jump = buffer[func_pc];
229
230            assert_eq!(check_jump, if func_pc > 0xff { 0x61 } else { 0x60 });
231            assert_eq!(func_jump, 0x61);
232
233            // Update total offset for next iteration
234            total_offset += check_pc_offset + 3;
235        }
236
237        Ok(())
238    }
239}