Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,46 @@ class AutoPopulate:
_key_source = None
_allow_insert = False
_jobs = None
_upstream = None # set per-make() by _populate_one; see `upstream` property below

@property
def upstream(self):
"""
Pre-restricted ancestor view for the current ``make(self, key)`` call.

Inside ``make()``, ``self.upstream`` is a ``Diagram`` constructed via
:meth:`Diagram.trace(self & key) <datajoint.Diagram.trace>`. Use
``self.upstream[T]`` to obtain a pre-restricted ``QueryExpression``
(or ``FreeTable``, when indexed by a string) for any ancestor of
``self``.

Reading via ``self.upstream`` is the provenance-safe pattern: the
framework guarantees the restriction matches the current ``key``,
and indexing a non-ancestor table raises ``DataJointError``. See
:doc:`reference/specs/provenance` for the contract.

Raises
------
DataJointError
If accessed outside ``make()`` execution. To construct a trace
explicitly, use ``dj.Diagram.trace(self & key)``.

Examples
--------
::

def make(self, key):
date = self.upstream[Session].fetch1("session_date")
traces = self.upstream[ExtractTraces].to_arrays("trace")
self.insert1({**key, "summary": compute(traces, date)})
"""
if self._upstream is None:
raise DataJointError(
"self.upstream is only available inside make(). "
"Outside make(), construct a trace explicitly: "
"dj.Diagram.trace(self & key)."
)
return self._upstream

class _JobsDescriptor:
"""Descriptor allowing jobs access on both class and instance."""
Expand Down Expand Up @@ -611,6 +651,13 @@ def _populate1(
logger.jobs(f"Making {key} -> {self.full_table_name}")
self.__class__._allow_insert = True

# Pre-construct the upstream view for this make() call. Lazy — only
# `dj.Diagram.trace(self & key)` runs here (graph copy); the
# expensive SQL fetch fires when the user accesses self.upstream[T].
from .diagram import Diagram

self._upstream = Diagram.trace(self & dict(key))

try:
if not is_generator:
make(dict(key), **(make_kwargs or {}))
Expand Down Expand Up @@ -668,6 +715,10 @@ def _populate1(
return True
finally:
self.__class__._allow_insert = False
# Clear the per-make() upstream view so subsequent attribute
# access raises a clear error rather than silently using a
# stale trace from the previous make() call.
self._upstream = None

def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
"""
Expand Down
169 changes: 169 additions & 0 deletions tests/integration/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,175 @@ def make_insert(self, key, result, scale):
assert row["result"] == 1000 # 200 * 5


# =========================================================================
# #1424: self.upstream pre-restricted ancestor access in make()
# =========================================================================


def test_upstream_provides_pre_restricted_ancestor(prefix, connection_test):
"""make() can read self.upstream[Ancestor] and get pre-restricted data."""
schema = dj.Schema(f"{prefix}_upstream_basic", connection=connection_test)

@schema
class Subject(dj.Lookup):
definition = """
subject_id : int32
---
name : varchar(64)
"""
contents = [(1, "alice"), (2, "bob")]

@schema
class Greeting(dj.Computed):
definition = """
-> Subject
---
greeting : varchar(128)
"""

def make(self, key):
# Provenance-safe read: self.upstream pre-restricted to current key
name = self.upstream[Subject].fetch1("name")
self.insert1({**key, "greeting": f"Hello, {name}!"})

Greeting.populate()
assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!"
assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!"


def test_upstream_rejects_non_ancestor(prefix, connection_test):
"""self.upstream[T] for a non-ancestor table raises inside make()."""
schema = dj.Schema(f"{prefix}_upstream_non_ancestor", connection=connection_test)

@schema
class Subject(dj.Lookup):
definition = """
subject_id : int32
"""
contents = [(1,)]

@schema
class Unrelated(dj.Lookup):
definition = """
u_id : int32
"""
contents = [(99,)]

captured_errors: list[Exception] = []

@schema
class Bad(dj.Computed):
definition = """
-> Subject
---
ok : tinyint
"""

def make(self, key):
try:
self.upstream[Unrelated]
except DataJointError as exc:
captured_errors.append(exc)
# Insert anyway so populate doesn't fail
self.insert1({**key, "ok": 1})

Bad.populate()
assert len(captured_errors) == 1
assert "not in this trace" in str(captured_errors[0]).lower()


def test_upstream_unset_outside_make(prefix, connection_test):
"""Accessing self.upstream outside of make() raises a clear error."""
schema = dj.Schema(f"{prefix}_upstream_outside_make", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
"""
contents = [(1,)]

@schema
class Derived(dj.Computed):
definition = """
-> Source
---
val : int32
"""

def make(self, key):
self.insert1({**key, "val": 0})

with pytest.raises(DataJointError, match="only available inside make"):
Derived().upstream


def test_upstream_cleared_after_make(prefix, connection_test):
"""After a make() call completes, self.upstream is reset (no stale state)."""
schema = dj.Schema(f"{prefix}_upstream_cleared", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
"""
contents = [(1,)]

@schema
class Derived(dj.Computed):
definition = """
-> Source
---
val : int32
"""

def make(self, key):
self.insert1({**key, "val": 0})

Derived.populate()
# The class attribute defaults to None; the per-instance _upstream
# set during make() must have been cleared by the finally block.
# Probe via the public property — should raise the "outside make" error.
with pytest.raises(DataJointError, match="only available inside make"):
Derived().upstream


def test_upstream_seen_across_tripartite_make(prefix, connection_test):
"""The tripartite make() invocation pattern sees the same self.upstream
across all three phases (fetch / compute / insert)."""
schema = dj.Schema(f"{prefix}_upstream_tripartite", connection=connection_test)

@schema
class Source(dj.Lookup):
definition = """
source_id : int32
---
value : int32
"""
contents = [(1, 100), (2, 200)]

@schema
class TriComputed(dj.Computed):
definition = """
-> Source
---
result : int32
"""

def make_fetch(self, key):
return (self.upstream[Source].fetch1("value"),)

def make_compute(self, key, value):
return (value * 2,)

def make_insert(self, key, doubled):
self.insert1({**key, "result": doubled})

TriComputed.populate()
assert (TriComputed & {"source_id": 1}).fetch1("result") == 200
assert (TriComputed & {"source_id": 2}).fetch1("result") == 400


def test_populate_reserve_jobs_respects_restrictions(clean_autopopulate, subject, experiment):
"""Regression test for #1413: populate() with reserve_jobs=True must honour restrictions.

Expand Down
Loading