Skip to content

Commit cc85a02

Browse files
committed
DLPy example
1 parent 1cfe7fc commit cc85a02

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

examples/register_sas_dlpy_model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
#
4+
# Copyright © 2019, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
5+
# SPDX-License-Identifier: Apache-2.0
6+
7+
import swat
8+
from dlpy.applications import Sequential
9+
from dlpy.layers import Dense, InputLayer, OutputLayer
10+
from sasctl import Session
11+
from sasctl.tasks import register_model, publish_model
12+
13+
14+
# Connect to the CAS server
15+
s = swat.CAS('hostname', 5570, 'username', 'password')
16+
17+
# Upload the training data to CAS
18+
tbl = s.upload('data/iris.csv').casTable
19+
20+
# Construct a simple neural network
21+
model = Sequential(conn=s, model_table='dlpy_model')
22+
model.add(InputLayer())
23+
model.add(Dense(n=64))
24+
model.add(Dense(n=32))
25+
model.add(OutputLayer(n=3))
26+
27+
# Train on the sample
28+
model.fit(data=tbl,
29+
inputs=['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'],
30+
target='Species',
31+
max_epochs=50,
32+
lr=0.001)
33+
34+
# Export the model as an ASTORE and get a reference to the new ASTORE table
35+
s.deeplearn.dlexportmodel(modelTable=model.model_table, initWeights=model.model_weights, casout='astore_table')
36+
astore = s.CASTable('astore_table')
37+
38+
# Connect to the SAS environment
39+
with Session('hostname', 'username', 'password'):
40+
# Register the trained model by providing:
41+
# - the ASTORE containing the model
42+
# - a name for the model
43+
# - a name for the project
44+
#
45+
# NOTE: the force=True option will create the project if it does not exist.
46+
model = register_model(astore, 'Deep Learning', 'Iris', force=True)
47+
48+
# Publish the model to SAS® Micro Analytic Service (MAS). Specifically to
49+
# the default MAS service "maslocal".
50+
module = publish_model(model, 'maslocal')
51+
52+
# sasctl wraps the published module with Python methods corresponding to
53+
# the various steps defined in the module (like "predict").
54+
response = module.score(SepalLength=5.1, SepalWidth=3.5,
55+
PetalLength=1.4, PetalWidth=0.2)
56+
57+
s.terminate()

0 commit comments

Comments
 (0)