Skip to content

Commit b6a6db2

Browse files
committed
feat: Enhance URL safety checks and integrate is_safe_url across image handling #3397
1 parent 7a4bf07 commit b6a6db2

File tree

4 files changed

+73
-30
lines changed

4 files changed

+73
-30
lines changed

etc/unittest/backend.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import unittest
44
import asyncio
5-
from unittest.mock import MagicMock
5+
import socket
6+
from unittest.mock import MagicMock, patch
67
from g4f.errors import MissingRequirementsError
78
try:
89
from g4f.gui.server.backend_api import Backend_Api
@@ -43,6 +44,24 @@ def test_get_providers(self):
4344
self.assertIsInstance(response, list)
4445
self.assertTrue(len(response) > 0)
4546

47+
@patch('g4f.gui.server.backend_api.socket.getaddrinfo')
48+
def test_is_safe_url_with_backslash_confusion(self, mock_getaddrinfo):
49+
mock_getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('127.0.0.1', 0))]
50+
from g4f.gui.server.backend_api import _is_safe_url
51+
self.assertFalse(_is_safe_url('http://127.0.0.1:6666\\@www.baidu.com'))
52+
53+
@patch('g4f.gui.server.backend_api.socket.getaddrinfo')
54+
def test_is_safe_url_blocks_private(self, mock_getaddrinfo):
55+
mock_getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('127.0.0.1', 0))]
56+
from g4f.gui.server.backend_api import _is_safe_url
57+
self.assertFalse(_is_safe_url('http://127.0.0.1'))
58+
59+
@patch('g4f.gui.server.backend_api.socket.getaddrinfo')
60+
def test_is_safe_url_allows_public(self, mock_getaddrinfo):
61+
mock_getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('8.8.8.8', 0))]
62+
from g4f.gui.server.backend_api import _is_safe_url
63+
self.assertTrue(_is_safe_url('http://example.com'))
64+
4665
def test_search(self):
4766
if not has_search:
4867
self.skipTest("import error")

g4f/gui/server/backend_api.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
import shutil
1212
import random
1313
import datetime
14-
import ipaddress
15-
import socket
1614
from hashlib import sha256
17-
from urllib.parse import quote_plus, urlparse
15+
from urllib.parse import quote_plus
1816
from functools import lru_cache
1917
from flask import Flask, Response, redirect, request, jsonify, send_from_directory
2018
from werkzeug.exceptions import NotFound
@@ -46,7 +44,7 @@
4644
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_tempfile
4745
from ...tools.run_tools import iter_run_tools
4846
from ...errors import ModelNotFoundError, ProviderNotFoundError, MissingAuthError, RateLimitError
49-
from ...image import is_allowed_extension, process_image, MEDIA_TYPE_MAP
47+
from ...image import is_allowed_extension, process_image, MEDIA_TYPE_MAP, is_safe_url as _is_safe_url
5048
from ...cookies import get_cookies_dir
5149
from ...image.copy_images import secure_filename, get_source_url, get_media_dir, copy_media
5250
from ...client.service import get_model_and_provider
@@ -60,30 +58,6 @@
6058

6159
_DATE_RE = re.compile(r'^\d{4}-\d{2}-\d{2}$')
6260

63-
def _is_safe_url(url: str) -> bool:
64-
"""Return True only for http/https URLs that do not point to private/loopback/reserved addresses."""
65-
try:
66-
parsed = urlparse(url)
67-
if parsed.scheme not in ("http", "https"):
68-
return False
69-
hostname = parsed.hostname
70-
if hostname is None:
71-
return False
72-
# Resolve all IP addresses for the hostname and reject if any is non-public.
73-
# Validating all addresses reduces the window for DNS rebinding attacks.
74-
addr_infos = socket.getaddrinfo(hostname, None)
75-
if not addr_infos:
76-
return False
77-
for addr_info in addr_infos:
78-
addr = ipaddress.ip_address(addr_info[4][0])
79-
if (addr.is_private or addr.is_loopback or addr.is_link_local
80-
or addr.is_reserved or addr.is_multicast or addr.is_unspecified):
81-
return False
82-
except Exception as e:
83-
logger.debug("URL safety check failed for %r: %s", url, e)
84-
return False
85-
return True
86-
8761
def safe_iter_generator(generator: Generator) -> Generator:
8862
start = next(generator)
8963
def iter_generator():

