Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🚀 Feature] [py]: Better compatibility with Appium-python #14587

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions py/selenium/webdriver/common/by.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
"""The By implementation."""

from typing import Dict
from typing import Literal
from typing import Optional


class By:
Expand All @@ -31,5 +33,19 @@ class By:
CLASS_NAME = "class name"
CSS_SELECTOR = "css selector"

_custom_finders: Dict[str, str] = {}

@classmethod
def register_custom_finder(cls, name: str, strategy: str) -> None:
cls._custom_finders[name] = strategy

@classmethod
def get_finder(cls, name: str) -> Optional[str]:
return cls._custom_finders.get(name) or getattr(cls, name.upper(), None)

@classmethod
def clear_custom_finders(cls) -> None:
cls._custom_finders.clear()


ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"]
28 changes: 28 additions & 0 deletions py/selenium/webdriver/remote/locator_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


class LocatorConverter:
def convert(self, by, value):
# Default conversion logic
if by == "id":
return "css selector", f'[id="{value}"]'
elif by == "class name":
return "css selector", f".{value}"
elif by == "name":
return "css selector", f'[name="{value}"]'
return by, value
44 changes: 35 additions & 9 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ class RemoteConnection:
)
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()

system = platform.system().lower()
if system == "darwin":
system = "mac"

# Class variables for headers
extra_headers = None
user_agent = f"selenium/{__version__} (python {system}"
navin772 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def get_timeout(cls):
""":Returns:
Expand Down Expand Up @@ -196,14 +204,10 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
- keep_alive (Boolean) - Is this a keep-alive connection (default: False)
"""

system = platform.system().lower()
if system == "darwin":
system = "mac"

headers = {
"Accept": "application/json",
"Content-Type": "application/json;charset=UTF-8",
"User-Agent": f"selenium/{__version__} (python {system})",
"User-Agent": cls.user_agent,
}

if parsed_url.username:
Expand All @@ -213,6 +217,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
if keep_alive:
headers.update({"Connection": "keep-alive"})

if cls.extra_headers:
headers.update(cls.extra_headers)

return headers

def _get_proxy_url(self):
Expand All @@ -234,9 +241,14 @@ def _separate_http_proxy_auth(self):
proxy_without_auth = protocol + no_protocol[len(auth) + 1 :]
return proxy_without_auth, auth

def _get_connection_manager(self):
def _get_connection_manager(self, **pool_manager_kwargs):
pool_manager_init_args = {"timeout": self.get_timeout()}
if self._ca_certs:
pool_manager_init_args.update(pool_manager_kwargs.get("init_args_for_pool_manager", {}))

if self._ignore_certificates:
pool_manager_init_args["cert_reqs"] = "CERT_NONE"
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
elif self._ca_certs:
pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED"
pool_manager_init_args["ca_certs"] = self._ca_certs

Expand All @@ -252,9 +264,16 @@ def _get_connection_manager(self):

return urllib3.PoolManager(**pool_manager_init_args)

def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_proxy: bool = False):
def __init__(
self,
remote_server_addr: str,
keep_alive: bool = False,
ignore_proxy: bool = False,
ignore_certificates: bool = False,
navin772 marked this conversation as resolved.
Show resolved Hide resolved
navin772 marked this conversation as resolved.
Show resolved Hide resolved
):
self.keep_alive = keep_alive
self._url = remote_server_addr
self._ignore_certificates = ignore_certificates

# Env var NO_PROXY will override this part of the code
_no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY"))
Expand All @@ -280,6 +299,13 @@ def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_pro
self._conn = self._get_connection_manager()
self._commands = remote_commands

extra_commands = {}

@classmethod
def add_command(cls, name, method, url):
navin772 marked this conversation as resolved.
Show resolved Hide resolved
"""Register a new command."""
cls.extra_commands[name] = (method, url)

