diff --git a/changelog.d/3710.fixed.md b/changelog.d/3710.fixed.md new file mode 100644 index 000000000..b5b64d1d2 --- /dev/null +++ b/changelog.d/3710.fixed.md @@ -0,0 +1 @@ +Restored gzip compression for Cloud Run metadata responses. diff --git a/policyengine_api/asgi_factory.py b/policyengine_api/asgi_factory.py index e81588aaa..2456ccccd 100644 --- a/policyengine_api/asgi_factory.py +++ b/policyengine_api/asgi_factory.py @@ -8,11 +8,10 @@ from a2wsgi import WSGIMiddleware from fastapi import FastAPI, HTTPException -from pydantic import BaseModel - from policyengine_api.constants import VERSION from policyengine_api.migration_logging import log_migration_request - +from pydantic import BaseModel +from starlette.middleware.gzip import GZipMiddleware FASTAPI_NATIVE_LOGGED_PATHS = frozenset( { @@ -50,6 +49,7 @@ def create_asgi_app(wsgi_app) -> FastAPI: redoc_url=None, openapi_url=None, ) + app.add_middleware(GZipMiddleware, minimum_size=1000) @app.middleware("http") async def add_cors_for_native_routes(request, call_next): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5e2a983ad..336fe07f0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,7 +1,15 @@ -import pytest +import os import sqlite3 -from policyengine_api.data import PolicyEngineDatabase + +import pytest + +# The legacy SQL module creates its default database object at import time. +# Keep unit tests on SQLite until the SQL layer is broadly refactored in a +# later migration stage. +os.environ.setdefault("FLASK_DEBUG", "1") + from policyengine_api.constants import REPO +from policyengine_api.data import PolicyEngineDatabase class TestPolicyEngineDatabase(PolicyEngineDatabase): @@ -39,7 +47,7 @@ def initialize(self): / "data" / f"initialise{'_local' if self.local else ''}.sql" ) - with open(init_file, "r") as f: + with open(init_file) as f: full_query = f.read() # Split and execute the queries diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 3882bb0f7..1f380b9df 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -10,10 +10,9 @@ (dict(row) and row["key"]). """ -import pytest +import policyengine_api.data.data as data_module import sqlalchemy - -from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase +from policyengine_api.data.data import PolicyEngineDatabase, _ResultProxy class TestSQLAlchemyVersion: @@ -180,3 +179,81 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None + + +class TestRemotePoolSetup: + """Test remote pool setup without opening a real Cloud SQL connection.""" + + def _stub_remote_pool(self, monkeypatch): + fake_connection = object() + connector_calls = [] + engine_calls = [] + + class FakeConnector: + def connect(self, **kwargs): + connector_calls.append(kwargs) + return fake_connection + + def fake_create_engine(url, creator): + engine_calls.append((url, creator)) + assert creator() is fake_connection + return "fake-engine" + + fake_connector = FakeConnector() + monkeypatch.setattr(data_module, "Connector", lambda: fake_connector) + monkeypatch.setattr(data_module.sqlalchemy, "create_engine", fake_create_engine) + return fake_connector, connector_calls, engine_calls + + def test_create_pool_uses_remote_database_config(self, monkeypatch): + fake_connector, connector_calls, engine_calls = self._stub_remote_pool( + monkeypatch + ) + monkeypatch.setenv( + "POLICYENGINE_DB_INSTANCE_CONNECTION_NAME", + "test-project:us-central1:test-db", + ) + monkeypatch.setenv("POLICYENGINE_DB_USER", "test-user") + monkeypatch.setenv("POLICYENGINE_DB_NAME", "test-db") + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + assert db.connector is fake_connector + assert db.pool == "fake-engine" + assert connector_calls == [ + { + "instance_connection_string": "test-project:us-central1:test-db", + "driver": "pymysql", + "db": "test-db", + "user": "test-user", + "password": "test-password", + } + ] + assert engine_calls[0][0] == "mysql+pymysql://" + + def test_create_pool_reads_dot_dbpw_file(self, monkeypatch, tmp_path): + _, connector_calls, _ = self._stub_remote_pool(monkeypatch) + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", ".dbpw") + (tmp_path / ".dbpw").write_text("file-password\n") + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + assert connector_calls[0]["password"] == "file-password" + + def test_remote_constructor_initializes_pool_without_local_database( + self, monkeypatch + ): + calls = [] + + def fake_create_pool(self): + calls.append(("pool", self.local)) + + monkeypatch.setattr(PolicyEngineDatabase, "_create_pool", fake_create_pool) + + db = PolicyEngineDatabase(local=False, initialize=False) + + assert db.local is False + assert calls == [("pool", False)] diff --git a/tests/unit/test_asgi_factory.py b/tests/unit/test_asgi_factory.py index 35503edbe..766339aa8 100644 --- a/tests/unit/test_asgi_factory.py +++ b/tests/unit/test_asgi_factory.py @@ -7,9 +7,8 @@ from fastapi.testclient import TestClient from flask import Flask, Response, jsonify, make_response, request from flask_cors import CORS -from starlette.responses import Response as ASGIResponse - from policyengine_api.asgi_factory import _add_vary_origin, create_asgi_app +from starlette.responses import Response as ASGIResponse def create_test_wsgi_app() -> Flask: @@ -23,6 +22,10 @@ def fallback(): response.set_cookie("fallback-cookie", "present") return response + @app.get("/large-fallback") + def large_fallback(): + return Response("x" * 2_000, status=200, mimetype="text/plain") + @app.get("/request-echo") def request_echo(): response = jsonify( @@ -113,6 +116,20 @@ def test_flask_fallback_preserves_status_body_headers_and_cookies(): assert response.headers["content-type"].startswith("text/html") +def test_large_flask_fallback_response_supports_http_gzip(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + response = client.get( + "/large-fallback", + headers={"Accept-Encoding": "gzip"}, + ) + + assert response.status_code == 200 + assert response.headers["content-encoding"] == "gzip" + assert "Accept-Encoding" in response.headers["vary"] + assert response.text == "x" * 2_000 + + def test_request_headers_and_cookies_pass_through_to_flask_fallback(): client = TestClient(create_asgi_app(create_test_wsgi_app())) client.cookies.set("session_id", "session-123") @@ -143,7 +160,8 @@ def test_flask_cors_behavior_is_preserved_for_fallback_routes(): response.headers["access-control-allow-origin"] == "https://app.policyengine.org" ) - assert response.headers["vary"] == "Origin" + vary_values = {value.strip() for value in response.headers["vary"].split(",")} + assert vary_values == {"Origin", "Accept-Encoding"} def test_health_route_uses_same_reflected_cors_policy():