Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3710.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Restored gzip compression for Cloud Run metadata responses.
6 changes: 3 additions & 3 deletions policyengine_api/asgi_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from a2wsgi import WSGIMiddleware
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from policyengine_api.constants import VERSION
from policyengine_api.migration_logging import log_migration_request

from pydantic import BaseModel
from starlette.middleware.gzip import GZipMiddleware

FASTAPI_NATIVE_LOGGED_PATHS = frozenset(
{
Expand Down Expand Up @@ -50,6 +49,7 @@ def create_asgi_app(wsgi_app) -> FastAPI:
redoc_url=None,
openapi_url=None,
)
app.add_middleware(GZipMiddleware, minimum_size=1000)

@app.middleware("http")
async def add_cors_for_native_routes(request, call_next):
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import pytest
import os
import sqlite3
from policyengine_api.data import PolicyEngineDatabase

import pytest

# The legacy SQL module creates its default database object at import time.
# Keep unit tests on SQLite until the SQL layer is broadly refactored in a
# later migration stage.
os.environ.setdefault("FLASK_DEBUG", "1")

from policyengine_api.constants import REPO
from policyengine_api.data import PolicyEngineDatabase


class TestPolicyEngineDatabase(PolicyEngineDatabase):
Expand Down Expand Up @@ -39,7 +47,7 @@ def initialize(self):
/ "data"
/ f"initialise{'_local' if self.local else ''}.sql"
)
with open(init_file, "r") as f:
with open(init_file) as f:
full_query = f.read()

# Split and execute the queries
Expand Down
83 changes: 80 additions & 3 deletions tests/unit/data/test_sqlalchemy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
(dict(row) and row["key"]).
"""

import pytest
import policyengine_api.data.data as data_module
import sqlalchemy

from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase
from policyengine_api.data.data import PolicyEngineDatabase, _ResultProxy


class TestSQLAlchemyVersion:
Expand Down Expand Up @@ -180,3 +179,81 @@ def test_remote_delete(self):
db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)])
result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)])
assert result.fetchone() is None


class TestRemotePoolSetup:
"""Test remote pool setup without opening a real Cloud SQL connection."""

def _stub_remote_pool(self, monkeypatch):
fake_connection = object()
connector_calls = []
engine_calls = []

class FakeConnector:
def connect(self, **kwargs):
connector_calls.append(kwargs)
return fake_connection

def fake_create_engine(url, creator):
engine_calls.append((url, creator))
assert creator() is fake_connection
return "fake-engine"

fake_connector = FakeConnector()
monkeypatch.setattr(data_module, "Connector", lambda: fake_connector)
monkeypatch.setattr(data_module.sqlalchemy, "create_engine", fake_create_engine)
return fake_connector, connector_calls, engine_calls

def test_create_pool_uses_remote_database_config(self, monkeypatch):
fake_connector, connector_calls, engine_calls = self._stub_remote_pool(
monkeypatch
)
monkeypatch.setenv(
"POLICYENGINE_DB_INSTANCE_CONNECTION_NAME",
"test-project:us-central1:test-db",
)
monkeypatch.setenv("POLICYENGINE_DB_USER", "test-user")
monkeypatch.setenv("POLICYENGINE_DB_NAME", "test-db")
monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password")

db = PolicyEngineDatabase.__new__(PolicyEngineDatabase)
db._create_pool()

assert db.connector is fake_connector
assert db.pool == "fake-engine"
assert connector_calls == [
{
"instance_connection_string": "test-project:us-central1:test-db",
"driver": "pymysql",
"db": "test-db",
"user": "test-user",
"password": "test-password",
}
]
assert engine_calls[0][0] == "mysql+pymysql://"

def test_create_pool_reads_dot_dbpw_file(self, monkeypatch, tmp_path):
_, connector_calls, _ = self._stub_remote_pool(monkeypatch)
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", ".dbpw")
(tmp_path / ".dbpw").write_text("file-password\n")

db = PolicyEngineDatabase.__new__(PolicyEngineDatabase)
db._create_pool()

assert connector_calls[0]["password"] == "file-password"

def test_remote_constructor_initializes_pool_without_local_database(
self, monkeypatch
):
calls = []

def fake_create_pool(self):
calls.append(("pool", self.local))

monkeypatch.setattr(PolicyEngineDatabase, "_create_pool", fake_create_pool)

db = PolicyEngineDatabase(local=False, initialize=False)

assert db.local is False
assert calls == [("pool", False)]
24 changes: 21 additions & 3 deletions tests/unit/test_asgi_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from fastapi.testclient import TestClient
from flask import Flask, Response, jsonify, make_response, request
from flask_cors import CORS
from starlette.responses import Response as ASGIResponse

from policyengine_api.asgi_factory import _add_vary_origin, create_asgi_app
from starlette.responses import Response as ASGIResponse


def create_test_wsgi_app() -> Flask:
Expand All @@ -23,6 +22,10 @@ def fallback():
response.set_cookie("fallback-cookie", "present")
return response

@app.get("/large-fallback")
def large_fallback():
return Response("x" * 2_000, status=200, mimetype="text/plain")

@app.get("/request-echo")
def request_echo():
response = jsonify(
Expand Down Expand Up @@ -113,6 +116,20 @@ def test_flask_fallback_preserves_status_body_headers_and_cookies():
assert response.headers["content-type"].startswith("text/html")


def test_large_flask_fallback_response_supports_http_gzip():
client = TestClient(create_asgi_app(create_test_wsgi_app()))

response = client.get(
"/large-fallback",
headers={"Accept-Encoding": "gzip"},
)

assert response.status_code == 200
assert response.headers["content-encoding"] == "gzip"
assert "Accept-Encoding" in response.headers["vary"]
assert response.text == "x" * 2_000


def test_request_headers_and_cookies_pass_through_to_flask_fallback():
client = TestClient(create_asgi_app(create_test_wsgi_app()))
client.cookies.set("session_id", "session-123")
Expand Down Expand Up @@ -143,7 +160,8 @@ def test_flask_cors_behavior_is_preserved_for_fallback_routes():
response.headers["access-control-allow-origin"]
== "https://app.policyengine.org"
)
assert response.headers["vary"] == "Origin"
vary_values = {value.strip() for value in response.headers["vary"].split(",")}
assert vary_values == {"Origin", "Accept-Encoding"}


def test_health_route_uses_same_reflected_cors_policy():
Expand Down
Loading