diff --git a/src/hitachione/tools/company_filtering_tool/test.py b/src/hitachione/tools/company_filtering_tool/test.py index 38cb2ae..c5e3bb1 100644 --- a/src/hitachione/tools/company_filtering_tool/test.py +++ b/src/hitachione/tools/company_filtering_tool/test.py @@ -13,6 +13,7 @@ # Load .env file (project root is 5 levels up from this file) load_dotenv(Path(__file__).resolve().parents[4] / ".env") +import os from tool import ( find_relevant_symbols, find_relevant_sp500_symbols, @@ -34,7 +35,8 @@ def test_symbol_extraction(): print_section("Testing Symbol Extraction from Weaviate") print("Extracting all unique tickers from Weaviate collection...") - print("(Iterates through Hitachi_finance_news collection)\n") + collection = os.getenv('WEAVIATE_COLLECTION_NAME', 'hitachi-finance-news') + print(f"(Iterates through {collection} collection)\n") try: import time diff --git a/src/hitachione/tools/company_filtering_tool/tool.py b/src/hitachione/tools/company_filtering_tool/tool.py index 7037cb0..c3eefc1 100644 --- a/src/hitachione/tools/company_filtering_tool/tool.py +++ b/src/hitachione/tools/company_filtering_tool/tool.py @@ -1,7 +1,7 @@ """ Tool for finding relevant stock symbols from the Weaviate financial news knowledge base. -This tool queries the Weaviate `Hitachi_finance_news` collection to retrieve unique +This tool queries the Weaviate financial news collection to retrieve unique stock tickers and uses an LLM to filter them based on user queries. """ @@ -28,8 +28,8 @@ _cached_companies: dict[str, str] | None = None # ticker -> company name _client_manager = None -# Weaviate collection name -WEAVIATE_COLLECTION = "Hitachi_finance_news" +# Weaviate collection name (from WEAVIATE_COLLECTION_NAME env var) +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") def get_client_manager() -> AsyncClientManager: @@ -71,7 +71,7 @@ def get_all_symbols() -> List[str]: """ Get all unique stock tickers from the Weaviate knowledge base. - Iterates through the Hitachi_finance_news collection and collects + Iterates through the Weaviate collection and collects unique ticker symbols and their corresponding company names. Returns: @@ -254,7 +254,7 @@ def find_relevant_symbols(query: str, use_llm_filter: bool = True) -> List[str]: "name": "find_relevant_symbols", "description": ( "Find relevant stock symbols from the Weaviate financial news knowledge base " - "(Hitachi_finance_news collection). The tool uses an LLM internally to filter " + "The tool uses an LLM internally to filter " "symbols based on the query, returning only symbols that match the specified " "criteria (sector, industry, time period, ranking, etc.)." ), diff --git a/src/hitachione/tools/performance_analysis_tool/test.py b/src/hitachione/tools/performance_analysis_tool/test.py new file mode 100644 index 0000000..0cfaced --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/test.py @@ -0,0 +1,289 @@ +""" +Test harness for the Performance Analysis Tool (Weaviate-backed). + +Run: + cd src/hitachione/tools/performance_analysis_tool + python3 test.py all # full suite + python3 test.py data # data retrieval only (no LLM) + python3 test.py analyse # full analysis (requires LLM key) + python3 test.py schema # show tool schema + python3 test.py interactive # interactive ticker input +""" + +import json +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +from tool import ( + TOOL_SCHEMA, + analyse_stock_performance, + get_ticker_data, +) + +# ── Tickers known to exist in the Weaviate collection ── +KNOWN_TICKERS = ["AAPL", "AMZN", "GOOGL", "JPM", "META", "MSFT", "NVDA", "TSLA", "V", "WMT"] +UNKNOWN_TICKER = "ZZZZZ" + + +# ────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────── + +def _section(title: str) -> None: + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def _pass(msg: str) -> None: + print(f" ✓ {msg}") + + +def _fail(msg: str) -> None: + print(f" ✗ {msg}") + + +# ────────────────────────────────────────────────────────────────────────── +# Tests +# ────────────────────────────────────────────────────────────────────────── + +def test_tool_schema() -> None: + """Validate the tool schema structure.""" + _section("Tool Schema") + print(json.dumps(TOOL_SCHEMA, indent=2)) + + assert TOOL_SCHEMA["type"] == "function" + fn = TOOL_SCHEMA["function"] + assert fn["name"] == "analyse_stock_performance" + assert "ticker" in fn["parameters"]["properties"] + assert "ticker" in fn["parameters"]["required"] + _pass("Schema is valid") + + +def test_data_retrieval_known_ticker() -> None: + """Retrieve data for known tickers and verify structure.""" + _section("Data Retrieval — Known Tickers") + + for ticker in ["AAPL", "TSLA", "JPM"]: + t0 = time.time() + data = get_ticker_data(ticker) + elapsed = time.time() - t0 + + assert isinstance(data, dict), f"Expected dict, got {type(data)}" + for key in ("price_data", "earnings", "news"): + assert key in data, f"Missing key '{key}' for {ticker}" + + total = sum(len(v) for v in data.values()) + _pass( + f"{ticker}: {len(data['price_data'])} price, " + f"{len(data['earnings'])} earnings, " + f"{len(data['news'])} news ({elapsed:.2f}s, {total} total)" + ) + + # At least one data source should have records + assert total > 0, f"No data returned for known ticker {ticker}" + + +def test_data_retrieval_unknown_ticker() -> None: + """Verify graceful handling of an unknown ticker.""" + _section("Data Retrieval — Unknown Ticker") + + data = get_ticker_data(UNKNOWN_TICKER) + total = sum(len(v) for v in data.values()) + assert total == 0, f"Expected 0 records for {UNKNOWN_TICKER}, got {total}" + _pass(f"{UNKNOWN_TICKER}: 0 records as expected") + + +def test_data_retrieval_case_insensitive() -> None: + """Verify ticker is uppercased automatically.""" + _section("Data Retrieval — Case Insensitivity") + + data_upper = get_ticker_data("AAPL") + data_lower = get_ticker_data("aapl") + + assert len(data_upper["price_data"]) == len(data_lower["price_data"]), \ + "Price data count differs between 'AAPL' and 'aapl'" + _pass("'AAPL' and 'aapl' return same results") + + +def test_price_data_fields() -> None: + """Verify price records contain expected fields.""" + _section("Price Data — Field Validation") + + data = get_ticker_data("TSLA") + if not data["price_data"]: + _fail("No price data for TSLA") + return + + expected_fields = {"date", "open", "high", "low", "close"} + for i, rec in enumerate(data["price_data"][:3]): + present = set(rec.keys()) & expected_fields + assert present == expected_fields, ( + f"Record {i} missing fields: {expected_fields - present}" + ) + _pass(f"First {min(3, len(data['price_data']))} records have all OHLC fields") + + +def test_price_data_sorted() -> None: + """Verify price records are sorted by date.""" + _section("Price Data — Sort Order") + + data = get_ticker_data("GOOGL") + dates = [r["date"] for r in data["price_data"] if "date" in r] + assert dates == sorted(dates), "Price data is not sorted by date" + _pass(f"GOOGL: {len(dates)} price records sorted correctly") + + +def test_analyse_unknown_ticker() -> None: + """Full analysis on unknown ticker returns None score.""" + _section("Full Analysis — Unknown Ticker") + + result = analyse_stock_performance(UNKNOWN_TICKER) + assert result["ticker"] == UNKNOWN_TICKER + assert result["performance_score"] is None + assert result["data_summary"]["price_records"] == 0 + _pass(f"{UNKNOWN_TICKER}: score=None, outlook={result['outlook']}") + + +def test_analyse_known_ticker() -> None: + """Full analysis on a known ticker returns valid structure.""" + _section("Full Analysis — Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping LLM analysis tests") + return + + for ticker in ["AAPL", "NVDA"]: + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert isinstance(result["performance_score"], int) + assert 1 <= result["performance_score"] <= 10 + assert result["outlook"] in ("Bullish", "Bearish", "Volatile", "Sideways") + assert len(result["justification"]) > 20 + assert result["data_summary"]["price_records"] > 0 + + _pass( + f"{ticker}: score={result['performance_score']}, " + f"outlook={result['outlook']}, {elapsed:.1f}s" + ) + print(f" Justification: {result['justification'][:120]}...") + + +def test_analyse_multiple_tickers() -> None: + """Analyse several tickers to confirm consistency.""" + _section("Full Analysis — All Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in KNOWN_TICKERS: + result = analyse_stock_performance(ticker) + score = result["performance_score"] + outlook = result["outlook"] + ds = result["data_summary"] + _pass( + f"{ticker:5s}: score={score:>2}, outlook={outlook:8s} " + f"(price={ds['price_records']}, earn={ds['earnings_records']}, " + f"news={ds['news_records']})" + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Interactive mode +# ────────────────────────────────────────────────────────────────────────── + +def interactive() -> None: + _section("Interactive Mode") + print(f"Available tickers: {', '.join(KNOWN_TICKERS)}") + print("Enter a ticker (or 'quit' to exit)\n") + + while True: + try: + ticker = input("Ticker> ").strip() + if ticker.lower() in ("quit", "exit", "q"): + break + if not ticker: + continue + + print(f"\nAnalysing {ticker.upper()}...") + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + print(json.dumps(result, indent=2)) + print(f"({elapsed:.1f}s)\n") + + except KeyboardInterrupt: + print("\n") + break + except Exception as e: + print(f" ✗ Error: {e}\n") + + +# ────────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────────── + +def main() -> None: + print("\n" + "=" * 80) + print(" Performance Analysis Tool (Weaviate) — Test Harness") + print("=" * 80) + + if len(sys.argv) > 1: + mode = sys.argv[1].lower() + else: + print("\nModes:") + print(" 1. all — Run all tests") + print(" 2. data — Data retrieval tests only (no LLM)") + print(" 3. analyse — Full analysis tests (requires LLM)") + print(" 4. schema — Display tool schema") + print(" 5. interactive — Interactive ticker input") + + choice = input("\nSelect (1-5) or press Enter for 'all': ").strip() + mode = {"1": "all", "2": "data", "3": "analyse", "4": "schema", "5": "interactive"}.get(choice, "all") + + if mode in ("all", "schema"): + test_tool_schema() + + if mode in ("all", "data"): + test_data_retrieval_known_ticker() + test_data_retrieval_unknown_ticker() + test_data_retrieval_case_insensitive() + test_price_data_fields() + test_price_data_sorted() + + if mode in ("all", "analyse"): + test_analyse_unknown_ticker() + test_analyse_known_ticker() + test_analyse_multiple_tickers() + + if mode == "interactive": + interactive() + + _section("Test Harness Complete") + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/performance_analysis_tool/tool.py b/src/hitachione/tools/performance_analysis_tool/tool.py new file mode 100644 index 0000000..5ae74ce --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/tool.py @@ -0,0 +1,318 @@ +""" +Tool for analysing stock performance using the Weaviate knowledge base. + +Queries the Weaviate financial news collection for price history, earnings +transcripts, and financial news, then uses an LLM to produce a performance +rating score (1-10), future outlook, and justification. +""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any, List + +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) +from utils.client_manager import AsyncClientManager + +# --------------------------------------------------------------------------- +# Weaviate helpers +# --------------------------------------------------------------------------- + +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + if http_host.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{http_host}", + auth_credentials=AuthApiKey(api_key), + ) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +# --------------------------------------------------------------------------- +# Data retrieval +# --------------------------------------------------------------------------- + + +def get_ticker_data(ticker: str) -> dict[str, list[dict]]: + """Retrieve all Weaviate data for a given ticker, grouped by source. + + Returns a dict with keys ``price_data``, ``earnings``, ``news`` — each + containing a list of property dicts. + """ + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + ticker_filter = Filter.by_property("ticker").equal(ticker.upper()) + + # --- price data (stock_market) --- + price_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal("stock_market") + ), + limit=100, + return_properties=[ + "date", "open", "high", "low", "close", "volume", "text", + ], + ) + price_data = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in price_response.objects + ] + + # --- earnings transcripts --- + earnings_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal( + "sp500_earnings_transcripts" + ) + ), + limit=100, + return_properties=[ + "date", "quarter", "fiscal_year", "fiscal_quarter", "text", + "title", + ], + ) + earnings = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in earnings_response.objects + ] + + # --- news (bloomberg + yahoo) --- + news_response = col.query.fetch_objects( + filters=( + ticker_filter + & ( + Filter.by_property("dataset_source").equal( + "bloomberg_financial_news" + ) + | Filter.by_property("dataset_source").equal( + "yahoo_finance_news" + ) + ) + ), + limit=100, + return_properties=["date", "title", "text", "category"], + ) + # Also grab news that *mention* this ticker (mentioned_companies) + mentioned_response = col.query.fetch_objects( + filters=Filter.by_property("mentioned_companies").contains_any( + [ticker.upper()] + ), + limit=50, + return_properties=["date", "title", "text", "category"], + ) + + seen_titles: set[str] = set() + news: list[dict] = [] + for obj in list(news_response.objects) + list(mentioned_response.objects): + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + news.append(props) + + return { + "price_data": sorted(price_data, key=lambda d: d.get("date", "")), + "earnings": sorted(earnings, key=lambda d: d.get("date", "")), + "news": sorted(news, key=lambda d: d.get("date", "")), + } + + finally: + client.close() + + +# --------------------------------------------------------------------------- +# LLM-based performance scoring +# --------------------------------------------------------------------------- + +_client_manager = None + + +def _get_client_manager() -> AsyncClientManager: + global _client_manager + if _client_manager is None: + _client_manager = AsyncClientManager() + return _client_manager + + +async def _analyse_with_llm(ticker: str, data: dict[str, list[dict]]) -> dict: + """Send retrieved data to an LLM and get a structured performance analysis.""" + cm = _get_client_manager() + + # Build context sections + price_summary = "\n".join( + d.get("text", json.dumps(d)) for d in data["price_data"] + ) or "No price data available." + + earnings_summary = "\n---\n".join( + d.get("text", json.dumps(d)) for d in data["earnings"] + ) or "No earnings data available." + + news_summary = "\n---\n".join( + f"[{d.get('date','')}] {d.get('title','')}: {str(d.get('text',''))[:500]}" + for d in data["news"] + ) or "No news articles available." + + prompt = f"""You are a Stock Performance Analyst. Analyse the ticker "{ticker}" using ONLY the data provided below. + +## Price History +{price_summary} + +## Earnings Transcripts +{earnings_summary} + +## News Articles +{news_summary} + +Based on the data above, produce a JSON object (and NOTHING else) with exactly these keys: + +{{ + "ticker": "{ticker}", + "performance_score": , + "outlook": "", + "justification": "<2-4 sentence explanation citing specific data points>" +}} + +Scoring guide: + 1-4 → Negative (declining price, poor earnings, negative news) + 5 → Neutral + 6-10 → Positive (rising price, strong earnings, positive sentiment) +""" + + response = await cm.openai_client.chat.completions.create( + model=cm.configs.default_worker_model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + + content = response.choices[0].message.content.strip() + + # Extract JSON from potential markdown code fences + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + + return json.loads(content) + + +def _run_async(coro): + """Run an async coroutine, handling nested event loops.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def analyse_stock_performance(ticker: str) -> dict: + """Analyse a stock's performance using Weaviate knowledge base data. + + Retrieves price history, earnings transcripts, and news articles from the + Weaviate financial news collection, then uses an LLM to produce a + structured performance analysis. + + Parameters + ---------- + ticker : str + Stock ticker symbol (e.g. ``"AAPL"``, ``"TSLA"``). + + Returns + ------- + dict + A dictionary with keys: + - ``ticker`` (str) + - ``performance_score`` (int, 1–10) + - ``outlook`` (str, one of Bullish/Bearish/Volatile/Sideways) + - ``justification`` (str) + - ``data_summary`` (dict with counts of price/earnings/news records) + """ + ticker = ticker.upper().strip() + data = get_ticker_data(ticker) + + if not any(data.values()): + return { + "ticker": ticker, + "performance_score": None, + "outlook": "Unknown", + "justification": f"No data found for ticker {ticker} in the knowledge base.", + "data_summary": {"price_records": 0, "earnings_records": 0, "news_records": 0}, + } + + analysis = _run_async(_analyse_with_llm(ticker, data)) + analysis["data_summary"] = { + "price_records": len(data["price_data"]), + "earnings_records": len(data["earnings"]), + "news_records": len(data["news"]), + } + return analysis + + +# --------------------------------------------------------------------------- +# OpenAI tool schema +# --------------------------------------------------------------------------- + +TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "analyse_stock_performance", + "description": ( + "Analyse a stock's performance using the Weaviate financial knowledge " + "base. Returns a performance score (1-10), " + "future outlook, and justification based on price history, earnings " + "transcripts, and news articles." + ), + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "Stock ticker symbol, e.g. 'AAPL', 'TSLA', 'GOOGL'.", + } + }, + "required": ["ticker"], + }, + }, +} + +TOOL_IMPLEMENTATIONS = { + "analyse_stock_performance": analyse_stock_performance, +} \ No newline at end of file