diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9c370bc0f8..410f2bb7c3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -93,22 +93,32 @@ Common contribution types include: `doc`, `code`, `bug`, and `ideas`. See the fu ## Development environment -We use the ["scripts to rule them all"](https://github.blog/engineering/engineering-principles/scripts-to-rule-them-all/) philosophy to manage common tasks across the project. These are mostly backed by a Makefile that contains the implementation. - You'll need the following dependencies installed to build Cog locally: -- [Go](https://golang.org/doc/install): We're targeting 1.24, but you can install the latest version since Go is backwards compatible. If you're using a newer Mac with an M1 chip, be sure to download the `darwin-arm64` installer package. Alternatively you can run `brew install go` which will automatically detect and use the appropriate installer for your system architecture. -- [uv](https://docs.astral.sh/uv/): Python versions and dependencies are managed by uv. + +- [Go](https://golang.org/doc/install): We're targeting 1.23, but you can install the latest version since Go is backwards compatible. If you're using a newer Mac with an M1 chip, be sure to download the `darwin-arm64` installer package. Alternatively you can run `brew install go` which will automatically detect and use the appropriate installer for your system architecture. +- [uv](https://docs.astral.sh/uv/): Python versions and dependencies are managed by uv, both in development and container environments. - [Docker](https://docs.docker.com/desktop) or [OrbStack](https://orbstack.dev) Install the Python dependencies: script/setup -Once you have Go installed you can install the cog binary by running: +Once you have Go installed, run: + + make install + +This will build and install the `cog` binary to `/usr/local/bin/cog`. You can then use it to build and run models. + +## Package Management + +Cog uses [uv](https://docs.astral.sh/uv/) for Python package management, both in development and container environments. This provides: - make install PREFIX=$(go env GOPATH) +- Fast, reliable package installation +- Consistent dependency resolution +- Efficient caching +- Reproducible builds -This installs the `cog` binary to `$GOPATH/bin/cog`. +When building containers, uv is automatically installed and used to install Python packages from requirements.txt files. The cache is mounted at `/srv/r8/uv/cache` to speed up subsequent builds. To run ALL the tests: diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index 546200cf08..2a68e6c30a 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -92,7 +92,7 @@ func NewStandardGenerator(config *config.Config, dir string, command command.Com Config: config, Dir: dir, GOOS: runtime.GOOS, - GOARCH: runtime.GOOS, + GOARCH: runtime.GOARCH, tmpDir: tmpDir, relativeTmpDir: relativeTmpDir, fileWalker: filepath.Walk, @@ -414,18 +414,17 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq & git \ ca-certificates \ && rm -rf /var/lib/apt/lists/* -` + fmt.Sprintf(` + +ENV UV_CACHE_DIR="/srv/r8/uv/cache" RUN --mount=type=cache,target=/root/.cache/pip curl -s -S -L https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash && \ git clone https://github.com/momo-lab/pyenv-install-latest.git "$(pyenv root)"/plugins/pyenv-install-latest && \ export PYTHON_CONFIGURE_OPTS='--enable-optimizations --with-lto' && \ export PYTHON_CFLAGS='-O3' && \ - pyenv install-latest "%s" && \ - pyenv global $(pyenv install-latest --print "%s") && \ - pip install "wheel<1"`, py, py) + ` + pyenv install-latest "` + py + `" && \ + pyenv global $(pyenv install-latest --print "` + py + `") && \ + curl -LsSf https://astral.sh/uv/install.sh | sh + RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /usr/bin/python3 && chmod +x /usr/bin/python3", nil - // for sitePackagesLocation, kind of need to determine which specific version latest is (3.8 -> 3.8.17 or 3.8.18) - // install-latest essentially does pyenv install --list | grep $py | tail -1 - // there are many bad options, but a symlink to $(pyenv prefix) is the least bad one } func (g *StandardGenerator) installCog() (string, error) { @@ -451,7 +450,7 @@ func (g *StandardGenerator) installCog() (string, error) { cmds := []string{ "ENV R8_COG_VERSION=coglet", "ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion, - "RUN pip install " + m.LatestCoglet.URL, + "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install " + m.LatestCoglet.URL, } return strings.Join(cmds, "\n"), nil } @@ -469,13 +468,13 @@ func (g *StandardGenerator) installCog() (string, error) { if err != nil { return "", err } - pipInstallLine := "RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir" - pipInstallLine += " " + containerPath - pipInstallLine += " 'pydantic>=1.9,<3'" + uvInstallLine := "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install --no-cache-dir" + uvInstallLine += " " + containerPath + uvInstallLine += " 'pydantic>=1.9,<3'" if g.strip { - pipInstallLine += " && " + StripDebugSymbolsCommand + uvInstallLine += " && " + StripDebugSymbolsCommand } - lines = append(lines, CFlags, pipInstallLine, "ENV CFLAGS=") + lines = append(lines, CFlags, uvInstallLine, "ENV CFLAGS=") return strings.Join(lines, "\n"), nil } @@ -509,14 +508,14 @@ func (g *StandardGenerator) pipInstalls() (string, error) { return "", err } - pipInstallLine := "RUN --mount=type=cache,target=/root/.cache/pip pip install -r " + containerPath + uvInstallLine := "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r " + containerPath if g.strip { - pipInstallLine += " && " + StripDebugSymbolsCommand + uvInstallLine += " && " + StripDebugSymbolsCommand } return strings.Join([]string{ copyLine[0], CFlags, - pipInstallLine, + uvInstallLine, "ENV CFLAGS=", }, "\n"), nil } diff --git a/pkg/dockerfile/standard_generator_test.go b/pkg/dockerfile/standard_generator_test.go index c73662a6d6..3ee797d049 100644 --- a/pkg/dockerfile/standard_generator_test.go +++ b/pkg/dockerfile/standard_generator_test.go @@ -47,7 +47,7 @@ func testInstallCog(relativeTmpDir string, stripped bool) string { } return fmt.Sprintf(`COPY %s/%s /tmp/%s ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" -RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir /tmp/%s 'pydantic>=1.9,<3'%s +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install --no-cache-dir /tmp/%s 'pydantic>=1.9,<3'%s ENV CFLAGS=`, relativeTmpDir, wheel, wheel, wheel, strippedCall) } @@ -73,13 +73,15 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq & git \ ca-certificates \ && rm -rf /var/lib/apt/lists/* + +ENV UV_CACHE_DIR="/srv/r8/uv/cache" RUN --mount=type=cache,target=/root/.cache/pip curl -s -S -L https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash && \ git clone https://github.com/momo-lab/pyenv-install-latest.git "$(pyenv root)"/plugins/pyenv-install-latest && \ export PYTHON_CONFIGURE_OPTS='--enable-optimizations --with-lto' && \ export PYTHON_CFLAGS='-O3' && \ pyenv install-latest "%s" && \ pyenv global $(pyenv install-latest --print "%s") && \ - pip install "wheel<1" + curl -LsSf https://astral.sh/uv/install.sh | sh `, version, version) } @@ -414,7 +416,7 @@ ENV NVIDIA_DRIVER_CAPABILITIES=all ` + testInstallPython("3.12") + `RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /usr/bin/python3 && chmod +x /usr/bin/python3" + ` COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" -RUN --mount=type=cache,target=/root/.cache/pip pip install -r /tmp/requirements.txt +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r /tmp/requirements.txt ENV CFLAGS= ` + testInstallCog(gen.relativeTmpDir, gen.strip) + ` RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig @@ -898,3 +900,56 @@ torch==2.3.1 pandas==2.0.3 coglet @ https://github.com/replicate/cog-runtime/releases/download/v0.1.0-alpha31/coglet-0.1.0a31-py3-none-any.whl`, string(requirements)) } + +func TestGenerateDockerfileStripped(t *testing.T) { + tmpDir := t.TempDir() + + conf, err := config.FromYAML([]byte(` +build: + gpu: true + cuda: "11.8" + python_version: "3.12" + system_packages: + - ffmpeg + - cowsay + python_packages: + - torch==2.3.1 + - pandas==2.0.3 + run: + - "cowsay moo" +predict: predict.py:Predictor +`)) + require.NoError(t, err) + require.NoError(t, conf.ValidateAndComplete("")) + command := dockertest.NewMockCommand() + client := registrytest.NewMockRegistryClient() + gen, err := NewStandardGenerator(conf, tmpDir, command, client, true) + require.NoError(t, err) + gen.SetUseCogBaseImage(true) + gen.SetStrip(true) + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") + require.NoError(t, err) + + expected := `#syntax=docker/dockerfile:1.4 +FROM r8.im/replicate/cog-test-weights AS weights +FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1 +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* +COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r /tmp/requirements.txt && find / -type f -name "*python*.so" -not -name "*cpython*.so" -exec strip -S {} \; +ENV CFLAGS= +RUN find / -type f -name "*.py[co]" -delete && find / -type f -name "*.py" -exec touch -t 197001010000 {} \; && find / -type f -name "*.py" -printf "%h\n" | sort -u | /usr/bin/python3 -m compileall --invalidation-mode timestamp -o 2 -j 0 +RUN cowsay moo +WORKDIR /src +EXPOSE 5000 +CMD ["python", "-m", "cog.server.http"] +COPY . /src` + + require.Equal(t, expected, actual) + + requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt")) + require.NoError(t, err) + require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.3.1 +pandas==2.0.3`, string(requirements)) +} diff --git a/python/cog/server/webhook.py b/python/cog/server/webhook.py index 1aca58ae41..9e71d5bf4d 100644 --- a/python/cog/server/webhook.py +++ b/python/cog/server/webhook.py @@ -1,4 +1,5 @@ import os +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Set import requests @@ -16,6 +17,10 @@ log = structlog.get_logger(__name__) _response_interval = float(os.environ.get("COG_THROTTLE_RESPONSE_INTERVAL", 0.5)) +_webhook_timeout = float( + os.environ.get("COG_WEBHOOK_TIMEOUT", 10.0) +) # 10 second timeout by default +_webhook_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="webhook") # HACK: signal that we should skip the start webhook when the response interval # is tuned below 100ms. This should help us get output sooner for models that @@ -27,11 +32,40 @@ def webhook_caller_filtered( webhook: str, webhook_events_filter: Set[WebhookEvent], ) -> Callable[[Any, WebhookEvent], None]: - upstream_caller = webhook_caller(webhook) + # Create a session for this webhook + default_session = requests_session() + retry_session = requests_session_with_retries() + throttler = ResponseThrottler(response_interval=_response_interval) + + def _send_webhook(response: PredictionResponse, session: requests.Session) -> None: + if PYDANTIC_V2: + dict_response = jsonable_encoder(response.model_dump(exclude_unset=True)) + else: + dict_response = jsonable_encoder(response.dict(exclude_unset=True)) + + try: + session.post(webhook, json=dict_response, timeout=_webhook_timeout) + except requests.exceptions.Timeout: + log.warn("webhook request timed out", webhook=webhook) + except requests.exceptions.RequestException: + log.warn("caught exception while sending webhook", exc_info=True) def caller(response: PredictionResponse, event: WebhookEvent) -> None: - if event in webhook_events_filter: - upstream_caller(response) + if event not in webhook_events_filter: + return + + if not throttler.should_send_response(response): + return + + # Use a separate thread for webhook calls to avoid blocking + if Status.is_terminal(response.status): + # For terminal updates, retry persistently but in background + _webhook_executor.submit(_send_webhook, response, retry_session) + else: + # For other requests, don't retry, and ignore any errors + _webhook_executor.submit(_send_webhook, response, default_session) + + throttler.update_last_sent_response_time() return caller @@ -44,24 +78,32 @@ def webhook_caller(webhook: str) -> Callable[[Any], None]: default_session = requests_session() retry_session = requests_session_with_retries() + def _send_webhook(response: PredictionResponse, session: requests.Session) -> None: + if PYDANTIC_V2: + dict_response = jsonable_encoder(response.model_dump(exclude_unset=True)) + else: + dict_response = jsonable_encoder(response.dict(exclude_unset=True)) + + try: + session.post(webhook, json=dict_response, timeout=_webhook_timeout) + except requests.exceptions.Timeout: + log.warn("webhook request timed out", webhook=webhook) + except requests.exceptions.RequestException: + log.warn("caught exception while sending webhook", exc_info=True) + def caller(response: PredictionResponse) -> None: - if throttler.should_send_response(response): - if PYDANTIC_V2: - dict_response = jsonable_encoder( - response.model_dump(exclude_unset=True) - ) - else: - dict_response = jsonable_encoder(response.dict(exclude_unset=True)) - if Status.is_terminal(response.status): - # For terminal updates, retry persistently - retry_session.post(webhook, json=dict_response) - else: - # For other requests, don't retry, and ignore any errors - try: - default_session.post(webhook, json=dict_response) - except requests.exceptions.RequestException: - log.warn("caught exception while sending webhook", exc_info=True) - throttler.update_last_sent_response_time() + if not throttler.should_send_response(response): + return + + # Use a separate thread for webhook calls to avoid blocking + if Status.is_terminal(response.status): + # For terminal updates, retry persistently but in background + _webhook_executor.submit(_send_webhook, response, retry_session) + else: + # For other requests, don't retry, and ignore any errors + _webhook_executor.submit(_send_webhook, response, default_session) + + throttler.update_last_sent_response_time() return caller @@ -84,13 +126,12 @@ def requests_session() -> requests.Session: def requests_session_with_retries() -> requests.Session: - # This session will retry requests up to 12 times, with exponential - # backoff. In total it'll try for up to roughly 320 seconds, providing - # resilience through temporary networking and availability issues. + # This session will retry requests up to 6 times (reduced from 12), with exponential + # backoff. In total it'll try for up to roughly 60 seconds (reduced from 320s). session = requests_session() adapter = HTTPAdapter( max_retries=Retry( - total=12, + total=6, # Reduced from 12 to avoid blocking too long backoff_factor=0.1, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["POST"], diff --git a/python/tests/server/test_webhook.py b/python/tests/server/test_webhook.py index 8031d52d8d..703bc1dd3e 100644 --- a/python/tests/server/test_webhook.py +++ b/python/tests/server/test_webhook.py @@ -1,155 +1,412 @@ -import requests +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, Optional, Tuple +from unittest.mock import patch + import responses -from responses import registries from cog.schema import PredictionResponse, Status, WebhookEvent from cog.server.webhook import webhook_caller, webhook_caller_filtered -@responses.activate -def test_webhook_caller_basic(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, - status=200, +class SlowHandler(BaseHTTPRequestHandler): + def do_POST(self): + time.sleep(2) # Simulate slow response + self.send_response(200) + self.end_headers() + + +class ErrorHandler(BaseHTTPRequestHandler): + def do_POST(self): + self.send_response(500) + self.end_headers() + + +class UnreachableHandler(BaseHTTPRequestHandler): + """Handler that simulates connection refused""" + + def do_POST(self): + # Close connection immediately to simulate connection refused + self.wfile.close() + + +def make_prediction_response( + status: Status, output: Optional[Dict[str, Any]] = None +) -> PredictionResponse: + return PredictionResponse( + status=status, + input={}, # Required field + output=output or {}, ) - c(response) + +def wait_for_webhook_calls(expected_count: int, timeout: float = 2.0) -> None: + """Wait for the expected number of webhook calls to complete""" + start_time = time.time() + while time.time() - start_time < timeout: + # Check if all webhook threads are done + active_threads = [ + t for t in threading.enumerate() if t.name.startswith("webhook") + ] + if len(active_threads) == 0: + break + time.sleep(0.1) @responses.activate -def test_webhook_caller_non_terminal_does_not_retry(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, - status=429, +def test_webhook_timeout(): + """Test that webhook calls timeout properly and don't block indefinitely""" + # Set a very short timeout for testing + with patch.dict("os.environ", {"COG_WEBHOOK_TIMEOUT": "0.5"}): + responses.add( + responses.POST, + "http://example.com/webhook", + body=lambda request: time.sleep(2) or "OK", # type: ignore # Sleep longer than timeout + status=200, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + start_time = time.time() + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=3.0) + + elapsed_time = time.time() - start_time + # Should timeout quickly (within 2 seconds including overhead) + assert elapsed_time < 2.0, f"Webhook call took too long: {elapsed_time}s" + + +@responses.activate +def test_webhook_error_handling(): + """Test that webhook calls handle HTTP errors gracefully""" + responses.add( + responses.POST, + "http://example.com/webhook", + status=500, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + + # Should not raise an exception + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) - c(response) + assert len(responses.calls) == 1 -@responses.activate(registry=registries.OrderedRegistry) -def test_webhook_caller_terminal_retries(): - c = webhook_caller("https://example.com/webhook/123") - resps = [] +def test_webhook_connection_refused(): + """Test webhook behavior when connection is refused (simulating service down)""" + # Use a port that's guaranteed to be closed + webhook_url = "http://127.0.0.1:65432/webhook" # Unlikely to be in use - payload = {"status": Status.SUCCEEDED, "output": {"animal": "giraffe"}, "input": {}} - response = PredictionResponse(**payload) + prediction = make_prediction_response(Status.SUCCEEDED) + start_time = time.time() - for _ in range(2): - resps.append( - responses.post( - "https://example.com/webhook/123", - json=payload, - status=429, - ) - ) - resps.append( - responses.post( - "https://example.com/webhook/123", - json=payload, - status=200, - ) + # Should not raise an exception or block indefinitely + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=5.0) + + elapsed_time = time.time() - start_time + # Should fail quickly due to connection refused + assert elapsed_time < 15.0, ( + f"Connection refused handling took too long: {elapsed_time}s" ) - c(response) - assert all(r.call_count == 1 for r in resps) +@responses.activate +def test_webhook_retry_behavior(): + """Test that webhook retries work correctly for terminal status""" + call_count = 0 + + def callback(request: Any) -> Tuple[int, Dict[str, str], str]: + nonlocal call_count + call_count += 1 + if call_count < 3: # Fail first 2 attempts + return (500, {}, "Server Error") + return (200, {}, "OK") + + responses.add_callback( + responses.POST, + "http://example.com/webhook", + callback=callback, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=10.0) + + # Should have retried and eventually succeeded + assert call_count == 3 + assert len(responses.calls) == 3 @responses.activate -def test_webhook_includes_user_agent(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, +def test_webhook_filtered(): + """Test that webhook_caller_filtered only sends webhooks for specified events""" + responses.add( + responses.POST, + "http://example.com/webhook", status=200, ) - c(response) + prediction = make_prediction_response(Status.SUCCEEDED) + + # Should send webhook for COMPLETED event + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) assert len(responses.calls) == 1 - user_agent = responses.calls[0].request.headers["user-agent"] - assert user_agent.startswith("cog-worker/") + # Reset responses + responses.reset() + responses.add( + responses.POST, + "http://example.com/webhook", + status=200, + ) + + # Should NOT send webhook for START event when only COMPLETED is in filter + caller(prediction, WebhookEvent.START) + wait_for_webhook_calls(1) -@responses.activate -def test_webhook_caller_filtered_basic(): - events = WebhookEvent.default_events() - c = webhook_caller_filtered("https://example.com/webhook/123", events) + assert len(responses.calls) == 0 + + +def test_webhook_max_retry_limit(): + """Test that webhooks don't retry indefinitely""" + # Create a server that always returns 500 + server = HTTPServer(("localhost", 0), ErrorHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + prediction = make_prediction_response(Status.SUCCEEDED) - payload = {"status": Status.PROCESSING, "animal": "giraffe", "input": {}} - response = PredictionResponse(**payload) + start_time = time.time() + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=70.0) # Max ~60s for 6 retries + elapsed_time = time.time() - start_time - responses.post( - "https://example.com/webhook/123", - json=payload, + # Should stop retrying after max attempts (~60s with exponential backoff) + assert elapsed_time < 70.0, f"Webhook retries took too long: {elapsed_time}s" + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + + +def test_webhook_background_execution(): + """Test that webhooks execute in background threads and don't block main thread""" + # Create a slow server + server = HTTPServer(("localhost", 0), SlowHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + prediction = make_prediction_response(Status.SUCCEEDED) + + start_time = time.time() + + # Make multiple webhook calls + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + for _ in range(3): + caller(prediction, WebhookEvent.COMPLETED) + + # Should return immediately (not wait for webhooks to complete) + immediate_time = time.time() - start_time + assert immediate_time < 0.5, ( + f"Webhook calls blocked main thread: {immediate_time}s" + ) + + # Wait for all webhooks to complete + wait_for_webhook_calls(3, timeout=10.0) + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + + +@responses.activate +def test_webhook_user_agent(): + """Test that webhook calls include correct user agent""" + responses.add( + responses.POST, + "http://example.com/webhook", status=200, ) - c(response, WebhookEvent.LOGS) + prediction = make_prediction_response(Status.SUCCEEDED) + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) + assert len(responses.calls) == 1 + request = responses.calls[0].request + assert "cog-worker/" in request.headers.get("User-Agent", "") -@responses.activate -def test_webhook_caller_filtered_omits_filtered_events(): - events = {WebhookEvent.COMPLETED} - c = webhook_caller_filtered("https://example.com/webhook/123", events) - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) +def test_webhook_original_bug_scenario(): + """ + Test the original bug scenario: webhook service down causes prediction to get stuck + This test verifies that our fix prevents the issue + """ + # Simulate webhook service being completely down (connection refused) + webhook_url = "http://127.0.0.1:65433/webhook" # Port guaranteed to be closed + + prediction = make_prediction_response(Status.SUCCEEDED) + + # Record start time + start_time = time.time() + + # This should NOT block indefinitely or cause the prediction to get stuck + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + + # Wait for webhook call to complete (should fail after retries) + wait_for_webhook_calls(1, timeout=20.0) + + elapsed_time = time.time() - start_time + + # The fix should ensure this completes within reasonable time + # Original bug would cause this to hang for 320+ seconds (5+ minutes) + # With our fix, it should fail within ~15-20 seconds (6 retries with exponential backoff) + # This proves the webhook failures don't block the main thread indefinitely + assert elapsed_time < 25.0, ( + f"Webhook failure handling took too long: {elapsed_time}s" + ) + assert elapsed_time > 10.0, ( + f"Webhook should have attempted retries, took only: {elapsed_time}s" + ) + + # Verify that the prediction status would not be stuck in "BUSY" + # (In real usage, the runner would have updated status before webhook call) + assert prediction.status == Status.SUCCEEDED + + +def test_webhook_cancellation_during_failure(): + """ + Test that webhook failures don't prevent cancellation + This simulates the scenario where a prediction needs to be cancelled + while webhook calls are failing + """ + + # Create a server that's very slow to respond + class VerySlowHandler(BaseHTTPRequestHandler): + def do_POST(self): + time.sleep(5) # Very slow response + self.send_response(200) + self.end_headers() + + server = HTTPServer(("localhost", 0), VerySlowHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + + # Start a webhook call that will be slow + prediction = make_prediction_response(Status.PROCESSING) + caller = webhook_caller_filtered( + webhook_url, {WebhookEvent.START, WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.START) + + # Immediately try to "cancel" by updating status + # This should not be blocked by the ongoing webhook call + start_time = time.time() + prediction.status = Status.CANCELED + + # In real usage, this would trigger another webhook call for cancellation + caller(prediction, WebhookEvent.COMPLETED) + + immediate_time = time.time() - start_time + + # Cancellation should be immediate, not blocked by slow webhook + assert immediate_time < 1.0, ( + f"Cancellation was blocked by webhook: {immediate_time}s" + ) + + # Clean up - wait for webhooks to complete or timeout + wait_for_webhook_calls(2, timeout=15.0) + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + - c(response, WebhookEvent.LOGS) +def test_webhook_thread_pool_limits(): + """Test that webhook thread pool doesn't create unlimited threads""" + initial_thread_count = threading.active_count() + + # Create many webhook calls simultaneously + prediction = make_prediction_response(Status.SUCCEEDED) + + # Use a non-existent URL to make calls fail quickly + webhook_url = "http://127.0.0.1:65434/webhook" + + # Make many concurrent webhook calls + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + for _ in range(20): + caller(prediction, WebhookEvent.COMPLETED) + + # Check thread count hasn't exploded + peak_thread_count = threading.active_count() + thread_increase = peak_thread_count - initial_thread_count + + # Should not create more than the thread pool limit (4) + some overhead + assert thread_increase < 10, f"Too many threads created: {thread_increase}" + + # Wait for all webhooks to complete + wait_for_webhook_calls(20, timeout=10.0) + + # Thread count should return to normal + final_thread_count = threading.active_count() + assert ( + final_thread_count <= initial_thread_count + 4 + ) # Allow for thread pool threads @responses.activate -def test_webhook_caller_connection_errors(): - connerror_resp = responses.Response( +def test_webhook_caller_basic(): + """Test basic webhook_caller functionality (without events)""" + responses.add( responses.POST, - "https://example.com/webhook/123", + "http://example.com/webhook", status=200, ) - connerror_exc = requests.ConnectionError("failed to connect") - connerror_exc.response = connerror_resp - connerror_resp.body = connerror_exc - responses.add(connerror_resp) - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - c = webhook_caller("https://example.com/webhook/123") - # this should not raise an error - c(response) + + prediction = make_prediction_response(Status.SUCCEEDED) + + # webhook_caller doesn't use events, just sends the response + caller = webhook_caller("http://example.com/webhook") + caller(prediction) + wait_for_webhook_calls(1) + + assert len(responses.calls) == 1