Skip to content

revert/TP#671

Merged
joellidin merged 5 commits intodevfrom
revert/TP
Jan 5, 2026
Merged

revert/TP#671
joellidin merged 5 commits intodevfrom
revert/TP

Conversation

@joellidin
Copy link
Copy Markdown
Collaborator

@joellidin joellidin commented Jan 5, 2026

  • Revert "(neurons) Revert debug dict sampling indices"
  • Revert "feat: implement Tensor Parallelism (TP) support"
  • (comms) Add gradient shape validation to gather

Description

Related Issue(s)

  • Closes #[issue number]

Type of Change

  • Feature (adding new functionality)
  • Fix (resolving a bug or issue)
  • Docs (documentation updates)
  • Refactor (code changes that don't affect functionality)
  • Maintenance (dependency updates or other maintenance)
  • Tests (adding or improving tests)
  • Breaking change (fix or feature with incompatible API changes)
  • Other: _____

Branch Naming

  • My branch follows the project's naming convention (e.g., feature/add-new-capability)

Commit Messages

  • My commits are small, atomic, and have proper commit messages
  • Commit messages are in imperative mood with a capitalized summary under 50 chars

Code Quality

  • I've performed a self-review of my code
  • I've added appropriate docstrings following the project's conventions
  • I've added proper logging where necessary (without trailing periods)
  • I've applied linting and formatting with Ruff
  • My code generates no new warnings

Testing

  • I've added tests for new functionality or bug fixes
  • All tests pass locally with my changes
  • Test coverage has not decreased

Documentation

  • I've updated documentation to reflect my changes
  • I've updated comments in hard-to-understand areas

If this is a breaking change

Screenshots/Examples

Additional Notes

Summary by CodeRabbit

  • Refactor

    • Removed tensor-parallel-specific branches and unified distributed reduction, sharding, and gradient merge flows for simpler single-path behavior.
  • New Features

    • Added gradient shape validation during gather and added parameter index-range support for model comparisons.
  • Bug Fixes

    • Improved cross-rank metric aggregation, evaluation stability, logging, and handling of invalid/non-finite gradients.
  • Tests

    • Expanded and adjusted tests to exercise shape validation, index-range comparisons, UID overlap checks, and updated distributed expectations.
  • Chores

    • Version bumped; removed an unused multipart-upload utility and narrowed bucket-cleanup scope.

✏️ Tip: You can customize this high-level summary in your review settings.

Add xshapes parameter to gather() function to validate gradient tensor
shapes during the transition period, rejecting sharded uploads that
don't match expected model dimensions.

- Add xshapes parameter to gather() function signature
- Validate vals tensor shape prefix against expected xshapes
- Log warning and reject responses with shape mismatches
- Update miner.py to pass xshapes=self.xshapes to gather
- Update validator.py to pass xshapes=self.xshapes to gather

This prevents sharded gradients from being accepted during the
transition period before full shard support is enabled.
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Jan 5, 2026

Walkthrough

This PR removes tensor-parallel (TP/DTensor) handling across the codebase, unifies cross-rank reductions to DDP-based ops, simplifies gradient reconstruction/metadata, adds shape validation to comms.gather, and trims several TP-specific utilities and scripts. No public API signatures were broadly expanded.

Changes

Cohort / File(s) Summary
Repository config
\.gitignore
Removed validator-state-*.pt ignore rule; validator-state-*.npz remains ignored.
Neurons: eval/train/miner/validator
neurons/evaluator.py, neurons/trainer.py, neurons/miner.py, neurons/validator.py
Eliminated TP_DEGREE/env TP handling; removed TP-specific reconstruction and per-shard logic; consolidated all-reduce paths to ddp_reduce (SUM/AVG) without TP scaling; simplified error_feedback/metadata and logging; removed timing/tp-conditional branches.
Comms & distributed helpers
src/tplr/comms.py, src/tplr/distributed.py
Added optional xshapes param to gather and validate gradient shapes against expected shapes; removed DTensor safe-globals init; removed batched_all_reduce; simplified temp-dir usage to /tmp.
Sampling & model utils
src/tplr/sharded_sampler.py, src/tplr/neurons.py, src/tplr/chain.py
Removed tp_degree from samplers and TP-aware sharding; simplified DTensor detection and gradient flows; removed fallback peer update path when no active peers.
Model/config handling
src/tplr/hparams.py, src/tplr/model_factory.py
Replaced recursive deep_merge with shallow .update(); removed deep_merge; simplified/strict parallel-dims validation (fixed dp_shard for validator, guarded miner degree extraction).
Compression
src/tplr/compress.py
Removed dynamic shape-enforcement helper _ensure_shape_in_dict; encode assumes pre-populated shape_dict entries (no dynamic TP fallback).
Scripts
scripts/abort_multipart_uploads.py, scripts/cleanup_bucket.py
Deleted abort_multipart_uploads.py; cleanup_bucket.py now restricts deletable prefixes to checkpoint, gradient, and start_window.
Tests
tests/test_comms.py, tests/test_evaluator.py, tests/test_prepare_gradient_dict.py, tests/test_model_comparison.py, tests/unit/test_model_factory.py, tests/unit/test_neurons.py, tests/...
Added gather shape-validation tests; removed explicit evaluator.tp_degree assignment; adjusted gradient metadata expectations to only {"window": ...}; added/updated compare_model_with_debug_dict index_range behavior and tests; added check_uid_index_overlap tests; updated model-factory test assertions.
Package version
src/tplr/__init__.py
Version bumped from 2.1.192.1.20.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • v2.1.19 #670 — Modifies many of the same TP/DTensor codepaths (neurons, sharded_sampler, comms, tplr modules) with opposing TP-related changes.
  • feat/anneal #669 — Touches the same compare_model_with_debug_dict API and index-range sampling behavior.
  • feat: Implement Tensor Parallelism (TP) support #621 — Alters parallel/TP validation and distributed reduction patterns in the same modules (hparams, model_factory, neurons).

Suggested reviewers

  • shivam-MBZUAI
  • amiiir-sarfi

Poem

"🐇 I hopped through shards and threads,
Dusted TP crumbs from my bed,
Merged the ranks with steady hum,
DDP now leads the drum,
Clean gradients stride — hooray, onward we tread!"

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description check ⚠️ Warning The description only lists three high-level change summaries at the top but leaves all template sections blank (unchecked boxes and empty comments), providing minimal context about what is being reverted or why. Complete the description template by filling in the 'Type of Change' section, adding a 'Related Issue(s)' reference, confirming branch/commit conventions, and providing context about the reversions and their impact.
Docstring Coverage ⚠️ Warning Docstring coverage is 50.70% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'revert/TP' is vague and does not clearly convey the specific changes being made, using unclear abbreviation 'TP' without context about what is being reverted. Use a more descriptive title that explains the main changes, such as 'Revert Tensor Parallelism and add gradient shape validation' or 'Remove TP support and add comms gradient validation'.
✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 90a41fa and 6d65c7c.

📒 Files selected for processing (2)
  • src/tplr/__init__.py
  • tests/test_comms.py
✅ Files skipped from review due to trivial changes (1)
  • src/tplr/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_comms.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: test (3.12)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link
Copy Markdown

codecov bot commented Jan 5, 2026

Codecov Report

❌ Patch coverage is 81.01266% with 15 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tplr/neurons.py 66.66% 9 Missing ⚠️
src/tplr/chain.py 0.00% 2 Missing ⚠️
src/tplr/model_factory.py 92.59% 2 Missing ⚠️
src/tplr/compress.py 66.66% 1 Missing ⚠️
src/tplr/hparams.py 80.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (81.01%) is below the target coverage (85.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project status has failed because the head coverage (57.72%) is below the target coverage (85.00%). You can increase the head coverage or adjust the target coverage.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #671      +/-   ##
==========================================
+ Coverage   56.45%   57.72%   +1.27%     
==========================================
  Files          27       27              
  Lines        5165     4975     -190     
==========================================
- Hits         2916     2872      -44     
+ Misses       2249     2103     -146     
Files with missing lines Coverage Δ
src/tplr/__init__.py 100.00% <100.00%> (ø)
src/tplr/comms.py 65.39% <100.00%> (+2.22%) ⬆️
src/tplr/distributed.py 17.64% <ø> (+1.11%) ⬆️
src/tplr/sharded_sampler.py 92.53% <100.00%> (+5.95%) ⬆️
src/tplr/compress.py 80.96% <66.66%> (+3.39%) ⬆️
src/tplr/hparams.py 60.52% <80.00%> (-1.67%) ⬇️
src/tplr/chain.py 36.02% <0.00%> (+0.65%) ⬆️
src/tplr/model_factory.py 61.08% <92.59%> (+1.22%) ⬆️
src/tplr/neurons.py 75.40% <66.66%> (+4.59%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Fix all issues with AI Agents 🤖
In @neurons/miner.py:
- Around line 815-843: The debug slice for model parameter samples currently
uses indices [10:12] when populating debug_dict (entries named name + "_debug"),
but the comparer (compare_model_with_debug_dict) and catchup code expect the
first two elements (index_range=(0,2)); update the population logic in the loop
inside Miner (where debug_dict is filled for self.model.named_parameters()) to
use the same index_range (slice [0:2] or parameterize it via index_range) so the
debug samples and compare_model_with_debug_dict use the same indices and the
avg_steps_behind comparison can succeed.

In @scripts/cleanup_bucket.py:
- Around line 75-83: The cleanup script was left excluding "aggregator",
"debug", and "peers" while the app still creates those objects; update the
cleanup filter in scripts/cleanup_bucket.py to include those prefixes (add
"aggregator", "debug", and "peers" / "peers_" to the startswith tuple) so
objects produced by neurons.py (key="aggregator"), comms.py
s3_put_object()/s3_get_object() (debug key), and create_and_post_peers()
(peers_/peers) are deleted, or alternatively remove the creation of those keys
in neurons.py and comms.py to match the current exclusion—pick one approach and
make the corresponding change consistently across the cleanup filter and the
creators.

In @tests/unit/test_neurons.py:
- Around line 576-611: Tests in TestCheckUidIndexOverlap exercise a simple
pairwise helper but the production function tplr.neurons.check_uid_index_overlap
has a different signature/return used by Validator.run and slash_from_overlap;
fix by adding a new helper (e.g., _pairwise_index_overlap or
pairwise_index_overlap(uids, uid_to_indices, window=0)) that implements the
simple (uids, uid_to_indices, window) → {(uid_i, uid_j): set(indices)} contract
and update TestCheckUidIndexOverlap to import and call that helper, or
alternatively update the tests to construct a minimal neuron/gather_result and
call the existing async check_uid_index_overlap(neuron, gather_result, window,
overlap_threshold=...) so the tests validate the real production API (refer to
symbols: TestCheckUidIndexOverlap, check_uid_index_overlap, slash_from_overlap,
Validator.run, tplr.neurons).

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
neurons/trainer.py (1)

899-975: Remove pre-unscale gradient clipping to follow standard GradScaler pattern

The code clips gradients twice within the optimizer step:

  • Line 916: before scaler.unscale_() (operates on scaled gradients)
  • Line 937: after scaler.unscale_() (operates on unscaled gradients)

Clipping before unscale violates PyTorch AMP best practice and can over-constrain updates. The standard pattern is unscale → clip → step, not clip → unscale → clip → step.

Remove the first clip call at line 916 to keep only the post-unscale clipping at line 937. The distributed reductions for global_tokens_step, global_loss_step, log_loss, and accum_batch_size are correctly implemented.

neurons/validator.py (1)

3564-3740: update_model_with_gradient decompression/distribution is mostly solid, but missing/invalid param data now hard-fails

The new path (using torch.empty_like(p) as ref and self.xshapes[n] / self.totalks[n] for decompression) is correct and avoids copying parameter contents. Broadcasting has_valid_gradient before any barrier is also the right pattern for distributed error consensus.

However:

  • Any parameter with vals is None or quant_params is None now flips has_valid_gradient=False, causing a ValueError for that UID even if other parameters are fine.
  • That means any peer omitting a parameter from its state dict (or drifting top‑k layout) will have the entire gradient rejected, rather than having that parameter treated as zero-update.

If this is deliberate for the transition period, it would be safer to:

  • Document this as a hard contract on the gradient schema, and/or
  • Downgrade missing-per-param data to a logged skip (no update on that param) instead of raising, while still raising for genuine corruption (bad indices, NaN/Inf, shape mismatch).

Also consider guarding the self.xshapes[n] / self.totalks[n] lookups with an explicit KeyErrorValueError conversion so failures produce a clear “unknown parameter in gradient” message.

src/tplr/chain.py (1)

372-395: Early-return when active_peers is empty may leave stale eval peers

The new branch:

if not active_peers:
    logger.warning("No active peers found. Skipping update.")
    return

prevents using gather peers as a fallback, which is good. But it also skips rebuilding self.eval_peers, so any previous eval_peers dict remains in place even though active_peers is now empty.

You might want to explicitly clear self.eval_peers (and possibly self.peers) in this branch to avoid evaluating or gathering from stale peer sets when the chain reports no active peers.

🧹 Nitpick comments (6)
src/tplr/comms.py (1)

98-101: Unify temp directory handling to avoid scattered /tmp paths (optional)

self.temp_dir is now /tmp/templar_{self.uid}, while put() creates a separate /tmp/{self.uid} per instance for temp files. Functionally this works, but using a single convention (e.g. always self.temp_dir) would simplify reasoning about cleanup and avoid leaving stray empty /tmp/{uid} dirs around.

Also applies to: 1314-1318

src/tplr/sharded_sampler.py (1)

67-69: DP-only grad accumulation and rank slicing look correct; consider enforcing divisibility

Using denom = micro_bs * world_size and self._local = global_indices[self.rank :: self.world_size] correctly implements pure-DP sharding and grad accumulation. The only subtlety is that if batch_size is not an exact multiple of micro_bs * world_size, grad_accum_steps will floor-divide and you’ll get an effective global batch smaller than configured. Consider asserting or logging when batch_size % (micro_bs * world_size) != 0 so misconfigurations are caught early.

Also applies to: 75-80

neurons/trainer.py (1)

51-80: Sampler wiring matches TP-less samplers; minor redundancy in micro_bs (optional)

set_dataloader now cleanly routes through MinerSampler / EvalSampler without TP, and the shared args look consistent. The only nit is that shared_args already contains micro_bs, so re-passing micro_bs=self.hparams.micro_batch_size in the miner kwargs is redundant; you can drop it to avoid confusion.

src/tplr/compress.py (1)

236-256: encode now hard-requires pre-populated shape_dict entries

Switching to:

n1 = self.shape_dict[x.shape[0]]
n2 = self.shape_dict[x.shape[1]]

(and similarly for 1D) removes the old dynamic “slow path” and will raise KeyError if a new dimension appears that wasn’t seen in __init__.

This is fine if ChunkingTransformer is only ever used on parameters of the passed model. If there’s any chance of calling encode on ad‑hoc tensors (e.g., external grads, reshaped buffers), consider replacing direct indexing with an explicit check that raises a clearer error such as:

if x.shape[0] not in self.shape_dict:
    raise ValueError(f"Unsupported DCT size {x.shape[0]} – not in shape_dict")

so callers get a diagnostic instead of a generic KeyError.

tests/test_model_comparison.py (1)

196-279: Well-structured test for custom index_range functionality.

The test properly validates:

  1. Default index range (0, 2) matching behavior
  2. Custom index range (5, 7) with proper skip logic for small parameters
  3. Mismatch detection with expected difference calculation

Minor observation: The variable num_elements on lines 214 and 234 is assigned but never used. Consider removing these unused assignments.

🔎 Suggested cleanup
     for name, param in model.named_parameters():
         param_flat = param.flatten()
-        num_elements = param_flat.numel()
         debug_dict[name + "_debug"] = param_flat[:2].detach().cpu().tolist()
     for name, param in model.named_parameters():
         param_flat = param.flatten()
-        num_elements = param_flat.numel()
 
         # Only include parameters with enough elements for the custom range
src/tplr/model_factory.py (1)

275-298: Redundant int() coercion for tp_degree.

Line 275 already casts tp_degree to int, but line 296 casts it again. The same pattern exists for other degrees but is less obvious since their initial extraction doesn't include int().

Consider consolidating the coercion:

🔎 Suggested consolidation
-        tp_degree = int(getattr(tt, "tp_degree", 1))
-        pp_degree = int(getattr(tt, "pp_degree", 1))
-        cp_degree = int(getattr(tt, "cp_degree", 1))
-        dp_replicate = getattr(tt, "dp_replicate", 1)
-        dp_shard = getattr(tt, "dp_shard", 1)
+        tp_degree = getattr(tt, "tp_degree", 1)
+        pp_degree = getattr(tt, "pp_degree", 1)
+        cp_degree = getattr(tt, "cp_degree", 1)
+        dp_replicate = getattr(tt, "dp_replicate", 1)
+        dp_shard = getattr(tt, "dp_shard", 1)
 
-        # Ensure divisors are not zero before coercion to 1 and modulo operations
+        # Ensure divisors are not zero
         if dp_replicate == 0:
             raise ValueError("dp_replicate cannot be zero.")
         # ... other checks ...
 
-        # Coerce to int after zero checks
+        # Coerce all to int
         dp_replicate = int(dp_replicate)
         dp_shard = int(dp_shard)
         tp_degree = int(tp_degree)
         pp_degree = int(pp_degree)
         cp_degree = int(cp_degree)
📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bc7e48e and 90a41fa.

📒 Files selected for processing (21)
  • .gitignore
  • neurons/evaluator.py
  • neurons/miner.py
  • neurons/trainer.py
  • neurons/validator.py
  • scripts/abort_multipart_uploads.py
  • scripts/cleanup_bucket.py
  • src/tplr/chain.py
  • src/tplr/comms.py
  • src/tplr/compress.py
  • src/tplr/distributed.py
  • src/tplr/hparams.py
  • src/tplr/model_factory.py
  • src/tplr/neurons.py
  • src/tplr/sharded_sampler.py
  • tests/test_comms.py
  • tests/test_evaluator.py
  • tests/test_model_comparison.py
  • tests/test_prepare_gradient_dict.py
  • tests/unit/test_model_factory.py
  • tests/unit/test_neurons.py
💤 Files with no reviewable changes (5)
  • tests/test_evaluator.py
  • src/tplr/distributed.py
  • tests/test_comms.py
  • .gitignore
  • scripts/abort_multipart_uploads.py
🧰 Additional context used
🧬 Code graph analysis (9)
src/tplr/neurons.py (5)
src/tplr/compress.py (2)
  • batch_decompress (484-580)
  • decode (260-292)
tests/test_comms.py (2)
  • batch_decompress (954-956)
  • decode (940-942)
tests/conftest.py (1)
  • totalks (141-144)
tests/test_prepare_gradient_dict.py (3)
  • decode (37-38)
  • decode (174-175)
  • decode (268-270)
tests/test_dcp_checkpoint.py (1)
  • warning (117-118)
neurons/trainer.py (1)
src/tplr/distributed.py (1)
  • ddp_reduce (97-131)
src/tplr/chain.py (1)
tests/test_dcp_checkpoint.py (1)
  • warning (117-118)
neurons/evaluator.py (1)
src/tplr/distributed.py (1)
  • ddp_reduce (97-131)
neurons/miner.py (3)
src/tplr/logging.py (2)
  • T (41-48)
  • P (51-62)
src/tplr/comms.py (1)
  • put (1261-1352)
src/tplr/neurons.py (1)
  • outer_step (225-476)
tests/unit/test_model_factory.py (1)
src/tplr/model_factory.py (1)
  • create_parallel_dims (215-332)
src/tplr/comms.py (1)
tests/test_dcp_checkpoint.py (1)
  • warning (117-118)
tests/unit/test_neurons.py (1)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1298-1441)
tests/test_model_comparison.py (1)
src/tplr/neurons.py (1)
  • compare_model_with_debug_dict (1194-1294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (30)
scripts/cleanup_bucket.py (1)

86-89: LGTM! Log messages are consistently updated.

All log messages and the docstring accurately reflect the narrowed scope of the cleanup script. The messaging is clear and consistent throughout.

Also applies to: 108-110, 118-118

neurons/evaluator.py (1)

799-811: DDP reduction for custom eval metrics looks consistent

Using dist_helper.ddp_reduce(..., op=dist.ReduceOp.SUM, device=self.device) on local_loss, local_tokens, and local_bytes to form global totals is correct and matches the later average/perplexity and bpb computations.

src/tplr/hparams.py (1)

77-85: Shallow dict.update changes nested hparams override semantics

Switching from deep merges to:

  • full_hparams = DEFAULT_HPARAMS.copy(); full_hparams.update(hparams) in create_namespace, and
  • successive hparams.update(...) calls in load_hparams

means any nested dicts (e.g. optimizer, anneal_mode, torchtitan) from higher‑priority files now replace lower‑priority ones entirely, rather than merging per key. That’s fine if all override JSONs carry complete nested sections; just ensure no configs rely on “partial” nested overrides that previously depended on deep_merge behavior.

Also applies to: 145-195

src/tplr/comms.py (1)

1606-1621: xshapes-based vals shape check should correctly filter sharded gradients

The new xshapes argument and the (vals.shape[:-1] == xshapes[base_name][:-1]) check give a cheap, robust guard against receiving sharded or otherwise mismatched gradients: assuming xshapes[base_name] holds the unsharded parameter shape and compression is along the last dimension, any TP-style shards will fail the prefix match and the UID is skipped with a clear warning. Please just confirm that:

  • xshapes keys use the raw parameter names (matching param_name + "vals" / param_name + "idxs"), and
  • the TopK compressor indeed produces vals with shape xshape[:-1] + [topk],

so valid gradients aren’t accidentally rejected.

Also applies to: 1650-1652, 1855-1875

neurons/trainer.py (2)

427-479: Evaluation loss and batch aggregation via ddp_reduce is coherent

Summing total_loss and n_batches across ranks with dist_helper.ddp_reduce(..., device=device) yields global totals; callers can recover a true mean as total_loss / n_batches. The guard on world_size > 1 and dist_helper.is_distributed() keeps single-GPU behavior unchanged.


591-631: Adam metrics: distributed reductions and max tracking are implemented correctly

The new block that:

  • sums local squared norms and moment buffers with ddp_reduce(..., device=self.device),
  • uses op=ReduceOp.MAX for local_grad_norm_max / local_update_norm_max, and
  • aggregates ratio sums and counts before dividing,

produces consistent, global metrics across ranks. The subsequent square-roots/ratios use these globals correctly.

neurons/validator.py (4)

1491-1511: Passing xshapes into gather_with_reserve looks correct

Wiring xshapes=self.xshapes next to totalks=self.totalks keeps the validator-side expectations consistent with the new shape-aware gather validation in comms. Assuming self.xshapes / self.totalks are constructed for every named parameter (as in __init__), this call site looks sound.


1878-1893: Early-exit on window change during eval is reasonable

The new master-only log + break when self.current_window != eval_window prevents partially-evaluated windows without leaving other ranks behind (barriers come after the loop). This is a sensible guard and should avoid inconsistent scoring when the chain advances mid-evaluation.


4040-4055: Fixing miner digest to world_size=1 is consistent but assumes single-rank semantics

Passing world_size=1 into MinerSampler for the validator-side digest makes the reconstructed index pool independent of validator DDP/TP topology, which is desirable for matching miner behavior.

Just ensure the miner-side _training_pool_digest (or equivalent) uses the same world_size=1 convention; otherwise, digest mismatches could appear even for honest miners purely due to parallel layout differences.


2031-2068: Gradient-apply failure handling ensures group consensus by rejecting whole peer gradients

The current design explicitly invalidates the entire peer gradient if any parameter is missing or undecompressible, rather than treating missing per-parameter contributions as zeros. This is intentional: any per-parameter failure (missing vals/quant_params, decompression errors, NaN/Inf in decompressed data) raises an exception at the parameter loop level, broadcasts consensus to all ranks, then triggers a slash + full model-state restore for that UID.

Confirm this stricter schema—rejecting partially-sparse or schema-drifted gradients outright—aligns with your expected behavior, since the alternative (zero-filling missing parameters) would be more permissive.

tests/test_prepare_gradient_dict.py (1)

107-110: Metadata tests now correctly enforce minimal {"window": step_window} contract

Both tests now assert that gradient["metadata"] is exactly {"window": step_window}. This matches the simplified metadata semantics and guards against accidentally reintroducing extra keys like xshapes/totalks into the payload.

Also applies to: 134-138

tests/unit/test_model_factory.py (1)

242-247: Parallel-dims tests correctly mirror new validation rules

The updated miner tests:

  • Use torchtitan.tp_degree=2 for the custom-TP happy path.
  • Assert the new error message for dp_replicate >1 and dp_shard >1.
  • Assert the refined message for dp_replicate with non‑trivial tp.
  • Check that both dp_shard and tp_degree mis-divisibility raise the generic “world_size … must be divisible by the product of all parallel degrees” error.

These are consistent with create_parallel_dims in model_factory.py and should catch future regressions in the validation logic.

Also applies to: 248-281

tests/unit/test_neurons.py (2)

48-76: compare_model_with_debug_dict tests now match the 3-element slice behavior

Switching param1 to a MagicMock with .data and .dtype, and expanding the debug slices and index_range=(0, 3), lines up with the current implementation which operates on a 3‑element slice. The perfect-match and mismatch cases look correct and will regress if the comparison logic changes.


476-481: Broadcast call-count expectations adjusted to new catchup behavior

The updated mock_broadcast.call_count assertions (3, 4, and 2 respectively) reflect the revised catchup_with_aggregation_server control flow (skipped windows, fallbacks, and early exits). As long as these numbers were obtained by tracing the new implementation, the tests look aligned and will help catch future changes to the broadcast protocol.

Also applies to: 521-524, 566-569

tests/test_model_comparison.py (1)

308-314: LGTM!

Explicitly passing index_range=(0, 2) in tests makes the intent clear, even though it matches the default. This improves test readability and ensures tests won't break if the default changes.

src/tplr/model_factory.py (2)

254-270: Hardcoded dp_shard=4 may cause failures for non-standard validator setups.

The validator role now requires world_size % 4 == 0. This is a breaking change for validators running with world sizes like 1, 2, 3, 5, 6, 7, etc.

Consider whether this is intentional. If validators must run with exactly 4 GPUs (or multiples), document this requirement. Otherwise, consider a more flexible approach:

🔎 Suggested alternative
     elif role == "validator":
         # Validator: pipeline parallelism with data parallel replication
-        # Ensure dp_shard is at least 1 to prevent division by zero
-        dp_shard = 4
+        # Use dp_shard that evenly divides world_size (prefer 4 if possible)
+        dp_shard = 4 if world_size >= 4 and world_size % 4 == 0 else max(1, min(4, world_size))
         if world_size % dp_shard != 0:

Is the requirement for validators to run with world_size divisible by 4 intentional? If so, this should be documented in deployment guides.


300-322: LGTM!

The validation logic is comprehensive and well-structured:

  • Mutual exclusivity between dp_replicate > 1 and dp_shard > 1
  • Constraint that dp_replicate > 1 requires all other parallelism degrees to be 1
  • Clear error messages with the actual values for debugging

The safeguard at lines 314-317 provides defense-in-depth even though individual zero checks should prevent this case.

src/tplr/neurons.py (7)

63-69: LGTM!

Good defensive pattern for DTensor detection with a graceful fallback to string-based type checking. This reduces hard dependencies on DTensor internals.


120-131: Verify assertion behavior for DTensor parameters without gradients.

Line 128-130 logic: if a DTensor parameter has grad=None, the code falls through to the assertion at line 130, which will fail.

This is likely intentional (DTensor params should always have grads after backward in a distributed setting), but consider whether there are edge cases where a DTensor param might legitimately have no gradient (e.g., frozen parameters, parameters not used in forward pass).

If this assertion is expected behavior, consider adding a more descriptive error message:

-            assert g is not None, f"p.grad is None for {n}"
+            assert g is not None, f"p.grad is None for DTensor param {n} - all DTensor params must have gradients"

138-152: LGTM!

Clean simplification of error feedback handling:

  • Initializes to zeros on first use
  • Properly handles device placement
  • Clear null-round vs normal momentum update logic

220-221: LGTM!

Metadata simplified to contain only window. The xshapes and totalks are now passed separately to gather() calls, which provides better separation of concerns.


353-371: LGTM!

Good memory optimization using empty_like instead of copying the parameter tensor. The decompression path is simplified by using xshapes[name] and totalks[name] directly from function parameters.


1145-1170: LGTM!

Clean refactoring to use the shared compare_model_with_debug_dict helper with proper parameter passing:

  • index_range=(0, 2) for debug snippet comparison
  • param_avg_change for step-size estimation
  • Appropriate error handling and logging

656-658: LGTM!

Informative log message for the edge case where bootstrap start_window cannot be fetched.

neurons/miner.py (6)

207-212: Consider removing unused parallel degree attributes.

tp_degree (line 207) is stored but no longer used for ownership computation after reverting TP support. The ownership is now determined purely by idx % world_size (line 236).

If TP is not supported, consider removing these attributes or documenting that they're for future use:

self.tp_degree  # Stored but unused
self.pp_degree  # Stored but unused  
self.cp_degree  # Stored but unused

Are these attributes used elsewhere in the miner, or should they be removed as part of the TP revert?


239-243: LGTM!

Simplified error feedback initialization:

  • Lazy initialization (None) for actual tensors
  • Pre-allocated pinned CPU buffers for efficient GPU↔CPU transfers
  • Uses parameter shape directly, which works for both regular and DTensor params

638-667: LGTM!

Clean simplification of gradient merging:

  • Simple dict.update() instead of TP-specific reconstruction
  • Good memory management by freeing shards immediately after use
  • Metadata correctly includes window and sample receipt information

691-711: LGTM!

Upload pathway is clean:

  • Master-only upload with proper async handling
  • Useful upload size logging for debugging/monitoring
  • Consistent error handling via the comms.put interface

750-766: LGTM!

The addition of xshapes=self.xshapes to gather_with_reserve enables the gradient shape validation mentioned in the PR objectives. This helps prevent accepting sharded gradients during transition periods.


931-938: LGTM!

Good observability addition - gradient fingerprint metrics provide visibility into the optimization process and can help diagnose training issues.

Add three new test cases to verify xshapes-based gradient shape
validation in comms.gather():

- test_gather_rejects_sharded_gradient_shape: Verifies that gradients
  with mismatched dimensions (e.g., sharded instead of full) are
  properly rejected when xshapes is provided

- test_gather_accepts_correct_gradient_shape: Ensures gradients with
  correct shapes are accepted when xshapes validation is enabled

- test_gather_without_xshapes_accepts_all: Confirms that shape
  validation is skipped entirely when xshapes parameter is None

These tests protect against miners uploading incorrectly shaped
gradients (particularly sharded gradients when full gradients are
expected).
@joellidin joellidin merged commit 92b66a9 into dev Jan 5, 2026
6 of 8 checks passed
@joellidin joellidin deleted the revert/TP branch January 5, 2026 17:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants