|
| 1 | +''' |
| 2 | +MIT License |
| 3 | +
|
| 4 | +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). |
| 5 | +Project: Harmony (https://harmonydata.ac.uk) |
| 6 | +Maintainer: Thomas Wood (https://fastdatascience.com) |
| 7 | +
|
| 8 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 9 | +of this software and associated documentation files (the "Software"), to deal |
| 10 | +in the Software without restriction, including without limitation the rights |
| 11 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 12 | +copies of the Software, and to permit persons to whom the Software is |
| 13 | +furnished to do so, subject to the following conditions: |
| 14 | +
|
| 15 | +The above copyright notice and this permission notice shall be included in all |
| 16 | +copies or substantial portions of the Software. |
| 17 | +
|
| 18 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 19 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 20 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 21 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 22 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 23 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 24 | +SOFTWARE. |
| 25 | +
|
| 26 | +''' |
| 27 | + |
| 28 | +import base64 |
| 29 | +import hashlib |
| 30 | +import requests |
| 31 | +import ssl |
| 32 | +import urllib.parse |
| 33 | +import uuid |
| 34 | +from datetime import datetime, timedelta |
| 35 | +from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments |
| 36 | +from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError |
| 37 | +from harmony.schemas.requests.text import RawFile, Instrument, FileType |
| 38 | +from pathlib import Path |
| 39 | +from requests.adapters import HTTPAdapter |
| 40 | +from typing import List, Dict |
| 41 | + |
| 42 | +MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB |
| 43 | +DOWNLOAD_TIMEOUT = 30 # seconds |
| 44 | +MAX_REDIRECTS = 5 |
| 45 | +ALLOWED_SCHEMES = {'https'} |
| 46 | +RATE_LIMIT_REQUESTS = 60 # requests per min |
| 47 | +RATE_LIMIT_WINDOW = 60 # seconds |
| 48 | + |
| 49 | +MIME_TO_FILE_TYPE = { |
| 50 | + 'application/pdf': FileType.pdf, |
| 51 | + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': FileType.xlsx, |
| 52 | + 'text/plain': FileType.txt, |
| 53 | + 'text/csv': FileType.csv, |
| 54 | + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': FileType.docx |
| 55 | +} |
| 56 | + |
| 57 | +EXT_TO_FILE_TYPE = { |
| 58 | + '.pdf': FileType.pdf, |
| 59 | + '.xlsx': FileType.xlsx, |
| 60 | + '.txt': FileType.txt, |
| 61 | + '.csv': FileType.csv, |
| 62 | + '.docx': FileType.docx |
| 63 | +} |
| 64 | + |
| 65 | + |
| 66 | +class URLDownloader: |
| 67 | + def __init__(self): |
| 68 | + self.rate_limit_storage: Dict[str, List[datetime]] = {} |
| 69 | + self.session = requests.Session() |
| 70 | + self.session.mount('https://', HTTPAdapter(max_retries=3)) |
| 71 | + self.session.verify = True |
| 72 | + |
| 73 | + def _check_rate_limit(self, domain: str) -> None: |
| 74 | + now = datetime.now() |
| 75 | + if domain not in self.rate_limit_storage: |
| 76 | + self.rate_limit_storage[domain] = [] |
| 77 | + |
| 78 | + self.rate_limit_storage[domain] = [ |
| 79 | + ts for ts in self.rate_limit_storage[domain] |
| 80 | + if ts > now - timedelta(seconds=RATE_LIMIT_WINDOW) |
| 81 | + ] |
| 82 | + |
| 83 | + if len(self.rate_limit_storage[domain]) >= RATE_LIMIT_REQUESTS: |
| 84 | + raise ConflictError("Rate limit exceeded") |
| 85 | + |
| 86 | + self.rate_limit_storage[domain].append(now) |
| 87 | + |
| 88 | + def _validate_url(self, url: str) -> None: |
| 89 | + try: |
| 90 | + parsed = urllib.parse.urlparse(url) |
| 91 | + |
| 92 | + if parsed.scheme not in ALLOWED_SCHEMES: |
| 93 | + raise BadRequestError(f"URL must use HTTPS") |
| 94 | + |
| 95 | + if not parsed.netloc or '.' not in parsed.netloc: |
| 96 | + raise BadRequestError("Invalid domain") |
| 97 | + |
| 98 | + if '..' in parsed.path or '//' in parsed.path: |
| 99 | + raise ForbiddenError("Path traversal detected") |
| 100 | + |
| 101 | + if parsed.fragment: |
| 102 | + raise BadRequestError("URL fragments not supported") |
| 103 | + |
| 104 | + blocked_domains = {'localhost', '127.0.0.1', '0.0.0.0'} |
| 105 | + if parsed.netloc in blocked_domains: |
| 106 | + raise ForbiddenError("Access to internal domains blocked") |
| 107 | + |
| 108 | + except Exception as e: |
| 109 | + raise BadRequestError(f"Invalid URL: {str(e)}") |
| 110 | + |
| 111 | + def _validate_ssl(self, response: requests.Response) -> None: |
| 112 | + cert = response.raw.connection.sock.getpeercert() |
| 113 | + if not cert: |
| 114 | + raise ForbiddenError("Invalid SSL certificate") |
| 115 | + |
| 116 | + not_after = ssl.cert_time_to_seconds(cert['notAfter']) |
| 117 | + if datetime.fromtimestamp(not_after) < datetime.now(): |
| 118 | + raise ForbiddenError("Expired SSL certificate") |
| 119 | + |
| 120 | + def _check_legal_headers(self, response: requests.Response) -> None: |
| 121 | + if response.headers.get('X-Robots-Tag', '').lower() == 'noindex': |
| 122 | + raise ForbiddenError("Access not allowed by robots directive") |
| 123 | + |
| 124 | + if 'X-Copyright' in response.headers: |
| 125 | + raise ForbiddenError("Content is copyright protected") |
| 126 | + |
| 127 | + if 'X-Terms-Of-Service' in response.headers: |
| 128 | + raise ForbiddenError("Terms of service acceptance required") |
| 129 | + |
| 130 | + def _validate_content_type(self, url: str, content_type: str) -> FileType: |
| 131 | + try: |
| 132 | + content_type = content_type.split(';')[0].lower() |
| 133 | + |
| 134 | + if content_type in MIME_TO_FILE_TYPE: |
| 135 | + return MIME_TO_FILE_TYPE[content_type] |
| 136 | + |
| 137 | + ext = Path(urllib.parse.urlparse(url).path).suffix.lower() |
| 138 | + if ext in EXT_TO_FILE_TYPE: |
| 139 | + return EXT_TO_FILE_TYPE[ext] |
| 140 | + |
| 141 | + raise BadRequestError(f"Unsupported file type: {content_type}") |
| 142 | + except BadRequestError: |
| 143 | + raise |
| 144 | + except Exception as e: |
| 145 | + raise BadRequestError(f"Error validating content type: {str(e)}") |
| 146 | + |
| 147 | + def download(self, url: str) -> RawFile: |
| 148 | + try: |
| 149 | + self._validate_url(url) |
| 150 | + domain = urllib.parse.urlparse(url).netloc |
| 151 | + self._check_rate_limit(domain) |
| 152 | + |
| 153 | + response = self.session.get( |
| 154 | + url, |
| 155 | + timeout=DOWNLOAD_TIMEOUT, |
| 156 | + stream=True, |
| 157 | + verify=True, |
| 158 | + allow_redirects=True, |
| 159 | + headers={ |
| 160 | + 'User-Agent': 'HarmonyBot/1.0 (+https://harmonydata.ac.uk)', |
| 161 | + 'Accept': ', '.join(MIME_TO_FILE_TYPE.keys()) |
| 162 | + } |
| 163 | + ) |
| 164 | + response.raise_for_status() |
| 165 | + |
| 166 | + self._validate_ssl(response) |
| 167 | + self._check_legal_headers(response) |
| 168 | + |
| 169 | + content_length = response.headers.get('content-length') |
| 170 | + if content_length and int(content_length) > MAX_FILE_SIZE: |
| 171 | + raise ForbiddenError(f"File too large: {content_length} bytes (max {MAX_FILE_SIZE})") |
| 172 | + |
| 173 | + file_type = self._validate_content_type(url, response.headers.get('content-type', '')) |
| 174 | + |
| 175 | + hasher = hashlib.sha256() |
| 176 | + content = b'' |
| 177 | + for chunk in response.iter_content(chunk_size=8192): |
| 178 | + hasher.update(chunk) |
| 179 | + content += chunk |
| 180 | + |
| 181 | + if file_type in [FileType.pdf, FileType.xlsx, FileType.docx]: |
| 182 | + content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode( |
| 183 | + 'ascii') |
| 184 | + else: |
| 185 | + content_str = content.decode('utf-8') |
| 186 | + |
| 187 | + return RawFile( |
| 188 | + file_id=str(uuid.uuid4()), |
| 189 | + file_name=Path(urllib.parse.urlparse(url).path).name or "downloaded_file", |
| 190 | + file_type=file_type, |
| 191 | + content=content_str, |
| 192 | + metadata={ |
| 193 | + 'content_hash': hasher.hexdigest(), |
| 194 | + 'download_timestamp': datetime.now().isoformat(), |
| 195 | + 'source_url': url |
| 196 | + } |
| 197 | + ) |
| 198 | + |
| 199 | + except (BadRequestError, ForbiddenError, ConflictError): |
| 200 | + raise |
| 201 | + except requests.Timeout: |
| 202 | + raise SomethingWrongError("Download timeout") |
| 203 | + except requests.TooManyRedirects: |
| 204 | + raise ForbiddenError("Too many redirects") |
| 205 | + except requests.RequestException as e: |
| 206 | + if e.response is not None: |
| 207 | + if e.response.status_code == 401: |
| 208 | + raise ForbiddenError("Resource requires authentication") |
| 209 | + elif e.response.status_code == 403: |
| 210 | + raise ForbiddenError("Access forbidden") |
| 211 | + elif e.response.status_code == 429: |
| 212 | + raise ConflictError("Rate limit exceeded") |
| 213 | + raise SomethingWrongError(f"Download error: {str(e)}") |
| 214 | + except Exception as e: |
| 215 | + raise SomethingWrongError(f"Unexpected error: {str(e)}") |
| 216 | + |
| 217 | + |
| 218 | +def load_instruments_from_url(url: str) -> List[Instrument]: |
| 219 | + downloader = URLDownloader() |
| 220 | + raw_file = downloader.download(url) |
| 221 | + return convert_files_to_instruments([raw_file]) |
0 commit comments