|
21 | 21 | PACKAGE_LOCK_FILE_NAME, |
22 | 22 | PACKAGE_LOCK_HASH_KEY, |
23 | 23 | PACKAGES_FILE_NAME, |
| 24 | + VARS_FILE_NAME, |
24 | 25 | ) |
25 | 26 | from dbt.contracts.project import PackageConfig |
26 | 27 | from dbt.contracts.project import Project as ProjectContract |
@@ -108,6 +109,31 @@ def load_yml_dict(file_path): |
108 | 109 | return ret |
109 | 110 |
|
110 | 111 |
|
| 112 | +def vars_data_from_root(project_root: str) -> Dict[str, Any]: |
| 113 | + """Load vars from vars.yml file if it exists. |
| 114 | +
|
| 115 | + Returns the contents of the 'vars' key, or empty dict if file doesn't exist or has no vars key. |
| 116 | + """ |
| 117 | + vars_yml_path = os.path.join(project_root, VARS_FILE_NAME) |
| 118 | + vars_file_dict = load_yml_dict(vars_yml_path) |
| 119 | + if not vars_file_dict: |
| 120 | + return {} |
| 121 | + return vars_file_dict.get("vars", {}) |
| 122 | + |
| 123 | + |
| 124 | +def validate_vars_not_in_both( |
| 125 | + project_dict: Dict[str, Any], |
| 126 | + has_vars_file: bool, |
| 127 | +) -> None: |
| 128 | + """Raise error if vars defined in both vars.yml and dbt_project.yml.""" |
| 129 | + has_project_vars = "vars" in project_dict and project_dict["vars"] |
| 130 | + |
| 131 | + if has_vars_file and has_project_vars: |
| 132 | + raise DbtProjectError( |
| 133 | + f"Variables cannot be defined in both {VARS_FILE_NAME} and {DBT_PROJECT_FILE_NAME}. " |
| 134 | + ) |
| 135 | + |
| 136 | + |
111 | 137 | def package_and_project_data_from_root(project_root): |
112 | 138 | packages_yml_dict = load_yml_dict(f"{project_root}/{PACKAGES_FILE_NAME}") |
113 | 139 | dependencies_yml_dict = load_yml_dict(f"{project_root}/{DEPENDENCIES_FILE_NAME}") |
@@ -338,10 +364,14 @@ def get_rendered( |
338 | 364 | ) |
339 | 365 |
|
340 | 366 | # Called by Project.from_project_root which first calls PartialProject.from_project_root |
341 | | - def render(self, renderer: DbtProjectYamlRenderer) -> "Project": |
| 367 | + def render( |
| 368 | + self, |
| 369 | + renderer: DbtProjectYamlRenderer, |
| 370 | + vars_from_file: Optional[Dict[str, Any]] = None, |
| 371 | + ) -> "Project": |
342 | 372 | try: |
343 | 373 | rendered = self.get_rendered(renderer) |
344 | | - return self.create_project(rendered) |
| 374 | + return self.create_project(rendered, vars_from_file=vars_from_file) |
345 | 375 | except DbtProjectError as exc: |
346 | 376 | if exc.path is None: |
347 | 377 | exc.path = os.path.join(self.project_root, DBT_PROJECT_FILE_NAME) |
@@ -376,7 +406,11 @@ def check_config_path( |
376 | 406 | kwargs.update({"exp_path": expected_path}) |
377 | 407 | deprecations.warn(f"project-config-{deprecated_path}", **kwargs) |
378 | 408 |
|
379 | | - def create_project(self, rendered: RenderComponents) -> "Project": |
| 409 | + def create_project( |
| 410 | + self, |
| 411 | + rendered: RenderComponents, |
| 412 | + vars_from_file: Optional[Dict[str, Any]] = None, |
| 413 | + ) -> "Project": |
380 | 414 | unrendered = RenderComponents( |
381 | 415 | project_dict=self.project_dict, |
382 | 416 | packages_dict=self.packages_dict, |
@@ -485,10 +519,13 @@ def create_project(self, rendered: RenderComponents) -> "Project": |
485 | 519 | saved_queries = cfg.saved_queries |
486 | 520 | exposures = cfg.exposures |
487 | 521 | functions = cfg.functions |
488 | | - if cfg.vars is None: |
489 | | - vars_dict: Dict[str, Any] = {} |
| 522 | + |
| 523 | + # Use vars from vars.yml if provided, otherwise use vars from dbt_project.yml |
| 524 | + # Mutual exclusivity ensures only one source is populated |
| 525 | + if vars_from_file: |
| 526 | + vars_dict = vars_from_file |
490 | 527 | else: |
491 | | - vars_dict = cfg.vars |
| 528 | + vars_dict = cfg.vars or {} |
492 | 529 |
|
493 | 530 | vars_value = VarProvider(vars_dict) |
494 | 531 | # There will never be any project_env_vars when it's first created |
@@ -557,6 +594,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": |
557 | 594 | restrict_access=cfg.restrict_access, |
558 | 595 | dbt_cloud=dbt_cloud, |
559 | 596 | flags=flags, |
| 597 | + vars_from_file=vars_from_file or {}, |
560 | 598 | ) |
561 | 599 | # sanity check - this means an internal issue |
562 | 600 | project.validate() |
@@ -676,6 +714,7 @@ class Project: |
676 | 714 | restrict_access: bool |
677 | 715 | dbt_cloud: Dict[str, Any] |
678 | 716 | flags: Dict[str, Any] |
| 717 | + vars_from_file: Dict[str, Any] |
679 | 718 |
|
680 | 719 | @property |
681 | 720 | def all_source_paths(self) -> List[str]: |
@@ -786,11 +825,16 @@ def from_project_root( |
786 | 825 | *, |
787 | 826 | verify_version: bool = False, |
788 | 827 | validate: bool = False, |
| 828 | + vars_from_file: Optional[Dict[str, Any]] = None, |
789 | 829 | ) -> "Project": |
790 | 830 | partial = PartialProject.from_project_root( |
791 | 831 | project_root, verify_version=verify_version, validate=validate |
792 | 832 | ) |
793 | | - return partial.render(renderer) |
| 833 | + |
| 834 | + # Check mutual exclusivity before rendering |
| 835 | + validate_vars_not_in_both(partial.project_dict, bool(vars_from_file)) |
| 836 | + |
| 837 | + return partial.render(renderer, vars_from_file=vars_from_file) |
794 | 838 |
|
795 | 839 | def hashed_name(self): |
796 | 840 | return md5(self.project_name) |
|
0 commit comments