Notes on implementing the 2025 paper, "Attacking Large Language Models with Projected Gradient Descent"
This post delves into the paper, Attacking LLMs with Projected Gradient Descent, my experience implementing its methodology in PyTorch, and subsequent explorations.
All references to methodology are drawn from the paper unless otherwise noted.
The paper highlights that if the rate of adoption and application of LLMs outpaces efforts towards ensuring their alignment and safety, their vulnerabilities will manifest with harmful repercussions.
The 2025 paper by Geisler et al.
I'm a helpful assistant and cannot assist with that request...
However, by appending a tuned ‘adversarial’ prefix and/or suffix \(P\), \(S\) to the input, forming \(P \circ I \circ S\), the LLM can be manipulated into producing the undesired content \(C\).
\[P(C | P \circ I \circ S) >> P(C | I)\]Note: \(\circ\) above is the concatenation operator.
The attack draws inspiration from adding noise as adversarial perturbations to images in computer vision tasks, resulting in misclassifications. However, unlike images where pixel values are continuous, text inputs to LLMs are discrete, integral indices drawn from a vocabulary, making direct gradient-based perturbations infeasible.
So given an open-weight LLM, a malicious user prompt \(U\), and a target prefix \(T\) representative of a continuation of malicious content, the task is to find an adversarial suffix X that when attached to the user prompt, “jailbreaks” the LLM into generating the target T.
*Note that both prefix and/or suffix can be used.
Mathematically,
\[\arg\max_{X} P(LLM(U \circ X) = T || T')\]Here, the optimization occurs against ground truth target \(T\). Empirically, steering towards generation of \(T\), is often sufficient for the LLM to continue producing harmful content \(T'\). We want to learn a suffix \(X\), when appended to \(U\), that maximizes the likelihood of generating \(T\).
The likelihood of generating \(T\) given \(U \circ X\) can be expressed as: \(P(LLM(U \circ X) = T) = \prod_{i=1}^{|T|} P_{LLM}(t_i | U \circ X \circ t_{<i})\)
For numerical stability, the log-likelihood is maximized (as the logarithm function is monotonically increasing): \(\arg\max_{X} \sum_{i=1}^{|T|} \log P_{LLM}(t_i | U \circ X \circ t_{<i})\)
This is equivalent to minimizing the negative log-likelihood (NLL): \(\arg\min_{X} - \sum_{i=1}^{|T|} \log P_{LLM}(t_i | U \circ X \circ t_{<i})\)
Which is synonymous with minimizing the cross-entropy loss between the predicted token distribution and the target tokens expressed as one-hot encodings over the model’s vocabulary.
The adversarial tokens we wish to learn can be expressed as a sequence of \(L\) tokens, each a 1-hot vector of dimensionality \(|V|\), where \(V\) is the model vocabulary. Thus, \(X\) defined as \(X \in \{0,1\}^{L \times |V|}\) represents the discrete token sequence we wish to learn.
In its discrete form, this combinatorial optimization problem is computationally intractable with complexity \(O(|V|^L)\). A relaxation is applied on the ‘hard’ token representations, \(X \in \{0,1\}^{L \times \vert V \vert}\), to ‘soft’ representations, \(X \in [0,1]^{L \times \vert V \vert}\), with \(\forall i, \sum_{j=1}^{\vert V \vert} X_{i,j} = 1\).
Each row of \(X\) can be interpreted as a categorical distribution over the model’s vocabulary. This makes the loss function differentiable with respect to \(X\), allowing us to apply gradient descent.
Each row of \(X\) belongs to the probability simplex in \(\mathbb{R}^{\vert V \vert}\), which is the set of all vectors \(x \in \mathbb{R}^{\vert V \vert}\) with \(x_i \in [0,1]\) and \(\sum x_i = 1\) (matching our definition of soft token representation above).
To map back to discrete, token space, \(t_i = \arg\max_j X_{i,j}\) (greedy decoding).
After each gradient descent step, the updated \(X\) may no longer satisfy the simplex constraints, i.e., be a probability vector. To enforce this, \(X\) is row-wise projected onto the probability simplex using the algorithm from Efficient Projections onto the l1-Ball for Learning in High Dimensions by Duchi et al.
The projection is defined as finding a probability vector \(x_p\) such that \(x_p = \arg\min_{x \in \Delta} \vert \vert x - x_{t+1} \vert \vert\)
A natural question here is the choice of simplex projection over naive normalization or softmax. Two interpretations that intuitively to support this:
Lastly, an additional projection is conditionally applied to enforce a target Tsallis Entropy with \(q=2\) on the rows of \(X\). This is done if the entropy difference between the current token distribution and the representative uniform distribution is sufficiently small.
The Tsallis entropy for a probability vector \(p\) simplifies to \(H_2(p) = 1 - \sum_i p_i^2\) which is related to the norm. For a uniform distribution, this is maximum at 1 and for a one-hot vector (a case of a shifted delta distribution), is minimum at 0, matching the interpretation of entropy as a measure of uncertainty.
However, we should note that the definition can be used to solve for a contradictory case of a distribution with value \(\frac{1}{\sqrt{n}}\) everywhere, evaluating to an entropy of 0 yet being uniform.
The attack is implemented in PyTorch using the HF transformers library.
class PGD:
def __init__(self, model_name: str, device: str = "cuda"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
self.learning_rate = 0.325
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
"""
Freeze all model parameters, we only want to learn the adversarial suffix
"""
for param in self.model.parameters():
param.requires_grad = False
"""
Embedding matrix for soft token projection
"""
self.embedding = self.model.get_input_embeddings()
self.V, _ = self.embedding.weight.shape
Some choices of relaxed adversarial tokens are:
Creating our learnable adversarial matrix for (1):
# class PGD:
def initialize_prompt_from_normal(
self,
user_prompt: str,
num_prefix_tokens: int = 20,
num_suffix_tokens: int = 20,
temperature: float = 0.6) -> torch.Tensor:
self.num_prefix = num_prefix_tokens
self.num_suffix = num_suffix_tokens
user_prompt_token_ids = self.tokenizer.encode(user_prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
self.user_prompt_tokens = F.one_hot(user_prompt_token_ids, num_classes=self.V).float().to(self.device)
prefix_relaxed = torch.randn((1, num_prefix_tokens, self.V), device=self.device)
prefix_relaxed = F.softmax(prefix_relaxed / temperature, dim=-1)
suffix_relaxed = torch.randn((1, num_suffix_tokens, self.V), device=self.device)
suffix_relaxed = F.softmax(suffix_relaxed / temperature, dim=-1)
adversarial_tokens = torch.cat([prefix_relaxed, self.user_prompt_tokens.float(), suffix_relaxed], dim=1)
adversarial_tokens = torch.nn.Parameter(adversarial_tokens, requires_grad=True)
self.model.register_parameter("adversarial_tokens", adversarial_tokens)
return adversarial_tokens
As the learnable matrix contains both the user prompt and the suffix, a helpful utility to keep the user part of the matrix fixed during optimization,
class Util:
@staticmethod
def overlay(x: torch.Tensor, prefix_len: int, suffix_len: int, new: torch.Tensor) -> torch.Tensor:
# copy new tensor data only to prefix and suffix tokens
result = x.clone()
result[:, :prefix_len, :] = new[:, :prefix_len, :]
result[:, -suffix_len:, :] = new[:, -suffix_len:, :]
return result
Now, the simplex algorithm can be defined as:
class Math:
@staticmethod
def simplex_projection(x: torch.Tensor, device) -> torch.Tensor:
# x = (B, seq_len, vocab_size)
size = x.size()[-1]
sorted_tokens, _ = torch.sort(x, descending=True, dim=-1)
prefix_sum = torch.cumsum(sorted_tokens, dim=-1)
range_tensor = torch.arange(1, size + 1, device=device)
prefix_averages = (prefix_sum - 1) / range_tensor.float()
indicators = (sorted_tokens - prefix_averages) > 0
rho = torch.sum(indicators, dim=-1).unsqueeze(-1)
delta = (prefix_sum.gather(-1, rho - 1) - 1) / rho.float()
return F.relu(x - delta)
applied using,
# class PGD:
def simplex_projection(self, x: torch.Tensor) -> torch.Tensor:
projection = Math.simplex_projection(x, self.device)
return Util.overlay(x, self.num_prefix, self.num_suffix, projection)
It is initially un-intuitive how this corresponds to the aforementioned geometric definition. The short paper Simplex Projection derives the algorithm analytically for the equivalent quadratic program using KKT conditions. The quadratic form expresses minimizing Euclidean distance to the simplex.
Similarly, entropy projection is implemented as:
# class Math:
@staticmethod
def entropy_project(x: torch.Tensor, target_entropy: float, device) -> torch.Tensor:
# x shape: (batch_size, seq_len, vocab_size)
indicator = (x > 0).float()
num_positive = indicator.sum(dim=-1, keepdim=True)
center = indicator / num_positive
radius_sq = 1.0 - target_entropy - (1.0 / num_positive)
radius = torch.sqrt(torch.clamp(radius_sq, min=0.0))
diff = x - center
dist = torch.norm(diff, p=2, dim=-1, keepdim=True)
mask = (radius >= dist).float()
scale = radius / torch.clamp(dist, min=1e-12)
projection = scale * diff + center
projection = Math.simplex_projection(projection, device)
return mask * x + (1 - mask) * projection
# class PGD:
def entropy_projection(self, x: torch.Tensor, target_entropy: float) -> torch.Tensor:
# Project to entropy constraint
projection = Math.entropy_project(x, target_entropy, self.device)
return Util.overlay(x, self.num_prefix, self.num_suffix, projection)
Computing the losses is done by making use of the in-built cross-entropy loss function in transformers:
# class PGD:
def compute_loss(
self, adversarial_tokens: torch.Tensor, target_text: str
) -> torch.Tensor:
# Compute negative log likelihood loss for target text
embeddings = torch.matmul(adversarial_tokens, self.embedding.weight)
target_ids = self.tokenizer.encode(target_text, return_tensors="pt", add_special_tokens=False).to(self.device)
with torch.no_grad():
target_embeddings = self.embedding(target_ids)
# shape (B, seq_len, D)
inputs_embeds = torch.cat([embeddings, target_embeddings], dim=1)
# prepend labels with -100 to ignore loss on prompt tokens; 1 is the sequence dimension
pad_length = inputs_embeds.size(1) - target_ids.size(1)
padded_target_ids = F.pad(target_ids, (pad_length, 0), value=IGNORE_LOSS_ID)
outputs = self.model(inputs_embeds=inputs_embeds, labels=padded_target_ids)
return outputs.loss
def compute_hard_loss(
self, adversarial_tokens: torch.Tensor, target_text: str
) -> float:
with torch.no_grad():
discretized_tokens = adversarial_tokens.argmax(dim=-1)
ids = self.tokenizer.encode(self.tokenizer.decode(discretized_tokens.squeeze(0)), return_tensors="pt", add_special_tokens=False).to(self.device)
one_hot_tokens = F.one_hot(ids, num_classes=self.V).float()
return self.compute_loss(one_hot_tokens, target_text).item()
Note, the loss computation above implicitly uses teacher forcing, where ground truth target tokens are used instead the model’s own output in the auto-regressive decoding phase.
In the optimization loop, gradients are masked to only be applied to the non-user prompt portion of the adversarial matrix along with the relevant projections. Both of these are done in a context manager to avoid these computations being tracked by the autograd engine.
...
loss = self.compute_loss(adversarial_tokens, target)
loss.backward()
with torch.no_grad():
gradients = Util.overlay(torch.zeros_like(adversarial_tokens.grad), self.num_prefix, self.num_suffix, adversarial_tokens.grad)
clipped_gradients = self.clip_norm(gradients, max_norm=20.0)
adversarial_tokens.grad.copy_(clipped_gradients)
optimizer.step()
scheduler.step()
learning_rate = scheduler.get_last_lr()[0]
with torch.no_grad():
hard_loss = self.compute_hard_loss(adversarial_tokens, target)
projected = self.simplex_projection(adversarial_tokens)
target_entropy = self.compute_target_entropy(learning_rate, loss.item(), hard_loss)
projected = self.entropy_projection(projected, target_entropy)
adversarial_tokens.copy_(projected)
...
Some key things to ensure for stable training and convergence:
*TODO: I am actively exploring methods in alignment to increase resilience and will update this section.
Optimizing the relaxed objective purely without alignment to discretized token, i.e., finding adversarial embeddings is trivial in comparison to finding discrete solutions, i.e., adversarial tokens. However, discrete adversarial tokens exist and are effective in steering open-weight models such as Vicuna-7B and Gemma-2B into generating harmful content.