Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/railwaycopilot/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
**/__pycache__
**/__init___.py
.env
.DS_Store
**/.DS_Store
10 changes: 10 additions & 0 deletions apps/railwaycopilot/backend/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["streamlit", "run", "app.py", "--server.address=0.0.0.0", "--server.port=8501"]
83 changes: 83 additions & 0 deletions apps/railwaycopilot/backend/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import streamlit as st

from langchain_mistralai import ChatMistralAI, MistralAIEmbeddings
from langchain.docstore.document import Document

from rail_rag.config import (
MONGODB_URI,
EMBED_MODEL,
CHAT_MODEL,
)
from rail_rag.index_utils import get_mongo_collection
from rail_rag.retriever import MongoAtlasRetriever
from rail_rag.ui import render_prompt_lab
from rail_rag.generation import run_generation
from rail_rag.classifier import classify_text

# --- Streamlit UI chrome ---
st.set_page_config(page_title="Rail Ops & Safety Assistant", page_icon="🚆", layout="wide")
st.title("🚆 Rail Operations & Safety Assistant (MongoDB + LangChain + Mistral)")

# Sidebar: Prompt Lab (returns all user choices + composed system prompt)
lab = render_prompt_lab()

if not os.getenv("MISTRAL_API_KEY"):
st.error("Missing `MISTRAL_API_KEY` in environment.")
st.stop()

if not MONGODB_URI:
st.error("Missing `MONGODB_URI` in environment.")
st.stop()

colA, colB, colC = st.columns([2, 1, 1])
with colA:
q = st.text_input(
"Ask a question (e.g., 'What must a signaller do when going off duty?')",
"",
)
with colB:
top_k = st.slider("Top-K chunks", 1, 10, 4, 1)
with colC:
show_debug = st.toggle("Show debug", value=False)

# Connect resources (MongoDB collection + embeddings + retriever)
try:
collection = get_mongo_collection()
except Exception as e:
st.exception(e)
st.stop()

embedder = MistralAIEmbeddings(model=EMBED_MODEL)
retriever = MongoAtlasRetriever(collection=collection, embedder=embedder, k=top_k)

llm = ChatMistralAI(model=CHAT_MODEL)

if q:
try:
retrieved = retriever.invoke(q)

if show_debug:
with st.expander("🔎 Retrieved docs (debug)"):
for i, d in enumerate(retrieved, 1):
st.write(f"{i}. meta = {d.metadata}")
st.write((d.page_content or "")[:300] + "…")

if not retrieved:
st.warning(
"No documents retrieved. "
"Check MongoDB URI / DB / collection / vector index / field names."
)
st.stop()

# Full prompt-building + A/B + rendering (answers + sources)
run_generation(
question=q,
retrieved=retrieved,
chat_model_name=CHAT_MODEL,
lab=lab,
)

except Exception as e:
st.exception(e)
st.stop()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
214 changes: 214 additions & 0 deletions apps/railwaycopilot/backend/ingest_rulebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import os
from glob import glob
from typing import List, Dict
import numpy as np
import requests
import certifi

from pymongo import MongoClient
from pymongo.errors import OperationFailure
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from pymongo.errors import OperationFailure

MONGODB_URI = os.getenv("MONGODB_URI")
DB_NAME = os.getenv("DB_NAME", "rail_ops")
COLL_NAME = os.getenv("COLLECTION_NAME", "rulebook_chunks")

EMBED_MODEL = os.getenv("MISTRAL_EMBED_MODEL", "mistral-embed")
EMBED_DIM = 1024
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")

CORPUS_DIR = os.getenv("CORPUS_DIR", "corpus")

# Field names
TEXT_KEY = "content"
VEC_KEY = "content_vector"
SRC_KEY = "source"
PAGE_KEY = "page"


