Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llm/config/qwen/emb_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
"unified_checkpoint": true,
"use_flash_attention": true,
"amp_custom_black_list": "elementwise_div",
"release_grads": true
}
"release_grads": true,
"loss_type": "contrastive"
}
8 changes: 8 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ class EmbeddingArgument:
default=None,
metadata={"help": "The dims for matryoshka training."},
)
loss_type: str = field(
default="contrastive",
metadata={"help": "The type of loss computation."},
)
inf_cl_head_dim: int = field(
default=64,
metadata={"help": "The size of the head dimension when gpu ops are set as 'inf_cl'."},
)
1 change: 1 addition & 0 deletions ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from .flash import cal_flash_loss
from .inf_cl_loss import *
from .ring import cal_inf_loss, cal_ring_loss
61 changes: 61 additions & 0 deletions ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import paddle
import paddle.nn as nn

from .ring import cal_inf_loss

__all__ = ["Simple_Inf_cl_loss", "Matryoshka_Inf_cl_loss"]


class Simple_Inf_cl_loss(nn.Layer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一些注释

def __init__(self, inf_cl_head_dim=64):
super().__init__()
self.head_dim = inf_cl_head_dim

def forward(self, q_reps, p_reps):
group_size = p_reps.shape[0] // q_reps.shape[0]
labels = paddle.arange(q_reps.shape[0], dtype="int64")
labels = labels * group_size
loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim)
return loss


class Matryoshka_Inf_cl_loss(nn.Layer):
def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64):
super().__init__()
if embedding_matryoshka_dims is None:
self.embedding_matryoshka_dims = []
else:
self.embedding_matryoshka_dims = embedding_matryoshka_dims
self.loss_fn = Simple_Inf_cl_loss(inf_cl_head_dim)

def forward(self, q_reps, p_reps):
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim]
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)

reduced_p_reps = p_reps[:, :dim]
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)

dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)
loss += dim_loss
else:
loss = self.loss_fn(q_reps, p_reps)
return loss
20 changes: 16 additions & 4 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from paddle.base import core
from paddle.distributed import fleet

from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (
from paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有默认安装,需要 try except一下

Matryoshka_Inf_cl_loss,
Simple_Inf_cl_loss,
)
from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import (
MatryoshkaContrastiveLoss,
Expand All @@ -44,11 +48,19 @@ def __init__(self, model_args, **kwargs):
self.accum_rng_states["hybrid"] = []

if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0:
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = Matryoshka_Inf_cl_loss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
else:
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = Simple_Inf_cl_loss(model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)

def clear_memory(self):
self.accum_q_features.clear()
Expand Down