Skip to content

Commit 16eaf1a

Browse files
authored
Update DVCLive examples in Setup (#4224)
1 parent 5e131d9 commit 16eaf1a

File tree

8 files changed

+111
-22
lines changed

8 files changed

+111
-22
lines changed

webview/src/setup/components/App.test.tsx

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -918,14 +918,21 @@ describe('App', () => {
918918
expect(setupDVCButton).toBeInTheDocument()
919919
expect(setupDVCButton).toBeVisible()
920920

921-
expect(screen.getByText('demo')).toBeInTheDocument()
922-
expect(screen.getAllByText('-')).toHaveLength(2)
921+
const remotesSection = screen.getByTestId('remotes-section-details')
923922

924-
expect(screen.getByText('example-get-started')).toBeInTheDocument()
925-
expect(screen.getByText('drive')).toBeInTheDocument()
926-
expect(screen.getByText('gdrive://appDataFolder')).toBeInTheDocument()
927-
expect(screen.getByText('storage')).toBeInTheDocument()
928-
expect(screen.getByText('s3://some-bucket')).toBeInTheDocument()
923+
expect(within(remotesSection).getByText('demo')).toBeInTheDocument()
924+
expect(within(remotesSection).getAllByText('-')).toHaveLength(2)
925+
expect(
926+
within(remotesSection).getByText('example-get-started')
927+
).toBeInTheDocument()
928+
expect(within(remotesSection).getByText('drive')).toBeInTheDocument()
929+
expect(
930+
within(remotesSection).getByText('gdrive://appDataFolder')
931+
).toBeInTheDocument()
932+
expect(within(remotesSection).getByText('storage')).toBeInTheDocument()
933+
expect(
934+
within(remotesSection).getByText('s3://some-bucket')
935+
).toBeInTheDocument()
929936
})
930937
})
931938
})

webview/src/setup/components/experiments/DvcLiveExamples.tsx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* eslint-disable @typescript-eslint/no-unsafe-call */
22
import React from 'react'
33
import styles from './styles.module.scss'
4+
import { OtherFrameworks } from './OtherFrameworks'
45
import pyTorch from '../../snippets/pyTorch.py'
56
import huggingFace from '../../snippets/huggingFace.py'
67
import keras from '../../snippets/keras.py'
@@ -17,6 +18,10 @@ export const DvcLiveExamples: React.FC = () => {
1718
<Panels
1819
className={styles.dvcLiveExamples}
1920
panels={[
21+
{
22+
children: <PythonCodeBlock>{pythonApi.toString()}</PythonCodeBlock>,
23+
title: 'Python API'
24+
},
2025
{
2126
children: <PythonCodeBlock>{pyTorch.toString()}</PythonCodeBlock>,
2227
title: 'PyTorch Lightning'
@@ -30,8 +35,8 @@ export const DvcLiveExamples: React.FC = () => {
3035
title: 'Keras'
3136
},
3237
{
33-
children: <PythonCodeBlock>{pythonApi.toString()}</PythonCodeBlock>,
34-
title: 'General Python API'
38+
children: <OtherFrameworks />,
39+
title: 'Other'
3540
}
3641
]}
3742
/>
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import React from 'react'
2+
import styles from './styles.module.scss'
3+
4+
export const OtherFrameworks = () => (
5+
<div className={styles.otherFrameworks}>
6+
These frameworks are also supported:
7+
<ul>
8+
<li>
9+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/catalyst">
10+
Catalyst
11+
</a>
12+
</li>
13+
<li>
14+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/fastai">Fast.ai</a>
15+
</li>
16+
<li>
17+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/lightgbm">
18+
LightGBM
19+
</a>
20+
</li>
21+
<li>
22+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/mmcv">MMCV</a>
23+
</li>
24+
<li>
25+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/optuna">Optuna</a>
26+
</li>
27+
<li>
28+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/pytorch">PyTorch</a>
29+
</li>
30+
<li>
31+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/tensorflow">
32+
TensorFlow
33+
</a>
34+
</li>
35+
<li>
36+
<a href="https://dvc.org/doc/dvclive/ml-frameworks/xgboost">XGBoost</a>
37+
</li>
38+
</ul>
39+
</div>
40+
)

webview/src/setup/components/experiments/styles.module.scss

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,15 @@
88
margin: 0 auto;
99
}
1010
}
11+
12+
.otherFrameworks {
13+
white-space: pre-wrap;
14+
background-color: transparent !important;
15+
color: $watermark-color;
16+
font-family: sans-serif;
17+
letter-spacing: 0.04em;
18+
line-height: 1.6;
19+
padding: 0;
20+
text-align: left;
21+
width: max-content;
22+
}

webview/src/setup/snippets/huggingFace.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,12 @@
22

33
...
44

5+
trainer = Trainer(
6+
model, args,
7+
train_dataset=train_data,
8+
eval_dataset=eval_data,
9+
tokenizer=tokenizer,
10+
compute_metrics=compute_metrics,
11+
)
512
trainer.add_callback(DVCLiveCallback(save_dvc_exp=True))
6-
trainer.train()
13+
trainer.train()

webview/src/setup/snippets/keras.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
...
44

55
model.fit(
6-
train_dataset, validation_data=validation_dataset,
7-
callbacks=[DVCLiveCallback(save_dvc_exp=True)])
6+
train_dataset, epochs=num_epochs,
7+
validation_data=validation_dataset,
8+
callbacks=[DVCLiveCallback(save_dvc_exp=True)])
Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
import lightning.pytorch as pl
12
from dvclive.lightning import DVCLiveLogger
23

34
...
45

5-
trainer = Trainer(logger=DVCLiveLogger(save_dvc_exp=True))
6-
trainer.fit(model)
6+
class LitModule(pl.LightningModule):
7+
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
8+
super().__init__()
9+
# layer_1_dim and learning_rate will be logged by DVCLive
10+
self.save_hyperparameters()
11+
12+
def training_step(self, batch, batch_idx):
13+
metric = ...
14+
# See Output Format bellow
15+
self.log("train_metric", metric, on_step=False, on_epoch=True)
16+
17+
dvclive_logger = DVCLiveLogger(save_dvc_exp=True)
18+
19+
model = LitModule()
20+
trainer = pl.Trainer(logger=dvclive_logger)
21+
trainer.fit(model)
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import random
2+
import sys
13
from dvclive import Live
24

35
with Live(save_dvc_exp=True) as live:
4-
live.log_param("epochs", NUM_EPOCHS)
5-
6-
for epoch in range(NUM_EPOCHS):
7-
train_model(...)
8-
metrics = evaluate_model(...)
9-
for metric_name, value in metrics.items():
10-
live.log_metric(metric_name, value)
11-
live.next_step()
6+
epochs = int(sys.argv[1])
7+
live.log_param("epochs", epochs)
8+
for epoch in range(epochs):
9+
live.log_metric("train/accuracy", epoch + random.random())
10+
live.log_metric("train/loss", epochs - epoch - random.random())
11+
live.log_metric("val/accuracy",epoch + random.random() )
12+
live.log_metric("val/loss", epochs - epoch - random.random())
13+
live.next_step()

0 commit comments

Comments
 (0)