Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 401b3177da
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e915f6e0a2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| @abstractmethod | ||
| def default_generation_config(self) -> LatentTTSGenerationConfig: ... |
There was a problem hiding this comment.
i think that's a nice abstraction, maybe let's have the same for normal TTS too?
| def create_tokenizer(self, model_path: Path | str) -> Tokenizer: | ||
| return Tokenizer.from_file(str(Path(model_path) / "tokenizer.json")) |
There was a problem hiding this comment.
this is a nontrivial decision, write a comment why we need this (some private models need to override create tokenizer)? i think?
| def create_message_processor(self, config: TTSMessageProcessorConfig, tokenizer: Tokenizer) -> TTSMessageProcessor: | ||
| return TTSMessageProcessor(config, tokenizer) |
There was a problem hiding this comment.
no, wait, this is what needs to be overriden. Do we really need a create_tokenizer method?
| Current class is reserved for future usage of audio prompts | ||
| to condition style of generated audio | ||
| """ | ||
| waveform: Float[Array, "*"] |
There was a problem hiding this comment.
waveform: Float[Array, " audio_samples"]
| def import_weights(self, weights: ParameterTree[Array]) -> Self: | ||
| weights = require_mapping(weights) | ||
| block_weights = weights["decoder_blocks"] | ||
| assert isinstance(block_weights, Sequence) |
| d_out = spatial_params.d_out | ||
|
|
||
| num_keys = 2 + len(rates) | ||
| keys = jax.random.split(key, num_keys) |
There was a problem hiding this comment.
let's maybe do this as first_conv_key, final_conv_key, *decoder_keys =
| input_channel = spatial_params.input_channel | ||
| channels = spatial_params.channels | ||
| rates = spatial_params.rates | ||
| d_out = spatial_params.d_out |
There was a problem hiding this comment.
one letter variable name is meh
| mlp_ratio: float = 4.0 | ||
| kernel_size: int = 7 | ||
| dilation: int = 1 |
There was a problem hiding this comment.
remove the defaults
| y = self.act2(y) | ||
| y = self.conv2(y) | ||
|
|
||
| pad = x.shape[1] - y.shape[1] |
There was a problem hiding this comment.
please use shape unpacking instead of direct indexing
| class SnakeBetaConfig: | ||
| precision: DTypeLike | ||
| alpha_init: float = 1.0 | ||
| no_div_by_zero: float = 1e-9 |
There was a problem hiding this comment.
that's a weird var name, call it maybe eps or something like that?
| output = self._call_causal(x) | ||
| case Conv1dPadding.SYMMETRIC: | ||
| output = self._call_symmetric(x) | ||
|
|
There was a problem hiding this comment.
hard raise on anything else
| ) -> Float[Array, "batch sequence_out out_channels"]: | ||
| length = x.shape[1] # sequence dimension is axis 1 | ||
| pad = self.padding | ||
| length = x.shape[1] |
There was a problem hiding this comment.
let's not use direct indexing, unpack the shape tuple
lalamo/modules/audio/__init__.py
Outdated
| @@ -1 +1,13 @@ | |||
| # TODO @peter.glushkov: think carefully what to export once audio submodule is more stable | |||
There was a problem hiding this comment.
reassing todo to knyazer
| key: PRNGKeyArray, | ||
| ) -> Int[Array, " batch"]: | ||
| processed_logits = vmap(sampling_policy.process_logits)(logits) | ||
| sample_keys = jax.random.split(key, logits.shape[0]) |
There was a problem hiding this comment.
please don't index into shapes directly, use tuple unpacking
| @classmethod | ||
| def format_instruction(cls, style: str) -> str: | ||
| return f"<|im_start|>user\n{style}<|im_end|>\n" |
There was a problem hiding this comment.
errr doesn't this function contradict default definition of how formatting is supposed to work? the hardcoded prompt, no?
| dtype=jnp.int32, | ||
| ) | ||
| special_hidden = self._project_text_embeddings(special_text_tokens) | ||
| tts_bos_embed, tts_eos_embed, tts_pad_embed = jnp.split(special_hidden, 3, axis=1) |
There was a problem hiding this comment.
this seems to be very fragile, can't we just tokenize the whole formatted prompt directly instead of doing this surgery on raw embeddings
| key: PRNGKeyArray, | ||
| ) -> "VectorQuantization": | ||
| key_codebook, key_project = jax.random.split(key) | ||
| codebook_dim = dim if codebook_dim is None else codebook_dim |
There was a problem hiding this comment.
isn't codebook_dim never None? check, if so - update this line and type annotation
| def export_weights(self) -> ParameterTree[Array]: | ||
| project_out_weights: ParameterTree[Array] | ||
| if self.project_out is None: | ||
| project_out_weights = {} |
There was a problem hiding this comment.
follow convention in other files, first of all use assert require_mapping(project_out_weights), then use project_out_weights or None when exporting
lalamo/model_import/common.py
Outdated
| weights_paths = origin.resolve_weights(progress_callback) | ||
| config_path = origin.resolve_file(model_spec.configs.model_config, progress_callback) | ||
| extra_config_paths = tuple(origin.resolve_file(ec, progress_callback) for ec in model_spec.configs.extra_configs) |
There was a problem hiding this comment.
should we move configs under Origin?
| ) | ||
|
|
||
|
|
||
| class Origin(RegistryABC): |
| @property | ||
| @abstractmethod | ||
| def description(self) -> str: ... | ||
|
|
There was a problem hiding this comment.
uv run lalamo convert mars8nano --path blahblah
ModelSpec(
origin: Mars8NanoOrigin("../weights.pth", TORCH)
)
uv run lalamo convert mars8nano --custom-origin {origin: local, blahblah}
uv run lalamo convert mars8nano
each origin wants a list of keys that it wants to get to be resolved, okay
cli args back/forth with arbitrary
import models gets stuff from kwargs
cli origin? ENV origin?
No description provided.