-
Notifications
You must be signed in to change notification settings - Fork 100
Implement MatchMaker model #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
97206fe
cfc5dff
7fcbf66
2047464
a6a69bd
9d30302
9a81380
5133f74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,108 @@ | ||
| """An implementation of the MatchMaker model.""" | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F # noqa:N812 | ||
|
|
||
| from chemicalx.data import DrugPairBatch | ||
| from chemicalx.models import Model | ||
|
|
||
| from .base import UnimplementedModel | ||
|
|
||
| __all__ = [ | ||
| "MatchMaker", | ||
| ] | ||
|
|
||
|
|
||
| class MatchMaker(UnimplementedModel): | ||
| class MatchMaker(Model): | ||
| """An implementation of the MatchMaker model. | ||
|
|
||
| .. seealso:: https://github.com/AstraZeneca/chemicalx/issues/23 | ||
| .. [matchmaker] `MatchMaker: A Deep Learning Framework for Drug Synergy Prediction | ||
| <https://www.biorxiv.org/content/10.1101/2020.05.24.113241v3.full>`_ | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| context_channels: int, | ||
| drug_channels: int, | ||
| input_hidden_channels: int = 32, | ||
| middle_hidden_channels: int = 32, | ||
| final_hidden_channels: int = 32, | ||
| out_channels: int = 1, | ||
| dropout_rate: float = 0.5, | ||
| ): | ||
| """Instantiate the MatchMaker model. | ||
|
|
||
| :param context_channels: The number of context features. | ||
| :param drug_channels: The number of drug features. | ||
| :param input_hidden_channels: The number of hidden layer neurons in the input layer. | ||
| :param middle_hidden_channels: The number of hidden layer neurons in the middle layer. | ||
| :param final_hidden_channels: The number of hidden layer neurons in the final layer. | ||
| :param out_channels: The number of output channels. | ||
| :param dropout_rate: The rate of dropout before the scoring head is used. | ||
| """ | ||
|
|
||
|
||
| super().__init__() | ||
| self.encoder = torch.nn.Linear(drug_channels + context_channels, input_hidden_channels) | ||
| self.hidden_first = torch.nn.Linear(input_hidden_channels, middle_hidden_channels) | ||
| self.hidden_second = torch.nn.Linear(middle_hidden_channels, middle_hidden_channels) | ||
| self.dropout = torch.nn.Dropout(dropout_rate) | ||
| self.scoring_head_first = torch.nn.Linear(2 * middle_hidden_channels, final_hidden_channels) | ||
| self.scoring_head_second = torch.nn.Linear(final_hidden_channels, out_channels) | ||
|
|
||
| def unpack(self, batch: DrugPairBatch): | ||
| """Return the context features, left drug features, and right drug features.""" | ||
| return ( | ||
| batch.context_features, | ||
| batch.drug_features_left, | ||
| batch.drug_features_right, | ||
| ) | ||
|
|
||
| def _forward_hidden(self, tensor: torch.FloatTensor) -> torch.FloatTensor: | ||
| hidden = self.encoder(tensor) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.hidden_first(hidden) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.hidden_second(hidden) | ||
| return hidden | ||
|
|
||
| def _forward_hidden_merged(self, tensor: torch.FloatTensor) -> torch.FloatTensor: | ||
| hidden = self.scoring_head_first(tensor) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.scoring_head_second(hidden) | ||
| hidden = torch.sigmoid(hidden) | ||
| return hidden | ||
|
|
||
| def forward( | ||
| self, | ||
| context_features: torch.FloatTensor, | ||
| drug_features_left: torch.FloatTensor, | ||
| drug_features_right: torch.FloatTensor, | ||
| ) -> torch.FloatTensor: | ||
| """ | ||
| Run a forward pass of the MatchMaker model model. | ||
|
|
||
| Args: | ||
| context_features (torch.FloatTensor): A matrix of biological context features. | ||
| drug_features_left (torch.FloatTensor): A matrix of head drug features. | ||
| drug_features_right (torch.FloatTensor): A matrix of tail drug features. | ||
| Returns: | ||
| hidden (torch.FloatTensor): A column vector of predicted synergy scores. | ||
| """ | ||
|
|
||
| # The left drug | ||
| hidden_left = torch.cat([context_features, drug_features_left], dim=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since the code is exactly duplicated between the left and right drug (besides the input), consider splitting this into a helper function |
||
| hidden_left = self._forward_hidden(hidden_left) | ||
|
|
||
| # The right drug | ||
| hidden_right = torch.cat([context_features, drug_features_right], dim=1) | ||
| hidden_right = self._forward_hidden(hidden_right) | ||
|
|
||
| # Merged | ||
| hidden_merged = torch.cat([hidden_left, hidden_right], dim=1) | ||
| hidden_merged = self._forward_hidden_merged(hidden_merged) | ||
|
|
||
| return hidden_merged | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| """Example with MatchMaker.""" | ||
|
|
||
| from chemicalx import pipeline | ||
| from chemicalx.data import DrugCombDB | ||
| from chemicalx.models import MatchMaker | ||
|
|
||
|
|
||
| def main(): | ||
| """Train and evaluate the MatchMaker model.""" | ||
| dataset = DrugCombDB() | ||
| model = MatchMaker(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels) | ||
|
|
||
| results = pipeline( | ||
| dataset=dataset, | ||
| model=model, | ||
| batch_size=5120, | ||
| epochs=100, | ||
| context_features=True, | ||
| drug_features=True, | ||
| drug_molecules=False, | ||
| metrics=["roc_auc"], | ||
| ) | ||
| results.summarize() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
|
||
| # ctx_features = torch.FloatTensor(np.random.uniform(0, 1, (1000, ctx_chs))) | ||
| # drug_features_left = torch.FloatTensor(np.random.uniform(0, 1, (1000, drug_chs))) | ||
| # drug_features_right = torch.FloatTensor(np.random.uniform(0, 1, (1000, drug_chs))) | ||
|
|
||
| # model.forward(ctx_features, drug_features_left, drug_features_right) |
Uh oh!
There was an error while loading. Please reload this page.