-
Notifications
You must be signed in to change notification settings - Fork 100
Implement MatchMaker model #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
97206fe
Implement MatchMaker model
andrejlamov cfc5dff
Fix typo
andrejlamov 7fcbf66
Fix typo again
andrejlamov 2047464
Split forward into helper methods to remove duplication
andrejlamov a6a69bd
Lint them all
andrejlamov 9d30302
More lint
andrejlamov 9a81380
Remove comments from example
andrejlamov 5133f74
Add ref to matchmaker
andrejlamov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,104 @@ | ||
| """An implementation of the MatchMaker model.""" | ||
|
|
||
| from .base import UnimplementedModel | ||
| import torch | ||
| import torch.nn.functional as F # noqa:N812 | ||
|
|
||
| from chemicalx.data import DrugPairBatch | ||
| from chemicalx.models import Model | ||
|
|
||
| __all__ = [ | ||
| "MatchMaker", | ||
| ] | ||
|
|
||
|
|
||
| class MatchMaker(UnimplementedModel): | ||
| """An implementation of the MatchMaker model. | ||
| class MatchMaker(Model): | ||
| """An implementation of the MatchMaker model from [matchmaker]_. | ||
|
|
||
| .. seealso:: https://github.com/AstraZeneca/chemicalx/issues/23 | ||
| .. [matchmaker] `MatchMaker: A Deep Learning Framework for Drug Synergy Prediction | ||
| <https://www.biorxiv.org/content/10.1101/2020.05.24.113241v3.full>`_ | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| context_channels: int, | ||
| drug_channels: int, | ||
| input_hidden_channels: int = 32, | ||
| middle_hidden_channels: int = 32, | ||
| final_hidden_channels: int = 32, | ||
| out_channels: int = 1, | ||
| dropout_rate: float = 0.5, | ||
| ): | ||
| """Instantiate the MatchMaker model. | ||
|
|
||
| :param context_channels: The number of context features. | ||
| :param drug_channels: The number of drug features. | ||
| :param input_hidden_channels: The number of hidden layer neurons in the input layer. | ||
| :param middle_hidden_channels: The number of hidden layer neurons in the middle layer. | ||
| :param final_hidden_channels: The number of hidden layer neurons in the final layer. | ||
| :param out_channels: The number of output channels. | ||
| :param dropout_rate: The rate of dropout before the scoring head is used. | ||
| """ | ||
| super().__init__() | ||
| self.encoder = torch.nn.Linear(drug_channels + context_channels, input_hidden_channels) | ||
| self.hidden_first = torch.nn.Linear(input_hidden_channels, middle_hidden_channels) | ||
| self.hidden_second = torch.nn.Linear(middle_hidden_channels, middle_hidden_channels) | ||
| self.dropout = torch.nn.Dropout(dropout_rate) | ||
| self.scoring_head_first = torch.nn.Linear(2 * middle_hidden_channels, final_hidden_channels) | ||
| self.scoring_head_second = torch.nn.Linear(final_hidden_channels, out_channels) | ||
|
|
||
| def unpack(self, batch: DrugPairBatch): | ||
| """Return the context features, left drug features, and right drug features.""" | ||
| return ( | ||
| batch.context_features, | ||
| batch.drug_features_left, | ||
| batch.drug_features_right, | ||
| ) | ||
|
|
||
| def _forward_hidden(self, tensor: torch.FloatTensor) -> torch.FloatTensor: | ||
| hidden = self.encoder(tensor) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.hidden_first(hidden) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.hidden_second(hidden) | ||
| return hidden | ||
|
|
||
| def _forward_hidden_merged(self, tensor: torch.FloatTensor) -> torch.FloatTensor: | ||
| hidden = self.scoring_head_first(tensor) | ||
| hidden = F.relu(hidden) | ||
| hidden = self.dropout(hidden) | ||
| hidden = self.scoring_head_second(hidden) | ||
| hidden = torch.sigmoid(hidden) | ||
| return hidden | ||
|
|
||
| def forward( | ||
| self, | ||
| context_features: torch.FloatTensor, | ||
| drug_features_left: torch.FloatTensor, | ||
| drug_features_right: torch.FloatTensor, | ||
| ) -> torch.FloatTensor: | ||
| """ | ||
| Run a forward pass of the MatchMaker model model. | ||
|
|
||
| Args: | ||
| context_features (torch.FloatTensor): A matrix of biological context features. | ||
| drug_features_left (torch.FloatTensor): A matrix of head drug features. | ||
| drug_features_right (torch.FloatTensor): A matrix of tail drug features. | ||
| Returns: | ||
| hidden (torch.FloatTensor): A column vector of predicted synergy scores. | ||
| """ | ||
| # The left drug | ||
| hidden_left = torch.cat([context_features, drug_features_left], dim=1) | ||
| hidden_left = self._forward_hidden(hidden_left) | ||
|
|
||
| # The right drug | ||
| hidden_right = torch.cat([context_features, drug_features_right], dim=1) | ||
| hidden_right = self._forward_hidden(hidden_right) | ||
|
|
||
| # Merged | ||
| hidden_merged = torch.cat([hidden_left, hidden_right], dim=1) | ||
| hidden_merged = self._forward_hidden_merged(hidden_merged) | ||
|
|
||
| return hidden_merged | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| """Example with MatchMaker.""" | ||
|
|
||
| from chemicalx import pipeline | ||
| from chemicalx.data import DrugCombDB | ||
| from chemicalx.models import MatchMaker | ||
|
|
||
|
|
||
| def main(): | ||
| """Train and evaluate the MatchMaker model.""" | ||
| dataset = DrugCombDB() | ||
| model = MatchMaker(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels) | ||
|
|
||
| results = pipeline( | ||
| dataset=dataset, | ||
| model=model, | ||
| batch_size=5120, | ||
| epochs=100, | ||
| context_features=True, | ||
| drug_features=True, | ||
| drug_molecules=False, | ||
| metrics=["roc_auc"], | ||
| ) | ||
| results.summarize() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since the code is exactly duplicated between the left and right drug (besides the input), consider splitting this into a helper function