Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/common/providers/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,24 @@ def connect(cls):
host, port = cls.get_host_port()
if not verify_connection(host, port):
raise ConnectionError(f"Ray is not listening on {host}:{port}")

logger.info(f"Connecting to Ray at {cls.ray_url}...")
Copy link
Member

@MichaelRipa MichaelRipa Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logging is useful for me, so I'm adding it back in

ray.init(logging_level="error", address=cls.ray_url)
logger.info("Connected to Ray")

@classmethod
def connected(cls) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so you are aware, I profiled this method and the added ray.get_actor() is significantly slower (high variance, but the slowest was just over 5 seconds).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is fine because connected is only called if theres some critical error in the first place. I think its good to know youre truly connected if you can access the controller. Also makes it so if you run a test script right after brining it up, you wont get the error saying controller does not exist

return ray.is_initialized() and cls.is_listening()

connected = ray.is_initialized() and cls.is_listening()

if connected:

try:
ray.get_actor("Controller", namespace="NDIF")
except:
return False
else:
return True

return False


@classmethod
def reset(cls):
Expand Down
7 changes: 0 additions & 7 deletions src/common/providers/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def to_env(cls) -> dict:
@classmethod
@retry
def connect(cls):
logger.info(f"Connecting to API at {cls.api_url}...")
if cls.sio is None:
logger.debug("Creating new socketio client")
cls.sio = socketio.SimpleClient(reconnection_attempts=10)
Expand All @@ -42,17 +41,11 @@ def connect(cls):
)
# Wait for connection to be fully established
time.sleep(0.1)
logger.info("Connected to API")

@classmethod
def disconnect(cls):
logger.debug("SioProvider.disconnect() called")
if cls.sio is not None:
logger.info("Disconnecting socketio client")
cls.sio.disconnect()
logger.debug("Socketio client disconnected")
else:
logger.debug("No socketio client to disconnect")

@classmethod
def connected(cls) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions src/common/schema/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class ObjectStorageMixin(BaseModel):
@classmethod
def object_name(cls, id: str):
return f"{id}.{cls._file_extension}"

def _url(self, client: boto3.client) -> str:
return client.generate_presigned_url('get_object', Params={'Bucket': self._bucket_name, 'Key': self.object_name(self.id)}, ExpiresIn=3600 * 6)

def _save(self, client: boto3.client, data: BytesIO, content_type: str, bucket_name: str = None) -> None:
bucket_name = self._bucket_name if bucket_name is None else bucket_name
Expand Down
11 changes: 7 additions & 4 deletions src/common/schema/result.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import ClassVar
from pydantic import ConfigDict
from pydantic import ConfigDict
from .mixins import ObjectStorageMixin

from ..providers.objectstore import ObjectStoreProvider
class BackendResultModel(ObjectStorageMixin):

model_config = ConfigDict(extra='allow')
model_config = ConfigDict(extra='allow', validate_assignment=False, frozen=False, arbitrary_types_allowed=True, str_strip_whitespace=False, strict=False)

_bucket_name: ClassVar[str] = "dev-ndif-results"
_file_extension: ClassVar[str] = "pt"

def url(self) -> str:
return self._url(ObjectStoreProvider.object_store)

108 changes: 16 additions & 92 deletions src/services/api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@
import redis
import socketio
import uvicorn
from fastapi import BackgroundTasks, Depends, FastAPI, Request
from fastapi import BackgroundTasks, Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

from fastapi_cache.decorator import cache
from fastapi_socketio import SocketManager
from nnsight.schema.response import ResponseModel
from prometheus_fastapi_instrumentator import Instrumentator

from .types import REQUEST_ID, SESSION_ID
from nnsight.schema.response import ResponseModel

from .logging import set_logger
from .types import REQUEST_ID, SESSION_ID

logger = set_logger("API")

from .dependencies import validate_request
from .metrics import NetworkStatusMetric
from .providers.objectstore import ObjectStoreProvider
from .schema import (BackendRequestModel, BackendResponseModel,
BackendResultModel)


from .schema import BackendRequestModel, BackendResponseModel

