Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torchao/prototype/models/gpt_fused/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## gpt-fused

A more handwritten version of [gpt-fast](https://github.com/pytorch-labs/gpt-fast)'s model.py for us to experiment with.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym by more handwritten?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use this to try various fused kernels (Triton or CUDA).


Requires the use of gpt-fast to use, but is a drop-in replacement.

To use it just set to PYTHONPATH environment variable to `PYTHONPATH=<path to ao repository>/ao/torchao/prototype/models/gpt_fused` and delete gpt-fast's model.py.

For example

```
PYTHONPATH=/home/cpuhrsch/local/ao/torchao/prototype/models/gpt_fused CUDA_VISIBLE_DEVICES=0 numactl --membind 0 --cpubind 0 python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here lol, why do i need to set the python path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that the import statements in gpt-fast pick up on the location of model.py in torchao

```
255 changes: 255 additions & 0 deletions torchao/prototype/models/gpt_fused/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a file in test or benchmark script that would just sanity check that the script works with real and random weights

Also why prototype namespace? I think torchao.models.gpt is better, i expect a lot of people will use this as is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using prototype just to get started. Yes, we can add a benchmark script. I'll work on that next.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)

@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]

# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match

return cls(**transformer_configs[config[0]])


transformer_configs = {
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
}

class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.output.weight.dtype
# For quantized layers, dtype is encoded in scales
if hasattr(self.output, "scales"):
dtype = self.output.scales.dtype
elif hasattr(self.output, "scales_and_zeros"):
dtype = self.output.scales_and_zeros.dtype
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
return logits

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))


class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out


class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y


class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000,
dtype: torch.dtype = torch.bfloat16
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=dtype)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)