Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,34 @@ def _populate1(

self._upstream = Diagram.trace(self & dict(key))

# If strict_provenance is on, push the active-make context so the
# runtime gates in expression.cursor / table.insert can check this
# make()'s reads and writes. The context is popped in the finally
# block below.
strict_token = None
if self.connection._config.get("strict_provenance", False):
from .provenance import push_strict_make_context
from .user_tables import Part

allowed_tables = set(self._upstream._cascade_restrictions.keys()) | {self.full_table_name}
# Add Part tables of self to the allowed set. Use class __dict__
# (not dir/getattr) to avoid triggering descriptors like the
# _JobsDescriptor that lazy-declares the ~~ job table.
for cls in type(self).__mro__:
for attr_name, attr in cls.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr, type) and issubclass(attr, Part):
# Instantiate to get full_table_name resolved against
# this schema. The Part class is already attached via
# @schema decoration of the master.
try:
part_ftn = attr().full_table_name
allowed_tables.add(part_ftn)
except Exception:
pass
strict_token = push_strict_make_context(self, frozenset(allowed_tables), dict(key))

try:
if not is_generator:
make(dict(key), **(make_kwargs or {}))
Expand Down Expand Up @@ -719,6 +747,11 @@ def _populate1(
# access raises a clear error rather than silently using a
# stale trace from the previous make() call.
self._upstream = None
# Pop the strict-make context, if any.
if strict_token is not None:
from .provenance import pop_strict_make_context

pop_strict_make_context(strict_token)

def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
"""
Expand Down
6 changes: 6 additions & 0 deletions src/datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,12 @@ def cursor(self, as_dict=False):
cursor
Database query cursor.
"""
# Strict-provenance read gate. No-op outside make() or when the
# config flag is off. See src/datajoint/provenance.py.
from .provenance import assert_read_allowed

assert_read_allowed(self)

sql = self.make_sql()
logger.debug(sql)
return self.connection.query(sql, as_dict=as_dict)
Expand Down
193 changes: 193 additions & 0 deletions src/datajoint/provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""
Runtime gates for ``dj.config["strict_provenance"]``.

When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``)
tracks which tables and primary key the currently-executing ``make()`` is
allowed to read and write. The read gate in :func:`assert_read_allowed`
fires inside ``QueryExpression.cursor``; the write gate in
:func:`assert_write_allowed` fires inside ``Table.insert``.

The contract is documented in
``datajoint-docs/src/reference/specs/provenance.md`` §3.

Implementation note: the active-make context is stored in a
``contextvars.ContextVar`` so it propagates correctly across threads
that share the parent's context (e.g. the populate-in-subprocess path
which uses ``multiprocessing`` workers, each of which inherits its
parent's contextvar binding at fork time).
"""

from __future__ import annotations

from contextvars import ContextVar
from typing import TYPE_CHECKING, Optional, Tuple

from .errors import DataJointError

if TYPE_CHECKING:
from .table import Table


# Active context: (the target table, the set of allowed full table names, the current key dict)
_active_strict_make: ContextVar[Optional[Tuple["Table", frozenset[str], dict]]] = ContextVar(
"_dj_active_strict_make", default=None
)


def push_strict_make_context(target: "Table", allowed_tables: frozenset[str], key: dict):
"""
Push a strict-make context for the duration of one ``make()`` invocation.

Returns a token that the caller must pass to :func:`pop_strict_make_context`
in a ``finally`` block.
"""
return _active_strict_make.set((target, allowed_tables, key))


def pop_strict_make_context(token) -> None:
"""Pop the strict-make context using a token from :func:`push_strict_make_context`."""
_active_strict_make.reset(token)


def get_active_context():
"""Return the currently-active strict-make context, or None."""
return _active_strict_make.get()


def _base_tables(query_expression) -> set[str]:
"""
Return the set of base-table SQL names that a QueryExpression reads from.

For a single-table expression (FreeTable / Table / restricted variants),
returns ``{full_table_name}``. For compound expressions (joins,
projections of joins), traverses ``support`` recursively.
"""
# FreeTable / Table: has full_table_name directly
ftn = getattr(query_expression, "full_table_name", None)
if isinstance(ftn, str):
return {ftn}

bases: set[str] = set()
support = getattr(query_expression, "_support", None) or []
for s in support:
if isinstance(s, str):
# Direct table name in the support list
bases.add(s)
else:
# Subquery — recurse
bases.update(_base_tables(s))
return bases


def assert_read_allowed(query_expression) -> None:
"""
Verify a fetch is allowed under the active strict-make context.

Called from ``QueryExpression.cursor`` before SQL is issued. No-op when
no strict-make context is active (i.e. outside ``make()`` or when
``strict_provenance`` is False).

Allowed reads:

