Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions inference_perf/datagen/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ class InferenceData(BaseModel):

class DataGenerator(ABC):
"""Abstract base class for data generators."""
apiType: APIType

@abstractmethod
def __init__(self, *args: Tuple[int, ...]) -> None:
pass
"""Abstract base class for data generators."""
def __init__(self, apiType: APIType) -> None:
if apiType not in self.get_supported_apis():
raise Exception(f"Unsupported API type {apiType}")
self.apiType = apiType

@abstractmethod
def get_api(self) -> APIType:
def get_supported_apis(self) -> List[APIType]:
raise NotImplementedError

@abstractmethod
Expand Down
28 changes: 22 additions & 6 deletions inference_perf/datagen/hf_sharegpt_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import DataGenerator, InferenceData, ChatCompletionData, ChatMessage
from .base import DataGenerator, InferenceData, CompletionData, ChatCompletionData, ChatMessage
from inference_perf.config import APIType
from typing import Generator
from typing import Generator, List
from datasets import load_dataset


class HFShareGPTDataGenerator(DataGenerator):
def __init__(self) -> None:
def __init__(self, apiType: APIType) -> None:
super().__init__(apiType)
self.sharegpt_dataset = iter(
load_dataset(
"anon8231489123/ShareGPT_Vicuna_unfiltered",
Expand All @@ -34,8 +35,8 @@ def __init__(self) -> None:
# initialize data collection
next(self.sharegpt_dataset)

def get_api(self) -> APIType:
return APIType.Chat
def get_supported_apis(self) -> List[APIType]:
return [APIType.Chat, APIType.Completion]

def get_data(self) -> Generator[InferenceData, None, None]:
if self.sharegpt_dataset is not None:
Expand All @@ -48,7 +49,20 @@ def get_data(self) -> Generator[InferenceData, None, None]:
or len(data[self.data_key]) == 0
):
continue
else:

if self.apiType == APIType.Completion:
try:
prompt = data[self.data_key][0].get(self.content_key)
if not prompt:
continue
yield InferenceData(
type=APIType.Completion,
data=CompletionData(prompt=prompt),
)
except (KeyError, TypeError) as e:
print(f"Skipping invalid completion data: {e}")
continue
elif self.APIType == APIType.Chat:
yield InferenceData(
type=APIType.Chat,
chat=ChatCompletionData(
Expand All @@ -58,3 +72,5 @@ def get_data(self) -> Generator[InferenceData, None, None]:
]
),
)
else:
raise Exception("Unsupported API type")
18 changes: 12 additions & 6 deletions inference_perf/datagen/mock_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import DataGenerator, InferenceData, CompletionData
from typing import Generator
from typing import Generator, List
from inference_perf.config import APIType


class MockDataGenerator(DataGenerator):
def __init__(self) -> None:
def __init__(self, apiType: APIType) -> None:
super().__init__(apiType)
pass

def get_api(self) -> APIType:
return APIType.Completion
def get_supported_apis(self) -> List[APIType]:
return [APIType.Completion]

def get_data(self) -> Generator[InferenceData, None, None]:
i = 0
while True:
i += 1
yield InferenceData(data=CompletionData(prompt="text" + str(i)))
if self.apiType == APIType.Completion:
yield InferenceData(
data=CompletionData(prompt="text" + str(i))
)
else:
raise Exception("Unsupported API type")

6 changes: 3 additions & 3 deletions inference_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def main_cli() -> None:

# Define DataGenerator
if config.data:
datagen = HFShareGPTDataGenerator() if config.data.type == DataGenType.ShareGPT else MockDataGenerator()
if datagen.get_api() != config.vllm.api:
raise Exception("data and model server api type doesn't match")
datagen = MockDataGenerator(config.vllm.api)
if config.data.type == DataGenType.ShareGPT:
datagen = HFShareGPTDataGenerator(config.vllm.api)
else:
raise Exception("data config missing")

Expand Down