From dae8f62372c23b54b5c72ff93398180f150cdd77 Mon Sep 17 00:00:00 2001 From: sandyfirst <13823237161@139.com> Date: Wed, 14 May 2025 22:07:52 +0800 Subject: [PATCH 1/2] add _form_columns_description --- sdgx/models/LLM/base.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/sdgx/models/LLM/base.py b/sdgx/models/LLM/base.py index 07e6f109..4ac8d3c2 100644 --- a/sdgx/models/LLM/base.py +++ b/sdgx/models/LLM/base.py @@ -1,7 +1,7 @@ from sdgx.exceptions import SynthesizerInitError from sdgx.models.base import SynthesizerModel from sdgx.utils import logger - +import pandas as pd class LLMBaseModel(SynthesizerModel): """ @@ -87,13 +87,32 @@ def _check_access_type(self): raise SynthesizerInitError("Duplicate data access type found.") def _form_columns_description(self): - """ - We believe that giving information about a column helps improve data quality. - - Currently, we leave this function to Good First Issue until March 2024, if unclaimed we will implement it quickly. - """ - - raise NotImplementedError + + df = self.raw_data # 确保 self.raw_data 是一个 pandas.DataFrame + desc_lines = [] + + for col in df.columns: + series = df[col] + dtype = series.dtype + + if pd.api.types.is_numeric_dtype(dtype): + line = (f'Column "{col}": type {dtype}, ' + f'min {series.min()}, max {series.max()}, ' + f'mean {series.mean():.2f}, std {series.std():.2f}.') + elif pd.api.types.is_datetime64_any_dtype(dtype): + line = (f'Column "{col}": type datetime, ' + f'from {series.min().strftime("%Y-%m-%d")}, ' + f'to {series.max().strftime("%Y-%m-%d")}.') + elif pd.api.types.is_categorical_dtype(series) or series.nunique() < 20: + values = series.unique() + line = (f'Column "{col}": type category, ' + f'{len(values)} categories: {list(values[:5])}{"..." if len(values) > 5 else ""}.') + else: + line = f'Column "{col}": type {dtype}.' + + desc_lines.append(line) + + return "\n".join(desc_lines) def _form_message_with_offtable_features(self): """ From 6a10f35817e97f34dbaf3de2dd3e6f1779088335 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 May 2025 14:20:05 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sdgx/models/LLM/base.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sdgx/models/LLM/base.py b/sdgx/models/LLM/base.py index 4ac8d3c2..4bddac09 100644 --- a/sdgx/models/LLM/base.py +++ b/sdgx/models/LLM/base.py @@ -1,7 +1,9 @@ +import pandas as pd + from sdgx.exceptions import SynthesizerInitError from sdgx.models.base import SynthesizerModel from sdgx.utils import logger -import pandas as pd + class LLMBaseModel(SynthesizerModel): """ @@ -87,7 +89,7 @@ def _check_access_type(self): raise SynthesizerInitError("Duplicate data access type found.") def _form_columns_description(self): - + df = self.raw_data # 确保 self.raw_data 是一个 pandas.DataFrame desc_lines = [] @@ -96,17 +98,23 @@ def _form_columns_description(self): dtype = series.dtype if pd.api.types.is_numeric_dtype(dtype): - line = (f'Column "{col}": type {dtype}, ' - f'min {series.min()}, max {series.max()}, ' - f'mean {series.mean():.2f}, std {series.std():.2f}.') + line = ( + f'Column "{col}": type {dtype}, ' + f"min {series.min()}, max {series.max()}, " + f"mean {series.mean():.2f}, std {series.std():.2f}." + ) elif pd.api.types.is_datetime64_any_dtype(dtype): - line = (f'Column "{col}": type datetime, ' - f'from {series.min().strftime("%Y-%m-%d")}, ' - f'to {series.max().strftime("%Y-%m-%d")}.') + line = ( + f'Column "{col}": type datetime, ' + f'from {series.min().strftime("%Y-%m-%d")}, ' + f'to {series.max().strftime("%Y-%m-%d")}.' + ) elif pd.api.types.is_categorical_dtype(series) or series.nunique() < 20: values = series.unique() - line = (f'Column "{col}": type category, ' - f'{len(values)} categories: {list(values[:5])}{"..." if len(values) > 5 else ""}.') + line = ( + f'Column "{col}": type category, ' + f'{len(values)} categories: {list(values[:5])}{"..." if len(values) > 5 else ""}.' + ) else: line = f'Column "{col}": type {dtype}.'