File tree Expand file tree Collapse file tree 4 files changed +14
-17
lines changed
Expand file tree Collapse file tree 4 files changed +14
-17
lines changed Original file line number Diff line number Diff line change 1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
1516
1617import joblib
1718import numpy as np
1819import torch
20+ import transformers
1921from sklearn .linear_model import LinearRegression
2022from tqdm import tqdm
21- from transformers import PreTrainedModel
2223
2324from crossfit .utils .model_adapter import adapt_model_input
2425from crossfit .utils .torch_utils import (
2930
3031
3132def fit_memory_estimate_curve (
32- model : PreTrainedModel ,
33+ model : "transformers. PreTrainedModel" ,
3334 path_or_name : str ,
3435 start_batch_size : int = 1 ,
3536 end_batch_size : int = 2048 ,
Original file line number Diff line number Diff line change @@ -172,12 +172,12 @@ def to_pytrec(self, is_run=False):
172172
173173 qrel = {}
174174 for i in range (self .indices .shape [0 ]):
175- query_id = f"q{ i + 1 } "
175+ query_id = f"q{ i + 1 } "
176176 qrel [query_id ] = {}
177177
178178 row = sparse_matrix [i ]
179179 for j , score in zip (row .indices , row .data ):
180- doc_id = f"d{ j + 1 } "
180+ doc_id = f"d{ j + 1 } "
181181 qrel [query_id ][doc_id ] = int (score ) if is_run else float (score )
182182
183183 return qrel
Original file line number Diff line number Diff line change 5353 columns = [f"I{ i } " for i in range (1 , ncolumns + 1 )]
5454 if groupby :
5555 columns += groupby if isinstance (groupby , list ) else [groupby ]
56- ddf = dd .read_parquet (
57- path ,
58- blocksize = blocksize ,
59- columns = columns ,
60- )
56+ ddf = dd .read_parquet (path , blocksize = blocksize , columns = columns )
6157 print (f"\n ddf: { ddf } \n " )
6258
6359 # Aggregate moments (mean, var, std)
6460 agg = cf .Aggregator (Moments (axis = 0 ), per_column = True )
6561 t0 = time .time ()
6662 result = aggregate (ddf , agg , to_frame = True )
6763 tf = time .time ()
68- print (f"\n Wall Time: { tf - t0 } seconds\n " )
64+ print (f"\n Wall Time: { tf - t0 } seconds\n " )
6965
7066 # View result
7167 print (f"Result:\n { result } \n " )
7672 t0 = time .time ()
7773 std = ddf .groupby (groupby ).std ().compute ()
7874 tf = time .time ()
79- print (f"\n ddf.groupby().std() takes { tf - t0 } seconds, and returns:\n " )
75+ print (f"\n ddf.groupby().std() takes { tf - t0 } seconds, and returns:\n " )
8076 print (f"\n { std } \n " )
8177 else :
8278 # Compare to ddf.std()
8379 t0 = time .time ()
8480 std = ddf .std ().compute ()
8581 tf = time .time ()
86- print (f"\n ddf.std() takes { tf - t0 } seconds, and returns:\n " )
82+ print (f"\n ddf.std() takes { tf - t0 } seconds, and returns:\n " )
8783 print (f"\n { std } \n " )
Original file line number Diff line number Diff line change @@ -24,13 +24,13 @@ def create_qrel(relevance_scores, ids=None):
2424
2525 qrel = {}
2626 for i , query_scores in enumerate (relevance_scores ):
27- query_id = ids [i ] if ids is not None else f"q{ i + 1 } "
27+ query_id = ids [i ] if ids is not None else f"q{ i + 1 } "
2828 qrel [query_id ] = {}
2929 for j , score in enumerate (query_scores ):
3030 _score = int (score .item ())
3131
3232 if _score > 0 :
33- doc_id = f"d{ j + 1 } "
33+ doc_id = f"d{ j + 1 } "
3434 qrel [query_id ][doc_id ] = int (score .item ())
3535
3636 return qrel
@@ -41,10 +41,10 @@ def create_run(predicted_scores, ids=None):
4141
4242 run = {}
4343 for i , query_scores in enumerate (predicted_scores ):
44- query_id = ids [i ] if ids is not None else f"q{ i + 1 } "
44+ query_id = ids [i ] if ids is not None else f"q{ i + 1 } "
4545 run [query_id ] = {}
4646 for j , score in enumerate (query_scores ):
47- doc_id = f"d{ j + 1 } "
47+ doc_id = f"d{ j + 1 } "
4848 run [query_id ][doc_id ] = float (score .item ())
4949
5050 return run
@@ -60,6 +60,6 @@ def create_results(metric_arrays):
6060 for k , v in metric_arrays .items ():
6161 q_out [k ] = float (v [i ])
6262
63- outputs [f"q{ i + 1 } " ] = q_out
63+ outputs [f"q{ i + 1 } " ] = q_out
6464
6565 return outputs
You can’t perform that action at this time.
0 commit comments