Skip to content

Commit

Permalink
Minimize direct imports from claripy.ast (#558)
Browse files Browse the repository at this point in the history
* Minimize direct imports from claripy.ast

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
twizmwazin and pre-commit-ci[bot] authored Oct 30, 2024
1 parent 6c103a7 commit 798c07c
Show file tree
Hide file tree
Showing 24 changed files with 63 additions and 76 deletions.
2 changes: 1 addition & 1 deletion claripy/algorithm/bool_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from claripy.errors import BackendError

if TYPE_CHECKING:
from claripy.ast.bool import Bool
from claripy.ast import Bool

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion claripy/algorithm/ite_relocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from weakref import WeakValueDictionary

import claripy
from claripy.ast.base import Base
from claripy.ast import Base

T = TypeVar("T", bound=Base)

Expand Down
2 changes: 1 addition & 1 deletion claripy/algorithm/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, TypeVar

from claripy.ast.base import Base
from claripy.ast import Base
from claripy.errors import ClaripyReplacementError

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion claripy/algorithm/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TypeVar, cast
from weakref import WeakValueDictionary

from claripy.ast.base import Base
from claripy.ast import Base

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
from contextlib import suppress

from claripy.ast.base import Base
from claripy.ast import Base
from claripy.errors import BackendError, BackendUnsupportedError, ClaripyRecursionError

log = logging.getLogger(__name__)
Expand Down
15 changes: 6 additions & 9 deletions claripy/backends/backend_concrete/backend_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import operator
from functools import reduce

from claripy.ast import Base
from claripy.ast.bool import Bool, BoolV
from claripy.ast.bv import BV, BVV
from claripy.ast.fp import FPV
from claripy.ast.strings import StringV
import claripy
from claripy.ast import BV, Base, Bool
from claripy.backends.backend import Backend
from claripy.backends.backend_concrete import bv, fp, strings
from claripy.errors import BackendError, UnsatError
Expand Down Expand Up @@ -144,13 +141,13 @@ def _convert(self, r):

def _abstract(self, e): # pylint:disable=no-self-use
if isinstance(e, bv.BVV):
return BVV(e.value, e.size())
return claripy.BVV(e.value, e.size())
if isinstance(e, bool):
return BoolV(e)
return claripy.BoolV(e)
if isinstance(e, fp.FPV):
return FPV(e.value, e.sort)
return claripy.FPV(e.value, e.sort)
if isinstance(e, strings.StringV):
return StringV(e.value)
return claripy.StringV(e.value)
raise BackendError(f"Couldn't abstract object of type {type(e)}")

def _cardinality(self, a): # pylint:disable=unused-argument
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_concrete/bv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def normalize_helper(self, o):

class BVV(BackendObject):
"""A concrete bitvector value. Used in the concrete backend for calculations.
Any use outside of claripy should use `claripy.ast.bv.BVV` instead.
Any use outside of claripy should use `claripy.BVV` instead.
"""

__slots__ = ["bits", "_value", "mod"]
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_concrete/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def normalize_helper(self, o):

class FPV(BackendObject):
"""A concrete floating point value. Used in the concrete backend for
calculations. Any use outside of claripy should use `claripy.ast.fp.FPV`
calculations. Any use outside of claripy should use `claripy.FPV`
instead.
"""

Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_concrete/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class StringV(BackendObject):
"""A concrete string value. Used in the concrete backend for calculations.
Any use outside of claripy should use `claripy.ast.strings.StringV` instead.
Any use outside of claripy should use `claripy.StringV` instead.
"""

def __init__(self, value):
Expand Down
15 changes: 7 additions & 8 deletions claripy/backends/backend_vsa/backend_vsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import claripy
from claripy.annotation import RegionAnnotation, StridedIntervalAnnotation
from claripy.ast.base import Base
from claripy.ast.bv import BV, BVV, ESI, SI, TSI, VS
from claripy.ast import BV, Base
from claripy.backends.backend import Backend
from claripy.backends.backend_vsa.balancer import Balancer
from claripy.backends.backend_vsa.errors import ClaripyVSAError
Expand Down Expand Up @@ -127,12 +126,12 @@ def _abstract(self, e):
return e
if isinstance(e, StridedInterval):
if e.is_top:
return TSI(e.bits, explicit_name=e.name)
return claripy.TSI(e.bits, explicit_name=e.name)
if e.is_bottom:
return ESI(e.bits)
return claripy.ESI(e.bits)
if e.stride in {0, 1} and e.lower_bound == e.upper_bound:
return BVV(e.lower_bound, e.bits)
return SI(
return claripy.BVV(e.lower_bound, e.bits)
return claripy.SI(
name=e.name,
bits=e.bits,
lower_bound=e.lower_bound,
Expand All @@ -141,10 +140,10 @@ def _abstract(self, e):
)
if isinstance(e, ValueSet):
if len(e.regions) == 0:
return VS(bits=e.bits, name=e.name)
return claripy.VS(bits=e.bits, name=e.name)
if len(e.regions) == 1:
region = next(iter(e.regions))
return VS(
return claripy.VS(
bits=e.bits,
region=region,
region_base_addr=e._region_base_addrs[region].eval(1)[0] if e._region_base_addrs else 0,
Expand Down
14 changes: 6 additions & 8 deletions claripy/backends/backend_vsa/balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import claripy
import claripy.backends.backend_vsa as vsa
from claripy.ast.base import Base
from claripy.ast.bool import Bool
from claripy.ast.bv import BV, BVS, BVV
from claripy.ast import BV, Base, Bool
from claripy.errors import BackendError, ClaripyBalancerError, ClaripyBalancerUnsatError, ClaripyOperationError
from claripy.operations import commutative_operations, opposites

Expand Down Expand Up @@ -47,7 +45,7 @@ def _replacements_iter(self):
min_int = 0
mn = self._lower_bounds.get(k, min_int)
mx = self._upper_bounds.get(k, max_int)
bound_si = BVS("bound", len(ast)).annotate(claripy.annotation.StridedIntervalAnnotation(1, mn, mx))
bound_si = claripy.BVS("bound", len(ast)).annotate(claripy.annotation.StridedIntervalAnnotation(1, mn, mx))
l.debug("Yielding bound %s for %s.", bound_si, ast)
if ast.op == "Reverse":
yield (ast.args[0], ast.intersection(bound_si).reversed)
Expand Down Expand Up @@ -83,7 +81,7 @@ def _same_bound_bv(a):
si = claripy.backends.vsa.convert(a)
mx = Balancer._max(a)
mn = Balancer._min(a)
return BVS("bounds", len(a), min=mn, max=mx, stride=si._stride)
return claripy.BVS("bounds", len(a), min=mn, max=mx, stride=si._stride)

@staticmethod
def _cardinality(a):
Expand Down Expand Up @@ -456,15 +454,15 @@ def _balance_extract(truism):

if left_msb_zero and left_lsb_zero:
new_left = inner
new_right = claripy.Concat(BVV(0, len(left_msb)), truism.args[1], BVV(0, len(left_lsb)))
new_right = claripy.Concat(claripy.BVV(0, len(left_msb)), truism.args[1], claripy.BVV(0, len(left_lsb)))
return truism.make_like(truism.op, (new_left, new_right))
if left_msb_zero:
new_left = inner
new_right = claripy.Concat(BVV(0, len(left_msb)), truism.args[1])
new_right = claripy.Concat(claripy.BVV(0, len(left_msb)), truism.args[1])
return truism.make_like(truism.op, (new_left, new_right))
if left_lsb_zero:
new_left = inner
new_right = claripy.Concat(truism.args[1], BVV(0, len(left_lsb)))
new_right = claripy.Concat(truism.args[1], claripy.BVV(0, len(left_lsb)))
return truism.make_like(truism.op, (new_left, new_right))

if low == 0 and truism.args[1].op == "BVV" and truism.op not in {"SGE", "SLE", "SGT", "SLT"}:
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_vsa/strided_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import claripy
import claripy.backends.backend_vsa as vsa
from claripy.ast.base import Base
from claripy.ast import Base
from claripy.backends.backend_concrete import BVV
from claripy.backends.backend_object import BackendObject
from claripy.errors import ClaripyOperationError
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend_vsa/valueset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import numbers

from claripy.ast.base import Base
from claripy.ast import Base
from claripy.backends.backend_object import BackendObject
from claripy.errors import ClaripyValueError

Expand Down
18 changes: 8 additions & 10 deletions claripy/backends/backend_z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
import z3
from cachetools import LRUCache

from claripy.ast.bool import Bool, BoolV
from claripy.ast.bv import BV, BVV
from claripy.ast.fp import FP, FPV
from claripy.ast.strings import StringV
import claripy
from claripy.ast import BV, FP, Bool
from claripy.backends.backend import Backend
from claripy.errors import (
BackendError,
Expand Down Expand Up @@ -446,25 +444,25 @@ def _abstract_internal(self, ctx, ast, split_on=None):
append_children = True

if op_name == "True":
return BoolV(True)
return claripy.true()
if op_name == "False":
return BoolV(False)
return claripy.false()
if op_name.startswith("RM_"):
return RM(op_name)
if op_name == "INTERNAL":
return StringV(z3.SeqRef(ast).as_string())
return claripy.StringV(z3.SeqRef(ast).as_string())
if op_name == "BitVecVal":
bv_size = z3.Z3_get_bv_sort_size(ctx, z3_sort)
if z3.Z3_get_numeral_uint64(ctx, ast, self._c_uint64_p):
return BVV(self._c_uint64_p.contents.value, bv_size)
return claripy.BVV(self._c_uint64_p.contents.value, bv_size)
bv_num = int(z3.Z3_get_numeral_string(ctx, ast))
return BVV(bv_num, bv_size)
return claripy.BVV(bv_num, bv_size)
if op_name in ("FPVal", "MinusZero", "MinusInf", "PlusZero", "PlusInf", "NaN"):
ebits = z3.Z3_fpa_get_ebits(ctx, z3_sort)
sbits = z3.Z3_fpa_get_sbits(ctx, z3_sort)
sort = FSort.from_params(ebits, sbits)
val = self._abstract_fp_val(ctx, ast, op_name)
return FPV(val, sort)
return claripy.FPV(val, sort)

if op_name == "UNINTERPRETED" and num_args == 0: # symbolic value
symbol_name = _z3_decl_name_str(ctx, decl)
Expand Down
3 changes: 1 addition & 2 deletions claripy/frontend/composite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import weakref
from typing import TYPE_CHECKING

from claripy import backends
from claripy import Or, backends
from claripy.ast import Base
from claripy.ast.bool import Or
from claripy.errors import BackendError, UnsatError

from .constrained_frontend import ConstrainedFrontend
Expand Down
2 changes: 1 addition & 1 deletion claripy/frontend/constrained_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import logging

from claripy import And, Or
from claripy.algorithm import simplify
from claripy.annotation import SimplificationAvoidanceAnnotation
from claripy.ast.bool import And, Or

from .frontend import Frontend

Expand Down
5 changes: 2 additions & 3 deletions claripy/frontend/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import logging
import numbers

from claripy.ast.bool import BoolV
from claripy.ast.bv import BV, BVV
from claripy.ast.strings import String, StringV
from claripy import BVV, BoolV, StringV
from claripy.ast import BV, String

log = logging.getLogger(__name__)

Expand Down
14 changes: 6 additions & 8 deletions claripy/frontend/full_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import threading
from typing import TYPE_CHECKING, Any, overload

import claripy
from claripy import backends
from claripy.ast.bv import SGE, SLE, UGE, ULE
from claripy.errors import BackendError, ClaripyFrontendError, UnsatError

from .constrained_frontend import ConstrainedFrontend
Expand All @@ -15,9 +15,7 @@

from typing_extensions import Self

from claripy.ast.bool import Bool
from claripy.ast.bv import BV
from claripy.ast.fp import FP
from claripy.ast import BV, FP, Bool

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -228,9 +226,9 @@ def max(self, e, extra_constraints=(), signed=False, exact=None):
return two[0]

if signed:
c = (*tuple(extra_constraints), SGE(e, two[0]), SGE(e, two[1]))
c = (*tuple(extra_constraints), claripy.SGE(e, two[0]), claripy.SGE(e, two[1]))
else:
c = (*tuple(extra_constraints), UGE(e, two[0]), UGE(e, two[1]))
c = (*tuple(extra_constraints), claripy.UGE(e, two[0]), claripy.UGE(e, two[1]))
try:
return self._solver_backend.max(
e,
Expand Down Expand Up @@ -271,9 +269,9 @@ def min(self, e, extra_constraints=(), signed=False, exact=None):
return two[0]

if signed:
c = (*tuple(extra_constraints), SLE(e, two[0]), SLE(e, two[1]))
c = (*tuple(extra_constraints), claripy.SLE(e, two[0]), claripy.SLE(e, two[1]))
else:
c = (*tuple(extra_constraints), ULE(e, two[0]), ULE(e, two[1]))
c = (*tuple(extra_constraints), claripy.ULE(e, two[0]), claripy.ULE(e, two[1]))
try:
return self._solver_backend.min(
e,
Expand Down
9 changes: 4 additions & 5 deletions claripy/frontend/mixin/constraint_expansion_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import logging

from claripy.ast.bool import Or
from claripy.ast.bv import SGE, SLE, UGE, ULE
import claripy

log = logging.getLogger(__name__)

Expand All @@ -20,20 +19,20 @@ def eval(self, e, n, extra_constraints=(), exact=None):
# add constraints to help the solver out later
# TODO: does this really help?
if len(extra_constraints) == 0 and len(results) < n:
self.add([Or(*[e == v for v in results])], invalidate_cache=False)
self.add([claripy.Or(*[e == v for v in results])], invalidate_cache=False)

return results

def max(self, e, extra_constraints=(), signed=False, exact=None):
m = super().max(e, extra_constraints=extra_constraints, signed=signed, exact=exact)
if len(extra_constraints) == 0:
self.add([SLE(e, m) if signed else ULE(e, m)], invalidate_cache=False)
self.add([claripy.SLE(e, m) if signed else claripy.ULE(e, m)], invalidate_cache=False)
return m

def min(self, e, extra_constraints=(), signed=False, exact=None):
m = super().min(e, extra_constraints=extra_constraints, signed=signed, exact=exact)
if len(extra_constraints) == 0:
self.add([SGE(e, m) if signed else UGE(e, m)], invalidate_cache=False)
self.add([claripy.SGE(e, m) if signed else claripy.UGE(e, m)], invalidate_cache=False)
return m

def solution(self, e, v, extra_constraints=(), exact=None):
Expand Down
Loading

0 comments on commit 798c07c

Please sign in to comment.