diff --git a/sdgx/models/LLM/base.py b/sdgx/models/LLM/base.py index 07e6f109..4bddac09 100644 --- a/sdgx/models/LLM/base.py +++ b/sdgx/models/LLM/base.py @@ -1,3 +1,5 @@ +import pandas as pd + from sdgx.exceptions import SynthesizerInitError from sdgx.models.base import SynthesizerModel from sdgx.utils import logger @@ -87,13 +89,38 @@ def _check_access_type(self): raise SynthesizerInitError("Duplicate data access type found.") def _form_columns_description(self): - """ - We believe that giving information about a column helps improve data quality. - - Currently, we leave this function to Good First Issue until March 2024, if unclaimed we will implement it quickly. - """ - raise NotImplementedError + df = self.raw_data # 确保 self.raw_data 是一个 pandas.DataFrame + desc_lines = [] + + for col in df.columns: + series = df[col] + dtype = series.dtype + + if pd.api.types.is_numeric_dtype(dtype): + line = ( + f'Column "{col}": type {dtype}, ' + f"min {series.min()}, max {series.max()}, " + f"mean {series.mean():.2f}, std {series.std():.2f}." + ) + elif pd.api.types.is_datetime64_any_dtype(dtype): + line = ( + f'Column "{col}": type datetime, ' + f'from {series.min().strftime("%Y-%m-%d")}, ' + f'to {series.max().strftime("%Y-%m-%d")}.' + ) + elif pd.api.types.is_categorical_dtype(series) or series.nunique() < 20: + values = series.unique() + line = ( + f'Column "{col}": type category, ' + f'{len(values)} categories: {list(values[:5])}{"..." if len(values) > 5 else ""}.' + ) + else: + line = f'Column "{col}": type {dtype}.' + + desc_lines.append(line) + + return "\n".join(desc_lines) def _form_message_with_offtable_features(self): """