Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ def _update_fingerprint_mac(*args, **kwargs):
datasets.arrow_dataset.update_fingerprint = _update_fingerprint_mac


def partialed_map(self, *args, **kwargs):
if 'num_proc' not in kwargs:
num_proc = os.environ.get('DATASET_MAP_NPROC')
kwargs['num_proc'] = int(num_proc) if num_proc else num_proc
return self._origin_map(*args, **kwargs)
def patch_num_proc(func_name: str):
_origin_func_name = f'_origin_{func_name}'
_old_func = getattr(HfDataset, func_name)

def new_func(self, *args, **kwargs):
if 'num_proc' not in kwargs:
num_proc = os.environ.get('DATASET_MAP_NPROC')
if num_proc:
kwargs['num_proc'] = int(num_proc)
return _old_func(self, *args, **kwargs)

datasets.Dataset._origin_map = datasets.Dataset.map
datasets.Dataset.map = partialed_map
setattr(HfDataset, _origin_func_name, _old_func)
setattr(HfDataset, func_name, new_func)


for func_name in ['map', 'filter']:
patch_num_proc(func_name)

standard_keys = {
'query', 'query_role', 'response', 'rejected_response', 'system', 'history', 'history_roles', 'images', 'objects',
Expand Down
43 changes: 34 additions & 9 deletions swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
import os
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import numpy as np
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
from tqdm import tqdm
Expand Down Expand Up @@ -30,29 +32,50 @@ def _reduce_columns(cls: type) -> type:
cls._patching = True

def new_call_func(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
self.column_state = set(['images', 'videos', 'audios'])
self.key_mapping = {k: i for i, k in enumerate(self.empty_row.keys())}
num_proc = int(os.environ.get('DATASET_MAP_NPROC', '1'))
self.shared_shm_name = None
shm, buffer = None, None
if num_proc > 1: # multiprocess
shm = shared_memory.SharedMemory(create=True, size=len(self.key_mapping))
self.shared_shm_name = shm.name
buffer = shm.buf
self.column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=buffer)
dataset = call_func(self, dataset)
if isinstance(dataset, HfIterableDataset) and dataset.features is None:
features = next(iter(dataset)).keys()
else:
features = dataset.features.keys()
for k in features:
if k not in self.column_state:
if k in ['images', 'videos', 'audios']:
continue
k_i = self.key_mapping.get(k, -1)
if k_i == -1 or not self.column_state[k_i]:
dataset = dataset.remove_columns([k])
if shm:
shm.close()
shm.unlink()
return dataset

def new_preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
if self.shared_shm_name is not None: # multiprocess
shm = shared_memory.SharedMemory(name=self.shared_shm_name)
column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=shm.buf)
else:
column_state = self.column_state
row = preprocess(self, row)
for k, v in row.items():
k_i = self.key_mapping[k]
if column_state[k_i]:
continue
if k == 'query_role':
if k not in self.column_state and v and v != 'user':
self.column_state.add(k)
if v and v != 'user':
column_state[k_i] = True
elif k == 'history_roles':
if k not in self.column_state and v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v):
self.column_state.add(k)
else:
if v:
self.column_state.add(k)
if v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v):
column_state[k_i] = True
elif v:
column_state[k_i] = True
return row

cls.__call__ = new_call_func
Expand Down Expand Up @@ -142,6 +165,7 @@ def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
return dataset


@_reduce_columns
class AlpacaPreprocessor(MediaMixin, RowPreprocessMixin):

def __init__(self, concat_inst_inp: Optional[Callable[[str, str], str]] = None, **kwargs):
Expand Down Expand Up @@ -194,6 +218,7 @@ def _default_repair_conversations(s: Union[str, Any]) -> Any:
return s


@_reduce_columns
class ConversationsPreprocessor(MediaMixin, RowPreprocessMixin):

def __init__(self,
Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def _concat_context_list(
new_str_list = [system, query, round0, round1]
for (old_str, new_str) in zip(old_str_list, new_str_list):
if new_str is not None and old_str in context:
assert isinstance(new_str, str), f'new_str: {new_str}'
context = context.replace(old_str, new_str)
if len(context) == 0:
continue
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def _find_module_list(vision_tower) -> Optional[nn.ModuleList]:
return
if isinstance(m, nn.ModuleList) and len(m) >= 10:
module_lists.append(m)
if module_lists is not None:
if module_lists:
return max(module_lists, key=lambda x: len(x))


Expand Down