Skip to content

Commit 0e84a1c

Browse files
Revert "Implement MatchMaker model (#67)"
This reverts commit e025146.
1 parent e025146 commit 0e84a1c

File tree

3 files changed

+6
-140
lines changed

3 files changed

+6
-140
lines changed

chemicalx/models/matchmaker.py

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,14 @@
11
"""An implementation of the MatchMaker model."""
22

3-
import torch
4-
import torch.nn.functional as F # noqa:N812
5-
6-
from chemicalx.data import DrugPairBatch
7-
from chemicalx.models import Model
3+
from .base import UnimplementedModel
84

95
__all__ = [
106
"MatchMaker",
117
]
128

139

14-
class MatchMaker(Model):
15-
"""An implementation of the MatchMaker model from [matchmaker]_.
10+
class MatchMaker(UnimplementedModel):
11+
"""An implementation of the MatchMaker model.
1612
17-
.. [matchmaker] `MatchMaker: A Deep Learning Framework for Drug Synergy Prediction
18-
<https://www.biorxiv.org/content/10.1101/2020.05.24.113241v3.full>`_
13+
.. seealso:: https://github.com/AstraZeneca/chemicalx/issues/23
1914
"""
20-
21-
def __init__(
22-
self,
23-
*,
24-
context_channels: int,
25-
drug_channels: int,
26-
input_hidden_channels: int = 32,
27-
middle_hidden_channels: int = 32,
28-
final_hidden_channels: int = 32,
29-
out_channels: int = 1,
30-
dropout_rate: float = 0.5,
31-
):
32-
"""Instantiate the MatchMaker model.
33-
34-
:param context_channels: The number of context features.
35-
:param drug_channels: The number of drug features.
36-
:param input_hidden_channels: The number of hidden layer neurons in the input layer.
37-
:param middle_hidden_channels: The number of hidden layer neurons in the middle layer.
38-
:param final_hidden_channels: The number of hidden layer neurons in the final layer.
39-
:param out_channels: The number of output channels.
40-
:param dropout_rate: The rate of dropout before the scoring head is used.
41-
"""
42-
super().__init__()
43-
self.encoder = torch.nn.Linear(drug_channels + context_channels, input_hidden_channels)
44-
self.hidden_first = torch.nn.Linear(input_hidden_channels, middle_hidden_channels)
45-
self.hidden_second = torch.nn.Linear(middle_hidden_channels, middle_hidden_channels)
46-
self.dropout = torch.nn.Dropout(dropout_rate)
47-
self.scoring_head_first = torch.nn.Linear(2 * middle_hidden_channels, final_hidden_channels)
48-
self.scoring_head_second = torch.nn.Linear(final_hidden_channels, out_channels)
49-
50-
def unpack(self, batch: DrugPairBatch):
51-
"""Return the context features, left drug features, and right drug features."""
52-
return (
53-
batch.context_features,
54-
batch.drug_features_left,
55-
batch.drug_features_right,
56-
)
57-
58-
def _forward_hidden(self, tensor: torch.FloatTensor) -> torch.FloatTensor:
59-
hidden = self.encoder(tensor)
60-
hidden = F.relu(hidden)
61-
hidden = self.dropout(hidden)
62-
hidden = self.hidden_first(hidden)
63-
hidden = F.relu(hidden)
64-
hidden = self.dropout(hidden)
65-
hidden = self.hidden_second(hidden)
66-
return hidden
67-
68-
def _forward_hidden_merged(self, tensor: torch.FloatTensor) -> torch.FloatTensor:
69-
hidden = self.scoring_head_first(tensor)
70-
hidden = F.relu(hidden)
71-
hidden = self.dropout(hidden)
72-
hidden = self.scoring_head_second(hidden)
73-
hidden = torch.sigmoid(hidden)
74-
return hidden
75-
76-
def forward(
77-
self,
78-
context_features: torch.FloatTensor,
79-
drug_features_left: torch.FloatTensor,
80-
drug_features_right: torch.FloatTensor,
81-
) -> torch.FloatTensor:
82-
"""
83-
Run a forward pass of the MatchMaker model model.
84-
85-
Args:
86-
context_features (torch.FloatTensor): A matrix of biological context features.
87-
drug_features_left (torch.FloatTensor): A matrix of head drug features.
88-
drug_features_right (torch.FloatTensor): A matrix of tail drug features.
89-
Returns:
90-
hidden (torch.FloatTensor): A column vector of predicted synergy scores.
91-
"""
92-
# The left drug
93-
hidden_left = torch.cat([context_features, drug_features_left], dim=1)
94-
hidden_left = self._forward_hidden(hidden_left)
95-
96-
# The right drug
97-
hidden_right = torch.cat([context_features, drug_features_right], dim=1)
98-
hidden_right = self._forward_hidden(hidden_right)
99-
100-
# Merged
101-
hidden_merged = torch.cat([hidden_left, hidden_right], dim=1)
102-
hidden_merged = self._forward_hidden_merged(hidden_merged)
103-
104-
return hidden_merged

examples/matchmaker_examples.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

tests/unit/test_models.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,5 @@ def test_deepdds(self):
202202

203203
def test_matchmaker(self):
204204
"""Test MatchMaker."""
205-
model = MatchMaker(
206-
context_channels=self.loader.context_channels,
207-
drug_channels=self.loader.drug_channels,
208-
input_hidden_channels=32,
209-
middle_hidden_channels=16,
210-
final_hidden_channels=16,
211-
dropout_rate=0.5,
212-
)
213-
214-
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
215-
model.train()
216-
loss = torch.nn.BCELoss()
217-
for batch in self.generator:
218-
optimizer.zero_grad()
219-
prediction = model(batch.context_features, batch.drug_features_left, batch.drug_features_right)
220-
output = loss(prediction, batch.labels)
221-
output.backward()
222-
optimizer.step()
223-
assert prediction.shape[0] == batch.labels.shape[0]
205+
model = MatchMaker(x=2)
206+
assert model.x == 2

0 commit comments

Comments
 (0)