# Init FastAPI app
app = FastAPI()
Expand Down Expand Up @@ -64,7 +59,7 @@
@app.post("/request")
async def request(
background_tasks: BackgroundTasks,
backend_request: BackendRequestModel = Depends(validate_request)
backend_request: BackendRequestModel = Depends(validate_request),
) -> BackendResponseModel:
"""Endpoint to submit request. See src/common/schema/request.py to see the headers and data that are validated and populated.

Expand Down Expand Up @@ -108,36 +103,6 @@ async def request(
return response


# @app.delete("/request/{request_id}")
# async def delete_request(request_id: str):
# """Delete a submitted request, provided it is either queued or running"""
# try:
# endpoint = f"http://{os.environ.get('QUEUE_URL')}/queue/{request_id}"
# async with httpx.AsyncClient() as client:
# response = await client.delete(endpoint)
# response.raise_for_status()
# return {"message": f"Request {request_id} successfully submitted for deletion!"}
# except httpx.HTTPStatusError as e:
# # Handle HTTP errors from the queue service
# if e.response is not None and e.response.status_code == 404:
# raise HTTPException(status_code=404, detail=f"Request {request_id} not found")
# elif e.response is not None and e.response.status_code == 500:
# # Try to extract the error message from the queue service
# try:
# error_detail = e.response.json().get('detail', str(e))
# except:
# error_detail = str(e)
# raise HTTPException(status_code=500, detail=f"Failed to delete request: {error_detail}")
# else:
# status = e.response.status_code if e.response is not None else 500
# raise HTTPException(status_code=status, detail=str(e))
# except httpx.RequestError as e:
# # Handle connection errors, timeouts, etc.
# raise HTTPException(status_code=503, detail=f"Queue service unavailable: {e}")
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")


@sm.on("connect")
async def connect(session_id: SESSION_ID, environ: Dict):
params = environ.get("QUERY_STRING")
Expand All @@ -149,14 +114,18 @@ async def connect(session_id: SESSION_ID, environ: Dict):


@sm.on("blocking_response")
async def blocking_response(session_id: SESSION_ID, client_session_id: SESSION_ID, data: Any):
async def blocking_response(
session_id: SESSION_ID, client_session_id: SESSION_ID, data: Any
):

await sm.emit("blocking_response", data=data, to=client_session_id)


@sm.on("stream")
async def stream(session_id: SESSION_ID, client_session_id: SESSION_ID, data: bytes, job_id: str):

async def stream(
session_id: SESSION_ID, client_session_id: SESSION_ID, data: bytes, job_id: str
):

await sm.enter_room(session_id, job_id)

await blocking_response(session_id, client_session_id, data)
Expand All @@ -183,65 +152,20 @@ async def response(id: REQUEST_ID) -> BackendResponseModel:
return BackendResponseModel.load(ObjectStoreProvider.object_store, id)


@app.get("/result/{id}")
async def result(id: REQUEST_ID) -> BackendResultModel:
"""Endpoint to retrieve result for id.

Args:
id: ID of request/response.

Returns:
BackendResultModel: Result.

Yields:
Iterator[BackendResultModel]: _description_
"""

# Get cursor to bytes stored in data backend.
object, content_length = BackendResultModel.load(ObjectStoreProvider.object_store, id, stream=True)

# Inform client the total size of result in bytes.
headers = {
"Content-length": str(content_length),
}

def stream():
try:
while True:
data = object.read(8192)
if not data:
break
yield data
finally:
object.close()

BackendResultModel.delete(ObjectStoreProvider.object_store, id)
BackendResponseModel.delete(ObjectStoreProvider.object_store, id)
BackendRequestModel.delete(ObjectStoreProvider.object_store, id)

return StreamingResponse(
content=stream(),
media_type="application/octet-stream",
headers=headers,
)


@app.get("/ping", status_code=200)
async def ping():
"""Endpoint to check if the server is online.
"""
"""Endpoint to check if the server is online."""
return "pong"


@app.get("/status", status_code=200)
async def status():

id = str(os.getpid())

await redis_client.lpush("status", id)
result = await redis_client.brpop(id)
return pickle.loads(result[1])



if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion src/services/api/src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ async def validate_python_version(python_version: str) -> str:
Raises:
HTTPException: If the Python version is missing or incompatible.
"""

if DEV_MODE:
return python_version

server_python_version = '.'.join(sys.version.split('.')[0:2]) # e.g. 3.12
user_python_version = '.'.join(python_version.split('.')[0:2])

Expand Down Expand Up @@ -85,6 +89,10 @@ async def validate_nnsight_version(nnsight_version: str) -> str:
Raises:
HTTPException: If the nnsight version is missing or incompatible.
"""

if DEV_MODE:
return nnsight_version

if nnsight_version == '':
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -143,7 +151,7 @@ async def validate_request(raw_request: Request) -> BackendRequestModel:
nnsight_version = raw_request.headers.get("nnsight-version", "")
python_version = raw_request.headers.get("python-version", "")

# Validate using existing dependency functions (call them directly, not as dependencies)
# # Validate using existing dependency functions (call them directly, not as dependencies)
await authenticate_api_key(api_key)
await validate_nnsight_version(nnsight_version)
await validate_python_version(python_version)
Expand Down
4 changes: 2 additions & 2 deletions src/services/api/src/gunicorn.conf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from multiprocessing import Process
from src.queue.coordinator import Coordinator
from src.queue.dispatcher import Dispatcher

def on_starting(server):

Process(target=Coordinator.start, daemon=False).start()
Process(target=Dispatcher.start, daemon=False).start()
Loading