Conversation
This reverts commit 8d1113d.
This reverts commit da2713f.
Add xshapes parameter to gather() function to validate gradient tensor shapes during the transition period, rejecting sharded uploads that don't match expected model dimensions. - Add xshapes parameter to gather() function signature - Validate vals tensor shape prefix against expected xshapes - Log warning and reject responses with shape mismatches - Update miner.py to pass xshapes=self.xshapes to gather - Update validator.py to pass xshapes=self.xshapes to gather This prevents sharded gradients from being accepted during the transition period before full shard support is enabled.
WalkthroughThis PR removes tensor-parallel (TP/DTensor) handling across the codebase, unifies cross-rank reductions to DDP-based ops, simplifies gradient reconstruction/metadata, adds shape validation to comms.gather, and trims several TP-specific utilities and scripts. No public API signatures were broadly expanded. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings, 1 inconclusive)
✨ Finishing touches
📜 Recent review detailsConfiguration used: Organization UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (81.01%) is below the target coverage (85.00%). You can increase the patch coverage or adjust the target coverage. @@ Coverage Diff @@
## dev #671 +/- ##
==========================================
+ Coverage 56.45% 57.72% +1.27%
==========================================
Files 27 27
Lines 5165 4975 -190
==========================================
- Hits 2916 2872 -44
+ Misses 2249 2103 -146
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 3
Fix all issues with AI Agents 🤖
In @neurons/miner.py:
- Around line 815-843: The debug slice for model parameter samples currently
uses indices [10:12] when populating debug_dict (entries named name + "_debug"),
but the comparer (compare_model_with_debug_dict) and catchup code expect the
first two elements (index_range=(0,2)); update the population logic in the loop
inside Miner (where debug_dict is filled for self.model.named_parameters()) to
use the same index_range (slice [0:2] or parameterize it via index_range) so the
debug samples and compare_model_with_debug_dict use the same indices and the
avg_steps_behind comparison can succeed.
In @scripts/cleanup_bucket.py:
- Around line 75-83: The cleanup script was left excluding "aggregator",
"debug", and "peers" while the app still creates those objects; update the
cleanup filter in scripts/cleanup_bucket.py to include those prefixes (add
"aggregator", "debug", and "peers" / "peers_" to the startswith tuple) so
objects produced by neurons.py (key="aggregator"), comms.py
s3_put_object()/s3_get_object() (debug key), and create_and_post_peers()
(peers_/peers) are deleted, or alternatively remove the creation of those keys
in neurons.py and comms.py to match the current exclusion—pick one approach and
make the corresponding change consistently across the cleanup filter and the
creators.
In @tests/unit/test_neurons.py:
- Around line 576-611: Tests in TestCheckUidIndexOverlap exercise a simple
pairwise helper but the production function tplr.neurons.check_uid_index_overlap
has a different signature/return used by Validator.run and slash_from_overlap;
fix by adding a new helper (e.g., _pairwise_index_overlap or
pairwise_index_overlap(uids, uid_to_indices, window=0)) that implements the
simple (uids, uid_to_indices, window) → {(uid_i, uid_j): set(indices)} contract
and update TestCheckUidIndexOverlap to import and call that helper, or
alternatively update the tests to construct a minimal neuron/gather_result and
call the existing async check_uid_index_overlap(neuron, gather_result, window,
overlap_threshold=...) so the tests validate the real production API (refer to
symbols: TestCheckUidIndexOverlap, check_uid_index_overlap, slash_from_overlap,
Validator.run, tplr.neurons).
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
neurons/trainer.py (1)
899-975: Remove pre-unscale gradient clipping to follow standard GradScaler patternThe code clips gradients twice within the optimizer step:
- Line 916: before
scaler.unscale_()(operates on scaled gradients)- Line 937: after
scaler.unscale_()(operates on unscaled gradients)Clipping before unscale violates PyTorch AMP best practice and can over-constrain updates. The standard pattern is unscale → clip → step, not clip → unscale → clip → step.
Remove the first clip call at line 916 to keep only the post-unscale clipping at line 937. The distributed reductions for
global_tokens_step,global_loss_step,log_loss, andaccum_batch_sizeare correctly implemented.neurons/validator.py (1)
3564-3740:update_model_with_gradientdecompression/distribution is mostly solid, but missing/invalid param data now hard-failsThe new path (using
torch.empty_like(p)asrefandself.xshapes[n]/self.totalks[n]for decompression) is correct and avoids copying parameter contents. Broadcastinghas_valid_gradientbefore any barrier is also the right pattern for distributed error consensus.However:
- Any parameter with
vals is Noneorquant_params is Nonenow flipshas_valid_gradient=False, causing aValueErrorfor that UID even if other parameters are fine.- That means any peer omitting a parameter from its state dict (or drifting top‑k layout) will have the entire gradient rejected, rather than having that parameter treated as zero-update.
If this is deliberate for the transition period, it would be safer to:
- Document this as a hard contract on the gradient schema, and/or
- Downgrade missing-per-param data to a logged skip (no update on that param) instead of raising, while still raising for genuine corruption (bad indices, NaN/Inf, shape mismatch).
Also consider guarding the
self.xshapes[n]/self.totalks[n]lookups with an explicitKeyError→ValueErrorconversion so failures produce a clear “unknown parameter in gradient” message.src/tplr/chain.py (1)
372-395: Early-return whenactive_peersis empty may leave stale eval peersThe new branch:
if not active_peers: logger.warning("No active peers found. Skipping update.") returnprevents using gather peers as a fallback, which is good. But it also skips rebuilding
self.eval_peers, so any previouseval_peersdict remains in place even thoughactive_peersis now empty.You might want to explicitly clear
self.eval_peers(and possiblyself.peers) in this branch to avoid evaluating or gathering from stale peer sets when the chain reports no active peers.
🧹 Nitpick comments (6)
src/tplr/comms.py (1)
98-101: Unify temp directory handling to avoid scattered /tmp paths (optional)
self.temp_diris now/tmp/templar_{self.uid}, whileput()creates a separate/tmp/{self.uid}per instance for temp files. Functionally this works, but using a single convention (e.g. alwaysself.temp_dir) would simplify reasoning about cleanup and avoid leaving stray empty/tmp/{uid}dirs around.Also applies to: 1314-1318
src/tplr/sharded_sampler.py (1)
67-69: DP-only grad accumulation and rank slicing look correct; consider enforcing divisibilityUsing
denom = micro_bs * world_sizeandself._local = global_indices[self.rank :: self.world_size]correctly implements pure-DP sharding and grad accumulation. The only subtlety is that ifbatch_sizeis not an exact multiple ofmicro_bs * world_size,grad_accum_stepswill floor-divide and you’ll get an effective global batch smaller than configured. Consider asserting or logging whenbatch_size % (micro_bs * world_size) != 0so misconfigurations are caught early.Also applies to: 75-80
neurons/trainer.py (1)
51-80: Sampler wiring matches TP-less samplers; minor redundancy inmicro_bs(optional)
set_dataloadernow cleanly routes throughMinerSampler/EvalSamplerwithout TP, and the shared args look consistent. The only nit is thatshared_argsalready containsmicro_bs, so re-passingmicro_bs=self.hparams.micro_batch_sizein the miner kwargs is redundant; you can drop it to avoid confusion.src/tplr/compress.py (1)
236-256:encodenow hard-requires pre-populatedshape_dictentriesSwitching to:
n1 = self.shape_dict[x.shape[0]] n2 = self.shape_dict[x.shape[1]](and similarly for 1D) removes the old dynamic “slow path” and will raise
KeyErrorif a new dimension appears that wasn’t seen in__init__.This is fine if
ChunkingTransformeris only ever used on parameters of the passedmodel. If there’s any chance of callingencodeon ad‑hoc tensors (e.g., external grads, reshaped buffers), consider replacing direct indexing with an explicit check that raises a clearer error such as:if x.shape[0] not in self.shape_dict: raise ValueError(f"Unsupported DCT size {x.shape[0]} – not in shape_dict")so callers get a diagnostic instead of a generic
KeyError.tests/test_model_comparison.py (1)
196-279: Well-structured test for custom index_range functionality.The test properly validates:
- Default index range (0, 2) matching behavior
- Custom index range (5, 7) with proper skip logic for small parameters
- Mismatch detection with expected difference calculation
Minor observation: The variable
num_elementson lines 214 and 234 is assigned but never used. Consider removing these unused assignments.🔎 Suggested cleanup
for name, param in model.named_parameters(): param_flat = param.flatten() - num_elements = param_flat.numel() debug_dict[name + "_debug"] = param_flat[:2].detach().cpu().tolist()for name, param in model.named_parameters(): param_flat = param.flatten() - num_elements = param_flat.numel() # Only include parameters with enough elements for the custom rangesrc/tplr/model_factory.py (1)
275-298: Redundantint()coercion fortp_degree.Line 275 already casts
tp_degreetoint, but line 296 casts it again. The same pattern exists for other degrees but is less obvious since their initial extraction doesn't includeint().Consider consolidating the coercion:
🔎 Suggested consolidation
- tp_degree = int(getattr(tt, "tp_degree", 1)) - pp_degree = int(getattr(tt, "pp_degree", 1)) - cp_degree = int(getattr(tt, "cp_degree", 1)) - dp_replicate = getattr(tt, "dp_replicate", 1) - dp_shard = getattr(tt, "dp_shard", 1) + tp_degree = getattr(tt, "tp_degree", 1) + pp_degree = getattr(tt, "pp_degree", 1) + cp_degree = getattr(tt, "cp_degree", 1) + dp_replicate = getattr(tt, "dp_replicate", 1) + dp_shard = getattr(tt, "dp_shard", 1) - # Ensure divisors are not zero before coercion to 1 and modulo operations + # Ensure divisors are not zero if dp_replicate == 0: raise ValueError("dp_replicate cannot be zero.") # ... other checks ... - # Coerce to int after zero checks + # Coerce all to int dp_replicate = int(dp_replicate) dp_shard = int(dp_shard) tp_degree = int(tp_degree) pp_degree = int(pp_degree) cp_degree = int(cp_degree)
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
.gitignoreneurons/evaluator.pyneurons/miner.pyneurons/trainer.pyneurons/validator.pyscripts/abort_multipart_uploads.pyscripts/cleanup_bucket.pysrc/tplr/chain.pysrc/tplr/comms.pysrc/tplr/compress.pysrc/tplr/distributed.pysrc/tplr/hparams.pysrc/tplr/model_factory.pysrc/tplr/neurons.pysrc/tplr/sharded_sampler.pytests/test_comms.pytests/test_evaluator.pytests/test_model_comparison.pytests/test_prepare_gradient_dict.pytests/unit/test_model_factory.pytests/unit/test_neurons.py
💤 Files with no reviewable changes (5)
- tests/test_evaluator.py
- src/tplr/distributed.py
- tests/test_comms.py
- .gitignore
- scripts/abort_multipart_uploads.py
🧰 Additional context used
🧬 Code graph analysis (9)
src/tplr/neurons.py (5)
src/tplr/compress.py (2)
batch_decompress(484-580)decode(260-292)tests/test_comms.py (2)
batch_decompress(954-956)decode(940-942)tests/conftest.py (1)
totalks(141-144)tests/test_prepare_gradient_dict.py (3)
decode(37-38)decode(174-175)decode(268-270)tests/test_dcp_checkpoint.py (1)
warning(117-118)
neurons/trainer.py (1)
src/tplr/distributed.py (1)
ddp_reduce(97-131)
src/tplr/chain.py (1)
tests/test_dcp_checkpoint.py (1)
warning(117-118)
neurons/evaluator.py (1)
src/tplr/distributed.py (1)
ddp_reduce(97-131)
neurons/miner.py (3)
src/tplr/logging.py (2)
T(41-48)P(51-62)src/tplr/comms.py (1)
put(1261-1352)src/tplr/neurons.py (1)
outer_step(225-476)
tests/unit/test_model_factory.py (1)
src/tplr/model_factory.py (1)
create_parallel_dims(215-332)
src/tplr/comms.py (1)
tests/test_dcp_checkpoint.py (1)
warning(117-118)
tests/unit/test_neurons.py (1)
src/tplr/neurons.py (1)
check_uid_index_overlap(1298-1441)
tests/test_model_comparison.py (1)
src/tplr/neurons.py (1)
compare_model_with_debug_dict(1194-1294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (30)
scripts/cleanup_bucket.py (1)
86-89: LGTM! Log messages are consistently updated.All log messages and the docstring accurately reflect the narrowed scope of the cleanup script. The messaging is clear and consistent throughout.
Also applies to: 108-110, 118-118
neurons/evaluator.py (1)
799-811: DDP reduction for custom eval metrics looks consistentUsing
dist_helper.ddp_reduce(..., op=dist.ReduceOp.SUM, device=self.device)onlocal_loss,local_tokens, andlocal_bytesto form global totals is correct and matches the later average/perplexity and bpb computations.src/tplr/hparams.py (1)
77-85: Shallowdict.updatechanges nested hparams override semanticsSwitching from deep merges to:
full_hparams = DEFAULT_HPARAMS.copy(); full_hparams.update(hparams)increate_namespace, and- successive
hparams.update(...)calls inload_hparamsmeans any nested dicts (e.g.
optimizer,anneal_mode,torchtitan) from higher‑priority files now replace lower‑priority ones entirely, rather than merging per key. That’s fine if all override JSONs carry complete nested sections; just ensure no configs rely on “partial” nested overrides that previously depended on deep_merge behavior.Also applies to: 145-195
src/tplr/comms.py (1)
1606-1621: xshapes-based vals shape check should correctly filter sharded gradientsThe new
xshapesargument and the(vals.shape[:-1] == xshapes[base_name][:-1])check give a cheap, robust guard against receiving sharded or otherwise mismatched gradients: assumingxshapes[base_name]holds the unsharded parameter shape and compression is along the last dimension, any TP-style shards will fail the prefix match and the UID is skipped with a clear warning. Please just confirm that:
xshapeskeys use the raw parameter names (matchingparam_name + "vals"/param_name + "idxs"), and- the TopK compressor indeed produces
valswith shapexshape[:-1] + [topk],so valid gradients aren’t accidentally rejected.
Also applies to: 1650-1652, 1855-1875
neurons/trainer.py (2)
427-479: Evaluation loss and batch aggregation viaddp_reduceis coherentSumming
total_lossandn_batchesacross ranks withdist_helper.ddp_reduce(..., device=device)yields global totals; callers can recover a true mean astotal_loss / n_batches. The guard onworld_size > 1 and dist_helper.is_distributed()keeps single-GPU behavior unchanged.
591-631: Adam metrics: distributed reductions and max tracking are implemented correctlyThe new block that:
- sums local squared norms and moment buffers with
ddp_reduce(..., device=self.device),- uses
op=ReduceOp.MAXforlocal_grad_norm_max/local_update_norm_max, and- aggregates ratio sums and counts before dividing,
produces consistent, global metrics across ranks. The subsequent square-roots/ratios use these globals correctly.
neurons/validator.py (4)
1491-1511: Passingxshapesintogather_with_reservelooks correctWiring
xshapes=self.xshapesnext tototalks=self.totalkskeeps the validator-side expectations consistent with the new shape-aware gather validation in comms. Assumingself.xshapes/self.totalksare constructed for every named parameter (as in__init__), this call site looks sound.
1878-1893: Early-exit on window change during eval is reasonableThe new master-only log +
breakwhenself.current_window != eval_windowprevents partially-evaluated windows without leaving other ranks behind (barriers come after the loop). This is a sensible guard and should avoid inconsistent scoring when the chain advances mid-evaluation.
4040-4055: Fixing miner digest toworld_size=1is consistent but assumes single-rank semanticsPassing
world_size=1intoMinerSamplerfor the validator-side digest makes the reconstructed index pool independent of validator DDP/TP topology, which is desirable for matching miner behavior.Just ensure the miner-side
_training_pool_digest(or equivalent) uses the sameworld_size=1convention; otherwise, digest mismatches could appear even for honest miners purely due to parallel layout differences.
2031-2068: Gradient-apply failure handling ensures group consensus by rejecting whole peer gradientsThe current design explicitly invalidates the entire peer gradient if any parameter is missing or undecompressible, rather than treating missing per-parameter contributions as zeros. This is intentional: any per-parameter failure (missing vals/quant_params, decompression errors, NaN/Inf in decompressed data) raises an exception at the parameter loop level, broadcasts consensus to all ranks, then triggers a slash + full model-state restore for that UID.
Confirm this stricter schema—rejecting partially-sparse or schema-drifted gradients outright—aligns with your expected behavior, since the alternative (zero-filling missing parameters) would be more permissive.
tests/test_prepare_gradient_dict.py (1)
107-110: Metadata tests now correctly enforce minimal{"window": step_window}contractBoth tests now assert that
gradient["metadata"]is exactly{"window": step_window}. This matches the simplified metadata semantics and guards against accidentally reintroducing extra keys likexshapes/totalksinto the payload.Also applies to: 134-138
tests/unit/test_model_factory.py (1)
242-247: Parallel-dims tests correctly mirror new validation rulesThe updated miner tests:
- Use
torchtitan.tp_degree=2for the custom-TP happy path.- Assert the new error message for
dp_replicate >1anddp_shard >1.- Assert the refined message for
dp_replicatewith non‑trivialtp.- Check that both
dp_shardandtp_degreemis-divisibility raise the generic “world_size … must be divisible by the product of all parallel degrees” error.These are consistent with
create_parallel_dimsinmodel_factory.pyand should catch future regressions in the validation logic.Also applies to: 248-281
tests/unit/test_neurons.py (2)
48-76:compare_model_with_debug_dicttests now match the 3-element slice behaviorSwitching
param1to aMagicMockwith.dataand.dtype, and expanding the debug slices andindex_range=(0, 3), lines up with the current implementation which operates on a 3‑element slice. The perfect-match and mismatch cases look correct and will regress if the comparison logic changes.
476-481: Broadcast call-count expectations adjusted to new catchup behaviorThe updated
mock_broadcast.call_countassertions (3, 4, and 2 respectively) reflect the revisedcatchup_with_aggregation_servercontrol flow (skipped windows, fallbacks, and early exits). As long as these numbers were obtained by tracing the new implementation, the tests look aligned and will help catch future changes to the broadcast protocol.Also applies to: 521-524, 566-569
tests/test_model_comparison.py (1)
308-314: LGTM!Explicitly passing
index_range=(0, 2)in tests makes the intent clear, even though it matches the default. This improves test readability and ensures tests won't break if the default changes.src/tplr/model_factory.py (2)
254-270: Hardcodeddp_shard=4may cause failures for non-standard validator setups.The validator role now requires
world_size % 4 == 0. This is a breaking change for validators running with world sizes like 1, 2, 3, 5, 6, 7, etc.Consider whether this is intentional. If validators must run with exactly 4 GPUs (or multiples), document this requirement. Otherwise, consider a more flexible approach:
🔎 Suggested alternative
elif role == "validator": # Validator: pipeline parallelism with data parallel replication - # Ensure dp_shard is at least 1 to prevent division by zero - dp_shard = 4 + # Use dp_shard that evenly divides world_size (prefer 4 if possible) + dp_shard = 4 if world_size >= 4 and world_size % 4 == 0 else max(1, min(4, world_size)) if world_size % dp_shard != 0:Is the requirement for validators to run with
world_sizedivisible by 4 intentional? If so, this should be documented in deployment guides.
300-322: LGTM!The validation logic is comprehensive and well-structured:
- Mutual exclusivity between
dp_replicate > 1anddp_shard > 1- Constraint that
dp_replicate > 1requires all other parallelism degrees to be 1- Clear error messages with the actual values for debugging
The safeguard at lines 314-317 provides defense-in-depth even though individual zero checks should prevent this case.
src/tplr/neurons.py (7)
63-69: LGTM!Good defensive pattern for DTensor detection with a graceful fallback to string-based type checking. This reduces hard dependencies on DTensor internals.
120-131: Verify assertion behavior for DTensor parameters without gradients.Line 128-130 logic: if a DTensor parameter has
grad=None, the code falls through to the assertion at line 130, which will fail.This is likely intentional (DTensor params should always have grads after backward in a distributed setting), but consider whether there are edge cases where a DTensor param might legitimately have no gradient (e.g., frozen parameters, parameters not used in forward pass).
If this assertion is expected behavior, consider adding a more descriptive error message:
- assert g is not None, f"p.grad is None for {n}" + assert g is not None, f"p.grad is None for DTensor param {n} - all DTensor params must have gradients"
138-152: LGTM!Clean simplification of error feedback handling:
- Initializes to zeros on first use
- Properly handles device placement
- Clear null-round vs normal momentum update logic
220-221: LGTM!Metadata simplified to contain only
window. Thexshapesandtotalksare now passed separately togather()calls, which provides better separation of concerns.
353-371: LGTM!Good memory optimization using
empty_likeinstead of copying the parameter tensor. The decompression path is simplified by usingxshapes[name]andtotalks[name]directly from function parameters.
1145-1170: LGTM!Clean refactoring to use the shared
compare_model_with_debug_dicthelper with proper parameter passing:
index_range=(0, 2)for debug snippet comparisonparam_avg_changefor step-size estimation- Appropriate error handling and logging
656-658: LGTM!Informative log message for the edge case where bootstrap start_window cannot be fetched.
neurons/miner.py (6)
207-212: Consider removing unused parallel degree attributes.
tp_degree(line 207) is stored but no longer used for ownership computation after reverting TP support. The ownership is now determined purely byidx % world_size(line 236).If TP is not supported, consider removing these attributes or documenting that they're for future use:
self.tp_degree # Stored but unused self.pp_degree # Stored but unused self.cp_degree # Stored but unusedAre these attributes used elsewhere in the miner, or should they be removed as part of the TP revert?
239-243: LGTM!Simplified error feedback initialization:
- Lazy initialization (None) for actual tensors
- Pre-allocated pinned CPU buffers for efficient GPU↔CPU transfers
- Uses parameter shape directly, which works for both regular and DTensor params
638-667: LGTM!Clean simplification of gradient merging:
- Simple
dict.update()instead of TP-specific reconstruction- Good memory management by freeing shards immediately after use
- Metadata correctly includes window and sample receipt information
691-711: LGTM!Upload pathway is clean:
- Master-only upload with proper async handling
- Useful upload size logging for debugging/monitoring
- Consistent error handling via the comms.put interface
750-766: LGTM!The addition of
xshapes=self.xshapestogather_with_reserveenables the gradient shape validation mentioned in the PR objectives. This helps prevent accepting sharded gradients during transition periods.
931-938: LGTM!Good observability addition - gradient fingerprint metrics provide visibility into the optimization process and can help diagnose training issues.
Add three new test cases to verify xshapes-based gradient shape validation in comms.gather(): - test_gather_rejects_sharded_gradient_shape: Verifies that gradients with mismatched dimensions (e.g., sharded instead of full) are properly rejected when xshapes is provided - test_gather_accepts_correct_gradient_shape: Ensures gradients with correct shapes are accepted when xshapes validation is enabled - test_gather_without_xshapes_accepts_all: Confirms that shape validation is skipped entirely when xshapes parameter is None These tests protect against miners uploading incorrectly shaped gradients (particularly sharded gradients when full gradients are expected).
Description
Related Issue(s)
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
If this is a breaking change
Screenshots/Examples
Additional Notes
Summary by CodeRabbit
Refactor
New Features
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.