-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprovider.py
More file actions
351 lines (290 loc) · 13.1 KB
/
provider.py
File metadata and controls
351 lines (290 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from __future__ import annotations
from typing import List, Dict, Tuple, Set
import logging
import os
from gpu_logging_utils import (
log_gpu_memory_nvidia_smi,
log_cuda_memory_pytorch,
flush_cuda_cache,
)
from nltk.grammar import Nonterminal
import torch
from torch.nn import functional as F
from local_llm import LocalLLM
from dataclasses import dataclass, field
os.makedirs("logs", exist_ok=True)
# Logger for provider events
provider_logger = logging.getLogger("provider")
provider_logger.setLevel(logging.WARNING)
provider_handler = logging.FileHandler("logs/provider.log")
provider_handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(message)s")
)
provider_logger.addHandler(provider_handler)
# Logger for logits events
logits_logger = logging.getLogger("logits")
logits_logger.setLevel(logging.WARNING)
logits_handler = logging.FileHandler("logs/logits.log")
logits_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logits_logger.addHandler(logits_handler)
# Logger for performance metrics
performance_logger = logging.getLogger("performance")
performance_logger.setLevel(logging.INFO)
performance_handler = logging.FileHandler("logs/performance.log")
performance_handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(message)s")
)
performance_logger.addHandler(performance_handler)
@dataclass
class _TrieNode:
children: Dict[int, "_TrieNode"] = field(default_factory=dict)
nts: List[str] = field(default_factory=list)
class TokenLevelProbabilityProvider:
"""
Uses the LLM's token-level probabilities to calculate grammar rule probabilities.
"""
def __init__(
self, llm: LocalLLM, nonterminals: Set[str], *, cache_size: int = 2048
):
self._llm = llm
self._nonterminals = nonterminals
log_gpu_memory_nvidia_smi("provider init")
log_cuda_memory_pytorch("provider init")
# Map non-terminals to the sequence of token ids that represent them in
# the LLM vocabulary. Some symbols (e.g. "NP-SBJ") are tokenised into
# multiple tokens, so we keep the full sequence to compute probabilities
# correctly when querying the model.
self._nt_token_seqs = self._map_nonterminals_to_token_seqs()
self._trie = self._build_trie()
flush_cuda_cache()
try:
self._space_token_id = self._llm.tokenizer.encode(
" ", add_special_tokens=False
)[0]
except Exception:
self._space_token_id = None
# Metrics for performance logging
self._token_generation_count = 0
self._final_nt_token_count = 0
self._prediction_count = 0
def _map_nonterminals_to_token_seqs(self) -> Dict[str, List[int]]:
"""Map non-terminals to their corresponding token id sequences."""
nt_token_ids: Dict[str, List[int]] = {}
log_gpu_memory_nvidia_smi("map_nonterminals_start")
log_cuda_memory_pytorch("map_nonterminals_start")
for nt in self._nonterminals:
# Add a space before the symbol to mimic generation from the model
token_ids = self._llm.tokenizer.encode(f" {nt}", add_special_tokens=False)
nt_token_ids[nt] = token_ids
provider_logger.info(f"Non-terminal '{nt}' mapped to token IDs {token_ids}")
# log_cuda_memory_pytorch(f"mapped {nt}")
return nt_token_ids
def _build_trie(self) -> _TrieNode:
"""Build a prefix trie of non-terminal token sequences."""
# log_gpu_memory_nvidia_smi("build_trie_start")
# log_cuda_memory_pytorch("build_trie_start")
root = _TrieNode()
for nt, seq in self._nt_token_seqs.items():
node = root
for tok in seq:
node = node.children.setdefault(tok, _TrieNode())
node.nts.append(nt)
# log_cuda_memory_pytorch(f"trie_nt_{nt}")
flush_cuda_cache()
return root
def _predict_recursive(
self,
context: torch.Tensor,
node: _TrieNode,
prob: float,
depth: int = 0,
result_cache: Dict[Tuple[int, Tuple[int, ...]], Dict[str, float]] = None
) -> Dict[str, float]:
if result_cache is None:
result_cache = {}
# Create a cache key from the last token ID and current node's children keys
cache_key = (context[0, -1].item(), tuple(sorted(node.children.keys())))
if cache_key in result_cache:
return {k: v * prob for k, v in result_cache[cache_key].items()}
log_cuda_memory_pytorch("predict_recursive_start")
device = self._llm.model.device
results: Dict[str, float] = {}
if not node.children:
for nt in node.nts:
results[nt] = results.get(nt, 0.0) + prob
return results
log_cuda_memory_pytorch("before_model_call")
# print(
# f"ctx shape: {context.shape}, "
# f"memory: {context.numel() * context.element_size() / (1024 * 1024):.4f} MB"
# )
# Count tokens generated by each model call
self._token_generation_count += 1
out = self._llm.model(context)
logits = out.logits[0, -1, :]
log_cuda_memory_pytorch("after_model_call")
token_probs = F.softmax(logits, dim=0)
# Free memory after getting token probabilities
del out
if depth % 3 == 0: # Flush cache every 3 levels of recursion
flush_cuda_cache()
total_child_prob = 0.0
# Process children in batches to manage memory
batch_size = 5 # Adjust based on your GPU memory
child_items = list(node.children.items())
for batch_idx in range(0, len(child_items), batch_size):
batch = child_items[batch_idx:batch_idx + batch_size]
for tok, child in batch:
p = token_probs[tok].item()
if p <= 0:
continue
total_child_prob += p
# Create new context tensor
new_ctx = torch.cat([context, torch.tensor([[tok]], device=device)], dim=1)
# Recursive call with depth tracking
sub_nodes = self._predict_recursive(new_ctx, child, prob * p, depth + 1, result_cache)
# Update results
for nt, v in sub_nodes.items():
results[nt] = results.get(nt, 0.0) + v
# Free memory
del new_ctx
# Flush cache after each batch
flush_cuda_cache()
if self._space_token_id is not None:
leftover = token_probs[self._space_token_id].item()
else:
leftover = max(1.0 - total_child_prob, 0.0)
for nt in node.nts:
results[nt] = results.get(nt, 0.0) + prob * leftover
# Cache the normalized results (without the prob multiplier)
if prob > 0:
result_cache[cache_key] = {k: v/prob for k, v in results.items()}
return results
def get_span_probabilities(
self, text: str, spans: List[Tuple[int, int]]
) -> Dict[Tuple[int, int], Dict[Nonterminal, float]]:
"""
Get probabilities of nonterminals for each text span.
"""
log_gpu_memory_nvidia_smi("get_span_probabilities_start")
log_cuda_memory_pytorch("get_span_probabilities_start")
result = {}
history: Dict[Tuple[int, int], str] = {}
for start, end in spans:
# Count a prediction for each span
self._prediction_count += 1
span_text = text.split()[start:end]
span_str = " ".join(span_text)
context_parts = [f"{s}:{e}->{cat}" for (s, e), cat in history.items()]
context_str = "; ".join(context_parts)
prompt = (
f"Given the possible nonterminal symbols {', '.join(self._nonterminals)}, "
f"previously recognised spans: {context_str}. "
"Using the nonterminal symbols of the Penn Treebank corpus, "
f"the phrase '{span_str}' in the text '{text}' forms the "
f"syntactic category of"
)
tokens = self._llm.tokenizer.encode(prompt, return_tensors="pt")
device = self._llm.model.device
span_probs_raw = self._predict_recursive(tokens.to(device), self._trie, 1.0)
span_probs: Dict[Nonterminal, float] = {
Nonterminal(nt): p for nt, p in span_probs_raw.items()
}
for nt, p in span_probs.items():
logits_logger.info(
f"Probability for non-terminal '{nt.symbol()}' in the span '{span_str}': {p:.8f}"
)
sorted_probs = sorted(span_probs.items(), key=lambda x: x[1], reverse=True)[
:5
]
topk_results = [(nt.symbol(), p) for nt, p in sorted_probs]
print(f"Top 5 candidates for span '{span_str}': {topk_results}")
provider_logger.info(
f"Top candidates for span '{span_str}': {topk_results}"
)
log_cuda_memory_pytorch(f"span_{start}_{end}")
result[(start, end)] = span_probs
if span_probs:
best_nt = max(span_probs.items(), key=lambda x: x[1])[0]
history[(start, end)] = best_nt.symbol()
# Track tokens used in the final selected non-terminal
self._final_nt_token_count += len(
self._nt_token_seqs.get(best_nt.symbol(), [])
)
provider_logger.info(f"Span: {span_text} probabilities: {span_probs}")
log_gpu_memory_nvidia_smi("get_span_probabilities_end")
log_cuda_memory_pytorch("get_span_probabilities_end")
return result
def set_text_and_precompute(self, tokens):
"""
Set the current text being parsed and precompute span probabilities.
This should be called before parsing.
"""
log_gpu_memory_nvidia_smi("set_text_start")
log_cuda_memory_pytorch("set_text_start")
self._text = " ".join(tokens)
n = len(tokens)
# Reset metrics for this sentence
self._token_generation_count = 0
self._final_nt_token_count = 0
self._prediction_count = 0
# Generate all possible spans
spans: List[Tuple[int, int]] = [
(start, start + length)
for length in range(1, n + 1)
for start in range(n - length + 1)
]
spans.sort(key=lambda s: (s[1] - s[0], s[0]))
# Get span probabilities from LLM
self._span_probs = self.get_span_probabilities(self._text, spans)
log_gpu_memory_nvidia_smi("set_text_end")
log_cuda_memory_pytorch("set_text_end")
# Log performance metrics
if self._prediction_count:
avg_generated = self._token_generation_count / self._prediction_count
avg_final = self._final_nt_token_count / self._prediction_count
ratio = (
self._token_generation_count / self._final_nt_token_count
if self._final_nt_token_count
else 0.0
)
else:
avg_generated = avg_final = ratio = 0.0
performance_logger.info(
f"Avg tokens generated per prediction: {avg_generated:.2f}; "
f"Avg tokens in final non-terminal: {avg_final:.2f}; "
f"Generation/Final ratio: {ratio:.2f}; "
f"Total predictions: {self._prediction_count}"
)
return self._span_probs
def print_trie(self):
"""Print the structure of the trie for debugging purposes."""
def _print_node(node, prefix="", token_id=None, depth=0):
indent = " " * depth
token_text = ""
if token_id is not None:
try:
token_text = f" → '{self._llm.tokenizer.decode([token_id])}'"
except Exception:
pass
print(f"{indent}├── Token: {token_id}{token_text}")
if node.nts:
nt_indent = " " * (depth + 1)
print(f"{nt_indent}└── Non-terminals: {', '.join(node.nts)}")
for token_id, child in sorted(node.children.items()):
_print_node(child, prefix + " ", token_id, depth + 1)
print("Trie Structure:")
print("Root")
for token_id, child in sorted(self._trie.children.items()):
_print_node(child, "", token_id, 1)
if __name__ == "__main__":
from nltk.grammar import PCFG
with open("grammar/induced_grammar.cfg", "r") as f:
grammar_str = f.read()
grammar = PCFG.fromstring(grammar_str)
# Extract all nonterminals from the grammar
nonterminals = {str(prod.lhs()) for prod in grammar.productions()}
llm = LocalLLM(model_name="llama3_1-70b")
# logger.info(f"Creating token level probability provider with nonterminals: {nonterminals}")
provider = TokenLevelProbabilityProvider(llm, nonterminals, cache_size=2048)
provider.print_trie()