diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a3b548..d30e842 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,13 @@ repos: additional_dependencies: - "flake8-bugbear==24.8.19" + - repo: "https://github.com/pre-commit/mirrors-mypy" + rev: "v1.19.1" + hooks: + - id: "mypy" + additional_dependencies: + - "flask>=3.0" + - repo: "https://github.com/python-jsonschema/check-jsonschema" rev: "0.36.0" hooks: diff --git a/flask_compress/__init__.py b/flask_compress/__init__.py index 38109c7..2b424e2 100644 --- a/flask_compress/__init__.py +++ b/flask_compress/__init__.py @@ -1,4 +1,4 @@ -from .flask_compress import Compress, DictCache +from .flask_compress import CacheBackend, Compress, DictCache # _version.py is generated by setuptools_scm when building the package. # It is not version-controlled, so if it is missing, this likely means that @@ -9,4 +9,4 @@ __version__ = "0" -__all__ = ("Compress", "DictCache") +__all__ = ("CacheBackend", "Compress", "DictCache") diff --git a/flask_compress/flask_compress.py b/flask_compress/flask_compress.py index b87234f..5808921 100644 --- a/flask_compress/flask_compress.py +++ b/flask_compress/flask_compress.py @@ -2,34 +2,44 @@ # Copyright (c) 2013-2017 William Fagan # License: The MIT License (MIT) +from __future__ import annotations + import functools from collections import defaultdict +from collections.abc import Callable, Iterator from functools import lru_cache +from typing import Any, Protocol try: import brotlicffi as brotli except ImportError: import brotli -from flask import after_this_request, current_app, request, stream_with_context +from flask import Flask, after_this_request, current_app, request, stream_with_context +from flask.wrappers import Response from .compat import compression +class CacheBackend(Protocol): + def get(self, key: str) -> bytes | None: ... + def set(self, key: str, value: bytes) -> bool | None: ... + + class DictCache: - def __init__(self): - self.data = {} + def __init__(self) -> None: + self.data: dict[str, bytes] = {} - def get(self, key): + def get(self, key: str) -> bytes | None: return self.data.get(key) - def set(self, key, value): + def set(self, key: str, value: bytes) -> None: self.data[key] = value @lru_cache(maxsize=128) -def _choose_algorithm(algorithms, accept_encoding): +def _choose_algorithm(algorithms: tuple[str, ...], accept_encoding: str) -> str | None: """ Determine which compression algorithm we're going to use based on the client request. The `Accept-Encoding` header may list one or more desired @@ -46,7 +56,7 @@ def _choose_algorithm(algorithms, accept_encoding): fallback_to_any = False # Map quality factors to requested algorithm names. - algos_by_quality = defaultdict(set) + algos_by_quality: defaultdict[float, set[str | None]] = defaultdict(set) # Set of supported algorithms server_algos_set = set(algorithms) @@ -96,7 +106,7 @@ def _choose_algorithm(algorithms, accept_encoding): return None -def _format(algo): +def _format(algo: str | list[str]) -> tuple[str, ...]: """Format the algorithm configuration into a tuple of strings. >>> _format("gzip, deflate, br") @@ -122,7 +132,14 @@ class Compress: :type app: :class:`flask.Flask` or None """ - def __init__(self, app=None): + cache: CacheBackend | None + cache_key: Callable[..., str] | None + compress_mimetypes_set: set[str] + enabled_algorithms: tuple[str, ...] + streaming_algorithms: tuple[str, ...] + streaming_endpoint_with_conditional: set[str] + + def __init__(self, app: Flask | None = None) -> None: """ An alternative way to pass your :class:`flask.Flask` application object to Flask-Compress. :meth:`init_app` also takes care of some @@ -134,7 +151,7 @@ def __init__(self, app=None): if app is not None: self.init_app(app) - def init_app(self, app): + def init_app(self, app: Flask) -> None: defaults = [ ( "COMPRESS_MIMETYPES", @@ -202,7 +219,7 @@ def init_app(self, app): if app.config["COMPRESS_REGISTER"] and app.config["COMPRESS_MIMETYPES"]: app.after_request(self.after_request) - def after_request(self, response): + def after_request(self, response: Response) -> Response: app = self.app or current_app vary = response.headers.get("Vary") @@ -247,6 +264,7 @@ def after_request(self, response): response.headers.pop("Content-Length", None) else: if self.cache is not None: + assert self.cache_key is not None key = f"{chosen_algorithm};{self.cache_key(request)}" compressed_content = self.cache.get(key) if compressed_content is None: @@ -265,7 +283,7 @@ def after_request(self, response): etag, is_weak = response.get_etag() if etag and not is_weak: - response.set_etag(f"{etag}:{chosen_algorithm}", weak=is_weak) + response.set_etag(f"{etag}:{chosen_algorithm}", weak=False) if ( app.config["COMPRESS_EVALUATE_CONDITIONAL_REQUEST"] @@ -276,12 +294,12 @@ def after_request(self, response): return response - def compressed(self): - def decorator(f): + def compressed(self) -> Callable[..., Callable[..., Any]]: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(f) - def decorated_function(*args, **kwargs): + def decorated_function(*args: Any, **kwargs: Any) -> Any: @after_this_request - def compressor(response): + def compressor(response: Response) -> Response: return self.after_request(response) return f(*args, **kwargs) @@ -291,18 +309,24 @@ def compressor(response): return decorator -def _compress_data(app, data, algorithm): +def _compress_data(app: Flask, data: bytes, algorithm: str) -> bytes: if algorithm == "zstd": - return compression.zstd.compress(data, app.config["COMPRESS_ZSTD_LEVEL"]) + return compression.zstd.compress( # type: ignore[no-any-return] + data, app.config["COMPRESS_ZSTD_LEVEL"] + ) if algorithm == "gzip": - return compression.gzip.compress(data, app.config["COMPRESS_LEVEL"]) + return compression.gzip.compress( # type: ignore[no-any-return] + data, app.config["COMPRESS_LEVEL"] + ) if algorithm == "deflate": - return compression.zlib.compress(data, app.config["COMPRESS_DEFLATE_LEVEL"]) + return compression.zlib.compress( # type: ignore[no-any-return] + data, app.config["COMPRESS_DEFLATE_LEVEL"] + ) if algorithm == "br": - return brotli.compress( + return brotli.compress( # type: ignore[no-any-return] data, mode=app.config["COMPRESS_BR_MODE"], quality=app.config["COMPRESS_BR_LEVEL"], @@ -313,21 +337,23 @@ def _compress_data(app, data, algorithm): raise ValueError(f"Unknown compression algorithm: {algorithm}") -def _uncompress_data(data, algorithm): +def _uncompress_data(data: bytes, algorithm: str) -> bytes: # This is used for tests purposes only. if algorithm == "zstd": - return compression.zstd.decompress(data) + return compression.zstd.decompress(data) # type: ignore[no-any-return] if algorithm == "gzip": - return compression.gzip.decompress(data) + return compression.gzip.decompress(data) # type: ignore[no-any-return] if algorithm == "deflate": - return compression.zlib.decompress(data) + return compression.zlib.decompress(data) # type: ignore[no-any-return] if algorithm == "br": - return brotli.decompress(data) + return brotli.decompress(data) # type: ignore[no-any-return] raise ValueError(f"Unknown compression algorithm: {algorithm}") -def _compress_chunks(app, chunks, algorithm): +def _compress_chunks( + app: Flask, chunks: Iterator[bytes], algorithm: str +) -> Iterator[bytes]: if algorithm == "zstd": level = app.config["COMPRESS_ZSTD_LEVEL"] compressor = compression.zstd.ZstdCompressor(level=level) diff --git a/flask_compress/py.typed b/flask_compress/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 0154742..dbd4455 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,14 @@ fail_under = 94 [tool.coverage.html] directory = "htmlcov/" skip_covered = false + + +# type checking +# -------------- + +[tool.mypy] +strict = true + +[[tool.mypy.overrides]] +module = ["brotli", "brotlicffi", "backports.*", "flask_caching"] +ignore_missing_imports = true diff --git a/setup.py b/setup.py index 1944998..4bc7bd3 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown", packages=find_packages(exclude=["tests"]), + package_data={"flask_compress": ["py.typed"]}, include_package_data=True, platforms="any", python_requires=">=3.9", diff --git a/tests/test_flask_compress.py b/tests/test_flask_compress.py index b56d595..5da23da 100644 --- a/tests/test_flask_compress.py +++ b/tests/test_flask_compress.py @@ -2,15 +2,19 @@ import os import tempfile import unittest +from collections.abc import Iterator from flask import ( Flask, + Request, + Response, make_response, render_template, request, stream_with_context, ) from flask_caching import Cache +from werkzeug.test import TestResponse from flask_compress import Compress, DictCache from flask_compress.flask_compress import _choose_algorithm, _uncompress_data @@ -19,13 +23,13 @@ class DefaultsTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True Compress(self.app) - def test_mimetypes_default(self): + def test_mimetypes_default(self) -> None: """Tests COMPRESS_MIMETYPES default value is correctly set.""" defaults = [ "text/html", @@ -57,60 +61,60 @@ def test_mimetypes_default(self): ] self.assertEqual(self.app.config["COMPRESS_MIMETYPES"], defaults) - def test_level_default(self): + def test_level_default(self) -> None: """Tests COMPRESS_LEVEL default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_LEVEL"], 6) - def test_min_size_default(self): + def test_min_size_default(self) -> None: """Tests COMPRESS_MIN_SIZE default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_MIN_SIZE"], 500) - def test_algorithm_default(self): + def test_algorithm_default(self) -> None: """Tests COMPRESS_ALGORITHM default value is correctly set.""" self.assertEqual( self.app.config["COMPRESS_ALGORITHM"], ["zstd", "br", "gzip", "deflate"] ) - def test_algorithm_streaming(self): + def test_algorithm_streaming(self) -> None: """Tests COMPRESS_ALGORITHM_STREAMING default value is correctly set.""" self.assertEqual( self.app.config["COMPRESS_ALGORITHM_STREAMING"], ["zstd", "br", "deflate"] ) - def test_default_deflate_settings(self): + def test_default_deflate_settings(self) -> None: """Tests COMPRESS_DELATE_LEVEL default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_DEFLATE_LEVEL"], -1) - def test_mode_default(self): + def test_mode_default(self) -> None: """Tests COMPRESS_BR_MODE default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_BR_MODE"], 0) - def test_quality_level_default(self): + def test_quality_level_default(self) -> None: """Tests COMPRESS_BR_LEVEL default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_BR_LEVEL"], 4) - def test_window_size_default(self): + def test_window_size_default(self) -> None: """Tests COMPRESS_BR_WINDOW default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_BR_WINDOW"], 22) - def test_block_size_default(self): + def test_block_size_default(self) -> None: """Tests COMPRESS_BR_BLOCK default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_BR_BLOCK"], 0) - def test_stream(self): + def test_stream(self) -> None: """Tests COMPRESS_STREAMS default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_STREAMS"], True) - def test_quality_level_default_zstd(self): + def test_quality_level_default_zstd(self) -> None: """Tests COMPRESS_ZSTD_LEVEL default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_ZSTD_LEVEL"], 3) - def test_evaluate_conditional_request(self): + def test_evaluate_conditional_request(self) -> None: """Tests COMPRESS_EVALUATE_CONDITIONAL_REQUEST default value is correctly set.""" self.assertEqual(self.app.config["COMPRESS_EVALUATE_CONDITIONAL_REQUEST"], True) - def test_streaming_endpoint_conditional(self): + def test_streaming_endpoint_conditional(self) -> None: """Tests COMPRESS_STREAMING_ENDPOINT_CONDITIONAL default value is correctly set.""" self.assertEqual( @@ -119,20 +123,20 @@ def test_streaming_endpoint_conditional(self): class InitTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True - def test_constructor_init(self): + def test_constructor_init(self) -> None: Compress(self.app) - def test_delayed_init(self): + def test_delayed_init(self) -> None: compress = Compress() compress.init_app(self.app) class UrlTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True @@ -145,14 +149,14 @@ def setUp(self): Compress(self.app) @self.app.route("/small/") - def small(): + def small() -> str: return render_template("small.html") @self.app.route("/large/") - def large(): + def large() -> str: return render_template("large.html") - def test_compressed_content(self): + def test_compressed_content(self) -> None: client = self.app.test_client() with open(self.large_path, "rb") as f: original_data = f.read().rstrip() # flask strips trailing newline @@ -166,13 +170,13 @@ def test_compressed_content(self): self.assertGreater(self.large_size, len(response.data)) self.assertEqual(original_data, _uncompress_data(response.data, algorithm)) - def client_get(self, ufs): + def client_get(self, ufs: str) -> TestResponse: client = self.app.test_client() response = client.get(ufs, headers=[("Accept-Encoding", "gzip")]) self.assertEqual(response.status_code, 200) return response - def test_br_algorithm(self): + def test_br_algorithm(self) -> None: client = self.app.test_client() headers = [("Accept-Encoding", "br")] @@ -182,7 +186,7 @@ def test_br_algorithm(self): response = client.options("/large/", headers=headers) self.assertEqual(response.status_code, 200) - def test_zstd_algorithm(self): + def test_zstd_algorithm(self) -> None: client = self.app.test_client() headers = [("Accept-Encoding", "zstd")] @@ -192,7 +196,7 @@ def test_zstd_algorithm(self): response = client.options("/large/", headers=headers) self.assertEqual(response.status_code, 200) - def test_compress_min_size(self): + def test_compress_min_size(self) -> None: """Tests COMPRESS_MIN_SIZE correctly affects response data.""" response = self.client_get("/small/") self.assertEqual(self.small_size, len(response.data)) @@ -200,18 +204,18 @@ def test_compress_min_size(self): response = self.client_get("/large/") self.assertNotEqual(self.large_size, len(response.data)) - def test_mimetype_mismatch(self): + def test_mimetype_mismatch(self) -> None: """Tests if mimetype not in COMPRESS_MIMETYPES.""" response = self.client_get("/static/1.png") self.assertEqual(response.mimetype, "image/png") - def test_content_length_options(self): + def test_content_length_options(self) -> None: client = self.app.test_client() headers = [("Accept-Encoding", "gzip")] response = client.options("/small/", headers=headers) self.assertEqual(response.status_code, 200) - def test_gzip_compression_level(self): + def test_gzip_compression_level(self) -> None: """Tests COMPRESS_LEVEL correctly affects response data.""" self.app.config["COMPRESS_LEVEL"] = 1 client = self.app.test_client() @@ -225,7 +229,7 @@ def test_gzip_compression_level(self): self.assertNotEqual(response1_size, response6_size) - def test_br_compression_level(self): + def test_br_compression_level(self) -> None: """Tests that COMPRESS_BR_LEVEL correctly affects response data.""" self.app.config["COMPRESS_BR_LEVEL"] = 4 client = self.app.test_client() @@ -239,7 +243,7 @@ def test_br_compression_level(self): self.assertNotEqual(response4_size, response11_size) - def test_deflate_compression_level(self): + def test_deflate_compression_level(self) -> None: """Tests COMPRESS_DELATE_LEVEL correctly affects response data.""" self.app.config["COMPRESS_DEFLATE_LEVEL"] = -1 client = self.app.test_client() @@ -253,7 +257,7 @@ def test_deflate_compression_level(self): self.assertNotEqual(response_size, response1_size) - def test_zstd_compression_level(self): + def test_zstd_compression_level(self) -> None: """Tests that COMPRESS_ZSTD_LEVEL correctly affects response data.""" self.app.config["COMPRESS_ZSTD_LEVEL"] = 1 client = self.app.test_client() @@ -276,7 +280,7 @@ class CompressionAlgoTests(unittest.TestCase): supported by this extension. """ - def setUp(self): + def setUp(self) -> None: super().setUp() # Create the app here but don't call `Compress()` on it just yet; @@ -290,10 +294,10 @@ def setUp(self): self.small_size = os.path.getsize(small_path) - 1 @self.app.route("/small/") - def small(): + def small() -> str: return render_template("small.html") - def test_setting_compress_algorithm_simple_string(self): + def test_setting_compress_algorithm_simple_string(self) -> None: """Test that a single entry in `COMPRESS_ALGORITHM` still works. This is a backwards-compatibility test.""" @@ -301,19 +305,19 @@ def test_setting_compress_algorithm_simple_string(self): c = Compress(self.app) self.assertTupleEqual(c.enabled_algorithms, ("gzip",)) - def test_setting_compress_algorithm_cs_string(self): + def test_setting_compress_algorithm_cs_string(self) -> None: """Test that `COMPRESS_ALGORITHM` can be a comma-separated string""" self.app.config["COMPRESS_ALGORITHM"] = "gzip, br, zstd" c = Compress(self.app) self.assertTupleEqual(c.enabled_algorithms, ("gzip", "br", "zstd")) - def test_setting_compress_algorithm_list(self): + def test_setting_compress_algorithm_list(self) -> None: """Test that `COMPRESS_ALGORITHM` can be a list of strings""" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] c = Compress(self.app) self.assertTupleEqual(c.enabled_algorithms, ("gzip", "br", "deflate")) - def test_one_algo_supported(self): + def test_one_algo_supported(self) -> None: """Tests requesting a single supported compression algorithm""" accept_encoding = "gzip" self.app.config["COMPRESS_ALGORITHM"] = ["br", "gzip"] @@ -321,7 +325,7 @@ def test_one_algo_supported(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "gzip") - def test_one_algo_unsupported(self): + def test_one_algo_unsupported(self) -> None: """Tests requesting single unsupported compression algorithm""" accept_encoding = "some-alien-algorithm" self.app.config["COMPRESS_ALGORITHM"] = ["br", "gzip"] @@ -329,7 +333,7 @@ def test_one_algo_unsupported(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertIsNone(chosen_algorithm) - def test_multiple_algos_supported(self): + def test_multiple_algos_supported(self) -> None: """Tests requesting multiple supported compression algorithms""" accept_encoding = "br, gzip, zstd" self.app.config["COMPRESS_ALGORITHM"] = ["zstd", "br", "gzip"] @@ -338,7 +342,7 @@ def test_multiple_algos_supported(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "zstd") - def test_multiple_algos_unsupported(self): + def test_multiple_algos_unsupported(self) -> None: """Tests requesting multiple unsupported compression algorithms""" accept_encoding = "future-algo, alien-algo, forbidden-algo" self.app.config["COMPRESS_ALGORITHM"] = ["zstd", "br", "gzip"] @@ -346,7 +350,7 @@ def test_multiple_algos_unsupported(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertIsNone(chosen_algorithm) - def test_multiple_algos_with_wildcard(self): + def test_multiple_algos_with_wildcard(self) -> None: """Request multiple unsupported compression algorithms and a wildcard""" accept_encoding = "future-algo, alien-algo, forbidden-algo, *" self.app.config["COMPRESS_ALGORITHM"] = ["zstd", "br", "gzip"] @@ -355,7 +359,7 @@ def test_multiple_algos_with_wildcard(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "zstd") - def test_multiple_algos_with_different_quality(self): + def test_multiple_algos_with_different_quality(self) -> None: """Request multiple supported compression algorithms with different q-factors""" accept_encoding = "zstd;q=0.8, br;q=0.9, gzip;q=0.5" self.app.config["COMPRESS_ALGORITHM"] = ["zstd", "br", "gzip"] @@ -363,7 +367,7 @@ def test_multiple_algos_with_different_quality(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "br") - def test_multiple_algos_with_equal_quality(self): + def test_multiple_algos_with_equal_quality(self) -> None: """Request multiple supported compression algorithms with equal q-factors""" accept_encoding = "zstd;q=0.5, br;q=0.5, gzip;q=0.5" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "zstd"] @@ -372,7 +376,7 @@ def test_multiple_algos_with_equal_quality(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "gzip") - def test_default_quality_is_1(self): + def test_default_quality_is_1(self) -> None: """Tests that when making mixed-quality requests, the default q-factor is 1.0""" accept_encoding = "deflate, br;q=0.999, gzip;q=0.5" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] @@ -380,7 +384,7 @@ def test_default_quality_is_1(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "deflate") - def test_default_wildcard_quality_is_0(self): + def test_default_wildcard_quality_is_0(self) -> None: """Tests that a wildcard has a default q-factor of 0.0""" accept_encoding = "br;q=0.001, *" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] @@ -388,7 +392,7 @@ def test_default_wildcard_quality_is_0(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertEqual(chosen_algorithm, "br") - def test_wildcard_quality(self): + def test_wildcard_quality(self) -> None: """Tests that a wildcard with q=0 is discarded""" accept_encoding = "*;q=0" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] @@ -396,7 +400,7 @@ def test_wildcard_quality(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertIsNone(chosen_algorithm) - def test_identity(self): + def test_identity(self) -> None: """Tests that identity is understood""" accept_encoding = "identity;q=1, br;q=0.5, *;q=0" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] @@ -404,7 +408,7 @@ def test_identity(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertIsNone(chosen_algorithm) - def test_chrome_ranged_requests(self): + def test_chrome_ranged_requests(self) -> None: """Tests that Chrome ranged requests behave as expected""" accept_encoding = "identity;q=1, *;q=0" self.app.config["COMPRESS_ALGORITHM"] = ["gzip", "br", "deflate"] @@ -412,7 +416,7 @@ def test_chrome_ranged_requests(self): chosen_algorithm = _choose_algorithm(c.enabled_algorithms, accept_encoding) self.assertIsNone(chosen_algorithm) - def test_content_encoding_is_correct(self): + def test_content_encoding_is_correct(self) -> None: """Test that the `Content-Encoding` header matches the compression algorithm""" self.app.config["COMPRESS_ALGORITHM"] = ["zstd", "br", "gzip", "deflate"] Compress(self.app) @@ -443,7 +447,7 @@ def test_content_encoding_is_correct(self): class CompressionPerViewTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True self.app.config["COMPRESS_REGISTER"] = False @@ -451,15 +455,15 @@ def setUp(self): compress.init_app(self.app) @self.app.route("/route1/") - def view_1(): + def view_1() -> str: return render_template("large.html") @self.app.route("/route2/") @compress.compressed() - def view_2(): + def view_2() -> str: return render_template("large.html") - def test_compression(self): + def test_compression(self) -> None: client = self.app.test_client() headers = [("Accept-Encoding", "deflate")] @@ -474,7 +478,7 @@ def test_compression(self): class StreamTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True @@ -489,14 +493,14 @@ def setUp(self): self.file_size = os.path.getsize(self.file_path) @self.app.route("/stream/large") - def stream(): - def _stream(): + def stream() -> Response: + def _stream() -> Iterator[str]: with open(self.file_path) as f: yield from f.readlines() return self.app.response_class(_stream(), mimetype="text/html") - def test_no_compression_stream(self): + def test_no_compression_stream(self) -> None: """Tests compression is skipped when COMPRESS_STREAMS is False""" Compress(self.app) self.app.config["COMPRESS_STREAMS"] = False @@ -509,7 +513,7 @@ def test_no_compression_stream(self): self.assertEqual(response.is_streamed, True) self.assertEqual(self.file_size, len(response.data)) - def test_compression_stream(self): + def test_compression_stream(self) -> None: Compress(self.app) client = self.app.test_client() with open(self.file_path, "rb") as f: @@ -534,7 +538,7 @@ def test_compression_stream(self): class StreamTestsWithETags(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__, static_folder="web", static_url_path="/path") self.app.testing = True @@ -546,8 +550,8 @@ def setUp(self): Compress(self.app) @self.app.route("/stream/large") - def stream(): - def _stream(): + def stream() -> Response: + def _stream() -> Iterator[str]: with open(self.file_path) as f: yield from f.readlines() @@ -556,7 +560,7 @@ def _stream(): rv.set_etag("stream-etag", weak=False) return rv - def test_conditionals_are_skipped_on_streaming(self): + def test_conditionals_are_skipped_on_streaming(self) -> None: client = self.app.test_client() r1 = client.get("/stream/large", headers=[("Accept-Encoding", "br")]) @@ -579,7 +583,7 @@ def test_conditionals_are_skipped_on_streaming(self): self.assertEqual(r1.headers.get("Content-Encoding"), "br") r2.close() - def test_static_exception_to_conditionals(self): + def test_static_exception_to_conditionals(self) -> None: # Here we test that the static endpoint, which is using streaming responses, # still respects conditional requests even when streaming. # We use a custom static folder and static url path to show that the test @@ -596,6 +600,7 @@ def test_static_exception_to_conditionals(self): self.assertIn("Content-Encoding", r1.headers) self.assertEqual(r1.headers.get("Content-Encoding"), "br") self.assertIsNotNone(tag) + assert tag is not None self.assertEqual(tag[-3:], ":br") self.assertFalse(is_weak) r1.close() @@ -612,7 +617,7 @@ def test_static_exception_to_conditionals(self): class CachingCompressionTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # We keep track of the number of times the view is called self.view_calls = 0 self.tmpdir = tempfile.TemporaryDirectory() @@ -628,7 +633,7 @@ def setUp(self): ) cache.init_app(self.app) - def get_cache_key(request): + def get_cache_key(request: Request) -> str: return request.url compress = Compress() @@ -638,14 +643,14 @@ def get_cache_key(request): compress.cache_key = get_cache_key @self.app.route("/route/") - def view(): + def view() -> str: self.view_calls += 1 return render_template("large.html") - def tearDown(self): + def tearDown(self) -> None: self.tmpdir.cleanup() - def test_compression(self): + def test_compression(self) -> None: # Here we are testing cache pollution where the same query is cached # but with different compression algorithms. The cache key should include # the compression algorithm so that the cache is not polluted. @@ -669,13 +674,13 @@ def test_compression(self): class DictCacheTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # We keep track of the number of times the cache key function is called self.cache_key_calls = 0 self.app = Flask(__name__) self.app.testing = True - def get_cache_key(request): + def get_cache_key(request: Request) -> str: self.cache_key_calls += 1 return request.url @@ -686,10 +691,10 @@ def get_cache_key(request): compress.cache_key = get_cache_key @self.app.route("/route/") - def view(): + def view() -> str: return render_template("large.html") - def test_compression(self): + def test_compression(self) -> None: client = self.app.test_client() headers = [("Accept-Encoding", "deflate")] @@ -708,7 +713,7 @@ def test_compression(self): class ETagTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.app = Flask(__name__) self.app.testing = True self.app.config["COMPRESS_ALGORITHM"] = ["gzip"] @@ -717,30 +722,32 @@ def setUp(self): Compress(self.app) @self.app.route("/strong/") - def strong(): + def strong() -> Response: rv = make_response(render_template("large.html")) rv.set_etag("abc123", weak=False) - return rv.make_conditional(request) + rv.make_conditional(request) + return rv @self.app.route("/strong-compress-conditional/") - def strong_compress_conditional(): + def strong_compress_conditional() -> Response: rv = make_response(render_template("large.html")) rv.set_etag("abc123", weak=False) return rv @self.app.route("/weak/") - def weak(): + def weak() -> Response: rv = make_response(render_template("large.html")) rv.set_etag("abc123", weak=True) - return rv.make_conditional(request) + rv.make_conditional(request) + return rv @self.app.route("/weak-compress-conditional/") - def weak_compress_conditional(): + def weak_compress_conditional() -> Response: rv = make_response(render_template("large.html")) rv.set_etag("abc123", weak=True) return rv - def test_strong_etag_is_mutated_with_suffix_and_remains_strong(self): + def test_strong_etag_is_mutated_with_suffix_and_remains_strong(self) -> None: client = self.app.test_client() r = client.get("/strong/", headers=[("Accept-Encoding", "gzip")]) self.assertEqual(r.status_code, 200) @@ -751,7 +758,7 @@ def test_strong_etag_is_mutated_with_suffix_and_remains_strong(self): self.assertEqual(tag, "abc123:gzip") self.assertEqual(int(r.headers["Content-Length"]), len(r.data)) - def test_weak_etag_is_preserved(self): + def test_weak_etag_is_preserved(self) -> None: client = self.app.test_client() r = client.get("/weak/", headers=[("Accept-Encoding", "gzip")]) self.assertEqual(r.status_code, 200) @@ -762,7 +769,7 @@ def test_weak_etag_is_preserved(self): # No :gzip suffix when flag is False self.assertEqual(tag, "abc123") - def test_conditional_get_uses_strong_compressed_representation(self): + def test_conditional_get_uses_strong_compressed_representation(self) -> None: self.app.config["COMPRESS_EVALUATE_CONDITIONAL_REQUEST"] = False client = self.app.test_client() r1 = client.get("/strong/", headers=[("Accept-Encoding", "gzip")]) @@ -779,7 +786,7 @@ def test_conditional_get_uses_strong_compressed_representation(self): # We would expect a 304 but it does not because of etag mismatch self.assertEqual(r2.status_code, 200) - def test_conditional_get_uses_weak_compressed_representation(self): + def test_conditional_get_uses_weak_compressed_representation(self) -> None: self.app.config["COMPRESS_EVALUATE_CONDITIONAL_REQUEST"] = False client = self.app.test_client() r1 = client.get("/weak/", headers=[("Accept-Encoding", "gzip")]) @@ -798,7 +805,7 @@ def test_conditional_get_uses_weak_compressed_representation(self): def test_conditional_get_uses_strong_compressed_representation_evaluate_conditional( self, - ): + ) -> None: client = self.app.test_client() r1 = client.get( "/strong-compress-conditional/", headers=[("Accept-Encoding", "gzip")] @@ -818,7 +825,7 @@ def test_conditional_get_uses_strong_compressed_representation_evaluate_conditio def test_conditional_get_uses_weak_compressed_representation_evaluate_conditional( self, - ): + ) -> None: client = self.app.test_client() r1 = client.get( "/weak-compress-conditional/", headers=[("Accept-Encoding", "gzip")]