diff --git a/docs/source/api_ref_data.rst b/docs/source/api_ref_data.rst index b487db8cdf..3f868e2048 100644 --- a/docs/source/api_ref_data.rst +++ b/docs/source/api_ref_data.rst @@ -64,6 +64,7 @@ Converts data from common schema and conversation JSON formats into a list of to ShareGPTToMessages OpenAIToMessages ChosenRejectedToMessages + AlpacaToMessages Collaters --------- diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 2f490b2025..0f7e00ae99 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -17,6 +17,7 @@ from torchtune.data._converters import get_openai_messages, get_sharegpt_messages from torchtune.data._instruct_templates import InstructTemplate from torchtune.data._messages import ( + AlpacaToMessages, ChosenRejectedToMessages, InputOutputToMessages, Message, @@ -43,6 +44,7 @@ "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", + "AlpacaToMessages", "truncate", "Message", "validate_messages", diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index fb24d9fbe7..86d40fe084 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -621,3 +621,70 @@ def validate_messages( f"System message at index {i} in messages, but system messages must come first" ) last_turn = message.role + + +class AlpacaToMessages(Transform): + """ + Message transform class for Alpaca-style datasets with "instruction", "input", and "output" + (or equivalent fields specified in column_map) columns. User messages are formed from the + instruction + input columns and assistant messages are formed from the output column. Prompt + templating is conditional on the presence of the "input" column, and thus is handled directly + in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class + due to this custom logic. + + Args: + train_on_input (bool): Whether the model is trained on the user prompt or not. + Default is True. + column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input", + and "output" column names to the actual column names in the dataset. Default is None, + keeping the default column names. + """ + + def __init__( + self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + self.template = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:\n" + ), + } + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self.column_map or {} + key_input = column_map.get("input", "input") + key_instruction = column_map.get("instruction", "instruction") + key_output = column_map.get("output", "output") + + if key_input in sample and sample[key_input]: + prompt = self.template["prompt_input"].format( + instruction=sample[key_instruction], input=sample[key_input] + ) + else: + prompt = self.template["prompt_no_input"].format( + instruction=sample[key_instruction] + ) + + messages = [ + Message( + role="user", + content=prompt, + masked=not self.train_on_input, + eot=True, + ), + Message( + role="assistant", + content=sample[key_output], + masked=False, + eot=True, + ), + ] + return {"messages": messages} diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 981e33316b..47d4155743 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -6,80 +6,13 @@ from functools import partial -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Optional, Union + +from torchtune.data._messages import AlpacaToMessages -from torchtune.data._messages import Message from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset from torchtune.modules.tokenizers import ModelTokenizer -from torchtune.modules.transforms import Transform - - -class AlpacaToMessages(Transform): - """ - Message transform class for Alpaca-style datasets with "instruction", "input", and "output" - (or equivalent fields specified in column_map) columns. User messages are formed from the - instruction + input columns and assistant messages are formed from the output column. Prompt - templating is conditional on the presence of the "input" column, and thus is handled directly - in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class - due to this custom logic. - - Args: - train_on_input (bool): Whether the model is trained on the user prompt or not. - Default is True. - column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input", - and "output" column names to the actual column names in the dataset. Default is None, - keeping the default column names. - """ - - def __init__( - self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None - ): - self.train_on_input = train_on_input - self.column_map = column_map - self.template = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:\n" - ), - } - - def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - column_map = self.column_map or {} - key_input = column_map.get("input", "input") - key_instruction = column_map.get("instruction", "instruction") - key_output = column_map.get("output", "output") - - if key_input in sample and sample[key_input]: - prompt = self.template["prompt_input"].format( - instruction=sample[key_instruction], input=sample[key_input] - ) - else: - prompt = self.template["prompt_no_input"].format( - instruction=sample[key_instruction] - ) - - messages = [ - Message( - role="user", - content=prompt, - masked=not self.train_on_input, - eot=True, - ), - Message( - role="assistant", - content=sample[key_output], - masked=False, - eot=True, - ), - ] - return {"messages": messages} def alpaca_dataset(