-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathflow.py
More file actions
170 lines (128 loc) · 5.86 KB
/
flow.py
File metadata and controls
170 lines (128 loc) · 5.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from abc import ABC, abstractmethod
from typing import Callable, Optional
import torch
from torch import Tensor
from einops import rearrange
def x2prob(x: Tensor, vocab_size: int) -> Tensor:
"""Converts sequence of tokens to class distribution representation
"""
return torch.nn.functional.one_hot(x, num_classes=vocab_size).float()
def sample_p(pt: Tensor, temperature: float = 1.0) -> Tensor:
"""Samples protein sequence from class distribution representation
"""
b, l, _ = pt.shape
pt = rearrange(pt, 'b l c -> (b l) c')
xt = torch.multinomial(pt / temperature, 1)
return xt.reshape(b, l)
class Coupling(ABC):
@abstractmethod
def sample(self, x1: Tensor) -> tuple[Tensor, Tensor]:
raise NotImplementedError
class EmptyCoupling(Coupling):
"""A coupling that samples empty prior sequences
"""
def sample(self, x1: Tensor):
x0 = torch.empty((x1.shape[0], 0), dtype=x1.dtype, device=x1.device).long()
return x0, x1
class GeneratorCoupling(Coupling):
"""A coupling that samples prior sequences from a generator function
"""
def __init__(self, generator_fn: Callable[[Optional[Tensor]], Tensor]):
self.generator_fn = generator_fn
def sample(self, x1: Tensor):
x0 = self.generator_fn(x1)
return x0, x1
class ExtendedCoupling(Coupling):
"""A coupling that randomly inserts tokens into the target sequence
"""
def __init__(self, n_insert: int = 10, vocab_size: int = 128, pad_token: int = 129):
self.n_insert = n_insert
self.vocab_size = vocab_size
self.pad_token = pad_token
def sample(self, x1: Tensor):
batch_size, x1_seq_len = x1.shape
x1_pad_mask = (x1 == self.pad_token)
x1_seq_lengths = (~x1_pad_mask).sum(dim=1).tolist()
ins_positions = torch.stack([
torch.randint(0, seqlen+1, size=(self.n_insert,), dtype=torch.long, device=x1.device)
for seqlen in x1_seq_lengths
])
ins_positions = torch.sort(ins_positions, dim=1)[0] # (batch_size, n_insert)
max_new_len = self.n_insert + x1_seq_len
x0 = torch.full((batch_size, max_new_len), self.pad_token, dtype=x1.dtype, device=x1.device) # (batch_size, max_new_len)
batch_indices = torch.arange(batch_size, device=x1.device).unsqueeze(1) # (batch_size, 1)
orig_positions = torch.arange(x1_seq_len, device=x1.device).unsqueeze(0).expand(batch_size, -1) # (batch_size, x1_seq_len)
num_insert_before = (ins_positions.unsqueeze(2) <= orig_positions.unsqueeze(1)).sum(dim=1) # (batch_size, x1_seq_len)
new_orig_positions = orig_positions + num_insert_before # (batch_size, x1_seq_len)
x0[batch_indices, new_orig_positions] = x1
ins_new_positions = ins_positions + torch.arange(self.n_insert, device=x1.device).unsqueeze(0) # (batch_size, n_insert)
ins_tokens = torch.randint(0, self.vocab_size, size=(batch_size, self.n_insert), dtype=x1.dtype, device=x1.device)
x0[batch_indices, ins_new_positions] = ins_tokens
return x0, x1
class UniformCoupling(Coupling):
"""A coupling that samples uniform prior sequences within a given length range
"""
def __init__(
self,
min_len: int = 0,
max_len: int = 100,
vocab_size: int = 128,
mirror_len: bool = False,
pad_token: int = 129,
):
self.min_len = min_len
self.max_len = max_len
self.vocab_size = vocab_size
self.mirror_len = mirror_len
self.pad_token = pad_token
def sample(self, x1: Tensor):
batch_size, _ = x1.shape
x1_pad_mask = (x1 == self.pad_token)
if self.mirror_len:
x0_pad_mask = x1_pad_mask
x0_max_len = x1.shape[1]
else:
x0_seq_len = torch.randint(self.min_len, self.max_len + 1, size=(batch_size,)).long()
x0_max_len = int(x0_seq_len.max().item())
x0_pad_mask = torch.arange(x0_max_len, device=x1.device).expand(batch_size, -1) >= x0_seq_len.unsqueeze(1)
x0 = torch.randint(0, self.vocab_size, size=(batch_size, x0_max_len), dtype=x1.dtype, device=x1.device)
x0[x0_pad_mask] = self.pad_token
return x0, x1
class KappaScheduler(ABC):
"""Base class for kappa schedulers
"""
@abstractmethod
def __call__(self, t: Tensor) -> Tensor:
raise NotImplementedError
@abstractmethod
def derivative(self, t: Tensor) -> Tensor:
raise NotImplementedError
class CubicScheduler(KappaScheduler):
def __init__(self, a: float = 2.0, b: float = 0.5) -> None:
self.a = a
self.b = b
def __call__(self, t: Tensor) -> Tensor:
return -2* (t**3) + 3 * (t**2) + self.a * (t ** 3 - 2* t**2 + t) + self.b * (t**3 - t**2)
def derivative(self, t: Tensor) -> Tensor:
return -6 * (t**2) + 6 * t + self.a * (3 * t**2 - 4 * t + 1) + self.b * (3 * t**2 - 2 * t)
if __name__ == "__main__":
from utils import pretty_parse
coupling = ExtendedCoupling(n_insert=10, vocab_size=128, pad_token=129)
# make a padded batch of sequences
x1 = torch.randint(0, 128, size=(4, 20), dtype=torch.long)
x1[0, 15:] = 129
x1[1, 18:] = 129
x1[2, 4:] = 129
x0, _ = coupling.sample(x1)
print(x1.shape)
print(x0.shape)
print()
for i in range(x0.shape[0]):
print(f"x0[{i}]: {pretty_parse(x0[i], compact=True)}")
print(f"x1[{i}]: {pretty_parse(x1[i], compact=True)}")
print()
x0_pad_mask = (x0 == coupling.pad_token) # (batch_size, x0_max_len)
x0_seq_lengths = (~x0_pad_mask).sum(dim=1) # (batch_size,)
for i in range(x0.shape[0]):
assert x0_pad_mask[i, :x0_seq_lengths[i]].any() == False, f"x0[{i}] has padding in the sequence part"
assert x0_pad_mask[i, x0_seq_lengths[i]:].all(), f"x0[{i}] is not padded correctly"