Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions kag/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import Any, Union
from jinja2 import Environment, FileSystemLoader, Template
from stat import S_IWUSR as OWNER_WRITE_PERMISSION
from tenacity import retry, stop_after_attempt
from tenacity import retry, stop_after_attempt, wait_exponential
from aiolimiter import AsyncLimiter

reset = "\033[0m"
Expand Down Expand Up @@ -279,16 +279,25 @@ def generate_hash_id(value):
return hasher.hexdigest()


@retry(stop=stop_after_attempt(3), reraise=True)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True,
)
def download_from_http(url: str, dest: str = None) -> str:
"""Downloads a file from an HTTP URL and saves it to a temporary directory.

This function uses the requests library to download a file from the specified
HTTP URL and saves it to the system's temporary directory. After the download
is complete, it returns the local path of the downloaded file.

The function includes retry logic with exponential backoff to handle transient
network errors and service unavailability (e.g., MinIO 503 errors).

Args:
url (str): The HTTP URL of the file to be downloaded.
dest (str, optional): The destination path for the downloaded file.
If not specified, a temporary file will be created.

Returns:
str: The local path of the downloaded file.
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/common/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

import os
import tempfile
from unittest.mock import patch, Mock
import requests
import pytest

from kag.common.utils import download_from_http


def test_download_from_http_success():
"""Test successful download from HTTP URL."""
# Use a small test file from W3C
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
result = download_from_http(url)

assert os.path.exists(result)
assert result.endswith("dummy.pdf")
assert os.path.getsize(result) > 0

# Clean up
if os.path.exists(result):
os.remove(result)


def test_download_from_http_with_dest():
"""Test download with specified destination."""
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
dest_dir = tempfile.gettempdir()
dest_path = os.path.join(dest_dir, "test_download.pdf")

result = download_from_http(url, dest_path)

assert result == dest_path
assert os.path.exists(result)
assert os.path.getsize(result) > 0

# Clean up
if os.path.exists(result):
os.remove(result)


@patch("kag.common.utils.requests.get")
def test_download_from_http_retry_on_503(mock_get):
"""Test that download_from_http retries on 503 errors with exponential backoff."""
# Mock response for 503 error
mock_response_503 = Mock()
mock_response_503.status_code = 503
mock_response_503.raise_for_status.side_effect = requests.exceptions.HTTPError(
"503 Server Error: Service Unavailable"
)

# Mock successful response
mock_response_success = Mock()
mock_response_success.status_code = 200
mock_response_success.iter_content = lambda chunk_size: [b"test content"]

# First two calls return 503, third call succeeds
mock_get.side_effect = [mock_response_503, mock_response_503, mock_response_success]

# Should succeed after retries
result = download_from_http("http://example.com/test.txt")

# Verify it was called 3 times (2 failures + 1 success)
assert mock_get.call_count == 3
assert os.path.exists(result)

# Clean up
if os.path.exists(result):
os.remove(result)


@patch("kag.common.utils.requests.get")
def test_download_from_http_max_retries_exceeded(mock_get):
"""Test that download_from_http raises error after max retries."""
# Mock response that always returns 503
mock_response = Mock()
mock_response.status_code = 503
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"503 Server Error: Service Unavailable"
)

mock_get.return_value = mock_response

# Should raise HTTPError after 3 attempts
with pytest.raises(requests.exceptions.HTTPError):
download_from_http("http://example.com/test.txt")

# Verify it was called 3 times (max retries)
assert mock_get.call_count == 3