g4f/image/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44
import re
55
import io
66
import base64
7+
import socket
8+
import ipaddress
79
from io import BytesIO
810
from pathlib import Path
911
from typing import Optional
1012
from urllib.parse import urlparse
1113

1214
import requests
1315

16+
try:
17+
from urllib3.util import parse_url as urllib3_parse_url
18+
except ImportError:
19+
urllib3_parse_url = None
20+
1421
try:
1522
from PIL import Image, ImageOps
1623
has_requirements = True
@@ -103,6 +110,45 @@ def is_allowed_extension(filename: str) -> Optional[str]:
103110
return None
104111
return EXTENSIONS_MAP[extension]
105112

113+
114+
def is_safe_url(url: str) -> bool:
115+
"""Return True only for http/https URLs that do not point to private/loopback/reserved addresses."""
116+
try:
117+
parsed = urlparse(url)
118+
119+
if parsed.scheme not in ("http", "https"):
120+
return False
121+
122+
if "\\" in url:
123+
return False
124+
125+
hostname = parsed.hostname
126+
if hostname is None:
127+
return False
128+
129+
if urllib3_parse_url is not None:
130+
parsed_urllib3 = urllib3_parse_url(url)
131+
if parsed_urllib3.host and parsed_urllib3.host != hostname:
132+
return False
133+
hostname = parsed_urllib3.host or hostname
134+
135+
if hostname is None:
136+
return False
137+
138+
addr_infos = socket.getaddrinfo(hostname, None)
139+
if not addr_infos:
140+
return False
141+
142+
for addr_info in addr_infos:
143+
addr = ipaddress.ip_address(addr_info[4][0])
144+
if (addr.is_private or addr.is_loopback or addr.is_link_local
145+
or addr.is_reserved or addr.is_multicast or addr.is_unspecified):
146+
return False
147+
except Exception:
148+
return False
149+
return True
150+
151+
106152
def is_data_an_media(data, filename: str = None) -> str:
107153
content_type = is_data_an_audio(data, filename)
108154
if content_type is not None:
@@ -378,6 +424,8 @@ def to_bytes(image: ImageType) -> bytes:
378424
is_data_uri_an_image(image)
379425
return extract_data_uri(image)
380426
elif image.startswith("http://") or image.startswith("https://"):
427+
if not is_safe_url(image):
428+
raise ValueError("Invalid or disallowed media URL")
381429
path: str = urlparse(image).path
382430
if path.startswith("/files/"):
383431
path = get_bucket_dir(*path.split("/")[2:])

g4f/image/copy_images.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
1717
from ..tools.files import secure_filename
1818
from ..providers.response import ImageResponse, AudioResponse, VideoResponse, quote_url
19-
from . import is_accepted_format, extract_data_uri
19+
from . import is_accepted_format, extract_data_uri, is_safe_url
2020
from .. import debug
2121

2222
# Directory for storing generated media files
@@ -170,6 +170,8 @@ async def copy_image(image: str, target: str = None) -> str:
170170
with open(target_path, "wb") as f:
171171
f.write(extract_data_uri(image))
172172
elif not os.path.exists(target_path) or os.lstat(target_path).st_size <= 0:
173+
if not is_safe_url(image):
174+
raise ValueError("Invalid or disallowed media URL")
173175
# Use aiohttp to fetch the image
174176
async with session.get(image, ssl=ssl) as response:
175177
response.raise_for_status()

0 commit comments

Comments
 (0)