You are viewing an unreleased or outdated version of the documentation

Source code for dagster_databricks.databricks

import base64
import logging
import time
from typing import IO, Any, Mapping, Optional, Tuple, Union, cast

import dagster
import dagster._check as check
import dagster_pyspark
import databricks_api
import databricks_cli.sdk
import requests.exceptions
from dagster._annotations import deprecated, public
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import compute, jobs
from typing_extensions import Final

import dagster_databricks

from .types import (
    DatabricksRunState,
)
from .version import __version__

# wait at most 24 hours by default for run execution
DEFAULT_RUN_MAX_WAIT_TIME_SEC: Final = 24 * 60 * 60


[docs]class DatabricksError(Exception): pass
[docs]class DatabricksClient: """A thin wrapper over the Databricks REST API.""" def __init__(self, host: str, token: str, workspace_id: Optional[str] = None): self.host = host self.workspace_id = workspace_id self._workspace_client = WorkspaceClient( host=host, token=token, product="dagster-databricks", product_version=__version__, ) # TODO: This is the old shim client that we were previously using. Arguably this is # confusing for users to use since this is an unofficial wrapper around the documented # Databricks REST API. We should consider removing this in the next minor release. self._client = databricks_api.DatabricksAPI(host=host, token=token) self.__setup_user_agent(self._client.client) # TODO: This is the old `databricks_cli` client that was previous recommended by Databricks. # It is no longer supported and should be removed in favour of `databricks-sdk` in the next # minor release. self._api_client = databricks_cli.sdk.ApiClient(host=host, token=token) self.__setup_user_agent(self._api_client) def __setup_user_agent( self, client: Union[WorkspaceClient, databricks_api.DatabricksAPI, databricks_cli.sdk.ApiClient], ) -> None: """Overrides the user agent for the Databricks API client.""" client.default_headers["user-agent"] = f"dagster-databricks/{__version__}" @deprecated( breaking_version="0.21.0", additional_warn_text="Use `workspace_client` property instead." ) @public @property def client(self) -> databricks_api.DatabricksAPI: """Retrieve the legacy Databricks API client.""" return self._client @client.setter def client(self, value: databricks_api.DatabricksAPI) -> None: self._client = value @deprecated( breaking_version="0.21.0", additional_warn_text="Use `workspace_client` property instead." ) @public @property def api_client(self) -> databricks_cli.sdk.ApiClient: """Retrieve a reference to the underlying Databricks API client. For more information, see the `Databricks Python API <https://docs.databricks.com/dev-tools/python-api.html>`_. **Examples:** .. code-block:: python from dagster import op from databricks_cli.jobs.api import JobsApi from databricks_cli.runs.api import RunsApi from databricks.sdk import WorkspaceClient @op(required_resource_keys={"databricks_client"}) def op1(context): # Initialize the Databricks Jobs API jobs_client = JobsApi(context.resources.databricks_client.api_client) runs_client = RunsApi(context.resources.databricks_client.api_client) client = context.resources.databricks_client.api_client # Example 1: Run a Databricks job with some parameters. jobs_client.run_now(...) client.jobs.run_now(...) # Example 2: Trigger a one-time run of a Databricks workload. runs_client.submit_run(...) client.jobs.submit(...) # Example 3: Get an existing run. runs_client.get_run(...) client.jobs.get_run(...) # Example 4: Cancel a run. runs_client.cancel_run(...) client.jobs.cancel_run(...) Returns: ApiClient: The authenticated Databricks API client. """ return self._api_client @public @property def workspace_client(self) -> WorkspaceClient: """Retrieve a reference to the underlying Databricks Workspace client. For more information, see the `Databricks SDK for Python <https://docs.databricks.com/dev-tools/sdk-python.html>`_. **Examples:** .. code-block:: python from dagster import op from databricks.sdk import WorkspaceClient @op(required_resource_keys={"databricks_client"}) def op1(context): # Initialize the Databricks Jobs API client = context.resources.databricks_client.api_client # Example 1: Run a Databricks job with some parameters. client.jobs.run_now(...) # Example 2: Trigger a one-time run of a Databricks workload. client.jobs.submit(...) # Example 3: Get an existing run. client.jobs.get_run(...) # Example 4: Cancel a run. client.jobs.cancel_run(...) Returns: WorkspaceClient: The authenticated Databricks SDK Workspace Client. """ return self._workspace_client def read_file(self, dbfs_path: str, block_size: int = 1024**2) -> bytes: """Read a file from DBFS to a **byte string**.""" if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] data = b"" bytes_read = 0 dbfs_service = self.workspace_client.dbfs jdoc = dbfs_service.read(path=dbfs_path, length=block_size) data += base64.b64decode(jdoc.data) while jdoc.bytes_read == block_size: bytes_read += jdoc.bytes_read jdoc = dbfs_service.read(path=dbfs_path, offset=bytes_read, length=block_size) data += base64.b64decode(jdoc.data) return data def put_file( self, file_obj: IO, dbfs_path: str, overwrite: bool = False, block_size: int = 1024**2 ) -> None: """Upload an arbitrary large file to DBFS. This doesn't use the DBFS `Put` API because that endpoint is limited to 1MB. """ if dbfs_path.startswith("dbfs://"): dbfs_path = dbfs_path[7:] dbfs_service = self.workspace_client.dbfs create_response = dbfs_service.create(path=dbfs_path, overwrite=overwrite) handle = create_response.handle block = file_obj.read(block_size) while block: data = base64.b64encode(block).decode("utf-8") dbfs_service.add_block(data=data, handle=handle) block = file_obj.read(block_size) dbfs_service.close(handle=handle) def get_run_state(self, databricks_run_id: int) -> "DatabricksRunState": """Get the state of a run by Databricks run ID. Return a `DatabricksRunState` object. Note that the `result_state` attribute may be `None` if the run hasn't yet terminated. """ run = self.workspace_client.jobs.get_run(databricks_run_id) return DatabricksRunState.from_databricks(run.state) def poll_run_state( self, logger: logging.Logger, start_poll_time: float, databricks_run_id: int, max_wait_time_sec: float, verbose_logs: bool = True, ) -> bool: run_state = self.get_run_state(databricks_run_id) if run_state.has_terminated(): if run_state.is_successful(): logger.info(f"Run `{databricks_run_id}` completed successfully.") return True if run_state.is_skipped(): logger.info(f"Run `{databricks_run_id}` was skipped.") return True else: error_message = ( f"Run `{databricks_run_id}` failed with result state:" f" `{run_state.result_state}`. Message: {run_state.state_message}." ) logger.error(error_message) raise DatabricksError(error_message) else: if verbose_logs: logger.debug(f"Run `{databricks_run_id}` in state {run_state}.") if time.time() - start_poll_time > max_wait_time_sec: raise DatabricksError( f"Run `{databricks_run_id}` took more than {max_wait_time_sec}s to complete." " Failing the run." ) return False def wait_for_run_to_complete( self, logger: logging.Logger, databricks_run_id: int, poll_interval_sec: float, max_wait_time_sec: int, verbose_logs: bool = True, ) -> None: logger.info(f"Waiting for Databricks run `{databricks_run_id}` to complete...") start_poll_time = time.time() while True: if self.poll_run_state( logger=logger, start_poll_time=start_poll_time, databricks_run_id=databricks_run_id, max_wait_time_sec=max_wait_time_sec, verbose_logs=verbose_logs, ): return time.sleep(poll_interval_sec)
class DatabricksJobRunner: """Submits jobs created using Dagster config to Databricks, and monitors their progress. Attributes: host (str): Databricks host, e.g. https://uksouth.azuredatabricks.net. token (str): Databricks authentication token. poll_interval_sec (float): How often to poll Databricks for run status. max_wait_time_sec (int): How long to wait for a run to complete before failing. """ def __init__( self, host: str, token: str, poll_interval_sec: float = 5, max_wait_time_sec: int = DEFAULT_RUN_MAX_WAIT_TIME_SEC, ): self.host = check.str_param(host, "host") self.token = check.str_param(token, "token") self.poll_interval_sec = check.numeric_param(poll_interval_sec, "poll_interval_sec") self.max_wait_time_sec = check.int_param(max_wait_time_sec, "max_wait_time_sec") self._client: DatabricksClient = DatabricksClient(host=self.host, token=self.token) @property def client(self) -> DatabricksClient: """Return the underlying `DatabricksClient` object.""" return self._client def submit_run(self, run_config: Mapping[str, Any], task: Mapping[str, Any]) -> int: """Submit a new run using the 'Runs submit' API.""" existing_cluster_id = run_config["cluster"].get("existing") new_cluster = run_config["cluster"].get("new") # The Databricks API needs different keys to be present in API calls depending # on new/existing cluster, so we need to process the new_cluster # config first. if new_cluster: new_cluster = new_cluster.copy() nodes = new_cluster.pop("nodes") if "instance_pool_id" in nodes: new_cluster["instance_pool_id"] = nodes["instance_pool_id"] else: node_types = nodes["node_types"] new_cluster["node_type_id"] = node_types["node_type_id"] if "driver_node_type_id" in node_types: new_cluster["driver_node_type_id"] = node_types["driver_node_type_id"] cluster_size = new_cluster.pop("size") if "num_workers" in cluster_size: new_cluster["num_workers"] = cluster_size["num_workers"] else: new_cluster["autoscale"] = cluster_size["autoscale"] tags = new_cluster.get("custom_tags", {}) if isinstance(tags, list): tags = {x["key"]: x["value"] for x in tags} tags["__dagster_version"] = dagster.__version__ new_cluster["custom_tags"] = tags check.invariant( existing_cluster_id is not None or new_cluster is not None, "Invalid value for run_config.cluster", ) # We'll always need some libraries, namely dagster/dagster_databricks/dagster_pyspark, # since they're imported by our scripts. # Add them if they're not already added by users in config. libraries = list(run_config.get("libraries", [])) install_default_libraries = run_config.get("install_default_libraries", True) if install_default_libraries: python_libraries = { x["pypi"]["package"].split("==")[0].replace("_", "-") for x in libraries if "pypi" in x } for library_name, library in [ ("dagster", dagster), ("dagster-databricks", dagster_databricks), ("dagster-pyspark", dagster_pyspark), ]: if library_name not in python_libraries: libraries.append( {"pypi": {"package": f"{library_name}=={library.__version__}"}} ) # Only one task should be able to be chosen really; make sure of that here. check.invariant( sum( task.get(key) is not None for key in [ "notebook_task", "spark_python_task", "spark_jar_task", "spark_submit_task", ] ) == 1, "Multiple tasks specified in Databricks run", ) return self.client.workspace_client.jobs.submit( run_name=run_config.get("run_name"), tasks=[ jobs.SubmitTask.from_dict( { "new_cluster": new_cluster, "existing_cluster_id": existing_cluster_id, # "libraries": [compute.Library.from_dict(lib) for lib in libraries], "libraries": libraries, **task, "task_key": "dagster-task", }, ) ], ).bind()["run_id"] def retrieve_logs_for_run_id( self, log: logging.Logger, databricks_run_id: int ) -> Optional[Tuple[Optional[str], Optional[str]]]: """Retrieve the stdout and stderr logs for a run.""" run = self.client.workspace_client.jobs.get_run(databricks_run_id) cluster = self.client.workspace_client.clusters.get(run.cluster_instance.cluster_id) log_config = cluster.cluster_log_conf if log_config is None: log.warn( "Logs not configured for cluster {cluster} used for run {run}".format( cluster=cluster.cluster_id, run=databricks_run_id ) ) return None if cast(Optional[compute.S3StorageInfo], log_config.s3) is not None: logs_prefix = log_config.s3.destination log.warn("Retrieving S3 logs not yet implemented") return None elif cast(Optional[compute.DbfsStorageInfo], log_config.dbfs) is not None: logs_prefix = log_config.dbfs.destination stdout = self.wait_for_dbfs_logs(log, logs_prefix, cluster.cluster_id, "stdout") stderr = self.wait_for_dbfs_logs(log, logs_prefix, cluster.cluster_id, "stderr") return stdout, stderr def wait_for_dbfs_logs( self, log: logging.Logger, prefix: str, cluster_id: str, filename: str, waiter_delay: int = 10, waiter_max_attempts: int = 10, ) -> Optional[str]: """Attempt up to `waiter_max_attempts` attempts to get logs from DBFS.""" path = "/".join([prefix, cluster_id, "driver", filename]) log.info(f"Retrieving logs from {path}") num_attempts = 0 while num_attempts <= waiter_max_attempts: try: logs = self.client.read_file(path) return logs.decode("utf-8") except requests.exceptions.HTTPError: num_attempts += 1 time.sleep(waiter_delay) log.warn("Could not retrieve cluster logs!")