diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 7423d66..ab6d4de 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -365,6 +365,7 @@ def cog_safe_push( raise tasks = [] + prediction_index = 1 if model_has_versions: log.info("Checking schema backwards compatibility") @@ -388,8 +389,10 @@ def cog_safe_push( fuzz_fixed_inputs=fuzz_fixed_inputs, fuzz_disabled_inputs=fuzz_disabled_inputs, fuzz_prompt=fuzz_prompt, + prediction_index=prediction_index, ) ) + prediction_index += 1 if test_cases: for inputs, checker in test_cases: @@ -399,8 +402,10 @@ def cog_safe_push( inputs=inputs, checker=checker, predict_timeout=predict_timeout, + prediction_index=prediction_index, ) ) + prediction_index += 1 if fuzz_iterations > 0: fuzz_inputs_queue = Queue(maxsize=fuzz_iterations) @@ -420,8 +425,10 @@ def cog_safe_push( context=task_context, inputs_queue=fuzz_inputs_queue, predict_timeout=predict_timeout, + prediction_index=prediction_index, ) ) + prediction_index += 1 asyncio.run(run_tasks(tasks, parallel=parallel)) @@ -443,14 +450,18 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None: log.info(f"Running tasks with parallelism {parallel}") semaphore = asyncio.Semaphore(parallel) - errors: list[Exception] = [] + errors: list[tuple[Exception, int | None]] = [] async def run_with_semaphore(task: Task) -> None: async with semaphore: try: await task.run() except Exception as e: - errors.append(e) + # Get prediction index if the task has one + prediction_index = getattr(task, "prediction_index", None) + errors.append((e, prediction_index)) + prefix = "" if prediction_index is None else f"[{prediction_index}] " + log.error(f"{prefix}{e}") # Create task coroutines and run them concurrently task_coroutines = [run_with_semaphore(task) for task in tasks] @@ -459,11 +470,13 @@ async def run_with_semaphore(task: Task) -> None: await asyncio.gather(*task_coroutines, return_exceptions=True) if errors: - # If there are multiple errors, we'll raise the first one - # but log all of them - for error in errors[1:]: - log.error(f"Additional error occurred: {error}") - raise errors[0] + # Display all errors with their prediction indices + log.error(f"💥 Tests finished with {len(errors)} error(s):") + for error, prediction_index in errors: + prefix = "" if prediction_index is None else f"[{prediction_index}] " + log.error(f"* {prefix}{error}") + + raise TaskExecutionError(f"Encountered {len(errors)} task error(s).", errors) def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index 8d5525a..3101e12 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -231,9 +231,11 @@ async def predict( train_destination: Model | None, inputs: dict, timeout_seconds: float, + prediction_index: int | None = None, ) -> tuple[Any | None, str | None]: + prefix = f"[{prediction_index}] " if prediction_index is not None else "" log.vv( - f"Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" + f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" ) start_time = time.time() @@ -261,7 +263,7 @@ async def predict( else: raise - log.v(f"Prediction URL: https://replicate.com/p/{prediction.id}") + log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}") while prediction.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(0.5) @@ -272,13 +274,13 @@ async def predict( duration = time.time() - start_time if prediction.status == "failed": - log.v(f"Got error: {prediction.error} ({duration:.2f} sec)") + log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)") return None, prediction.error output = prediction.output if _has_output_iterator_array_type(version): output = "".join(cast("list[str]", output)) - log.v(f"Got output: {truncate(output)} ({duration:.2f} sec)") + log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)") return output, None diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index 924fe48..af0ff32 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -27,6 +27,7 @@ class CheckOutputsMatch(Task): fuzz_fixed_inputs: dict[str, Any] fuzz_disabled_inputs: list[str] fuzz_prompt: str | None + prediction_index: int | None = None async def run(self) -> None: if self.first_test_case_inputs is not None: @@ -50,8 +51,11 @@ async def run(self) -> None: fuzz_prompt=self.fuzz_prompt, ) + prefix = ( + f"[{self.prediction_index}] " if self.prediction_index is not None else "" + ) log.v( - f"Checking outputs match between existing version and test version, with inputs: {inputs}" + f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}" ) test_output, test_error = await predict( model=self.context.test_model, @@ -59,6 +63,7 @@ async def run(self) -> None: train_destination=self.context.train_destination, inputs=inputs, timeout_seconds=self.timeout_seconds, + prediction_index=self.prediction_index, ) output, error = await predict( model=self.context.model, @@ -66,21 +71,22 @@ async def run(self) -> None: train_destination=self.context.train_destination, inputs=inputs, timeout_seconds=self.timeout_seconds, + prediction_index=self.prediction_index, ) if test_error is not None: raise OutputsDontMatchError( - f"Existing version raised an error: {test_error}" + f"{prefix}Existing version raised an error: {test_error}" ) if error is not None: - raise OutputsDontMatchError(f"New version raised an error: {error}") + raise OutputsDontMatchError(f"{prefix}New version raised an error: {error}") matches, match_error = await outputs_match( test_output, output, is_deterministic ) if not matches: raise OutputsDontMatchError( - f"Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{match_error}" + f"{prefix}Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{match_error}" ) @@ -90,15 +96,20 @@ class RunTestCase(Task): inputs: dict[str, Any] checker: OutputChecker predict_timeout: int + prediction_index: int | None = None async def run(self) -> None: - log.v(f"Running test case with inputs: {self.inputs}") + prefix = ( + f"[{self.prediction_index}] " if self.prediction_index is not None else "" + ) + log.v(f"{prefix}Running test case with inputs: {self.inputs}") output, error = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, inputs=self.inputs, timeout_seconds=self.predict_timeout, + prediction_index=self.prediction_index, ) await self.checker(output, error) @@ -138,11 +149,15 @@ class FuzzModel(Task): context: TaskContext inputs_queue: Queue[dict[str, Any]] predict_timeout: int + prediction_index: int | None = None async def run(self) -> None: inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60) - log.v(f"Fuzzing with inputs: {inputs}") + prefix = ( + f"[{self.prediction_index}] " if self.prediction_index is not None else "" + ) + log.v(f"{prefix}Fuzzing with inputs: {inputs}") try: output, error = await predict( model=self.context.test_model, @@ -150,13 +165,14 @@ async def run(self) -> None: train_destination=self.context.train_destination, inputs=inputs, timeout_seconds=self.predict_timeout, + prediction_index=self.prediction_index, ) except PredictionTimeoutError: - raise FuzzError("Prediction timed out") + raise FuzzError(f"{prefix}Prediction timed out") if error is not None: - raise FuzzError(f"Prediction raised an error: {error}") + raise FuzzError(f"{prefix}Prediction raised an error: {error}") if not output: - raise FuzzError("No output") + raise FuzzError(f"{prefix}No output") if error is not None: - raise FuzzError(f"Prediction failed: {error}") + raise FuzzError(f"{prefix}Prediction failed: {error}")