Skip to content

Commit 47d1448

Browse files
Ashwin Rameshdzier
authored andcommitted
Changed serialization library to json for security
1 parent 3d242a9 commit 47d1448

File tree

11 files changed

+148
-39
lines changed

11 files changed

+148
-39
lines changed

model_analyzer/perf_analyzer/perf_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def update_config(self, params=None):
9595
for key in params:
9696
self[key] = params[key]
9797

98+
def deserialize(self, perf_config_dict):
99+
for key, val in perf_config_dict.items():
100+
setattr(self, key, val)
101+
98102
def representation(self):
99103
"""
100104
Returns

model_analyzer/record/record.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class Record(metaclass=RecordType):
9090
This class is used for representing
9191
records
9292
"""
93-
9493
def __init__(self, value, timestamp):
9594
"""
9695
Parameters
@@ -144,6 +143,13 @@ def tag(self):
144143
the name tag of the record type.
145144
"""
146145

146+
def serialize(self):
147+
return (self.tag, self.__dict__)
148+
149+
def deserialize(self, record_dict):
150+
for key, val in record_dict.items():
151+
setattr(self, key, val)
152+
147153
def value(self):
148154
"""
149155
This method returns the value of recorded metric

model_analyzer/reports/report_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def _get_gpu_stats(self, measurements):
604604
for measurement in measurements:
605605
for gpu_uuid, gpu_info in self._gpu_info.items():
606606
if gpu_uuid in measurement.gpus_used():
607-
gpu_name = (gpu_info['name']).decode('ascii')
607+
gpu_name = gpu_info['name']
608608
max_memory = round(gpu_info['total_memory'] / (2**30), 1)
609609
if gpu_name not in gpu_dict:
610610
gpu_dict[gpu_name] = max_memory

model_analyzer/result/measurement.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from functools import total_ordering
1616
import logging
17+
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
1718

19+
from model_analyzer.record.record import RecordType
1820
from model_analyzer.model_analyzer_exceptions \
1921
import TritonModelAnalyzerException
2022

@@ -39,10 +41,44 @@ def __init__(self, gpu_data, non_gpu_data, perf_config):
3941

4042
# average values over all GPUs
4143
self._gpu_data = gpu_data
42-
self._avg_gpu_data = self._average_list(list(self._gpu_data.values()))
4344
self._non_gpu_data = non_gpu_data
4445
self._perf_config = perf_config
4546

47+
self._avg_gpu_data = self._average_list(list(self._gpu_data.values()))
48+
self._gpu_data_from_tag = {
49+
type(metric).tag: metric
50+
for metric in self._avg_gpu_data
51+
}
52+
self._non_gpu_data_from_tag = {
53+
type(metric).tag: metric
54+
for metric in self._non_gpu_data
55+
}
56+
57+
def deserialize(self, measurement_dict):
58+
# Deserialize gpu_data
59+
for gpu_uuid, gpu_data_list in measurement_dict['_gpu_data'].items():
60+
metric_list = []
61+
for [tag, record_dict] in gpu_data_list:
62+
record_type = RecordType.get(tag)
63+
record = record_type(0)
64+
record.deserialize(record_dict)
65+
metric_list.append(record)
66+
self._gpu_data[gpu_uuid] = metric_list
67+
68+
# non gpu data
69+
self._non_gpu_data = []
70+
for [tag, record_dict] in measurement_dict['_non_gpu_data']:
71+
record_type = RecordType.get(tag)
72+
record = record_type(0)
73+
record.deserialize(record_dict)
74+
self._non_gpu_data.append(record)
75+
76+
# perf config
77+
self._perf_config = PerfAnalyzerConfig()
78+
self._perf_config.deserialize(measurement_dict['_perf_config'])
79+
80+
# Compute contigent data structures
81+
self._avg_gpu_data = self._average_list(list(self._gpu_data.values()))
4682
self._gpu_data_from_tag = {
4783
type(metric).tag: metric
4884
for metric in self._avg_gpu_data

model_analyzer/state/analyzer_state.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,62 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from model_analyzer.model_analyzer_exceptions \
16-
import TritonModelAnalyzerException
15+
from model_analyzer.triton.model.model_config import ModelConfig
16+
from model_analyzer.result.measurement import Measurement
17+
from model_analyzer.record.record import RecordType
1718

1819

1920
class AnalyzerState:
2021
"""
2122
All the state information needed by
2223
Model Analyzer in one place
2324
"""
24-
2525
def __init__(self):
2626
self._state_dict = {}
2727

28+
def serialize(self):
29+
return self._state_dict
30+
31+
def deserialize(self, state_dict):
32+
# Fill results
33+
self._state_dict['ResultManager.results'] = {}
34+
for model_name in state_dict['ResultManager.results']:
35+
self._state_dict['ResultManager.results'][model_name] = {}
36+
for model_config_name in state_dict['ResultManager.results'][
37+
model_name]:
38+
model_config_dict, measurements = state_dict[
39+
'ResultManager.results'][model_name][model_config_name]
40+
41+
# Deserialize model config
42+
model_config = ModelConfig(None)
43+
model_config.deserialize(model_config_dict)
44+
45+
# Deserialize measurements
46+
measurements_dict = {}
47+
for measurement_key, measurement_dict in measurements.items():
48+
measurement = Measurement({}, [], None)
49+
measurement.deserialize(measurement_dict)
50+
measurements_dict[measurement_key] = measurement
51+
self._state_dict['ResultManager.results'][model_name][
52+
model_config_name] = (model_config, measurements_dict)
53+
54+
# Server data
55+
self._state_dict['ResultManager.server_only_data'] = {}
56+
for gpu_uuid, gpu_data_list in state_dict[
57+
'ResultManager.server_only_data'].items():
58+
metric_list = []
59+
for [tag, record_dict] in gpu_data_list:
60+
record_type = RecordType.get(tag)
61+
record = record_type(0)
62+
record.deserialize(record_dict)
63+
metric_list.append(record)
64+
self._state_dict['ResultManager.server_only_data'][
65+
gpu_uuid] = metric_list
66+
67+
# GPU data
68+
self._state_dict['MetricsManager.gpus'] = state_dict[
69+
'MetricsManager.gpus']
70+
2871
def get(self, name):
2972
if name in self._state_dict:
3073
return self._state_dict[name]

model_analyzer/state/analyzer_state_manager.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import signal
2222
import logging
23-
import pickle
23+
import json
2424
import os
2525
import glob
2626

@@ -105,19 +105,27 @@ def load_checkpoint(self):
105105
if os.path.exists(latest_checkpoint_file):
106106
logging.info(
107107
f"Loaded checkpoint from file {latest_checkpoint_file}")
108-
with open(latest_checkpoint_file, 'rb') as f:
108+
with open(latest_checkpoint_file, 'r') as f:
109109
try:
110-
state = pickle.load(f)
110+
111+
self._current_state.deserialize(json.load(f))
111112
except EOFError:
112113
raise TritonModelAnalyzerException(
113114
f'Checkpoint file {latest_checkpoint_file} is'
114115
' empty or corrupted. Remove it from checkpoint'
115116
' directory.')
116-
self._current_state = state
117117
self._starting_fresh_run = False
118118
else:
119119
logging.info("No checkpoint file found, starting a fresh run.")
120120

121+
def default_encode(self, obj):
122+
if isinstance(obj, bytes):
123+
return obj.decode('utf-8')
124+
elif hasattr(obj, 'serialize'):
125+
return obj.serialize()
126+
else:
127+
return obj.__dict__
128+
121129
def save_checkpoint(self):
122130
"""
123131
Saves the state of the model analyzer to disk
@@ -132,8 +140,8 @@ def save_checkpoint(self):
132140
ckpt_filename = os.path.join(self._checkpoint_dir,
133141
f"{self._checkpoint_index}.ckpt")
134142
if self._state_changed:
135-
with open(ckpt_filename, 'wb') as f:
136-
pickle.dump(self._current_state, f)
143+
with open(ckpt_filename, 'w') as f:
144+
json.dump(self._current_state, f, default=self.default_encode)
137145
logging.info(f"Saved checkpoint to {ckpt_filename}.")
138146

