Skip to content

Commit 95b7188

Browse files
d3do-23d3dohiyouga
authored
Merge commit from fork
* fix lfi and ssrf * move utils to common --------- Co-authored-by: d3do <[email protected]> Co-authored-by: hiyouga <[email protected]>
1 parent d5bb4e6 commit 95b7188

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/llamafactory/api/chat.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
2727
from ..extras.misc import is_env_enabled
2828
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
29-
from .common import dictify, jsonify
29+
from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
3030
from .protocol import (
3131
ChatCompletionMessage,
3232
ChatCompletionResponse,
@@ -121,8 +121,10 @@ def _process_request(
121121
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
122122
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
123123
elif os.path.isfile(image_url): # local file
124+
check_lfi_path(image_url)
124125
image_stream = open(image_url, "rb")
125126
else: # web uri
127+
check_ssrf_url(image_url)
126128
image_stream = requests.get(image_url, stream=True).raw
127129

128130
images.append(Image.open(image_stream).convert("RGB"))
@@ -132,8 +134,10 @@ def _process_request(
132134
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
133135
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
134136
elif os.path.isfile(video_url): # local file
137+
check_lfi_path(video_url)
135138
video_stream = video_url
136139
else: # web uri
140+
check_ssrf_url(video_url)
137141
video_stream = requests.get(video_url, stream=True).raw
138142

139143
videos.append(video_stream)
@@ -143,8 +147,10 @@ def _process_request(
143147
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
144148
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
145149
elif os.path.isfile(audio_url): # local file
150+
check_lfi_path(audio_url)
146151
audio_stream = audio_url
147152
else: # web uri
153+
check_ssrf_url(audio_url)
148154
audio_stream = requests.get(audio_url, stream=True).raw
149155

150156
audios.append(audio_stream)

src/llamafactory/api/common.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import ipaddress
1516
import json
17+
import os
18+
import socket
1619
from typing import TYPE_CHECKING, Any
20+
from urllib.parse import urlparse
21+
22+
from ..extras.misc import is_env_enabled
23+
from ..extras.packages import is_fastapi_available
24+
25+
26+
if is_fastapi_available():
27+
from fastapi import HTTPException, status
1728

1829

1930
if TYPE_CHECKING:
2031
from pydantic import BaseModel
2132

2233

34+
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
35+
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
36+
37+
2338
def dictify(data: "BaseModel") -> dict[str, Any]:
2439
try: # pydantic v2
2540
return data.model_dump(exclude_unset=True)
@@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
3247
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
3348
except AttributeError: # pydantic v1
3449
return data.json(exclude_unset=True, ensure_ascii=False)
50+
51+
52+
def check_lfi_path(path: str) -> None:
53+
"""Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
54+
if not ALLOW_LOCAL_FILES:
55+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
56+
57+
try:
58+
os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
59+
real_path = os.path.realpath(path)
60+
safe_path = os.path.realpath(SAFE_MEDIA_PATH)
61+
62+
if not real_path.startswith(safe_path):
63+
raise HTTPException(
64+
status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
65+
)
66+
except Exception:
67+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
68+
69+
70+
def check_ssrf_url(url: str) -> None:
71+
"""Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
72+
try:
73+
parsed_url = urlparse(url)
74+
if parsed_url.scheme not in ["http", "https"]:
75+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
76+
77+
hostname = parsed_url.hostname
78+
if not hostname:
79+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
80+
81+
ip_info = socket.getaddrinfo(hostname, parsed_url.port)
82+
ip_address_str = ip_info[0][4][0]
83+
ip = ipaddress.ip_address(ip_address_str)
84+
85+
if not ip.is_global:
86+
raise HTTPException(
87+
status_code=status.HTTP_403_FORBIDDEN,
88+
detail="Access to private or reserved IP addresses is not allowed.",
89+
)
90+
91+
except socket.gaierror:
92+
raise HTTPException(
93+
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
94+
)
95+
except Exception as e:
96+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")

0 commit comments

Comments
 (0)