Skip to content

Commit 02d75c8

Browse files
committed
First draft of function inspector in action
1 parent 018b760 commit 02d75c8

File tree

1 file changed

+105
-99
lines changed

1 file changed

+105
-99
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 105 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,9 @@
1010
import math
1111
import numpy as np
1212
from scipy.stats import kendalltau, gamma
13-
import types
1413
import pickle
1514
import pickletools
16-
import os
1715

18-
# %%
1916
class JSONFiles:
2017
@classmethod
2118
def writeVarJSON(cls, inputData, isInput=True, jPath=Path.cwd()):
@@ -305,7 +302,7 @@ def writeFileMetadataJSON(cls, modelPrefix, jPath=Path.cwd(), isH2OModel=False):
305302

306303
@classmethod
307304
def writeBaseFitStat(
308-
self, csvPath=None, jPath=Path.cwd(), userInput=False, tupleList=None
305+
cls, csvPath=None, jPath=Path.cwd(), userInput=False, tupleList=None
309306
):
310307
"""
311308
Writes a JSON file to display fit statistics for the model in SAS Open Model Manager.
@@ -376,7 +373,7 @@ def writeBaseFitStat(
376373
]
377374

378375
nullJSONPath = Path(__file__).resolve().parent / "null_dmcas_fitstat.json"
379-
nullJSONDict = self.readJSONFile(nullJSONPath)
376+
nullJSONDict = cls.readJSONFile(nullJSONPath)
380377

381378
dataMap = [{}, {}, {}]
382379
for i in range(3):
@@ -386,19 +383,19 @@ def writeBaseFitStat(
386383
for paramTuple in tupleList:
387384
# ignore incorrectly formatted input arguments
388385
if type(paramTuple) == tuple and len(paramTuple) == 3:
389-
paramName = self.formatParameter(paramTuple[0])
386+
paramName = cls.formatParameter(paramTuple[0])
390387
if paramName not in validParams:
391388
continue
392389
if type(paramTuple[2]) == str:
393-
dataRole = self.convertDataRole(paramTuple[2])
390+
dataRole = cls.convertDataRole(paramTuple[2])
394391
else:
395392
dataRole = paramTuple[2]
396393
dataMap[dataRole - 1]["dataMap"][paramName] = paramTuple[1]
397394

398395
if userInput:
399396
while True:
400397
paramName = input("Parameter name: ")
401-
paramName = self.formatParameter(paramName)
398+
paramName = cls.formatParameter(paramName)
402399
if paramName not in validParams:
403400
print("Not a valid parameter. Please see documentation.")
404401
if input("More parameters? (Y/N)") == "N":
@@ -408,7 +405,7 @@ def writeBaseFitStat(
408405
dataRole = input("Data role: ")
409406

410407
if type(dataRole) is str:
411-
dataRole = self.convertDataRole(dataRole)
408+
dataRole = cls.convertDataRole(dataRole)
412409
dataMap[dataRole - 1]["dataMap"][paramName] = paramValue
413410

414411
if input("More parameters? (Y/N)") == "N":
@@ -418,11 +415,11 @@ def writeBaseFitStat(
418415
csvData = pd.read_csv(csvPath)
419416
for i, row in enumerate(csvData.values):
420417
paramName, paramValue, dataRole = row
421-
paramName = self.formatParameter(paramName)
418+
paramName = cls.formatParameter(paramName)
422419
if paramName not in validParams:
423420
continue
424421
if type(dataRole) is str:
425-
dataRole = self.convertDataRole(dataRole)
422+
dataRole = cls.convertDataRole(dataRole)
426423
dataMap[dataRole - 1]["dataMap"][paramName] = paramValue
427424

428425
outJSON = nullJSONDict
@@ -439,7 +436,7 @@ def writeBaseFitStat(
439436

440437
@classmethod
441438
def calculateFitStat(
442-
self, validateData=None, trainData=None, testData=None, jPath=Path.cwd()
439+
cls, validateData=None, trainData=None, testData=None, jPath=Path.cwd()
443440
):
444441
"""
445442
Calculates fit statistics from user data and predictions and then writes to
@@ -499,7 +496,7 @@ def calculateFitStat(
499496
)
500497

501498
nullJSONPath = Path(__file__).resolve().parent / "null_dmcas_fitstat.json"
502-
nullJSONDict = self.readJSONFile(nullJSONPath)
499+
nullJSONDict = cls.readJSONFile(nullJSONPath)
503500

504501
dataSets = [[[None], [None]], [[None], [None]], [[None], [None]]]
505502

