use crate::{codegen::ExtFunc, jump::Jump, Code, Error, Result};
use std::collections::BTreeMap;
#[derive(Clone, Default, Debug)]
pub struct JumpTable {
pub(crate) jump: BTreeMap<u16, Jump>,
pub(crate) func: BTreeMap<u32, u16>,
pub(crate) code: Code,
}
impl JumpTable {
pub fn call(&mut self, pc: u16, func: u32) {
self.jump.insert(pc, Jump::Func(func));
}
pub fn call_offset(&mut self, func: u32, offset: u16) -> Result<()> {
if self.func.insert(func, offset).is_some() {
return Err(Error::DuplicateFunc(func));
}
Ok(())
}
pub fn code_offset(&mut self, offset: u16) {
self.code.shift(offset);
}
pub fn ext(&mut self, pc: u16, func: ExtFunc) {
self.code.try_add_func(func.clone());
self.jump.insert(pc, Jump::ExtFunc(func));
}
pub fn label(&mut self, pc: u16, label: u16) {
self.jump.insert(pc, Jump::Label(label));
}
pub fn merge(&mut self, mut table: Self, pc: u16) -> Result<()> {
if pc != 0 {
table.shift_pc(0, pc)?;
}
for (pc, jump) in table.jump.into_iter() {
if self.jump.insert(pc, jump).is_some() {
return Err(Error::DuplicateJump(pc));
}
}
for (func, offset) in table.func.into_iter() {
if self.func.insert(func, offset).is_some() {
return Err(Error::DuplicateFunc(func));
}
}
for func in table.code.funcs() {
self.code.try_add_func(func);
}
Ok(())
}
pub fn register(&mut self, pc: u16, jump: Jump) {
self.jump.insert(pc, jump);
}
pub fn max_target(&self) -> u16 {
self.jump
.iter()
.filter_map(|(_, jump)| self.target(jump).ok())
.max()
.unwrap_or(0)
}
}
#[test]
fn test_multiple_jumps_same_target() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x10, Jump::Label(0x100));
table.register(0x20, Jump::Label(0x100));
table.shift_targets()?;
assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
Ok(())
}
#[test]
fn test_nested_jumps() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x10, Jump::Label(0x100)); table.register(0x100, Jump::Label(0x200)); table.register(0x20, Jump::Label(0x100)); table.shift_targets()?;
assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x106);
assert_eq!(table.target(table.jump.get(&0x100).unwrap())?, 0x209);
assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x106);
Ok(())
}
#[test]
fn test_sequential_jumps() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x10, Jump::Label(0x20));
table.register(0x20, Jump::Label(0x30));
table.register(0x30, Jump::Label(0x40));
table.shift_targets()?;
assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
assert_eq!(table.target(table.jump.get(&0x20).unwrap())?, 0x34);
assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x46);
Ok(())
}
#[test]
fn test_jump_backwards() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x10, Jump::Label(0x20));
table.register(0x30, Jump::Label(0x20));
table.shift_targets()?;
assert_eq!(table.target(table.jump.get(&0x10).unwrap())?, 0x22);
assert_eq!(table.target(table.jump.get(&0x30).unwrap())?, 0x22);
Ok(())
}
#[test]
fn test_jump_table_state_consistency() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x10, Jump::Label(0x100)); table.register(0x20, Jump::Label(0x100)); let initial_state = table.jump.clone();
table.shift_targets()?;
let shifted_state = table.jump.clone();
assert_eq!(table.jump.len(), initial_state.len());
assert!(shifted_state.values().all(|j| matches!(j, Jump::Label(_))));
Ok(())
}
#[test]
fn test_jump_target_ordering() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.register(0x30, Jump::Label(0x100));
table.register(0x20, Jump::Label(0x100));
table.register(0x10, Jump::Label(0x100));
let mut shifts = Vec::new();
let cloned = table.clone();
let original_targets: Vec<_> = cloned.jump.values().collect();
table.shift_targets()?;
for (orig, shifted) in original_targets.iter().zip(table.jump.values()) {
shifts.push((orig, shifted));
}
Ok(())
}
#[test]
fn test_mixed_jump_types() -> anyhow::Result<()> {
let mut table = JumpTable::default();
table.func.insert(1, 0x100);
table.call(0x10, 1); table.register(0x20, Jump::Label(0x100)); let before_shift = table.jump.clone();
table.shift_targets()?;
let after_shift = table.jump.clone();
assert_eq!(before_shift.len(), after_shift.len());
Ok(())
}