diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py index da4779543..e79a5d4df 100644 --- a/src/datajoint/adapters/base.py +++ b/src/datajoint/adapters/base.py @@ -850,6 +850,29 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: raise NotImplementedError ... + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """ + Generate query to find schemas that the given schemas reference via FK. + + Used to discover unloaded schemas that the loaded ones depend on + (the upstream / ancestor direction). Symmetric to + :meth:`find_downstream_schemas_sql`. + + Parameters + ---------- + schemas_list : str + Comma-separated, quoted schema names for an IN clause. + + Returns + ------- + str + SQL query returning rows with a single column ``schema_name`` + containing distinct schema names that are referenced by the + given schemas. + """ + raise NotImplementedError + ... + @abstractmethod def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index f035ba87f..4d2d4ca73 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -696,6 +696,16 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: f"AND table_schema NOT IN ({schemas_list})" ) + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas that the given schemas reference via FK.""" + return ( + f"SELECT DISTINCT referenced_table_schema as schema_name " + f"FROM information_schema.key_column_usage " + f"WHERE table_schema IN ({schemas_list}) " + f"AND referenced_table_schema IS NOT NULL " + f"AND referenced_table_schema NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """Query to get FK constraint details from information_schema.""" return ( diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py index 543e972d3..1dc062bda 100644 --- a/src/datajoint/adapters/postgres.py +++ b/src/datajoint/adapters/postgres.py @@ -861,6 +861,20 @@ def find_downstream_schemas_sql(self, schemas_list: str) -> str: f"AND ns1.nspname NOT IN ({schemas_list})" ) + def find_upstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas that the given schemas reference via FK.""" + return ( + f"SELECT DISTINCT ns2.nspname as schema_name " + f"FROM pg_constraint c " + f"JOIN pg_class cl1 ON c.conrelid = cl1.oid " + f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid " + f"JOIN pg_class cl2 ON c.confrelid = cl2.oid " + f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid " + f"WHERE c.contype = 'f' " + f"AND ns1.nspname IN ({schemas_list}) " + f"AND ns2.nspname NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ Query to get FK constraint details from information_schema. diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 08fb50e1b..9b67c00d0 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -259,6 +259,35 @@ def load_all_downstream(self) -> None: self.load(force=True, schema_names=known_schemas) + def load_all_upstream(self) -> None: + """ + Load dependencies including all upstream schemas referenced via FK chains. + + Iteratively discovers schemas that the currently loaded schemas + reference, expanding the dependency graph until no new schemas + are found. This ensures that upstream restriction propagation + (``Diagram.trace()``) reaches all ancestor tables, including + those in schemas the user has not explicitly activated. + + Called automatically by ``Diagram.trace()``. Symmetric to + :meth:`load_all_downstream`. + """ + adapter = self._conn.adapter + known_schemas = set(self._conn.schemas) + if not known_schemas: + self.load() + return + + while True: + schemas_list = ", ".join(adapter.quote_string(s) for s in known_schemas) + result = self._conn.query(adapter.find_upstream_schemas_sql(schemas_list)) + new_schemas = {row[0] for row in result} - known_schemas + if not new_schemas: + break + known_schemas |= new_schemas + + self.load(force=True, schema_names=known_schemas) + def topo_sort(self) -> list[str]: """ Return table names in topological order. diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index aacf4ed61..b2572cfaf 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -365,6 +365,17 @@ def cascade(cls, table_expr, part_integrity="enforce"): # Propagate downstream result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity) + # part_integrity="cascade" may pull in nodes that aren't descendants of + # the seed (e.g. the master of a seed Part, plus the master's other + # Parts). Expand nodes_to_show to include any restricted node and the + # descendants of any newly-restricted ancestor. See #1429. + restricted_nodes = set(result._cascade_restrictions) + expanded = set(result.nodes_to_show) | restricted_nodes + for n in restricted_nodes - result.nodes_to_show: + expanded.update(nx.descendants(result, n)) + result.nodes_to_show = expanded & set(result.nodes()) + result._expanded_nodes = set(result.nodes_to_show) + # Trim graph to cascade subgraph: only restricted tables # (seed + descendants) plus alias nodes connecting them. keep = set(result._cascade_restrictions) @@ -376,6 +387,245 @@ def cascade(cls, table_expr, part_integrity="enforce"): result._expanded_nodes &= keep return result + @classmethod + def trace(cls, table_expr): + """ + Create an upstream-trace diagram for a (restricted) table expression. + + The upstream mirror of :meth:`cascade`. Walks the FK graph upward + from the seed, propagating the restriction to every ancestor with + OR convergence — an ancestor entity is included if reachable through + *any* FK path from the seed. + + Reuses the upward propagation rules + (``_apply_propagation_rule_upward``) defined alongside the cascade + engine, applied here in a generalized form (any child → any parent, + not just Part → Master). + + Parameters + ---------- + table_expr : QueryExpression + A (possibly restricted) table expression. + (e.g., ``Spectrum & key``). + + Returns + ------- + Diagram + New Diagram restricted to the seed and its ancestors, with + per-ancestor restrictions accumulated through the FK graph. + Use ``diagram[T]`` to obtain a pre-restricted + ``QueryExpression`` for an ancestor, or ``diagram.counts()`` + to preview row counts per ancestor. + + Examples + -------- + >>> trace = dj.Diagram.trace(Spectrum & {"recording_id": 5}) + >>> trace[Session].fetch1("session_date") + >>> trace.counts() # entity counts per ancestor + >>> trace["schema.Session"] # FreeTable, when class isn't in scope + + See Also + -------- + :meth:`cascade` — the downstream mirror. + """ + conn = table_expr.connection + conn.dependencies.load_all_upstream() + node = table_expr.full_table_name + + result = cls.__new__(cls) + nx.DiGraph.__init__(result, conn.dependencies) + result._connection = conn + result.context = {} + result._cascade_restrictions = {} # trace uses cascade-shape storage (OR semantics) + result._restrict_conditions = {} + result._restriction_attrs = {} + result._mode = "trace" + + # Include seed + all ancestors + ancestors = set(nx.ancestors(result, node)) | {node} + result.nodes_to_show = ancestors + result._expanded_nodes = set(ancestors) + + # Seed restriction + restriction = AndList(table_expr.restriction) + result._cascade_restrictions[node] = [restriction] if restriction else [] + result._restriction_attrs[node] = set(table_expr.restriction_attributes) + + # Propagate upstream + result._propagate_restrictions_upstream(node) + + # Trim graph to trace subgraph: only restricted tables (seed + ancestors) + # plus alias nodes connecting them. + keep = set(result._cascade_restrictions) + for alias in (n for n in result.nodes() if n.isdigit()): + if set(result.predecessors(alias)) & keep and set(result.successors(alias)) & keep: + keep.add(alias) + result.remove_nodes_from(set(result.nodes()) - keep) + result.nodes_to_show &= keep + result._expanded_nodes &= keep + return result + + def _propagate_restrictions_upstream(self, start_node): + """ + Propagate the seed's restriction upstream through the FK graph. + + Symmetric to :meth:`_propagate_restrictions` but walks ``in_edges`` + instead of ``out_edges`` and applies the upward rules + (``_apply_propagation_rule_upward``) at each real edge. Multiple + passes until no new ancestor is restricted; termination is + guaranteed because the dependency graph is a DAG. + """ + sorted_nodes = topo_sort(self) + # Only propagate through ancestors of start_node + allowed_nodes = {start_node} | set(nx.ancestors(self, start_node)) + propagated_edges = set() + + restrictions = self._cascade_restrictions + + any_new = True + while any_new: + any_new = False + + # Walk in reverse topological order so children are processed + # before their parents — when we reach a parent, its restriction + # accumulates from all of its (already-processed) children. + for node in reversed(sorted_nodes): + if node not in restrictions or node not in allowed_nodes: + continue + + child_ft = self._restricted_table(node) + child_attrs = self._restriction_attrs.get(node, set()) + + for parent, _, edge_props in self.in_edges(node, data=True): + edge_key = (parent, node) + if edge_key in propagated_edges: + continue + propagated_edges.add(edge_key) + + if parent not in allowed_nodes: + continue + + if isinstance(parent, str) and parent.isdigit(): + # Alias node — find the real parent on the far side. + # The alias has its own in_edges; the props on both + # half-edges are identical, so we can use either. + for real_parent, _, real_edge_props in self.in_edges(parent, data=True): + real_edge_key = (real_parent, parent, node) + if real_edge_key in propagated_edges: + continue + propagated_edges.add(real_edge_key) + if real_parent not in allowed_nodes: + continue + attr_map = real_edge_props.get("attr_map", {}) + aliased = real_edge_props.get("aliased", False) + was_new = real_parent not in restrictions + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + real_parent, + attr_map, + aliased, + "cascade", # OR semantics for trace + restrictions, + ) + if was_new and real_parent in restrictions: + any_new = True + else: + attr_map = edge_props.get("attr_map", {}) + aliased = edge_props.get("aliased", False) + was_new = parent not in restrictions + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + parent, + attr_map, + aliased, + "cascade", + restrictions, + ) + if was_new and parent in restrictions: + any_new = True + + def __getitem__(self, key): + """ + Return a pre-restricted query expression (or FreeTable) for an + ancestor table in this trace. + + Only meaningful for trace diagrams (constructed via + :meth:`Diagram.trace`). For ordinary diagrams, defers to + :class:`networkx.DiGraph`'s adjacency-dict lookup. + + Parameters + ---------- + key : type or str + A Table subclass (e.g. ``Session``) — returns a pre-restricted + ``QueryExpression``. Or a string giving the table's class name + or fully-qualified SQL name — returns a pre-restricted + ``FreeTable``. + + Returns + ------- + QueryExpression or FreeTable + The ancestor's table restricted to rows reachable via FK from + the seed of this trace. + + Raises + ------ + DataJointError + If the requested table is not in the trace's subgraph (i.e. + not an ancestor of the seed, and not the seed itself). + + Examples + -------- + >>> trace = dj.Diagram.trace(Spectrum & key) + >>> trace[Session].fetch1("session_date") # class index + >>> trace["my_schema.Session"].to_dicts() # string index → FreeTable + """ + # Non-trace diagrams: defer to networkx adjacency lookup so existing + # `diagram[node_name]` patterns (used in diagram algebra, ERD tests) + # keep working. + if getattr(self, "_mode", None) != "trace": + return super().__getitem__(key) + + from .table import Table + + # Resolve `key` to a full table name + if isinstance(key, type) and issubclass(key, Table): + full_name = key.full_table_name + elif isinstance(key, Table): + full_name = key.full_table_name + elif isinstance(key, str): + # Accept either a class name (resolve via context) or a full SQL name + if "`" in key or '"' in key: + full_name = key + else: + # Class name — search graph nodes for a matching tail + candidates = [ + n + for n in self.nodes() + if not (isinstance(n, str) and n.isdigit()) and n.lower().rstrip('`"').endswith(key.lower()) + ] + if not candidates: + raise DataJointError(f"Table {key!r} is not in this trace's subgraph " f"(not an ancestor of the seed).") + if len(candidates) > 1: + raise DataJointError( + f"Ambiguous table reference {key!r}: matches " f"{', '.join(candidates)}. Use a fully-qualified name." + ) + full_name = candidates[0] + else: + raise DataJointError(f"trace[...] expects a Table class, Table instance, or string; " f"got {type(key).__name__}.") + + if full_name not in self._cascade_restrictions: + raise DataJointError(f"Table {full_name} is not in this trace's subgraph " f"(not an ancestor of the seed).") + + # For class-typed key, return a restricted class instance; for string, + # return a FreeTable. + if isinstance(key, (type, Table)): + ft = self._restricted_table(full_name) + return ft + else: + return self._restricted_table(full_name) + def _restricted_table(self, node): """ Return a FreeTable for ``node`` with this diagram's restrictions applied. @@ -443,7 +693,6 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): propagation rules at each edge. Only processes descendants of start_node to avoid duplicate propagation when chaining. """ - from .table import FreeTable sorted_nodes = topo_sort(self) # Only propagate through descendants of start_node @@ -453,6 +702,18 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): restrictions = self._cascade_restrictions if mode == "cascade" else self._restrict_conditions + # Seed-is-Part case: when the seed itself is a Part and part_integrity="cascade", + # the main loop's part_integrity block (which fires inside `out_edges`) + # cannot trigger from the seed because a leaf Part has no out-edges. + # Trigger the upward propagation explicitly for the seed. See #1429. + if part_integrity == "cascade" and mode == "cascade": + seed_master = extract_master(start_node) + if seed_master and seed_master in self.nodes() and seed_master not in visited_masters: + visited_masters.add(seed_master) + if self._propagate_part_to_master(start_node, seed_master, mode, restrictions): + allowed_nodes.add(seed_master) + allowed_nodes.update(nx.descendants(self, seed_master)) + # Multiple passes to handle part_integrity="cascade" upward propagation. # When a part table triggers its master to join the cascade, the master's # other descendants need processing in a subsequent pass. The loop @@ -512,29 +773,19 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): any_new = True # part_integrity="cascade": propagate up from part to master + # via the actual FK graph path, applying upward propagation + # rules at each edge. Handles Part-of-Part chains and + # renamed FKs (via .proj()), unlike the prior implementation + # which assumed shared PK attribute names. See #1429. if part_integrity == "cascade" and mode == "cascade": master_name = extract_master(target) - if ( - master_name - and master_name in self.nodes() - and master_name not in restrictions - and master_name not in visited_masters - ): + if master_name and master_name in self.nodes() and master_name not in visited_masters: visited_masters.add(master_name) - child_ft = self._restricted_table(target) - master_ft = FreeTable(self._connection, master_name) - from .condition import make_condition - - master_restr = make_condition( - master_ft, - (master_ft.proj() & child_ft.proj()).to_arrays(), - master_ft.restriction_attributes, - ) - restrictions[master_name] = [master_restr] - self._restriction_attrs[master_name] = set() - allowed_nodes.add(master_name) - allowed_nodes.update(nx.descendants(self, master_name)) - any_new = True + propagated = self._propagate_part_to_master(target, master_name, mode, restrictions) + if propagated: + allowed_nodes.add(master_name) + allowed_nodes.update(nx.descendants(self, master_name)) + any_new = True def _apply_propagation_rule( self, @@ -590,6 +841,183 @@ def _apply_propagation_rule( self._restriction_attrs.setdefault(child_node, set()).update(child_attrs) + def _apply_propagation_rule_upward(self, child_ft, child_attrs, parent_node, attr_map, aliased, mode, restrictions): + """ + Apply the symmetric (upward) propagation rule to a parent←child edge. + + Inverts `_apply_propagation_rule`: derives a restriction on the parent + from a restriction on the child, following the FK chain in reverse. + Used by part_integrity="cascade" to propagate a Part's restriction up + to its Master, transparently handling renamed FKs (via .proj()) and + Part-of-Part chains. See #1429. + + Edge metadata convention (matches `_apply_propagation_rule`): + - `attr_map`: dict mapping child column → parent (referenced) column. + - `aliased`: True iff any column was renamed across the FK. + + Rules (symmetric to the forward rules in `_apply_propagation_rule`): + + 1. Non-aliased AND child restriction attrs ⊆ parent PK: + Copy child restriction directly (attrs are shared by name). + 2. Aliased FK (attr_map renames columns): + ``child.proj(**{parent: child for child, parent in attr_map.items()})`` + — reverses the renaming so the result has parent's column names. + 3. Non-aliased AND child restriction attrs ⊄ parent PK: + ``child.proj()`` — project child to parent's PK columns. + """ + parent_pk = self.nodes[parent_node].get("primary_key", set()) + + if not aliased and child_attrs and child_attrs <= parent_pk: + # Backward Rule 1: copy child restriction directly + child_restr = restrictions.get( + child_ft.full_table_name, + [] if mode == "cascade" else AndList(), + ) + if mode == "cascade": + restrictions.setdefault(parent_node, []).extend(child_restr) + else: + restrictions.setdefault(parent_node, AndList()).extend(child_restr) + parent_attrs = set(child_attrs) + elif aliased: + # Backward Rule 2: reverse rename + parent_item = child_ft.proj(**{pk: fk for fk, pk in attr_map.items()}) + if mode == "cascade": + restrictions.setdefault(parent_node, []).append(parent_item) + else: + restrictions.setdefault(parent_node, AndList()).append(parent_item) + parent_attrs = set(attr_map.values()) # parent's PK column names + else: + # Backward Rule 3: project child to its FK columns (which by name + # match parent's PK columns in the non-aliased case). For primary + # FKs (attr_map.keys() ⊆ child_pk) this is a no-op since + # ``proj()`` already returns the PK. For non-primary FKs this + # explicitly carries the FK columns into the projection so the + # subsequent restriction on the parent joins on the right columns. + parent_item = child_ft.proj(*attr_map.keys()) + if mode == "cascade": + restrictions.setdefault(parent_node, []).append(parent_item) + else: + restrictions.setdefault(parent_node, AndList()).append(parent_item) + parent_attrs = set(attr_map.values()) + + self._restriction_attrs.setdefault(parent_node, set()).update(parent_attrs) + + def _propagate_part_to_master(self, part_node, master_name, mode, restrictions): + """ + Walk the FK graph from `part_node` up to `master_name`, applying + `_apply_propagation_rule_upward` at each real edge along the path. + + Returns True if any propagation occurred. Handles Part-of-Part chains + by walking the full path (intermediate Parts get restricted too) and + renamed FKs via the upward rules. + + Alias nodes (integer-named graph nodes inserted for aliased edges) + are transparent — both half-edges carry the same `attr_map` props, + so we read props from one and skip the alias node when walking. + + After the walk, the master's restriction is **materialized** to a + literal value tuple via ``to_arrays()``. Without materialization, a + subsequent forward cascade from the master back down to its parts + would produce a self-referential subquery (MySQL error 1093, since + the master's restriction depends on the same Part being deleted). + Materializing converts the restriction into a static value set, so + the forward cascade generates ``WHERE ... IN (literal-list)`` rather + than ``WHERE ... IN (SELECT ... FROM )``. + + Limitations + ----------- + - **Single FK path**: ``nx.shortest_path`` returns *one* path from + ``master_name`` to ``part_node``. If a Part is reachable from its + Master through multiple distinct FK chains (e.g. references two + different intermediate Parts), restrictions through the + non-shortest paths are not applied. This pattern is unusual; if a + schema hits it, the user is responsible for restricting the + additional paths explicitly via ``part_integrity="ignore"`` plus + manual ``delete()`` calls. + - **Memory cost of materialization**: ``master_ft.proj().to_arrays()`` + pulls the matching master primary keys into Python memory. Cost is + bounded by the count of *distinct* master rows referenced by the + matching parts — typically small for surgical cascades, but can + grow with bulk cascades on tables with many master rows. Cascade + *preview* (``Diagram.cascade(...).counts()``) pays the same cost. + """ + try: + path = nx.shortest_path(self, master_name, part_node) + except (nx.NetworkXNoPath, nx.NodeNotFound): + return False + + # Strip alias nodes; what remains is the sequence of real tables. + real_path = [n for n in path if not (isinstance(n, str) and n.isdigit())] + if len(real_path) < 2 or real_path[-1] != part_node or real_path[0] != master_name: + return False + + # Walk real_path in reverse (child → parent direction). For each + # adjacent (parent, child) pair, look up the edge props — direct + # edge if non-aliased, via alias node if aliased. + any_propagated = False + for i in range(len(real_path) - 1, 0, -1): + child = real_path[i] + parent = real_path[i - 1] + edge_props = self._find_real_edge_props(parent, child) + if edge_props is None: + return any_propagated # Path broken (shouldn't happen if shortest_path succeeded) + + attr_map = edge_props.get("attr_map", {}) + aliased = edge_props.get("aliased", False) + child_ft = self._restricted_table(child) + child_attrs = self._restriction_attrs.get(child, set()) + + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + parent, + attr_map, + aliased, + mode, + restrictions, + ) + any_propagated = True + + # Materialize the master's restriction so subsequent forward cascade + # doesn't produce self-referential subqueries. Replace the master's + # accumulated query restrictions with a literal value tuple. + if any_propagated and master_name in restrictions: + from .condition import make_condition + from .table import FreeTable + + master_ft = self._restricted_table(master_name) + master_pk_values = master_ft.proj().to_arrays() + if mode == "cascade": + bare_master = FreeTable(self._connection, master_name) + if len(master_pk_values) > 0: + materialized = make_condition( + bare_master, + master_pk_values, + bare_master.restriction_attributes, + ) + restrictions[master_name] = [materialized] + else: + # No matching master rows — false restriction so master is + # included with zero matches in counts/iter. + restrictions[master_name] = [False] + self._restriction_attrs.setdefault(master_name, set()) + + return any_propagated + + def _find_real_edge_props(self, parent, child): + """ + Return edge props for parent → child, transparently traversing the + integer-named alias node that the graph inserts for aliased FKs. + Returns None if no such edge or alias-mediated edge exists. + """ + if self.has_edge(parent, child): + return self.edges[parent, child] + for _, mid, _ in self.out_edges(parent, data=True): + if isinstance(mid, str) and mid.isdigit() and self.has_edge(mid, child): + # Both half-edges carry the same attr_map / aliased props + return self.edges[parent, mid] + return None + def counts(self): """ Return affected row counts per table without modifying data. diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py index 3bc3dc73b..607669124 100644 --- a/tests/integration/test_cascade_delete.py +++ b/tests/integration/test_cascade_delete.py @@ -292,3 +292,190 @@ class Child(dj.Manual): connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}") except Exception: pass + + +# ========================================================================= +# Issue #1429: cascade with part_integrity="cascade" must traverse the FK +# chain through intermediate Parts (and renamed FKs), not assume that the +# Part shares PK attribute names with its Master. +# ========================================================================= + + +def test_cascade_part_of_part_no_master_reference(schema_by_backend): + """ + Case 2 from #1429: PartB references PartA directly (no -> Master). + Restricting PartB with part_integrity="cascade" must restrict both + PartA and Master (PartA via the direct FK, Master via the master-part + FK chained through PartA). + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (1, 11), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (1, 10, 101), (1, 11, 110), (2, 20, 200)]) + + # Cascade preview: deleting one PartB row must propagate up to PartA and Master. + counts = dj.Diagram.cascade( + Master.PartB & {"master_id": 1, "part_a_id": 10, "part_b_id": 100}, + part_integrity="cascade", + ).counts() + + # Master row (1,) is the originating Part's master — must appear with count 1 + assert counts.get(Master.full_table_name, 0) == 1, ( + f"Master restricted by 1 row; got {counts.get(Master.full_table_name)}. " + "Indicates the Part→Master upward propagation did not reach the Master " + "through the intermediate PartA." + ) + # Master cascades back down to ALL of master_id=1's Parts + assert counts.get(Master.PartA.full_table_name, 0) == 2 # rows 10, 11 + assert counts.get(Master.PartB.full_table_name, 0) == 3 # rows under master_id=1 + + +def test_cascade_part_of_part_renamed_fk(schema_by_backend): + """ + Case 1 from #1429: PartB references PartA via a renamed FK (`.proj()`). + PartB has no attribute named `master_id` (renamed to `src_master`). The + upward propagation must use the FK metadata, not assume shared attribute + names. + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA.proj(src_master='master_id', src_part='part_a_id') + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (2, 20, 200)]) + + # PartB has columns: src_master, src_part, part_b_id — NOT master_id. + counts = dj.Diagram.cascade( + Master.PartB & {"src_master": 1, "src_part": 10, "part_b_id": 100}, + part_integrity="cascade", + ).counts() + + assert counts.get(Master.full_table_name, 0) == 1, ( + f"Master restricted by 1 row; got {counts.get(Master.full_table_name)}. " + "Renamed FK was not reversed when propagating up to Master." + ) + assert counts.get(Master.PartA.full_table_name, 0) == 1 + assert counts.get(Master.PartB.full_table_name, 0) == 1 + + +def test_cascade_three_level_part_chain(schema_by_backend): + """ + Three-hop chain (#1429 follow-up review): PartC → PartB → PartA → Master. + Verify intermediate Parts (PartA, PartB) are restricted at every hop, not + just the first, and the master cascades back down to all siblings. + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + class PartC(dj.Part): + definition = """ + -> Master.PartB + part_c_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (1, 11), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + Master.PartC.insert([(1, 10, 100, 1000), (1, 11, 110, 1100), (2, 20, 200, 2000)]) + + counts = dj.Diagram.cascade( + Master.PartC & {"master_id": 1, "part_a_id": 10, "part_b_id": 100, "part_c_id": 1000}, + part_integrity="cascade", + ).counts() + + # Master pulled in via the 3-hop upward walk + assert counts.get(Master.full_table_name, 0) == 1, ( + "Master restriction lost across 3-hop chain — the per-edge upward walk " "did not reach Master through PartA + PartB." + ) + # Master forward-cascades back down to all rows under master_id=1 + assert counts.get(Master.PartA.full_table_name, 0) == 2 # both PartA rows under master 1 + assert counts.get(Master.PartB.full_table_name, 0) == 2 # both PartB rows under master 1 + assert counts.get(Master.PartC.full_table_name, 0) == 2 # both PartC rows under master 1 + + +def test_cascade_part_of_part_actual_delete(schema_by_backend): + """ + End-to-end: actually run delete() with part_integrity="cascade" through + a Part-of-Part chain. Verifies the upward propagation produces SQL that + executes (no MySQL 1093 self-reference; correct row removal). + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (2, 20, 200)]) + + (Master.PartB & {"master_id": 1}).delete(part_integrity="cascade") + + # master_id=1 chain is entirely gone; master_id=2 chain intact. + assert len(Master()) == 1 + assert Master().fetch1("master_id") == 2 + assert len(Master.PartA()) == 1 + assert len(Master.PartB()) == 1 diff --git a/tests/integration/test_trace.py b/tests/integration/test_trace.py new file mode 100644 index 000000000..787635bfe --- /dev/null +++ b/tests/integration/test_trace.py @@ -0,0 +1,293 @@ +""" +Integration tests for ``Diagram.trace()`` — upstream restriction propagation. + +The upstream mirror of ``Diagram.cascade()``. Walks the FK graph from a +restricted seed to every ancestor with OR convergence. Reuses the upward +propagation rules (U1/U2/U3 in cascade.md) added by #1468. +""" + +import pytest + +import datajoint as dj +from datajoint.errors import DataJointError + + +@pytest.fixture(scope="function") +def schema_by_backend(connection_by_backend, db_creds_by_backend, request): + """Create a fresh schema for each trace test.""" + backend = db_creds_by_backend["backend"] + import time + + test_id = str(int(time.time() * 1000))[-8:] + schema_name = f"djtest_trace_{backend}_{test_id}"[:64] + + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass + + schema = dj.Schema(schema_name, connection=connection_by_backend) + yield schema + + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass + + +def test_trace_single_hop(schema_by_backend): + """trace(Child & key)[Parent] returns Parent restricted via the FK.""" + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int32 + --- + name : varchar(64) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Parent.insert([(1, "alice"), (2, "bob")]) + Child.insert([(1, 10), (1, 11), (2, 20)]) + + trace = dj.Diagram.trace(Child & {"parent_id": 1, "child_id": 10}) + + # Seed itself + assert len(trace[Child]) == 1 + + # Ancestor: Parent restricted to the rows that contributed to the seed + assert len(trace[Parent]) == 1 + assert trace[Parent].fetch1("parent_id") == 1 + + +def test_trace_multi_hop(schema_by_backend): + """trace walks through intermediate ancestors (Grandparent ← Parent ← Child).""" + + @schema_by_backend + class Grandparent(dj.Manual): + definition = """ + gp_id : int32 + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> Grandparent + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Grandparent.insert([(1,), (2,)]) + Parent.insert([(1, 10), (1, 11), (2, 20)]) + Child.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + + trace = dj.Diagram.trace(Child & {"gp_id": 1, "parent_id": 10, "child_id": 100}) + + # All three ancestors restricted to the one contributing tuple per level + assert len(trace[Child]) == 1 + assert len(trace[Parent]) == 1 + assert len(trace[Grandparent]) == 1 + assert trace[Grandparent].fetch1("gp_id") == 1 + + +def test_trace_renamed_fk(schema_by_backend): + """Renamed FK (.proj(...)) — the upward rule reverses the rename.""" + + @schema_by_backend + class Animal(dj.Manual): + definition = """ + animal_id : int32 + --- + species : varchar(64) + """ + + @schema_by_backend + class Observation(dj.Manual): + definition = """ + obs_id : int32 + --- + -> Animal.proj(subject_id='animal_id') + measurement : float64 + """ + + Animal.insert([(1, "Mouse"), (2, "Rat")]) + Observation.insert([(10, 1, 1.5), (11, 1, 2.5), (20, 2, 3.0)]) + + # Observation columns: obs_id, subject_id (renamed), measurement. + # No `animal_id` column on Observation — the upward walk must reverse the rename. + trace = dj.Diagram.trace(Observation & {"obs_id": 10}) + + assert len(trace[Animal]) == 1 + assert trace[Animal].fetch1("animal_id") == 1 + assert trace[Animal].fetch1("species") == "Mouse" + + +def test_trace_or_convergence_two_paths(schema_by_backend): + """Two FK paths from child to the same ancestor → OR (union) at the ancestor.""" + + @schema_by_backend + class Source(dj.Manual): + definition = """ + source_id : int32 + """ + + @schema_by_backend + class Downstream(dj.Manual): + definition = """ + downstream_id : int32 + --- + -> Source + -> Source.proj(comparison_src='source_id') + """ + + Source.insert([(1,), (2,), (3,)]) + # Downstream rows reference Source via two columns; OR convergence means the + # ancestor is restricted to the UNION of contributors across both FK paths. + Downstream.insert( + [ + (100, 1, 2), # primary source=1, comparison_src=2 + (101, 3, 3), # primary source=3, comparison_src=3 + ] + ) + + trace = dj.Diagram.trace(Downstream & {"downstream_id": 100}) + + # Source is restricted via BOTH FK paths from row 100 → {1, 2} + contributing = set(trace[Source].fetch("source_id")) + assert contributing == {1, 2} + + +def test_trace_rejects_non_ancestor(schema_by_backend): + """Indexing into a table that isn't in the trace's subgraph raises.""" + + @schema_by_backend + class A(dj.Manual): + definition = """ + a_id : int32 + """ + + @schema_by_backend + class B(dj.Manual): + definition = """ + b_id : int32 + """ + + @schema_by_backend + class C(dj.Manual): + definition = """ + -> A + c_id : int32 + """ + + A.insert([(1,)]) + B.insert([(99,)]) + C.insert([(1, 10)]) + + trace = dj.Diagram.trace(C & {"a_id": 1, "c_id": 10}) + + # A is an ancestor — OK + assert len(trace[A]) == 1 + + # B is unrelated — should raise + with pytest.raises(DataJointError, match="not in this trace"): + trace[B] + + +def test_trace_string_indexing_returns_freetable(schema_by_backend): + """trace[str] returns a FreeTable (no class needed in caller scope).""" + from datajoint.table import FreeTable + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Parent.insert([(1,), (2,)]) + Child.insert([(1, 10), (2, 20)]) + + trace = dj.Diagram.trace(Child & {"parent_id": 1, "child_id": 10}) + + # String accepts the SQL-quoted full name + parent_via_string = trace[Parent.full_table_name] + assert isinstance(parent_via_string, FreeTable) + assert len(parent_via_string) == 1 + + +def test_trace_counts(schema_by_backend): + """trace.counts() reports per-ancestor row counts under the seed's restriction.""" + + @schema_by_backend + class Grandparent(dj.Manual): + definition = """ + gp_id : int32 + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> Grandparent + parent_id : int32 + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int32 + """ + + Grandparent.insert([(1,), (2,)]) + Parent.insert([(1, 10), (1, 11), (2, 20)]) + Child.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + + trace = dj.Diagram.trace(Child & {"gp_id": 1}) + counts = trace.counts() + + assert counts[Grandparent.full_table_name] == 1 + assert counts[Parent.full_table_name] == 2 + assert counts[Child.full_table_name] == 2 + + +def test_trace_seed_with_no_ancestors(schema_by_backend): + """Tracing from a table with no FK parents → trace contains only the seed.""" + + @schema_by_backend + class Standalone(dj.Manual): + definition = """ + std_id : int32 + """ + + Standalone.insert([(1,), (2,)]) + + trace = dj.Diagram.trace(Standalone & {"std_id": 1}) + + # Only the seed is in the trace + assert len(trace[Standalone]) == 1 + counts = trace.counts() + assert counts == {Standalone.full_table_name: 1}