import logging
import uuid
import zlib
from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import pendulum
import sqlalchemy as db
from dagster import check
from dagster.core.errors import (
DagsterInvariantViolationError,
DagsterRunAlreadyExists,
DagsterRunNotFoundError,
DagsterSnapshotDoesNotExist,
)
from dagster.core.events import EVENT_TYPE_TO_PIPELINE_RUN_STATUS, DagsterEvent, DagsterEventType
from dagster.core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster.core.snap import (
ExecutionPlanSnapshot,
PipelineSnapshot,
create_execution_plan_snapshot_id,
create_pipeline_snapshot_id,
)
from dagster.core.storage.tags import PARTITION_NAME_TAG, PARTITION_SET_TAG, ROOT_RUN_ID_TAG
from dagster.daemon.types import DaemonHeartbeat
from dagster.serdes import (
deserialize_as,
deserialize_json_to_dagster_namedtuple,
serialize_dagster_namedtuple,
)
from dagster.seven import JSONDecodeError
from dagster.utils import merge_dicts, utc_datetime_from_timestamp
from ..pipeline_run import JobBucket, PipelineRun, RunRecord, RunsFilter, TagBucket
from .base import RunStorage
from .migration import OPTIONAL_DATA_MIGRATIONS, REQUIRED_DATA_MIGRATIONS, RUN_PARTITIONS
from .schema import (
BulkActionsTable,
DaemonHeartbeatsTable,
InstanceInfo,
RunTagsTable,
RunsTable,
SecondaryIndexMigrationTable,
SnapshotsTable,
)
class SnapshotType(Enum):
PIPELINE = "PIPELINE"
EXECUTION_PLAN = "EXECUTION_PLAN"
[docs]class SqlRunStorage(RunStorage): # pylint: disable=no-init
"""Base class for SQL based run storages"""
@abstractmethod
def connect(self):
"""Context manager yielding a sqlalchemy.engine.Connection."""
@abstractmethod
def upgrade(self):
"""This method should perform any schema or data migrations necessary to bring an
out-of-date instance of the storage up to date.
"""
def fetchall(self, query):
with self.connect() as conn:
result_proxy = conn.execute(query)
res = result_proxy.fetchall()
result_proxy.close()
return res
def fetchone(self, query):
with self.connect() as conn:
result_proxy = conn.execute(query)
row = result_proxy.fetchone()
result_proxy.close()
return row
def add_run(self, pipeline_run: PipelineRun) -> PipelineRun:
check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
if pipeline_run.pipeline_snapshot_id and not self.has_pipeline_snapshot(
pipeline_run.pipeline_snapshot_id
):
raise DagsterSnapshotDoesNotExist(
"Snapshot {ss_id} does not exist in run storage".format(
ss_id=pipeline_run.pipeline_snapshot_id
)
)
has_tags = pipeline_run.tags and len(pipeline_run.tags) > 0
partition = pipeline_run.tags.get(PARTITION_NAME_TAG) if has_tags else None
partition_set = pipeline_run.tags.get(PARTITION_SET_TAG) if has_tags else None
runs_insert = RunsTable.insert().values( # pylint: disable=no-value-for-parameter
run_id=pipeline_run.run_id,
pipeline_name=pipeline_run.pipeline_name,
status=pipeline_run.status.value,
run_body=serialize_dagster_namedtuple(pipeline_run),
snapshot_id=pipeline_run.pipeline_snapshot_id,
partition=partition,
partition_set=partition_set,
)
with self.connect() as conn:
try:
conn.execute(runs_insert)
except db.exc.IntegrityError as exc:
raise DagsterRunAlreadyExists from exc
if pipeline_run.tags and len(pipeline_run.tags) > 0:
conn.execute(
RunTagsTable.insert(), # pylint: disable=no-value-for-parameter
[
dict(run_id=pipeline_run.run_id, key=k, value=v)
for k, v in pipeline_run.tags.items()
],
)
return pipeline_run
def handle_run_event(self, run_id: str, event: DagsterEvent):
check.str_param(run_id, "run_id")
check.inst_param(event, "event", DagsterEvent)
if event.event_type not in EVENT_TYPE_TO_PIPELINE_RUN_STATUS:
return
run = self.get_run_by_id(run_id)
if not run:
# TODO log?
return
new_pipeline_status = EVENT_TYPE_TO_PIPELINE_RUN_STATUS[event.event_type]
run_stats_cols_in_index = self.has_run_stats_index_cols()
kwargs = {}
# consider changing the `handle_run_event` signature to get timestamp off of the
# EventLogEntry instead of the DagsterEvent, for consistency
now = pendulum.now("UTC")
if run_stats_cols_in_index and event.event_type == DagsterEventType.PIPELINE_START:
kwargs["start_time"] = now.timestamp()
if run_stats_cols_in_index and event.event_type in {
DagsterEventType.PIPELINE_CANCELED,
DagsterEventType.PIPELINE_FAILURE,
DagsterEventType.PIPELINE_SUCCESS,
}:
kwargs["end_time"] = now.timestamp()
with self.connect() as conn:
conn.execute(
RunsTable.update() # pylint: disable=no-value-for-parameter
.where(RunsTable.c.run_id == run_id)
.values(
status=new_pipeline_status.value,
run_body=serialize_dagster_namedtuple(run.with_status(new_pipeline_status)),
update_timestamp=now,
**kwargs,
)
)
def _row_to_run(self, row: Tuple) -> PipelineRun:
return deserialize_as(row[0], PipelineRun)
def _rows_to_runs(self, rows: Iterable[Tuple]) -> List[PipelineRun]:
return list(map(self._row_to_run, rows))
def _add_cursor_limit_to_query(
self,
query,
cursor: Optional[str],
limit: Optional[int],
order_by: Optional[str],
ascending: Optional[bool],
):
"""Helper function to deal with cursor/limit pagination args"""
if cursor:
cursor_query = db.select([RunsTable.c.id]).where(RunsTable.c.run_id == cursor)
query = query.where(RunsTable.c.id < cursor_query)
if limit:
query = query.limit(limit)
sorting_column = getattr(RunsTable.c, order_by) if order_by else RunsTable.c.id
direction = db.asc if ascending else db.desc
query = query.order_by(direction(sorting_column))
return query
def _add_filters_to_query(self, query, filters: RunsFilter):
check.inst_param(filters, "filters", RunsFilter)
if filters.run_ids:
query = query.where(RunsTable.c.run_id.in_(filters.run_ids))
if filters.job_name:
query = query.where(RunsTable.c.pipeline_name == filters.job_name)
if filters.mode:
query = query.where(RunsTable.c.mode == filters.mode)
if filters.statuses:
query = query.where(
RunsTable.c.status.in_([status.value for status in filters.statuses])
)
if filters.tags:
query = query.where(
db.or_(
*(
db.and_(RunTagsTable.c.key == key, RunTagsTable.c.value == value)
for key, value in filters.tags.items()
)
)
).group_by(RunsTable.c.run_body, RunsTable.c.id)
if len(filters.tags) > 0:
query = query.having(db.func.count(RunsTable.c.run_id) == len(filters.tags))
if filters.snapshot_id:
query = query.where(RunsTable.c.snapshot_id == filters.snapshot_id)
if filters.updated_after:
query = query.where(RunsTable.c.update_timestamp > filters.updated_after)
if filters.created_before:
query = query.where(RunsTable.c.create_timestamp < filters.created_before)
return query
def _runs_query(
self,
filters: Optional[RunsFilter] = None,
cursor: Optional[str] = None,
limit: Optional[int] = None,
columns: Optional[List[str]] = None,
order_by: Optional[str] = None,
ascending: bool = False,
bucket_by: Optional[Union[JobBucket, TagBucket]] = None,
):
filters = check.opt_inst_param(filters, "filters", RunsFilter, default=RunsFilter())
check.opt_str_param(cursor, "cursor")
check.opt_int_param(limit, "limit")
check.opt_list_param(columns, "columns")
check.opt_str_param(order_by, "order_by")
check.opt_bool_param(ascending, "ascending")
if columns is None:
columns = ["run_body"]
if bucket_by:
if limit or cursor:
check.failed("cannot specify bucket_by and limit/cursor at the same time")
return self._bucketed_runs_query(bucket_by, filters, columns, order_by, ascending)
query_columns = [getattr(RunsTable.c, column) for column in columns]
if filters.tags:
base_query = db.select(query_columns).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
else:
base_query = db.select(query_columns).select_from(RunsTable)
base_query = self._add_filters_to_query(base_query, filters)
return self._add_cursor_limit_to_query(base_query, cursor, limit, order_by, ascending)
def _bucket_rank_column(self, bucket_by, order_by, ascending):
check.inst_param(bucket_by, "bucket_by", (JobBucket, TagBucket))
check.invariant(
self.supports_bucket_queries, "Bucket queries are not supported by this storage layer"
)
sorting_column = getattr(RunsTable.c, order_by) if order_by else RunsTable.c.id
direction = db.asc if ascending else db.desc
bucket_column = (
RunsTable.c.pipeline_name if isinstance(bucket_by, JobBucket) else RunTagsTable.c.value
)
return (
db.func.rank()
.over(order_by=direction(sorting_column), partition_by=bucket_column)
.label("rank")
)
def _bucketed_runs_query(
self,
bucket_by: Union[JobBucket, TagBucket],
filters: RunsFilter,
columns: List[str],
order_by: Optional[str] = None,
ascending: bool = False,
):
bucket_rank = self._bucket_rank_column(bucket_by, order_by, ascending)
query_columns = [getattr(RunsTable.c, column) for column in columns] + [bucket_rank]
if isinstance(bucket_by, JobBucket):
# bucketing by job
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
if filters.tags
else RunsTable
)
.where(RunsTable.c.pipeline_name.in_(bucket_by.job_names))
)
base_query = self._add_filters_to_query(base_query, filters)
elif not filters.tags:
# bucketing by tag, no tag filters
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
)
base_query = self._add_filters_to_query(base_query, filters)
else:
# there are tag filters as well as tag buckets, so we have to apply the tag filters in
# a separate join
filtered_query = db.select([RunsTable.c.run_id]).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
filtered_query = self._add_filters_to_query(filtered_query, filters)
filtered_query = filtered_query.alias("filtered_query")
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id).join(
filtered_query, RunsTable.c.run_id == filtered_query.c.run_id
)
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
)
subquery = base_query.alias("subquery")
# select all the columns, but skip the bucket_rank column, which is only used for applying
# the limit / order
subquery_columns = [getattr(subquery.c, column) for column in columns]
query = db.select(subquery_columns).order_by(subquery.c.rank.asc())
if bucket_by.bucket_limit:
query = query.where(subquery.c.rank <= bucket_by.bucket_limit)
return query
def get_runs(
self,
filters: Optional[RunsFilter] = None,
cursor: Optional[str] = None,
limit: Optional[int] = None,
bucket_by: Optional[Union[JobBucket, TagBucket]] = None,
) -> List[PipelineRun]:
query = self._runs_query(filters, cursor, limit, bucket_by=bucket_by)
rows = self.fetchall(query)
return self._rows_to_runs(rows)
def get_runs_count(self, filters: Optional[RunsFilter] = None) -> int:
subquery = self._runs_query(filters=filters).alias("subquery")
# We use an alias here because Postgres requires subqueries to be
# aliased.
subquery = subquery.alias("subquery")
query = db.select([db.func.count()]).select_from(subquery)
rows = self.fetchall(query)
count = rows[0][0]
return count
def get_run_by_id(self, run_id: str) -> Optional[PipelineRun]:
"""Get a run by its id.
Args:
run_id (str): The id of the run
Returns:
Optional[PipelineRun]
"""
check.str_param(run_id, "run_id")
query = db.select([RunsTable.c.run_body]).where(RunsTable.c.run_id == run_id)
rows = self.fetchall(query)
return deserialize_as(rows[0][0], PipelineRun) if len(rows) else None
def get_run_records(
self,
filters: Optional[RunsFilter] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
ascending: bool = False,
cursor: Optional[str] = None,
bucket_by: Optional[Union[JobBucket, TagBucket]] = None,
) -> List[RunRecord]:
filters = check.opt_inst_param(filters, "filters", RunsFilter, default=RunsFilter())
check.opt_int_param(limit, "limit")
columns = ["id", "run_body", "create_timestamp", "update_timestamp"]
if self.has_run_stats_index_cols():
columns += ["start_time", "end_time"]
# only fetch columns we use to build RunRecord
query = self._runs_query(
filters=filters,
limit=limit,
columns=columns,
order_by=order_by,
ascending=ascending,
cursor=cursor,
bucket_by=bucket_by,
)
rows = self.fetchall(query)
return [
RunRecord(
storage_id=check.int_param(row["id"], "id"),
pipeline_run=deserialize_as(
check.str_param(row["run_body"], "run_body"), PipelineRun
),
create_timestamp=check.inst(row["create_timestamp"], datetime),
update_timestamp=check.inst(row["update_timestamp"], datetime),
start_time=check.opt_inst(row["start_time"], float)
if "start_time" in row
else None,
end_time=check.opt_inst(row["end_time"], float) if "end_time" in row else None,
)
for row in rows
]
def get_run_tags(self) -> List[Tuple[str, Set[str]]]:
result = defaultdict(set)
query = db.select([RunTagsTable.c.key, RunTagsTable.c.value]).distinct(
RunTagsTable.c.key, RunTagsTable.c.value
)
rows = self.fetchall(query)
for r in rows:
result[r[0]].add(r[1])
return sorted(list([(k, v) for k, v in result.items()]), key=lambda x: x[0])
def add_run_tags(self, run_id: str, new_tags: Dict[str, str]):
check.str_param(run_id, "run_id")
check.dict_param(new_tags, "new_tags", key_type=str, value_type=str)
run = self.get_run_by_id(run_id)
if not run:
raise DagsterRunNotFoundError(
f"Run {run_id} was not found in instance.", invalid_run_id=run_id
)
current_tags = run.tags if run.tags else {}
all_tags = merge_dicts(current_tags, new_tags)
partition = all_tags.get(PARTITION_NAME_TAG)
partition_set = all_tags.get(PARTITION_SET_TAG)
with self.connect() as conn:
conn.execute(
RunsTable.update() # pylint: disable=no-value-for-parameter
.where(RunsTable.c.run_id == run_id)
.values(
run_body=serialize_dagster_namedtuple(
run.with_tags(merge_dicts(current_tags, new_tags))
),
partition=partition,
partition_set=partition_set,
update_timestamp=pendulum.now("UTC"),
)
)
current_tags_set = set(current_tags.keys())
new_tags_set = set(new_tags.keys())
existing_tags = current_tags_set & new_tags_set
added_tags = new_tags_set.difference(existing_tags)
for tag in existing_tags:
conn.execute(
RunTagsTable.update() # pylint: disable=no-value-for-parameter
.where(db.and_(RunTagsTable.c.run_id == run_id, RunTagsTable.c.key == tag))
.values(value=new_tags[tag])
)
if added_tags:
conn.execute(
RunTagsTable.insert(), # pylint: disable=no-value-for-parameter
[dict(run_id=run_id, key=tag, value=new_tags[tag]) for tag in added_tags],
)
def get_run_group(self, run_id: str) -> Optional[Tuple[str, Iterable[PipelineRun]]]:
check.str_param(run_id, "run_id")
pipeline_run = self.get_run_by_id(run_id)
if not pipeline_run:
raise DagsterRunNotFoundError(
f"Run {run_id} was not found in instance.", invalid_run_id=run_id
)
# find root_run
root_run_id = pipeline_run.root_run_id if pipeline_run.root_run_id else pipeline_run.run_id
root_run = self.get_run_by_id(root_run_id)
if not root_run:
raise DagsterRunNotFoundError(
f"Run id {root_run} set as root run id for run {run_id} was not found in instance.",
invalid_run_id=root_run,
)
# root_run_id to run_id 1:1 mapping
# https://github.com/dagster-io/dagster/issues/2495
# Note: we currently use tags to persist the run group info
root_to_run = (
db.select(
[RunTagsTable.c.value.label("root_run_id"), RunTagsTable.c.run_id.label("run_id")]
)
.where(
db.and_(RunTagsTable.c.key == ROOT_RUN_ID_TAG, RunTagsTable.c.value == root_run_id)
)
.alias("root_to_run")
)
# get run group
run_group_query = (
db.select([RunsTable.c.run_body])
.select_from(
root_to_run.join(
RunsTable,
root_to_run.c.run_id == RunsTable.c.run_id,
isouter=True,
)
)
.alias("run_group")
)
with self.connect() as conn:
res = conn.execute(run_group_query)
run_group = self._rows_to_runs(res)
return (root_run_id, [root_run] + run_group)
def get_run_groups(
self,
filters: Optional[RunsFilter] = None,
cursor: Optional[str] = None,
limit: Optional[int] = None,
) -> Dict[str, Dict[str, Union[Iterable[PipelineRun], int]]]:
# The runs that would be returned by calling RunStorage.get_runs with the same arguments
runs = self._runs_query(
filters=filters, cursor=cursor, limit=limit, columns=["run_body", "run_id"]
).alias("runs")
# Gets us the run_id and associated root_run_id for every run in storage that is a
# descendant run of some root
#
# pseudosql:
# with all_descendant_runs as (
# select *
# from run_tags
# where key = @ROOT_RUN_ID_TAG
# )
all_descendant_runs = (
db.select([RunTagsTable])
.where(RunTagsTable.c.key == ROOT_RUN_ID_TAG)
.alias("all_descendant_runs")
)
# Augment the runs in our query, for those runs that are the descendant of some root run,
# with the root_run_id
#
# pseudosql:
#
# with runs_augmented as (
# select
# runs.run_id as run_id,
# all_descendant_runs.value as root_run_id
# from runs
# left outer join all_descendant_runs
# on all_descendant_runs.run_id = runs.run_id
# )
runs_augmented = (
db.select(
[
runs.c.run_id.label("run_id"),
all_descendant_runs.c.value.label("root_run_id"),
]
)
.select_from(
runs.join(
all_descendant_runs,
all_descendant_runs.c.run_id == RunsTable.c.run_id,
isouter=True,
)
)
.alias("runs_augmented")
)
# Get all the runs our query will return. This includes runs as well as their root runs.
#
# pseudosql:
#
# with runs_and_root_runs as (
# select runs.run_id as run_id
# from runs, runs_augmented
# where
# runs.run_id = runs_augmented.run_id or
# runs.run_id = runs_augmented.root_run_id
# )
runs_and_root_runs = (
db.select([RunsTable.c.run_id.label("run_id")])
.select_from(runs_augmented)
.where(
db.or_(
RunsTable.c.run_id == runs_augmented.c.run_id,
RunsTable.c.run_id == runs_augmented.c.root_run_id,
)
)
.distinct(RunsTable.c.run_id)
).alias("runs_and_root_runs")
# We count the descendants of all of the runs in our query that are roots so that
# we can accurately display when a root run has more descendants than are returned by this
# query and afford a drill-down. This might be an unnecessary complication, but the
# alternative isn't obvious -- we could go and fetch *all* the runs in any group that we're
# going to return in this query, and then append those.
#
# pseudosql:
#
# select runs.run_body, count(all_descendant_runs.id) as child_counts
# from runs
# join runs_and_root_runs on runs.run_id = runs_and_root_runs.run_id
# left outer join all_descendant_runs
# on all_descendant_runs.value = runs_and_root_runs.run_id
# group by runs.run_body
# order by child_counts desc
runs_and_root_runs_with_descendant_counts = (
db.select(
[
RunsTable.c.run_body,
db.func.count(all_descendant_runs.c.id).label("child_counts"),
]
)
.select_from(
RunsTable.join(
runs_and_root_runs, RunsTable.c.run_id == runs_and_root_runs.c.run_id
).join(
all_descendant_runs,
all_descendant_runs.c.value == runs_and_root_runs.c.run_id,
isouter=True,
)
)
.group_by(RunsTable.c.run_body)
.order_by(db.desc(db.column("child_counts")))
)
with self.connect() as conn:
res = conn.execute(runs_and_root_runs_with_descendant_counts).fetchall()
# Postprocess: descendant runs get aggregated with their roots
root_run_id_to_group: Dict[str, List[PipelineRun]] = defaultdict(list)
root_run_id_to_count: Dict[str, int] = defaultdict(int)
for (run_body, count) in res:
row = (run_body,)
pipeline_run = self._row_to_run(row)
root_run_id = pipeline_run.get_root_run_id()
if root_run_id is not None:
root_run_id_to_group[root_run_id].append(pipeline_run)
else:
root_run_id_to_group[pipeline_run.run_id].append(pipeline_run)
root_run_id_to_count[pipeline_run.run_id] = count + 1
return {
root_run_id: {
"runs": list(run_group),
"count": root_run_id_to_count[root_run_id],
}
for root_run_id, run_group in root_run_id_to_group.items()
}
def has_run(self, run_id: str) -> bool:
check.str_param(run_id, "run_id")
return bool(self.get_run_by_id(run_id))
def delete_run(self, run_id: str):
check.str_param(run_id, "run_id")
query = db.delete(RunsTable).where(RunsTable.c.run_id == run_id)
with self.connect() as conn:
conn.execute(query)
def has_pipeline_snapshot(self, pipeline_snapshot_id: str) -> bool:
check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
return self._has_snapshot_id(pipeline_snapshot_id)
def add_pipeline_snapshot(
self, pipeline_snapshot: PipelineSnapshot, snapshot_id: Optional[str] = None
) -> str:
check.inst_param(pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot)
check.opt_str_param(snapshot_id, "snapshot_id")
if not snapshot_id:
snapshot_id = create_pipeline_snapshot_id(pipeline_snapshot)
return self._add_snapshot(
snapshot_id=snapshot_id,
snapshot_obj=pipeline_snapshot,
snapshot_type=SnapshotType.PIPELINE,
)
def get_pipeline_snapshot(self, pipeline_snapshot_id: str) -> PipelineSnapshot:
check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
return self._get_snapshot(pipeline_snapshot_id)
def has_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> bool:
check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id")
return bool(self.get_execution_plan_snapshot(execution_plan_snapshot_id))
def add_execution_plan_snapshot(
self, execution_plan_snapshot: ExecutionPlanSnapshot, snapshot_id: Optional[str] = None
) -> str:
check.inst_param(execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot)
check.opt_str_param(snapshot_id, "snapshot_id")
if not snapshot_id:
snapshot_id = create_execution_plan_snapshot_id(execution_plan_snapshot)
return self._add_snapshot(
snapshot_id=snapshot_id,
snapshot_obj=execution_plan_snapshot,
snapshot_type=SnapshotType.EXECUTION_PLAN,
)
def get_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> ExecutionPlanSnapshot:
check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id")
return self._get_snapshot(execution_plan_snapshot_id)
def _add_snapshot(self, snapshot_id: str, snapshot_obj, snapshot_type: SnapshotType) -> str:
check.str_param(snapshot_id, "snapshot_id")
check.not_none_param(snapshot_obj, "snapshot_obj")
check.inst_param(snapshot_type, "snapshot_type", SnapshotType)
with self.connect() as conn:
snapshot_insert = (
SnapshotsTable.insert().values( # pylint: disable=no-value-for-parameter
snapshot_id=snapshot_id,
snapshot_body=zlib.compress(
serialize_dagster_namedtuple(snapshot_obj).encode("utf-8")
),
snapshot_type=snapshot_type.value,
)
)
conn.execute(snapshot_insert)
return snapshot_id
def get_run_storage_id(self) -> str:
query = db.select([InstanceInfo.c.run_storage_id])
row = self.fetchone(query)
if not row:
run_storage_id = str(uuid.uuid4())
with self.connect() as conn:
conn.execute(InstanceInfo.insert().values(run_storage_id=run_storage_id))
return run_storage_id
else:
return row[0]
def _has_snapshot_id(self, snapshot_id: str) -> bool:
query = db.select([SnapshotsTable.c.snapshot_id]).where(
SnapshotsTable.c.snapshot_id == snapshot_id
)
row = self.fetchone(query)
return bool(row)
def _get_snapshot(self, snapshot_id: str):
query = db.select([SnapshotsTable.c.snapshot_body]).where(
SnapshotsTable.c.snapshot_id == snapshot_id
)
row = self.fetchone(query)
return defensively_unpack_pipeline_snapshot_query(logging, row) if row else None
def _get_partition_runs(
self, partition_set_name: str, partition_name: str
) -> List[PipelineRun]:
# utility method to help test reads off of the partition column
if not self.has_built_index(RUN_PARTITIONS):
# query by tags
return self.get_runs(
filters=RunsFilter(
tags={
PARTITION_SET_TAG: partition_set_name,
PARTITION_NAME_TAG: partition_name,
}
)
)
else:
query = (
self._runs_query()
.where(RunsTable.c.partition == partition_name)
.where(RunsTable.c.partition_set == partition_set_name)
)
rows = self.fetchall(query)
return self._rows_to_runs(rows)
# Tracking data migrations over secondary indexes
def _execute_data_migrations(
self, migrations, print_fn: Optional[Callable] = None, force_rebuild_all: bool = False
):
for migration_name, migration_fn in migrations.items():
if self.has_built_index(migration_name):
if not force_rebuild_all:
continue
if print_fn:
print_fn(f"Starting data migration: {migration_name}")
migration_fn()(self, print_fn)
self.mark_index_built(migration_name)
if print_fn:
print_fn(f"Finished data migration: {migration_name}")
def migrate(self, print_fn: Optional[Callable] = None, force_rebuild_all: bool = False):
self._execute_data_migrations(REQUIRED_DATA_MIGRATIONS, print_fn, force_rebuild_all)
def optimize(self, print_fn: Optional[Callable] = None, force_rebuild_all: bool = False):
self._execute_data_migrations(OPTIONAL_DATA_MIGRATIONS, print_fn, force_rebuild_all)
def has_built_index(self, migration_name: str) -> bool:
query = (
db.select([1])
.where(SecondaryIndexMigrationTable.c.name == migration_name)
.where(SecondaryIndexMigrationTable.c.migration_completed != None)
.limit(1)
)
with self.connect() as conn:
results = conn.execute(query).fetchall()
return len(results) > 0
def mark_index_built(self, migration_name: str):
query = (
SecondaryIndexMigrationTable.insert().values( # pylint: disable=no-value-for-parameter
name=migration_name,
migration_completed=datetime.now(),
)
)
with self.connect() as conn:
try:
conn.execute(query)
except db.exc.IntegrityError:
conn.execute(
SecondaryIndexMigrationTable.update() # pylint: disable=no-value-for-parameter
.where(SecondaryIndexMigrationTable.c.name == migration_name)
.values(migration_completed=datetime.now())
)
# Checking for migrations
def has_run_stats_index_cols(self):
with self.connect() as conn:
column_names = [x.get("name") for x in db.inspect(conn).get_columns(RunsTable.name)]
return "start_time" in column_names and "end_time" in column_names
# Daemon heartbeats
def add_daemon_heartbeat(self, daemon_heartbeat: DaemonHeartbeat):
with self.connect() as conn:
# insert, or update if already present
try:
conn.execute(
DaemonHeartbeatsTable.insert().values( # pylint: disable=no-value-for-parameter
timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp),
daemon_type=daemon_heartbeat.daemon_type,
daemon_id=daemon_heartbeat.daemon_id,
body=serialize_dagster_namedtuple(daemon_heartbeat),
)
)
except db.exc.IntegrityError:
conn.execute(
DaemonHeartbeatsTable.update() # pylint: disable=no-value-for-parameter
.where(DaemonHeartbeatsTable.c.daemon_type == daemon_heartbeat.daemon_type)
.values( # pylint: disable=no-value-for-parameter
timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp),
daemon_id=daemon_heartbeat.daemon_id,
body=serialize_dagster_namedtuple(daemon_heartbeat),
)
)
def get_daemon_heartbeats(self) -> Dict[str, DaemonHeartbeat]:
with self.connect() as conn:
rows = conn.execute(db.select(DaemonHeartbeatsTable.columns))
heartbeats = []
for row in rows:
heartbeats.append(deserialize_as(row.body, DaemonHeartbeat))
return {heartbeat.daemon_type: heartbeat for heartbeat in heartbeats}
def wipe(self):
"""Clears the run storage."""
with self.connect() as conn:
# https://stackoverflow.com/a/54386260/324449
conn.execute(RunsTable.delete()) # pylint: disable=no-value-for-parameter
conn.execute(RunTagsTable.delete()) # pylint: disable=no-value-for-parameter
conn.execute(SnapshotsTable.delete()) # pylint: disable=no-value-for-parameter
conn.execute(DaemonHeartbeatsTable.delete()) # pylint: disable=no-value-for-parameter
def wipe_daemon_heartbeats(self):
with self.connect() as conn:
# https://stackoverflow.com/a/54386260/324449
conn.execute(DaemonHeartbeatsTable.delete()) # pylint: disable=no-value-for-parameter
def get_backfills(
self,
status: Optional[BulkActionStatus] = None,
cursor: Optional[str] = None,
limit: Optional[int] = None,
) -> List[PartitionBackfill]:
check.opt_inst_param(status, "status", BulkActionStatus)
query = db.select([BulkActionsTable.c.body])
if status:
query = query.where(BulkActionsTable.c.status == status.value)
if cursor:
cursor_query = db.select([BulkActionsTable.c.id]).where(
BulkActionsTable.c.key == cursor
)
query = query.where(BulkActionsTable.c.id < cursor_query)
if limit:
query = query.limit(limit)
query = query.order_by(BulkActionsTable.c.id.desc())
rows = self.fetchall(query)
return [deserialize_as(row[0], PartitionBackfill) for row in rows]
def get_backfill(self, backfill_id: str) -> Optional[PartitionBackfill]:
check.str_param(backfill_id, "backfill_id")
query = db.select([BulkActionsTable.c.body]).where(BulkActionsTable.c.key == backfill_id)
row = self.fetchone(query)
return deserialize_as(row[0], PartitionBackfill) if row else None
def add_backfill(self, partition_backfill: PartitionBackfill):
check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill)
with self.connect() as conn:
conn.execute(
BulkActionsTable.insert().values( # pylint: disable=no-value-for-parameter
key=partition_backfill.backfill_id,
status=partition_backfill.status.value,
timestamp=utc_datetime_from_timestamp(partition_backfill.backfill_timestamp),
body=serialize_dagster_namedtuple(partition_backfill),
)
)
def update_backfill(self, partition_backfill: PartitionBackfill):
check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill)
backfill_id = partition_backfill.backfill_id
if not self.get_backfill(backfill_id):
raise DagsterInvariantViolationError(
f"Backfill {backfill_id} is not present in storage"
)
with self.connect() as conn:
conn.execute(
BulkActionsTable.update() # pylint: disable=no-value-for-parameter
.where(BulkActionsTable.c.key == backfill_id)
.values(
status=partition_backfill.status.value,
body=serialize_dagster_namedtuple(partition_backfill),
)
)
GET_PIPELINE_SNAPSHOT_QUERY_ID = "get-pipeline-snapshot"
def defensively_unpack_pipeline_snapshot_query(logger, row):
# no checking here because sqlalchemy returns a special
# row proxy and don't want to instance check on an internal
# implementation detail
def _warn(msg):
logger.warning("get-pipeline-snapshot: {msg}".format(msg=msg))
if not isinstance(row[0], bytes):
_warn("First entry in row is not a binary type.")
return None
try:
uncompressed_bytes = zlib.decompress(row[0])
except zlib.error:
_warn("Could not decompress bytes stored in snapshot table.")
return None
try:
decoded_str = uncompressed_bytes.decode("utf-8")
except UnicodeDecodeError:
_warn("Could not unicode decode decompressed bytes stored in snapshot table.")
return None
try:
return deserialize_json_to_dagster_namedtuple(decoded_str)
except JSONDecodeError:
_warn("Could not parse json in snapshot table.")
return None