Skip to content

Commit 33d8418

Browse files
committed
release: v0.3.0
1 parent ade2657 commit 33d8418

File tree

5 files changed

+65
-69
lines changed

5 files changed

+65
-69
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pip install -e .
2626

2727
## Overview
2828

29-
The GCG algorithm was introduced in [Universal and Transferrable Attacks on Aligned Language Models](https://arxiv.org/pdf/2307.15043) [1] by Andy Zou, Zifan Wang, Nicholas Carlini, Milad Nasr, Zico Kolter, and Matt Fredrikson. This implementation implements the original algorithm and supports several modifications that can improve performance, including multi-position token swapping [2], a historical attack buffer [2][3], the mellowmax loss function [4][5], and probe sampling [6].
29+
The GCG algorithm was introduced in [Universal and Transferrable Attacks on Aligned Language Models](https://arxiv.org/pdf/2307.15043) [1] by Andy Zou, Zifan Wang, Nicholas Carlini, Milad Nasr, Zico Kolter, and Matt Fredrikson. nanoGCG implements the original algorithm and supports several modifications that can improve performance, including multi-position token swapping [2], a historical attack buffer [2][3], the mellowmax loss function [4][5], and probe sampling [6].
3030

3131
## Usage
3232

@@ -93,7 +93,7 @@ The parameters that can be configured and their defaults are:
9393

9494
- `verbosity: str = "INFO"` - the reported logging error level (e.g. "ERROR", "WARNING", "INFO")
9595

96-
- `probe_sampling_config: ProbeSamplingConfig = None` - A collection of configuratble parameters for probe sampling. See the example below.
96+
- `probe_sampling_config: ProbeSamplingConfig = None` - A collection of configurable parameters for probe sampling. See the example below.
9797

9898
Note that the default nanoGCG configuration will run the GCG algorithm as described in the [original paper](https://arxiv.org/pdf/2307.15043) without algorithmic changes like multi-position token swapping and mellowmax.
9999

@@ -136,13 +136,13 @@ You can enable probe sampling by specifying the `probe_sampling_config` with app
136136
import nanogcg
137137
import torch
138138

139-
from nanogcg import GCGConfig
139+
from nanogcg import GCGConfig, ProbeSamplingConfig
140140
from transformers import AutoModelForCausalLM, AutoTokenizer
141141

142142
draft_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", torch_dtype=torch.bfloat16).to("cuda")
143143
draft_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
144144

145-
probe_sampling_config = nanogcg.gcg.ProbeSamplingConfig(
145+
probe_sampling_config = ProbeSamplingConfig(
146146
draft_model=draft_model,
147147
draft_tokenizer=draft_tokenizer,
148148
r=64,

examples/simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import AutoModelForCausalLM, AutoTokenizer
99

1010
import nanogcg
11-
from nanogcg.gcg import ProbeSamplingConfig
11+
from nanogcg import GCGConfig, ProbeSamplingConfig
1212

1313

1414
def parse_args() -> argparse.Namespace:
@@ -40,7 +40,7 @@ def main():
4040

4141
messages = [{"role": "user", "content": args.prompt}]
4242

43-
config = nanogcg.GCGConfig(
43+
config = GCGConfig(
4444
verbosity="DEBUG",
4545
probe_sampling_config=probe_sampling_config,
4646
)

nanogcg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
For more detailed information, see the GitHub repository: https://github.com/GraySwanAI/nanoGCG/tree/main
1616
"""
1717

18-
from .gcg import GCGConfig, run
18+
from .gcg import GCGConfig, ProbeSamplingConfig, run

nanogcg/gcg.py

Lines changed: 57 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,8 @@ def run(
281281

282282
# Tokenize everything that doesn't get optimized for the draft model
283283
draft_before_ids = self.draft_tokenizer([before_str], padding=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64)
284-
draft_after_ids = self.draft_tokenizer([after_str], add_special_tokens=False, return_tensors="pt")["input_ids"].to(
285-
model.device, torch.int64
286-
)
287-
self.draft_target_ids = self.draft_tokenizer([target], add_special_tokens=False, return_tensors="pt")["input_ids"].to(
288-
model.device, torch.int64
289-
)
284+
draft_after_ids = self.draft_tokenizer([after_str], add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64)
285+
self.draft_target_ids = self.draft_tokenizer([target], add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64)
290286

291287
(
292288
self.draft_before_embeds,
@@ -356,7 +352,7 @@ def run(
356352
optim_ids = sampled_ids[loss.argmin()].unsqueeze(0)
357353
else:
358354
current_loss, optim_ids = find_executable_batch_size(self._compute_candidates_loss_probe_sampling, batch_size)(
359-
input_embeds, sampled_ids
355+
input_embeds, sampled_ids,
360356
)
361357

362358
# Update the buffer based on the loss
@@ -498,6 +494,60 @@ def compute_token_gradient(
498494

499495
return optim_ids_onehot_grad
500496

497+
def _compute_candidates_loss_original(
498+
self,
499+
search_batch_size: int,
500+
input_embeds: Tensor,
501+
) -> Tensor:
502+
"""Computes the GCG loss on all candidate token id sequences.
503+
504+
Args:
505+
search_batch_size : int
506+
the number of candidate sequences to evaluate in a given batch
507+
input_embeds : Tensor, shape = (search_width, seq_len, embd_dim)
508+
the embeddings of the `search_width` candidate sequences to evaluate
509+
"""
510+
all_loss = []
511+
prefix_cache_batch = []
512+
513+
for i in range(0, input_embeds.shape[0], search_batch_size):
514+
with torch.no_grad():
515+
input_embeds_batch = input_embeds[i:i + search_batch_size]
516+
current_batch_size = input_embeds_batch.shape[0]
517+
518+
if self.prefix_cache:
519+
if not prefix_cache_batch or current_batch_size != search_batch_size:
520+
prefix_cache_batch = [[x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]] for i in range(len(self.prefix_cache))]
521+
522+
outputs = self.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch, use_cache=True)
523+
else:
524+
outputs = self.model(inputs_embeds=input_embeds_batch)
525+
526+
logits = outputs.logits
527+
528+
tmp = input_embeds.shape[1] - self.target_ids.shape[1]
529+
shift_logits = logits[..., tmp-1:-1, :].contiguous()
530+
shift_labels = self.target_ids.repeat(current_batch_size, 1)
531+
532+
if self.config.use_mellowmax:
533+
label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
534+
loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
535+
else:
536+
loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none")
537+
538+
loss = loss.view(current_batch_size, -1).mean(dim=-1)
539+
all_loss.append(loss)
540+
541+
if self.config.early_stop:
542+
if torch.any(torch.all(torch.argmax(shift_logits, dim=-1) == shift_labels, dim=-1)).item():
543+
self.stop_flag = True
544+
545+
del outputs
546+
gc.collect()
547+
torch.cuda.empty_cache()
548+
549+
return torch.cat(all_loss, dim=0)
550+
501551
def _compute_candidates_loss_probe_sampling(
502552
self,
503553
search_batch_size: int,
@@ -671,60 +721,6 @@ def _convert_to_draft_tokens(token_ids: Tensor) -> Tensor:
671721
)
672722
)
673723

674-
def _compute_candidates_loss_original(
675-
self,
676-
search_batch_size: int,
677-
input_embeds: Tensor,
678-
) -> Tensor:
679-
"""Computes the GCG loss on all candidate token id sequences.
680-
681-
Args:
682-
search_batch_size : int
683-
the number of candidate sequences to evaluate in a given batch
684-
input_embeds : Tensor, shape = (search_width, seq_len, embd_dim)
685-
the embeddings of the `search_width` candidate sequences to evaluate
686-
"""
687-
all_loss = []
688-
prefix_cache_batch = []
689-
690-
for i in range(0, input_embeds.shape[0], search_batch_size):
691-
with torch.no_grad():
692-
input_embeds_batch = input_embeds[i:i + search_batch_size]
693-
current_batch_size = input_embeds_batch.shape[0]
694-
695-
if self.prefix_cache:
696-
if not prefix_cache_batch or current_batch_size != search_batch_size:
697-
prefix_cache_batch = [[x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]] for i in range(len(self.prefix_cache))]
698-
699-
outputs = self.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch, use_cache=True)
700-
else:
701-
outputs = self.model(inputs_embeds=input_embeds_batch)
702-
703-
logits = outputs.logits
704-
705-
tmp = input_embeds.shape[1] - self.target_ids.shape[1]
706-
shift_logits = logits[..., tmp-1:-1, :].contiguous()
707-
shift_labels = self.target_ids.repeat(current_batch_size, 1)
708-
709-
if self.config.use_mellowmax:
710-
label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
711-
loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
712-
else:
713-
loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none")
714-
715-
loss = loss.view(current_batch_size, -1).mean(dim=-1)
716-
all_loss.append(loss)
717-
718-
if self.config.early_stop:
719-
if torch.any(torch.all(torch.argmax(shift_logits, dim=-1) == shift_labels, dim=-1)).item():
720-
self.stop_flag = True
721-
722-
del outputs
723-
gc.collect()
724-
torch.cuda.empty_cache()
725-
726-
return torch.cat(all_loss, dim=0)
727-
728724

729725
# A wrapper around the GCG `run` method that provides a simple API
730726
def run(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "nanogcg"
7-
version = "0.2.3"
7+
version = "0.3.0"
88
authors = [
99
{ name="Justin Wang", email="[email protected]" },
1010
]

0 commit comments

Comments
 (0)