zingen/jump/
table.rs

1//! Jump Table
2//!
3//! This module defines the `JumpTable` struct, which manages the jump table, function
4//! table, and code section. It provides methods to register jumps, functions, and
5//! labels, as well as to merge jump tables.
6
7use crate::{codegen::ExtFunc, jump::Jump, Code, Error, Result};
8use std::collections::BTreeMap;
9
10/// Jump table implementation.
11#[derive(Clone, Default, Debug)]
12pub struct JumpTable {
13    /// Jump table mapping program counters to jump types.
14    pub(crate) jump: BTreeMap<u16, Jump>,
15    /// Function table mapping function indices to program counters.
16    pub(crate) func: BTreeMap<u32, u16>,
17    /// Code section associated with the jump table.
18    pub(crate) code: Code,
19}
20
21impl JumpTable {
22    /// Registers a function in the jump table.
23    ///
24    /// This function associates a program counter with a function.
25    pub fn call(&mut self, pc: u16, func: u32) {
26        self.jump.insert(pc, Jump::Func(func));
27    }
28
29    /// Registers a program counter to the function table.
30    ///
31    /// This function associates a function with a specific offset in the function table.
32    pub fn call_offset(&mut self, func: u32, offset: u16) -> Result<()> {
33        if self.func.insert(func, offset).is_some() {
34            return Err(Error::DuplicateFunc(func));
35        }
36
37        Ok(())
38    }
39
40    /// Registers the start of the program counter for the code section.
41    pub fn code_offset(&mut self, offset: u16) {
42        self.code.shift(offset);
43    }
44
45    /// Registers an external function in the jump table.
46    pub fn ext(&mut self, pc: u16, func: ExtFunc) {
47        self.code.try_add_func(func.clone());
48        self.jump.insert(pc, Jump::ExtFunc(func));
49    }
50
51    /// Registers a label in the jump table.
52    pub fn label(&mut self, pc: u16, label: u16) {
53        self.jump.insert(pc, Jump::Label(label));
54    }
55
56    /// Merges another jump table into this one.
57    ///
58    /// This function updates the program counters of the target jump table and
59    /// handles any potential duplicates.
60    pub fn merge(&mut self, mut table: Self, pc: u16) -> Result<()> {
61        if pc != 0 {
62            table.shift_pc(0, pc)?;
63        }
64
65        for (pc, jump) in table.jump.into_iter() {
66            if self.jump.insert(pc, jump).is_some() {
67                return Err(Error::DuplicateJump(pc));
68            }
69        }
70
71        for (func, offset) in table.func.into_iter() {
72            if self.func.insert(func, offset).is_some() {
73                return Err(Error::DuplicateFunc(func));
74            }
75        }
76
77        for func in table.code.funcs() {
78            self.code.try_add_func(func);
79        }
80
81        Ok(())
82    }
83
84    /// register jump to program counter
85    pub fn register(&mut self, pc: u16, jump: Jump) {
86        self.jump.insert(pc, jump);
87    }
88
89    /// Get the max target from the current jump table
90    pub fn max_target(&self) -> u16 {
91        self.jump
92            .iter()
93            .filter_map(|(_, jump)| self.target(jump).ok())
94            .max()
95            .unwrap_or(0)
96    }
97}
98
99#[test]
100fn test_multiple_jumps_same_target() -> anyhow::Result<()> {
101    let mut table = JumpTable::default();
102
103    // Setup multiple jumps to same target
104    table.register(0x10, Jump::Label(0x100));
105    table.register(0x20, Jump::Label(0x100));
106    table.shift_targets()?;
107
108    // Verify each jump's final target
109    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
110    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
111    Ok(())
112}
113
114#[test]
115fn test_nested_jumps() -> anyhow::Result<()> {
116    let mut table = JumpTable::default();
117
118    // Create nested jump pattern
119    table.register(0x10, Jump::Label(0x100)); // Jump to middle
120    table.register(0x100, Jump::Label(0x200)); // Middle jumps to end
121    table.register(0x20, Jump::Label(0x100)); // Another jump to middle
122
123    table.shift_targets()?;
124
125    // Verify jumps are processed correctly
126    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
127    assert_eq!(table.target(table.jump.get(&0x100).unwrap())?, 0x209);
128    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
129    Ok(())
130}
131
132#[test]
133fn test_sequential_jumps() -> anyhow::Result<()> {
134    let mut table = JumpTable::default();
135
136    // Create sequence of jumps that follow each other
137    table.register(0x10, Jump::Label(0x20));
138    table.register(0x20, Jump::Label(0x30));
139    table.register(0x30, Jump::Label(0x40));
140
141    table.shift_targets()?;
142
143    // Each target should be shifted by accumulated offset
144    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
145    assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x34);
146    assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x46);
147    Ok(())
148}
149
150#[test]
151fn test_jump_backwards() -> anyhow::Result<()> {
152    let mut table = JumpTable::default();
153
154    table.register(0x10, Jump::Label(0x20));
155    table.register(0x30, Jump::Label(0x20));
156
157    table.shift_targets()?;
158
159    assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
160    assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x22);
161    Ok(())
162}
163
164#[test]
165fn test_jump_table_state_consistency() -> anyhow::Result<()> {
166    let mut table = JumpTable::default();
167
168    // Register a sequence of jumps that mirror ERC20's pattern
169    table.register(0x10, Jump::Label(0x100)); // First jump
170    table.register(0x20, Jump::Label(0x100)); // Second jump to same target
171
172    // Record state before and after each operation
173    let initial_state = table.jump.clone();
174    table.shift_targets()?;
175    let shifted_state = table.jump.clone();
176
177    // Verify jump table consistency
178    assert_eq!(table.jump.len(), initial_state.len());
179    assert!(shifted_state.values().all(|j| matches!(j, Jump::Label(_))));
180    Ok(())
181}
182
183#[test]
184fn test_jump_target_ordering() -> anyhow::Result<()> {
185    let mut table = JumpTable::default();
186
187    // Register jumps in reverse order
188    table.register(0x30, Jump::Label(0x100));
189    table.register(0x20, Jump::Label(0x100));
190    table.register(0x10, Jump::Label(0x100));
191
192    // Track all target shifts
193    let mut shifts = Vec::new();
194    let cloned = table.clone();
195    let original_targets: Vec<_> = cloned.jump.values().collect();
196
197    table.shift_targets()?;
198
199    // Verify target consistency
200    for (orig, shifted) in original_targets.iter().zip(table.jump.values()) {
201        shifts.push((orig, shifted));
202    }
203
204    Ok(())
205}
206
207#[test]
208fn test_mixed_jump_types() -> anyhow::Result<()> {
209    let mut table = JumpTable::default();
210
211    // Mix function calls and labels like in ERC20
212    table.func.insert(1, 0x100);
213    table.call(0x10, 1); // Function call
214    table.register(0x20, Jump::Label(0x100)); // Label jump to same target
215
216    let before_shift = table.jump.clone();
217    table.shift_targets()?;
218    let after_shift = table.jump.clone();
219
220    // Compare states
221    assert_eq!(before_shift.len(), after_shift.len());
222
223    Ok(())
224}