|
| 1 | +""" |
| 2 | +Utility functions for OpenAI API adapter. |
| 3 | +""" |
| 4 | + |
| 5 | +import logging |
| 6 | +from typing import Dict, List |
| 7 | + |
| 8 | +import jinja2.nodes |
| 9 | +import transformers.utils.chat_template_utils as hf_chat_utils |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | +# ============================================================================ |
| 14 | +# JINJA TEMPLATE CONTENT FORMAT DETECTION |
| 15 | +# ============================================================================ |
| 16 | +# |
| 17 | +# This adapts vLLM's approach for detecting chat template content format: |
| 18 | +# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313 |
| 19 | +# - Analyzes Jinja template AST to detect content iteration patterns |
| 20 | +# - 'openai' format: templates with {%- for content in message['content'] -%} loops |
| 21 | +# - 'string' format: templates that expect simple string content |
| 22 | +# - Processes content accordingly to match template expectations |
| 23 | + |
| 24 | + |
| 25 | +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: |
| 26 | + """Check if node is a variable access like {{ varname }}""" |
| 27 | + if isinstance(node, jinja2.nodes.Name): |
| 28 | + return node.ctx == "load" and node.name == varname |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: |
| 33 | + """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}""" |
| 34 | + if isinstance(node, jinja2.nodes.Getitem): |
| 35 | + return ( |
| 36 | + _is_var_access(node.node, varname) |
| 37 | + and isinstance(node.arg, jinja2.nodes.Const) |
| 38 | + and node.arg.value == key |
| 39 | + ) |
| 40 | + |
| 41 | + if isinstance(node, jinja2.nodes.Getattr): |
| 42 | + return _is_var_access(node.node, varname) and node.attr == key |
| 43 | + |
| 44 | + return False |
| 45 | + |
| 46 | + |
| 47 | +def _is_var_or_elems_access( |
| 48 | + node: jinja2.nodes.Node, |
| 49 | + varname: str, |
| 50 | + key: str = None, |
| 51 | +) -> bool: |
| 52 | + """Check if node accesses varname or varname[key] with filters/tests""" |
| 53 | + if isinstance(node, jinja2.nodes.Filter): |
| 54 | + return node.node is not None and _is_var_or_elems_access( |
| 55 | + node.node, varname, key |
| 56 | + ) |
| 57 | + if isinstance(node, jinja2.nodes.Test): |
| 58 | + return _is_var_or_elems_access(node.node, varname, key) |
| 59 | + |
| 60 | + if isinstance(node, jinja2.nodes.Getitem) and isinstance( |
| 61 | + node.arg, jinja2.nodes.Slice |
| 62 | + ): |
| 63 | + return _is_var_or_elems_access(node.node, varname, key) |
| 64 | + |
| 65 | + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) |
| 66 | + |
| 67 | + |
| 68 | +def _try_extract_ast(chat_template: str): |
| 69 | + """Try to parse the Jinja template into an AST""" |
| 70 | + try: |
| 71 | + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) |
| 72 | + return jinja_compiled.environment.parse(chat_template) |
| 73 | + except Exception as e: |
| 74 | + logger.debug(f"Error when compiling Jinja template: {e}") |
| 75 | + return None |
| 76 | + |
| 77 | + |
| 78 | +def detect_template_content_format(chat_template: str) -> str: |
| 79 | + """ |
| 80 | + Detect whether a chat template expects 'string' or 'openai' content format. |
| 81 | +
|
| 82 | + - 'string': content is a simple string (like DeepSeek templates) |
| 83 | + - 'openai': content is a list of structured dicts (like Llama4 templates) |
| 84 | +
|
| 85 | + Detection logic: |
| 86 | + - If template has loops like {%- for content in message['content'] -%} → 'openai' |
| 87 | + - Otherwise → 'string' |
| 88 | + """ |
| 89 | + jinja_ast = _try_extract_ast(chat_template) |
| 90 | + if jinja_ast is None: |
| 91 | + return "string" |
| 92 | + |
| 93 | + try: |
| 94 | + # Look for patterns like: {%- for content in message['content'] -%} |
| 95 | + for loop_ast in jinja_ast.find_all(jinja2.nodes.For): |
| 96 | + loop_iter = loop_ast.iter |
| 97 | + |
| 98 | + # Check if iterating over message['content'] or similar |
| 99 | + if _is_var_or_elems_access(loop_iter, "message", "content"): |
| 100 | + return "openai" # Found content iteration → openai format |
| 101 | + |
| 102 | + return "string" # No content loops found → string format |
| 103 | + except Exception as e: |
| 104 | + logger.debug(f"Error when parsing AST of Jinja template: {e}") |
| 105 | + return "string" |
| 106 | + |
| 107 | + |
| 108 | +def process_content_for_template_format( |
| 109 | + msg_dict: dict, |
| 110 | + content_format: str, |
| 111 | + image_data: list, |
| 112 | + audio_data: list, |
| 113 | + modalities: list, |
| 114 | +) -> dict: |
| 115 | + """ |
| 116 | + Process message content based on detected template format. |
| 117 | +
|
| 118 | + Args: |
| 119 | + msg_dict: Message dictionary with content |
| 120 | + content_format: 'string' or 'openai' (detected via AST analysis) |
| 121 | + image_data: List to append extracted image URLs |
| 122 | + audio_data: List to append extracted audio URLs |
| 123 | + modalities: List to append modalities |
| 124 | +
|
| 125 | + Returns: |
| 126 | + Processed message dictionary |
| 127 | + """ |
| 128 | + if not isinstance(msg_dict.get("content"), list): |
| 129 | + # Already a string or None, no processing needed |
| 130 | + return {k: v for k, v in msg_dict.items() if v is not None} |
| 131 | + |
| 132 | + if content_format == "openai": |
| 133 | + # OpenAI format: preserve structured content list, normalize types |
| 134 | + processed_content_parts = [] |
| 135 | + for chunk in msg_dict["content"]: |
| 136 | + if isinstance(chunk, dict): |
| 137 | + chunk_type = chunk.get("type") |
| 138 | + |
| 139 | + if chunk_type == "image_url": |
| 140 | + image_data.append(chunk["image_url"]["url"]) |
| 141 | + if chunk.get("modalities"): |
| 142 | + modalities.append(chunk.get("modalities")) |
| 143 | + # Normalize to simple 'image' type for template compatibility |
| 144 | + processed_content_parts.append({"type": "image"}) |
| 145 | + elif chunk_type == "audio_url": |
| 146 | + audio_data.append(chunk["audio_url"]["url"]) |
| 147 | + # Normalize to simple 'audio' type |
| 148 | + processed_content_parts.append({"type": "audio"}) |
| 149 | + else: |
| 150 | + # Keep other content as-is (text, etc.) |
| 151 | + processed_content_parts.append(chunk) |
| 152 | + |
| 153 | + new_msg = { |
| 154 | + k: v for k, v in msg_dict.items() if v is not None and k != "content" |
| 155 | + } |
| 156 | + new_msg["content"] = processed_content_parts |
| 157 | + return new_msg |
| 158 | + |
| 159 | + else: # content_format == "string" |
| 160 | + # String format: flatten to text only (for templates like DeepSeek) |
| 161 | + text_parts = [] |
| 162 | + for chunk in msg_dict["content"]: |
| 163 | + if isinstance(chunk, dict) and chunk.get("type") == "text": |
| 164 | + text_parts.append(chunk["text"]) |
| 165 | + # Note: For string format, we ignore images/audio since the template |
| 166 | + # doesn't expect structured content - multimodal placeholders would |
| 167 | + # need to be inserted differently |
| 168 | + |
| 169 | + new_msg = msg_dict.copy() |
| 170 | + new_msg["content"] = " ".join(text_parts) if text_parts else "" |
| 171 | + new_msg = {k: v for k, v in new_msg.items() if v is not None} |
| 172 | + return new_msg |
0 commit comments