class SimpleMistralEmbedder:
def __init__(self, model: str, api_key: str):
if not api_key:
raise RuntimeError("Missing MISTRAL_API_KEY in environment.")
self.model = model
self.api_key = api_key
self.url = "https://api.mistral.ai/v1/embeddings"
self.session = requests.Session()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

def embed_documents(self, texts: List[str]) -> List[List[float]]:
payload = {"model": self.model, "input": texts}
r = self.session.post(self.url, headers=self.headers, json=payload, timeout=60)
try:
j = r.json()
except Exception:
raise RuntimeError(f"Embeddings HTTP {r.status_code}: {r.text[:500]}")
# Accept multiple shapes to be version-tolerant
if "data" in j and isinstance(j["data"], list):
return [item["embedding"] for item in j["data"]]
if "embeddings" in j and isinstance(j["embeddings"], list):
return j["embeddings"]
if "error" in j:
raise RuntimeError(f"Mistral embeddings error: {j['error']}")
raise RuntimeError(f"Unexpected embeddings response shape: {str(j)[:500]}")

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

# ---------------------------------------------------------------------
# Data loading & chunking
# ---------------------------------------------------------------------
def load_docs(corpus_dir: str):
docs = []
for p in glob(os.path.join(corpus_dir, "*.pdf")):
for d in PyPDFLoader(p).load():
d.metadata[SRC_KEY] = os.path.basename(p)
d.metadata[PAGE_KEY] = d.metadata.get("page")
docs.append(d)
return docs

def chunk_docs(docs):
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=120,
add_start_index=True,
)
return splitter.split_documents(docs)

# ---------------------------------------------------------------------
# Build Mongo-ready documents
# ---------------------------------------------------------------------
def build_records(chunks, embedder: SimpleMistralEmbedder) -> List[Dict]:
texts = [c.page_content or "" for c in chunks]
vectors = embedder.embed_documents(texts)
recs = []
for i, c in enumerate(chunks):
vec = vectors[i]
# Ensure correct dtype/dim for safety
if len(vec) != EMBED_DIM:
raise ValueError(f"Unexpected embedding dim {len(vec)} (expected {EMBED_DIM})")
# Mongo expects an array of numbers
vec = [float(x) for x in vec]

rec = {
TEXT_KEY: c.page_content or "",
VEC_KEY: vec,
SRC_KEY: c.metadata.get(SRC_KEY),
}
page_val = c.metadata.get(PAGE_KEY)
if page_val is not None:
try:
rec[PAGE_KEY] = int(page_val)
except Exception:
rec[PAGE_KEY] = -1
recs.append(rec)
return recs

# ---------------------------------------------------------------------
# Ensure Atlas Vector Search index exists (vectorSearch)
# ---------------------------------------------------------------------
def ensure_vector_index(coll, index_name="vector_index"):
"""
Creates a Vector Search index on content_vector if it doesn't already exist.
"""

print(f"[info] Checking existing search indexes on {coll.full_name}…")

existing = []
try:
existing = list(coll.aggregate([{"$listSearchIndexes": {}}]))
except OperationFailure as e:
print(f"[warn] $listSearchIndexes not supported or failed: {e}")
except Exception as e:
print(f"[warn] Unexpected error listing search indexes: {e}")

for idx in existing:
if idx.get("name") == index_name:
print(f"[info] Search index '{index_name}' already exists.")
return

print(f"[info] Creating VECTOR SEARCH index '{index_name}'…")

definition = {
"name": index_name,
"type": "vectorSearch",
"definition": {
"fields": [
{
"type": "vector",
"path": "content_vector",
"numDimensions": 1024,
"similarity": "cosine",
},
{
"type": "filter",
"path": "source",
},
{
"type": "filter",
"path": "page",
},
]
},
}

try:
result = coll.database.command({
"createSearchIndexes": coll.name,
"indexes": [definition],
})
print(f"[info] createSearchIndexes result: {result}")
except Exception as e:
print(f"[error] Failed to create search index '{index_name}': {e}")

