from __future__ import annotations import re from collections.abc import Iterable, Mapping from dataclasses import dataclass from .formatting import parse_int from .model import Instruction REGISTER_NAMES = tuple(f"R{idx}" for idx in range(8)) CONTROL_REGISTER_NAMES = ("CCR", "BR", "EP", "DP", "TP", "SR") @dataclass(frozen=True) class TrackedValue: value: int | None = None width: int | None = None source: str = "" reason: str = "" @property def known(self) -> bool: return self.value is not None and self.width is not None State = dict[str, dict[str, TrackedValue]] def analyze_dataflow( instructions: Mapping[int, Instruction], labels: Mapping[int, str] | None = None, functions: object | None = None, ) -> dict[str, object]: """Track simple register values through conservative linear basic blocks. The pass intentionally avoids merging states at branch targets. Each basic block starts with unknown register state, then instructions are interpreted linearly until a branch, jump, return, label, function entry, or address gap. """ ordered = sorted(instructions) block_starts = _find_block_starts(instructions, labels, functions) instruction_records: dict[int, dict[str, object]] = {} blocks: list[dict[str, object]] = [] state: State | None = None current_block: dict[str, object] | None = None for index, address in enumerate(ordered): ins = instructions[address] next_address = ordered[index + 1] if index + 1 < len(ordered) else None starts_new_block = state is None or current_block is None or address in block_starts if index > 0: previous = instructions[ordered[index - 1]] starts_new_block = starts_new_block or not _is_contiguous(previous, address) if starts_new_block: if current_block is not None: blocks.append(current_block) state = _initial_state() current_block = {"start": address, "instructions": []} assert state is not None assert current_block is not None before = _copy_state(state) after, notes = _transfer(ins, before) changes = _state_changes(before, after) block_start = int(current_block["start"]) record = { "address": address, "text": ins.text, "mnemonic": ins.mnemonic, "operands": ins.operands, "kind": ins.kind, "block": block_start, "before": _public_state(before), "after": _public_state(after), "changes": changes, "notes": notes, } instruction_records[address] = record cast_instructions = current_block["instructions"] assert isinstance(cast_instructions, list) cast_instructions.append(address) current_block["end"] = address current_block["end_exclusive"] = address + max(ins.size, 1) state = after if _ends_basic_block(ins, next_address): blocks.append(current_block) current_block = None state = None if current_block is not None: blocks.append(current_block) return { "instructions": instruction_records, "blocks": blocks, "registers": REGISTER_NAMES, "control_registers": CONTROL_REGISTER_NAMES, } track_registers = analyze_dataflow def state_for_instruction(analysis: Mapping[str, object] | None, address: int) -> dict[str, object]: if not analysis: return {} instructions = analysis.get("instructions") if not isinstance(instructions, Mapping): return {} record = instructions.get(address) return record if isinstance(record, dict) else {} def _find_block_starts( instructions: Mapping[int, Instruction], labels: Mapping[int, str] | None, functions: object | None, ) -> set[int]: addresses = set(instructions) starts: set[int] = set() if addresses: starts.add(min(addresses)) if labels: starts.update(address for address in labels if address in addresses) starts.update(address for address in _function_entries(functions) if address in addresses) for address, ins in instructions.items(): starts.update(target for target in ins.targets if target in addresses) if ins.kind == "branch" and ins.fallthrough: fallthrough = address + max(ins.size, 1) if fallthrough in addresses: starts.add(fallthrough) return starts def _function_entries(functions: object | None) -> set[int]: if functions is None: return set() if isinstance(functions, Mapping): if "nodes" in functions: return _function_entries(functions.get("nodes")) if "start" in functions: value = functions.get("start") return {int(value)} if value is not None else set() entries: set[int] = set() for key, value in functions.items(): if isinstance(key, int): entries.add(key) if isinstance(value, Mapping) and "start" in value: entries.add(int(value["start"])) return entries if isinstance(functions, Iterable) and not isinstance(functions, (str, bytes)): entries = set() for item in functions: if isinstance(item, int): entries.add(item) elif isinstance(item, Mapping) and "start" in item: entries.add(int(item["start"])) return entries return set() def _initial_state(reason: str = "block_entry") -> State: return { "registers": {name: _unknown(reason) for name in REGISTER_NAMES}, "control": {name: _unknown(reason) for name in CONTROL_REGISTER_NAMES}, } def _copy_state(state: State) -> State: return { "registers": dict(state["registers"]), "control": dict(state["control"]), } def _public_state(state: State) -> dict[str, dict[str, dict[str, object]]]: return { "registers": {name: _public_value(value) for name, value in state["registers"].items()}, "control": {name: _public_value(value) for name, value in state["control"].items()}, } def _public_value(value: TrackedValue) -> dict[str, object]: if not value.known: result: dict[str, object] = {"known": False} if value.reason: result["reason"] = value.reason return result assert value.value is not None assert value.width is not None digits = 2 if value.width <= 8 else 4 result = { "known": True, "value": value.value, "hex": f"0x{value.value:0{digits}X}", "width": value.width, } if value.source: result["source"] = value.source return result def _unknown(reason: str = "") -> TrackedValue: return TrackedValue(reason=reason) def _known(value: int, width: int, source: str) -> TrackedValue: return TrackedValue(value=value & _mask(width), width=width, source=source) def _transfer(ins: Instruction, state: State) -> tuple[State, list[str]]: after = _copy_state(state) notes: list[str] = [] mnemonic = ins.mnemonic base = _mnemonic_base(mnemonic) width = _mnemonic_width(mnemonic) ops = split_operands(ins.operands) if ins.kind == "call": _unknown_all(after, "call") notes.append("call clobbers tracked register state") return after, notes if ins.kind == "jump" and not ins.targets: _unknown_all(after, "indirect_jump") notes.append("indirect jump ends known register state") return after, notes if ins.writes_br: if ins.br_value is None: _set_control_unknown(after, "BR", "control_load") else: _set_control_known(after, "BR", ins.br_value, 8, ins.text) notes.append("tracked BR write") if base == "NOP": return after, notes if base in {"CMP:E", "CMP:I", "CMP:G", "TST", "BTST"}: _unknown_ccr(after, "flags") return after, notes if base in {"MOV:I", "MOV:E", "MOV:G"} and len(ops) == 2: _apply_mov(after, ops[0], ops[1], width, ins, notes) _unknown_ccr(after, "flags") return after, notes if base in {"MOV:L", "MOV:F", "MOVFPE"} and len(ops) == 2: if _is_register(ops[1]): _set_register_unknown(after, ops[1], "memory_load") notes.append(f"{ops[1]} unknown after memory load") _apply_addressing_side_effects(after, ops, width) _unknown_ccr(after, "flags") return after, notes if base in {"MOV:S", "MOVTPE"}: _apply_addressing_side_effects(after, ops, width) _unknown_ccr(after, "flags") return after, notes if base == "CLR" and len(ops) == 1: if _is_register(ops[0]): _set_register_known(after, ops[0], 0, width or 16, ins.text) notes.append(f"{ops[0]} cleared") else: _apply_addressing_side_effects(after, ops, width) _unknown_ccr(after, "flags") return after, notes if base in {"ADD", "ADD:G", "ADD:Q", "ADDS", "SUB", "SUBS"} and len(ops) == 2: _apply_add_sub(after, base, ops[0], ops[1], width, ins, notes) _unknown_ccr(after, "flags") return after, notes if base == "LDC" and len(ops) == 2: _apply_ldc(after, ops[0], ops[1], width, ins, notes) return after, notes if base == "STC" and len(ops) == 2: _apply_stc(after, ops[0], ops[1], width, ins, notes) return after, notes if base in {"ORC", "ANDC", "XORC"} and len(ops) == 2: _apply_control_binary(after, base, ops[0], ops[1], width, ins, notes) return after, notes _apply_unsupported(after, base, ops, width, ins, notes) return after, notes def split_operands(operands: str) -> list[str]: if not operands: return [] parts: list[str] = [] start = 0 depth = 0 for idx, char in enumerate(operands): if char in "({": depth += 1 elif char in ")}" and depth: depth -= 1 elif char == "," and depth == 0: parts.append(operands[start:idx].strip()) start = idx + 1 parts.append(operands[start:].strip()) return [part for part in parts if part] def _apply_mov( state: State, source: str, dest: str, width: int | None, ins: Instruction, notes: list[str], ) -> None: effective_width = width or 16 _apply_addressing_side_effects(state, (source, dest), effective_width) if not _is_register(dest): return if source.startswith("@"): _set_register_unknown(state, dest, "memory_load") notes.append(f"{dest} unknown after memory load") return operand = _operand_value(state, source, effective_width) if operand is None: _set_register_unknown(state, dest, "unknown_operand") notes.append(f"{dest} unknown after MOV source") return _set_register_known(state, dest, operand, effective_width, ins.text) notes.append(f"{dest} = {_format_known(operand, effective_width)}") def _apply_add_sub( state: State, base: str, source: str, dest: str, width: int | None, ins: Instruction, notes: list[str], ) -> None: effective_width = width or 16 _apply_addressing_side_effects(state, (source, dest), effective_width) if not _is_register(dest): return if source.startswith("@"): _set_register_unknown(state, dest, "memory_load") notes.append(f"{dest} unknown after arithmetic memory source") return left = _operand_value(state, dest, effective_width) right = _operand_value(state, source, effective_width) if left is None or right is None: _set_register_unknown(state, dest, "unknown_operand") notes.append(f"{dest} unknown after arithmetic") return if base.startswith("SUB"): result = left - right else: result = left + right _set_register_known(state, dest, result, effective_width, ins.text) notes.append(f"{dest} = {_format_known(result, effective_width)}") def _apply_ldc( state: State, source: str, dest: str, width: int | None, ins: Instruction, notes: list[str], ) -> None: control = _control_name(dest) if control is None: return effective_width = _control_width(control, width) if source.startswith("@"): _set_control_unknown(state, control, "memory_load") notes.append(f"{control} unknown after memory load") return value = _operand_value(state, source, effective_width) if value is None: _set_control_unknown(state, control, "unknown_operand") notes.append(f"{control} unknown after LDC source") return _set_control_known(state, control, value, effective_width, ins.text) notes.append(f"{control} = {_format_known(value, effective_width)}") def _apply_stc( state: State, source: str, dest: str, width: int | None, ins: Instruction, notes: list[str], ) -> None: control = _control_name(source) if control is None: return effective_width = _control_width(control, width) value = _control_value(state, control, effective_width) if _is_register(dest): if value is None: _set_register_unknown(state, dest, "unknown_operand") notes.append(f"{dest} unknown after STC source") else: _set_register_known(state, dest, value, effective_width, ins.text) notes.append(f"{dest} = {_format_known(value, effective_width)}") else: _apply_addressing_side_effects(state, (dest,), effective_width) def _apply_control_binary( state: State, base: str, source: str, dest: str, width: int | None, ins: Instruction, notes: list[str], ) -> None: control = _control_name(dest) if control is None: return effective_width = _control_width(control, width) left = _control_value(state, control, effective_width) right = _operand_value(state, source, effective_width) if left is None or right is None: _set_control_unknown(state, control, "unknown_operand") notes.append(f"{control} unknown after {base}") return if base == "ORC": result = left | right elif base == "ANDC": result = left & right else: result = left ^ right _set_control_known(state, control, result, effective_width, ins.text) notes.append(f"{control} = {_format_known(result, effective_width)}") def _apply_unsupported( state: State, base: str, ops: list[str], width: int | None, ins: Instruction, notes: list[str], ) -> None: if base in {"RTE", "RTS", "RTD", "PRTS", "PRTD", "SLEEP", "BRA", "BHI", "BLS", "BCC", "BCS", "BNE", "BEQ", "BVC", "BVS", "BPL", "BMI", "BGE", "BLT", "BGT", "BLE", "BRN", "SCB/F", "SCB/NE", "SCB/EQ", "JMP", "PJMP", "BSR", "JSR", "PJSR"}: return affected = _written_registers(base, ops) for register in affected: _set_register_unknown(state, register, f"unsupported:{ins.mnemonic}") _apply_addressing_side_effects(state, ops, width) if affected: notes.append(f"unsupported operation invalidated {', '.join(affected)}") if _may_update_ccr(base): _unknown_ccr(state, "flags") def _operand_value(state: State, operand: str, width: int) -> int | None: operand = operand.strip() immediate = _parse_immediate(operand) if immediate is not None: return immediate & _mask(width) if _is_register(operand): value = state["registers"][operand] return _narrow(value, width) control = _control_name(operand) if control is not None: return _control_value(state, control, width) return None def _control_value(state: State, control: str, width: int) -> int | None: return _narrow(state["control"][control], width) def _narrow(value: TrackedValue, width: int) -> int | None: if not value.known or value.value is None or value.width is None: return None if width <= value.width: return value.value & _mask(width) return None def _parse_immediate(operand: str) -> int | None: if not operand.startswith("#"): return None text = operand[1:].strip() if not text: return None if text.startswith("-"): return -parse_int(text[1:]) try: return parse_int(text) except ValueError: return None def _set_register_known(state: State, register: str, value: int, width: int, source: str) -> None: state["registers"][register] = _known(value, width, source) def _set_register_unknown(state: State, register: str, reason: str) -> None: state["registers"][register] = _unknown(reason) def _set_control_known(state: State, control: str, value: int, width: int, source: str) -> None: state["control"][control] = _known(value, width, source) def _set_control_unknown(state: State, control: str, reason: str) -> None: state["control"][control] = _unknown(reason) def _unknown_all(state: State, reason: str) -> None: for register in REGISTER_NAMES: _set_register_unknown(state, register, reason) for control in CONTROL_REGISTER_NAMES: _set_control_unknown(state, control, reason) def _unknown_ccr(state: State, reason: str) -> None: _set_control_unknown(state, "CCR", reason) def _apply_addressing_side_effects(state: State, operands: Iterable[str], width: int | None) -> None: _ = width for operand in operands: match = re.fullmatch(r"@-(R[0-7])", operand) or re.fullmatch(r"@(R[0-7])\+", operand) if match: _set_register_unknown(state, match.group(1), "addressing_side_effect") def _written_registers(base: str, ops: list[str]) -> list[str]: if base == "LDM" and len(ops) == 2: return [reg for reg in REGISTER_NAMES if re.search(rf"\b{reg}\b", ops[1])] if base in {"SWAP", "EXTS", "EXTU", "NEG", "NOT", "SHAL", "SHAR", "SHLL", "SHLR", "ROTL", "ROTR", "ROTXL", "ROTXR", "TAS"} and ops: return [ops[0]] if _is_register(ops[0]) else [] if len(ops) >= 2 and base not in {"CMP", "CMP:E", "CMP:I", "CMP:G", "BTST", "TST", "STM"}: dest = ops[-1] return [dest] if _is_register(dest) else [] return [] def _may_update_ccr(base: str) -> bool: return base not in {"NOP", "MOV:S", "MOVTPE", "STC", "LDC", "STM", "LDM", "LINK", "UNLK"} def _state_changes(before: State, after: State) -> list[dict[str, object]]: changes: list[dict[str, object]] = [] for group_name, public_name in (("registers", "register"), ("control", "control")): for name in before[group_name]: if before[group_name][name] == after[group_name][name]: continue changes.append( { "kind": public_name, "name": name, "before": _public_value(before[group_name][name]), "after": _public_value(after[group_name][name]), } ) return changes def _ends_basic_block(ins: Instruction, next_address: int | None) -> bool: if next_address is None: return True if ins.kind in {"branch", "jump", "return", "rte", "sleep"}: return True if not ins.fallthrough: return True return not _is_contiguous(previous_instruction=ins, address=next_address) def _is_contiguous(previous_instruction: Instruction, address: int) -> bool: return previous_instruction.address + max(previous_instruction.size, 1) == address def _mnemonic_base(mnemonic: str) -> str: return mnemonic.rsplit(".", 1)[0] if "." in mnemonic else mnemonic def _mnemonic_width(mnemonic: str) -> int | None: suffix = mnemonic.rsplit(".", 1)[-1] if "." in mnemonic else "" if suffix == "B": return 8 if suffix == "W": return 16 if mnemonic.endswith(":I"): return 16 if mnemonic.endswith(":E"): return 8 return None def _control_width(control: str, mnemonic_width: int | None) -> int: if control == "SR": return 16 return mnemonic_width or 8 def _mask(width: int) -> int: return (1 << width) - 1 def _format_known(value: int, width: int) -> str: digits = 2 if width <= 8 else 4 return f"0x{value & _mask(width):0{digits}X}" def _is_register(operand: str) -> bool: return operand in REGISTER_NAMES def _control_name(operand: str) -> str | None: operand = operand.strip() return operand if operand in CONTROL_REGISTER_NAMES else None