Skip to content

Commit

Permalink
[🚀 Feature] [py]: Better compatibility with Appium-python (#14587)
Browse files Browse the repository at this point in the history
* [py] override default locator converter for python

* Support registering custom finders in py

* Support registering extra HTTP commands and methods in python

* Support overriding User-Agent in python

* Support registering extra headers

* [py] Support ignore certificates

* Support using custom element classes

* tests for custom element test

* address review comments

* close parenthesis

Co-authored-by: Kazuaki Matsuo <[email protected]>

* pass `init_args_for_pool_manager` in constructor

* use existing driver fixture in tests

* convert `add_command` to instance method

---------

Co-authored-by: Sri Harsha <[email protected]>
Co-authored-by: Kazuaki Matsuo <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 6ea5651 commit 1345193
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 49 deletions.
16 changes: 16 additions & 0 deletions py/selenium/webdriver/common/by.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
"""The By implementation."""

from typing import Dict
from typing import Literal
from typing import Optional


class By:
Expand All @@ -31,5 +33,19 @@ class By:
CLASS_NAME = "class name"
CSS_SELECTOR = "css selector"

_custom_finders: Dict[str, str] = {}

@classmethod
def register_custom_finder(cls, name: str, strategy: str) -> None:
cls._custom_finders[name] = strategy

@classmethod
def get_finder(cls, name: str) -> Optional[str]:
return cls._custom_finders.get(name) or getattr(cls, name.upper(), None)

@classmethod
def clear_custom_finders(cls) -> None:
cls._custom_finders.clear()


ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"]
28 changes: 28 additions & 0 deletions py/selenium/webdriver/remote/locator_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


class LocatorConverter:
def convert(self, by, value):
# Default conversion logic
if by == "id":
return "css selector", f'[id="{value}"]'
elif by == "class name":
return "css selector", f".{value}"
elif by == "name":
return "css selector", f'[name="{value}"]'
return by, value
47 changes: 39 additions & 8 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ class RemoteConnection:
)
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()

system = platform.system().lower()
if system == "darwin":
system = "mac"

# Class variables for headers
extra_headers = None
user_agent = f"selenium/{__version__} (python {system})"

@classmethod
def get_timeout(cls):
""":Returns:
Expand Down Expand Up @@ -196,14 +204,10 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
- keep_alive (Boolean) - Is this a keep-alive connection (default: False)
"""

system = platform.system().lower()
if system == "darwin":
system = "mac"

headers = {
"Accept": "application/json",
"Content-Type": "application/json;charset=UTF-8",
"User-Agent": f"selenium/{__version__} (python {system})",
"User-Agent": cls.user_agent,
}

if parsed_url.username:
Expand All @@ -213,6 +217,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
if keep_alive:
headers.update({"Connection": "keep-alive"})

if cls.extra_headers:
headers.update(cls.extra_headers)

return headers

def _get_proxy_url(self):
Expand All @@ -236,7 +243,12 @@ def _separate_http_proxy_auth(self):

def _get_connection_manager(self):
pool_manager_init_args = {"timeout": self.get_timeout()}
if self._ca_certs:
pool_manager_init_args.update(self._init_args_for_pool_manager.get("init_args_for_pool_manager", {}))

if self._ignore_certificates:
pool_manager_init_args["cert_reqs"] = "CERT_NONE"
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
elif self._ca_certs:
pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED"
pool_manager_init_args["ca_certs"] = self._ca_certs

Expand All @@ -252,9 +264,18 @@ def _get_connection_manager(self):

return urllib3.PoolManager(**pool_manager_init_args)

def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_proxy: bool = False):
def __init__(
self,
remote_server_addr: str,
keep_alive: bool = False,
ignore_proxy: bool = False,
ignore_certificates: bool = False,
init_args_for_pool_manager: dict = None,
):
self.keep_alive = keep_alive
self._url = remote_server_addr
self._ignore_certificates = ignore_certificates
self._init_args_for_pool_manager = init_args_for_pool_manager or {}

# Env var NO_PROXY will override this part of the code
_no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY"))
Expand All @@ -280,6 +301,16 @@ def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_pro
self._conn = self._get_connection_manager()
self._commands = remote_commands

extra_commands = {}

def add_command(self, name, method, url):
"""Register a new command."""
self._commands[name] = (method, url)

def get_command(self, name: str):
"""Retrieve a command if it exists."""
return self._commands.get(name)

def execute(self, command, params):
"""Send a command to the remote server.
Expand All @@ -291,7 +322,7 @@ def execute(self, command, params):
- params - A dictionary of named parameters to send with the command as
its JSON payload.
"""
command_info = self._commands[command]
command_info = self._commands.get(command) or self.extra_commands.get(command)
assert command_info is not None, f"Unrecognised command {command}"
path_string = command_info[1]
path = string.Template(path_string).substitute(params)
Expand Down
31 changes: 11 additions & 20 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .errorhandler import ErrorHandler
from .file_detector import FileDetector
from .file_detector import LocalFileDetector
from .locator_converter import LocatorConverter
from .mobile import Mobile
from .remote_connection import RemoteConnection
from .script_key import ScriptKey
Expand Down Expand Up @@ -171,6 +172,8 @@ def __init__(
keep_alive: bool = True,
file_detector: Optional[FileDetector] = None,
options: Optional[Union[BaseOptions, List[BaseOptions]]] = None,
locator_converter: Optional[LocatorConverter] = None,
web_element_cls: Optional[type] = None,
) -> None:
"""Create a new driver that will issue commands using the wire
protocol.
Expand All @@ -183,6 +186,8 @@ def __init__(
- file_detector - Pass custom file detector object during instantiation. If None,
then default LocalFileDetector() will be used.
- options - instance of a driver options.Options class
- locator_converter - Custom locator converter to use. Defaults to None.
- web_element_cls - Custom class to use for web elements. Defaults to WebElement.
"""

