Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion chemicalx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from chemicalx.models import ( # noqa:F401,F403
caster,
deepcci,
deepddi,
deepdds,
deepdrug,
Expand Down
2 changes: 0 additions & 2 deletions chemicalx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from .base import Model, UnimplementedModel
from .caster import CASTER
from .deepcci import DeepCCI
from .deepddi import DeepDDI
from .deepdds import DeepDDS
from .deepdrug import DeepDrug
Expand All @@ -22,7 +21,6 @@
"UnimplementedModel",
# Implementations
"CASTER",
"DeepCCI",
"DeepDDI",
"DeepDDS",
"DeepDrug",
Expand Down
14 changes: 0 additions & 14 deletions chemicalx/models/deepcci.py

This file was deleted.

8 changes: 5 additions & 3 deletions chemicalx/models/deepsynergy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ class DeepSynergy(Model):
def __init__(
self,
*,
context_channels: int,
drug_channels: int,
context_channels: int = 128,
drug_channels: int = 128,
input_hidden_channels: int = 32,
middle_hidden_channels: int = 32,
final_hidden_channels: int = 32,
out_channels: int = 1,
dropout_rate: float = 0.5,
):
"""Instantiate the DeepSynergy model.
Expand All @@ -35,14 +36,15 @@ def __init__(
:param input_hidden_channels: The number of hidden layer neurons in the input layer.
:param middle_hidden_channels: The number of hidden layer neurons in the middle layer.
:param final_hidden_channels: The number of hidden layer neurons in the final layer.
:param out_channels: The number of output channels.
:param dropout_rate: The rate of dropout before the scoring head is used.
"""
super(DeepSynergy, self).__init__()
self.encoder = torch.nn.Linear(drug_channels + drug_channels + context_channels, input_hidden_channels)
self.hidden_first = torch.nn.Linear(input_hidden_channels, middle_hidden_channels)
self.hidden_second = torch.nn.Linear(middle_hidden_channels, final_hidden_channels)
self.dropout = torch.nn.Dropout(dropout_rate)
self.scoring_head = torch.nn.Linear(final_hidden_channels, 1)
self.scoring_head = torch.nn.Linear(final_hidden_channels, out_channels)

def unpack(self, batch: DrugPairBatch):
"""Return the context features, left drug features, and right drug features."""
Expand Down
11 changes: 7 additions & 4 deletions chemicalx/models/epgcnds.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ class EPGCNDS(Model):
<https://ojs.aaai.org/index.php/AAAI/article/view/7236>`_
"""

def __init__(self, *, in_channels: int, hidden_channels: int = 32, out_channels: int = 16):
def __init__(
self, *, in_channels: int = 128, hidden_channels: int = 32, middle_channels: int = 16, out_channels: int = 1
):
"""Instantiate the EPGCN-DS model.

:param in_channels: The number of molecular features.
:param hidden_channels: The number of graph convolutional filters.
:param out_channels: The number of hidden layer neurons in the last layer.
:param middle_channels: The number of hidden layer neurons in the last layer.
:param out_channels: The number of output channels.
"""
super(EPGCNDS, self).__init__()
self.graph_convolution_in = GraphConvolutionalNetwork(in_channels, hidden_channels)
self.graph_convolution_out = GraphConvolutionalNetwork(hidden_channels, out_channels)
self.graph_convolution_out = GraphConvolutionalNetwork(hidden_channels, middle_channels)
self.mean_readout = MeanReadout()
self.final = torch.nn.Linear(out_channels, 1)
self.final = torch.nn.Linear(middle_channels, out_channels)

def unpack(self, batch: DrugPairBatch):
"""Return the left molecular graph and right molecular graph."""
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
MHCADDI,
MRGNN,
SSIDDI,
DeepCCI,
DeepDDI,
DeepDDS,
DeepDrug,
Expand Down Expand Up @@ -112,11 +111,6 @@ def test_ssiddi(self):
model = SSIDDI(x=2)
assert model.x == 2

def test_deepcci(self):
"""Test DeepCCI."""
model = DeepCCI(x=2)
assert model.x == 2

def test_deepddi(self):
"""Test DeepDDI."""
model = DeepDDI(x=2)
Expand Down