diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 8f0946a06..d33e6ccf0 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -658,6 +658,34 @@ def _populate1( self._upstream = Diagram.trace(self & dict(key)) + # If strict_provenance is on, push the active-make context so the + # runtime gates in expression.cursor / table.insert can check this + # make()'s reads and writes. The context is popped in the finally + # block below. + strict_token = None + if self.connection._config.get("strict_provenance", False): + from .provenance import push_strict_make_context + from .user_tables import Part + + allowed_tables = set(self._upstream._cascade_restrictions.keys()) | {self.full_table_name} + # Add Part tables of self to the allowed set. Use class __dict__ + # (not dir/getattr) to avoid triggering descriptors like the + # _JobsDescriptor that lazy-declares the ~~ job table. + for cls in type(self).__mro__: + for attr_name, attr in cls.__dict__.items(): + if attr_name.startswith("_"): + continue + if isinstance(attr, type) and issubclass(attr, Part): + # Instantiate to get full_table_name resolved against + # this schema. The Part class is already attached via + # @schema decoration of the master. + try: + part_ftn = attr().full_table_name + allowed_tables.add(part_ftn) + except Exception: + pass + strict_token = push_strict_make_context(self, frozenset(allowed_tables), dict(key)) + try: if not is_generator: make(dict(key), **(make_kwargs or {})) @@ -719,6 +747,11 @@ def _populate1( # access raises a clear error rather than silently using a # stale trace from the previous make() call. self._upstream = None + # Pop the strict-make context, if any. + if strict_token is not None: + from .provenance import pop_strict_make_context + + pop_strict_make_context(strict_token) def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]: """ diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 1b5f5ac9e..f380b3b52 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -1242,6 +1242,12 @@ def cursor(self, as_dict=False): cursor Database query cursor. """ + # Strict-provenance read gate. No-op outside make() or when the + # config flag is off. See src/datajoint/provenance.py. + from .provenance import assert_read_allowed + + assert_read_allowed(self) + sql = self.make_sql() logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) diff --git a/src/datajoint/provenance.py b/src/datajoint/provenance.py new file mode 100644 index 000000000..e124d1160 --- /dev/null +++ b/src/datajoint/provenance.py @@ -0,0 +1,193 @@ +""" +Runtime gates for ``dj.config["strict_provenance"]``. + +When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``) +tracks which tables and primary key the currently-executing ``make()`` is +allowed to read and write. The read gate in :func:`assert_read_allowed` +fires inside ``QueryExpression.cursor``; the write gate in +:func:`assert_write_allowed` fires inside ``Table.insert``. + +The contract is documented in +``datajoint-docs/src/reference/specs/provenance.md`` §3. + +Implementation note: the active-make context is stored in a +``contextvars.ContextVar`` so it propagates correctly across threads +that share the parent's context (e.g. the populate-in-subprocess path +which uses ``multiprocessing`` workers, each of which inherits its +parent's contextvar binding at fork time). +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING, Optional, Tuple + +from .errors import DataJointError + +if TYPE_CHECKING: + from .table import Table + + +# Active context: (the target table, the set of allowed full table names, the current key dict) +_active_strict_make: ContextVar[Optional[Tuple["Table", frozenset[str], dict]]] = ContextVar( + "_dj_active_strict_make", default=None +) + + +def push_strict_make_context(target: "Table", allowed_tables: frozenset[str], key: dict): + """ + Push a strict-make context for the duration of one ``make()`` invocation. + + Returns a token that the caller must pass to :func:`pop_strict_make_context` + in a ``finally`` block. + """ + return _active_strict_make.set((target, allowed_tables, key)) + + +def pop_strict_make_context(token) -> None: + """Pop the strict-make context using a token from :func:`push_strict_make_context`.""" + _active_strict_make.reset(token) + + +def get_active_context(): + """Return the currently-active strict-make context, or None.""" + return _active_strict_make.get() + + +def _base_tables(query_expression) -> set[str]: + """ + Return the set of base-table SQL names that a QueryExpression reads from. + + For a single-table expression (FreeTable / Table / restricted variants), + returns ``{full_table_name}``. For compound expressions (joins, + projections of joins), traverses ``support`` recursively. + """ + # FreeTable / Table: has full_table_name directly + ftn = getattr(query_expression, "full_table_name", None) + if isinstance(ftn, str): + return {ftn} + + bases: set[str] = set() + support = getattr(query_expression, "_support", None) or [] + for s in support: + if isinstance(s, str): + # Direct table name in the support list + bases.add(s) + else: + # Subquery — recurse + bases.update(_base_tables(s)) + return bases + + +def assert_read_allowed(query_expression) -> None: + """ + Verify a fetch is allowed under the active strict-make context. + + Called from ``QueryExpression.cursor`` before SQL is issued. No-op when + no strict-make context is active (i.e. outside ``make()`` or when + ``strict_provenance`` is False). + + Allowed reads: + + - Any table in the active context's ``allowed_tables`` set. The set is + built from ``self.upstream`` (the ancestor graph) plus the target + table and its Parts. + + Anything else raises ``DataJointError``. + + Known limitation (will sharpen in a follow-up): the check does not + distinguish reads that came *through* ``self.upstream`` from reads of + the same ancestor via a direct expression. Both are allowed if the + table is in the allowed set. The intent is to catch reads from + *undeclared* dependencies; tightening the "must come through + ``self.upstream``" path requires propagating an attribution marker + through QueryExpression composition and is deferred. + """ + ctx = _active_strict_make.get() + if ctx is None: + return # strict mode off, or outside make() + + _target, allowed_tables, _key = ctx + bases = _base_tables(query_expression) + if not bases: + return # nothing to check (e.g. dj.U expressions) + + disallowed = bases - allowed_tables + if disallowed: + raise DataJointError( + f"strict_provenance=True: read from undeclared table(s) " + f"{sorted(disallowed)} is not permitted inside make(). " + f"Use self.upstream[T] for declared ancestors, or declare a " + f"foreign-key dependency on the table you want to read." + ) + + +def assert_write_allowed(target_table, rows) -> None: + """ + Verify an insert is allowed under the active strict-make context. + + Called from ``Table.insert`` after the existing ``_allow_insert`` check. + No-op when no strict-make context is active. + + Allowed writes: + + - Target is the current ``make()`` target (``self``) or one of its Part + tables. + - Every row's primary-key columns that overlap with the current ``key`` + must equal ``key``'s values. + + Anything else raises ``DataJointError``. + """ + ctx = _active_strict_make.get() + if ctx is None: + return + + make_target, _allowed_tables, key = ctx + + # 1. Target must be `make_target` (self) or one of its Parts. + target_name = getattr(target_table, "full_table_name", None) + target_set = {make_target.full_table_name} + # Collect Part tables of make_target via class __dict__ (not dir/getattr, + # which would trigger descriptors like the _JobsDescriptor). + from .user_tables import Part # local import to avoid circular dep + + for cls in type(make_target).__mro__: + for attr_name, attr in cls.__dict__.items(): + if attr_name.startswith("_"): + continue + if isinstance(attr, type) and issubclass(attr, Part): + try: + part_ftn = attr().full_table_name + target_set.add(part_ftn) + except Exception: + pass + + if target_name not in target_set: + raise DataJointError( + f"strict_provenance=True: insert into {target_name!r} is not permitted " + f"inside make() for {make_target.full_table_name!r}. Only the target " + f"table and its Part tables may be written." + ) + + # 2. Each row's key columns that overlap with the current key must match. + if isinstance(rows, dict): + _check_row_key(rows, key) + else: + try: + for row in rows: + if isinstance(row, dict): + _check_row_key(row, key) + # Non-dict rows (tuples, etc.) bypass — older API; can't check. + except TypeError: + pass # not iterable; let downstream code handle + + +def _check_row_key(row: dict, current_key: dict) -> None: + """Raise if any row attribute overlapping with the current key has a different value.""" + for k, v in current_key.items(): + if k in row and row[k] != v: + raise DataJointError( + f"strict_provenance=True: inserted row's {k!r}={row[k]!r} does not " + f"match the current make() key's {k!r}={v!r}. Inserts must be " + f"consistent with the key being populated." + ) diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 7a035f6d8..6ae23478b 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -69,6 +69,7 @@ "database.database_prefix": "DJ_DATABASE_PREFIX", "database.create_tables": "DJ_CREATE_TABLES", "loglevel": "DJ_LOG_LEVEL", + "strict_provenance": "DJ_STRICT_PROVENANCE", "display.diagram_direction": "DJ_DIAGRAM_DIRECTION", } @@ -361,6 +362,16 @@ class Config(BaseSettings): "*New in 2.2.3.*", ) + strict_provenance: bool = Field( + default=False, + validation_alias="DJ_STRICT_PROVENANCE", + description="If True, enforces the upstream-only convention inside make(): " + "reads must go through self.upstream[Ancestor], writes must target self " + "or self's Part tables with primary keys consistent with the current key. " + "Off by default; opt-in for deployments that need runtime provenance " + "guarantees backing downstream lineage / CDC tooling. *New in 2.3.*", + ) + # Cache path for query results query_cache: Path | None = None diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 7f8cbaf70..944bb1b63 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -797,6 +797,12 @@ def insert( " To override, set keyword argument allow_direct_insert=True." ) + # Strict-provenance write gate. No-op outside make() or when the + # config flag is off. See src/datajoint/provenance.py. + from .provenance import assert_write_allowed + + assert_write_allowed(self, rows) + if inspect.isclass(rows) and issubclass(rows, QueryExpression): rows = rows() # instantiate if a class if isinstance(rows, QueryExpression): diff --git a/tests/integration/test_strict_provenance.py b/tests/integration/test_strict_provenance.py new file mode 100644 index 000000000..ce3a0e5b9 --- /dev/null +++ b/tests/integration/test_strict_provenance.py @@ -0,0 +1,244 @@ +""" +Integration tests for ``dj.config["strict_provenance"]`` (#1425). + +Strict mode gates reads (``QueryExpression.cursor``) and writes +(``Table.insert``) inside ``make()`` to the declared upstream graph +and the target table + its Parts. Off by default; opt-in. +""" + +import pytest + +import datajoint as dj +from datajoint import DataJointError + + +@pytest.fixture +def strict_mode(connection_test): + """Enable strict_provenance for the duration of one test.""" + config = connection_test._config + previous = config.get("strict_provenance", False) + config["strict_provenance"] = True + try: + yield + finally: + config["strict_provenance"] = previous + + +def test_strict_compliant_make_passes(prefix, connection_test, strict_mode): + """A make() that reads via self.upstream and writes to self with key consistency runs cleanly.""" + schema = dj.Schema(f"{prefix}_strict_compliant", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + --- + name : varchar(64) + """ + contents = [(1, "alice"), (2, "bob")] + + @schema + class Greeting(dj.Computed): + definition = """ + -> Subject + --- + greeting : varchar(128) + """ + + def make(self, key): + name = self.upstream[Subject].fetch1("name") + self.insert1({**key, "greeting": f"Hello, {name}!"}) + + Greeting.populate() + assert (Greeting & {"subject_id": 1}).fetch1("greeting") == "Hello, alice!" + assert (Greeting & {"subject_id": 2}).fetch1("greeting") == "Hello, bob!" + + +def test_strict_blocks_read_from_undeclared_table(prefix, connection_test, strict_mode): + """Reading from a table NOT in the trace's ancestor set raises.""" + schema = dj.Schema(f"{prefix}_strict_undeclared", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Unrelated(dj.Lookup): + definition = """ + u_id : int32 + --- + secret : varchar(64) + """ + contents = [(42, "should-not-read")] + + captured: list[Exception] = [] + + @schema + class Bad(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + Unrelated.fetch() # not in declared upstream of Bad + except DataJointError as e: + captured.append(e) + # Insert anyway so populate doesn't fail + self.insert1({**key, "val": 0}) + + Bad.populate() + assert len(captured) == 1 + assert "strict_provenance" in str(captured[0]).lower() + assert "undeclared" in str(captured[0]).lower() + + +def test_strict_blocks_write_to_other_table(prefix, connection_test, strict_mode): + """Writing into a table other than self / self.Parts raises.""" + schema = dj.Schema(f"{prefix}_strict_other_target", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class AuditLog(dj.Manual): + definition = """ + log_id : int32 + --- + event : varchar(64) + """ + + captured: list[Exception] = [] + + @schema + class Derived(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + AuditLog.insert1({"log_id": 1, "event": "side-effect"}, allow_direct_insert=True) + except DataJointError as e: + captured.append(e) + self.insert1({**key, "val": 1}) + + Derived.populate() + assert len(captured) == 1 + assert "strict_provenance" in str(captured[0]).lower() + assert "not permitted" in str(captured[0]).lower() + + +def test_strict_blocks_write_with_mismatched_key(prefix, connection_test, strict_mode): + """Writing a row whose PK columns disagree with the current key raises.""" + schema = dj.Schema(f"{prefix}_strict_key_mismatch", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,), (2,)] + + captured: list[Exception] = [] + + @schema + class Wrong(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + try: + # Try to insert a row for a DIFFERENT subject than the current key + bogus_key = {"subject_id": 99} + self.insert1({**bogus_key, "val": 0}) + except DataJointError as e: + captured.append(e) + # Insert correctly to let populate complete + self.insert1({**key, "val": 1}) + + Wrong.populate() + assert len(captured) == 2 # fires for both subjects + assert all("does not match the current make() key" in str(e) for e in captured) + + +def test_strict_writes_to_part_table_pass(prefix, connection_test, strict_mode): + """Writing into self.Parts (with key consistency) is allowed.""" + schema = dj.Schema(f"{prefix}_strict_parts", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class Master(dj.Computed): + definition = """ + -> Subject + --- + summary : varchar(32) + """ + + class Bin(dj.Part): + definition = """ + -> master + bin_id : int32 + --- + energy : float64 + """ + + def make(self, key): + self.insert1({**key, "summary": "ok"}) + self.Bin.insert([{**key, "bin_id": i, "energy": float(i)} for i in range(3)]) + + Master.populate() + assert (Master & {"subject_id": 1}).fetch1("summary") == "ok" + assert len(Master.Bin & {"subject_id": 1}) == 3 + + +def test_strict_off_by_default_no_change(prefix, connection_test): + """With strict_provenance unset (default False), existing patterns work unchanged.""" + schema = dj.Schema(f"{prefix}_strict_default_off", connection=connection_test) + + @schema + class Subject(dj.Lookup): + definition = """ + subject_id : int32 + """ + contents = [(1,)] + + @schema + class DerivedLegacy(dj.Computed): + definition = """ + -> Subject + --- + val : int32 + """ + + def make(self, key): + # Direct ancestor fetch — would be flagged in strict mode (read from + # undeclared, but Subject IS an ancestor — actually allowed under + # the current "table in allowed set" rule even in strict mode). + # In default-off mode, this must work either way. + (Subject & key).fetch1("subject_id") + self.insert1({**key, "val": 0}) + + # No strict_mode fixture — default-off + DerivedLegacy.populate() + assert (DerivedLegacy & {"subject_id": 1}).fetch1("val") == 0