|
| 1 | +"""CloudBench-specific configuration/hyperparameters.""" |
| 2 | + |
| 3 | +import os |
| 4 | +from weaviate.classes.config import VectorDistances, Configure |
| 5 | +from weaviate.collections.classes.config_vector_index import VectorFilterStrategy |
| 6 | + |
| 7 | +from imsearch_eval.framework.interfaces import Config |
| 8 | + |
| 9 | + |
| 10 | +class CloudBenchConfig(Config): |
| 11 | + """Configuration for CloudBench benchmark (cloud/atmospheric image retrieval).""" |
| 12 | + |
| 13 | + def __init__(self): |
| 14 | + """Initialize CloudBench configuration.""" |
| 15 | + # dataset parameters |
| 16 | + self.cloudbench_dataset = os.environ.get( |
| 17 | + "CLOUDBENCH_DATASET", "sagecontinuum/CloudBench" |
| 18 | + ) |
| 19 | + self.sample_size = int(os.environ.get("SAMPLE_SIZE", 0)) |
| 20 | + self.seed = int(os.environ.get("SEED", 42)) |
| 21 | + self._hf_token = os.environ.get("HF_TOKEN", "") |
| 22 | + # Upload parameters |
| 23 | + self._upload_to_s3 = os.environ.get("UPLOAD_TO_S3", "false").lower() == "true" |
| 24 | + self._s3_bucket = os.environ.get("S3_BUCKET", "sage_imsearch") |
| 25 | + self._s3_prefix = os.environ.get("S3_PREFIX", "dev-metrics/cloudbench") |
| 26 | + self._s3_endpoint = os.environ.get( |
| 27 | + "S3_ENDPOINT", "http://rook-ceph-rgw-nautiluss3.rook" |
| 28 | + ) |
| 29 | + self._s3_access_key = os.environ.get("S3_ACCESS_KEY", "") |
| 30 | + self._s3_secret_key = os.environ.get("S3_SECRET_KEY", "") |
| 31 | + self._s3_secure = os.environ.get("S3_SECURE", "false").lower() == "true" |
| 32 | + self._image_results_file = os.environ.get( |
| 33 | + "IMAGE_RESULTS_FILE", "image_search_results.csv" |
| 34 | + ) |
| 35 | + self._query_eval_metrics_file = os.environ.get( |
| 36 | + "QUERY_EVAL_METRICS_FILE", "query_eval_metrics.csv" |
| 37 | + ) |
| 38 | + self._config_values_file = os.environ.get( |
| 39 | + "CONFIG_VALUES_FILE", "config_values.csv" |
| 40 | + ) |
| 41 | + |
| 42 | + # Weaviate parameters |
| 43 | + self._weaviate_host = os.environ.get("WEAVIATE_HOST", "127.0.0.1") |
| 44 | + self._weaviate_port = os.environ.get("WEAVIATE_PORT", "8080") |
| 45 | + self._weaviate_grpc_port = os.environ.get("WEAVIATE_GRPC_PORT", "50051") |
| 46 | + self._collection_name = os.environ.get("COLLECTION_NAME", "CloudBench") |
| 47 | + |
| 48 | + # model provider parameters |
| 49 | + self._llm_model_provider = os.environ.get( |
| 50 | + "LLM_MODEL_PROVIDER", "triton" |
| 51 | + ).lower() |
| 52 | + |
| 53 | + # Triton parameters |
| 54 | + self._triton_host = os.environ.get("TRITON_HOST", "triton") |
| 55 | + self._triton_port = os.environ.get("TRITON_PORT", "8001") |
| 56 | + |
| 57 | + # Workers parameters |
| 58 | + self._workers = int(os.environ.get("WORKERS", 5)) |
| 59 | + self._image_batch_size = int(os.environ.get("IMAGE_BATCH_SIZE", 25)) |
| 60 | + self._query_batch_size = int(os.environ.get("QUERY_BATCH_SIZE", 5)) |
| 61 | + |
| 62 | + # Logging parameters |
| 63 | + self._log_level = os.environ.get("LOG_LEVEL", "INFO").upper() |
| 64 | + |
| 65 | + # Weaviate HNSW hyperparameters |
| 66 | + self.hnsw_dist_metric = getattr( |
| 67 | + VectorDistances, os.environ.get("HNSW_DIST_METRIC", "COSINE").upper() |
| 68 | + ) |
| 69 | + self.hnsw_ef = int(os.environ.get("HNSW_EF", -1)) |
| 70 | + self.hnsw_ef_construction = int(os.environ.get("HNSW_EF_CONSTRUCTION", 100)) |
| 71 | + self.hnsw_maxConnections = int(os.environ.get("HNSW_MAX_CONNECTIONS", 50)) |
| 72 | + self.hsnw_dynamicEfMax = int(os.environ.get("HNSW_DYNAMIC_EF_MAX", 500)) |
| 73 | + self.hsnw_dynamicEfMin = int(os.environ.get("HNSW_DYNAMIC_EF_MIN", 200)) |
| 74 | + self.hnsw_ef_factor = int(os.environ.get("HNSW_EF_FACTOR", 20)) |
| 75 | + self.hsnw_filterStrategy = getattr( |
| 76 | + VectorFilterStrategy, |
| 77 | + os.environ.get("HNSW_FILTER_STRATEGY", "ACORN").upper(), |
| 78 | + ) |
| 79 | + self.hnsw_flatSearchCutoff = int( |
| 80 | + os.environ.get("HNSW_FLAT_SEARCH_CUTOFF", 40000) |
| 81 | + ) |
| 82 | + self.hnsw_vector_cache_max_objects = int( |
| 83 | + os.environ.get("HNSW_VECTOR_CACHE_MAX_OBJECTS", 1e12) |
| 84 | + ) |
| 85 | + self.hnsw_quantizer = Configure.VectorIndex.Quantizer.pq( |
| 86 | + training_limit=int( |
| 87 | + os.environ.get("HNSW_QUANTIZER_TRAINING_LIMIT", 500000) |
| 88 | + ) |
| 89 | + ) |
| 90 | + |
| 91 | + # Query parameters |
| 92 | + self.query_method = os.environ.get("QUERY_METHOD", "clip_hybrid_query") |
| 93 | + self.target_vector = os.environ.get("TARGET_VECTOR", "clip") |
| 94 | + self.response_limit = int(os.environ.get("RESPONSE_LIMIT", 50)) |
| 95 | + self.advanced_query_parameters = { |
| 96 | + "alpha": float(os.environ.get("QUERY_ALPHA", 0.4)), |
| 97 | + "query_properties": ["caption"], |
| 98 | + "autocut_jumps": int(os.environ.get("AUTOCUT_JUMPS", 0)), |
| 99 | + "rerank_prop": os.environ.get("RERANK_PROP", "caption"), |
| 100 | + "clip_alpha": float(os.environ.get("CLIP_ALPHA", 0.7)), |
| 101 | + } |
| 102 | + |
| 103 | + # Caption prompts (same as Firebench) |
| 104 | + default_prompt = """ |
| 105 | +role: |
| 106 | +You are a world-class Scientific Image Captioning Expert. |
| 107 | +
|
| 108 | +context: |
| 109 | +You will be shown a scientific image captured by edge devices. Your goal is to analyze its content and significance in detail. |
| 110 | +
|
| 111 | +task: |
| 112 | +Generate exactly one scientifically detailed caption that accurately describes what is visible in the image and its scientific relevance. |
| 113 | +Make it as detailed as possible. Also extract text and numbers from the images. |
| 114 | +
|
| 115 | +constraints: |
| 116 | +- Only return: |
| 117 | + 1. A single caption. |
| 118 | + 2. a list of 15 keywords relevant to the image. |
| 119 | +- Do not include any additional text, explanations, or formatting. |
| 120 | +
|
| 121 | +format: |
| 122 | + caption: <your_scientific_caption_here> |
| 123 | + keywords: <keyword1>, <keyword2>, ... |
| 124 | +""" |
| 125 | + self.gemma3_prompt = os.environ.get("GEMMA3_PROMPT", default_prompt) |
| 126 | + |
| 127 | + @staticmethod |
| 128 | + def is_nrp_key_set(): |
| 129 | + """Check if NRP API key is set.""" |
| 130 | + if os.environ.get("NRP_API_KEY", "") == "": |
| 131 | + raise ValueError("NRP_API_KEY is not set") |
0 commit comments