//! Jump table implementation.
//!
//! This module defines the `Jump` enum and the `JumpTable` struct, which are used to manage
//! various types of jumps in the program, including offsets, labels, function calls, and
//! external functions.
use crate::codegen::ExtFunc;
use core::fmt::Display;
pub use table::JumpTable;
mod pc;
mod relocate;
mod table;
mod target;
/// Represents the different types of jumps in the program.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Jump {
/// Jump to a specific label, which corresponds to the original program counter.
Label(u16),
/// Jump to a function identified by its index.
Func(u32),
/// Jump to an external function.
ExtFunc(ExtFunc),
}
impl Display for Jump {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Jump::Label(offset) => write!(f, "Label(0x{offset:x})"),
Jump::Func(index) => write!(f, "Func({index})"),
Jump::ExtFunc(_) => write!(f, "ExtFunc"),
}
}
}
impl Jump {
/// Checks if the target is a label.
pub fn is_label(&self) -> bool {
matches!(self, Jump::Label { .. })
}
/// Checks if the target is a function call.
pub fn is_call(&self) -> bool {
!self.is_label()
}
}
#[cfg(test)]
mod tests {
use crate::jump::{Jump, JumpTable};
use smallvec::smallvec;
#[allow(unused)]
fn init_tracing() {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.without_time()
.compact()
.try_init()
.ok();
}
fn assert_target_shift_vs_relocation(mut table: JumpTable) -> anyhow::Result<()> {
// Calculate expected buffer size based on the maximum target
let mut buffer = smallvec![0; table.max_target() as usize];
// Perform target shifts
table.shift_targets()?;
// Find the maximum target after shifts
let max_target = table.max_target();
// Perform relocation
table.relocate(&mut buffer)?;
assert_eq!(buffer.len(), max_target as usize);
Ok(())
}
#[test]
fn test_target_shift_vs_relocation() -> anyhow::Result<()> {
let mut table = JumpTable::default();
// Register jumps with known offsets and labels
table.register(0x10, Jump::Label(0x20)); // Jump to label at 0x20
table.register(0x30, Jump::Label(0x40)); // Jump to label at 0x40
assert_target_shift_vs_relocation(table)
}
#[test]
fn test_multiple_internal_calls() -> anyhow::Result<()> {
let mut table = JumpTable::default();
// Simulate multiple functions calling _approve
table.register(0x10, Jump::Label(0x100)); // approve() -> _approve
table.register(0x20, Jump::Label(0x100)); // spend_allowance() -> _approve
assert_target_shift_vs_relocation(table)
}
#[test]
fn test_nested_function_calls() -> anyhow::Result<()> {
let mut table = JumpTable::default();
// Simulate ERC20's approve -> _approve call chain
table.register(0x100, Jump::Label(0x200)); // approve entry
table.register(0x110, Jump::Label(0x300)); // approve -> _approve
table.register(0x200, Jump::Label(0x400)); // _approve entry
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;
// Check if all jumps use correct PUSH instructions
assert_eq!(buffer[0x100], 0x61); // PUSH2
assert_eq!(buffer[0x113], 0x61); // PUSH2
assert_eq!(buffer[0x206], 0x61); // PUSH2
Ok(())
}
#[test]
fn test_label_call_interaction() -> anyhow::Result<()> {
init_tracing();
let mut table = JumpTable::default();
table.func.insert(1, 0x317);
table.label(0x10, 0x12);
table.call(0x11, 1);
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;
assert_eq!(buffer[0x11], 0x17, "{buffer:?}");
assert_eq!(buffer[0x14], 0x03, "{buffer:?}");
assert_eq!(buffer[0x15], 0x1c, "{buffer:?}");
Ok(())
}
#[test]
fn test_large_target_offset_calculation() -> anyhow::Result<()> {
let mut table = JumpTable::default();
// Register a jump with target < 0xff
table.register(0x10, Jump::Label(0x80));
// Register a jump with target > 0xff
table.register(0x20, Jump::Label(0x100));
// Register a jump with target > 0xfff
table.register(0x30, Jump::Label(0x1000));
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;
// Check if offsets are correctly calculated
// For target 0x80: PUSH1 (1 byte) + target (1 byte)
// For target 0x100: PUSH2 (1 byte) + target (2 bytes)
// For target 0x1000: PUSH2 (1 byte) + target (2 bytes)
assert_eq!(buffer[0x11], 0x88); // Small target
assert_eq!(buffer[0x23], 0x01); // First byte of large target
assert_eq!(buffer[0x24], 0x08); // Second byte of large target
assert_eq!(buffer[0x36], 0x10); // First byte of large target
assert_eq!(buffer[0x37], 0x08); // Second byte of large target
Ok(())
}
#[test]
fn test_sequential_large_jumps() -> anyhow::Result<()> {
let mut table = JumpTable::default();
// Register multiple sequential jumps with increasing targets
// This mirrors the ERC20 pattern where we have many functions
for i in 0..20 {
let target = 0x100 + (i * 0x20);
table.register(0x10 + i, Jump::Label(target));
}
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;
// Check first jump (should use PUSH2)
assert_eq!(buffer[0x10], 0x61); // PUSH2
assert_eq!(buffer[0x11], 0x01); // First byte
assert_eq!(buffer[0x12], 0x3c); // Second byte
assert_eq!(0x013c, 0x100 + 20 * 3);
// Check last jump (should still use PUSH2 but with adjusted offset)
let last_idx = 0x10 + 19 + 19 * 3;
assert_eq!(buffer[last_idx], 0x61); // PUSH2
assert_eq!(buffer[last_idx + 1], 0x03); // First byte should be larger
assert_eq!(buffer[last_idx + 2], 0x9c); // Second byte accounts for all previous jumps
assert_eq!(0x039c, 0x100 + 0x20 * 19 + 20 * 3);
Ok(())
}
#[test]
fn test_dispatcher_jump_targets() -> anyhow::Result<()> {
let mut table = JumpTable::default();
let selectors = 5;
// Register jumps for each selector check
for i in 0..selectors {
let i = i as u16;
let check_pc = 0x10 + i * 0x20;
let target_pc = 0x100 + i * 0x40;
// Register both the comparison jump and function jump
table.register(check_pc, Jump::Label(check_pc + 0x10));
table.register(check_pc + 0x10, Jump::Label(target_pc));
}
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;
// Verify each selector's jump chain
let mut total_offset = 0;
for i in 0..selectors {
let check_pc = 0x10 + i * 0x20 + total_offset;
let check_pc_offset = if check_pc + 0x10 > 0xff { 3 } else { 2 };
let func_pc = check_pc + 0x10 + check_pc_offset;
let check_jump = buffer[check_pc];
let func_jump = buffer[func_pc];
assert_eq!(check_jump, if func_pc > 0xff { 0x61 } else { 0x60 });
assert_eq!(func_jump, 0x61);
// Update total offset for next iteration
total_offset += check_pc_offset + 3;
}
Ok(())
}
}