Skip to content

Commit 0e5e29f

Browse files
authored
Merge pull request #948 from circulon/fix/947
Fix/947
2 parents c4959b4 + 5a2edeb commit 0e5e29f

18 files changed

Lines changed: 126 additions & 62 deletions

orm.sqlite3

40 KB
Binary file not shown.

pytest.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[pytest]
2-
env =
3-
D:DB_CONFIG_PATH=config/test-database
2+
env =
3+
D:DB_CONFIG_PATH=tests/integrations/config/database

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ black
33
isort
44
faker
55
pytest
6+
pytest-env
67
pytest-cov
78
pymysql
89
inflection>=0.3

src/masoniteorm/connections/ConnectionFactory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from ..config import load_config
2+
from .ConnectionResolver import ConnectionResolver
23

34

45
class ConnectionFactory:
56
"""Class for controlling the registration and creation of connection types."""
67

78
_connections = {}
89

9-
def __init__(self, config_path=None):
10+
def __init__(self, config_path=None, resolver=None):
1011
self.config_path = config_path
12+
self._resolver: ConnectionResolver = resolver
1113

1214
@classmethod
1315
def register(cls, key, connection):
@@ -35,18 +37,16 @@ def make(self, key):
3537
Returns:
3638
masoniteorm.connection.BaseConnection -- Returns an instance of a BaseConnection class.
3739
"""
40+
if not self._resolver:
41+
self._resolver = load_config(config_path=self.config_path).DB
3842

39-
DB = load_config(config_path=self.config_path).DB
40-
41-
connections = DB.get_connection_details()
42-
43+
connections = self._resolver.get_connection_details()
4344
if key == "default":
4445
connection_details = connections.get(connections.get("default"))
4546
connection = self._connections.get(
4647
connection_details.get("driver")
4748
)
4849
else:
49-
connection_details = connections.get(key)
5050
connection = self._connections.get(key)
5151

5252
if connection:

src/masoniteorm/connections/ConnectionResolver.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22

33

44
class ConnectionResolver:
5-
_connection_details = {}
65
_connections = {}
76
_morph_map = {}
87

9-
def __init__(self, config_path=None):
8+
def __init__(self, config_path=None, connection_details=None):
109
from ..connections import (
10+
ConnectionFactory,
1111
MSSQLConnection,
1212
MySQLConnection,
1313
PostgresConnection,
1414
SQLiteConnection,
1515
)
1616

1717
self.config_path = config_path
18-
from ..connections import ConnectionFactory
18+
self._connection_details = connection_details or {}
1919

20-
self.connection_factory = ConnectionFactory(config_path=config_path)
20+
self.connection_factory = ConnectionFactory(
21+
config_path=config_path, resolver=self
22+
)
2123
self.register(SQLiteConnection)
2224
self.register(PostgresConnection)
2325
self.register(MySQLConnection)
@@ -28,7 +30,7 @@ def morph_map(self, map):
2830
return self
2931

3032
def set_connection_details(self, connection_details):
31-
self.__class__._connection_details = connection_details
33+
self._connection_details = connection_details
3234
return self
3335

3436
def get_connection_details(self):

src/masoniteorm/models/Model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def get_columns(cls):
392392
return list(cls.first().__attributes__.keys())
393393

394394
def get_connection_details(self):
395-
DB = load_config().DB
396-
return DB.get_connection_details()
395+
resolver = load_config().DB
396+
return resolver.get_connection_details()
397397

398398
def boot(self):
399399
if not self._booted:

src/masoniteorm/query/QueryBuilder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ..collection.Collection import Collection
77
from ..config import load_config
8+
from ..connections import ConnectionResolver
89
from ..exceptions import (
910
HTTP404,
1011
ConnectionNotRegistered,
@@ -107,8 +108,8 @@ def __init__(
107108
self.set_action("select")
108109

109110
if not self._connection_details:
110-
DB = load_config(config_path=self.config_path).DB
111-
self._connection_details = DB.get_connection_details()
111+
resolver = load_config(config_path=self.config_path).DB
112+
self._connection_details = resolver.get_connection_details()
112113

113114
self.on(connection)
114115

@@ -398,8 +399,6 @@ def method(*args, **kwargs):
398399
)
399400

400401
def on(self, connection):
401-
DB = load_config(self.config_path).DB
402-
403402
if connection == "default":
404403
self.connection = self._connection_details.get("default")
405404
else:
@@ -413,7 +412,10 @@ def on(self, connection):
413412
self._connection_driver = self._connection_details.get(
414413
self.connection
415414
).get("driver")
416-
self.connection_class = DB.connection_factory.make(
415+
resolver = ConnectionResolver(
416+
connection_details=self._connection_details
417+
)
418+
self.connection_class = resolver.connection_factory.make(
417419
self._connection_driver
418420
)
419421

src/masoniteorm/schema/Schema.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,22 @@ def on(self, connection_key):
8787
Returns:
8888
cls
8989
"""
90-
DB = load_config(config_path=self.config_path).DB
91-
90+
resolver = load_config(config_path=self.config_path).DB
91+
self.connection_details = resolver.get_connection_details()
9292
if connection_key == "default":
9393
self.connection = self.connection_details.get("default")
94+
else:
95+
self.connection = connection_key
9496

95-
connection_detail = self._connection_driver = (
96-
self.connection_details.get(self.connection)
97-
)
98-
97+
connection_detail = self.connection_details.get(self.connection)
9998
if connection_detail:
10099
self._connection_driver = connection_detail.get("driver")
101100
else:
102101
raise ConnectionNotRegistered(
103102
f"Could not find the '{connection_key}' connection details"
104103
)
105104

106-
self.connection_class = DB.connection_factory.make(
105+
self.connection_class = resolver.connection_factory.make(
107106
self._connection_driver
108107
)
109108

tests/config/test_db_url.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,12 @@ def test_using_it_with_connection_resolver(self):
117117
assert config.get("port") == 3306
118118
assert config.get("host") == "localhost"
119119
assert config.get("log_queries")
120-
# reset connection resolver to default for other tests to continue working
121-
from tests.integrations.config.database import DATABASES
122120

123-
ConnectionResolver().set_connection_details(DATABASES)
121+
inline_resolver = ConnectionResolver(connection_details=TEST_DATABASES)
122+
inline_config = inline_resolver.get_connection_details().get("test")
123+
assert inline_config.get("database") == "orm"
124+
assert inline_config.get("user") == "root"
125+
assert inline_config.get("password") == ""
126+
assert inline_config.get("port") == 3306
127+
assert inline_config.get("host") == "localhost"
128+
assert inline_config.get("log_queries")

tests/integrations/config/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
},
120120
}
121121

122-
DB = ConnectionResolver().set_connection_details(DATABASES)
122+
DB = ConnectionResolver(connection_details=DATABASES)
123123

124124
logger = logging.getLogger("masoniteorm.connection.queries")
125125
logger.setLevel(logging.DEBUG)

0 commit comments

Comments
 (0)