zingen/jump/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
//! Jump table implementation.
//!
//! This module defines the `Jump` enum and the `JumpTable` struct, which are used to manage
//! various types of jumps in the program, including offsets, labels, function calls, and
//! external functions.

use crate::codegen::ExtFunc;
use core::fmt::Display;
pub use table::JumpTable;

mod pc;
mod relocate;
mod table;
mod target;

/// Represents the different types of jumps in the program.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Jump {
    /// Jump to a specific label, which corresponds to the original program counter.
    Label(u16),
    /// Jump to a function identified by its index.
    Func(u32),
    /// Jump to an external function.
    ExtFunc(ExtFunc),
}

impl Display for Jump {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Jump::Label(offset) => write!(f, "Label(0x{offset:x})"),
            Jump::Func(index) => write!(f, "Func({index})"),
            Jump::ExtFunc(_) => write!(f, "ExtFunc"),
        }
    }
}

impl Jump {
    /// Checks if the target is a label.
    pub fn is_label(&self) -> bool {
        matches!(self, Jump::Label { .. })
    }

    /// Checks if the target is a function call.
    pub fn is_call(&self) -> bool {
        !self.is_label()
    }
}

#[cfg(test)]
mod tests {
    use crate::jump::{Jump, JumpTable};
    use smallvec::smallvec;

    #[allow(unused)]
    fn init_tracing() {
        tracing_subscriber::fmt()
            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
            .without_time()
            .compact()
            .try_init()
            .ok();
    }

    fn assert_target_shift_vs_relocation(mut table: JumpTable) -> anyhow::Result<()> {
        // Calculate expected buffer size based on the maximum target
        let mut buffer = smallvec![0; table.max_target() as usize];

        // Perform target shifts
        table.shift_targets()?;

        // Find the maximum target after shifts
        let max_target = table.max_target();

        // Perform relocation
        table.relocate(&mut buffer)?;

        assert_eq!(buffer.len(), max_target as usize);
        Ok(())
    }

    #[test]
    fn test_target_shift_vs_relocation() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register jumps with known offsets and labels
        table.register(0x10, Jump::Label(0x20)); // Jump to label at 0x20
        table.register(0x30, Jump::Label(0x40)); // Jump to label at 0x40

        assert_target_shift_vs_relocation(table)
    }

    #[test]
    fn test_multiple_internal_calls() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Simulate multiple functions calling _approve
        table.register(0x10, Jump::Label(0x100)); // approve() -> _approve
        table.register(0x20, Jump::Label(0x100)); // spend_allowance() -> _approve

        assert_target_shift_vs_relocation(table)
    }

    #[test]
    fn test_nested_function_calls() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Simulate ERC20's approve -> _approve call chain
        table.register(0x100, Jump::Label(0x200)); // approve entry
        table.register(0x110, Jump::Label(0x300)); // approve -> _approve
        table.register(0x200, Jump::Label(0x400)); // _approve entry

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check if all jumps use correct PUSH instructions
        assert_eq!(buffer[0x100], 0x61); // PUSH2
        assert_eq!(buffer[0x113], 0x61); // PUSH2
        assert_eq!(buffer[0x206], 0x61); // PUSH2

        Ok(())
    }

    #[test]
    fn test_label_call_interaction() -> anyhow::Result<()> {
        init_tracing();
        let mut table = JumpTable::default();

        table.func.insert(1, 0x317);
        table.label(0x10, 0x12);
        table.call(0x11, 1);

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        assert_eq!(buffer[0x11], 0x17, "{buffer:?}");
        assert_eq!(buffer[0x14], 0x03, "{buffer:?}");
        assert_eq!(buffer[0x15], 0x1c, "{buffer:?}");
        Ok(())
    }

    #[test]
    fn test_large_target_offset_calculation() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register a jump with target < 0xff
        table.register(0x10, Jump::Label(0x80));

        // Register a jump with target > 0xff
        table.register(0x20, Jump::Label(0x100));

        // Register a jump with target > 0xfff
        table.register(0x30, Jump::Label(0x1000));

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check if offsets are correctly calculated
        // For target 0x80: PUSH1 (1 byte) + target (1 byte)
        // For target 0x100: PUSH2 (1 byte) + target (2 bytes)
        // For target 0x1000: PUSH2 (1 byte) + target (2 bytes)
        assert_eq!(buffer[0x11], 0x88); // Small target
        assert_eq!(buffer[0x23], 0x01); // First byte of large target
        assert_eq!(buffer[0x24], 0x08); // Second byte of large target
        assert_eq!(buffer[0x36], 0x10); // First byte of large target
        assert_eq!(buffer[0x37], 0x08); // Second byte of large target

        Ok(())
    }

    #[test]
    fn test_sequential_large_jumps() -> anyhow::Result<()> {
        let mut table = JumpTable::default();

        // Register multiple sequential jumps with increasing targets
        // This mirrors the ERC20 pattern where we have many functions
        for i in 0..20 {
            let target = 0x100 + (i * 0x20);
            table.register(0x10 + i, Jump::Label(target));
        }

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Check first jump (should use PUSH2)
        assert_eq!(buffer[0x10], 0x61); // PUSH2
        assert_eq!(buffer[0x11], 0x01); // First byte
        assert_eq!(buffer[0x12], 0x3c); // Second byte
        assert_eq!(0x013c, 0x100 + 20 * 3);

        // Check last jump (should still use PUSH2 but with adjusted offset)
        let last_idx = 0x10 + 19 + 19 * 3;
        assert_eq!(buffer[last_idx], 0x61); // PUSH2
        assert_eq!(buffer[last_idx + 1], 0x03); // First byte should be larger
        assert_eq!(buffer[last_idx + 2], 0x9c); // Second byte accounts for all previous jumps
        assert_eq!(0x039c, 0x100 + 0x20 * 19 + 20 * 3);

        Ok(())
    }

    #[test]
    fn test_dispatcher_jump_targets() -> anyhow::Result<()> {
        let mut table = JumpTable::default();
        let selectors = 5;

        // Register jumps for each selector check
        for i in 0..selectors {
            let i = i as u16;
            let check_pc = 0x10 + i * 0x20;
            let target_pc = 0x100 + i * 0x40;

            // Register both the comparison jump and function jump
            table.register(check_pc, Jump::Label(check_pc + 0x10));
            table.register(check_pc + 0x10, Jump::Label(target_pc));
        }

        let mut buffer = smallvec![0; table.max_target() as usize];
        table.relocate(&mut buffer)?;

        // Verify each selector's jump chain
        let mut total_offset = 0;
        for i in 0..selectors {
            let check_pc = 0x10 + i * 0x20 + total_offset;
            let check_pc_offset = if check_pc + 0x10 > 0xff { 3 } else { 2 };

            let func_pc = check_pc + 0x10 + check_pc_offset;

            let check_jump = buffer[check_pc];
            let func_jump = buffer[func_pc];

            assert_eq!(check_jump, if func_pc > 0xff { 0x61 } else { 0x60 });
            assert_eq!(func_jump, 0x61);

            // Update total offset for next iteration
            total_offset += check_pc_offset + 3;
        }

        Ok(())
    }
}