diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index bbec090..6d1bbaa 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -32,6 +32,8 @@ jobs: run: tox -e pyroma - name: Check code quality run: tox -e flake8 + - name: Check typing + run: tox -e mypy docs: name: Documentation runs-on: ubuntu-latest diff --git a/chemicalx/data/batchgenerator.py b/chemicalx/data/batchgenerator.py index 9e5d50a..1bcbfe4 100644 --- a/chemicalx/data/batchgenerator.py +++ b/chemicalx/data/batchgenerator.py @@ -1,7 +1,7 @@ """A module for the batch generator class.""" import math -from typing import Iterable, Iterator, List, Optional +from typing import Iterable, Iterator, Optional, Sequence import numpy as np import pandas as pd @@ -27,10 +27,9 @@ def __init__( context_features: bool, drug_features: bool, drug_molecules: bool, - labels: bool, context_feature_set: Optional[ContextFeatureSet], drug_feature_set: Optional[DrugFeatureSet], - labeled_triples: Optional[LabeledTriples], + labeled_triples: LabeledTriples, ): """Initialize a batch generator. @@ -39,7 +38,6 @@ def __init__( context_features: Indicator whether the batch should include biological context features. drug_features: Indicator whether the batch should include drug features. drug_molecules: Indicator whether the batch should include drug molecules - labels: Indicator whether the batch should include drug pair labels. context_feature_set: A context feature set for feature generation. drug_feature_set: A drug feature set for feature generation. labeled_triples: A labeled triples object used to generate batches. @@ -48,7 +46,6 @@ def __init__( self.context_features = context_features self.drug_features = drug_features self.drug_molecules = drug_molecules - self.labels = labels self.context_feature_set = context_feature_set self.drug_feature_set = drug_feature_set self.labeled_triples = labeled_triples @@ -89,16 +86,15 @@ def _get_drug_molecules(self, drug_identifiers: Iterable[str]) -> Optional[Packe return None return self.drug_feature_set.get_molecules(drug_identifiers) - def _transform_labels(self, labels: List): + @classmethod + def _transform_labels(cls, labels: Sequence[float]) -> torch.FloatTensor: """Transform the labels from a chunk of the labeled triples frame. Args: - labels (pd.Series): The drug pair binary labels. + labels: The drug pair binary labels. Returns: - labels (torch.FloatTensor): The label target vector as a column vector. + labels : The label target vector as a column vector. """ - if not self.labels: - return None return torch.FloatTensor(np.array(labels).reshape(-1, 1)) def generate_batch(self, batch_frame: pd.DataFrame) -> DrugPairBatch: diff --git a/chemicalx/data/contextfeatureset.py b/chemicalx/data/contextfeatureset.py index c82a4ef..35b8738 100644 --- a/chemicalx/data/contextfeatureset.py +++ b/chemicalx/data/contextfeatureset.py @@ -1,8 +1,8 @@ """A module for the context feature set class.""" -from typing import Dict, Iterable +from collections import UserDict +from typing import Iterable, Mapping -import numpy as np import torch __all__ = [ @@ -10,122 +10,9 @@ ] -class ContextFeatureSet(dict): +class ContextFeatureSet(UserDict, Mapping[str, torch.FloatTensor]): """Context feature set for biological/chemical context feature vectors.""" - def __setitem__(self, context: str, features: np.ndarray) -> None: - """Set the feature vector for a biological context key. - - Args: - context: Biological or chemical context identifier. - features: Feature vector for the context. - """ - self.__dict__[context] = torch.FloatTensor(features) - - def __getitem__(self, context: str) -> torch.FloatTensor: - """Get the feature vector for a biological context key. - - Args: - context: Biological or chemical context identifier. - Returns: - : The feature vector corresponding to the key. - """ - return self.__dict__[context] - - def __len__(self) -> int: - """Get the number of biological/chemical contexts. - - Returns: - : The number of contexts. - """ - return len(self.__dict__) - - def __delitem__(self, context: str) -> None: - """Delete the feature vector for a biological context key. - - Args: - context: Biological or chemical context identifier. - """ - del self.__dict__[context] - - def clear(self) -> "ContextFeatureSet": - """Delete all the contexts from the context feature set. - - Returns: - : An empty context feature set. - """ - return self.__dict__.clear() - - def has_context(self, context: str) -> bool: - """Check whether a context feature set contains a context. - - Args: - context: Biological or chemical context identifier. - Returns: - : Boolean describing whether the context is in the context set. - """ - return context in self.__dict__ - - def update(self, data: Dict[str, np.ndarray]): - """Update a dictionary of context keys - feature vector values to a context set. - - Args: - data (dict): A dictionary of context keys with feature vector values. - Returns: - ContextFeatureSet: The updated context feature set. - """ - return self.__dict__.update({context: torch.FloatTensor(features) for context, features in data.items()}) - - def keys(self): - """Get the list of biological / chemical contexts in a feature set. - - Returns: - list: An iterator of context identifiers. - """ - return self.__dict__.keys() - - def values(self): - """Get the iterator of context feature vectors. - - Returns: - list: Feature vector iterator. - """ - return self.__dict__.values() - - def items(self): - """Get the iterator of tuples containing context identifier - feature vector pairs. - - Returns: - list: An iterator of (context - feature vector) tuples. - """ - return self.__dict__.items() - - def __contains__(self, context: str) -> bool: - """Check if the context is in keys. - - Args: - context (str): A context identifier. - Returns: - bool: An indicator whether the context is in the context feature set. - """ - return context in self.__dict__ - - def __iter__(self): - """Iterate over the context feature set. - - Returns: - iterable: An iterable of the context feature set. - """ - return iter(self.__dict__) - - def get_context_count(self) -> int: - """Get the number of biological contexts. - - Returns: - : The number of contexts. - """ - return len(self.__dict__) - def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor: """Get the feature matrix for a list of contexts. @@ -134,5 +21,4 @@ def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor: Return: features: A matrix of context features. """ - features = torch.cat([self.__dict__[context] for context in contexts]) - return features + return torch.cat([self.data[context] for context in contexts]) diff --git a/chemicalx/data/datasetloader.py b/chemicalx/data/datasetloader.py index c213c53..a8201d2 100644 --- a/chemicalx/data/datasetloader.py +++ b/chemicalx/data/datasetloader.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +import torch from .batchgenerator import BatchGenerator from .contextfeatureset import ContextFeatureSet @@ -44,8 +45,8 @@ def get_generators( context_features: bool, drug_features: bool, drug_molecules: bool, - labels: bool, - **kwargs, + train_size: Optional[float] = None, + random_state: Optional[int] = None, ) -> Tuple[BatchGenerator, BatchGenerator]: """Generate a pre-stratified pair of batch generators.""" return cast( @@ -56,10 +57,12 @@ def get_generators( context_features=context_features, drug_features=drug_features, drug_molecules=drug_molecules, - labels=labels, labeled_triples=labeled_triples, ) - for labeled_triples in self.get_labeled_triples().train_test_split(**kwargs) + for labeled_triples in self.get_labeled_triples().train_test_split( + train_size=train_size, + random_state=random_state, + ) ), ) @@ -69,7 +72,6 @@ def get_generator( context_features: bool, drug_features: bool, drug_molecules: bool, - labels: bool, labeled_triples: Optional[LabeledTriples] = None, ) -> BatchGenerator: """Initialize a batch generator. @@ -88,7 +90,6 @@ def get_generator( context_features=context_features, drug_features=drug_features, drug_molecules=drug_molecules, - labels=labels, context_feature_set=self.get_context_features() if context_features else None, drug_feature_set=self.get_drug_features() if drug_features else None, labeled_triples=self.get_labeled_triples() if labeled_triples is None else labeled_triples, @@ -143,10 +144,8 @@ def get_context_features(self): """ path = self.generate_path("context_set.json") raw_data = self.load_raw_json_data(path) - raw_data = {k: np.array(v).reshape(1, -1) for k, v in raw_data.items()} - context_feature_set = ContextFeatureSet() - context_feature_set.update(raw_data) - return context_feature_set + raw_data = {k: torch.FloatTensor(np.array(v).reshape(1, -1)) for k, v in raw_data.items()} + return ContextFeatureSet(raw_data) @property def num_contexts(self) -> int: @@ -169,11 +168,10 @@ def get_drug_features(self): path = self.generate_path("drug_set.json") raw_data = self.load_raw_json_data(path) raw_data = { - k: {"smiles": v["smiles"], "features": np.array(v["features"]).reshape(1, -1)} for k, v in raw_data.items() + key: {"smiles": value["smiles"], "features": np.array(value["features"]).reshape(1, -1)} + for key, value in raw_data.items() } - drug_feature_set = DrugFeatureSet() - drug_feature_set.update(raw_data) - return drug_feature_set + return DrugFeatureSet.from_dict(raw_data) @property def num_drugs(self) -> int: @@ -194,10 +192,8 @@ def get_labeled_triples(self): labeled_triples (LabeledTriples): The labeled triples in the dataset. """ path = self.generate_path("labeled_triples.csv") - raw_data = self.load_raw_csv_data(path) - labeled_triples = LabeledTriples() - labeled_triples.update_from_pandas(raw_data) - return labeled_triples + df = self.load_raw_csv_data(path) + return LabeledTriples(df) @property def num_labeled_triples(self) -> int: diff --git a/chemicalx/data/drugfeatureset.py b/chemicalx/data/drugfeatureset.py index 30c1462..aec0d8a 100644 --- a/chemicalx/data/drugfeatureset.py +++ b/chemicalx/data/drugfeatureset.py @@ -1,6 +1,7 @@ """A module for the drug feature set.""" -from typing import Dict, Iterable, Union +from collections import UserDict +from typing import Dict, Iterable, Mapping, Union import torch from torchdrug.data import Graph, Molecule, PackedGraph @@ -10,150 +11,38 @@ ] -class DrugFeatureSet(dict): +class DrugFeatureSet(UserDict, Mapping[str, Mapping[str, Union[torch.FloatTensor, Molecule]]]): """Drug feature set for compounds.""" - def __setitem__(self, drug: str, features: Dict[str, Union[str, torch.FloatTensor]]): - """Set the features for a compound key. - - Args: - drug (str): Drug identifier. - features (dict): Dictionary of smiles string and molecular features. - """ - self.__dict__[drug] = {} - self.__dict__[drug]["features"] = torch.FloatTensor(features["features"]) - self.__dict__[drug]["molecule"] = Molecule.from_smiles(features["smiles"]) - - def __getitem__(self, drug: str): - """Get the features for a drug key. - - Args: - drug (str): Drug identifier. - Returns: - dict: The drug features corresponding to the key. - """ - return self.__dict__[drug] - - def __len__(self): - """Get the number of drugs. - - Returns: - int: The number of drugs. - """ - return len(self.__dict__) - - def __delitem__(self, drug: str): - """Delete the features for a drug key. - - Args: - drug (str): Drug identifier. - """ - del self.__dict__[drug] - - def clear(self): - """Delete all the drugs from the drug feature set. - - Returns: - DrugFeatureSet: An empty drug feature set. - """ - return self.__dict__.clear() - - def has_drug(self, drug: str): - """Check whether a drug feature set contains a drug. - - Args: - drug (str): Drug identifier. - Returns: - bool: Boolean describing whether the drug is in the drug set. - """ - return drug in self.__dict__ - - def update(self, data: Dict[str, Dict]): - """Add a dictionary of drug keys - feature dictionaries to a drug set. - - Args: - data (dict): A dictionary of drug keys with feature dictionaries. - Returns: - DrugFeatureSet: The updated drug feature set. - """ - return self.__dict__.update( + @classmethod + def from_dict(cls, data: Dict[str, Dict]) -> "DrugFeatureSet": + """Generate a drug feature set from a data dictionary.""" + return cls( { - drug: { + key: { "features": torch.FloatTensor(features["features"]), "molecule": Molecule.from_smiles(features["smiles"]), } - for drug, features in data.items() + for key, features in data.items() } ) - def keys(self): - """Get the drugs in a feature set. - - Returns: - list: An iterator of drug identifiers. - """ - return self.__dict__.keys() - - def values(self): - """Get the iterator of drug features. - - Returns: - list: Feature iterator. - """ - return self.__dict__.values() - - def items(self): - """Get the iterator of tuples containing drug identifier - feature pairs. - - Returns: - list: An iterator of (drug - feature dictionary) tuples. - """ - return self.__dict__.items() - - def __contains__(self, drug: str): - """Check if the drug is in the drug feature set. - - Args: - drug (str): A drug identifier. - Returns: - bool: An indicator whether the drug is in the drug feature set. - """ - return drug in self.__dict__ - - def __iter__(self): - """Iterate over the drug feature set. - - Returns: - iterable: An iterable of the drug feature set. - """ - return iter(self.__dict__) - - def get_drug_count(self) -> int: - """Get the number of drugs. - - Returns: - int: The number of drugs. - """ - return len(self.__dict__) - def get_feature_matrix(self, drugs: Iterable[str]) -> torch.FloatTensor: """Get the drug feature matrix for a list of drugs. Args: - drugs (list): A list of drug identifiers. + drugs: A list of drug identifiers. Return: - features (torch.FloatTensor): A matrix of drug features. + : A matrix of drug features. """ - features = torch.cat([self.__dict__[drug]["features"] for drug in drugs]) - return features + return torch.cat([self.data[drug]["features"] for drug in drugs]) def get_molecules(self, drugs: Iterable[str]) -> PackedGraph: """Get the molecular structures. Args: - drugs (list): A list of drug identifiers. + drugs: A list of drug identifiers. Return: - molecules (torch.PackedGraph): The molecules batched together for message passing. + : The molecules batched together for message passing. """ - molecules = Graph.pack([self.__dict__[drug]["molecule"] for drug in drugs]) - return molecules + return Graph.pack([self.data[drug]["molecule"] for drug in drugs]) diff --git a/chemicalx/data/labeledtriples.py b/chemicalx/data/labeledtriples.py index 6da7193..9ecacf8 100644 --- a/chemicalx/data/labeledtriples.py +++ b/chemicalx/data/labeledtriples.py @@ -1,6 +1,6 @@ """A module for the labeled triples class.""" -from typing import List, Tuple +from typing import ClassVar, Iterable, Mapping, Optional, Sequence, Tuple, Union import pandas as pd from sklearn.model_selection import train_test_split @@ -11,11 +11,14 @@ class LabeledTriples: """Labeled triples for drug pair scoring.""" - def __init__(self): + columns: ClassVar[Sequence[str]] = ("drug_1", "drug_2", "context", "label") + dtype: ClassVar[Mapping[str, type]] = {"drug_1": str, "drug_2": str, "context": str, "label": float} + + def __init__(self, data: Union[pd.DataFrame, Iterable[Sequence]]): """Initialize the labeled triples object.""" - self.columns = ["drug_1", "drug_2", "context", "label"] - self.types = {"drug_1": str, "drug_2": str, "context": str, "label": float} - self.data = pd.DataFrame(columns=self.columns).astype(self.types) + if not isinstance(data, pd.DataFrame): + data = pd.DataFrame(data, columns=self.columns).astype(self.dtype) + self.data = data def __len__(self) -> int: """Get the number of triples.""" @@ -25,37 +28,16 @@ def drop_duplicates(self): """Drop the duplicated entries.""" self.data = self.data.drop_duplicates() - def update_from_pandas(self, data: pd.DataFrame): - """ - Update the labeled triples from a dataframe. - - Args: - data (pd.DataFrame): A dataframe of labeled triples. - """ - self.data = pd.concat([self.data, data]) - - def update_from_list(self, data: List[List]): - """ - Update the labeled triples from a list. - - Args: - data (list): A list of labeled triples. - """ - data = pd.DataFrame(data, columns=self.columns) - self.data = pd.concat([self.data, data]) - - def __add__(self, value): + def __add__(self, value: "LabeledTriples") -> "LabeledTriples": """ Add the triples in two LabeledTriples objects together - syntactic sugar for '+'. Args: - value (LabeledTriples): Another LabeledTriples object for the addition. + value: Another LabeledTriples object for the addition. Returns: - new_triples (LabeledTriples): A LabeledTriples object after the addition. + : A LabeledTriples object after the addition. """ - new_triples = LabeledTriples() - new_triples.update_from_pandas(pd.concat([self.data, value.data])) - return new_triples + return LabeledTriples(pd.concat([self.data, value.data])) def get_drug_count(self) -> int: """ @@ -131,20 +113,18 @@ def get_negative_rate(self) -> float: """ return 1.0 - self.data["label"].mean() - def train_test_split(self, train_size: float = 0.8, random_state: int = 42) -> Tuple: + def train_test_split( + self, train_size: Optional[float] = None, random_state: Optional[int] = 42 + ) -> Tuple["LabeledTriples", "LabeledTriples"]: """ Split the LabeledTriples object for training and testing. Args: - train_size (float): The ratio of training triples. Default is 0.8. - random_state (int): The random seed. Default is 42. + train_size: The ratio of training triples. Default is 0.8 if None is passed. + random_state: The random seed. Default is 42. Set to none for no fixed seed. Returns train_labeled_triples (LabeledTriples): The training triples. test_labeled_triples (LabeledTriples): The testing triples. """ - train_data, test_data = train_test_split(self.data, train_size=train_size, random_state=random_state) - train_labeled_triples = LabeledTriples() - test_labeled_triples = LabeledTriples() - train_labeled_triples.update_from_pandas(train_data) - test_labeled_triples.update_from_pandas(test_data) - return train_labeled_triples, test_labeled_triples + train_data, test_data = train_test_split(self.data, train_size=train_size or 0.8, random_state=random_state) + return LabeledTriples(train_data), LabeledTriples(test_data) diff --git a/chemicalx/pipeline.py b/chemicalx/pipeline.py index 37e4a16..25f8317 100644 --- a/chemicalx/pipeline.py +++ b/chemicalx/pipeline.py @@ -77,7 +77,8 @@ def pipeline( context_features: bool, drug_features: bool, drug_molecules: bool, - labels: bool, + train_size: Optional[float] = None, + random_state: Optional[int] = None, ) -> Result: """Run the training and evaluation pipeline. @@ -113,8 +114,10 @@ def pipeline( Indicator whether the batch should include drug features. :param drug_molecules: Indicator whether the batch should include drug molecules - :param labels: - Indicator whether the batch should include drug pair labels. + :param train_size: + The ratio of training triples. Default is 0.8 if None is passed. + :param random_state: + The random seed for splitting the triples. Default is 42. Set to none for no fixed seed. :returns: A result object with the trained model and evaluation results """ @@ -124,7 +127,8 @@ def pipeline( context_features=context_features, drug_features=drug_features, drug_molecules=drug_molecules, - labels=labels, + train_size=train_size, + random_state=random_state, ) model = model_resolver.make(model, model_kwargs) diff --git a/examples/deepsynergy_examples.py b/examples/deepsynergy_examples.py index 4ca1432..09ced70 100644 --- a/examples/deepsynergy_examples.py +++ b/examples/deepsynergy_examples.py @@ -17,7 +17,6 @@ def main(): context_features=True, drug_features=True, drug_molecules=False, - labels=True, ) results.summarize() diff --git a/examples/epgcnds_examples.py b/examples/epgcnds_examples.py index 7e02376..f5ff40c 100644 --- a/examples/epgcnds_examples.py +++ b/examples/epgcnds_examples.py @@ -18,7 +18,6 @@ def main(): context_features=True, drug_features=True, drug_molecules=True, - labels=True, ) results.summarize() diff --git a/tests/unit/test_batching.py b/tests/unit/test_batching.py index 3ed1560..3818c36 100644 --- a/tests/unit/test_batching.py +++ b/tests/unit/test_batching.py @@ -15,9 +15,6 @@ class TestGeneratorDrugCombDB(unittest.TestCase): def setUpClass(cls) -> None: """Set up the class with a dataset loader.""" cls.loader = DatasetLoader("drugcombdb") - cls.drug_feature_set = cls.loader.get_drug_features() - cls.context_feature_set = cls.loader.get_context_features() - cls.labeled_triples = cls.loader.get_labeled_triples() def test_all_true(self): """Test sizes of drug features during batch generation.""" @@ -26,8 +23,7 @@ def test_all_true(self): context_features=True, drug_features=True, drug_molecules=True, - labels=True, - labeled_triples=self.labeled_triples, + labeled_triples=self.loader.get_labeled_triples(), ) for batch in generator: assert batch.drug_features_left.shape[1] == 256 @@ -40,11 +36,10 @@ def test_set_all_false(self): context_features=False, drug_features=False, drug_molecules=False, - labels=False, - labeled_triples=self.labeled_triples, + labeled_triples=self.loader.get_labeled_triples(), ) for batch in generator: assert batch.drug_features_left is None assert batch.drug_molecules_left is None - assert batch.labels is None + assert batch.labels is not None assert batch.context_features is None diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index 868e53b..83cde28 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -1,6 +1,7 @@ """Tests for datasets.""" import unittest +from typing import ClassVar from chemicalx.data import DatasetLoader @@ -8,94 +9,102 @@ class TestDrugComb(unittest.TestCase): """A test case for DrugComb.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case.""" - cls.dataset_loader = DatasetLoader("drugcomb") + cls.loader = DatasetLoader("drugcomb") def test_get_context_features(self): """Test the number of context features.""" - context_feature_set = self.dataset_loader.get_context_features() + context_feature_set = self.loader.get_context_features() assert len(context_feature_set) == 288 def test_get_drug_features(self): """Test the number of drug features.""" - drug_feature_set = self.dataset_loader.get_drug_features() + drug_feature_set = self.loader.get_drug_features() assert len(drug_feature_set) == 4146 def test_get_labeled_triples(self): """Test the shape of the labeled triples.""" - labeled_triples = self.dataset_loader.get_labeled_triples() + labeled_triples = self.loader.get_labeled_triples() assert labeled_triples.data.shape == (659333, 4) class TestDrugCombDB(unittest.TestCase): """A test case for DrugCombDB.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case.""" - cls.dataset_loader = DatasetLoader("drugcombdb") + cls.loader = DatasetLoader("drugcombdb") def test_get_context_features(self): """Test the number of context features.""" - context_feature_set = self.dataset_loader.get_context_features() + context_feature_set = self.loader.get_context_features() assert len(context_feature_set) == 112 def test_get_drug_features(self): """Test the number of drug features.""" - drug_feature_set = self.dataset_loader.get_drug_features() + drug_feature_set = self.loader.get_drug_features() assert len(drug_feature_set) == 2956 def test_get_labeled_triples(self): """Test the shape of the labeled triples.""" - labeled_triples = self.dataset_loader.get_labeled_triples() + labeled_triples = self.loader.get_labeled_triples() assert labeled_triples.data.shape == (191391, 4) class TestDeepDDI(unittest.TestCase): """A test case for DeepDDI.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case.""" - cls.dataset_loader = DatasetLoader("drugbankddi") + cls.loader = DatasetLoader("drugbankddi") def test_get_context_features(self): """Test the number of context features.""" - context_feature_set = self.dataset_loader.get_context_features() + context_feature_set = self.loader.get_context_features() assert len(context_feature_set) == 86 def test_get_drug_features(self): """Test the number of drug features.""" - drug_feature_set = self.dataset_loader.get_drug_features() + drug_feature_set = self.loader.get_drug_features() assert len(drug_feature_set) == 1706 def test_get_labeled_triples(self): """Test the shape of the labeled triples.""" - labeled_triples = self.dataset_loader.get_labeled_triples() + labeled_triples = self.loader.get_labeled_triples() assert labeled_triples.data.shape == (575307, 4) class TestTwoSides(unittest.TestCase): """A test case for TwoSides.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case.""" - cls.dataset_loader = DatasetLoader("twosides") + cls.loader = DatasetLoader("twosides") def test_get_context_features(self): """Test the number of context features.""" - context_feature_set = self.dataset_loader.get_context_features() + context_feature_set = self.loader.get_context_features() assert len(context_feature_set) == 10 def test_get_drug_features(self): """Test the number of drug features.""" - drug_feature_set = self.dataset_loader.get_drug_features() + drug_feature_set = self.loader.get_drug_features() assert len(drug_feature_set) == 644 def test_get_labeled_triples(self): """Test the shape of the labeled triples.""" - labeled_triples = self.dataset_loader.get_labeled_triples() + labeled_triples = self.loader.get_labeled_triples() assert labeled_triples.data.shape == (499582, 4) diff --git a/tests/unit/test_datastructures.py b/tests/unit/test_datastructures.py index b5b3443..062ef9f 100644 --- a/tests/unit/test_datastructures.py +++ b/tests/unit/test_datastructures.py @@ -4,6 +4,8 @@ import numpy as np import pandas as pd +import torch +from torchdrug.data import Molecule from chemicalx.data import ContextFeatureSet, DrugFeatureSet, LabeledTriples @@ -21,7 +23,6 @@ def test_get(self): """Test getting data.""" assert self.context_feature_set["context_2"].shape == (1, 3) assert "context_2" in self.context_feature_set - assert self.context_feature_set.has_context("context_2") def test_delete(self): """Test deleting data.""" @@ -35,17 +36,14 @@ def test_len(self): def test_contexts_features(self): """Get the number of elements.""" + assert len(self.context_feature_set) == 2 assert len(list(self.context_feature_set.keys())) == 2 assert len(list(self.context_feature_set.values())) == 2 assert len(list(self.context_feature_set.items())) == 2 - def test_basic_statistics(self): - """Test the number of contexts.""" - assert self.context_feature_set.get_context_count() == 2 - def test_update_and_delete(self): """Test updating and deleting entries.""" - self.context_feature_set.update({"context_3": np.array([[1.1, 2.2, 3.4]])}) + self.context_feature_set["context_3"] = torch.FloatTensor(np.array([[1.1, 2.2, 3.4]])) assert len(self.context_feature_set) == 3 del self.context_feature_set["context_3"] assert len(self.context_feature_set) == 2 @@ -67,15 +65,17 @@ class TestDrugFeatureSet(unittest.TestCase): def setUp(self): """Set up the test case.""" - self.drug_feature_set = DrugFeatureSet() - self.drug_feature_set["drug_1"] = {"smiles": "CN=C=O", "features": np.array([[0.0, 1.7, 2.3]])} - self.drug_feature_set["drug_2"] = {"smiles": "[Cu+2].[O-]S(=O)(=O)[O-]", "features": np.array([[1, 0, 8]])} + self.drug_feature_set = DrugFeatureSet.from_dict( + { + "drug_1": {"smiles": "CN=C=O", "features": np.array([[0.0, 1.7, 2.3]])}, + "drug_2": {"smiles": "[Cu+2].[O-]S(=O)(=O)[O-]", "features": np.array([[1, 0, 8]])}, + } + ) def test_get(self): """Test getting data.""" assert self.drug_feature_set["drug_1"]["features"].shape == (1, 3) assert "drug_2" in self.drug_feature_set - assert self.drug_feature_set.has_drug("drug_2") def test_delete(self): """Test deleting data.""" @@ -89,19 +89,17 @@ def test_len(self): def test_drug_features(self): """Get the number of elements.""" + assert len(self.drug_feature_set) == 2 assert len(list(self.drug_feature_set.keys())) == 2 assert len(list(self.drug_feature_set.values())) == 2 assert len(list(self.drug_feature_set.items())) == 2 - def test_basic_statistics(self): - """Test the number of drugs.""" - assert self.drug_feature_set.get_drug_count() == 2 - def test_update_and_delete(self): """Test updating and deleting entries.""" - self.drug_feature_set.update( - {"drug_3": {"smiles": " CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "features": np.array([[1.1, 2.2, 3.4]])}} - ) + self.drug_feature_set["drug_3"] = { + "molecule": Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C"), + "features": torch.FloatTensor(np.array([[1.1, 2.2, 3.4]])), + } assert len(self.drug_feature_set) == 3 del self.drug_feature_set["drug_3"] assert len(self.drug_feature_set) == 2 @@ -128,17 +126,14 @@ class TestLabeledTriples(unittest.TestCase): def setUp(self): """Set up the test case.""" - self.labeled_triples = LabeledTriples() - self.other_labeled_triples = LabeledTriples() - data = pd.DataFrame( [["drug_a", "drug_b", "context_a", 1.0], ["drug_b", "drug_c", "context_b", 0.0]], columns=["drug_1", "drug_2", "context", "label"], ) - self.labeled_triples.update_from_pandas(data) + self.labeled_triples = LabeledTriples(data) data = [["drug_a", "drug_b", "context_a", 1.0], ["drug_a", "drug_c", "context_b", 0.0]] - self.other_labeled_triples.update_from_list(data) + self.other_labeled_triples = LabeledTriples(data) def test_from_pandas(self): """Test loading from pandas.""" diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 2bc796d..f39fcf2 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -2,6 +2,7 @@ import inspect import unittest +from typing import ClassVar import torch from class_resolver import Resolver @@ -30,6 +31,8 @@ class TestPipeline(unittest.TestCase): """Test the unified training and evaluation pipeline.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case with a dataset.""" @@ -46,7 +49,6 @@ def test_train_context(self): context_features=True, drug_features=True, drug_molecules=False, - labels=True, ) self.assertIsInstance(results.roc_auc, float) @@ -62,7 +64,6 @@ def test_train_contextless(self): context_features=True, drug_features=True, drug_molecules=True, - labels=True, ) self.assertIsInstance(results.roc_auc, float) @@ -105,6 +106,8 @@ def test_defaults(self): class TestModels(unittest.TestCase): """A test case for models.""" + loader: ClassVar[DatasetLoader] + @classmethod def setUpClass(cls) -> None: """Set up the test case with a dataset.""" @@ -117,7 +120,6 @@ def setUp(self): context_features=True, drug_features=True, drug_molecules=True, - labels=True, train_size=0.005, )