def execute(self, command, params):
"""Send a command to the remote server.

Expand All @@ -291,7 +317,7 @@ def execute(self, command, params):
- params - A dictionary of named parameters to send with the command as
its JSON payload.
"""
command_info = self._commands[command]
command_info = self._commands.get(command) or self.extra_commands.get(command)
assert command_info is not None, f"Unrecognised command {command}"
path_string = command_info[1]
path = string.Template(path_string).substitute(params)
Expand Down
31 changes: 11 additions & 20 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .errorhandler import ErrorHandler
from .file_detector import FileDetector
from .file_detector import LocalFileDetector
from .locator_converter import LocatorConverter
from .mobile import Mobile
from .remote_connection import RemoteConnection
from .script_key import ScriptKey
Expand Down Expand Up @@ -171,6 +172,8 @@ def __init__(
keep_alive: bool = True,
file_detector: Optional[FileDetector] = None,
options: Optional[Union[BaseOptions, List[BaseOptions]]] = None,
locator_converter: Optional[LocatorConverter] = None,
navin772 marked this conversation as resolved.
Show resolved Hide resolved
web_element_cls: Optional[type] = None,
navin772 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Create a new driver that will issue commands using the wire
protocol.
Expand All @@ -183,6 +186,8 @@ def __init__(
- file_detector - Pass custom file detector object during instantiation. If None,
then default LocalFileDetector() will be used.
- options - instance of a driver options.Options class
- locator_converter - Custom locator converter to use. Defaults to None.
- web_element_cls - Custom class to use for web elements. Defaults to WebElement.
"""

if isinstance(options, list):
Expand All @@ -207,6 +212,8 @@ def __init__(
self._switch_to = SwitchTo(self)
self._mobile = Mobile(self)
self.file_detector = file_detector or LocalFileDetector()
self.locator_converter = locator_converter or LocatorConverter()
self._web_element_cls = web_element_cls or self._web_element_cls
self._authenticator_id = None
self.start_client()
self.start_session(capabilities)
Expand Down Expand Up @@ -729,22 +736,14 @@ def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement:

:rtype: WebElement
"""
by, value = self.locator_converter.convert(by, value)

if isinstance(by, RelativeBy):
elements = self.find_elements(by=by, value=value)
if not elements:
raise NoSuchElementException(f"Cannot locate relative element with: {by.root}")
return elements[0]

if by == By.ID:
navin772 marked this conversation as resolved.
Show resolved Hide resolved
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"]

def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElement]:
Expand All @@ -757,22 +756,14 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElemen

:rtype: list of WebElement
"""
by, value = self.locator_converter.convert(by, value)

if isinstance(by, RelativeBy):
_pkg = ".".join(__name__.split(".")[:-1])
raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8")
find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);"
return self.execute_script(find_element_js, by.to_dict())

if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

# Return empty list if driver returns null
# See https://github.com/SeleniumHQ/selenium/issues/4555
return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or []
Expand Down
22 changes: 2 additions & 20 deletions py/selenium/webdriver/remote/webelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,16 +404,7 @@ def find_element(self, by=By.ID, value=None) -> WebElement:

:rtype: WebElement
"""
if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

by, value = self._parent.locator_converter.convert(by, value)
return self._execute(Command.FIND_CHILD_ELEMENT, {"using": by, "value": value})["value"]

def find_elements(self, by=By.ID, value=None) -> List[WebElement]:
Expand All @@ -426,16 +417,7 @@ def find_elements(self, by=By.ID, value=None) -> List[WebElement]:

:rtype: list of WebElement
"""
if by == By.ID:
by = By.CSS_SELECTOR
value = f'[id="{value}"]'
elif by == By.CLASS_NAME:
by = By.CSS_SELECTOR
value = f".{value}"
elif by == By.NAME:
by = By.CSS_SELECTOR
value = f'[name="{value}"]'

by, value = self._parent.locator_converter.convert(by, value)
return self._execute(Command.FIND_CHILD_ELEMENTS, {"using": by, "value": value})["value"]

def __hash__(self) -> int:
Expand Down
18 changes: 18 additions & 0 deletions py/test/selenium/webdriver/common/driver_element_finding_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,21 @@ def test_should_not_be_able_to_find_an_element_on_a_blank_page(driver, pages):
driver.get("about:blank")
with pytest.raises(NoSuchElementException):
driver.find_element(By.TAG_NAME, "a")


# custom finders tests


def test_register_and_get_custom_finder():
By.register_custom_finder("custom", "custom strategy")
assert By.get_finder("custom") == "custom strategy"


def test_get_nonexistent_finder():
assert By.get_finder("nonexistent") is None


def test_clear_custom_finders():
By.register_custom_finder("custom", "custom strategy")
By.clear_custom_finders()
assert By.get_finder("custom") is None
61 changes: 61 additions & 0 deletions py/test/selenium/webdriver/remote/custom_element_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webelement import WebElement


# Custom element class
class MyCustomElement(WebElement):
def custom_method(self):
return "Custom element method"


@pytest.fixture
def driver():
options = webdriver.ChromeOptions()
navin772 marked this conversation as resolved.
Show resolved Hide resolved
driver = webdriver.Chrome(options=options)
yield driver
driver.quit()


def test_find_element_with_custom_class(driver, pages):
"""Test to ensure custom element class is used for a single element."""
driver._web_element_cls = MyCustomElement
pages.load("simpleTest.html")
element = driver.find_element(By.TAG_NAME, "body")
assert isinstance(element, MyCustomElement)
assert element.custom_method() == "Custom element method"


def test_find_elements_with_custom_class(driver, pages):
"""Test to ensure custom element class is used for multiple elements."""
driver._web_element_cls = MyCustomElement
pages.load("simpleTest.html")
elements = driver.find_elements(By.TAG_NAME, "div")
assert all(isinstance(el, MyCustomElement) for el in elements)
assert all(el.custom_method() == "Custom element method" for el in elements)


def test_default_element_class(driver, pages):
"""Test to ensure default WebElement class is used."""
pages.load("simpleTest.html")
element = driver.find_element(By.TAG_NAME, "body")
assert isinstance(element, WebElement)
assert not hasattr(element, "custom_method")
Loading
Loading