Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use getClientAddress instead of getClientIP. #12599

Merged
merged 4 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12599.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `getClientAddress` instead of the deprecated `getClientIP`.
4 changes: 2 additions & 2 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def _wrapped_get_user_by_req(
Once get_user_by_req has set up the opentracing span, this does the actual work.
"""
try:
ip_addr = request.getClientIP()
ip_addr = request.getClientAddress().host
user_agent = get_request_user_agent(request)

access_token = self.get_access_token_from_request(request)
Expand Down Expand Up @@ -356,7 +356,7 @@ async def _get_appservice_user_id_and_device_id(
return None, None, None

if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientIP())
ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist:
return None, None, None

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ async def check_ui_auth(
await self.store.set_ui_auth_clientdict(sid, clientdict)

user_agent = get_request_user_agent(request)
clientip = request.getClientIP()
clientip = request.getClientAddress().host

await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def ratelimit_request_token_requests(
"""

await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientIP())
None, (medium, request.getClientAddress().host)
)
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ async def complete_sso_login_request(
auth_provider_id,
remote_user_id,
get_request_user_agent(request),
request.getClientIP(),
request.getClientAddress().host,
)
new_user = True
elif self._sso_update_profile_information:
Expand Down Expand Up @@ -928,7 +928,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None:
session.auth_provider_id,
session.remote_user_id,
get_request_user_agent(request),
request.getClientIP(),
request.getClientAddress().host,
)

logger.info(
Expand Down
6 changes: 3 additions & 3 deletions synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def render(self, resrc: Resource) -> None:
request_id,
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
ip_address=self.getClientAddress().host,
site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
Expand Down Expand Up @@ -381,7 +381,7 @@ def _started_processing(self, servlet_name: str) -> None:

self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.getClientAddress().host,
self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
Expand Down Expand Up @@ -429,7 +429,7 @@ def _finished_processing(self) -> None:
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.getClientAddress().host,
self.synapse_site.site_tag,
requester,
processing_time,
Expand Down
2 changes: 1 addition & 1 deletion synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
tags.PEER_HOST_IPV6: request.getClientAddress().host,
}

request_name = request.request_metrics.name
Expand Down
8 changes: 5 additions & 3 deletions synapse/rest/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientIP()
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
)
except LoginError as e:
# Authentication failed, let user try again
Expand All @@ -132,7 +132,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientIP()
LoginType.TERMS, authdict, request.getClientAddress().host
)
except LoginError as e:
# Authentication failed, let user try again
Expand Down Expand Up @@ -161,7 +161,9 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
LoginType.REGISTRATION_TOKEN,
authdict,
request.getClientAddress().host,
)
except LoginError as e:
html = self.registration_token_template.render(
Expand Down
14 changes: 10 additions & 4 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:

if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit(
None, request.getClientIP()
None, request.getClientAddress().host
)

result = await self._do_appservice_login(
Expand All @@ -195,19 +195,25 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if self.inhibit_user_in_use_error:
return 200, {"available": True}

ip = request.getClientIP()
ip = request.getClientAddress().host
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred

Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(self, hs: "HomeServer"):
)

async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))

if not self.hs.config.registration.enable_registration:
raise SynapseError(
Expand Down Expand Up @@ -441,7 +441,7 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

client_addr = request.getClientIP()
client_addr = request.getClientAddress().host

await self.ratelimiter.ratelimit(None, client_addr, update=False)

Expand Down
18 changes: 9 additions & 9 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -124,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.getClientAddress.return_value.host = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -143,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.getClientAddress.return_value.host = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
f = self.get_failure(
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -209,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -236,7 +236,7 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
self.store.get_device = simple_async_mock({"hidden": False})

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
self.store.get_device = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
Expand All @@ -288,7 +288,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -305,7 +305,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def _build_callback_request(
"getCookie",
"cookies",
"requestHeaders",
"getClientIP",
"getClientAddress",
"getHeader",
]
)
Expand All @@ -1310,5 +1310,5 @@ def _build_callback_request(
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.getClientAddress.return_value.host = ip_address
return request
2 changes: 1 addition & 1 deletion tests/handlers/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",
Expand Down
16 changes: 10 additions & 6 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
self.assertEqual(port, 8765)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(client_address)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why the change from calling buildProtocol with client_address as the parameter vs None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The address is supposed to be passed, None is not a legal value there: https://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IProtocolFactory.html#buildProtocol

Although now that I check it is supposed to be a tuple version, not an IAddress instance. 😢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it just not working then, or did it work but you're changing it to be more consistent with the proper usage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it just not working then, or did it work but you're changing it to be more consistent with the proper usage?

I think (currently) the value never gets used, I was hoping this would fix the bug, but the "real" fix was adding it to the transport. It seemed best to leave it as it was more correct to pass a value here than None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thanks for answering!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM assuming tests pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thanks for answering!

No problem!

I think the tl;dr is "It was wrong, but fixing it doesn't fix it. Let's include the fix anyway" 😆


# Set up the server side protocol
channel = self.site.buildProtocol(None)
server_address = IPv4Address("TCP", host, port)
channel = self.site.buildProtocol(server_address)

# hook into the channel's request factory so that we can keep a record
# of the requests
Expand All @@ -173,12 +175,12 @@ def request_factory(*args, **kwargs):

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)

Expand Down Expand Up @@ -406,19 +408,21 @@ def _handle_http_replication_attempt(self, hs, repl_port):
self.assertEqual(port, repl_port)

# Set up client side protocol
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(None)

# Set up the server side protocol
server_address = IPv4Address("TCP", host, port)
channel = self._hs_to_site[hs].buildProtocol(None)

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)

Expand Down
13 changes: 8 additions & 5 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def requestDone(self, _self):
self.resource_usage = _self.logcontext.get_resource_usage()

def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)

Expand Down Expand Up @@ -562,7 +562,10 @@ class FakeTransport:
"""

_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""
"""The value to be returned by getPeer"""

_host_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returned by getHost"""

disconnecting = False
disconnected = False
Expand All @@ -571,11 +574,11 @@ class FakeTransport:
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)

def getPeer(self):
def getPeer(self) -> Optional[IAddress]:
return self._peer_address

def getHost(self):
return None
def getHost(self) -> Optional[IAddress]:
return self._host_address

def loseConnection(self, reason=None):
if not self.disconnecting:
Expand Down