Skip to content

Commit

Permalink
AILBlockWalker: Support MultiStatementExpression objects. (#166)
Browse files Browse the repository at this point in the history
* AILBlockWalker: Support MultiStatementExpression objects.

* AILBlockWalker: Fix missing block is None checks.
  • Loading branch information
ltfish authored Sep 20, 2023
1 parent 81e5d4e commit 1733c41
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
64 changes: 59 additions & 5 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Tmp,
Register,
Const,
MultiStatementExpression,
)


Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None):
Tmp: self._handle_Tmp,
Register: self._handle_Register,
Const: self._handle_Const,
MultiStatementExpression: self._handle_MultiStatementExpression,
}

self.stmt_handlers: Dict[Type, Callable] = stmt_handlers if stmt_handlers else _default_stmt_handlers
Expand Down Expand Up @@ -149,6 +151,13 @@ def _handle_Register(self, expr_idx: int, expr: Register, stmt_idx: int, stmt: S
def _handle_Const(self, expr_idx: int, expr: Const, stmt_idx: int, stmt: Statement, block: Optional[Block]):
pass

def _handle_MultiStatementExpression(
self, expr_idx, expr: MultiStatementExpression, stmt_idx: int, stmt: Statement, block: Optional[Block]
):
for idx, stmt_ in enumerate(expr.stmts):
self._handle_stmt(idx, stmt_, None)
self._handle_expr(0, expr.expr, stmt_idx, stmt, block)

def _handle_DirtyExpression(
self, expr_idx: int, expr: DirtyExpression, stmt_idx: int, stmt: Statement, block: Optional[Block]
):
Expand Down Expand Up @@ -218,7 +227,10 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Optional[Bl
if changed:
# update the statement directly in the block
new_stmt = Assignment(stmt.idx, dst, src, **stmt.tags)
block.statements[stmt_idx] = new_stmt
if block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_Call(self, stmt_idx: int, stmt: Call, block: Optional[Block]):
if stmt.args:
Expand Down Expand Up @@ -250,7 +262,10 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Optional[Block]):
ret_expr=stmt.ret_expr,
**stmt.tags,
)
block.statements[stmt_idx] = new_stmt
if block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_Store(self, stmt_idx: int, stmt: Store, block: Optional[Block]):
changed = False
Expand Down Expand Up @@ -280,7 +295,10 @@ def _handle_Store(self, stmt_idx: int, stmt: Store, block: Optional[Block]):
offset=stmt.offset,
**stmt.tags,
)
block.statements[stmt_idx] = new_stmt
if block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Optional[Block]):
changed = False
Expand All @@ -305,7 +323,10 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: O

if changed:
new_stmt = ConditionalJump(stmt.idx, condition, true_target, false_target, **stmt.tags)
block.statements[stmt_idx] = new_stmt
if block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_Return(self, stmt_idx: int, stmt: Return, block: Optional[Block]):
if stmt.ret_exprs:
Expand All @@ -323,7 +344,14 @@ def _handle_Return(self, stmt_idx: int, stmt: Return, block: Optional[Block]):

if changed:
new_stmt = Return(stmt.idx, new_ret_exprs, **stmt.tags)
block.statements[stmt_idx] = new_stmt
if block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

#
# Expression handlers
#

def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement, block: Optional[Block]):
addr = self._handle_expr(0, expr.addr, stmt_idx, stmt, block)
Expand Down Expand Up @@ -454,3 +482,29 @@ def _handle_VEXCCallExpression(
new_expr.operands = tuple(new_operands)
return new_expr
return None

def _handle_MultiStatementExpression(
self, expr_idx, expr: MultiStatementExpression, stmt_idx: int, stmt: Statement, block: Optional[Block]
):
changed = False
new_statements = []
for idx, stmt_ in enumerate(expr.stmts):
new_stmt = self._handle_stmt(idx, stmt_, None)
if new_stmt is not None and new_stmt is not stmt_:
changed = True
new_statements.append(new_stmt)
else:
new_statements.append(stmt_)

new_expr = self._handle_expr(0, expr.expr, stmt_idx, stmt, block)
if new_expr is not None and new_expr is not expr.expr:
changed = True
else:
new_expr = expr.expr

if changed:
expr_ = expr.copy()
expr_.expr = new_expr
expr_.stmts = new_statements
return expr_
return None
3 changes: 3 additions & 0 deletions ailment/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,9 @@ def replace(self, old_expr, new_expr):
)
return False, self

def copy(self) -> "MultiStatementExpression":
return MultiStatementExpression(self.idx, self.stmts[::], self.expr, **self.tags)


#
# Special (Dummy) expressions
Expand Down

0 comments on commit 1733c41

Please sign in to comment.