Skip to content

Commit 16e812d

Browse files
committed
Add generate_compose tests
1 parent ddc8cab commit 16e812d

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ Usage: ./bitcoin-on-local.sh start|stop|renew|draw|scenario|draw [output_file]
7373
This tool requires :
7474

7575
- Docker (compose v2)
76-
- Python (>= 3.12)[^1]
76+
- Python (>= 3.11)[^1]
7777

78-
[^1]: The tool may work on older versions but has not been tested on it.
78+
[^1]: The tool has been tested on Python 3.11 and 3.12
7979

8080
All other dependencies will be installed with pip. See [requirements.txt](./requirements.txt).
8181

py/generate_compose.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# this file generates a docker-compose.yml file based on the provided configuration
22

33
import random
4+
import os
45

56
from config import (
67
NODE_NUMBER,
@@ -130,6 +131,9 @@ def export_data(all_ports: dict, node_names: list, output_dir: str = 'data'):
130131
output_dir (str): Subdirectory of /docker to store the .env files. Defaults to 'data'.
131132
"""
132133

134+
# ensure the output directory exists
135+
os.makedirs(f"docker/{output_dir}", exist_ok=True)
136+
133137
# export node names :
134138
output_file_names = f"docker/{output_dir}/.env.node_names"
135139
with open(output_file_names, 'w') as file:
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import pytest
2+
import random
3+
import builtins
4+
import sys
5+
import pathlib
6+
7+
# Patch sys.path so we can import the module directly from py/
8+
import importlib.util
9+
10+
MODULE_PATH = str(pathlib.Path(__file__).parent.parent.parent / "py" / "generate_compose.py")
11+
spec = importlib.util.spec_from_file_location("generate_compose", MODULE_PATH)
12+
generate_compose = importlib.util.module_from_spec(spec)
13+
sys.modules["generate_compose"] = generate_compose
14+
spec.loader.exec_module(generate_compose)
15+
16+
def test_generate_names_basic():
17+
assert generate_compose.generate_names(3, "node") == ["node_1", "node_2", "node_3"]
18+
assert generate_compose.generate_names(0, "n") == []
19+
assert generate_compose.generate_names(1, "bitcoin") == ["bitcoin_1"]
20+
21+
def test_compute_ports_basic():
22+
ports = generate_compose.compute_ports(2, 1000, 2000, "n")
23+
assert ports == {"n_1": (1000, 2000), "n_2": (1002, 2002)}
24+
ports = generate_compose.compute_ports(0, 100, 200, "x")
25+
assert ports == {}
26+
27+
def test_generate_peers_deterministic(monkeypatch):
28+
# Patch random to make the test deterministic
29+
monkeypatch.setattr(random, "randint", lambda a, b: 1)
30+
monkeypatch.setattr(random, "sample", lambda l, n: l[:n])
31+
names = ["n1", "n2", "n3"]
32+
peers = generate_compose.generate_peers(names, 2)
33+
for k, v in peers.items():
34+
assert len(v) == 1
35+
assert k not in v
36+
assert set(v).issubset(set(names) - {k})
37+
38+
def test_generate_peers_max_peers(monkeypatch):
39+
# If max_peers > available, should not exceed available
40+
monkeypatch.setattr(random, "randint", lambda a, b: b)
41+
monkeypatch.setattr(random, "sample", lambda l, n: l[:n])
42+
names = ["a", "b", "c", "d"]
43+
peers = generate_compose.generate_peers(names, 10)
44+
for k, v in peers.items():
45+
assert len(v) == len(names) - 1
46+
47+
def test_generate_command_addnode_and_logging(tmp_path, monkeypatch):
48+
# Prepare a fake template file
49+
template = (
50+
"user={RPCUSER}\n"
51+
"pass={RPCPASSWORD}\n"
52+
"max={MAXCONNECTIONS}\n"
53+
"rpc={RPCPORT}\n"
54+
"p2p={P2PPORT}\n"
55+
"{ADDNODE}\n"
56+
)
57+
template_path = tmp_path / "cmd.template"
58+
template_path.write_text(template)
59+
60+
# Patch logging flags
61+
monkeypatch.setattr(generate_compose, "LOG_NET_ENABLED", True)
62+
monkeypatch.setattr(generate_compose, "LOG_MEMPOOL_ENABLED", True)
63+
64+
all_ports = {"n2": (1002, 2002), "n3": (1004, 2004)}
65+
peers = ["n2", "n3"]
66+
result = generate_compose.generate_command(
67+
str(template_path),
68+
rpc_user="alice",
69+
rpc_password="pw",
70+
max_peers=5,
71+
rpc_port=1000,
72+
p2p_port=2000,
73+
peers=peers,
74+
all_ports=all_ports,
75+
)
76+
assert "user=alice" in result
77+
assert "-addnode=n2:2002" in result
78+
assert "-addnode=n3:2004" in result
79+
assert "-debug=net" in result
80+
assert "-debug=mempool" in result
81+
82+
def test_generate_command_template_extension(tmp_path):
83+
# Should raise if not .template
84+
with pytest.raises(ValueError):
85+
generate_compose.generate_command(
86+
"notemplate.txt", "u", "p", 1, 2, 3, [], {}
87+
)
88+
89+
def test_generate_command_no_logging(tmp_path, monkeypatch):
90+
template_path = tmp_path / "cmd.template"
91+
template_path.write_text("{ADDNODE}")
92+
monkeypatch.setattr(generate_compose, "LOG_NET_ENABLED", False)
93+
monkeypatch.setattr(generate_compose, "LOG_MEMPOOL_ENABLED", False)
94+
all_ports = {"n2": (1002, 2002)}
95+
peers = ["n2"]
96+
result = generate_compose.generate_command(
97+
str(template_path),
98+
rpc_user="a",
99+
rpc_password="b",
100+
max_peers=1,
101+
rpc_port=1,
102+
p2p_port=2,
103+
peers=peers,
104+
all_ports=all_ports,
105+
)
106+
assert "-addnode=n2:2002" in result
107+
assert "-debug" not in result
108+
109+
def test_export_data_creates_files(tmp_path, monkeypatch):
110+
# Patch print to suppress output
111+
monkeypatch.chdir(tmp_path)
112+
all_ports = {"n1": (1001, 2001), "n2": (1003, 2003)}
113+
node_names = ["n1", "n2"]
114+
names_path = tmp_path / "docker" / "data" / ".env.node_names"
115+
ports_path = tmp_path / "docker" / "data" / ".env.rpc_ports"
116+
generate_compose.export_data(all_ports, node_names, output_dir="data")
117+
assert names_path.exists()
118+
assert ports_path.exists()
119+
names_content = names_path.read_text()
120+
ports_content = ports_path.read_text()
121+
assert "n1" in names_content and "n2" in names_content
122+
assert "N1_RPC_PORT=1001" in ports_content
123+
assert "N2_RPC_PORT=1003" in ports_content
124+
125+
def test_export_data_prints(monkeypatch, tmp_path):
126+
# Patch print to capture output
127+
monkeypatch.chdir(tmp_path)
128+
printed = []
129+
monkeypatch.setattr(builtins, "print", lambda *a, **k: printed.append(a[0]))
130+
all_ports = {"n1": (1, 2)}
131+
node_names = ["n1"]
132+
generate_compose.export_data(all_ports, node_names, output_dir="data")
133+
assert any("Node names exported" in s for s in printed)
134+
assert any("RPC ports exported" in s for s in printed)

0 commit comments

Comments
 (0)