139147
self._checkpoint_index += 1

model_analyzer/triton/model/model_config.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class ModelConfig:
2828
"""
2929
A class that encapsulates all the metadata about a Triton model.
3030
"""
31-
3231
def __init__(self, model_config):
3332
"""
3433
Parameters
@@ -39,6 +38,11 @@ def __init__(self, model_config):
3938
self._model_config = model_config
4039
self._cpu_only = False
4140

41+
def serialize(self):
42+
model_config_dict = json_format.MessageToDict(self._model_config)
43+
model_config_dict['cpu_only'] = self._cpu_only
44+
return model_config_dict
45+
4246
def __getstate__(self):
4347
"""
4448
Allows serialization of
@@ -49,6 +53,13 @@ def __getstate__(self):
4953
model_config_dict['cpu_only'] = self._cpu_only
5054
return model_config_dict
5155

56+
def deserialize(self, model_config_dict):
57+
self._cpu_only = model_config_dict['cpu_only']
58+
del model_config_dict['cpu_only']
59+
protobuf_message = json_format.ParseDict(
60+
model_config_dict, model_config_pb2.ModelConfig())
61+
self._model_config = protobuf_message
62+
5263
def __setstate__(self, model_config_dict):
5364
"""
5465
Allows deserialization of
@@ -57,8 +68,8 @@ def __setstate__(self, model_config_dict):
5768

5869
self._cpu_only = model_config_dict['cpu_only']
5970
del model_config_dict['cpu_only']
60-
protobuf_message = json_format.ParseDict(model_config_dict,
61-
model_config_pb2.ModelConfig())
71+
protobuf_message = json_format.ParseDict(
72+
model_config_dict, model_config_pb2.ModelConfig())
6273
self._model_config = protobuf_message
6374

6475
@staticmethod
@@ -114,8 +125,8 @@ def create_from_dictionary(model_dict):
114125
ModelConfig
115126
"""
116127

