zingen/jump/
table.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
//! Jump Table
//!
//! This module defines the `JumpTable` struct, which manages the jump table, function
//! table, and code section. It provides methods to register jumps, functions, and
//! labels, as well as to merge jump tables.

use crate::{codegen::ExtFunc, jump::Jump, Code, Error, Result};
use std::collections::BTreeMap;

/// Jump table implementation.
#[derive(Clone, Default, Debug)]
pub struct JumpTable {
    /// Jump table mapping program counters to jump types.
    pub(crate) jump: BTreeMap<u16, Jump>,
    /// Function table mapping function indices to program counters.
    pub(crate) func: BTreeMap<u32, u16>,
    /// Code section associated with the jump table.
    pub(crate) code: Code,
}

impl JumpTable {
    /// Registers a function in the jump table.
    ///
    /// This function associates a program counter with a function.
    pub fn call(&mut self, pc: u16, func: u32) {
        self.jump.insert(pc, Jump::Func(func));
    }

    /// Registers a program counter to the function table.
    ///
    /// This function associates a function with a specific offset in the function table.
    pub fn call_offset(&mut self, func: u32, offset: u16) -> Result<()> {
        if self.func.insert(func, offset).is_some() {
            return Err(Error::DuplicateFunc(func));
        }

        Ok(())
    }

    /// Registers the start of the program counter for the code section.
    pub fn code_offset(&mut self, offset: u16) {
        self.code.shift(offset);
    }

    /// Registers an external function in the jump table.
    pub fn ext(&mut self, pc: u16, func: ExtFunc) {
        self.code.try_add_func(func.clone());
        self.jump.insert(pc, Jump::ExtFunc(func));
    }

    /// Registers a label in the jump table.
    pub fn label(&mut self, pc: u16, label: u16) {
        self.jump.insert(pc, Jump::Label(label));
    }

    /// Merges another jump table into this one.
    ///
    /// This function updates the program counters of the target jump table and
    /// handles any potential duplicates.
    pub fn merge(&mut self, mut table: Self, pc: u16) -> Result<()> {
        if pc != 0 {
            table.shift_pc(0, pc)?;
        }

        for (pc, jump) in table.jump.into_iter() {
            if self.jump.insert(pc, jump).is_some() {
                return Err(Error::DuplicateJump(pc));
            }
        }

        for (func, offset) in table.func.into_iter() {
            if self.func.insert(func, offset).is_some() {
                return Err(Error::DuplicateFunc(func));
            }
        }

        for func in table.code.funcs() {
            self.code.try_add_func(func);
        }

        Ok(())
    }

    /// register jump to program counter
    pub fn register(&mut self, pc: u16, jump: Jump) {
        self.jump.insert(pc, jump);
    }

    /// Get the max target from the current jump table
    pub fn max_target(&self) -> u16 {
        self.jump
            .iter()
            .filter_map(|(_, jump)| self.target(jump).ok())
            .max()
            .unwrap_or(0)
    }
}

#[test]
fn test_multiple_jumps_same_target() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Setup multiple jumps to same target
    table.register(0x10, Jump::Label(0x100));
    table.register(0x20, Jump::Label(0x100));
    table.shift_targets()?;

    // Verify each jump's final target
    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
    Ok(())
}

#[test]
fn test_nested_jumps() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Create nested jump pattern
    table.register(0x10, Jump::Label(0x100)); // Jump to middle
    table.register(0x100, Jump::Label(0x200)); // Middle jumps to end
    table.register(0x20, Jump::Label(0x100)); // Another jump to middle

    table.shift_targets()?;

    // Verify jumps are processed correctly
    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
    assert_eq!(table.target(table.jump.get(&0x100).unwrap())?, 0x209);
    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
    Ok(())
}

#[test]
fn test_sequential_jumps() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Create sequence of jumps that follow each other
    table.register(0x10, Jump::Label(0x20));
    table.register(0x20, Jump::Label(0x30));
    table.register(0x30, Jump::Label(0x40));

    table.shift_targets()?;

    // Each target should be shifted by accumulated offset
    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x34);
    assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x46);
    Ok(())
}

#[test]
fn test_jump_backwards() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    table.register(0x10, Jump::Label(0x20));
    table.register(0x30, Jump::Label(0x20));

    table.shift_targets()?;

    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
    assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x22);
    Ok(())
}

#[test]
fn test_jump_table_state_consistency() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Register a sequence of jumps that mirror ERC20's pattern
    table.register(0x10, Jump::Label(0x100)); // First jump
    table.register(0x20, Jump::Label(0x100)); // Second jump to same target

    // Record state before and after each operation
    let initial_state = table.jump.clone();
    table.shift_targets()?;
    let shifted_state = table.jump.clone();

    // Verify jump table consistency
    assert_eq!(table.jump.len(), initial_state.len());
    assert!(shifted_state.values().all(|j| matches!(j, Jump::Label(_))));
    Ok(())
}

#[test]
fn test_jump_target_ordering() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Register jumps in reverse order
    table.register(0x30, Jump::Label(0x100));
    table.register(0x20, Jump::Label(0x100));
    table.register(0x10, Jump::Label(0x100));

    // Track all target shifts
    let mut shifts = Vec::new();
    let cloned = table.clone();
    let original_targets: Vec<_> = cloned.jump.values().collect();

    table.shift_targets()?;

    // Verify target consistency
    for (orig, shifted) in original_targets.iter().zip(table.jump.values()) {
        shifts.push((orig, shifted));
    }

    Ok(())
}

#[test]
fn test_mixed_jump_types() -> anyhow::Result<()> {
    let mut table = JumpTable::default();

    // Mix function calls and labels like in ERC20
    table.func.insert(1, 0x100);
    table.call(0x10, 1); // Function call
    table.register(0x20, Jump::Label(0x100)); // Label jump to same target

    let before_shift = table.jump.clone();
    table.shift_targets()?;
    let after_shift = table.jump.clone();

    // Compare states
    assert_eq!(before_shift.len(), after_shift.len());

    Ok(())
}