from typing import Set, Dict
from collections import defaultdict
import logging
import ailment
from .optimization_pass import OptimizationPass, OptimizationPassStage
_l = logging.getLogger(name=__name__)
[docs]def s2u(s, bits):
if s > 0:
return s
return (1 << bits) + s
[docs]class StackCanarySimplifier(OptimizationPass):
"""
Removes stack canary checks from decompilation results.
"""
ARCHES = [
"X86",
"AMD64",
] # TODO: fs is x86 only. Figure out how stack canary is loaded in other architectures
PLATFORMS = ["cgc", "linux"]
STAGE = OptimizationPassStage.AFTER_GLOBAL_SIMPLIFICATION
NAME = "Simplify stack canaries"
DESCRIPTION = __doc__.strip()
[docs] def __init__(self, func, **kwargs):
super().__init__(func, **kwargs)
self.analyze()
def _check(self):
# Check the first block and see if there is any statement reading data from fs:0x28h
init_stmt = self._find_canary_init_stmt()
return init_stmt is not None, {"init_stmt": init_stmt}
def _analyze(self, cache=None):
init_stmt = None
if cache is not None:
init_stmt = cache.get("init_stmt", None)
if init_stmt is None:
init_stmt = self._find_canary_init_stmt()
if init_stmt is None:
return
# Look for the statement that loads back canary value from the stack
first_block, stmt_idx = init_stmt
canary_init_stmt = first_block.statements[stmt_idx]
# where is the stack canary stored?
if not isinstance(canary_init_stmt.addr, ailment.Expr.StackBaseOffset):
_l.debug(
"Unsupported canary storing location %s. Expects an ailment.Expr.StackBaseOffset.",
canary_init_stmt.addr,
)
return
store_offset = canary_init_stmt.addr.offset
if not isinstance(store_offset, int):
_l.debug("Unsupported canary storing offset %s. Expects an int.", store_offset)
# The function should have at least one end point with an if-else statement
# Find all nodes with 0 out-degrees
all_endpoint_addrs = [node.addr for node in self._func.graph.nodes() if self._func.graph.out_degree(node) == 0]
# Before node duplication, each pair of canary-check-success and canary-check-failure nodes have a common
# predecessor.
# map endpoint addrs to their common predecessors
pred_addr_to_endpoint_addrs: Dict[int, Set[int]] = defaultdict(set)
for node_addr in all_endpoint_addrs:
preds = self._func.graph.predecessors(self._func.get_node(node_addr))
for pred in preds:
pred_addr_to_endpoint_addrs[pred.addr].add(node_addr)
found_endpoints = False
for pred_addr in pred_addr_to_endpoint_addrs:
endpoint_addrs = pred_addr_to_endpoint_addrs[pred_addr]
if len(endpoint_addrs) != 2:
# we expect there to be only two nodes: one for canary-check-success, and the other for
# canary-check-failure. if not, we check the next predecessor
continue
# because other optimization passes may duplicate nodes, we may have more than one node for each function
# endpoint.
endpoint_addrs_list = list(endpoint_addrs)
endnodes_0 = list(self._get_blocks(endpoint_addrs_list[0]))
endnodes_1 = list(self._get_blocks(endpoint_addrs_list[1]))
if not endnodes_0 or not endnodes_1:
_l.warning("Unexpected situation: endnodes_0 or endnodes_1 is empty.")
continue
# One of the end nodes calls __stack_chk_fail
stack_chk_fail_callers = None
other_nodes = None
for endnodes, o in [(endnodes_0, endnodes_1), (endnodes_1, endnodes_0)]:
if self._calls_stack_chk_fail(endnodes[0]):
stack_chk_fail_callers = endnodes
other_nodes = o
break
else:
_l.debug("Cannot find the node that calls __stack_chk_fail().")
continue
# Match stack_chk_fail_caller, ret_node, and predecessor
nodes_to_process = []
for stack_chk_fail_caller in stack_chk_fail_callers:
all_preds = set(self._graph.predecessors(stack_chk_fail_caller))
preds_for_other_nodes = set()
for o in other_nodes:
preds_for_other_nodes |= set(self._graph.predecessors(o))
preds = list(all_preds.intersection(preds_for_other_nodes))
if len(preds) != 1:
_l.debug("Expect 1 predecessor. Found %d.", len(preds))
continue
pred = preds[0]
# More sanity checks
if len(pred.statements) < 1:
_l.debug("The predecessor node is empty.")
continue
# Check the last statement
if not isinstance(pred.statements[-1], ailment.Stmt.ConditionalJump):
_l.debug("The predecessor does not end with a conditional jump.")
continue
# Find the statement that computes real canary value xor stored canary value
canary_check_stmt_idx = self._find_canary_comparison_stmt(pred, store_offset)
if canary_check_stmt_idx is None:
_l.debug("Cannot find the canary check statement in the predecessor.")
continue
succs = list(self._graph.successors(pred))
if len(succs) != 2:
_l.debug("Expect 2 successors. Found %d.", len(succs))
continue
if stack_chk_fail_caller is succs[0]:
ret_node = succs[1]
else:
ret_node = succs[0]
nodes_to_process.append((pred, canary_check_stmt_idx, stack_chk_fail_caller, ret_node))
# Awesome. Now patch this function.
for pred, canary_check_stmt_idx, stack_chk_fail_caller, ret_node in nodes_to_process:
# Patch the pred so that it jumps to the one that is not stack_chk_fail_caller
pred_copy = pred.copy()
pred_copy.statements[-1] = ailment.Stmt.Jump(
len(pred_copy.statements) - 1,
ailment.Expr.Const(None, None, ret_node.addr, self.project.arch.bits),
ins_addr=pred_copy.statements[-1].ins_addr,
)
self._graph.remove_edge(pred, stack_chk_fail_caller)
self._update_block(pred, pred_copy)
if self._graph.in_degree[stack_chk_fail_caller] == 0:
# Remove the block that calls stack_chk_fail_caller
self._remove_block(stack_chk_fail_caller)
found_endpoints = True
if found_endpoints:
# Remove the statement that loads the stack canary from fs
first_block_copy = first_block.copy()
first_block_copy.statements.pop(stmt_idx)
self._update_block(first_block, first_block_copy)
# Done!
def _find_canary_init_stmt(self):
block_addr = self._func.addr
traversed = set()
while True:
traversed.add(block_addr)
first_block = self._get_block(block_addr)
if first_block is None:
break
for idx, stmt in enumerate(first_block.statements):
if (
isinstance(stmt, ailment.Stmt.Store)
and isinstance(stmt.addr, ailment.Expr.StackBaseOffset)
and isinstance(stmt.data, ailment.Expr.Load)
and self._is_add(stmt.data.addr)
):
# Check addr: must be fs+0x28
op0, op1 = stmt.data.addr.operands
if isinstance(op1, ailment.Expr.Register):
op0, op1 = op1, op0
if isinstance(op0, ailment.Expr.Register) and isinstance(op1, ailment.Expr.Const):
if op0.reg_offset == self.project.arch.get_register_offset("fs") and op1.value == 0x28:
return first_block, idx
succs = list(self._graph.successors(first_block))
if len(succs) == 1:
block_addr = succs[0].addr
if block_addr not in traversed:
continue
break
return None
def _find_canary_comparison_stmt(self, block, canary_value_stack_offset):
for idx, stmt in enumerate(block.statements):
if isinstance(stmt, ailment.Stmt.ConditionalJump):
if isinstance(stmt.condition, ailment.Expr.UnaryOp) and stmt.condition.op == "Not":
# !(s_10 ^ fs:0x28 == 0)
negated = True
condition = stmt.condition.operand
else:
negated = False
condition = stmt.condition
if isinstance(condition, ailment.Expr.BinaryOp) and (
not negated and condition.op == "CmpEQ" or negated and condition.op == "CmpNE"
):
pass
else:
continue
expr0, expr1 = condition.operands
if isinstance(expr0, ailment.Expr.BinaryOp) and expr0.op == "Xor":
# a ^ b
op0, op1 = expr0.operands
if not (
self._is_stack_canary_load_expr(op0, self.project.arch.bits, canary_value_stack_offset)
and self._is_random_number_load_expr(op1, self.project.arch.get_register_offset("fs"))
or (
self._is_stack_canary_load_expr(op1, self.project.arch.bits, canary_value_stack_offset)
and self._is_random_number_load_expr(op0, self.project.arch.get_register_offset("fs"))
)
):
continue
elif (
isinstance(expr0, ailment.Expr.Load)
and isinstance(expr1, ailment.Expr.Load)
and condition.op == "CmpEQ"
):
# a == b
if not (
self._is_stack_canary_load_expr(expr0, self.project.arch.bits, canary_value_stack_offset)
and self._is_random_number_load_expr(expr1, self.project.arch.get_register_offset("fs"))
or (
self._is_stack_canary_load_expr(expr1, self.project.arch.bits, canary_value_stack_offset)
and self._is_random_number_load_expr(expr0, self.project.arch.get_register_offset("fs"))
)
):
continue
# Found it
return idx
return None
def _calls_stack_chk_fail(self, node):
for stmt in node.statements:
if isinstance(stmt, ailment.Stmt.Call) and isinstance(stmt.target, ailment.Expr.Const):
const_target = stmt.target.value
if const_target in self.kb.functions:
func = self.kb.functions.function(addr=const_target)
if func.name == "__stack_chk_fail":
return True
return False
@staticmethod
def _is_stack_canary_load_expr(expr, bits: int, canary_value_stack_offset: int) -> bool:
if not (isinstance(expr, ailment.Expr.Load) and isinstance(expr.addr, ailment.Expr.StackBaseOffset)):
return False
if s2u(expr.addr.offset, bits) != s2u(canary_value_stack_offset, bits):
return False
return True
@staticmethod
def _is_random_number_load_expr(expr, fs_reg_offset: int) -> bool:
return (
isinstance(expr, ailment.Expr.Load)
and isinstance(expr.addr, ailment.Expr.BinaryOp)
and expr.addr.op == "Add"
and isinstance(expr.addr.operands[0], ailment.Expr.Const)
and expr.addr.operands[0].value == 0x28
and isinstance(expr.addr.operands[1], ailment.Expr.Register)
and expr.addr.operands[1].reg_offset == fs_reg_offset
)