diff --git a/titan/blueprint.py b/titan/blueprint.py index d77429c..d592081 100644 --- a/titan/blueprint.py +++ b/titan/blueprint.py @@ -1,21 +1,17 @@ import json import logging +import threading from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from queue import Queue -from typing import Any, Generator, Iterable, Optional, Sequence, TypeVar, Union, cast +from typing import Any, Generator, Iterable, Optional, Sequence, Set, TypeVar, Union, cast import snowflake.connector from . import data_provider, lifecycle from .blueprint_config import BlueprintConfig -from .client import ( - ALREADY_EXISTS_ERR, - DOES_NOT_EXIST_ERR, - INVALID_GRANT_ERR, - execute, - reset_cache, -) +from .client import ALREADY_EXISTS_ERR, DOES_NOT_EXIST_ERR, INVALID_GRANT_ERR, execute, reset_cache from .data_provider import SessionContext from .enums import AccountEdition, BlueprintScope, ResourceType, RunMode, resource_type_is_grant from .exceptions import ( @@ -28,10 +24,7 @@ OrphanResourceException, ) from .identifiers import URN, parse_identifier, parse_URN, resource_label_for_type -from .privs import ( - CREATE_PRIV_FOR_RESOURCE_TYPE, - system_role_for_priv, -) +from .privs import CREATE_PRIV_FOR_RESOURCE_TYPE, system_role_for_priv from .resource_name import ResourceName from .resource_tags import ResourceTags from .resources import Database, FutureGrant, Grant, GrantOnAll, RoleGrant, Schema @@ -446,7 +439,10 @@ def _merge(resource: ResourceContainer, pointer: ResourcePointer): # Create a unique identifier for the resource resource_id: ResourceRef if isinstance(resource_or_pointer, NamedResource): - resource_id = (resource_or_pointer.resource_type, resource_or_pointer.name) + resource_id = ( + resource_or_pointer.resource_type, + str(resource_or_pointer.name), + ) else: resource_id = str(resource_or_pointer.urn) @@ -476,16 +472,27 @@ def _merge(resource: ResourceContainer, pointer: ResourcePointer): return list(namespace.values()) -def _get_databases(resource: ResourceContainer) -> list[Union[Database, ResourcePointer]]: - return cast(list[Union[Database, ResourcePointer]], resource.items(resource_type=ResourceType.DATABASE)) +def _get_databases( + resource: ResourceContainer, +) -> list[Union[Database, ResourcePointer]]: + return cast( + list[Union[Database, ResourcePointer]], + resource.items(resource_type=ResourceType.DATABASE), + ) def _get_schemas(resource: ResourceContainer) -> list[Union[Schema, ResourcePointer]]: - return cast(list[Union[Schema, ResourcePointer]], resource.items(resource_type=ResourceType.SCHEMA)) + return cast( + list[Union[Schema, ResourcePointer]], + resource.items(resource_type=ResourceType.SCHEMA), + ) def _get_schema_by_name(resource: ResourceContainer, name: Union[ResourceName, str]) -> Union[Schema, ResourcePointer]: - return cast(Union[Schema, ResourcePointer], resource.find(name=name, resource_type=ResourceType.SCHEMA)) + return cast( + Union[Schema, ResourcePointer], + resource.find(name=name, resource_type=ResourceType.SCHEMA), + ) def _get_public_schema(resource: ResourceContainer) -> Union[Schema, ResourcePointer]: @@ -525,22 +532,25 @@ def __init__( scope: Optional[str] = None, database: Optional[str] = None, schema: Optional[str] = None, + threads: int = 8, ) -> None: - self._config: BlueprintConfig = BlueprintConfig( + self._config = BlueprintConfig( name=name, resources=resources, run_mode=RunMode(run_mode) if run_mode else RunMode.CREATE_OR_UPDATE, - dry_run=False if dry_run is None else dry_run, + dry_run=dry_run, allowlist=[ResourceType(item) for item in allowlist] if allowlist else None, vars=vars or {}, vars_spec=vars_spec or [], scope=BlueprintScope(scope) if scope else None, database=ResourceName(database) if database else None, schema=ResourceName(schema) if schema else None, + threads=max(1, threads), # Ensure at least 1 thread ) - self._finalized: bool = False - self._staged: list[Resource] = [] - self._root: ResourcePointer = ResourcePointer(name="MISSING", resource_type=ResourceType.ACCOUNT) + self._finalized = False + self._staged = [] + self._root = ResourcePointer(name="ACCOUNT", resource_type=ResourceType.ACCOUNT) + self._levels = {} # Store dependency levels self.add(resources or []) @classmethod @@ -548,7 +558,7 @@ def from_config(cls, config: BlueprintConfig): blueprint = cls.__new__(cls) blueprint._config = config blueprint._staged = [] - blueprint._root = ResourcePointer(name="MISSING", resource_type=ResourceType.ACCOUNT) + blueprint._root = ResourcePointer(name="ACCOUNT", resource_type=ResourceType.ACCOUNT) blueprint._finalized = False blueprint.add(config.resources or []) return blueprint @@ -593,82 +603,74 @@ def _raise_for_nonconforming_plan(self, session_ctx: SessionContext, plan: Plan) exception_block = "\n".join(exceptions) raise NonConformingPlanException("Non-conforming actions found in plan:\n" + exception_block) - def _plan(self, remote_state: State, manifest: Manifest) -> Plan: - additive_changes: list[ResourceChange] = [] - destructive_changes: list[ResourceChange] = [] - - for resource_change in diff(remote_state, manifest): - if isinstance(resource_change, (CreateResource, UpdateResource, TransferOwnership)): - additive_changes.append(resource_change) - elif isinstance(resource_change, DropResource): - destructive_changes.append(resource_change) - - # Generate a list of all URNs - resource_set = set(manifest.urns + list(remote_state.keys())) - for ref in manifest.refs: - resource_set.add(ref[0]) - resource_set.add(ref[1]) - # Calculate a topological sort order for the URNs - sort_order = topological_sort(resource_set, set(manifest.refs)) - plan = sorted(additive_changes, key=lambda change: sort_order[change.urn]) + _sort_destructive_changes( - destructive_changes, sort_order - ) - return plan - def fetch_remote_state(self, session, manifest: Manifest) -> State: - state: State = {} + """Fetch remote state with parallel resource retrieval.""" + state = {} + logger = logging.getLogger(__name__) session_ctx = data_provider.fetch_session(session) data_provider.use_secondary_roles(session, all=True) + urns = manifest.urns if self._config.run_mode == RunMode.SYNC: - if self._config.allowlist: - for resource_type in self._config.allowlist: - for fqn in data_provider.list_resource(session, resource_label_for_type(resource_type)): - # FIXME - if self._config.scope == BlueprintScope.DATABASE and fqn.database != self._config.database: - continue - elif self._config.scope == BlueprintScope.SCHEMA and fqn.schema != self._config.schema: - continue - urn = URN(resource_type=resource_type, fqn=fqn, account_locator=session_ctx["account_locator"]) - data = data_provider.fetch_resource(session, urn) - if data is None: - raise MissingResourceException(f"Resource could not be found: {urn}") - resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) - state[urn] = resource_cls.spec(**data).to_dict(session_ctx["account_edition"]) - else: + if not self._config.allowlist: raise RuntimeError("Sync mode requires an allowlist") + urns = [URN.from_resource(account_locator=manifest._account_locator, resource=self._root)] + for resource_type in self._config.allowlist: + for fqn in data_provider.list_resource(session, resource_label_for_type(resource_type)): + if self._config.scope == BlueprintScope.DATABASE and fqn.database != self._config.database: + continue + if self._config.scope == BlueprintScope.SCHEMA and fqn.schema != self._config.schema: + continue + + urns.append( + URN( + resource_type=resource_type, + fqn=fqn, + account_locator=session_ctx["account_locator"], + ) + ) - for urn, manifest_item in manifest.items(): - data = data_provider.fetch_resource(session, urn) - if data is not None: - if isinstance(manifest_item, ResourcePointer): - resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) - else: - resource_cls = manifest_item.resource_cls - - state[urn] = resource_cls.spec(**data).to_dict(session_ctx["account_edition"]) - - # check for existence of resource refs - for parent, reference in manifest.refs: - if reference in manifest: - continue - - is_public_schema = reference.resource_type == ResourceType.SCHEMA and reference.fqn.name == ResourceName( - "PUBLIC" - ) - - try: - data = data_provider.fetch_resource(session, reference) - except Exception: - data = None + with ThreadPoolExecutor(max_workers=self._config.threads) as executor: + future_to_urn = {executor.submit(data_provider.fetch_resource, session, urn): urn for urn in urns} + for future in as_completed(future_to_urn): + urn = future_to_urn[future] + try: + data = future.result() + if data: + if self._config.run_mode == RunMode.SYNC: + resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) + else: + item = manifest[urn] + resource_cls = ( + item.resource_cls + if isinstance(item, ManifestResource) + else Resource.resolve_resource_cls(urn.resource_type, data) + ) - if data is None and not is_public_schema: - # logger.error(manifest.to_dict(session_ctx)) - raise MissingResourceException( - f"Resource {reference} required by {parent} not found or failed to fetch" + state[urn] = resource_cls.spec(**data).to_dict(session_ctx["account_edition"]) + else: + if self._config.run_mode == RunMode.SYNC: + raise MissingResourceException(f"Resource {urn} not found") + except Exception as e: + logger.error(f"Failed to fetch resource {urn}: {e}") + raise # Stop processing if any fetch fails + + if self._config.run_mode != RunMode.SYNC: + for parent, reference in manifest.refs: + if reference in manifest or reference in state: + continue + is_public_schema = ( + reference.resource_type == ResourceType.SCHEMA and reference.fqn.name == ResourceName("PUBLIC") ) - + try: + data = data_provider.fetch_resource(session, reference) + if data is None and not is_public_schema: + raise MissingResourceException(f"Resource {reference} required by {parent} not found") + except Exception as e: + if not is_public_schema: + logger.error(f"Error fetching reference {reference}: {e}") + raise return state def _resolve_vars(self): @@ -691,8 +693,6 @@ def _build_resource_graph(self, session_ctx: SessionContext) -> None: # Create root node of the resource graph if len(org_scoped) > 0: raise Exception("Blueprint cannot contain an Account resource") - else: - self._root = ResourcePointer(name="ACCOUNT", resource_type=ResourceType.ACCOUNT) # Merge account scoped pointers into their proper resource acct_scoped = _merge_pointers(acct_scoped) @@ -945,7 +945,26 @@ def generate_manifest(self, session_ctx: SessionContext) -> Manifest: return manifest + def _execute_change(self, session, commands: list[str]) -> None: + """Execute a list of SQL commands for a single change.""" + logger = logging.getLogger(__name__) + for sql in commands: + if not self._config.dry_run: + try: + execute(session, sql) + except snowflake.connector.errors.ProgrammingError as err: + if err.errno == ALREADY_EXISTS_ERR: + logger.warning(f"Resource already exists: {sql}, skipping...") + elif err.errno == INVALID_GRANT_ERR: + logger.warning(f"Invalid grant: {sql}, skipping...") + elif err.errno == DOES_NOT_EXIST_ERR and sql.startswith(("REVOKE", "DROP")): + logger.warning(f"Resource does not exist: {sql}, skipping...") + else: + raise + def plan(self, session) -> Plan: + """Generate and store the plan, computing dependency levels.""" + logger = logging.getLogger(__name__) reset_cache() logger.debug("Using blueprint vars:") for key in self._config.vars.keys(): @@ -954,63 +973,96 @@ def plan(self, session) -> Plan: manifest = self.generate_manifest(session_ctx) remote_state = self.fetch_remote_state(session, manifest) try: - finished_plan = self._plan(remote_state, manifest) + finished_plan = diff(remote_state, manifest) + # Compute dependency levels + resource_set = set(manifest.urns + list(remote_state.keys())) + for ref in manifest.refs: + resource_set.add(ref[0]) + resource_set.add(ref[1]) + self._levels = compute_levels(resource_set, set(manifest.refs)) except Exception as e: logger.error("~" * 80 + "REMOTE STATE") logger.error(remote_state) logger.error("~" * 80 + "MANIFEST") logger.error(manifest) - - raise e + raise self._raise_for_nonconforming_plan(session_ctx, finished_plan) return finished_plan - def apply(self, session, plan: Optional[Plan] = None): - if plan is None: - plan = self.plan(session) + def apply(self, session, plan: Optional[Plan] = None) -> None: + """Apply the plan with parallel execution of independent additive changes. + + At this point, we have a list of actions as a part of the plan. Each action is one of: + 1. ADD action (CREATE command) + 2. CHANGE action (one or many ALTER or SET PARAMETER commands) + 3. REMOVE action (DROP command, REVOKE command, or a rename operation) + 4. TRANSFER action (GRANT OWNERSHIP command) + + Each action requires: + • a set of privileges necessary to run commands + • the appropriate role to execute commands + + Once we've determined those things, we can compare the list of required roles and privileges + against what we have access to in the session and the role tree.""" + + def execute_commands_in_parallel(commands): + """Execute a list of SQL commands in parallel using a thread pool.""" + with ThreadPoolExecutor(max_workers=self._config.threads) as executor: + future_to_change = { + executor.submit( + self._execute_change, + session, + c["commands"], + ): c["change"] + for c in commands + } + for future in as_completed(future_to_change): + change = future_to_change[future] + try: + future.result() + except Exception as e: + logger.error(f"Failed to execute change {change}: {e}") + raise # TODO: cursor setup, including query tag - """ - At this point, we have a list of actions as a part of the plan. Each action is one of: - 1. ADD action (CREATE command) - 2. CHANGE action (one or many ALTER or SET PARAMETER commands) - 3. REMOVE action (DROP command, REVOKE command, or a rename operation) - 4. TRANSFER action (GRANT OWNERSHIP command) - - Each action requires: - • a set of privileges necessary to run commands - • the appropriate role to execute commands - - Once we've determined those things, we can compare the list of required roles and privileges - against what we have access to in the session and the role tree. - """ + logger = logging.getLogger(__name__) + if plan is None: + plan = self.plan(session) session_ctx = data_provider.fetch_session(session) - _raise_if_plan_would_drop_session_user(session_ctx, plan) - action_queue = compile_plan_to_sql(session_ctx, plan) - actions_taken = [] + sql_commands_per_change = compile_plan_to_sql(session_ctx, plan) + roles = [] + additive_commands = [] + destructive_commands = [] + for command in sql_commands_per_change: + roles.append(command["role"]) + if isinstance(command["change"], (CreateResource, UpdateResource, TransferOwnership)): + additive_commands.append(command) + elif isinstance(command["change"], DropResource): + destructive_commands.append(command) + roles = set(roles) + + # Map changes to their levels + levels = { + c["change"].urn: self._levels[c["change"].urn] for c in additive_commands if c["change"].urn in self._levels + } + max_level = max(levels.values()) if levels else -1 - while action_queue: - sql = action_queue.pop(0) - actions_taken.append(sql) - try: - if not self._config.dry_run: - execute(session, sql) - except snowflake.connector.errors.ProgrammingError as err: - if err.errno == ALREADY_EXISTS_ERR: - logger.error(f"Resource already exists: {sql}, skipping...") - elif err.errno == INVALID_GRANT_ERR: - logger.error(f"Invalid grant: {sql}, skipping...") - elif err.errno == DOES_NOT_EXIST_ERR and sql.startswith("REVOKE"): - logger.error(f"Resource does not exist: {sql}, skipping...") - elif err.errno == DOES_NOT_EXIST_ERR and sql.startswith("DROP"): - logger.error(f"Resource does not exist: {sql}, skipping...") - else: - raise err - return actions_taken + # Execute additive changes by level + for level in reversed(range(max_level + 1)): + commands_at_level = [c for c in additive_commands if levels.get(c["change"].urn, -1) == level] + for role in roles: + # Execute additive changes in current level by role + commands_at_role_level = [c for c in commands_at_level if c["role"] == role] + if commands_at_role_level: + logger.debug(f"Executing level {level} role {role} with {len(commands_at_role_level)} changes") + execute_commands_in_parallel(commands_at_role_level) + + # Execute destructive changes + execute_commands_in_parallel(destructive_commands) def _add(self, resource: Resource): if self._finalized: @@ -1059,7 +1111,7 @@ def execution_strategy_for_change( if isinstance(change, CreateResource) and change.urn.resource_type == ResourceType.GRANT: execution_role = system_role_for_priv(change.after["priv"]) if execution_role and execution_role in available_roles: - return execution_role, False + return ResourceName(execution_role), False if "SECURITYADMIN" in available_roles: return ResourceName("SECURITYADMIN"), False @@ -1110,7 +1162,7 @@ def execution_strategy_for_change( system_role = system_role_for_priv(create_priv) if system_role and system_role in available_roles: transfer_ownership = system_role != change_owner - return system_role, transfer_ownership + return ResourceName(system_role), transfer_ownership raise MissingPrivilegeException(f"{system_role} isnt available to execute {change}") elif isinstance(change.resource_cls.scope, (DatabaseScope, SchemaScope)) and change.container: container_owner = ResourceName(change.container[1]) @@ -1128,7 +1180,7 @@ def sql_commands_for_change( change: ResourceChange, available_roles: list[ResourceName], default_role: ResourceName, -): +) -> tuple[ResourceName, list[str]]: """ In Snowflake's RBAC model, a session has an active role, and zero or more secondary roles. @@ -1159,7 +1211,6 @@ def sql_commands_for_change( available_roles, default_role, ) - before_change_cmd.append(f"USE ROLE {execution_role}") if isinstance(change, CreateResource): @@ -1214,88 +1265,63 @@ def sql_commands_for_change( copy_current_grants=True, ) - return before_change_cmd + [change_cmd] + after_change_cmd + return execution_role, before_change_cmd + [change_cmd] + after_change_cmd -def compile_plan_to_sql(session_ctx: SessionContext, plan: Plan): - sql_commands = [] - - sql_commands.append("USE SECONDARY ROLES ALL") +def compile_plan_to_sql(session_ctx: SessionContext, plan: Plan) -> list[dict]: + """Compile the plan into a list of SQL command lists, one per change.""" + sql_commands_per_change = [] available_roles = session_ctx["available_roles"].copy() default_role = session_ctx["role"] for change in plan: - # Generate SQL commands - commands = sql_commands_for_change( - change, - available_roles, - default_role, - ) - sql_commands.extend(commands) - + role, commands = sql_commands_for_change(change, available_roles, default_role) + sql_commands_per_change.append({"role": role, "commands": commands, "change": change}) if isinstance(change, CreateResource): if change.urn.resource_type == ResourceType.ROLE: available_roles.append(ResourceName(change.after["name"])) elif change.urn.resource_type == ResourceType.ROLE_GRANT: if change.after["to_role"] in available_roles: available_roles.append(ResourceName(change.after["role"])) - - return sql_commands - - -def topological_sort(resource_set: set[T], references: set[tuple[T, T]]) -> dict[T, int]: - # Kahn's algorithm - - # Compute in-degree (# of inbound edges) for each node - in_degrees: dict[T, int] = {} - outgoing_edges: dict[T, set[T]] = {} - - for node in resource_set: - in_degrees[node] = 0 - outgoing_edges[node] = set() - - for node, ref in references: - in_degrees[ref] += 1 - outgoing_edges[node].add(ref) - - # Put all nodes with 0 in-degree in a queue - queue: Queue = Queue() - for node, in_degree in in_degrees.items(): - if in_degree == 0: - queue.put(node) - - # Create an empty node list - nodes = [] - - while not queue.empty(): - node = queue.get() - nodes.append(node) - - # For each of node's outgoing edges - empty_neighbors = set() - for edge in outgoing_edges[node]: - in_degrees[edge] -= 1 - if in_degrees[edge] == 0: - queue.put(edge) - empty_neighbors.add(edge) - - # Remove edges to empty neighbors - outgoing_edges[node].difference_update(empty_neighbors) - nodes.reverse() - if len(nodes) != len(resource_set): - raise NotADAGException("Graph is not a DAG") - return {value: index for index, value in enumerate(nodes)} - - -def diff(remote_state: State, manifest: Manifest): - - def _container_descriptor(resource_urn: URN) -> Optional[ContainerDescriptor]: + return sql_commands_per_change + + +def compute_levels(resource_set: Set[URN], references: Set[tuple[URN, URN]]) -> dict[URN, int]: + """Compute the dependency level for each URN based on references.""" + in_degrees = {urn: 0 for urn in resource_set} + for _, ref in references: + if ref in in_degrees: + in_degrees[ref] += 1 + levels = {} + queue = [urn for urn in resource_set if in_degrees[urn] == 0] + current_level = 0 + while queue: + next_queue = [] + for urn in queue: + levels[urn] = current_level + for parent, ref in references: + # if the parent is the current node and the ref is in the set of resources + if parent == urn and ref in in_degrees: + in_degrees[ref] -= 1 + if in_degrees[ref] == 0: + next_queue.append(ref) + queue = next_queue + current_level += 1 + if len(levels) != len(resource_set): + raise NotADAGException("Dependency graph contains cycles") + return levels + + +def diff(remote_state: State, manifest: Manifest) -> list: + """Compute the differences between remote state and manifest""" + + def _container_descriptor(urn: URN) -> Optional[ContainerDescriptor]: """ Given the URN of a resource, return a descriptor of the container that owns it. """ - if isinstance(RESOURCE_SCOPES[resource_urn.resource_type], AccountScope): + if isinstance(RESOURCE_SCOPES[urn.resource_type], AccountScope): return None - container_urn = _container_urn(resource_urn) + container_urn = _container_urn(urn) if container_urn in remote_state: if "owner" in remote_state[container_urn]: container_owner = remote_state[container_urn]["owner"] @@ -1313,10 +1339,6 @@ def _container_descriptor(resource_urn: URN) -> Optional[ContainerDescriptor]: return (container_urn, container_owner) def _diff_resource_data(lhs: dict, rhs: dict) -> dict: - - if not isinstance(lhs, dict) or not isinstance(rhs, dict): - raise TypeError("diff_resources requires two dictionaries") - delta = {} for field_name in lhs.keys(): lhs_value = lhs[field_name] @@ -1325,48 +1347,41 @@ def _diff_resource_data(lhs: dict, rhs: dict) -> dict: delta[field_name] = rhs_value return delta + changes = [] state_urns = set(remote_state.keys()) manifest_urns = set(manifest.urns) # Resources in remote state but not in the manifest should be removed for urn in state_urns - manifest_urns: - yield DropResource(urn, remote_state[urn]) + changes.append(DropResource(urn, remote_state[urn])) # Resources in the manifest but not in remote state should be added for urn in manifest_urns - state_urns: manifest_item = manifest[urn] if isinstance(manifest_item, ResourcePointer): - raise MissingResourceException( - f"Blueprint has pointer to resource that doesn't exist or isn't visible in session: {urn}" - ) - elif isinstance(manifest_item, ManifestResource): - # We don't create implicit resources - if manifest[urn].implicit: - continue - yield CreateResource( - urn, - manifest_item.resource_cls, - _container_descriptor(urn), - manifest_item.data, + raise MissingResourceException(f"Missing resource: {urn}") + elif isinstance(manifest_item, ManifestResource) and not manifest_item.implicit: + changes.append( + CreateResource( + urn, + manifest_item.resource_cls, + _container_descriptor(urn), + manifest_item.data, + ) ) - else: - raise Exception(f"Unknown type in manifest: {manifest_item}") # Resources in both should be compared for urn in state_urns & manifest_urns: manifest_item = manifest[urn] - - # We don't diff resource pointers if isinstance(manifest_item, ResourcePointer): continue - delta = _diff_resource_data(remote_state[urn], manifest_item.data) owner_attr = delta.pop("owner", None) - # TODO: do we care about implicit resources? replace_resource = False create_resource = False ignore_fields = set() + for attr in delta.keys(): attr_metadata = manifest_item.resource_cls.spec.get_metadata(attr) change_requires_replacement = attr_metadata.triggers_replacement @@ -1389,27 +1404,28 @@ def _diff_resource_data(lhs: dict, rhs: dict) -> dict: if replace_resource: raise NotImplementedError("replace_resource") - # yield DropResource(urn, remote_state[urn]) - # yield CreateResource(urn, manifest_item.resource_cls, manifest_item.data) - # continue if create_resource: - yield CreateResource( - urn, - manifest_item.resource_cls, - _container_descriptor(urn), - manifest_item.data, + changes.append( + CreateResource( + urn, + manifest_item.resource_cls, + _container_descriptor(urn), + manifest_item.data, + ) ) continue delta = {k: v for k, v in delta.items() if k not in ignore_fields} if delta: - yield UpdateResource( - urn, - manifest_item.resource_cls, - remote_state[urn], - manifest_item.data, - delta, + changes.append( + UpdateResource( + urn, + manifest_item.resource_cls, + remote_state[urn], + manifest_item.data, + delta, + ) ) # Force transfers to occur after all other attribute changes @@ -1423,16 +1439,20 @@ def _diff_resource_data(lhs: dict, rhs: dict) -> dict: if not owner_is_fetchable or owner_changes_should_be_ignored: continue - yield TransferOwnership( - urn, - manifest_item.resource_cls, - from_owner=remote_state[urn]["owner"], - to_owner=manifest_item.data["owner"], + changes.append( + TransferOwnership( + urn, + manifest_item.resource_cls, + remote_state[urn]["owner"], + manifest_item.data["owner"], + ) ) + return changes + def _sort_destructive_changes( - destructive_changes: list[ResourceChange], sort_order: dict[URN, int] + destructive_changes: Sequence[ResourceChange], sort_order: dict[URN, int] ) -> list[ResourceChange]: # Not quite right but close enough for now. def sort_key(change: ResourceChange) -> tuple: diff --git a/titan/blueprint_config.py b/titan/blueprint_config.py index bc68462..6318e75 100644 --- a/titan/blueprint_config.py +++ b/titan/blueprint_config.py @@ -30,6 +30,7 @@ class BlueprintConfig: scope: Optional[BlueprintScope] = None database: Optional[ResourceName] = None schema: Optional[ResourceName] = None + threads: int = 8 def __post_init__(self): @@ -102,6 +103,8 @@ def __post_init__(self): raise ValueError( f"Cannot specify a database or schema when using ACCOUNT scope (database={repr(self.database)}, schema={repr(self.schema)})" ) + if not isinstance(self.threads, int): + raise ValueError(f"Threads must be an integer, got: {self.threads}") def set_vars_defaults(vars_spec: list[dict], vars: dict) -> dict: