Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tap_postgres/connection_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def from_tap_config(cls, config: Mapping[str, Any]) -> ConnectionParameters:
options=_build_options_from_tap_config(config),
)

def with_host_and_port(self, host: str, port: int) -> ConnectionParameters:
def with_host_and_port(self, *, host: str, port: int) -> ConnectionParameters:
"""Return a new ConnectionParameters with the given host and port.

Args:
Expand Down
4 changes: 2 additions & 2 deletions tap_postgres/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,8 @@ def ssh_tunnel_connect(

# Swap the URL to use the tunnel
return connection_parameters.with_host_and_port(
host=connection_parameters.host,
port=connection_parameters.port,
host=self.ssh_tunnel.local_bind_host,
port=self.ssh_tunnel.local_bind_port, # type: ignore[arg-type]
)

def clean_up(self) -> None:
Expand Down
70 changes: 63 additions & 7 deletions tests/test_connection_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,16 @@ def test_connection_parameters_from_sqlalchemy_url_parses_fields(
}


def test_connection_parameters_from_sqlalchemy_url_defaults_port_and_keeps_ssl_paths(
tmp_path: Path,
) -> None:
def test_from_sqlalchemy_url_default_port() -> None:
cfg = {
"sqlalchemy_url": ("postgresql://user:pass@localhost/mydb"),
}

parameters = ConnectionParameters.from_tap_config(cfg)
assert parameters.port == 5432 # noqa: PLR2004


def test_from_sqlalchemy_url_and_preserves_ssl_paths(tmp_path: Path) -> None:
rootcert = tmp_path / "root.crt"
rootcert.write_text("CA", encoding="utf-8")

Expand All @@ -154,16 +161,14 @@ def test_connection_parameters_from_sqlalchemy_url_defaults_port_and_keeps_ssl_p
}

parameters = ConnectionParameters.from_tap_config(cfg)

assert parameters.port == 5432 # noqa: PLR2004
assert parameters.options == {
"application_name": "tap_postgres",
"sslmode": "verify-full",
"sslrootcert": str(rootcert),
}


def test_connection_parameters_renders_as_sqlalchemy_url(tmp_path: Path) -> None:
def test_renders_as_sqlalchemy_url(tmp_path: Path) -> None:
cfg = _base_config(tmp_path)
cfg.update(
{
Expand All @@ -181,7 +186,7 @@ def test_connection_parameters_renders_as_sqlalchemy_url(tmp_path: Path) -> None
)


def test_connection_parameters_renders_as_psycopg2_dsn(tmp_path: Path) -> None:
def test_renders_as_psycopg2_dsn(tmp_path: Path) -> None:
cfg = _base_config(tmp_path)
cfg.update(
{
Expand All @@ -197,3 +202,54 @@ def test_connection_parameters_renders_as_psycopg2_dsn(tmp_path: Path) -> None:
"host=localhost port=5432 dbname=postgres user=postgres password=postgres "
"application_name=tap_postgres sslmode=require"
)


def test_with_host_and_port():
"""Unit test for ConnectionParameters.with_host_and_port method."""
# Create original connection parameters (pointing to remote database)
original = ConnectionParameters(
host="remote-db.example.com",
port=5432,
database="testdb",
user="testuser",
password="testpass",
options={"sslmode": "require", "application_name": "tap_postgres"},
)

ssh_tunnel_host = "127.0.0.1"
ssh_tunnel_port = 12345

# Simulate what ssh_tunnel_connect does: update to tunnel's local bind address
updated = original.with_host_and_port(
host=ssh_tunnel_host, # tunnel's local_bind_host
port=ssh_tunnel_port, # tunnel's local_bind_port
)

assert updated.host == ssh_tunnel_host
assert updated.port == ssh_tunnel_port

# Verify other parameters are preserved
assert updated.database == original.database == "testdb"
assert updated.user == original.user == "testuser"
assert updated.password == original.password == "testpass"
assert (
updated.options
== original.options
== {"sslmode": "require", "application_name": "tap_postgres"}
)

# Verify original parameters are unchanged (immutability check)
assert original.host == "remote-db.example.com"
assert original.port == 5432 # noqa: PLR2004

# Verify the connection strings use the tunnel address
sqlalchemy_url = updated.render_as_sqlalchemy_url()
psycopg2_dsn = updated.render_as_psycopg2_dsn()

assert ssh_tunnel_host in sqlalchemy_url
assert str(ssh_tunnel_port) in sqlalchemy_url
assert original.host not in sqlalchemy_url

assert f"host={ssh_tunnel_host}" in psycopg2_dsn
assert f"port={ssh_tunnel_port}" in psycopg2_dsn
assert original.host not in psycopg2_dsn
2 changes: 1 addition & 1 deletion tests/test_ssh_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@


def test_ssh_tunnel():
"""We expect the SSH environment to already be up"""
"""We expect the SSH environment to already be up."""
tap = TapPostgres(config=SAMPLE_CONFIG)
tap.sync_all()