Skip to content

Commit d94f51b

Browse files
authored
Merge 0.0.8.post1 into main (#122)
* Fix cuda context bug caused by PreTrainedModel (#113) * Fix cuda context bug caused by PreTrainedModel Signed-off-by: Vibhu Jawa <vjawa@nvidia.com> * Flake8 style fixes Signed-off-by: Vibhu Jawa <vjawa@nvidia.com> --------- Signed-off-by: Vibhu Jawa <vjawa@nvidia.com> * update version (#115) --------- Signed-off-by: Vibhu Jawa <vjawa@nvidia.com>
1 parent c8ba17a commit d94f51b

File tree

3 files changed

+5
-8
lines changed

3 files changed

+5
-8
lines changed

crossfit/backend/torch/hf/memory_curve_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

1617
import joblib
1718
import numpy as np
1819
import torch
20+
import transformers
1921
from sklearn.linear_model import LinearRegression
2022
from tqdm import tqdm
21-
from transformers import PreTrainedModel
2223

2324
from crossfit.utils.model_adapter import adapt_model_input
2425
from crossfit.utils.torch_utils import (
@@ -29,7 +30,7 @@
2930

3031

3132
def fit_memory_estimate_curve(
32-
model: PreTrainedModel,
33+
model: "transformers.PreTrainedModel",
3334
path_or_name: str,
3435
start_batch_size: int = 1,
3536
end_batch_size: int = 2048,

examples/dask_aggregate_bench.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@
5353
columns = [f"I{i}" for i in range(1, ncolumns + 1)]
5454
if groupby:
5555
columns += groupby if isinstance(groupby, list) else [groupby]
56-
ddf = dd.read_parquet(
57-
path,
58-
blocksize=blocksize,
59-
columns=columns,
60-
)
56+
ddf = dd.read_parquet(path, blocksize=blocksize, columns=columns)
6157
print(f"\nddf: {ddf}\n")
6258

6359
# Aggregate moments (mean, var, std)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from setuptools import find_packages, setup
2020

21-
VERSION = "0.0.8"
21+
VERSION = "0.0.8.post1"
2222

2323

2424
def get_long_description():

0 commit comments

Comments
 (0)