[networking] Add strict Request extension checking (#7604)

Authored by: coletdjnz
Co-authored-by: pukkandan <pukkandan.ytdlp@gmail.com>
This commit is contained in:
coletdjnz 2023-07-23 17:17:15 +12:00 committed by GitHub
parent 11de6fec9c
commit 86aea0d3a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 35 deletions

View file

@ -804,10 +804,10 @@ def test_httplib_validation_errors(self, handler):
assert not isinstance(exc_info.value, TransportError) assert not isinstance(exc_info.value, TransportError)
def run_validation(handler, fail, req, **handler_kwargs): def run_validation(handler, error, req, **handler_kwargs):
with handler(**handler_kwargs) as rh: with handler(**handler_kwargs) as rh:
if fail: if error:
with pytest.raises(UnsupportedRequest): with pytest.raises(error):
rh.validate(req) rh.validate(req)
else: else:
rh.validate(req) rh.validate(req)
@ -824,6 +824,9 @@ class NoCheckRH(ValidationRH):
_SUPPORTED_PROXY_SCHEMES = None _SUPPORTED_PROXY_SCHEMES = None
_SUPPORTED_URL_SCHEMES = None _SUPPORTED_URL_SCHEMES = None
def _check_extensions(self, extensions):
extensions.clear()
class HTTPSupportedRH(ValidationRH): class HTTPSupportedRH(ValidationRH):
_SUPPORTED_URL_SCHEMES = ('http',) _SUPPORTED_URL_SCHEMES = ('http',)
@ -834,26 +837,26 @@ class HTTPSupportedRH(ValidationRH):
('https', False, {}), ('https', False, {}),
('data', False, {}), ('data', False, {}),
('ftp', False, {}), ('ftp', False, {}),
('file', True, {}), ('file', UnsupportedRequest, {}),
('file', False, {'enable_file_urls': True}), ('file', False, {'enable_file_urls': True}),
]), ]),
(NoCheckRH, [('http', False, {})]), (NoCheckRH, [('http', False, {})]),
(ValidationRH, [('http', True, {})]) (ValidationRH, [('http', UnsupportedRequest, {})])
] ]
PROXY_SCHEME_TESTS = [ PROXY_SCHEME_TESTS = [
# scheme, expected to fail # scheme, expected to fail
('Urllib', [ ('Urllib', [
('http', False), ('http', False),
('https', True), ('https', UnsupportedRequest),
('socks4', False), ('socks4', False),
('socks4a', False), ('socks4a', False),
('socks5', False), ('socks5', False),
('socks5h', False), ('socks5h', False),
('socks', True), ('socks', UnsupportedRequest),
]), ]),
(NoCheckRH, [('http', False)]), (NoCheckRH, [('http', False)]),
(HTTPSupportedRH, [('http', True)]), (HTTPSupportedRH, [('http', UnsupportedRequest)]),
] ]
PROXY_KEY_TESTS = [ PROXY_KEY_TESTS = [
@ -863,8 +866,22 @@ class HTTPSupportedRH(ValidationRH):
('unrelated', False), ('unrelated', False),
]), ]),
(NoCheckRH, [('all', False)]), (NoCheckRH, [('all', False)]),
(HTTPSupportedRH, [('all', True)]), (HTTPSupportedRH, [('all', UnsupportedRequest)]),
(HTTPSupportedRH, [('no', True)]), (HTTPSupportedRH, [('no', UnsupportedRequest)]),
]
EXTENSION_TESTS = [
('Urllib', [
({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': CookieJar()}, False),
({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest),
]),
(NoCheckRH, [
({'cookiejar': 'notacookiejar'}, False),
({'somerandom': 'test'}, False), # but any extension is allowed through
]),
] ]
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@ -907,15 +924,16 @@ def test_empty_proxy(self, handler):
@pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1']) @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_missing_proxy_scheme(self, handler, proxy_url): def test_missing_proxy_scheme(self, handler, proxy_url):
run_validation(handler, True, Request('http://', proxies={'http': 'example.com'})) run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'}))
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler,extensions,fail', [
def test_cookiejar_extension(self, handler): (handler_tests[0], extensions, fail)
run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'})) for handler_tests in EXTENSION_TESTS
for extensions, fail in handler_tests[1]
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) ], indirect=['handler'])
def test_timeout_extension(self, handler): def test_extension(self, handler, extensions, fail):
run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'})) run_validation(
handler, fail, Request('http://', extensions=extensions))
def test_invalid_request_type(self): def test_invalid_request_type(self):
rh = self.ValidationRH(logger=FakeLogger()) rh = self.ValidationRH(logger=FakeLogger())

View file

@ -385,6 +385,11 @@ def __init__(self, *, enable_file_urls: bool = False, **kwargs):
if self.enable_file_urls: if self.enable_file_urls:
self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file') self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
def _create_instance(self, proxies, cookiejar): def _create_instance(self, proxies, cookiejar):
opener = urllib.request.OpenerDirector() opener = urllib.request.OpenerDirector()
handlers = [ handlers = [

View file

@ -21,6 +21,7 @@
TransportError, TransportError,
UnsupportedRequest, UnsupportedRequest,
) )
from ..compat.types import NoneType
from ..utils import ( from ..utils import (
bug_reports_message, bug_reports_message,
classproperty, classproperty,
@ -147,6 +148,7 @@ class RequestHandler(abc.ABC):
a proxy url with an url scheme not in this list will raise an UnsupportedRequest. a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
- `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum. - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
The above may be set to None to disable the checks. The above may be set to None to disable the checks.
Parameters: Parameters:
@ -169,9 +171,14 @@ class RequestHandler(abc.ABC):
Requests may have additional optional parameters defined as extensions. Requests may have additional optional parameters defined as extensions.
RequestHandler subclasses may choose to support custom extensions. RequestHandler subclasses may choose to support custom extensions.
If an extension is supported, subclasses should extend _check_extensions(extensions)
to pop and validate the extension.
- Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
The following extensions are defined for RequestHandler: The following extensions are defined for RequestHandler:
- `cookiejar`: Cookiejar to use for this request - `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request - `timeout`: socket timeout to use for this request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys: Apart from the url protocol, proxies dict may contain the following keys:
- `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol. - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
@ -263,26 +270,19 @@ def _check_proxies(self, proxies):
if scheme not in self._SUPPORTED_PROXY_SCHEMES: if scheme not in self._SUPPORTED_PROXY_SCHEMES:
raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"') raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
def _check_cookiejar_extension(self, extensions):
if not extensions.get('cookiejar'):
return
if not isinstance(extensions['cookiejar'], CookieJar):
raise UnsupportedRequest('cookiejar is not a CookieJar')
def _check_timeout_extension(self, extensions):
if extensions.get('timeout') is None:
return
if not isinstance(extensions['timeout'], (float, int)):
raise UnsupportedRequest('timeout is not a float or int')
def _check_extensions(self, extensions): def _check_extensions(self, extensions):
self._check_cookiejar_extension(extensions) """Check extensions for unsupported extensions. Subclasses should extend this."""
self._check_timeout_extension(extensions) assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
def _validate(self, request): def _validate(self, request):
self._check_url_scheme(request) self._check_url_scheme(request)
self._check_proxies(request.proxies or self.proxies) self._check_proxies(request.proxies or self.proxies)
self._check_extensions(request.extensions) extensions = request.extensions.copy()
self._check_extensions(extensions)
if extensions:
# TODO: add support for optional extensions
raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
@wrap_request_errors @wrap_request_errors
def validate(self, request: Request): def validate(self, request: Request):