- Any table in the active context's ``allowed_tables`` set. The set is
built from ``self.upstream`` (the ancestor graph) plus the target
table and its Parts.

Anything else raises ``DataJointError``.

Known limitation (will sharpen in a follow-up): the check does not
distinguish reads that came *through* ``self.upstream`` from reads of
the same ancestor via a direct expression. Both are allowed if the
table is in the allowed set. The intent is to catch reads from
*undeclared* dependencies; tightening the "must come through
``self.upstream``" path requires propagating an attribution marker
through QueryExpression composition and is deferred.
"""
ctx = _active_strict_make.get()
if ctx is None:
return # strict mode off, or outside make()

_target, allowed_tables, _key = ctx
bases = _base_tables(query_expression)
if not bases:
return # nothing to check (e.g. dj.U expressions)

disallowed = bases - allowed_tables
if disallowed:
raise DataJointError(
f"strict_provenance=True: read from undeclared table(s) "
f"{sorted(disallowed)} is not permitted inside make(). "
f"Use self.upstream[T] for declared ancestors, or declare a "
f"foreign-key dependency on the table you want to read."
)


def assert_write_allowed(target_table, rows) -> None:
"""
Verify an insert is allowed under the active strict-make context.

Called from ``Table.insert`` after the existing ``_allow_insert`` check.
No-op when no strict-make context is active.

Allowed writes:

- Target is the current ``make()`` target (``self``) or one of its Part
tables.
- Every row's primary-key columns that overlap with the current ``key``
must equal ``key``'s values.

Anything else raises ``DataJointError``.
"""
ctx = _active_strict_make.get()
if ctx is None:
return

make_target, _allowed_tables, key = ctx

# 1. Target must be `make_target` (self) or one of its Parts.
target_name = getattr(target_table, "full_table_name", None)
target_set = {make_target.full_table_name}
# Collect Part tables of make_target via class __dict__ (not dir/getattr,
# which would trigger descriptors like the _JobsDescriptor).
from .user_tables import Part # local import to avoid circular dep

for cls in type(make_target).__mro__:
for attr_name, attr in cls.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr, type) and issubclass(attr, Part):
try:
part_ftn = attr().full_table_name
target_set.add(part_ftn)
except Exception:
pass

if target_name not in target_set:
raise DataJointError(
f"strict_provenance=True: insert into {target_name!r} is not permitted "
f"inside make() for {make_target.full_table_name!r}. Only the target "
f"table and its Part tables may be written."
)

# 2. Each row's key columns that overlap with the current key must match.
if isinstance(rows, dict):
_check_row_key(rows, key)
else:
try:
for row in rows:
if isinstance(row, dict):
_check_row_key(row, key)
# Non-dict rows (tuples, etc.) bypass — older API; can't check.
except TypeError:
pass # not iterable; let downstream code handle


def _check_row_key(row: dict, current_key: dict) -> None:
"""Raise if any row attribute overlapping with the current key has a different value."""
for k, v in current_key.items():
if k in row and row[k] != v:
raise DataJointError(
f"strict_provenance=True: inserted row's {k!r}={row[k]!r} does not "
f"match the current make() key's {k!r}={v!r}. Inserts must be "
f"consistent with the key being populated."
)
11 changes: 11 additions & 0 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"database.database_prefix": "DJ_DATABASE_PREFIX",
"database.create_tables": "DJ_CREATE_TABLES",
"loglevel": "DJ_LOG_LEVEL",
"strict_provenance": "DJ_STRICT_PROVENANCE",
"display.diagram_direction": "DJ_DIAGRAM_DIRECTION",
}

Expand Down Expand Up @@ -361,6 +362,16 @@ class Config(BaseSettings):
"*New in 2.2.3.*",
)

strict_provenance: bool = Field(
default=False,
validation_alias="DJ_STRICT_PROVENANCE",
description="If True, enforces the upstream-only convention inside make(): "
"reads must go through self.upstream[Ancestor], writes must target self "
"or self's Part tables with primary keys consistent with the current key. "
"Off by default; opt-in for deployments that need runtime provenance "
"guarantees backing downstream lineage / CDC tooling. *New in 2.3.*",
)

# Cache path for query results
query_cache: Path | None = None

Expand Down
6 changes: 6 additions & 0 deletions src/datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,12 @@ def insert(
" To override, set keyword argument allow_direct_insert=True."
)

# Strict-provenance write gate. No-op outside make() or when the
# config flag is off. See src/datajoint/provenance.py.
from .provenance import assert_write_allowed

assert_write_allowed(self, rows)

if inspect.isclass(rows) and issubclass(rows, QueryExpression):
rows = rows() # instantiate if a class
if isinstance(rows, QueryExpression):
Expand Down
Loading
Loading