Skip to content

Commit 677cc8f

Browse files
authored
v1.0.2: fixes python 3.9 dependencies
1 parent f5edb93 commit 677cc8f

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

conversion/guacamol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import os
6+
from typing import Union
67

78
import torch
89
from loguru import logger
@@ -108,7 +109,11 @@ def check_smiles_graph_mapping_worker(smile_idx, smile):
108109

109110

110111
def process(
111-
split: str, raw_dir: str, n_jobs: int, limit: int | None, chunk_size: int
112+
split: str,
113+
raw_dir: str,
114+
n_jobs: int,
115+
limit: Union[int, None],
116+
chunk_size: int,
112117
) -> None:
113118
path = os.path.join(raw_dir, f"guacamol_v1_{split}.smiles")
114119
smile_list = [

polygraph/metrics/base/polygraphdiscrepancy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _descriptions_to_classifier_metric(
251251
variant: Literal["informedness", "jsd"] = "jsd",
252252
classifier: Optional[ClassifierProtocol] = None,
253253
rng: Optional[np.random.Generator] = None,
254-
) -> Tuple[float, int | float]:
254+
) -> Tuple[float, Union[int, float]]:
255255
rng = np.random.default_rng(0) if rng is None else rng
256256

257257
if isinstance(ref_descriptions, csr_array):

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "polygraph-benchmark"
7-
version = "1.0.1"
7+
version = "1.0.2"
88
description = "Evaluation benchmarks for graph generative models"
99
readme = "README.md"
1010
authors = [
@@ -13,22 +13,22 @@ authors = [
1313
{ name = "Dexiong Chen", email = "[email protected]" },
1414
{ name = "Karsten Borgwardt", email = "[email protected]" },
1515
]
16-
requires-python = ">=3.7"
16+
requires-python = ">=3.9"
1717
dependencies = [
1818
"numpy>=1.26.4,<3.0",
1919
"torch>=2.4.0,<3.0",
2020
"torch_geometric>=2.6.1,<3.0",
2121
"rich",
22-
"scipy>=1.14.0,<2.0",
22+
"scipy>=1.12.0,<2.0",
2323
"pydantic~=2.11.7",
24-
"networkx>=3.4,<4.0",
24+
"networkx>=3.2,<4.0",
2525
"joblib",
2626
"appdirs",
2727
"loguru",
2828
"rdkit",
2929
"pandas",
3030
"orbit-count",
31-
"numba~=0.61.2",
31+
"numba>=0.60.0,<0.62.0",
3232
"scikit-learn>=1.6.1,<2.0",
3333
"tabpfn==2.0.9",
3434
"fcd~=1.2.2"

tests/test_mmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848
from polygraph.utils.mmd_utils import mmd_from_gram
4949
from polygraph.metrics.base.metric_interval import MetricInterval
5050

51-
import grakel
52-
5351

5452
class WeisfeilerLehmanMMD2(DescriptorMMD2):
5553
def __init__(self, reference_graphs, iterations=3):
@@ -67,6 +65,8 @@ def __init__(self, reference_graphs, iterations=3):
6765
def grakel_wl_mmd(
6866
reference_graphs, test_graphs, is_parallel=False, iterations=3
6967
):
68+
import grakel
69+
7070
grakel_kernel = grakel.WeisfeilerLehman(n_iter=iterations)
7171
all_graphs = reference_graphs + test_graphs
7272
for g in all_graphs:

0 commit comments

Comments
 (0)