|
1 | | -# Copyright 2023 NVIDIA CORPORATION |
| 1 | +# Copyright 2023-2024 NVIDIA CORPORATION |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
18 | 18 | from dask_cuda import LocalCUDACluster |
19 | 19 | from distributed import Client, LocalCluster |
20 | 20 |
|
21 | | -from crossfit.calculate.frame import MetricFrame |
22 | | -from crossfit.dask.calculate import calculate_per_col as calculate_dask |
23 | | -from crossfit.stats.continuous.stats import ContinuousStats |
| 21 | +import crossfit as cf |
| 22 | +from crossfit.backend.dask.aggregate import aggregate |
| 23 | +from crossfit.metric.continuous.moments import Moments |
24 | 24 |
|
25 | 25 | # Benchmark assumes Criteo dataset. |
26 | 26 | # Low-cardinality columns: |
27 | 27 | # {C6:4, C9:64, C13:11, C16:155, C17:4, C19:15, C25:109, C26:37} |
28 | 28 |
|
29 | 29 | # Options |
30 | | -path = "/raid/dask-space/criteo/crit_pq_int" |
| 30 | +path = "/datasets/rzamora/crit_pq_int" |
31 | 31 | backend = "cudf" |
32 | | -split_row_groups = 10 |
33 | | -ncolumns = 10 |
| 32 | +blocksize = "265MiB" |
| 33 | +ncolumns = 4 |
34 | 34 | groupby = None |
35 | 35 | use_cluster = True |
36 | 36 |
|
|
55 | 55 | columns += groupby if isinstance(groupby, list) else [groupby] |
56 | 56 | ddf = dd.read_parquet( |
57 | 57 | path, |
58 | | - split_row_groups=split_row_groups, |
| 58 | + blocksize=blocksize, |
59 | 59 | columns=columns, |
60 | 60 | ) |
61 | 61 | print(f"\nddf: {ddf}\n") |
62 | 62 |
|
63 | | - # Calculate continuous stats |
64 | | - metric = ContinuousStats() |
| 63 | + # Aggregate moments (mean, var, std) |
| 64 | + agg = cf.Aggregator(Moments(axis=0), per_column=True) |
65 | 65 | t0 = time.time() |
66 | | - mf: MetricFrame = calculate_dask(metric, ddf, groupby=groupby) |
| 66 | + result = aggregate(ddf, agg, to_frame=True) |
67 | 67 | tf = time.time() |
68 | 68 | print(f"\nWall Time: {tf-t0} seconds\n") |
69 | 69 |
|
70 | 70 | # View result |
71 | | - assert isinstance(mf, MetricFrame) |
72 | | - result = mf.result() |
73 | 71 | print(f"Result:\n{result}\n") |
74 | 72 | print(f"Type: {type(result)}\n") |
75 | 73 |
|
|
0 commit comments