117-
protobuf_message = json_format.ParseDict(model_dict,
118-
model_config_pb2.ModelConfig())
128+
protobuf_message = json_format.ParseDict(
129+
model_dict, model_config_pb2.ModelConfig())
119130

120131
return ModelConfig(protobuf_message)
121132

@@ -159,7 +170,8 @@ def cpu_only(self):
159170

160171
return self._cpu_only
161172

162-
def write_config_to_file(self, model_path, src_model_path, last_model_path):
173+
def write_config_to_file(self, model_path, src_model_path,
174+
last_model_path):
163175
"""
164176
Writes a protobuf config file.
165177

qa/L0_server_launch_modes/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ CLIENT_PROTOCOLS="http grpc"
4141
TRITON_DOCKER_IMAGE="nvcr.io/nvidia/tritonserver:21.05-py3"
4242

4343
mkdir $CHECKPOINT_DIRECTORY
44-
cp $CHECKPOINT_REPOSITORY/server_launch_modes.ckpt $CHECKPOINT_DIRECTORY/0.ckpt
44+
# cp $CHECKPOINT_REPOSITORY/server_launch_modes.ckpt $CHECKPOINT_DIRECTORY/0.ckpt
4545

4646
# Run the model-analyzer, both client protocols
4747
RET=0

tests/mocks/mock_pickle.py renamed to tests/mocks/mock_json.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,41 @@
1616
from unittest.mock import Mock, MagicMock, patch
1717

1818

19-
class MockPickleMethods(MockBase):
19+
class MockJSONMethods(MockBase):
2020
"""
2121
Mocks the methods for the os module
2222
"""
2323
def __init__(self):
24-
pickle_attrs = {'load': MagicMock(), 'dump': MagicMock()}
25-
self.patcher_pickle = patch(
26-
'model_analyzer.state.analyzer_state_manager.pickle',
27-
Mock(**pickle_attrs))
24+
json_attrs = {'load': MagicMock(), 'dump': MagicMock()}
25+
self.patcher_json = patch(
26+
'model_analyzer.state.analyzer_state_manager.json',
27+
Mock(**json_attrs))
2828
super().__init__()
2929

3030
def start(self):
3131
"""
3232
start the patchers
3333
"""
3434

35-
self.pickle_mock = self.patcher_pickle.start()
35+
self.json_mock = self.patcher_json.start()
3636

3737
def _fill_patchers(self):
3838
"""
3939
Fills the patcher list for destruction
4040
"""
4141

42-
self._patchers.append(self.patcher_pickle)
42+
self._patchers.append(self.patcher_json)
4343

44-
def set_pickle_load_return_value(self, value):
44+
def set_json_load_return_value(self, value):
4545
"""
46-
Sets the return value for pickle load
46+
Sets the return value for json load
4747
"""
4848

49-
self.pickle_mock.load.return_value = value
49+
self.json_mock.load.return_value = value
5050

51-
def set_pickle_load_side_effect(self, effect):
51+
def set_json_load_side_effect(self, effect):
5252
"""
5353
Sets a side effet
5454
"""
5555

56-
self.pickle_mock.load.side_effect = effect
56+
self.json_mock.load.side_effect = effect

tests/test_report_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .mocks.mock_io import MockIOMethods
2929
from .mocks.mock_matplotlib import MockMatplotlibMethods
3030
from .mocks.mock_os import MockOSMethods
31-
from .mocks.mock_pickle import MockPickleMethods
31+
from .mocks.mock_json import MockJSONMethods
3232

3333
from .common.test_utils import construct_measurement
3434
from .common import test_result_collector as trc
@@ -115,8 +115,8 @@ def setUp(self):
115115
self.io_mock.start()
116116
self.matplotlib_mock = MockMatplotlibMethods()
117117
self.matplotlib_mock.start()
118-
self.pickle_mock = MockPickleMethods()
119-
self.pickle_mock.start()
118+
self.json_mock = MockJSONMethods()
119+
self.json_mock.start()
120120

121121
def test_add_results(self):
122122
for mode in ['online', 'offline']:
@@ -221,7 +221,7 @@ def tearDown(self):
221221
self.matplotlib_mock.stop()
222222
self.io_mock.stop()
223223
self.os_mock.stop()
224-
self.pickle_mock.stop()
224+
self.json_mock.stop()
225225

226226

227227
if __name__ == '__main__':

0 commit comments

Comments
 (0)