diff --git a/python/tests/test_DFtoVW.py b/python/tests/test_DFtoVW.py index eb5b83c9f04..ccceee52488 100644 --- a/python/tests/test_DFtoVW.py +++ b/python/tests/test_DFtoVW.py @@ -126,6 +126,17 @@ def test_multiple_named_namespaces_multiple_features_multiple_lines(): ] +def test_multiple_lines_with_weight(): + df = pd.DataFrame({ + "y": [1, 2, -1], + "w": [2.5, 1.2, 3.75], + "x": ["a", "b", "c"] + }) + conv = DFtoVW(df=df, label=SimpleLabel(label="y", weight="w"), features=Feature("x")) + lines_list = conv.convert_df() + assert lines_list == ['1 2.5 | x=a', '2 1.2 | x=b', '-1 3.75 | x=c'] + + # Exception tests for SimpleLabel def test_absent_col_error(): with pytest.raises(ValueError) as value_error: @@ -162,6 +173,14 @@ def test_wrong_feature_type_error(): assert expected == str(type_error.value) +def test_wrong_weight_type_error(): + df = pd.DataFrame({"y": [1], "x": [2], "w": ["a"]}) + with pytest.raises(TypeError) as type_error: + DFtoVW(df=df, label=SimpleLabel(label="y", weight="w"), features=Feature("x")) + expected = "In argument 'weight' of 'SimpleLabel', column 'w' should be either of the following type(s): 'int', 'float'." + assert expected == str(type_error.value) + + # Tests for MulticlassLabel def test_multiclasslabel(): df = pd.DataFrame({"a": [1], "b": [0.5], "c": [-3]}) diff --git a/python/vowpalwabbit/DFtoVW.py b/python/vowpalwabbit/DFtoVW.py index a1164952f5b..39eb098cfd0 100644 --- a/python/vowpalwabbit/DFtoVW.py +++ b/python/vowpalwabbit/DFtoVW.py @@ -225,20 +225,24 @@ class SimpleLabel(object): """The simple label type for the constructor of DFtoVW.""" label = AttributeDescriptor("label", expected_type=(int, float)) + weight = AttributeDescriptor("weight", expected_type=(int, float)) - def __init__(self, label): + def __init__(self, label, weight=None): """Initialize a SimpleLabel instance. Parameters ---------- label : str The column name with the label. + weight : str + The column name with the weight. Returns ------- self : SimpleLabel """ self.label = label + self.weight = weight def process(self, df): """Returns the SimpleLabel string representation. @@ -253,7 +257,10 @@ def process(self, df): pandas.Series The SimpleLabel string representation. """ - return self.label.get_col(df) + out = self.label.get_col(df) + if self.weight is not None: + out += " " + self.weight.get_col(df) + return out class MulticlassLabel(object): @@ -292,12 +299,9 @@ def process(self, df): pandas.Series The MulticlassLabel string representation. """ - label_col = self.label.get_col(df) + out = self.label.get_col(df) if self.weight is not None: - weight_col = self.weight.get_col(df) - out = label_col + " " + weight_col - else: - out = label_col + out += " " + self.weight.get_col(df) return out