Skip to content

use torch.gather() to make batched_gather() more efficient#135

Open
OccupyMars2025 wants to merge 1 commit intoaqlaboratory:mainfrom
OccupyMars2025:my-modification-20260314
Open

use torch.gather() to make batched_gather() more efficient#135
OccupyMars2025 wants to merge 1 commit intoaqlaboratory:mainfrom
OccupyMars2025:my-modification-20260314

Conversation

@OccupyMars2025
Copy link

@OccupyMars2025 OccupyMars2025 commented Mar 14, 2026

You can refer to this PR: bytedance/Protenix#269

Note: Yes, I know batched_gather() is not used in openfold-3

@jandom
Copy link
Collaborator

jandom commented Mar 14, 2026

hi there @OccupyMars2025 thanks for this contribution – I'm guessing this is more of an "idea", rather than a drop-in replacement? The equivalent change in ProteniX lead to an improvement, so if batched_gather was used in openfold (it's not right now, no callers) it could be adapted in a similar fashion?

What'd be your recommended next step here?

@OccupyMars2025
Copy link
Author

OccupyMars2025 commented Mar 14, 2026

The following code compares the execution time of the original implementation and my implementation of batched_gather():

import torch

def batched_gather(data, inds, dim=0, no_batch_dims=0):
    ranges = []
    for i, s in enumerate(data.shape[:no_batch_dims]):
        r = torch.arange(s)
        r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
        ranges.append(r)

    remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
    remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
    ranges.extend(remaining_dims)
    # return data[ranges]
    return data[tuple(ranges)] # transform "ranges" to a tuple to silent a pytorch framework UserWarning


def batched_gather_my_version(
    data: torch.Tensor, inds: torch.Tensor, dim: int=0, no_batch_dims: int=0
) -> torch.Tensor:
    """Gather data according to indices specified by inds along the dim = len(inds.shape) - 1

    Args:
        data (torch.Tensor): the input data
            [..., K, ...]
        inds (torch.Tensor): the indices for gathering data
            [..., N]

    Returns:
        torch.Tensor: gathered data, have the same number of dimensions as data,
            only the size of dimension len(inds.shape) - 1 is changed to N
            [..., N, ...]
    """
    assert len(inds.shape) <= len(data.shape), "inds must have less or equal dimensions than data"
    assert inds.shape[:len(inds.shape)-1] == data.shape[:len(inds.shape)-1], "Batch dimensions must match between data and inds"
   
    if len(inds.shape) == len(data.shape):
        return torch.gather(data, dim=-1, index=inds)

    append_shape = (1,) * (len(data.shape) - len(inds.shape))
    append_shape_broadcasted = data.shape[len(inds.shape) - len(data.shape):]
    inds_broadcasted = inds.reshape(inds.shape + append_shape)
    inds_broadcasted = inds_broadcasted.expand(inds.shape + append_shape_broadcasted)
    return torch.gather(data, dim=len(inds.shape) - 1, index=inds_broadcasted)

n_tokens = 100
n_atoms = 200
batch_dims = (20, 30, 4)
token_channels = 16
data = torch.randn(*batch_dims, n_tokens, token_channels)
inds = torch.randint(0, n_tokens, (*batch_dims, n_atoms))
data = data.to(torch.device('cuda:0'))
inds = inds.to(torch.device('cuda:0'))
dim = len(inds.shape) - 1
no_batch_dims = len(batch_dims)
result = batched_gather(data, inds, dim=dim, no_batch_dims=no_batch_dims)
result_my_version = batched_gather_my_version(data, inds, dim=dim, no_batch_dims=no_batch_dims)
assert torch.equal(result, result_my_version), "The results from both versions should be the same"




# ===== compare the execution time of the two versions =====
import torch
import time
import torch.utils.benchmark as benchmark

# ----------------------
# Your two functions
# ----------------------
fun1 = batched_gather
fun2 = batched_gather_my_version

