|
12 | 12 | from raglite._search import vector_search |
13 | 13 |
|
14 | 14 |
|
15 | | -def update_query_adapter( # noqa: PLR0915, C901 |
| 15 | +def update_query_adapter( # noqa: C901, PLR0915 |
16 | 16 | *, |
17 | 17 | max_triplets: int = 4096, |
18 | 18 | max_triplets_per_eval: int = 64, |
@@ -78,59 +78,63 @@ def update_query_adapter( # noqa: PLR0915, C901 |
78 | 78 | evals = session.exec( |
79 | 79 | select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval)) |
80 | 80 | ).all() |
81 | | - if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding): |
82 | | - error_message = "First run `insert_evals()` to generate sufficient evals." |
| 81 | + # Exit if there aren't enough evals to compute the query adapter. |
| 82 | + embedding_dim = len(chunk_embedding.embedding) |
| 83 | + required_evals = np.ceil(embedding_dim / max_triplets_per_eval) - len(evals) |
| 84 | + if required_evals > 0: |
| 85 | + error_message = f"First run `insert_evals()` to generate {required_evals} more evals." |
83 | 86 | raise ValueError(error_message) |
84 | 87 | # Loop over the evals to generate (q, p, n) triplets. |
85 | | - Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806 |
| 88 | + Q = np.zeros((0, embedding_dim)) # noqa: N806 |
86 | 89 | P = np.zeros_like(Q) # noqa: N806 |
87 | 90 | N = np.zeros_like(Q) # noqa: N806 |
88 | 91 | for eval_ in tqdm( |
89 | 92 | evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True |
90 | 93 | ): |
91 | 94 | # Embed the question. |
92 | | - question_embedding = embed_strings([eval_.question], config=config)[0] |
| 95 | + question_embedding = embed_strings([eval_.question], config=config) |
93 | 96 | # Retrieve chunks that would be used to answer the question. |
94 | 97 | chunk_ids, _ = vector_search( |
95 | | - question_embedding, num_results=optimize_top_k, config=config_no_query_adapter |
| 98 | + question_embedding[0], num_results=optimize_top_k, config=config_no_query_adapter |
96 | 99 | ) |
97 | 100 | retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() |
98 | | - # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval. |
| 101 | + retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) |
| 102 | + # Extract (q, p, n) triplets from the eval. |
99 | 103 | num_triplets = 0 |
100 | 104 | for i, retrieved_chunk in enumerate(retrieved_chunks): |
101 | | - # Select irrelevant chunks. |
| 105 | + # Only loop over irrelevant chunks. |
102 | 106 | if retrieved_chunk.id not in eval_.chunk_ids: |
103 | | - # Look up all positive chunks (each represented by the mean of its multi-vector |
104 | | - # embedding) that are ranked lower than this negative one (represented by the |
105 | | - # embedding in the multi-vector embedding that best matches the query). |
106 | | - p_mean = [ |
107 | | - np.mean(chunk.embedding_matrix, axis=0, keepdims=True) |
108 | | - for chunk in retrieved_chunks[i + 1 :] |
109 | | - if chunk is not None and chunk.id in eval_.chunk_ids |
| 107 | + continue |
| 108 | + irrelevant_chunk = retrieved_chunk |
| 109 | + # Grab the negative chunk embedding of this irrelevant chunk. |
| 110 | + n_top = irrelevant_chunk.embedding_matrix[ |
| 111 | + [np.argmax(irrelevant_chunk.embedding_matrix @ question_embedding.T)] |
| 112 | + ] |
| 113 | + # Grab the positive chunk embeddings that are ranked lower than the negative one. |
| 114 | + p_top = [ |
| 115 | + chunk.embedding_matrix[ |
| 116 | + [np.argmax(chunk.embedding_matrix @ question_embedding.T)] |
110 | 117 | ] |
111 | | - n_top = retrieved_chunk.embedding_matrix[ |
112 | | - np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T), |
113 | | - np.newaxis, |
114 | | - :, |
115 | | - ] |
116 | | - # Filter out any (p, n, q) triplets for which the mean positive embedding ranks |
117 | | - # higher than the top negative one. |
118 | | - p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0] |
119 | | - if not p_mean: |
120 | | - continue |
121 | | - # Stack the (p, n, q) triplets. |
122 | | - p = np.vstack(p_mean) |
123 | | - n = np.repeat(n_top, p.shape[0], axis=0) |
124 | | - q = np.repeat(question_embedding, p.shape[0], axis=0) |
125 | | - num_triplets += p.shape[0] |
126 | | - # Append the (query, positive, negative) tuples to the Q, P, N matrices. |
127 | | - Q = np.vstack([Q, q]) # noqa: N806 |
128 | | - P = np.vstack([P, p]) # noqa: N806 |
129 | | - N = np.vstack([N, n]) # noqa: N806 |
130 | | - # Check if we have sufficient triplets for this eval. |
131 | | - if num_triplets >= max_triplets_per_eval: |
132 | | - break |
133 | | - # Check if we have sufficient triplets to compute the query adapter. |
| 118 | + for chunk in retrieved_chunks[i + 1 :] # Chunks that are ranked lower. |
| 119 | + if chunk is not None and chunk.id in eval_.chunk_ids |
| 120 | + ] |
| 121 | + # Ensure that we only have (q, p, n) triplets for which p is ranked lower than n. |
| 122 | + p_top = [p for p in p_top if (n_top - p) @ question_embedding.T > 0] |
| 123 | + if not p_top: |
| 124 | + continue |
| 125 | + # Stack the (q, p, n) triplets. |
| 126 | + p = np.vstack(p_top) |
| 127 | + n = np.repeat(n_top, p.shape[0], axis=0) |
| 128 | + q = np.repeat(question_embedding, p.shape[0], axis=0) |
| 129 | + num_triplets += p.shape[0] |
| 130 | + # Append the (q, p, n) triplets to the Q, P, N matrices. |
| 131 | + Q = np.vstack([Q, q]) # noqa: N806 |
| 132 | + P = np.vstack([P, p]) # noqa: N806 |
| 133 | + N = np.vstack([N, n]) # noqa: N806 |
| 134 | + # Stop if we have enough triplets for this eval. |
| 135 | + if num_triplets >= max_triplets_per_eval: |
| 136 | + break |
| 137 | + # Stop if we have enough triplets to compute the query adapter. |
134 | 138 | if Q.shape[0] > max_triplets: |
135 | 139 | Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806 |
136 | 140 | break |
|
0 commit comments