Skip to content

Commit 4b56d5d

Browse files
committed
Fix endpoints layer
1 parent cd9f593 commit 4b56d5d

File tree

4 files changed

+31
-22
lines changed

4 files changed

+31
-22
lines changed

interprot/endpoints/sae_inference/handler.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,14 @@ def get_sequence(self, x, layer_idx):
365365

366366

367367
def load_models():
368-
sae_name_to_model = {}
368+
sea_name_to_info = {}
369369
for sae_name, sae_checkpoint in SAE_NAME_TO_CHECKPOINT.items():
370370
pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)"
371371
matches = re.search(pattern, sae_checkpoint)
372372
if matches:
373-
plm_dim, _, sae_dim = map(int, matches.groups())
373+
plm_dim, plm_layer, sae_dim = map(int, matches.groups())
374374
else:
375-
raise ValueError("Checkpoint file must be named in the format plm<n>_l<n>_sae<n>")
375+
raise ValueError("Checkpoint file must start with plm<n>_l<n>_sae<n>")
376376
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
377377

378378
# Load ESM2 model
@@ -403,10 +403,13 @@ def load_models():
403403
for k, v in torch.load(sae_weights)["state_dict"].items()
404404
}
405405
)
406-
sae_name_to_model[sae_name] = sae_model
406+
sea_name_to_info[sae_name] = {
407+
"model": sae_model,
408+
"plm_layer": plm_layer,
409+
}
407410

408411
logger.info("Models loaded successfully")
409-
return esm2_model, sae_name_to_model
412+
return esm2_model, sea_name_to_info
410413

411414

412415
def handler(event):
@@ -416,14 +419,15 @@ def handler(event):
416419
seq = input_data["sequence"]
417420
sae_name = input_data["sae_name"]
418421
dim = input_data.get("dim")
419-
_, esm_layer_acts = esm2_model.get_layer_activations(seq, 24)
422+
sae_info = sea_name_to_info[sae_name]
423+
sae_model = sae_info["model"]
424+
plm_layer = sae_info["plm_layer"]
425+
logger.info(f"sae_name: {sae_name}, plm_layer: {plm_layer}, dim: {dim}")
426+
427+
_, esm_layer_acts = esm2_model.get_layer_activations(seq, plm_layer)
420428
esm_layer_acts = esm_layer_acts[0].float()
421-
logger.info(f"esm_layer_acts: {esm_layer_acts.shape}")
422429

423-
sae_model = sae_name_to_model[sae_name]
424-
print(f"sae_model: {sae_model}")
425430
sae_acts = sae_model.get_acts(esm_layer_acts)[1:-1]
426-
logger.info(f"sae_acts: {sae_acts.shape}")
427431

428432
data = {}
429433
if dim is not None:
@@ -454,5 +458,5 @@ def handler(event):
454458
return {"status": "error", "error": str(e)}
455459

456460

457-
esm2_model, sae_name_to_model = load_models()
461+
esm2_model, sea_name_to_info = load_models()
458462
runpod.serverless.start({"handler": handler})

interprot/endpoints/steer_feature/handler.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,14 @@ def get_sequence(self, x, layer_idx):
365365

366366

367367
def load_models():
368-
sae_name_to_model = {}
368+
sea_name_to_info = {}
369369
for sae_name, sae_checkpoint in SAE_NAME_TO_CHECKPOINT.items():
370370
pattern = r"plm(\d+).*?l(\d+).*?sae(\d+)"
371371
matches = re.search(pattern, sae_checkpoint)
372372
if matches:
373-
plm_dim, _, sae_dim = map(int, matches.groups())
373+
plm_dim, plm_layer, sae_dim = map(int, matches.groups())
374374
else:
375-
raise ValueError("Checkpoint file must be named in the format plm<n>_l<n>_sae<n>")
375+
raise ValueError("Checkpoint file must start with plm<n>_l<n>_sae<n>")
376376
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
377377

378378
# Load ESM2 model
@@ -403,10 +403,13 @@ def load_models():
403403
for k, v in torch.load(sae_weights)["state_dict"].items()
404404
}
405405
)
406-
sae_name_to_model[sae_name] = sae_model
406+
sea_name_to_info[sae_name] = {
407+
"model": sae_model,
408+
"plm_layer": plm_layer,
409+
}
407410

408411
logger.info("Models loaded successfully")
409-
return esm2_model, sae_name_to_model
412+
return esm2_model, sea_name_to_info
410413

411414

412415
def handler(event):
@@ -416,12 +419,14 @@ def handler(event):
416419
sae_name = input_data["sae_name"]
417420
dim = input_data["dim"]
418421
multiplier = input_data["multiplier"]
422+
logger.info(f"sae_name: {sae_name}, dim: {dim}, multiplier: {multiplier}")
419423

420-
sae_model = sae_name_to_model[sae_name]
421-
print(f"sae_model: {sae_model}")
424+
sae_info = sea_name_to_info[sae_name]
425+
sae_model = sae_info["model"]
426+
plm_layer = sae_info["plm_layer"]
422427

423428
# First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor
424-
_, esm_layer_acts = esm2_model.get_layer_activations(seq, 24)
429+
_, esm_layer_acts = esm2_model.get_layer_activations(seq, plm_layer)
425430
sae_latents, mu, std = sae_model.encode(esm_layer_acts[0])
426431

427432
# Decode the SAE latents yields a (L, 1280) tensor `decoded_esm_layer_acts`,
@@ -452,5 +457,5 @@ def handler(event):
452457
return {"status": "error", "error": str(e)}
453458

454459

455-
esm2_model, sae_name_to_model = load_models()
460+
esm2_model, sea_name_to_info = load_models()
456461
runpod.serverless.start({"handler": handler})

viz/src/components/FullSeqsViewer.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ const FullSeqsViewer: React.FC<FullSeqViewerProps> = ({
8080
</span>
8181
</TooltipTrigger>
8282
<TooltipContent>
83-
Position: {pos}, SAE Activation: {chain.activations[pos]?.toFixed(3)}
83+
Position: {pos}, SAE Activation: {chain.activations[pos]}
8484
</TooltipContent>
8585
</Tooltip>
8686
</TooltipProvider>

viz/src/components/SeqsViewer.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ export default function SeqsViewer({ seqs, title }: SeqsViewerProps) {
472472
<TooltipContent>
473473
Position: {seq.sequence.slice(0, index).replace(/-/g, "").length}
474474
<br />
475-
SAE Activation: {seq.sae_acts[index]?.toFixed(3)}
475+
SAE Activation: {seq.sae_acts[index]}
476476
</TooltipContent>
477477
</Tooltip>
478478
</TooltipProvider>

0 commit comments

Comments
 (0)