@@ -598,7 +595,7 @@ def calculateFitStat(
598595

599596
@classmethod
600597
def generateROCLiftStat(
601-
self,
598+
cls,
602599
targetName,
603600
targetValue,
604601
swatConn,
@@ -656,10 +653,10 @@ def generateROCLiftStat(
656653
)
657654

658655
nullJSONROCPath = Path(__file__).resolve().parent / "null_dmcas_roc.json"
659-
nullJSONROCDict = self.readJSONFile(nullJSONROCPath)
656+
nullJSONROCDict = cls.readJSONFile(nullJSONROCPath)
660657

661658
nullJSONLiftPath = Path(__file__).resolve().parent / "null_dmcas_lift.json"
662-
nullJSONLiftDict = self.readJSONFile(nullJSONLiftPath)
659+
nullJSONLiftDict = cls.readJSONFile(nullJSONLiftPath)
663660

664661
dataSets = [pd.DataFrame(), pd.DataFrame(), pd.DataFrame()]
665662
columns = ["actual", "predict"]
@@ -965,25 +962,85 @@ def convertDataRole(self, dataRole):
965962

966963
return conversion
967964

968-
def getCurrentScopedImports(self):
965+
@classmethod
966+
def createRequirementsJSON(cls, jPath=Path.cwd()):
969967
"""
970-
Gets the Python modules from the current scope's global variables.
968+
Searches the model directory for Python scripts and pickle files and determines
969+
their Python package dependencies. Found dependencies are then matched to the package
970+
version found in the current working environment. Then the package and version are
971+
written to a requirements.json file.
972+
973+
WARNING:
974+
The methods utilized in this function can determine package dependencies from provided
975+
scripts and pickle files, but CANNOT determine the required package versions without
976+
being in the development environment which they were originally created.
977+
978+
This function works best when run in the model development environment and is likely to
979+
throw errors if run in another environment (and/or produce incorrect package versions).
980+
In the case of using this function outside of the model development environment, it is
981+
recommended to the user that they adjust the requirements.json file's package versions
982+
to match the model development environment.
983+
984+
Parameters
985+
----------
986+
jPath : str, optional
987+
The path to a Python project, by default Path.cwd().
971988
972989
Yields
973-
-------
974-
str
975-
Name of the package that is generated.
990+
------
991+
requirements.json : file
992+
JSON file used to create a specific Python environment in a SAS Model Manager published
993+
container.
976994
"""
977995

978-
for name, val in globals().items():
979-
if isinstance(val, types.ModuleType):
980-
# Split ensures you get root package, not just imported function
981-
name = val.__name__.split(".")[0]
982-
yield name
983-
elif isinstance(val, type):
984-
name = val.__module__.split(".")[0]
985-
yield name
996+
picklePackages = []
997+
pickleFiles = cls.getPickleFile(jPath)
998+
for pickleFile in pickleFiles:
999+
picklePackages.append(cls.getDependenciesFromPickleFile(pickleFile))
1000+
1001+
codeDependencies = cls.getCodeDependencies(jPath)
1002+
1003+
packageList = picklePackages + codeDependencies
1004+
packageAndVersion = cls.getLocalPackageVersion()
1005+
1006+
with open(Path(jPath) / "requirements.json") as file:
1007+
for package, version in packageAndVersion:
1008+
jsonStep = json.dumps(
1009+
[
1010+
{
1011+
"step": "install " + package,
1012+
"command": "pip install " + package + "==" + version,
1013+
}
1014+
],
1015+
indent=4,
1016+
)
1017+
file.write(jsonStep)
1018+
1019+
def getCodeDependencies(self, jPath, debug=False):
1020+
from ..utils import functionInspector
1021+
import inspect
9861022

1023+
fileNames = []
1024+
fileNames.extend(sorted(Path(jPath).glob("*.py")))
1025+
1026+
strScoreCode = ''
1027+
for file in fileNames:
1028+
with open(file, "r") as code:
1029+
strScoreCode = strScoreCode + code.read()
1030+
1031+
stringFunctionInspector = inspect.getsource(functionInspector)
1032+
1033+
execCode = strScoreCode + stringFunctionInspector + '''import logging
1034+
if __name__ == "__main__":
1035+
debug = {}
1036+
logLevel = logging.DEBUG if debug else logging.INFO
1037+
logging.basicConfig(level=logLevel, format="%%(levelname)s: %%(message)s")
1038+
1039+
symbols, dependencies = findDependencies()
1040+
print(dependencies)
1041+
'''.format(debug)
1042+
1043+
exec(execCode)
9871044
def getPickleFile(self, pPath):
9881045
"""
9891046
Given a file path, retrieve the pickle file(s).
@@ -1025,60 +1082,11 @@ def getDependenciesFromPickleFile(self, pickleFile):
10251082
obj = pickle.load(openfile)
10261083
dumps = pickle.dumps(obj)
10271084

1028-
modules = {mod.split(".")[0] for mod, _ in self.getNames(dumps)}
1085+
modules = {mod.split(".")[0] for mod, _ in self.getPackageNames(dumps)}
1086+
modules.discard("builtins")
10291087
return modules
10301088

1031-
@classmethod
1032-
def createRequirementsJSON(self, jPath=Path.cwd()):
1033-
"""
1034-
Searches the root of the project for all Python modules and writes them to a requirements.json file.
1035-
1036-
Parameters
1037-
----------
1038-
jPath : str, optional
1039-
The path to a Python project, by default Path.cwd().
1040-
"""
1041-
1042-
module_version_map = {}
1043-
pickle_files = self.get_pickle_file(jPath)
1044-
requirements_txt_file = os.path.join(jPath, "requirements.txt")
1045-
with open(requirements_txt_file, "r") as f:
1046-
modules_requirements_txt = set()
1047-
for pickle_file in pickle_files:
1048-
modules_pickle = self.get_modules_from_pickle_file(pickle_file)
1049-
for line in f:
1050-
module_parts = line.rstrip().split("==")
1051-
module = module_parts[0]
1052-
version = module_parts[1]
1053-
module_version_map[module] = version
1054-
modules_requirements_txt.add(module)
1055-
pip_name_list = list(modules_requirements_txt.union(modules_pickle))
1056-
1057-
for item in pip_name_list:
1058-
if item in module_version_map:
1059-
if module_version_map[item] == "0.0.0":
1060-
print(
1061-
"Warning: No pip install name found for package: "
1062-
+ item.split("==")[0]
1063-
)
1064-
pip_name_list.remove(item)
1065-
1066-
j = json.dumps(
1067-
[
1068-
{
1069-
"step": "install " + i,
1070-
"command": "pip install " + i + "==" + module_version_map[i],
1071-
}
1072-
if i in module_version_map
1073-
else {"step": "install " + i, "command": "pip install " + i}
1074-
for i in pip_name_list
1075-
],
1076-
indent=4,
1077-
)
1078-
with open(os.path.join(jPath, "requirements.json"), "w") as file:
1079-
print(j, file=file)
1080-
1081-
def getNames(self, stream):
1089+
def getPackageNames(self, stream):
10821090
"""
10831091
Generates (module, class_name) tuples from a pickle stream. Extracts all class names referenced
10841092
by GLOBAL and STACK_GLOBAL opcodes.
@@ -1092,50 +1100,48 @@ def getNames(self, stream):
10921100
A file like object or string containing the pickle.
10931101
10941102
Yields
1095-
-------
1103+
------
10961104
tuple
10971105
Generated (module, class_name) tuples.
10981106
"""
10991107

11001108
stack, markstack, memo = [], [], []
1101-
mo = pickletools.markobject
1109+
mark = pickletools.markobject
11021110

1103-
for op, arg, pos in pickletools.genops(stream):
1104-
# simulate the pickle stack and marking scheme, insofar
1105-
# necessary to allow us to retrieve the names used by STACK_GLOBAL
1111+
# Step through the pickle stack and retrieve names used by STACK_GLOBAL
1112+
for opcode, arg, pos in pickletools.genops(stream):
11061113

1107-
before, after = op.stack_before, op.stack_after
1114+
before, after = opcode.stack_before, opcode.stack_after
11081115
numtopop = len(before)
11091116

1110-
if op.name == "GLOBAL":
1117+
if opcode.name == "GLOBAL":
11111118
yield tuple(arg.split(1, None))
1112-
elif op.name == "STACK_GLOBAL":
1119+
elif opcode.name == "STACK_GLOBAL":
11131120
yield (stack[-2], stack[-1])
1114-
1115-
elif mo in before or (op.name == "POP" and stack and stack[-1] is mo):
1121+
elif mark in before or (opcode.name == "POP" and stack and stack[-1] is mark):
11161122
markpos = markstack.pop()
1117-
while stack[-1] is not mo:
1123+
while stack[-1] is not mark:
11181124
stack.pop()
11191125
stack.pop()
11201126
try:
1121-
numtopop = before.index(mo)
1127+
numtopop = before.index(mark)
11221128
except ValueError:
11231129
numtopop = 0
1124-
elif op.name in {"PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"}:
1125-
if op.name == "MEMOIZE":
1130+
elif opcode.name in {"PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"}:
1131+
if opcode.name == "MEMOIZE":
11261132
memo.append(stack[-1])
11271133
else:
11281134
memo[arg] = stack[-1]
1129-
numtopop, after = 0, [] # memoize and put do not pop the stack
1130-
elif op.name in {"GET", "BINGET", "LONG_BINGET"}:
1135+
numtopop, after = 0, [] # memoize and put; do not pop the stack
1136+
elif opcode.name in {"GET", "BINGET", "LONG_BINGET"}:
11311137
arg = memo[arg]
11321138

11331139
if numtopop:
11341140
del stack[-numtopop:]
1135-
if mo in after:
1141+
if mark in after:
11361142
markstack.append(pos)
11371143

1138-
if len(after) == 1 and op.arg is not None:
1144+
if len(after) == 1 and opcode.arg is not None:
11391145
stack.append(arg)
11401146
else:
11411147
stack.extend(after)

0 commit comments

Comments
 (0)