Skip to content

Commit 1ee3de4

Browse files
authored
Update dask "compute" benchmark (#50)
* update dask benchmark to use aggregator API * rename file * update comment
1 parent 06afeee commit 1ee3de4

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 NVIDIA CORPORATION
1+
# Copyright 2023-2024 NVIDIA CORPORATION
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,19 +18,19 @@
1818
from dask_cuda import LocalCUDACluster
1919
from distributed import Client, LocalCluster
2020

21-
from crossfit.calculate.frame import MetricFrame
22-
from crossfit.dask.calculate import calculate_per_col as calculate_dask
23-
from crossfit.stats.continuous.stats import ContinuousStats
21+
import crossfit as cf
22+
from crossfit.backend.dask.aggregate import aggregate
23+
from crossfit.metric.continuous.moments import Moments
2424

2525
# Benchmark assumes Criteo dataset.
2626
# Low-cardinality columns:
2727
# {C6:4, C9:64, C13:11, C16:155, C17:4, C19:15, C25:109, C26:37}
2828

2929
# Options
30-
path = "/raid/dask-space/criteo/crit_pq_int"
30+
path = "/datasets/rzamora/crit_pq_int"
3131
backend = "cudf"
32-
split_row_groups = 10
33-
ncolumns = 10
32+
blocksize = "265MiB"
33+
ncolumns = 4
3434
groupby = None
3535
use_cluster = True
3636

@@ -55,21 +55,19 @@
5555
columns += groupby if isinstance(groupby, list) else [groupby]
5656
ddf = dd.read_parquet(
5757
path,
58-
split_row_groups=split_row_groups,
58+
blocksize=blocksize,
5959
columns=columns,
6060
)
6161
print(f"\nddf: {ddf}\n")
6262

63-
# Calculate continuous stats
64-
metric = ContinuousStats()
63+
# Aggregate moments (mean, var, std)
64+
agg = cf.Aggregator(Moments(axis=0), per_column=True)
6565
t0 = time.time()
66-
mf: MetricFrame = calculate_dask(metric, ddf, groupby=groupby)
66+
result = aggregate(ddf, agg, to_frame=True)
6767
tf = time.time()
6868
print(f"\nWall Time: {tf-t0} seconds\n")
6969

7070
# View result
71-
assert isinstance(mf, MetricFrame)
72-
result = mf.result()
7371
print(f"Result:\n{result}\n")
7472
print(f"Type: {type(result)}\n")
7573

0 commit comments

Comments
 (0)