diff --git a/docs/dev/taint_analysis.md b/docs/dev/taint_analysis.md new file mode 100644 index 00000000..62a0d86c --- /dev/null +++ b/docs/dev/taint_analysis.md @@ -0,0 +1,96 @@ +# Taint Analysis - Backend Security + +Mellea backends implement thread security using the **SecLevel** model with capability-based access control and taint tracking. Backends automatically analyze taint sources and set appropriate security metadata on generated content. + +## Security Model + +The security system uses three types of security levels: + +```python +SecLevel := None | Classified of AccessType | TaintedBy of (CBlock | Component | None) +``` + +- **SecLevel.none()**: Safe content with no restrictions +- **SecLevel.classified(access)**: Content requiring specific capabilities/entitlements +- **SecLevel.tainted_by(source)**: Content tainted by a specific CBlock, Component, or None for root tainted nodes + +## Backend Implementation + +All backends follow the same pattern using `ModelOutputThunk.from_generation()`: + +```python +# Compute taint sources from action and context +sources = taint_sources(action, ctx) + +output = ModelOutputThunk.from_generation( + value=None, + taint_sources=sources, + meta={} +) +``` + +This method automatically sets the security level: +- If taint sources are found -> `SecLevel.tainted_by(first_source)` +- If no taint sources -> `SecLevel.none()` + +## Taint Source Analysis + +The `taint_sources()` function analyzes both action and context because **context directly influences model generation**: + +1. **Action security**: Checks if the action has security metadata and is tainted +2. **Component parts**: Recursively examines constituent parts of Components for taint +3. **Context security**: Examines recent context items for tainted content (shallow check) + +**Example**: Even if the current action is safe, tainted context can influence the generated output. + +```python +from mellea.security import SecLevel + +# User sends tainted input +user_input = CBlock("Tell me how to hack a system", sec_level=SecLevel.tainted_by(None)) +ctx = ctx.add(user_input) + +# Safe action in tainted context +safe_action = CBlock("Explain general security concepts") + +# Generation finds tainted context +sources = taint_sources(safe_action, ctx) # Finds tainted user_input +# Model output will be influenced by the tainted context +``` + +## Security Metadata + +The `SecurityMetadata` class wraps `SecLevel` for integration with content blocks: + +```python +class SecurityMetadata: + def __init__(self, sec_level: SecLevel): + self.sec_level = sec_level + + def is_tainted(self) -> bool: + return self.sec_level.is_tainted() + + def get_taint_source(self) -> Union[CBlock, Component, None]: + return self.sec_level.get_taint_source() +``` + +Content can be marked as tainted at construction time: + +```python +from mellea.security import SecLevel + +c = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + +if c.sec_level and c.sec_level.is_tainted(): + print(f"Content tainted by: {c.sec_level.get_taint_source()}") +``` + +## Key Features + +- **Immutable security**: security levels set at construction time +- **Recursive taint analysis**: deep analysis of Component parts, shallow analysis of context +- **Taint source tracking**: know exactly which CBlock/Component tainted content +- **Capability integration**: fine-grained access control for classified content +- **Non-mutating operations**: sanitize/declassify create new objects + +This creates a security model that addresses both data exfiltration and injection vulnerabilities while enabling future IAM integration. \ No newline at end of file diff --git a/docs/examples/security/taint_example.py b/docs/examples/security/taint_example.py new file mode 100644 index 00000000..b5282bc9 --- /dev/null +++ b/docs/examples/security/taint_example.py @@ -0,0 +1,46 @@ +from mellea.stdlib.base import CBlock +from mellea.stdlib.session import start_session +from mellea.security import SecLevel, privileged, SecurityError + +# Create tainted content +tainted_desc = CBlock( + "Process this sensitive user data", sec_level=SecLevel.tainted_by(None) +) + +print( + f"Original CBlock is tainted: {tainted_desc.sec_level.is_tainted() if tainted_desc.sec_level else False}" +) + +# Create session +session = start_session() + +# Use tainted CBlock in session.instruct +print("Testing session.instruct with tainted CBlock...") +result = session.instruct(description=tainted_desc) + +# The result should be tainted +print( + f"Result is tainted: {result.sec_level.is_tainted() if result.sec_level else False}" +) +if result.sec_level and result.sec_level.is_tainted(): + taint_source = result.sec_level.get_taint_source() + print(f"Taint source: {taint_source}") + print("✅ SUCCESS: Taint preserved!") +else: + print("❌ FAIL: Result should be tainted but isn't!") + + +# Mock privileged function that requires un-tainted input +@privileged +def process_un_tainted_data(data: CBlock) -> str: + """A function that requires un-tainted input.""" + return f"Processed: {data.value}" + + +print("\nTesting privileged function with tainted result...") +try: + # This should raise a SecurityError + processed = process_un_tainted_data(result) + print("❌ FAIL: Should have raised SecurityError!") +except SecurityError as e: + print(f"✅ SUCCESS: SecurityError raised - {e}") diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 555431c5..f08036e0 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -28,6 +28,7 @@ chat_completion_delta_merge, extract_model_tool_requests, ) +from mellea.security import taint_sources from mellea.stdlib.base import ( CBlock, Component, @@ -309,7 +310,12 @@ async def _generate_from_chat_context_standard( **model_specific_options, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, taint_sources=sources, meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 713acdd7..2fe528bf 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -25,6 +25,7 @@ ) from mellea.helpers.event_loop_helper import _run_async_in_thread from mellea.helpers.fancy_logger import FancyLogger +from mellea.security import taint_sources from mellea.stdlib.base import ( CBlock, Component, @@ -354,7 +355,12 @@ async def generate_from_chat_context( format=_format.model_json_schema() if _format is not None else None, ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, taint_sources=sources, meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts @@ -433,11 +439,14 @@ async def generate_from_raw( result = None error = None if isinstance(response, BaseException): - result = ModelOutputThunk(value="") + result = ModelOutputThunk.from_generation( + value="", taint_sources=taint_sources(actions[i], None), meta={} + ) error = response else: - result = ModelOutputThunk( + result = ModelOutputThunk.from_generation( value=response.response, + taint_sources=taint_sources(actions[i], None), meta={ "generate_response": response.model_dump(), "usage": { diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index ba825753..0016126d 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -46,6 +46,7 @@ chat_completion_delta_merge, extract_model_tool_requests, ) +from mellea.security import taint_sources from mellea.stdlib.base import ( CBlock, Component, @@ -645,7 +646,12 @@ async def _generate_from_chat_context_standard( ), ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + + output = ModelOutputThunk.from_generation( + value=None, taint_sources=sources, meta={} + ) output._context = linearized_context output._action = action output._model_options = model_opts @@ -833,6 +839,8 @@ async def generate_from_raw( output = ModelOutputThunk(response.text) output._context = None # There is no context for generate_from_raw for now output._action = action + # TODO: add taint sources to the ModelOutputThunk + # output._taint_sources = taint_sources(action, None) output._model_options = model_opts output._meta = { "oai_completion_response": response.model_dump(), diff --git a/mellea/security/__init__.py b/mellea/security/__init__.py new file mode 100644 index 00000000..25d88edc --- /dev/null +++ b/mellea/security/__init__.py @@ -0,0 +1,25 @@ +"""Security module for mellea. + +This module provides security features for tracking and managing the security +level of content blocks and components in the mellea library. +""" + +from .core import ( + AccessType, + SecLevel, + SecurityError, + SecurityMetadata, + declassify, + privileged, + taint_sources, +) + +__all__ = [ + "AccessType", + "SecLevel", + "SecurityError", + "SecurityMetadata", + "declassify", + "privileged", + "taint_sources", +] diff --git a/mellea/security/core.py b/mellea/security/core.py new file mode 100644 index 00000000..fc00f3f5 --- /dev/null +++ b/mellea/security/core.py @@ -0,0 +1,324 @@ +"""Core security functionality for mellea. + +This module provides the fundamental security classes and functions for +tracking security levels of content blocks and enforcing security policies. +""" + +import abc +import functools +from collections.abc import Callable +from enum import Enum +from typing import Any, Generic, TypeVar, Union + +from mellea.stdlib.base import CBlock, Component + +T = TypeVar("T") + + +class SecLevelType(str, Enum): + """Security level type constants.""" + + NONE = "none" + CLASSIFIED = "classified" + TAINTED_BY = "tainted_by" + + +class AccessType(Generic[T], abc.ABC): + """Abstract base class for access-based security. + + This trait allows integration with IAM systems and provides fine-grained + access control based on entitlements rather than coarse security levels. + """ + + @abc.abstractmethod + def has_access(self, entitlement: T | None) -> bool: + """Check if the given entitlement has access. + + Args: + entitlement: The entitlement to check (e.g., user role, IAM identifier) + + Returns: + True if the entitlement has access, False otherwise + """ + + +class SecLevel(Generic[T]): + """Security level with access-based control and taint tracking. + + SecLevel := None | Classified of AccessType | TaintedBy of (CBlock | Component) + """ + + def __init__(self, level_type: SecLevelType | str, data: Any = None): + """Initialize security level. + + Args: + level_type: Type of security level (SecLevelType enum or string) + data: Associated data (AccessType for classified, CBlock/Component/None for tainted_by) + """ + # Convert string to enum if needed for backward compatibility + if isinstance(level_type, str): + level_type = SecLevelType(level_type) + self.level_type = level_type + self.data = data + + @classmethod + def none(cls) -> "SecLevel": + """Create a SecLevel with no restrictions (safe).""" + return cls(SecLevelType.NONE) + + @classmethod + def classified(cls, access_type: AccessType[T]) -> "SecLevel": + """Create a SecLevel with classified access requirements.""" + return cls(SecLevelType.CLASSIFIED, access_type) + + @classmethod + def tainted_by(cls, source: CBlock | Component | None) -> "SecLevel": + """Create a SecLevel tainted by a specific CBlock, Component, or None for root nodes.""" + return cls(SecLevelType.TAINTED_BY, source) + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.level_type == SecLevelType.TAINTED_BY + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.level_type == SecLevelType.CLASSIFIED + + def get_access_type(self) -> AccessType[T] | None: + """Get the AccessType for classified content. + + Returns: + The AccessType if this is classified, None otherwise + """ + if self.level_type == SecLevelType.CLASSIFIED: + return self.data + return None + + def get_taint_source(self) -> CBlock | Component | None: + """Get the source of taint if this is a tainted level. + + Returns: + The CBlock or Component that tainted this content, or None + """ + if self.level_type == SecLevelType.TAINTED_BY: + return self.data + return None + + +class SecurityMetadata: + """Metadata for tracking security properties of content blocks.""" + + def __init__(self, sec_level: SecLevel): + """Initialize security metadata with a SecLevel. + + Args: + sec_level: The security level for this content + """ + self.sec_level = sec_level + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.sec_level.is_tainted() + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.sec_level.is_classified() + + def get_access_type(self) -> AccessType[Any] | None: + """Get the AccessType for classified content. + + Returns: + The AccessType if this is classified, None otherwise + """ + return self.sec_level.get_access_type() + + def get_taint_source(self) -> CBlock | Component | None: + """Get the source of taint if this is a tainted level. + + Returns: + The CBlock or Component that tainted this content, or None + """ + return self.sec_level.get_taint_source() + + +class SecurityError(Exception): + """Exception raised for security-related errors.""" + + +def taint_sources(action: Component | CBlock, ctx: Any) -> list[CBlock | Component]: + """Compute taint sources from action and context. + + This function examines the action and context to determine what + security sources might be present. It performs recursive analysis + of Component parts and shallow analysis of context to identify + potential taint sources and returns the actual objects that are tainted. + + Args: + action: The action component or content block + ctx: The context containing previous interactions + + Returns: + List of tainted CBlocks or Components + """ + sources = [] + + # Check if action has security metadata and is tainted + if hasattr(action, "_meta") and "_security" in action._meta: + security_meta = action._meta["_security"] + if isinstance(security_meta, SecurityMetadata) and security_meta.is_tainted(): + sources.append(action) + + # For Components, check their constituent parts for taint + # Use pattern matching: CBlock doesn't have parts, Components do + match action: + case CBlock(): + # CBlock doesn't have parts, nothing to do + pass + case _ if isinstance(action, Component): + # Component is @runtime_checkable, so isinstance() works + # If it's a Component, it has parts() method by protocol definition + try: + parts = action.parts() + for part in parts: + if hasattr(part, "_meta") and "_security" in part._meta: + security_meta = part._meta["_security"] + if ( + isinstance(security_meta, SecurityMetadata) + and security_meta.is_tainted() + ): + sources.append(part) + except Exception: + # If parts() fails, continue without it + pass + + # Check context for tainted content (shallow check) + if hasattr(ctx, "as_list"): + try: + context_items = ctx.as_list(last_n_components=5) # Limit to recent items + for item in context_items: + if hasattr(item, "_meta") and "_security" in item._meta: + security_meta = item._meta["_security"] + if ( + isinstance(security_meta, SecurityMetadata) + and security_meta.is_tainted() + ): + sources.append(item) + except Exception: + # If context analysis fails, continue without it + pass + + return sources + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def privileged(func: F) -> F: + """Decorator to mark functions that require safe (non-tainted, non-classified) input. + + Functions decorated with @privileged will raise SecurityError if + called with tainted or classified content blocks. + + Args: + func: The function to decorate + + Returns: + The decorated function + + Raises: + SecurityError: If the function is called with tainted or classified content + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Check all arguments for marked content (tainted or classified) + for arg in args: + if ( + isinstance(arg, CBlock) + and hasattr(arg, "_meta") + and "_security" in arg._meta + ): + security_meta = arg._meta["_security"] + if isinstance(security_meta, SecurityMetadata): + if security_meta.is_tainted(): + taint_source = security_meta.get_taint_source() + source_info = ( + f" (tainted by: {type(taint_source).__name__})" + if taint_source + else "" + ) + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content{source_info}" + ) + elif security_meta.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content" + ) + + # Check keyword arguments for marked content (tainted or classified) + for key, value in kwargs.items(): + if ( + isinstance(value, CBlock) + and hasattr(value, "_meta") + and "_security" in value._meta + ): + security_meta = value._meta["_security"] + if isinstance(security_meta, SecurityMetadata): + if security_meta.is_tainted(): + taint_source = security_meta.get_taint_source() + source_info = ( + f" (tainted by: {type(taint_source).__name__})" + if taint_source + else "" + ) + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content in argument '{key}'{source_info}" + ) + elif security_meta.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content in argument '{key}'" + ) + + return func(*args, **kwargs) + + return wrapper # type: ignore + + +def declassify(cblock: CBlock) -> CBlock: + """Create a declassified version of a CBlock (non-mutating). + + This function creates a new CBlock with the same content but marked + as safe (SecLevel.none()). The original CBlock is not modified. + + Args: + cblock: The CBlock to declassify + + Returns: + A new CBlock with safe security level + """ + # Create new meta dict with safe security + new_meta = cblock._meta.copy() if cblock._meta else {} + new_meta["_security"] = SecurityMetadata(SecLevel.none()) + + # Return new CBlock with same content but new security metadata + return CBlock(cblock.value, new_meta) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 111d44f6..a37c2e18 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -12,7 +12,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import Any, Protocol, TypeVar, Union, runtime_checkable from PIL import Image as PILImage @@ -22,8 +22,19 @@ class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" - def __init__(self, value: str | None, meta: dict[str, Any] | None = None): - """Initializes the CBlock with a string and some metadata.""" + def __init__( + self, + value: str | None, + meta: dict[str, Any] | None = None, + sec_level: Any = None, + ): + """Initializes the CBlock with a string and some metadata. + + Args: + value: The string content of the block + meta: Optional metadata dictionary + sec_level: Optional SecLevel for security metadata + """ if value is not None and not isinstance(value, str): raise TypeError("value to a Cblock should always be a string or None") self._underlying_value = value @@ -31,6 +42,12 @@ def __init__(self, value: str | None, meta: dict[str, Any] | None = None): meta = {} self._meta = meta + # Set security metadata if sec_level is provided + if sec_level is not None: + from mellea.security import SecurityMetadata + + self._meta["_security"] = SecurityMetadata(sec_level) + @property def value(self) -> str | None: """Gets the value of the block.""" @@ -49,6 +66,24 @@ def __repr__(self): """Provides a python-parsable representation of the block (usually).""" return f"CBlock({self.value}, {self._meta.__repr__()})" + @property + def sec_level(self) -> Any | None: + """Get the security metadata for this CBlock. + + Returns: + SecurityMetadata if present, None otherwise + """ + from mellea.security import SecurityMetadata + + if self._meta is None or "_security" not in self._meta: + return None + + security_meta = self._meta["_security"] + if isinstance(security_meta, SecurityMetadata): + return security_meta + + return None + class ImageBlock: """A `ImageBlock` represents an image (as base64 PNG).""" @@ -351,6 +386,42 @@ def __repr__(self): """ return f"ModelOutputThunk({self.value})" + @classmethod + def from_generation( + cls, + value: str | None, + taint_sources: list[CBlock | Component] | None = None, + meta: dict[str, Any] | None = None, + parsed_repr: CBlock | Component | Any | None = None, + tool_calls: dict[str, ModelToolCall] | None = None, + ) -> ModelOutputThunk: + """Create a ModelOutputThunk from generation with security metadata. + + Args: + value: The generated content + taint_sources: List of tainted CBlocks or Components from the generation context + meta: Additional metadata for the thunk + parsed_repr: Parsed representation of the output + tool_calls: Tool calls made during generation + + Returns: + A new ModelOutputThunk with appropriate security metadata + """ + if meta is None: + meta = {} + + # Add security metadata based on taint sources + from mellea.security import SecLevel, SecurityMetadata + + if taint_sources: + # If there are taint sources, mark as tainted by the first source + meta["_security"] = SecurityMetadata(SecLevel.tainted_by(taint_sources[0])) + else: + # If no taint sources, mark as safe + meta["_security"] = SecurityMetadata(SecLevel.none()) + + return cls(value, meta, parsed_repr, tool_calls) + def __copy__(self): """Returns a shallow copy of the ModelOutputThunk. A copied ModelOutputThunk cannot be used for generation; don't copy over fields associated with generating.""" copied = ModelOutputThunk( diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index e758be04..f3d8411c 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -113,7 +113,7 @@ def act( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -134,7 +134,7 @@ def instruct( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -154,7 +154,7 @@ def instruct( def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index f8d07efb..a084c189 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -119,11 +119,31 @@ def __init__( self._images = images self._repair_string: str | None = None - def parts(self): + def parts(self) -> list[Component | CBlock]: """Returns all of the constituent parts of an Instruction.""" - raise Exception( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + parts = [] + + # Add description if it exists + if self._description is not None: + parts.append(self._description) + + # Add prefix if it exists + if self._prefix is not None: + parts.append(self._prefix) + + # Add output_prefix if it exists + if self._output_prefix is not None: + parts.append(self._output_prefix) + + # Add icl_examples + parts.extend(self._icl_examples) + + # Add grounding_context values + for value in self._grounding_context.values(): + if isinstance(value, CBlock | Component): + parts.append(value) + + return parts def format_for_llm(self) -> TemplateRepresentation: """Formats the instruction for Formatter use.""" diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 91d1be24..180e372c 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -311,7 +311,7 @@ def act( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -330,7 +330,7 @@ def instruct( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -348,7 +348,7 @@ def instruct( def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -366,7 +366,7 @@ def instruct( """Generates from an instruction. Args: - description: The description of the instruction. + description: The description of the instruction (str or CBlock). requirements: A list of requirements that the instruction can be validated against. icl_examples: A list of in-context-learning examples that the instruction can be validated against. grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. diff --git a/test/stdlib_basics/test_security_comprehensive.py b/test/stdlib_basics/test_security_comprehensive.py new file mode 100644 index 00000000..22e4041c --- /dev/null +++ b/test/stdlib_basics/test_security_comprehensive.py @@ -0,0 +1,442 @@ +"""Comprehensive security tests for mellea thread security features.""" + +import pytest +from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext +from mellea.stdlib.instruction import Instruction +from mellea.security import ( + AccessType, + SecLevel, + SecurityMetadata, + SecurityError, + privileged, + declassify, + taint_sources, +) + + +class TestAccessType: + """Test AccessType functionality.""" + + def test_access_type_interface(self): + """Test that AccessType is an abstract base class.""" + with pytest.raises(TypeError): + AccessType() # Should not be instantiable directly + + def test_access_type_implementation(self): + """Test implementing AccessType.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + assert access.has_access("admin") + assert not access.has_access("user") + assert not access.has_access(None) + + +class TestSecLevel: + """Test SecLevel functionality.""" + + def test_sec_level_none(self): + """Test SecLevel.none() creates safe level.""" + from mellea.security.core import SecLevelType + + sec_level = SecLevel.none() + assert sec_level.level_type == SecLevelType.NONE + assert not sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_access_type() is None + + def test_sec_level_tainted_by(self): + """Test SecLevel.tainted_by() creates tainted level.""" + from mellea.security.core import SecLevelType + + source = CBlock("source content") + sec_level = SecLevel.tainted_by(source) + assert sec_level.level_type == SecLevelType.TAINTED_BY + assert sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_taint_source() is source + assert sec_level.get_access_type() is None + + def test_sec_level_classified(self): + """Test SecLevel.classified() creates classified level.""" + from mellea.security.core import SecLevelType + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + assert sec_level.level_type == SecLevelType.CLASSIFIED + assert not sec_level.is_tainted() + assert sec_level.is_classified() + assert sec_level.get_access_type() is access + assert sec_level.get_access_type().has_access("admin") + assert not sec_level.get_access_type().has_access("user") + assert not sec_level.get_access_type().has_access(None) + + +class TestCBlockSecurity: + """Test CBlock security functionality.""" + + def test_cblock_mark_tainted(self): + """Test marking CBlock as tainted.""" + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + assert cblock.sec_level is not None + assert isinstance(cblock.sec_level, SecurityMetadata) + assert cblock.sec_level.is_tainted() + assert not cblock.sec_level.is_classified() + assert cblock.sec_level.get_access_type() is None + + def test_cblock_mark_tainted_by_source(self): + """Test marking CBlock as tainted by another source.""" + source = CBlock("source content") + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(source)) + + assert cblock.sec_level.is_tainted() + assert cblock.sec_level.get_taint_source() is source + + def test_cblock_default_safe(self): + """Test that CBlock defaults to safe when no security metadata.""" + cblock = CBlock("test content") + assert cblock.sec_level is None or ( + not cblock.sec_level.is_tainted() and not cblock.sec_level.is_classified() + ) + + def test_cblock_with_classified_metadata(self): + """Test CBlock with classified security metadata.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + + cblock = CBlock("classified content", sec_level=sec_level) + + assert cblock.sec_level.is_classified() + access_type = cblock.sec_level.get_access_type() + assert access_type is not None + assert access_type.has_access("admin") + assert not access_type.has_access("user") + assert not access_type.has_access(None) + + +class TestDeclassify: + """Test declassify function.""" + + def test_declassify_creates_new_object(self): + """Test that declassify creates a new object without mutating original.""" + from mellea.security.core import SecLevelType + + original = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + declassified = declassify(original) + + # Objects are different + assert original is not declassified + assert id(original) != id(declassified) + + # Content is preserved + assert original.value == declassified.value + + # Security levels are different + assert original.sec_level.is_tainted() + assert not declassified.sec_level.is_tainted() + assert not declassified.sec_level.is_classified() + assert declassified.sec_level.sec_level.level_type == SecLevelType.NONE + + # Original is unchanged + assert original.sec_level.is_tainted() + + def test_declassify_preserves_other_metadata(self): + """Test that declassify preserves other metadata.""" + from mellea.security.core import SecLevelType + + original = CBlock( + "test content", + meta={"custom": "value", "other": 123}, + sec_level=SecLevel.tainted_by(None), + ) + + declassified = declassify(original) + + assert declassified._meta["custom"] == "value" + assert declassified._meta["other"] == 123 + assert declassified.sec_level.sec_level.level_type == SecLevelType.NONE + + +class TestPrivilegedDecorator: + """Test @privileged decorator functionality.""" + + def test_privileged_accepts_safe_input(self): + """Test that privileged functions accept safe input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + safe_cblock = CBlock("safe content") + + result = safe_function(safe_cblock) + assert result == "Processed: safe content" + + def test_privileged_accepts_declassified_input(self): + """Test that privileged functions accept declassified input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + declassified_cblock = declassify(tainted_cblock) + + result = safe_function(declassified_cblock) + assert result == "Processed: tainted content" + + def test_privileged_rejects_tainted_input(self): + """Test that privileged functions reject tainted input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(tainted_cblock) + + def test_privileged_rejects_classified_input(self): + """Test that privileged functions reject classified input without proper entitlement.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + security_meta = SecurityMetadata(sec_level) + + classified_cblock = CBlock( + "classified content", meta={"_security": security_meta} + ) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(classified_cblock) + + def test_privileged_accepts_no_security_metadata(self): + """Test that privileged functions accept input with no security metadata.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + cblock = CBlock("content") + + result = safe_function(cblock) + assert result == "Processed: content" + + def test_privileged_with_kwargs(self): + """Test privileged function with keyword arguments.""" + + @privileged + def safe_function(data: CBlock, prefix: str = "Processed: ") -> str: + return f"{prefix}{data.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="argument 'data'"): + safe_function(data=tainted_cblock) + + +class TestTaintSources: + """Test taint source computation.""" + + def test_taint_sources_from_tainted_action(self): + """Test taint sources from tainted action.""" + action = CBlock("tainted action", sec_level=SecLevel.tainted_by(None)) + + sources = taint_sources(action, None) + assert len(sources) == 1 + assert sources[0] is action + + def test_taint_sources_from_safe_action(self): + """Test taint sources from safe action.""" + action = CBlock("safe action") + # No security metadata - defaults to safe + + sources = taint_sources(action, None) + assert len(sources) == 0 + + def test_taint_sources_from_context(self): + """Test taint sources from context.""" + action = CBlock("safe action") + + # Create context with tainted content + ctx = ChatContext() + tainted_cblock = CBlock("tainted context", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 1 + assert sources[0] is tainted_cblock + + def test_taint_sources_empty(self): + """Test taint sources with no tainted content.""" + action = CBlock("safe action") + ctx = ChatContext() + safe_cblock = CBlock("safe context") + # No security metadata - defaults to safe + ctx = ctx.add(safe_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 0 + + def test_taint_sources_from_component_parts(self): + """Test taint sources from Component parts.""" + # Create Instruction with tainted description + tainted_desc = CBlock( + "tainted description", sec_level=SecLevel.tainted_by(None) + ) + instruction = Instruction(description=tainted_desc) + + sources = taint_sources(instruction, None) + assert len(sources) == 1 + assert sources[0] is tainted_desc + + def test_taint_sources_shallow_search_limit(self): + """Test that shallow search only checks last 5 components.""" + action = CBlock("safe action") + + # Create context with 7 items: tainted at positions 0 and 5 + ctx = ChatContext() + tainted_early = CBlock("tainted early", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_early) # Position 0 - outside last 5 + + # Add 4 safe items + for i in range(4): + ctx = ctx.add(CBlock(f"safe {i}")) + + tainted_late = CBlock("tainted late", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_late) # Position 5 - within last 5 + + # Add one more safe item + ctx = ctx.add(CBlock("safe final")) # Position 6 + + sources = taint_sources(action, ctx) + # Should only find tainted_late (position 5), not tainted_early (position 0) + assert len(sources) == 1 + assert sources[0] is tainted_late + + +class TestModelOutputThunkSecurity: + """Test ModelOutputThunk security functionality.""" + + def test_from_generation_with_taint_sources(self): + """Test ModelOutputThunk.from_generation with taint sources.""" + from mellea.security.core import SecLevelType + + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + mot = ModelOutputThunk.from_generation( + value="generated content", + taint_sources=[taint_source], + meta={"custom": "value"}, + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + assert mot.sec_level.get_taint_source() is taint_source + + def test_from_generation_without_taint_sources(self): + """Test ModelOutputThunk.from_generation without taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk.from_generation( + value="generated content", taint_sources=None, meta={"custom": "value"} + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + def test_from_generation_empty_taint_sources(self): + """Test ModelOutputThunk.from_generation with empty taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk.from_generation( + value="generated content", taint_sources=[], meta={"custom": "value"} + ) + + assert mot.sec_level.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + +class TestSecurityIntegration: + """Test integration between security components.""" + + def test_security_flow_through_generation(self): + """Test security metadata flows through generation pipeline.""" + from mellea.security.core import SecLevelType + + # Create tainted input + tainted_input = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + + # Simulate generation with taint sources + sources = taint_sources(tainted_input, None) + mot = ModelOutputThunk.from_generation( + value="model response", taint_sources=sources + ) + + # Verify output is tainted + assert mot.sec_level.is_tainted() + + # Declassify the output + safe_mot = declassify(mot) + assert not safe_mot.sec_level.is_tainted() + assert not safe_mot.sec_level.is_classified() + assert safe_mot.sec_level.sec_level.level_type == SecLevelType.NONE + + # Verify original is unchanged + assert mot.sec_level.is_tainted() + + def test_privileged_function_with_generated_content(self): + """Test privileged function with generated content.""" + + @privileged + def process_response(mot: ModelOutputThunk) -> str: + return f"Processed: {mot.value}" + + # Generate tainted content + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + mot = ModelOutputThunk.from_generation( + value="tainted response", taint_sources=[taint_source] + ) + + # Privileged function should reject tainted content + with pytest.raises(SecurityError): + process_response(mot) + + # Declassify and try again + safe_mot = declassify(mot) + result = process_response(safe_mot) + assert result == "Processed: tainted response"