# ----------------------
# Setup: Create test tensor (match your real use case!)
# ----------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Warmup: VERY important for GPU (compiles kernels, allocates memory)
fun1(data, inds, dim=dim, no_batch_dims=no_batch_dims)
fun2(data, inds, dim=dim, no_batch_dims=no_batch_dims)
if DEVICE == "cuda":
    torch.cuda.synchronize()

# ==============================================================================
# Method 1: Manual timing (simple, good for quick checks)
# ==============================================================================
def time_function(func, repeats=1000):
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    start = time.time()
    
    for _ in range(repeats):
        func(data, inds, dim=dim, no_batch_dims=no_batch_dims)
    
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    end = time.time()
    return (end - start) / repeats  # average time per call

# Run
t1 = time_function(fun1)
t2 = time_function(fun2)

print("=" * 50)
print("Manual Timing (avg per call)")
print(f"fun1: {t1*10**6:.4f} us")
print(f"fun2: {t2*10**6:.4f} us")
print(f"Faster by: {max(t1/t2, t2/t1):.2f}x")

# ==============================================================================
# Method 2: Torch Built-in Benchmark (RECOMMENDED - most accurate)
# Handles warmup, CUDA sync, statistics, outliers
# ==============================================================================
print("\n" + "=" * 50)
print("Torch Benchmark (official, accurate)")

t_fun1 = benchmark.Timer(
    stmt='fun1(data, inds, dim, no_batch_dims)',
    globals={'fun1': fun1, 'data': data, 'inds': inds, 'dim': dim, 'no_batch_dims': no_batch_dims}
)
t_fun2 = benchmark.Timer(
    stmt='fun2(data, inds, dim, no_batch_dims)',
    globals={'fun2': fun2, 'data': data, 'inds': inds, 'dim': dim, 'no_batch_dims': no_batch_dims}
)


res1 = t_fun1.timeit(1000)
res2 = t_fun2.timeit(1000)

print(res1)
print(res2)
print(f"Faster function: {'fun1' if res1.mean < res2.mean else 'fun2'}")
print(f"Faster by: {max(res1.mean/res2.mean, res2.mean/res1.mean):.2f}x")

conclusion:

On colab, H100 GPU, my implementation(fun2) is about 4x faster

==================================================
Manual Timing (avg per call)
fun1: 139.8790 us
fun2: 35.8293 us
Faster by: 3.90x

==================================================
Torch Benchmark (official, accurate)
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee253d370>
fun1(data, inds, dim, no_batch_dims)
  140.54 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee2189ac0>
fun2(data, inds, dim, no_batch_dims)
  35.76 us
  1 measurement, 1000 runs , 1 thread
Faster function: fun2
Faster by: 3.93x

On CPU, the efficiency differences between these two implementations are not obvious

==================================================
Manual Timing (avg per call)
fun1: 1036.5732 us
fun2: 812.4716 us
Faster by: 1.28x

==================================================
Torch Benchmark (official, accurate)
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee2d3aa20>
fun1(data, inds, dim, no_batch_dims)
  11.49 ms
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee2d0bd40>
fun2(data, inds, dim, no_batch_dims)
  9.65 ms
  1 measurement, 1000 runs , 1 thread
Faster function: fun2
Faster by: 1.19x

@OccupyMars2025
Copy link
Author

OccupyMars2025 commented Mar 14, 2026

I think my implementation is a more efficient drop-in replacement.

TODO: My recommended next step: In protenix, batched_gather() is used to broadcast a token embedding to an atom embedding(task 1) and gather frame atom coordinates(task 2). But in openfold-3, batched_gather() is NOT used now. Maybe we can compare how protenix and openfold-3 handle these 2 tasks.

Still you can refer to bytedance/Protenix#269 (comment)

Currently, I don't have time to do this. I hope I will have time to do this two or three weeks later

@jandom jandom added enhancement New feature or request training Relating to the training pipeline model Related to the model definition sketch This is an idea/sketch rather than a mergeable PR, significant work needed to advance needed labels Mar 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request model Related to the model definition sketch This is an idea/sketch rather than a mergeable PR, significant work needed to advance needed training Relating to the training pipeline

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants