Skip to content

Commit 110bc8b

Browse files
committed
neptune integration update
1 parent 4adfb95 commit 110bc8b

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

examples/neptune.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# TO START:
2-
# pip install neptune-client, livelossplot
2+
# pip install neptune, livelossplot
33
# export environment variables
44
# enjoy results
55

@@ -13,22 +13,26 @@
1313

1414

1515
def main():
16-
api_token = os.environ.get('NEPTUNE_API_TOKEN')
17-
project_qualified_name = os.environ.get('NEPTUNE_PROJECT_NAME')
18-
logger = NeptuneLogger(api_token=api_token, project_qualified_name=project_qualified_name)
16+
api_token = os.environ.get("NEPTUNE_API_TOKEN")
17+
project_qualified_name = os.environ.get("NEPTUNE_PROJECT")
18+
logger = NeptuneLogger(
19+
api_token=api_token,
20+
project_qualified_name=project_qualified_name,
21+
tags=["example", "livelossplot"],
22+
)
1923
liveplot = PlotLosses(outputs=[logger])
2024
for i in range(20):
2125
liveplot.update(
2226
{
23-
'accuracy': 1 - np.random.rand() / (i + 2.),
24-
'val_accuracy': 1 - np.random.rand() / (i + 0.5),
25-
'mse': 1. / (i + 2.),
26-
'val_mse': 1. / (i + 0.5)
27+
"accuracy": 1 - np.random.rand() / (i + 2.0),
28+
"val_accuracy": 1 - np.random.rand() / (i + 0.5),
29+
"mse": 1.0 / (i + 2.0),
30+
"val_mse": 1.0 / (i + 0.5),
2731
}
2832
)
2933
liveplot.send()
30-
sleep(.5)
34+
sleep(0.5)
3135

3236

33-
if __name__ == '__main__':
37+
if __name__ == "__main__":
3438
main()

livelossplot/outputs/neptune_logger.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,33 @@ class NeptuneLogger(BaseOutput):
77
"""See: https://github.com/neptune-ai/neptune-client
88
YOUR_API_TOKEN and USERNAME/PROJECT_NAME
99
"""
10-
def __init__(self, api_token: Optional[str] = None, project_qualified_name: Optional[str] = None, **kwargs):
10+
11+
def __init__(
12+
self,
13+
api_token: Optional[str] = None,
14+
project_qualified_name: Optional[str] = None,
15+
**kwargs
16+
):
1117
"""Set secrets and create experiment
1218
Args:
1319
api_token: your api token, you can create NEPTUNE_API_TOKEN environment variable instead
1420
project_qualified_name: <user>/<project>, you can create NEPTUNE_PROJECT environment variable instead
1521
**kwargs: keyword args, that will be passed to create_experiment function
1622
"""
1723
import neptune
24+
1825
self.neptune = neptune
19-
self.neptune.init(api_token=api_token, project_qualified_name=project_qualified_name)
20-
self.experiment = self.neptune.create_experiment(**kwargs)
26+
self.run = self.neptune.init_run(
27+
api_token=api_token, project=project_qualified_name, **kwargs
28+
)
2129

2230
def close(self):
2331
"""Close connection"""
24-
self.neptune.stop()
32+
if hasattr(self, "run"):
33+
self.run.stop()
2534

2635
def send(self, logger: MainLogger):
2736
"""Send metrics collected in last step to neptune server"""
2837
for name, log_items in logger.log_history.items():
2938
last_log_item = log_items[-1]
30-
self.neptune.send_metric(name, x=last_log_item.step, y=last_log_item.value)
39+
self.run[name].append(value=last_log_item.value, step=last_log_item.step)

0 commit comments

Comments
 (0)