Skip to content

Commit 046aac0

Browse files
cms42CloudRipple
andcommitted
refactor: update DacVAE Architecture, Configuration, ComponentLoader and revert necessary changes (sgl-project#20)
* refactor: enhance KL divergence method in DiagonalGaussianDistribution for flexible dimension handling and clean up DAC class by removing unused code * refactor: update DacVAE architecture, configuration and its customized loader. * Revert "fix: update adjust_frames parameter to False for improved multi-GPU compatibility" * revert changes in base pipeline configs * revert changes in configs/sample/__init__.py * [Feature] Remove weight norm in DAC * [Fix] Use legacy weight norm, which can be removed * [Fix] remove weight norm at the right place * [Chore] update test script * Revert "[Fix] remove weight norm at the right place" This reverts commit 3a0accbae41650e926c5828025323a12454827a4. * Revert "[Fix] Use legacy weight norm, which can be removed" This reverts commit eb93f20f134888adba4a5124fa1d167b93d180e7. * Revert "[Feature] Remove weight norm in DAC" This reverts commit aaa64abbc25112a706bf3d3604ffeac390a1d8a8. * [Feature] Remove all weight norm from DAC modeling --------- Co-authored-by: CloudRipple <[email protected]>
1 parent 32dbb5d commit 046aac0

8 files changed

Lines changed: 324 additions & 792 deletions

File tree

python/sglang/multimodal_gen/configs/models/vaes/dac.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,29 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from dataclasses import dataclass, field
5+
from typing import List
56

6-
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
7+
from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig
78

89

910
@dataclass
10-
class DacVAEArchConfig(VAEArchConfig):
11-
sample_rate: int = 44100
12-
hop_length: int = 2048
11+
class DacVAEArchConfig(ArchConfig):
12+
codebook_dim: int = 8
13+
codebook_size: int = 1024
14+
continuous: bool = True
15+
decoder_dim: int = 2048
16+
decoder_rates: List[int] = field(default_factory=lambda: [8, 5, 4, 3, 2])
17+
encoder_dim: int = 128
18+
encoder_rates: List[int] = field(default_factory=lambda: [2, 3, 4, 5, 8])
19+
hop_length: int = 3840
1320
latent_dim: int = 128
21+
n_codebooks: int = 9
22+
quantizer_dropout: bool = False
23+
sample_rate: int = 48000
1424

1525

1626
@dataclass
17-
class DacVAEConfig(VAEConfig):
18-
arch_config: VAEArchConfig = field(default_factory=DacVAEArchConfig)
19-
load_encoder: bool = False
27+
class DacVAEConfig(ModelConfig):
28+
arch_config: DacVAEArchConfig = field(default_factory=DacVAEArchConfig)
29+
load_encoder: bool = True
2030
load_decoder: bool = True

python/sglang/multimodal_gen/configs/pipeline_configs/base.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -536,29 +536,9 @@ def from_kwargs(
536536
)
537537
from sglang.multimodal_gen.registry import get_pipeline_config_classes
538538

539-
pipeline_config_cls = None
540-
541-
# If users explicitly specify a pipeline class name, try to resolve
542-
# pipeline config classes directly without relying on model_index.json.
543-
if pipeline_class_name:
544-
config_classes = get_pipeline_config_classes(pipeline_class_name)
545-
if config_classes is not None:
546-
pipeline_config_cls, _ = config_classes
547-
logger.info(
548-
"Using pipeline_class_name '%s' to resolve PipelineConfig: %s",
549-
pipeline_class_name,
550-
pipeline_config_cls.__name__,
551-
)
552-
else:
553-
logger.warning(
554-
"pipeline_class_name '%s' not found in pipeline config registry; "
555-
"falling back to model auto-detection.",
556-
pipeline_class_name,
557-
)
558-
559539
# If model_path is a safetensors file and pipeline_class_name is specified,
560540
# try to get PipelineConfig from the registry first
561-
if pipeline_config_cls is None and is_safetensors_file and pipeline_class_name:
541+
if is_safetensors_file and pipeline_class_name:
562542
config_classes = get_pipeline_config_classes(pipeline_class_name)
563543
if config_classes is not None:
564544
pipeline_config_cls, _ = config_classes
@@ -582,7 +562,7 @@ def from_kwargs(
582562
f"Available pipelines with config classes: {available_pipelines}"
583563
)
584564
pipeline_config_cls = model_info.pipeline_config_cls
585-
elif pipeline_config_cls is None:
565+
else:
586566
model_info = get_model_info(model_path)
587567
if model_info is None:
588568
raise ValueError(

python/sglang/multimodal_gen/configs/sample/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from sglang.multimodal_gen.configs.sample.diffusers_generic import (
44
DiffusersGenericSamplingParams,
55
)
6-
from sglang.multimodal_gen.configs.sample.mova import MovaSamplingParams
76
from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
87

9-
__all__ = ["SamplingParams", "DiffusersGenericSamplingParams", "MovaSamplingParams"]
8+
__all__ = ["SamplingParams", "DiffusersGenericSamplingParams"]

python/sglang/multimodal_gen/configs/sample/sampling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class SamplingParams:
148148
# if True, disallow user params to override subclass-defined protected fields
149149
no_override_protected_fields: bool = False
150150
# whether to adjust num_frames for multi-GPU friendly splitting (default: True)
151-
adjust_frames: bool = False
151+
adjust_frames: bool = True
152152

153153
def _set_output_file_ext(self):
154154
# add extension if needed

python/sglang/multimodal_gen/runtime/loader/component_loader.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -704,40 +704,67 @@ def should_offload(
704704
return server_args.vae_cpu_offload
705705

706706
def load_customized(
707-
self, component_model_path: str, server_args: ServerArgs, module_name: str
707+
self,
708+
component_model_path: str,
709+
server_args: ServerArgs,
710+
module_name: str | None = None,
708711
):
709-
from sglang.multimodal_gen.runtime.models.vaes.dac import DAC
710712

711-
server_args.model_paths[module_name] = component_model_path
712-
# Prefer diffusers-style directory if present
713-
config_path = os.path.join(component_model_path, "config.json")
714-
if os.path.isfile(config_path):
715-
audio_vae = DAC.from_pretrained(component_model_path)
716-
return audio_vae.eval()
717-
718-
# Fallback: load from a single checkpoint file
719-
if os.path.isfile(component_model_path):
720-
if component_model_path.endswith(".dac"):
721-
return DAC.load(component_model_path).eval()
722-
state_dict = torch.load(component_model_path, map_location="cpu")
723-
audio_vae = DAC()
724-
audio_vae.load_state_dict(state_dict, strict=False)
725-
return audio_vae.eval()
726-
727-
# Attempt to load any supported file in directory
728-
for candidate in ("model.safetensors", "pytorch_model.bin", "model.pth"):
729-
candidate_path = os.path.join(component_model_path, candidate)
730-
if os.path.isfile(candidate_path):
731-
if candidate_path.endswith(".safetensors"):
732-
state_dict = safetensors_load_file(candidate_path)
733-
else:
734-
state_dict = torch.load(candidate_path, map_location="cpu")
735-
audio_vae = DAC()
736-
audio_vae.load_state_dict(state_dict, strict=False)
737-
return audio_vae.eval()
738-
raise FileNotFoundError(
739-
f"Cannot locate audio VAE weights in {component_model_path}"
740-
)
713+
config = get_diffusers_component_config(model_path=component_model_path)
714+
class_name = config.pop("_class_name", None)
715+
assert (
716+
class_name is not None
717+
), "Model config does not contain a _class_name attribute. Only diffusers format is supported."
718+
719+
module_key = "audio_vae"
720+
if module_name in ("audio_vae",):
721+
module_key = module_name
722+
server_args.model_paths[module_key] = component_model_path
723+
logger.info("HF model config: %s", config)
724+
725+
audio_vae_config = server_args.pipeline_config.audio_vae_config
726+
audio_vae_config.update_model_arch(config)
727+
728+
should_offload = self.should_offload(server_args)
729+
target_device = self.target_device(should_offload)
730+
731+
# Check for auto_map first (custom VAE classes)
732+
auto_map = config.get("auto_map", {})
733+
auto_model_map = auto_map.get("AutoModel")
734+
if auto_model_map:
735+
module_path, cls_name = auto_model_map.rsplit(".", 1)
736+
custom_module_file = os.path.join(component_model_path, f"{module_path}.py")
737+
spec = importlib.util.spec_from_file_location("_custom", custom_module_file)
738+
custom_module = importlib.util.module_from_spec(spec)
739+
spec.loader.exec_module(custom_module)
740+
vae_cls = getattr(custom_module, cls_name)
741+
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
742+
with set_default_torch_dtype(vae_dtype):
743+
vae = vae_cls.from_pretrained(
744+
component_model_path,
745+
revision=server_args.revision,
746+
trust_remote_code=server_args.trust_remote_code,
747+
)
748+
vae = vae.to(device=target_device, dtype=vae_dtype)
749+
return vae.eval()
750+
751+
# Load from ModelRegistry (standard VAE classes)
752+
with (
753+
set_default_torch_dtype(
754+
PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
755+
),
756+
skip_init_modules(),
757+
):
758+
audio_vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
759+
audio_vae = audio_vae_cls(audio_vae_config).to(target_device)
760+
761+
safetensors_list = _list_safetensors_files(component_model_path)
762+
assert (
763+
len(safetensors_list) == 1
764+
), f"Found {len(safetensors_list)} safetensors files in {component_model_path}"
765+
loaded = safetensors_load_file(safetensors_list[0])
766+
audio_vae.load_state_dict(loaded, strict=False)
767+
return audio_vae.eval()
741768

742769

743770
class MovaDiTLoader(ComponentLoader):

python/sglang/multimodal_gen/runtime/models/vaes/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,15 +611,17 @@ def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
611611
return x
612612

613613
def kl(
614-
self, other: Optional["DiagonalGaussianDistribution"] = None
614+
self,
615+
other: Optional["DiagonalGaussianDistribution"] = None,
616+
dims: tuple[int, ...] = (1, 2, 3),
615617
) -> torch.Tensor:
616618
if self.deterministic:
617619
return torch.Tensor([0.0])
618620
else:
619621
if other is None:
620622
return 0.5 * torch.sum(
621623
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
622-
dim=[1, 2, 3],
624+
dim=dims,
623625
)
624626
else:
625627
return 0.5 * torch.sum(
@@ -628,7 +630,7 @@ def kl(
628630
- 1.0
629631
- self.logvar
630632
+ other.logvar,
631-
dim=[1, 2, 3],
633+
dim=dims,
632634
)
633635

634636
def nll(

0 commit comments

Comments
 (0)