# ---------------------------------------------------------------------
def main():
if not MISTRAL_API_KEY:
raise SystemExit("Missing MISTRAL_API_KEY in environment!")
if not MONGODB_URI:
raise SystemExit("Missing MONGODB_URI in environment!")

# 1) Load & chunk PDFs
docs = load_docs(CORPUS_DIR)
if not docs:
raise SystemExit(f"No PDFs found in '{CORPUS_DIR}'")
chunks = chunk_docs(docs)

# 2) Embed
embedder = SimpleMistralEmbedder(model=EMBED_MODEL, api_key=MISTRAL_API_KEY)
records = build_records(chunks, embedder)

# 3) Connect to MongoDB Atlas
#client = MongoClient(MONGODB_URI)
client = MongoClient(MONGODB_URI, tlsCAFile=certifi.where())
db = client[DB_NAME]
coll = db[COLL_NAME]

# 4) Create / ensure vector index
try:
ensure_vector_index(coll, index_name="vector_index")
except Exception as e:
# If running locally (no Atlas) or on an older server this may fail; ingestion can still proceed.
print(f"[warn] Could not ensure vector index now: {e}")

# 5) Fresh load: optional cleanup for a clean re-ingest
if os.getenv("FRESH_LOAD", "true").lower() in ("1", "true", "yes"):
coll.delete_many({})

# 6) Insert records
if records:
# Insert in batches
BATCH = 500
for i in range(0, len(records), BATCH):
coll.insert_many(records[i:i+BATCH])
print(f"[✅] Ingested {len(records)} chunks into '{DB_NAME}.{COLL_NAME}'")
else:
print("[ℹ️] No records to insert.")

if __name__ == "__main__":
main()
Empty file.
41 changes: 41 additions & 0 deletions apps/railwaycopilot/backend/rail_rag/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from langchain_mistralai import ChatMistralAI
import json

CLASSIFIER_SYSTEM_PROMPT = """You are a classification assistant for rail operations and safety.
Classify the input into one of these intents:
- informational
- procedural
- compliance
- safety_critical
- other
Respond ONLY in JSON like:
{"intent": "..."}.
"""

llm_classifier = ChatMistralAI(model="mistral-small-latest", temperature=0.0)

def classify_text(text: str) -> dict:
messages = [
("system", CLASSIFIER_SYSTEM_PROMPT),
("human", text),
]
result = llm_classifier.invoke(messages)
raw = result.content.strip()

# Try to parse JSON; if it fails, fall back to dict with string
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
return parsed
else:
return {"intent": str(parsed)}
except Exception:
# fallback: sometimes the LLM returns plain text or partial JSON
if raw.startswith("{") and raw.endswith("}"):
# slightly malformed JSON, try to clean quotes
raw = raw.replace("'", '"')
try:
return json.loads(raw)
except Exception:
pass
return {"intent": raw}
22 changes: 22 additions & 0 deletions apps/railwaycopilot/backend/rail_rag/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

MONGODB_URI = os.getenv("MONGODB_URI")
MONGO_DB_NAME = os.getenv("DB_NAME", "rail_ops")
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME", "rulebook_chunks")
VECTOR_INDEX_NAME = os.getenv("VECTOR_INDEX_NAME", "vector_index")

# --- Mistral models ---
EMBED_MODEL = os.getenv("MISTRAL_EMBED_MODEL", "mistral-embed")
CHAT_MODEL = os.getenv("MISTRAL_CHAT_MODEL", "mistral-small-latest")

# --- Field names ---
TEXT_KEY = "content"
VEC_KEY = "content_vector"
SRC_KEY = "source"
PAGE_KEY = "page"

SYSTEM_PROMPT = """You are a Rail Operations & Safety assistant.
Answer ONLY using the provided context.
If the answer is not in the context, say “I don’t have that in the documents.”
Cite sources as (filename p.page). Be concise and correct. Do not reveal internal reasoning steps.
"""
Loading