|
1 | 1 | """An implementation of the MatchMaker model.""" |
2 | 2 |
|
3 | | -import torch |
4 | | -import torch.nn.functional as F # noqa:N812 |
5 | | - |
6 | | -from chemicalx.data import DrugPairBatch |
7 | | -from chemicalx.models import Model |
| 3 | +from .base import UnimplementedModel |
8 | 4 |
|
9 | 5 | __all__ = [ |
10 | 6 | "MatchMaker", |
11 | 7 | ] |
12 | 8 |
|
13 | 9 |
|
14 | | -class MatchMaker(Model): |
15 | | - """An implementation of the MatchMaker model from [matchmaker]_. |
| 10 | +class MatchMaker(UnimplementedModel): |
| 11 | + """An implementation of the MatchMaker model. |
16 | 12 |
|
17 | | - .. [matchmaker] `MatchMaker: A Deep Learning Framework for Drug Synergy Prediction |
18 | | - <https://www.biorxiv.org/content/10.1101/2020.05.24.113241v3.full>`_ |
| 13 | + .. seealso:: https://github.com/AstraZeneca/chemicalx/issues/23 |
19 | 14 | """ |
20 | | - |
21 | | - def __init__( |
22 | | - self, |
23 | | - *, |
24 | | - context_channels: int, |
25 | | - drug_channels: int, |
26 | | - input_hidden_channels: int = 32, |
27 | | - middle_hidden_channels: int = 32, |
28 | | - final_hidden_channels: int = 32, |
29 | | - out_channels: int = 1, |
30 | | - dropout_rate: float = 0.5, |
31 | | - ): |
32 | | - """Instantiate the MatchMaker model. |
33 | | -
|
34 | | - :param context_channels: The number of context features. |
35 | | - :param drug_channels: The number of drug features. |
36 | | - :param input_hidden_channels: The number of hidden layer neurons in the input layer. |
37 | | - :param middle_hidden_channels: The number of hidden layer neurons in the middle layer. |
38 | | - :param final_hidden_channels: The number of hidden layer neurons in the final layer. |
39 | | - :param out_channels: The number of output channels. |
40 | | - :param dropout_rate: The rate of dropout before the scoring head is used. |
41 | | - """ |
42 | | - super().__init__() |
43 | | - self.encoder = torch.nn.Linear(drug_channels + context_channels, input_hidden_channels) |
44 | | - self.hidden_first = torch.nn.Linear(input_hidden_channels, middle_hidden_channels) |
45 | | - self.hidden_second = torch.nn.Linear(middle_hidden_channels, middle_hidden_channels) |
46 | | - self.dropout = torch.nn.Dropout(dropout_rate) |
47 | | - self.scoring_head_first = torch.nn.Linear(2 * middle_hidden_channels, final_hidden_channels) |
48 | | - self.scoring_head_second = torch.nn.Linear(final_hidden_channels, out_channels) |
49 | | - |
50 | | - def unpack(self, batch: DrugPairBatch): |
51 | | - """Return the context features, left drug features, and right drug features.""" |
52 | | - return ( |
53 | | - batch.context_features, |
54 | | - batch.drug_features_left, |
55 | | - batch.drug_features_right, |
56 | | - ) |
57 | | - |
58 | | - def _forward_hidden(self, tensor: torch.FloatTensor) -> torch.FloatTensor: |
59 | | - hidden = self.encoder(tensor) |
60 | | - hidden = F.relu(hidden) |
61 | | - hidden = self.dropout(hidden) |
62 | | - hidden = self.hidden_first(hidden) |
63 | | - hidden = F.relu(hidden) |
64 | | - hidden = self.dropout(hidden) |
65 | | - hidden = self.hidden_second(hidden) |
66 | | - return hidden |
67 | | - |
68 | | - def _forward_hidden_merged(self, tensor: torch.FloatTensor) -> torch.FloatTensor: |
69 | | - hidden = self.scoring_head_first(tensor) |
70 | | - hidden = F.relu(hidden) |
71 | | - hidden = self.dropout(hidden) |
72 | | - hidden = self.scoring_head_second(hidden) |
73 | | - hidden = torch.sigmoid(hidden) |
74 | | - return hidden |
75 | | - |
76 | | - def forward( |
77 | | - self, |
78 | | - context_features: torch.FloatTensor, |
79 | | - drug_features_left: torch.FloatTensor, |
80 | | - drug_features_right: torch.FloatTensor, |
81 | | - ) -> torch.FloatTensor: |
82 | | - """ |
83 | | - Run a forward pass of the MatchMaker model model. |
84 | | -
|
85 | | - Args: |
86 | | - context_features (torch.FloatTensor): A matrix of biological context features. |
87 | | - drug_features_left (torch.FloatTensor): A matrix of head drug features. |
88 | | - drug_features_right (torch.FloatTensor): A matrix of tail drug features. |
89 | | - Returns: |
90 | | - hidden (torch.FloatTensor): A column vector of predicted synergy scores. |
91 | | - """ |
92 | | - # The left drug |
93 | | - hidden_left = torch.cat([context_features, drug_features_left], dim=1) |
94 | | - hidden_left = self._forward_hidden(hidden_left) |
95 | | - |
96 | | - # The right drug |
97 | | - hidden_right = torch.cat([context_features, drug_features_right], dim=1) |
98 | | - hidden_right = self._forward_hidden(hidden_right) |
99 | | - |
100 | | - # Merged |
101 | | - hidden_merged = torch.cat([hidden_left, hidden_right], dim=1) |
102 | | - hidden_merged = self._forward_hidden_merged(hidden_merged) |
103 | | - |
104 | | - return hidden_merged |
0 commit comments