import argparse from dataclasses import dataclass from pathlib import Path from typing import Iterable, Iterator, List, Optional, Sequence, Set, Tuple SBOX = [ 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, ] def rot_word_be(w: int) -> int: return ((w << 8) & 0xFFFFFFFF) | ((w >> 24) & 0xFF) def sub_word_be(w: int) -> int: return ( (SBOX[(w >> 24) & 0xFF] << 24) | (SBOX[(w >> 16) & 0xFF] << 16) | (SBOX[(w >> 8) & 0xFF] << 8) | (SBOX[w & 0xFF]) ) def rot_word_le(w: int) -> int: b = w.to_bytes(4, "little") b = b[1:] + b[:1] return int.from_bytes(b, "little") def sub_word_le(w: int) -> int: b = w.to_bytes(4, "little") sb = bytes([SBOX[x] for x in b]) return int.from_bytes(sb, "little") def xtime(x: int) -> int: x <<= 1 if x & 0x100: x ^= 0x11B return x & 0xFF def rcon_word(i: int, *, endian: str) -> int: rc = 1 for _ in range(1, i): rc = xtime(rc) if endian == "be": return rc << 24 return rc def total_words_for_nk(nk: int) -> int: if nk == 4: nr = 10 elif nk == 6: nr = 12 elif nk == 8: nr = 14 else: raise ValueError(f"Unsupported Nk={nk}") return 4 * (nr + 1) def expand_key(words0: Sequence[int], *, nk: int, endian: str) -> List[int]: tw = total_words_for_nk(nk) w = list(words0[:nk]) + [0] * (tw - nk) if endian == "be": rot = rot_word_be sub = sub_word_be elif endian == "le": rot = rot_word_le sub = sub_word_le else: raise ValueError("endian must be 'be' or 'le'") for i in range(nk, tw): temp = w[i - 1] if i % nk == 0: temp = sub(rot(temp)) ^ rcon_word(i // nk, endian=endian) elif nk > 6 and i % nk == 4: temp = sub(temp) w[i] = w[i - nk] ^ temp return w def schedule_matches(words: Sequence[int], *, nk: int, endian: str) -> bool: tw = total_words_for_nk(nk) if len(words) < tw: return False exp = expand_key(words, nk=nk, endian=endian) return all((words[i] & 0xFFFFFFFF) == exp[i] for i in range(tw)) def iter_files(paths: Iterable[str]) -> Iterator[Path]: for p in paths: path = Path(p) if path.is_dir(): for child in sorted(path.rglob("*")): if child.is_file(): yield child elif path.is_file(): yield path @dataclass(frozen=True) class Hit: file: Path offset: int nk: int word_endian: str byte_endian: str key_bytes: bytes def scan_file(fp: Path, *, nk_list: Sequence[int]) -> List[Hit]: data = fp.read_bytes() hits: List[Hit] = [] for nk in nk_list: tw = total_words_for_nk(nk) nbytes = tw * 4 if len(data) < nbytes: continue for off in range(0, len(data) - nbytes + 1, 4): for word_endian in ("little", "big"): words = [ int.from_bytes( data[off + 4 * i : off + 4 * i + 4], byteorder=word_endian ) for i in range(tw) ] for byte_endian in ("be", "le"): if schedule_matches(words, nk=nk, endian=byte_endian): if byte_endian == "be": key = b"".join(w.to_bytes(4, "big") for w in words[:nk]) else: key = b"".join(w.to_bytes(4, "little") for w in words[:nk]) hits.append( Hit( file=fp, offset=off, nk=nk, word_endian=word_endian, byte_endian=byte_endian, key_bytes=key, ) ) return hits def main() -> int: ap = argparse.ArgumentParser(description="Find AES key schedules in raw dumps") ap.add_argument("paths", nargs="+", help="Files and/or directories to scan") ap.add_argument("--nk", type=int, default=8, help="AES Nk words: 4=128-bit, 6=192-bit, 8=256-bit (default: 8)") ap.add_argument( "--also", type=str, default="", help="Comma-separated extra Nk values to scan (e.g. 4,6)", ) args = ap.parse_args() nk_list = [args.nk] if args.also: for part in args.also.split(","): part = part.strip() if not part: continue nk_list.append(int(part)) nk_list = sorted(set(nk_list)) seen: Set[Tuple[int, str, bytes]] = set() total = 0 for fp in iter_files(args.paths): if fp.suffix.lower() not in (".dmp", ".bin", ".raw", ".mem", ""): continue try: hits = scan_file(fp, nk_list=nk_list) except Exception: continue for h in hits: k = (h.nk, h.word_endian, h.byte_endian, h.key_bytes) if k in seen: continue seen.add(k) total += 1 print( f"{h.file}\t0x{h.offset:X}\tNk={h.nk}\tword={h.word_endian}\tbytes={h.byte_endian}\tkey={h.key_bytes.hex()}" ) if total == 0: print("[!] No AES key schedules found") return 2 return 0 if __name__ == "__main__": raise SystemExit(main())