Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from llama_stack.providers.utils.memory.vector_store import parse_data_url, read_file_uri
from llama_stack_api import (
URL,
Files,
Expand Down Expand Up @@ -60,12 +60,14 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
file_data = data.encode("utf-8")

return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
if doc.content.uri.startswith("file://"):
content, guessed_mime = await read_file_uri(doc.content.uri)
return content, guessed_mime or "application/octet-stream"
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
Expand Down
42 changes: 41 additions & 1 deletion src/llama_stack/providers/utils/memory/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
import os
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import unquote
from urllib.parse import unquote, urlparse

import httpx
import numpy as np
Expand Down Expand Up @@ -54,6 +58,9 @@ class ChunkForDeletion(BaseModel):
RERANKER_TYPE_WEIGHTED = "weighted"
RERANKER_TYPE_NORMALIZED = "normalized"

# Maximum file size for file:// URIs (default 100MB, configurable via env)
MAX_FILE_URI_SIZE_BYTES = int(os.environ.get("LLAMA_STACK_MAX_FILE_URI_SIZE_MB", "100")) * 1024 * 1024


def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
Expand Down Expand Up @@ -131,10 +138,37 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return ""


async def read_file_uri(uri: str, max_size: int | None = None) -> tuple[bytes, str | None]:
parsed = urlparse(uri)
file_path = unquote(parsed.path)
real_path = os.path.realpath(file_path)
filename = os.path.basename(real_path)

if os.path.isdir(real_path):
raise IsADirectoryError(f"Cannot read directory: {filename}")
if not os.path.isfile(real_path):
raise FileNotFoundError(f"File not found: {filename}")

file_size = os.path.getsize(real_path)
size_limit = max_size if max_size is not None else MAX_FILE_URI_SIZE_BYTES
if file_size > size_limit:
raise ValueError(f"File too large: {file_size} bytes exceeds limit of {size_limit} bytes")

content = await asyncio.to_thread(Path(real_path).read_bytes)
mime_type, _ = mimetypes.guess_type(real_path)
return content, mime_type


async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
if doc.content.uri.startswith("file://"):
content, guessed_mime = await read_file_uri(doc.content.uri)
mime = doc.mime_type or guessed_mime
if mime == "application/pdf":
return parse_pdf(content)
return content.decode("utf-8")
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
Expand All @@ -145,6 +179,12 @@ async def content_from_doc(doc: RAGDocument) -> str:
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
if doc.content.startswith("file://"):
content, guessed_mime = await read_file_uri(doc.content)
mime = doc.mime_type or guessed_mime
if mime == "application/pdf":
return parse_pdf(content)
return content.decode("utf-8")
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
Expand Down
68 changes: 67 additions & 1 deletion tests/unit/providers/utils/memory/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
from llama_stack.providers.utils.memory.vector_store import (
content_from_data_and_mime_type,
content_from_doc,
read_file_uri,
)
from llama_stack_api import URL, RAGDocument, TextContentItem


Expand Down Expand Up @@ -215,3 +221,63 @@ async def test_memory_tool_error_handling():
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2


class TestReadFileUri:
async def test_read_file_uri_basic(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Hello from file URI!")
temp_path = f.name

try:
content, mime_type = await read_file_uri(f"file://{temp_path}")
assert content == b"Hello from file URI!"
assert mime_type == "text/plain"
finally:
os.unlink(temp_path)

async def test_read_file_uri_with_spaces(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, prefix="test file ") as f:
f.write("Content with spaces in path")
temp_path = f.name

try:
encoded_uri = f"file://{temp_path.replace(' ', '%20')}"
content, mime_type = await read_file_uri(encoded_uri)
assert content == b"Content with spaces in path"
finally:
os.unlink(temp_path)

async def test_read_file_uri_file_not_found(self):
"""Test FileNotFoundError with sanitized message (no path disclosure)."""
with pytest.raises(FileNotFoundError) as exc_info:
await read_file_uri("file:///nonexistent/path/to/secret/file.txt")

error_message = str(exc_info.value)
assert "file.txt" in error_message
assert "/nonexistent/path/to/secret" not in error_message

async def test_read_file_uri_directory_error(self):
with tempfile.TemporaryDirectory() as temp_dir:
with pytest.raises(IsADirectoryError) as exc_info:
await read_file_uri(f"file://{temp_dir}")

error_message = str(exc_info.value)
dir_name = os.path.basename(temp_dir)
assert dir_name in error_message

async def test_read_file_uri_size_limit(self):
with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
f.write(b"x" * 1000) # 1KB file
temp_path = f.name

try:
with pytest.raises(ValueError) as exc_info:
await read_file_uri(f"file://{temp_path}", max_size=500)

error_message = str(exc_info.value)
assert "too large" in error_message
assert "1000 bytes" in error_message
assert "500 bytes" in error_message
finally:
os.unlink(temp_path)
Loading