Skip to content

Commit

Permalink
[py] Set user_agent and extra_headers via ClientConfig (#14718)
Browse files Browse the repository at this point in the history

---------

Signed-off-by: Viet Nguyen Duc <[email protected]>
  • Loading branch information
VietND96 authored Nov 9, 2024
1 parent 3e1cb0c commit e202389
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 30 deletions.
55 changes: 44 additions & 11 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import base64
import os
import socket
from enum import Enum
from typing import Optional
from urllib import parse

Expand All @@ -26,6 +27,12 @@
from selenium.webdriver.common.proxy import ProxyType


class AuthType(Enum):
BASIC = "Basic"
BEARER = "Bearer"
X_API_KEY = "X-API-Key"


class ClientConfig:
def __init__(
self,
Expand All @@ -38,8 +45,10 @@ def __init__(
ca_certs: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
auth_type: Optional[AuthType] = AuthType.BASIC,
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
Expand All @@ -51,6 +60,8 @@ def __init__(
self.password = password
self.auth_type = auth_type
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers

self.timeout = (
(
Expand Down Expand Up @@ -198,14 +209,17 @@ def password(self, value: str) -> None:
self._password = value

@property
def auth_type(self) -> str:
def auth_type(self) -> AuthType:
"""Returns the type of authentication to the remote server."""
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
def auth_type(self, value: AuthType) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password."""
using basic with username and password.
:Args: value - AuthType enum value. For others, please use `extra_headers` instead
"""
self._auth_type = value

@property
Expand All @@ -219,6 +233,26 @@ def token(self, value: str) -> None:
auth_type is not basic."""
self._token = value

@property
def user_agent(self) -> str:
"""Returns user agent to be added to the request headers."""
return self._user_agent

@user_agent.setter
def user_agent(self, value: str) -> None:
"""Sets user agent to be added to the request headers."""
self._user_agent = value

@property
def extra_headers(self) -> dict:
"""Returns extra headers to be added to the request."""
return self._extra_headers

@extra_headers.setter
def extra_headers(self, value: dict) -> None:
"""Sets extra headers to be added to the request."""
self._extra_headers = value

def get_proxy_url(self) -> Optional[str]:
"""Returns the proxy URL to use for the connection."""
proxy_type = self.proxy.proxy_type
Expand Down Expand Up @@ -246,13 +280,12 @@ def get_proxy_url(self) -> Optional[str]:

def get_auth_header(self) -> Optional[dict]:
"""Returns the authorization to add to the request headers."""
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
if self.auth_type is AuthType.BASIC and self.username and self.password:
credentials = f"{self.username}:{self.password}"
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
return {"Authorization": f"Basic {encoded_credentials}"}
if auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
if auth_type == "oauth" and self.token:
return {"Authorization": f"OAuth {self.token}"}
return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"}
if self.auth_type is AuthType.BEARER and self.token:
return {"Authorization": f"{AuthType.BEARER.value} {self.token}"}
if self.auth_type is AuthType.X_API_KEY and self.token:
return {f"{AuthType.X_API_KEY.value}": f"{self.token}"}
return None
20 changes: 12 additions & 8 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from base64 import b64encode
from typing import Optional
from urllib import parse
from urllib.parse import urlparse

import urllib3

Expand Down Expand Up @@ -243,6 +244,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
}

if parsed_url.username:
warnings.warn(
"Embedding username and password in URL could be insecure, use ClientConfig instead", stacklevel=2
)
base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode())
headers.update({"Authorization": f"Basic {base64string.decode()}"})

Expand All @@ -255,16 +259,14 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
return headers

def _identify_http_proxy_auth(self):
url = self._proxy_url
url = url[url.find(":") + 3 :]
return "@" in url and len(url[: url.find("@")]) > 0
parsed_url = urlparse(self._proxy_url)
if parsed_url.username and parsed_url.password:
return True

def _separate_http_proxy_auth(self):
url = self._proxy_url
protocol = url[: url.find(":") + 3]
no_protocol = url[len(protocol) :]
auth = no_protocol[: no_protocol.find("@")]
proxy_without_auth = protocol + no_protocol[len(auth) + 1 :]
parsed_url = urlparse(self._proxy_url)
proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
auth = f"{parsed_url.username}:{parsed_url.password}"
return proxy_without_auth, auth

def _get_connection_manager(self):
Expand Down Expand Up @@ -312,6 +314,8 @@ def __init__(
RemoteConnection._timeout = self._client_config.timeout
RemoteConnection._ca_certs = self._client_config.ca_certs
RemoteConnection._client_config = self._client_config
RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers
RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent

if remote_server_addr:
warnings.warn(
Expand Down
Loading

0 comments on commit e202389

Please sign in to comment.