Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
run: tox -e pyroma
- name: Check code quality
run: tox -e flake8
- name: Check typing
run: tox -e mypy
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
16 changes: 6 additions & 10 deletions chemicalx/data/batchgenerator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A module for the batch generator class."""

import math
from typing import Iterable, Iterator, List, Optional
from typing import Iterable, Iterator, Optional, Sequence

import numpy as np
import pandas as pd
Expand All @@ -27,10 +27,9 @@ def __init__(
context_features: bool,
drug_features: bool,
drug_molecules: bool,
labels: bool,
context_feature_set: Optional[ContextFeatureSet],
drug_feature_set: Optional[DrugFeatureSet],
labeled_triples: Optional[LabeledTriples],
labeled_triples: LabeledTriples,
):
"""Initialize a batch generator.

Expand All @@ -39,7 +38,6 @@ def __init__(
context_features: Indicator whether the batch should include biological context features.
drug_features: Indicator whether the batch should include drug features.
drug_molecules: Indicator whether the batch should include drug molecules
labels: Indicator whether the batch should include drug pair labels.
context_feature_set: A context feature set for feature generation.
drug_feature_set: A drug feature set for feature generation.
labeled_triples: A labeled triples object used to generate batches.
Expand All @@ -48,7 +46,6 @@ def __init__(
self.context_features = context_features
self.drug_features = drug_features
self.drug_molecules = drug_molecules
self.labels = labels
self.context_feature_set = context_feature_set
self.drug_feature_set = drug_feature_set
self.labeled_triples = labeled_triples
Expand Down Expand Up @@ -89,16 +86,15 @@ def _get_drug_molecules(self, drug_identifiers: Iterable[str]) -> Optional[Packe
return None
return self.drug_feature_set.get_molecules(drug_identifiers)

def _transform_labels(self, labels: List):
@classmethod
def _transform_labels(cls, labels: Sequence[float]) -> torch.FloatTensor:
"""Transform the labels from a chunk of the labeled triples frame.

Args:
labels (pd.Series): The drug pair binary labels.
labels: The drug pair binary labels.
Returns:
labels (torch.FloatTensor): The label target vector as a column vector.
labels : The label target vector as a column vector.
"""
if not self.labels:
return None
return torch.FloatTensor(np.array(labels).reshape(-1, 1))

def generate_batch(self, batch_frame: pd.DataFrame) -> DrugPairBatch:
Expand Down
122 changes: 4 additions & 118 deletions chemicalx/data/contextfeatureset.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,18 @@
"""A module for the context feature set class."""

from typing import Dict, Iterable
from collections import UserDict
from typing import Iterable, Mapping

import numpy as np
import torch

__all__ = [
"ContextFeatureSet",
]


class ContextFeatureSet(dict):
class ContextFeatureSet(UserDict, Mapping[str, torch.FloatTensor]):
"""Context feature set for biological/chemical context feature vectors."""

def __setitem__(self, context: str, features: np.ndarray) -> None:
"""Set the feature vector for a biological context key.

Args:
context: Biological or chemical context identifier.
features: Feature vector for the context.
"""
self.__dict__[context] = torch.FloatTensor(features)

def __getitem__(self, context: str) -> torch.FloatTensor:
"""Get the feature vector for a biological context key.

Args:
context: Biological or chemical context identifier.
Returns:
: The feature vector corresponding to the key.
"""
return self.__dict__[context]

def __len__(self) -> int:
"""Get the number of biological/chemical contexts.

Returns:
: The number of contexts.
"""
return len(self.__dict__)

def __delitem__(self, context: str) -> None:
"""Delete the feature vector for a biological context key.

Args:
context: Biological or chemical context identifier.
"""
del self.__dict__[context]

def clear(self) -> "ContextFeatureSet":
"""Delete all the contexts from the context feature set.

Returns:
: An empty context feature set.
"""
return self.__dict__.clear()

def has_context(self, context: str) -> bool:
"""Check whether a context feature set contains a context.

Args:
context: Biological or chemical context identifier.
Returns:
: Boolean describing whether the context is in the context set.
"""
return context in self.__dict__

def update(self, data: Dict[str, np.ndarray]):
"""Update a dictionary of context keys - feature vector values to a context set.

Args:
data (dict): A dictionary of context keys with feature vector values.
Returns:
ContextFeatureSet: The updated context feature set.
"""
return self.__dict__.update({context: torch.FloatTensor(features) for context, features in data.items()})

def keys(self):
"""Get the list of biological / chemical contexts in a feature set.

Returns:
list: An iterator of context identifiers.
"""
return self.__dict__.keys()

def values(self):
"""Get the iterator of context feature vectors.

Returns:
list: Feature vector iterator.
"""
return self.__dict__.values()

def items(self):
"""Get the iterator of tuples containing context identifier - feature vector pairs.

Returns:
list: An iterator of (context - feature vector) tuples.
"""
return self.__dict__.items()

def __contains__(self, context: str) -> bool:
"""Check if the context is in keys.

Args:
context (str): A context identifier.
Returns:
bool: An indicator whether the context is in the context feature set.
"""
return context in self.__dict__

def __iter__(self):
"""Iterate over the context feature set.

Returns:
iterable: An iterable of the context feature set.
"""
return iter(self.__dict__)

def get_context_count(self) -> int:
"""Get the number of biological contexts.

Returns:
: The number of contexts.
"""
return len(self.__dict__)

def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor:
"""Get the feature matrix for a list of contexts.

Expand All @@ -134,5 +21,4 @@ def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor:
Return:
features: A matrix of context features.
"""
features = torch.cat([self.__dict__[context] for context in contexts])
return features
return torch.cat([self.data[context] for context in contexts])
32 changes: 14 additions & 18 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
import torch

from .batchgenerator import BatchGenerator
from .contextfeatureset import ContextFeatureSet
Expand Down Expand Up @@ -44,8 +45,8 @@ def get_generators(
context_features: bool,
drug_features: bool,
drug_molecules: bool,
labels: bool,
**kwargs,
train_size: Optional[float] = None,
random_state: Optional[int] = None,
) -> Tuple[BatchGenerator, BatchGenerator]:
"""Generate a pre-stratified pair of batch generators."""
return cast(
Expand All @@ -56,10 +57,12 @@ def get_generators(
context_features=context_features,
drug_features=drug_features,
drug_molecules=drug_molecules,
labels=labels,
labeled_triples=labeled_triples,
)
for labeled_triples in self.get_labeled_triples().train_test_split(**kwargs)
for labeled_triples in self.get_labeled_triples().train_test_split(
train_size=train_size,
random_state=random_state,
)
),
)

Expand All @@ -69,7 +72,6 @@ def get_generator(
context_features: bool,
drug_features: bool,
drug_molecules: bool,
labels: bool,
labeled_triples: Optional[LabeledTriples] = None,
) -> BatchGenerator:
"""Initialize a batch generator.
Expand All @@ -88,7 +90,6 @@ def get_generator(
context_features=context_features,
drug_features=drug_features,
drug_molecules=drug_molecules,
labels=labels,
context_feature_set=self.get_context_features() if context_features else None,
drug_feature_set=self.get_drug_features() if drug_features else None,
labeled_triples=self.get_labeled_triples() if labeled_triples is None else labeled_triples,
Expand Down Expand Up @@ -143,10 +144,8 @@ def get_context_features(self):
"""
path = self.generate_path("context_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: np.array(v).reshape(1, -1) for k, v in raw_data.items()}
context_feature_set = ContextFeatureSet()
context_feature_set.update(raw_data)
return context_feature_set
raw_data = {k: torch.FloatTensor(np.array(v).reshape(1, -1)) for k, v in raw_data.items()}
return ContextFeatureSet(raw_data)

@property
def num_contexts(self) -> int:
Expand All @@ -169,11 +168,10 @@ def get_drug_features(self):
path = self.generate_path("drug_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {
k: {"smiles": v["smiles"], "features": np.array(v["features"]).reshape(1, -1)} for k, v in raw_data.items()
key: {"smiles": value["smiles"], "features": np.array(value["features"]).reshape(1, -1)}
for key, value in raw_data.items()
}
drug_feature_set = DrugFeatureSet()
drug_feature_set.update(raw_data)
return drug_feature_set
return DrugFeatureSet.from_dict(raw_data)

@property
def num_drugs(self) -> int:
Expand All @@ -194,10 +192,8 @@ def get_labeled_triples(self):
labeled_triples (LabeledTriples): The labeled triples in the dataset.
"""
path = self.generate_path("labeled_triples.csv")
raw_data = self.load_raw_csv_data(path)
labeled_triples = LabeledTriples()
labeled_triples.update_from_pandas(raw_data)
return labeled_triples
df = self.load_raw_csv_data(path)
return LabeledTriples(df)

@property
def num_labeled_triples(self) -> int:
Expand Down
Loading