Skip to content

Commit 770d261

Browse files
authored
Fix cuda context bug caused by PreTrainedModel (#113)
* Fix cuda context bug caused by PreTrainedModel Signed-off-by: Vibhu Jawa <vjawa@nvidia.com> * Flake8 style fixes Signed-off-by: Vibhu Jawa <vjawa@nvidia.com> --------- Signed-off-by: Vibhu Jawa <vjawa@nvidia.com>
1 parent ef21071 commit 770d261

File tree

4 files changed

+14
-17
lines changed

4 files changed

+14
-17
lines changed

crossfit/backend/torch/hf/memory_curve_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

1617
import joblib
1718
import numpy as np
1819
import torch
20+
import transformers
1921
from sklearn.linear_model import LinearRegression
2022
from tqdm import tqdm
21-
from transformers import PreTrainedModel
2223

2324
from crossfit.utils.model_adapter import adapt_model_input
2425
from crossfit.utils.torch_utils import (
@@ -29,7 +30,7 @@
2930

3031

3132
def fit_memory_estimate_curve(
32-
model: PreTrainedModel,
33+
model: "transformers.PreTrainedModel",
3334
path_or_name: str,
3435
start_batch_size: int = 1,
3536
end_batch_size: int = 2048,

crossfit/data/sparse/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ def to_pytrec(self, is_run=False):
172172

173173
qrel = {}
174174
for i in range(self.indices.shape[0]):
175-
query_id = f"q{i+1}"
175+
query_id = f"q{i + 1}"
176176
qrel[query_id] = {}
177177

178178
row = sparse_matrix[i]
179179
for j, score in zip(row.indices, row.data):
180-
doc_id = f"d{j+1}"
180+
doc_id = f"d{j + 1}"
181181
qrel[query_id][doc_id] = int(score) if is_run else float(score)
182182

183183
return qrel

examples/dask_aggregate_bench.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,15 @@
5353
columns = [f"I{i}" for i in range(1, ncolumns + 1)]
5454
if groupby:
5555
columns += groupby if isinstance(groupby, list) else [groupby]
56-
ddf = dd.read_parquet(
57-
path,
58-
blocksize=blocksize,
59-
columns=columns,
60-
)
56+
ddf = dd.read_parquet(path, blocksize=blocksize, columns=columns)
6157
print(f"\nddf: {ddf}\n")
6258

6359
# Aggregate moments (mean, var, std)
6460
agg = cf.Aggregator(Moments(axis=0), per_column=True)
6561
t0 = time.time()
6662
result = aggregate(ddf, agg, to_frame=True)
6763
tf = time.time()
68-
print(f"\nWall Time: {tf-t0} seconds\n")
64+
print(f"\nWall Time: {tf - t0} seconds\n")
6965

7066
# View result
7167
print(f"Result:\n{result}\n")
@@ -76,12 +72,12 @@
7672
t0 = time.time()
7773
std = ddf.groupby(groupby).std().compute()
7874
tf = time.time()
79-
print(f"\nddf.groupby().std() takes {tf-t0} seconds, and returns:\n")
75+
print(f"\nddf.groupby().std() takes {tf - t0} seconds, and returns:\n")
8076
print(f"\n{std}\n")
8177
else:
8278
# Compare to ddf.std()
8379
t0 = time.time()
8480
std = ddf.std().compute()
8581
tf = time.time()
86-
print(f"\nddf.std() takes {tf-t0} seconds, and returns:\n")
82+
print(f"\nddf.std() takes {tf - t0} seconds, and returns:\n")
8783
print(f"\n{std}\n")

tests/pytrec_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def create_qrel(relevance_scores, ids=None):
2424

2525
qrel = {}
2626
for i, query_scores in enumerate(relevance_scores):
27-
query_id = ids[i] if ids is not None else f"q{i+1}"
27+
query_id = ids[i] if ids is not None else f"q{i + 1}"
2828
qrel[query_id] = {}
2929
for j, score in enumerate(query_scores):
3030
_score = int(score.item())
3131

3232
if _score > 0:
33-
doc_id = f"d{j+1}"
33+
doc_id = f"d{j + 1}"
3434
qrel[query_id][doc_id] = int(score.item())
3535

3636
return qrel
@@ -41,10 +41,10 @@ def create_run(predicted_scores, ids=None):
4141

4242
run = {}
4343
for i, query_scores in enumerate(predicted_scores):
44-
query_id = ids[i] if ids is not None else f"q{i+1}"
44+
query_id = ids[i] if ids is not None else f"q{i + 1}"
4545
run[query_id] = {}
4646
for j, score in enumerate(query_scores):
47-
doc_id = f"d{j+1}"
47+
doc_id = f"d{j + 1}"
4848
run[query_id][doc_id] = float(score.item())
4949

5050
return run
@@ -60,6 +60,6 @@ def create_results(metric_arrays):
6060
for k, v in metric_arrays.items():
6161
q_out[k] = float(v[i])
6262

63-
outputs[f"q{i+1}"] = q_out
63+
outputs[f"q{i + 1}"] = q_out
6464

6565
return outputs

0 commit comments

Comments
 (0)