Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import io
import mimetypes
from typing import Any
from urllib.parse import urlparse

import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter

from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from llama_stack.providers.utils.memory.vector_store import ALLOW_FILE_URI, parse_data_url, read_file_uri
from llama_stack_api import (
URL,
Files,
Expand Down Expand Up @@ -49,7 +50,9 @@
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parsed = urlparse(doc.content.uri)

if parsed.scheme == "data":
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
Expand All @@ -60,12 +63,22 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
file_data = data.encode("utf-8")

return file_data, mime_type
else:
elif parsed.scheme == "file":
if not ALLOW_FILE_URI:
log.warning(
f"Attempt to use file:// URI blocked. LLAMA_STACK_ALLOW_FILE_URI is not set. URI: {doc.content.uri}"
)
raise ValueError("file:// URIs are not allowed. Please use the Files API (/v1/files) to upload files.")
content, guessed_mime = await read_file_uri(doc.content.uri)
return content, guessed_mime or "application/octet-stream"
elif parsed.scheme in ("http", "https"):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
raise ValueError(f"Unsupported URI scheme: {parsed.scheme}")
else:
if isinstance(doc.content, str):
content_str = doc.content
Expand Down
77 changes: 60 additions & 17 deletions src/llama_stack/providers/utils/memory/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
import os
import re
import stat
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import unquote
from urllib.parse import unquote, urlparse

import httpx
import numpy as np
Expand Down Expand Up @@ -54,6 +59,10 @@ class ChunkForDeletion(BaseModel):
RERANKER_TYPE_WEIGHTED = "weighted"
RERANKER_TYPE_NORMALIZED = "normalized"

# Maximum file size for file:// URIs (default 100MB, configurable via env)
MAX_FILE_URI_SIZE_BYTES = int(os.environ.get("LLAMA_STACK_MAX_FILE_URI_SIZE_MB", "100")) * 1024 * 1024
ALLOW_FILE_URI = os.environ.get("LLAMA_STACK_ALLOW_FILE_URI", "false").lower() in ("true", "1", "yes")


def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
Expand Down Expand Up @@ -131,29 +140,63 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return ""


async def read_file_uri(uri: str, max_size: int | None = None) -> tuple[bytes, str | None]:
parsed = urlparse(uri)
file_path = unquote(parsed.path)
real_path = os.path.realpath(file_path)
filename = os.path.basename(real_path)

try:
file_stat = os.stat(real_path)
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {filename}") from None

if stat.S_ISDIR(file_stat.st_mode):
raise IsADirectoryError(f"Cannot read directory: {filename}")
if not stat.S_ISREG(file_stat.st_mode):
raise ValueError(f"Not a regular file: {filename}")

file_size = file_stat.st_size
size_limit = max_size if max_size is not None else MAX_FILE_URI_SIZE_BYTES
if file_size > size_limit:
raise ValueError(f"File too large: {file_size} bytes exceeds limit of {size_limit} bytes")

content = await asyncio.to_thread(Path(real_path).read_bytes)
mime_type, _ = mimetypes.guess_type(real_path)
return content, mime_type


async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
uri = doc.content.uri
elif isinstance(doc.content, str):
uri = doc.content
else:
return interleaved_content_as_str(doc.content)

if uri.startswith("data:"):
return content_from_data(uri)

if uri.startswith("file://"):
if not ALLOW_FILE_URI:
raise ValueError(
"file:// URIs disabled. Use Files API (/v1/files) instead, or set LLAMA_STACK_ALLOW_FILE_URI=true."
)
content, guessed_mime = await read_file_uri(uri)
mime = doc.mime_type or guessed_mime
return parse_pdf(content) if mime == "application/pdf" else content.decode("utf-8")

if uri.startswith("http://") or uri.startswith("https://"):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r = await client.get(uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text

if isinstance(doc.content, str):
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)

raise ValueError(f"Unsupported URL scheme: {uri}")


def make_overlapped_chunks(
Expand Down
139 changes: 138 additions & 1 deletion tests/unit/providers/utils/memory/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
from llama_stack.providers.utils.memory.vector_store import (
content_from_data_and_mime_type,
content_from_doc,
read_file_uri,
)
from llama_stack_api import URL, RAGDocument, TextContentItem


Expand Down Expand Up @@ -215,3 +221,134 @@ async def test_memory_tool_error_handling():
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2


@pytest.mark.parametrize(
"content_factory",
[
lambda uri: uri, # string content
lambda uri: URL(uri=uri), # URL object content
],
ids=["string_content", "url_object_content"],
)
async def test_file_uri_rejected_by_default(monkeypatch, content_factory):
"""Test that file:// URIs are rejected when ALLOW_FILE_URI is not set (both string and URL object)."""
monkeypatch.delenv("LLAMA_STACK_ALLOW_FILE_URI", raising=False)

# Need to reimport to pick up env change
import importlib

from llama_stack.providers.utils.memory import vector_store

importlib.reload(vector_store)

with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Test content")
temp_path = f.name

try:
file_uri = f"file://{temp_path}"
mock_doc = RAGDocument(document_id="test", content=content_factory(file_uri))

with pytest.raises(ValueError) as exc_info:
await vector_store.content_from_doc(mock_doc)

error_message = str(exc_info.value)
assert "file:// URIs disabled" in error_message
assert "Files API" in error_message
assert "/v1/files" in error_message
assert "LLAMA_STACK_ALLOW_FILE_URI" in error_message
finally:
os.unlink(temp_path)


async def test_file_uri_allowed_when_enabled(monkeypatch, tmp_path):
"""Test that file:// URIs work when ALLOW_FILE_URI=true."""
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

# Need to reimport to pick up env change
import importlib

from llama_stack.providers.utils.memory import vector_store

importlib.reload(vector_store)

# Create a temp file
test_file = tmp_path / "test.txt"
test_file.write_text("Hello from allowed file URI!")

file_uri = f"file://{test_file}"
mock_doc = RAGDocument(document_id="test", content=file_uri)

result = await vector_store.content_from_doc(mock_doc)
assert result == "Hello from allowed file URI!"


class TestReadFileUri:
async def test_read_file_uri_basic(self, monkeypatch):
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("Hello from file URI!")
temp_path = f.name

try:
content, mime_type = await read_file_uri(f"file://{temp_path}")
assert content == b"Hello from file URI!"
assert mime_type == "text/plain"
finally:
os.unlink(temp_path)

async def test_read_file_uri_with_spaces(self, monkeypatch):
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, prefix="test file ") as f:
f.write("Content with spaces in path")
temp_path = f.name

try:
encoded_uri = f"file://{temp_path.replace(' ', '%20')}"
content, mime_type = await read_file_uri(encoded_uri)
assert content == b"Content with spaces in path"
finally:
os.unlink(temp_path)

async def test_read_file_uri_file_not_found(self, monkeypatch):
"""Test FileNotFoundError with sanitized message (no path disclosure)."""
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

with pytest.raises(FileNotFoundError) as exc_info:
await read_file_uri("file:///nonexistent/path/to/secret/file.txt")

error_message = str(exc_info.value)
assert "file.txt" in error_message
assert "/nonexistent/path/to/secret" not in error_message

async def test_read_file_uri_directory_error(self, monkeypatch):
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

with tempfile.TemporaryDirectory() as temp_dir:
with pytest.raises(IsADirectoryError) as exc_info:
await read_file_uri(f"file://{temp_dir}")

error_message = str(exc_info.value)
dir_name = os.path.basename(temp_dir)
assert dir_name in error_message

async def test_read_file_uri_size_limit(self, monkeypatch):
monkeypatch.setenv("LLAMA_STACK_ALLOW_FILE_URI", "true")

with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
f.write(b"x" * 1000) # 1KB file
temp_path = f.name

try:
with pytest.raises(ValueError) as exc_info:
await read_file_uri(f"file://{temp_path}", max_size=500)

error_message = str(exc_info.value)
assert "too large" in error_message
assert "1000 bytes" in error_message
assert "500 bytes" in error_message
finally:
os.unlink(temp_path)
Loading