@@ -193,4 +193,51 @@ def test_load_csv_with_metadata(csv_file: str):
193193 "question" : "q2" ,
194194 "dataset" : "dataset2" ,
195195 "content" : "content2" ,
196- }
196+ }
197+
198+ def test_retrieve_docs_from_vectordb (mocker ):
199+ # Create mock CompanyContent objects
200+ mock_company_contents = [
201+ CompanyContent (
202+ id = 1 ,
203+ category = "cat1" ,
204+ subcategory = "subcat1" ,
205+ question = "What is cat1?" ,
206+ content = "Content for cat1" ,
207+ embedding = "0.1 0.2 0.3" , # Mock embedding data
208+ dataset = "dataset1" ,
209+ link = "http://link1"
210+ ),
211+ CompanyContent (
212+ id = 2 ,
213+ category = "cat2" ,
214+ subcategory = "subcat2" ,
215+ question = "What is cat2?" ,
216+ content = "Content for cat2" ,
217+ embedding = "0.4 0.5 0.6" , # Mock embedding data
218+ dataset = "dataset2" ,
219+ link = "http://link2"
220+ )
221+ ]
222+
223+ # Mock the query and its all() method
224+ mock_query : Mock = mocker .patch ("load_csv.session" )
225+ load_csv .session .query .return_value .all .return_value = mock_company_contents
226+
227+ # Call the function to test
228+ docs = load_csv .retrieve_docs_from_vectordb ()
229+
230+ # Check that the output is as expected
231+ assert len (docs ) == 2
232+ assert docs [0 ].page_content == "Content for cat1"
233+ assert docs [0 ].metadata == {
234+ "category" : "cat1" ,
235+ "subcategory" : "subcat1" ,
236+ "question" : "What is cat1?"
237+ }
238+ assert docs [1 ].page_content == "Content for cat2"
239+ assert docs [1 ].metadata == {
240+ "category" : "cat2" ,
241+ "subcategory" : "subcat2" ,
242+ "question" : "What is cat2?"
243+ }
0 commit comments