diff --git a/cpapi/mgmt_api.py b/cpapi/mgmt_api.py index 62972a4..1e8e3ce 100644 --- a/cpapi/mgmt_api.py +++ b/cpapi/mgmt_api.py @@ -38,9 +38,10 @@ class APIClientArgs: # port is set to None by default, but it gets replaced with 443 if not specified # context possible values - web_api (default) or gaia_api + # single_conn is set to True by default, when work on parallel set to False def __init__(self, port=None, fingerprint=None, sid=None, server="127.0.0.1", http_debug_level=0, api_calls=None, debug_file="", proxy_host=None, proxy_port=8080, - api_version=None, unsafe=False, unsafe_auto_accept=False, context="web_api"): + api_version=None, unsafe=False, unsafe_auto_accept=False, context="web_api", single_conn=True): self.port = port # management server fingerprint self.fingerprint = fingerprint @@ -66,6 +67,8 @@ def __init__(self, port=None, fingerprint=None, sid=None, server="127.0.0.1", ht self.unsafe_auto_accept = unsafe_auto_accept # The context of using the client - defaults to web_api self.context = context + # Indicates that the client should use single HTTPS connection + self.single_conn = single_conn class APIClient: @@ -108,6 +111,10 @@ def __init__(self, api_client_args=None): self.unsafe_auto_accept = api_client_args.unsafe_auto_accept # The context of using the client - defaults to web_api self.context = api_client_args.context + # HTTPS connection + self.conn = None + # Indicates that the client should use single HTTPS connection + self.single_conn = api_client_args.single_conn def __enter__(self): return self @@ -117,6 +124,7 @@ def __exit__(self, exc_type, exc_value, traceback): # if sid is not empty (the login api was called), then call logout if self.sid: self.api_call("logout") + self.close_connection() # save debug data with api calls to disk self.save_debug_data() @@ -265,8 +273,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout= :side-effects: updates the class's uid and server variables """ timeout_start = time.time() - - self.check_fingerprint() + if self.check_fingerprint() is False: + return APIResponse("", False, err_message="Invalid fingerprint") if payload is None: payload = {} # Convert the json payload to a string if needed @@ -292,23 +300,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout= if sid is not None: _headers["X-chkp-sid"] = sid - # Create ssl context with no ssl verification, we do it by ourselves - context = ssl.create_default_context() - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - - # create https connection - if self.proxy_host and self.proxy_port: - conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context) - conn.set_tunnel(self.server, self.get_port()) - else: - conn = HTTPSConnection(self.server, self.get_port(), context=context) - - # Set fingerprint - conn.fingerprint = self.fingerprint - - # Set debug level - conn.set_debuglevel(self.http_debug_level) + # init https connection. if single connection is True, use last connection + conn = self.get_https_connection() url = "/" + self.context + "/" + (("v" + str(self.api_version) + "/") if self.api_version else "") + command response = None try: @@ -328,7 +321,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout= except Exception as err: res = APIResponse("", False, err_message=err) finally: - conn.close() + if not self.single_conn: + conn.close() if response: res.status_code = response.status @@ -464,21 +458,13 @@ def gen_api_query(self, command, details_level="standard", container_keys=None, def get_server_fingerprint(self): """ - Initiates an HTTPS connection to the server and extracts the SHA1 fingerprint from the server's certificate. + Initiates an HTTPS connection to the server if need and extracts the SHA1 fingerprint from the server's certificate. :return: string with SHA1 fingerprint (all uppercase letters) """ - context = ssl.create_default_context() - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - - if self.proxy_host and self.proxy_port: - conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context) - conn.set_tunnel(self.server, self.get_port()) - else: - conn = HTTPSConnection(self.server, self.get_port(), context=context) - + conn = self.get_https_connection(set_fingerprint=False, set_debug_level=False) fingerprint_hash = conn.get_fingerprint_hash() - conn.close() + if not self.single_conn: + conn.close() return fingerprint_hash def __wait_for_task(self, task_id, timeout=-1): @@ -723,22 +709,50 @@ def read_fingerprint_from_file(server, filename="fingerprints.txt"): return json_dict[server] return "" + def create_https_connection(self, set_fingerprint, set_debug_level): + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + # create https connection + if self.proxy_host and self.proxy_port: + conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context) + conn.set_tunnel(self.server, self.get_port()) + else: + conn = HTTPSConnection(self.server, self.get_port(), context=context) + + # Set fingerprint + if set_fingerprint: + conn.fingerprint = self.fingerprint + + # Set debug level + if set_debug_level: + conn.set_debuglevel(self.http_debug_level) + conn.connect() + return conn + + def get_https_connection(self, set_fingerprint=True, set_debug_level=True): + if self.single_conn: + if self.conn is None: + self.conn = self.create_https_connection(set_fingerprint, set_debug_level) + return self.conn + return self.create_https_connection(set_fingerprint, set_debug_level) + + def close_connection(self): + if self.conn: + self.conn.close() + class HTTPSConnection(http_client.HTTPSConnection): """ A class for making HTTPS connections that overrides the default HTTPS checks (e.g. not accepting self-signed-certificates) and replaces them with a server fingerprint check. """ - def connect(self): http_client.HTTPConnection.connect(self) self.sock = ssl.wrap_socket(self.sock, self.key_file, self.cert_file, cert_reqs=ssl.CERT_NONE) def get_fingerprint_hash(self): - try: - http_client.HTTPConnection.connect(self) - self.sock = ssl.wrap_socket(self.sock, self.key_file, self.cert_file, cert_reqs=ssl.CERT_NONE) - except Exception: - return "" + if self.sock is None: + self.connect() fingerprint = hashlib.new("SHA1", self.sock.getpeercert(True)).hexdigest() return fingerprint.upper()