Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ install-dev-dbt-%:
$(MAKE) install-dev; \
if [ "$$version" = "1.6.0" ]; then \
echo "Applying overrides for dbt 1.6.0"; \
$(PIP) install 'pydantic>=2.0.0' 'google-cloud-bigquery==3.30.0' 'databricks-sdk==0.28.0' --reinstall; \
$(PIP) install 'pydantic>=2.0.0' 'google-cloud-bigquery==3.30.0' 'databricks-sdk==0.28.0' \
'pyOpenSSL>=24.0.0' --reinstall; \
fi; \
if [ "$$version" = "1.7.0" ]; then \
echo "Applying overrides for dbt 1.7.0"; \
$(PIP) install 'databricks-sdk==0.28.0' --reinstall; \
$(PIP) install 'databricks-sdk==0.28.0' \
'pyOpenSSL>=24.0.0' --reinstall; \
fi; \
if [ "$$version" = "1.5.0" ]; then \
echo "Applying overrides for dbt 1.5.0"; \
Expand Down
5 changes: 5 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,11 @@ def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any
is_flag=True,
help="Wait for the environment to be deleted before returning. If not specified, the environment will be deleted asynchronously by the janitor process. This option requires a connection to the data warehouse.",
)
@click.option(
"--cleanup-snapshots",
is_flag=True,
help="After invalidating, immediately delete physical snapshot tables that are exclusively owned by this environment (not referenced by any other environment). Cleanup runs synchronously regardless of --sync.",
)
@click.pass_context
@error_handler
@cli_analytics
Expand Down
41 changes: 39 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
Snapshot,
SnapshotEvaluator,
SnapshotFingerprint,
SnapshotId,
missing_intervals,
to_table_mapping,
)
Expand All @@ -108,7 +109,11 @@
StateReader,
StateSync,
)
from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots
from sqlmesh.core.janitor import (
cleanup_expired_views,
delete_expired_snapshots,
delete_snapshots_for_environment,
)
from sqlmesh.core.table_diff import TableDiff
from sqlmesh.core.test import (
ModelTextTestResult,
Expand Down Expand Up @@ -1835,18 +1840,50 @@ def apply(
)

@python_api_analytics
def invalidate_environment(self, name: str, sync: bool = False) -> None:
def invalidate_environment(
self, name: str, sync: bool = False, cleanup_snapshots: bool = False
) -> None:
"""Invalidates the target environment by setting its expiration timestamp to now.

Args:
name: The name of the environment to invalidate.
sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will
be deleted asynchronously by the janitor process.
cleanup_snapshots: If True, immediately deletes physical snapshot tables that are exclusively
owned by this environment (not referenced by any other environment). Cleanup runs
synchronously regardless of --sync.
"""
name = Environment.sanitize_name(name)
sync = sync or cleanup_snapshots

target_snapshot_ids: t.Set[SnapshotId] = set()
if cleanup_snapshots:
# Capture snapshot IDs before invalidation so we can scope the cleanup afterwards.
env = self.state_sync.get_environment(name)
if env is None:
logger.warning("Environment '%s' does not exist; skipping snapshot cleanup.", name)
return
target_snapshot_ids = {s.snapshot_id for s in env.snapshots}

self.state_sync.invalidate_environment(name)

if sync:
self._cleanup_environments(name=name)
if cleanup_snapshots and target_snapshot_ids:
failures = delete_snapshots_for_environment(
self.state_sync,
self.snapshot_evaluator,
target_snapshot_ids,
console=self.console,
)
if failures:
summary = "\n".join(failures)
if self.config.janitor.warn_on_delete_failure:
self.console.log_warning(
f"Snapshot cleanup completed with failures:\n{summary}"
)
else:
raise SQLMeshError(f"Snapshot cleanup completed with failures:\n{summary}")
self.console.log_success(f"Environment '{name}' deleted.")
else:
self.console.log_success(f"Environment '{name}' invalidated.")
Expand Down
71 changes: 70 additions & 1 deletion sqlmesh/core/janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlmesh.core.console import Console
from sqlmesh.core.dialect import schema_
from sqlmesh.core.environment import Environment
from sqlmesh.core.snapshot import SnapshotEvaluator
from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotId
from sqlmesh.core.state_sync import StateSync
from sqlmesh.core.state_sync.common import (
logger,
Expand Down Expand Up @@ -193,3 +193,72 @@ def delete_expired_snapshots(
failures.append(message)
logger.info("Cleaned up %s expired snapshots", num_expired_snapshots)
return failures


def delete_snapshots_for_environment(
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
target_snapshot_ids: t.Collection[SnapshotId],
*,
force_delete: bool = False,
console: t.Optional[Console] = None,
) -> t.List[str]:
"""Delete snapshots that are exclusively owned by a specific (now-deleted) environment.

This performs a scoped cleanup: only the provided snapshot IDs are considered for deletion,
and only those that are not referenced by any remaining active environment will be removed.

Args:
state_sync: StateSync instance to query and delete snapshot state from.
snapshot_evaluator: SnapshotEvaluator instance to clean up physical tables.
target_snapshot_ids: The snapshot IDs to consider for deletion (typically from the
environment that was just invalidated/deleted).
force_delete: If True, delete snapshot state records even when physical table cleanup fails.
console: Optional console for reporting progress.

Returns:
List of failure messages encountered during cleanup.
"""
if not target_snapshot_ids:
return []

failures: t.List[str] = []
batch = state_sync.get_expired_snapshots(
ignore_ttl=True,
batch_range=ExpiredBatchRange.all_batch_range(),
target_snapshot_ids=target_snapshot_ids,
)
if batch is None:
return failures

logger.info(
"Cleaning up %s snapshots exclusively owned by invalidated environment",
len(batch.expired_snapshot_ids),
)

cleanup_succeeded = True
if batch.cleanup_tasks:
try:
snapshot_evaluator.cleanup(
target_snapshots=batch.cleanup_tasks,
on_complete=console.update_cleanup_progress if console else None,
)
except Exception as failed_drops:
message = f"Failed to clean up: {failed_drops}"
logger.warning(message)
failures.append(message)
cleanup_succeeded = False

if cleanup_succeeded or force_delete:
try:
state_sync.delete_snapshots(batch.expired_snapshot_ids)
logger.info(
"Cleaned up %s snapshots from invalidated environment",
len(batch.expired_snapshot_ids),
)
except Exception as e:
message = f"Failed to delete snapshot state records: {e}"
logger.warning(message)
failures.append(message)

return failures
6 changes: 6 additions & 0 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ def get_expired_snapshots(
batch_range: ExpiredBatchRange,
current_ts: t.Optional[int] = None,
ignore_ttl: bool = False,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
"""Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier).

Args:
current_ts: Timestamp used to evaluate expiration.
ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
batch_range: The range of the batch to fetch.
target_snapshot_ids: If provided, only consider snapshots with these IDs. Useful for
scoped cleanup after environment invalidation.

Returns:
A batch describing expired snapshots or None if no snapshots are pending cleanup.
Expand Down Expand Up @@ -368,6 +371,7 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
"""Removes expired snapshots.

Expand All @@ -379,6 +383,8 @@ def delete_expired_snapshots(
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
all snapshots that are not referenced in any environment
current_ts: Timestamp used to evaluate expiration.
target_snapshot_ids: If provided, only delete snapshots with these IDs. Useful for
scoped cleanup after environment invalidation.
"""

@abc.abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/state_sync/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
self.snapshot_cache.clear()
self.state_sync.delete_expired_snapshots(
batch_range=batch_range,
ignore_ttl=ignore_ttl,
current_ts=current_ts,
target_snapshot_ids=target_snapshot_ids,
)

def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,15 @@ def get_expired_snapshots(
batch_range: ExpiredBatchRange,
current_ts: t.Optional[int] = None,
ignore_ttl: bool = False,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
current_ts = current_ts or now_timestamp()
return self.snapshot_state.get_expired_snapshots(
environments=self.environment_state.get_environments(),
current_ts=current_ts,
ignore_ttl=ignore_ttl,
batch_range=batch_range,
target_snapshot_ids=target_snapshot_ids,
)

def get_expired_environments(
Expand All @@ -287,11 +289,13 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
batch = self.get_expired_snapshots(
ignore_ttl=ignore_ttl,
current_ts=current_ts,
batch_range=batch_range,
target_snapshot_ids=target_snapshot_ids,
)
if batch and batch.expired_snapshot_ids:
self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids)
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def get_expired_snapshots(
current_ts: int,
ignore_ttl: bool,
batch_range: ExpiredBatchRange,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
expired_query = exp.select("name", "identifier", "version", "updated_ts").from_(
self.snapshots_table
Expand All @@ -180,6 +181,16 @@ def get_expired_snapshots(
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
)

if target_snapshot_ids is not None:
target_conditions = list(
snapshot_id_filter(
self.engine_adapter,
target_snapshot_ids,
batch_size=self.SNAPSHOT_BATCH_SIZE,
)
)
expired_query = expired_query.where(exp.or_(*target_conditions))

expired_query = expired_query.where(batch_range.where_filter)

promoted_snapshot_ids = {
Expand Down
4 changes: 1 addition & 3 deletions sqlmesh/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ def wrap(*args: t.Any, **kwargs: t.Any) -> t.Any:


class classproperty(property):
"""
Similar to a normal property but works for class methods
"""
"""Similar to a normal property but works for class methods"""

def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
return classmethod(self.fget).__get__(None, owner)() # type: ignore
Expand Down
70 changes: 70 additions & 0 deletions tests/core/integration/test_aux_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,76 @@ def test_invalidating_environment(sushi_context: Context):
assert start_schemas - schemas_after_janitor == {"sushi__dev"}


def test_invalidate_environment_cleanup_snapshots_scoped(tmp_path: Path):
"""Test that --cleanup-snapshots only deletes snapshots exclusively owned by the invalidated env."""
models_dir = tmp_path / "models"
models_dir.mkdir()
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")
(models_dir / "model2.sql").write_text("MODEL(name test.model2, kind FULL); SELECT 2 AS col")

ctx = Context(
paths=[tmp_path],
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
)

# Apply both models to prod and dev.
ctx.plan("prod", no_prompts=True, auto_apply=True)
ctx.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True)

prod_env = ctx.state_sync.get_environment("prod")
dev_env = ctx.state_sync.get_environment("dev")
assert prod_env is not None
assert dev_env is not None

prod_snapshot_ids = {s.snapshot_id for s in prod_env.snapshots}
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}

# In a virtual environment, dev shares snapshots with prod.
# Shared snapshots must NOT be deleted when invalidating dev with --cleanup-snapshots.
shared_snapshot_ids = prod_snapshot_ids & dev_snapshot_ids

ctx.invalidate_environment("dev", cleanup_snapshots=True)

# The dev environment record should be gone.
assert ctx.state_sync.get_environment("dev") is None

# Shared snapshots (also in prod) must still exist.
remaining_snapshots = ctx.state_sync.get_snapshots(list(shared_snapshot_ids))
assert set(remaining_snapshots.keys()) == shared_snapshot_ids

# Prod environment should be unaffected.
assert ctx.state_sync.get_environment("prod") is not None


def test_invalidate_environment_cleanup_snapshots_exclusive(tmp_path: Path):
"""Test that --cleanup-snapshots deletes snapshots exclusively owned by the invalidated env."""
models_dir = tmp_path / "models"
models_dir.mkdir()
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")

ctx = Context(
paths=[tmp_path],
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
)

# Apply model1 to dev only (not prod). These snapshots will be exclusively owned by dev.
ctx.plan("dev", no_prompts=True, auto_apply=True)

dev_env = ctx.state_sync.get_environment("dev")
assert dev_env is not None
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}
assert dev_snapshot_ids

ctx.invalidate_environment("dev", cleanup_snapshots=True)

# The dev environment record should be gone.
assert ctx.state_sync.get_environment("dev") is None

# All dev-exclusive snapshots should have been deleted.
remaining_snapshots = ctx.state_sync.get_snapshots(list(dev_snapshot_ids))
assert not remaining_snapshots


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
Expand Down
Loading
Loading