1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import ipaddress
1516import json
17+ import os
18+ import socket
1619from typing import TYPE_CHECKING , Any
20+ from urllib .parse import urlparse
21+
22+ from ..extras .misc import is_env_enabled
23+ from ..extras .packages import is_fastapi_available
24+
25+
26+ if is_fastapi_available ():
27+ from fastapi import HTTPException , status
1728
1829
1930if TYPE_CHECKING :
2031 from pydantic import BaseModel
2132
2233
34+ SAFE_MEDIA_PATH = os .environ .get ("SAFE_MEDIA_PATH" , os .path .join (os .path .dirname (__file__ ), "safe_media" ))
35+ ALLOW_LOCAL_FILES = is_env_enabled ("ALLOW_LOCAL_FILES" , "1" )
36+
37+
2338def dictify (data : "BaseModel" ) -> dict [str , Any ]:
2439 try : # pydantic v2
2540 return data .model_dump (exclude_unset = True )
@@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
3247 return json .dumps (data .model_dump (exclude_unset = True ), ensure_ascii = False )
3348 except AttributeError : # pydantic v1
3449 return data .json (exclude_unset = True , ensure_ascii = False )
50+
51+
52+ def check_lfi_path (path : str ) -> None :
53+ """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
54+ if not ALLOW_LOCAL_FILES :
55+ raise HTTPException (status_code = status .HTTP_403_FORBIDDEN , detail = "Local file access is disabled." )
56+
57+ try :
58+ os .makedirs (SAFE_MEDIA_PATH , exist_ok = True )
59+ real_path = os .path .realpath (path )
60+ safe_path = os .path .realpath (SAFE_MEDIA_PATH )
61+
62+ if not real_path .startswith (safe_path ):
63+ raise HTTPException (
64+ status_code = status .HTTP_403_FORBIDDEN , detail = "File access is restricted to the safe media directory."
65+ )
66+ except Exception :
67+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Invalid or inaccessible file path." )
68+
69+
70+ def check_ssrf_url (url : str ) -> None :
71+ """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
72+ try :
73+ parsed_url = urlparse (url )
74+ if parsed_url .scheme not in ["http" , "https" ]:
75+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Only HTTP/HTTPS URLs are allowed." )
76+
77+ hostname = parsed_url .hostname
78+ if not hostname :
79+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Invalid URL hostname." )
80+
81+ ip_info = socket .getaddrinfo (hostname , parsed_url .port )
82+ ip_address_str = ip_info [0 ][4 ][0 ]
83+ ip = ipaddress .ip_address (ip_address_str )
84+
85+ if not ip .is_global :
86+ raise HTTPException (
87+ status_code = status .HTTP_403_FORBIDDEN ,
88+ detail = "Access to private or reserved IP addresses is not allowed." ,
89+ )
90+
91+ except socket .gaierror :
92+ raise HTTPException (
93+ status_code = status .HTTP_400_BAD_REQUEST , detail = f"Could not resolve hostname: { parsed_url .hostname } "
94+ )
95+ except Exception as e :
96+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = f"Invalid URL: { e } " )
0 commit comments