if isinstance(options, list):
Expand All @@ -207,6 +212,8 @@ def __init__(
self._switch_to = SwitchTo(self)
self._mobile = Mobile(self)
self.file_detector = file_detector or LocalFileDetector()
self.locator_converter = locator_converter or LocatorConverter()
self._web_element_cls = web_element_cls or self._web_element_cls
self._authenticator_id = None
self.start_client()
self.start_session(capabilities)
Expand Down Expand Up @@ -729,22 +736,14 @@ def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement:
:rtype: WebElement
"""
by, value = self.locator_converter.convert(by, value)

if isinstance(by, RelativeBy):
elements = self.find_elements(by=by, value=value)
if not elements:
raise NoSuchElementException(f"Cannot locate relative element with: {by.root}")
return elements[0]

if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"]

def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElement]:
Expand All @@ -757,22 +756,14 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElemen
:rtype: list of WebElement
"""
by, value = self.locator_converter.convert(by, value)

if isinstance(by, RelativeBy):
_pkg = ".".join(__name__.split(".")[:-1])
raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8")
find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);"
return self.execute_script(find_element_js, by.to_dict())

if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

# Return empty list if driver returns null
# See https://github.com/SeleniumHQ/selenium/issues/4555
return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or []
Expand Down
22 changes: 2 additions & 20 deletions py/selenium/webdriver/remote/webelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,16 +404,7 @@ def find_element(self, by=By.ID, value=None) -> WebElement:
:rtype: WebElement
"""
if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

by, value = self._parent.locator_converter.convert(by, value)
return self._execute(Command.FIND_CHILD_ELEMENT, {"using": by, "value": value})["value"]

def find_elements(self, by=By.ID, value=None) -> List[WebElement]:
Expand All @@ -426,16 +417,7 @@ def find_elements(self, by=By.ID, value=None) -> List[WebElement]:
:rtype: list of WebElement
"""
if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

by, value = self._parent.locator_converter.convert(by, value)
return self._execute(Command.FIND_CHILD_ELEMENTS, {"using": by, "value": value})["value"]

def __hash__(self) -> int:
Expand Down
18 changes: 18 additions & 0 deletions py/test/selenium/webdriver/common/driver_element_finding_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,21 @@ def test_should_not_be_able_to_find_an_element_on_a_blank_page(driver, pages):
driver.get("about:blank")
with pytest.raises(NoSuchElementException):
driver.find_element(By.TAG_NAME, "a")


# custom finders tests


def test_register_and_get_custom_finder():
By.register_custom_finder("custom", "custom strategy")
assert By.get_finder("custom") == "custom strategy"


def test_get_nonexistent_finder():
assert By.get_finder("nonexistent") is None


def test_clear_custom_finders():
By.register_custom_finder("custom", "custom strategy")
By.clear_custom_finders()
assert By.get_finder("custom") is None
50 changes: 50 additions & 0 deletions py/test/selenium/webdriver/remote/custom_element_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webelement import WebElement


# Custom element class
class MyCustomElement(WebElement):
def custom_method(self):
return "Custom element method"


def test_find_element_with_custom_class(driver, pages):
"""Test to ensure custom element class is used for a single element."""
driver._web_element_cls = MyCustomElement
pages.load("simpleTest.html")
element = driver.find_element(By.TAG_NAME, "body")
assert isinstance(element, MyCustomElement)
assert element.custom_method() == "Custom element method"


def test_find_elements_with_custom_class(driver, pages):
"""Test to ensure custom element class is used for multiple elements."""
driver._web_element_cls = MyCustomElement
pages.load("simpleTest.html")
elements = driver.find_elements(By.TAG_NAME, "div")
assert all(isinstance(el, MyCustomElement) for el in elements)
assert all(el.custom_method() == "Custom element method" for el in elements)


def test_default_element_class(driver, pages):
"""Test to ensure default WebElement class is used."""
pages.load("simpleTest.html")
element = driver.find_element(By.TAG_NAME, "body")
assert isinstance(element, WebElement)
assert not hasattr(element, "custom_method")
40 changes: 40 additions & 0 deletions py/test/selenium/webdriver/remote/remote_custom_locator_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.remote.locator_converter import LocatorConverter


class CustomLocatorConverter(LocatorConverter):
def convert(self, by, value):
# Custom conversion logic
if by == "custom":
return "css selector", f'[custom-attr="{value}"]'
return super().convert(by, value)


def test_find_element_with_custom_locator(driver):
driver.get("data:text/html,<div custom-attr='example'>Test</div>")
element = driver.find_element("custom", "example")
assert element is not None
assert element.text == "Test"


def test_find_elements_with_custom_locator(driver):
driver.get("data:text/html,<div custom-attr='example'>Test1</div><div custom-attr='example'>Test2</div>")
elements = driver.find_elements("custom", "example")
assert len(elements) == 2
assert elements[0].text == "Test1"
assert elements[1].text == "Test2"
Loading

0 comments on commit 1345193

Please sign in to comment.