Skip to content

Commit d3e063a

Browse files
authored
Merge pull request #308 from splunk/cmcginley/mathieugonzales_replace_deprecated_pydantic_validators
Refactoring for formatting and some logical error correction
2 parents 9f30e62 + 31b4b21 commit d3e063a

File tree

8 files changed

+128
-79
lines changed

8 files changed

+128
-79
lines changed

contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class SetupTestGroupResults(BaseModel):
4848
success: bool = True
4949
duration: float = 0
5050
start_time: float
51-
model_config = ConfigDict(arbitrary_types_allowed=True)
51+
model_config = ConfigDict(
52+
arbitrary_types_allowed=True
53+
)
5254

5355

5456
class CleanupTestGroupResults(BaseModel):
@@ -89,7 +91,9 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
8991
_conn: client.Service = PrivateAttr()
9092
pbar: tqdm.tqdm = None
9193
start_time: Optional[float] = None
92-
model_config = ConfigDict(arbitrary_types_allowed=True)
94+
model_config = ConfigDict(
95+
arbitrary_types_allowed=True
96+
)
9397

9498
def __init__(self, **data):
9599
super().__init__(**data)

contentctl/actions/detection_testing/views/DetectionTestingViewWeb.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from bottle import template, Bottle, ServerAdapter
2-
from contentctl.actions.detection_testing.views.DetectionTestingView import (
3-
DetectionTestingView,
4-
)
1+
from threading import Thread
52

3+
from bottle import template, Bottle, ServerAdapter
64
from wsgiref.simple_server import make_server, WSGIRequestHandler
75
import jinja2
86
import webbrowser
9-
from threading import Thread
107
from pydantic import ConfigDict
118

9+
from contentctl.actions.detection_testing.views.DetectionTestingView import (
10+
DetectionTestingView,
11+
)
12+
1213
DEFAULT_WEB_UI_PORT = 7999
1314

1415
STATUS_TEMPLATE = """
@@ -101,7 +102,9 @@ def log_exception(*args, **kwargs):
101102
class DetectionTestingViewWeb(DetectionTestingView):
102103
bottleApp: Bottle = Bottle()
103104
server: SimpleWebServer = SimpleWebServer(host="0.0.0.0", port=DEFAULT_WEB_UI_PORT)
104-
model_config = ConfigDict(arbitrary_types_allowed=True)
105+
model_config = ConfigDict(
106+
arbitrary_types_allowed=True
107+
)
105108

106109
def setup(self):
107110
self.bottleApp.route("/", callback=self.showStatus)

contentctl/enrichments/cve_enrichment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ def url(self)->str:
3232
class CveEnrichment(BaseModel):
3333
use_enrichment: bool = True
3434
cve_api_obj: Union[CVESearch,None] = None
35-
35+
3636
# Arbitrary_types are allowed to let us use the CVESearch Object
37-
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
37+
model_config = ConfigDict(
38+
arbitrary_types_allowed=True,
39+
frozen=True
40+
)
3841

3942
@staticmethod
4043
def getCveEnrichment(config:validate, timeout_seconds:int=10, force_disable_enrichment:bool=True)->CveEnrichment:

contentctl/objects/base_test_result.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33

44
from pydantic import ConfigDict, BaseModel
5-
from splunklib.data import Record
5+
from splunklib.data import Record # type: ignore
66

77
from contentctl.helper.utils import Utils
88

@@ -52,9 +52,12 @@ class BaseTestResult(BaseModel):
5252

5353
# The Splunk endpoint URL
5454
sid_link: Union[None, str] = None
55-
55+
5656
# Needed to allow for embedding of Exceptions in the model
57-
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
57+
model_config = ConfigDict(
58+
validate_assignment=True,
59+
arbitrary_types_allowed=True
60+
)
5861

5962
@property
6063
def passed(self) -> bool:

contentctl/objects/correlation_search.py

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22
import time
33
import json
4-
from typing import Union, Optional, Any
4+
from typing import Any
55
from enum import Enum
6+
from functools import cached_property
67

78
from pydantic import ConfigDict, BaseModel, computed_field, Field, PrivateAttr
89
from splunklib.results import JSONResultsReader, Message # type: ignore
@@ -15,7 +16,7 @@
1516
from contentctl.objects.base_test_result import TestResultStatus
1617
from contentctl.objects.integration_test_result import IntegrationTestResult
1718
from contentctl.actions.detection_testing.progress_bar import (
18-
format_pbar_string,
19+
format_pbar_string, # type: ignore
1920
TestReportingType,
2021
TestingStates
2122
)
@@ -178,12 +179,14 @@ class PbarData(BaseModel):
178179
:param fq_test_name: the fully qualifed (fq) test name ("<detection_name>:<test_name>") used for logging
179180
:param start_time: the start time used for logging
180181
"""
181-
pbar: tqdm
182+
pbar: tqdm # type: ignore
182183
fq_test_name: str
183184
start_time: float
184-
185+
185186
# needed to support the tqdm type
186-
model_config = ConfigDict(arbitrary_types_allowed=True)
187+
model_config = ConfigDict(
188+
arbitrary_types_allowed=True
189+
)
187190

188191

189192
class CorrelationSearch(BaseModel):
@@ -196,78 +199,110 @@ class CorrelationSearch(BaseModel):
196199
:param pbar_data: the encapsulated info needed for logging w/ pbar
197200
:param test_index: the index attack data is forwarded to for testing (optionally used in cleanup)
198201
"""
199-
## The following three fields are explicitly needed at instantiation # noqa: E266
200-
201202
# the detection associated with the correlation search (e.g. "Windows Modify Registry EnableLinkedConnections")
202-
detection: Detection
203+
detection: Detection = Field(...)
203204

204205
# a Service instance representing a connection to a Splunk instance
205-
service: splunklib.Service
206+
service: splunklib.Service = Field(...)
206207

207208
# the encapsulated info needed for logging w/ pbar
208-
pbar_data: PbarData
209-
210-
## The following field is optional for instantiation # noqa: E266
209+
pbar_data: PbarData = Field(...)
211210

212211
# The index attack data is sent to; can be None if we are relying on the caller to do our
213212
# cleanup of this index
214-
test_index: Optional[str] = Field(default=None, min_length=1)
215-
216-
## All remaining fields can be derived from other fields or have intentional defaults that # noqa: E266
217-
## should not be changed (validators should prevent instantiating some of these fields directly # noqa: E266
218-
## to prevent undefined behavior) # noqa: E266
213+
test_index: str | None = Field(default=None, min_length=1)
219214

220215
# The logger to use (logs all go to a null pipe unless ENABLE_LOGGING is set to True, so as not
221216
# to conflict w/ tqdm)
222-
logger: logging.Logger = Field(default_factory=get_logger)
217+
logger: logging.Logger = Field(default_factory=get_logger, init=False)
218+
219+
# The set of indexes to clear on cleanup
220+
indexes_to_purge: set[str] = Field(default=set(), init=False)
221+
222+
# The risk analysis adaptive response action (if defined)
223+
_risk_analysis_action: RiskAnalysisAction | None = PrivateAttr(default=None)
224+
225+
# The notable adaptive response action (if defined)
226+
_notable_action: NotableAction | None = PrivateAttr(default=None)
227+
228+
# The list of risk events found
229+
_risk_events: list[RiskEvent] | None = PrivateAttr(default=None)
230+
231+
# The list of notable events found
232+
_notable_events: list[NotableEvent] | None = PrivateAttr(default=None)
233+
234+
# Need arbitrary types to allow fields w/ types like SavedSearch; we also want to forbid
235+
# unexpected fields
236+
model_config = ConfigDict(
237+
arbitrary_types_allowed=True,
238+
extra='forbid'
239+
)
240+
241+
def model_post_init(self, __context: Any) -> None:
242+
super().model_post_init(__context)
243+
244+
# Parse the initial values for the risk/notable actions
245+
self._parse_risk_and_notable_actions()
223246

224-
# The search name (e.g. "ESCU - Windows Modify Registry EnableLinkedConnections - Rule")
225247
@computed_field
226-
@property
248+
@cached_property
227249
def name(self) -> str:
250+
"""
251+
The search name (e.g. "ESCU - Windows Modify Registry EnableLinkedConnections - Rule")
252+
253+
:returns: the search name
254+
:rtype: str
255+
"""
228256
return f"ESCU - {self.detection.name} - Rule"
229257

230-
# The path to the saved search on the Splunk instance
231258
@computed_field
232-
@property
259+
@cached_property
233260
def splunk_path(self) -> str:
261+
"""
262+
The path to the saved search on the Splunk instance
263+
264+
:returns: the search path
265+
:rtype: str
266+
"""
234267
return f"/saved/searches/{self.name}"
235268

236-
# A model of the saved search as provided by splunklib
237269
@computed_field
238-
@property
239-
def saved_search(self) -> splunklib.SavedSearch | None:
270+
@cached_property
271+
def saved_search(self) -> splunklib.SavedSearch:
272+
"""
273+
A model of the saved search as provided by splunklib
274+
275+
:returns: the SavedSearch object
276+
:rtype: :class:`splunklib.client.SavedSearch`
277+
"""
240278
return splunklib.SavedSearch(
241279
self.service,
242280
self.splunk_path,
243281
)
244282

245-
# The set of indexes to clear on cleanup
246-
indexes_to_purge: set[str] = set()
247-
248-
# The risk analysis adaptive response action (if defined)
283+
# TODO (cmcginley): need to make this refreshable
249284
@computed_field
250285
@property
251286
def risk_analysis_action(self) -> RiskAnalysisAction | None:
252-
if not self.saved_search.content:
253-
return None
254-
return CorrelationSearch._get_risk_analysis_action(self.saved_search.content)
287+
"""
288+
The risk analysis adaptive response action (if defined)
255289
256-
# The notable adaptive response action (if defined)
290+
:returns: the RiskAnalysisAction object, if it exists
291+
:rtype: :class:`contentctl.objects.risk_analysis_action.RiskAnalysisAction` | None
292+
"""
293+
return self._risk_analysis_action
294+
295+
# TODO (cmcginley): need to make this refreshable
257296
@computed_field
258297
@property
259298
def notable_action(self) -> NotableAction | None:
260-
if not self.saved_search.content:
261-
return None
262-
return CorrelationSearch._get_notable_action(self.saved_search.content)
263-
264-
# The list of risk events found
265-
_risk_events: Optional[list[RiskEvent]] = PrivateAttr(default=None)
266-
267-
# The list of notable events found
268-
_notable_events: Optional[list[NotableEvent]] = PrivateAttr(default=None)
269-
model_config = ConfigDict(arbitrary_types_allowed=True, extra='forbid')
299+
"""
300+
The notable adaptive response action (if defined)
270301
302+
:returns: the NotableAction object, if it exists
303+
:rtype: :class:`contentctl.objects.notable_action.NotableAction` | None
304+
"""
305+
return self._notable_action
271306

272307
@property
273308
def earliest_time(self) -> str:
@@ -327,7 +362,7 @@ def has_notable_action(self) -> bool:
327362
return self.notable_action is not None
328363

329364
@staticmethod
330-
def _get_risk_analysis_action(content: dict[str, Any]) -> Optional[RiskAnalysisAction]:
365+
def _get_risk_analysis_action(content: dict[str, Any]) -> RiskAnalysisAction | None:
331366
"""
332367
Given the saved search content, parse the risk analysis action
333368
:param content: a dict of strings to values
@@ -341,7 +376,7 @@ def _get_risk_analysis_action(content: dict[str, Any]) -> Optional[RiskAnalysisA
341376
return None
342377

343378
@staticmethod
344-
def _get_notable_action(content: dict[str, Any]) -> Optional[NotableAction]:
379+
def _get_notable_action(content: dict[str, Any]) -> NotableAction | None:
345380
"""
346381
Given the saved search content, parse the notable action
347382
:param content: a dict of strings to values
@@ -365,10 +400,6 @@ def _get_relevant_observables(observables: list[Observable]) -> list[Observable]
365400
relevant.append(observable)
366401
return relevant
367402

368-
# TODO (PEX-484): ideally, we could handle this and the following init w/ a call to
369-
# model_post_init, so that all the logic is encapsulated w/in _parse_risk_and_notable_actions
370-
# but that is a pydantic v2 feature (see the init validators for risk/notable actions):
371-
# https://docs.pydantic.dev/latest/api/base_model/#pydantic.main.BaseModel.model_post_init
372403
def _parse_risk_and_notable_actions(self) -> None:
373404
"""Parses the risk/notable metadata we care about from self.saved_search.content
374405
@@ -379,12 +410,12 @@ def _parse_risk_and_notable_actions(self) -> None:
379410
unpacked to be anything other than a singleton
380411
"""
381412
# grab risk details if present
382-
self.risk_analysis_action = CorrelationSearch._get_risk_analysis_action(
413+
self._risk_analysis_action = CorrelationSearch._get_risk_analysis_action(
383414
self.saved_search.content # type: ignore
384415
)
385416

386417
# grab notable details if present
387-
self.notable_action = CorrelationSearch._get_notable_action(self.saved_search.content) # type: ignore
418+
self._notable_action = CorrelationSearch._get_notable_action(self.saved_search.content) # type: ignore
388419

389420
def refresh(self) -> None:
390421
"""Refreshes the metadata in the SavedSearch entity, and re-parses the fields we care about
@@ -672,7 +703,7 @@ def validate_risk_events(self) -> None:
672703
# TODO (#250): Re-enable and refactor code that validates the specific risk counts
673704
# Validate risk events in aggregate; we should have an equal amount of risk events for each
674705
# relevant observable, and the total count should match the total number of events
675-
# individual_count: Optional[int] = None
706+
# individual_count: int | None = None
676707
# total_count = 0
677708
# for observable_str in observable_counts:
678709
# self.logger.debug(
@@ -736,7 +767,7 @@ def test(self, max_sleep: int = TimeoutConfig.MAX_SLEEP.value, raise_on_exc: boo
736767
)
737768

738769
# initialize result as None
739-
result: Optional[IntegrationTestResult] = None
770+
result: IntegrationTestResult | None = None
740771

741772
# keep track of time slept and number of attempts for exponential backoff (base 2)
742773
elapsed_sleep_time = 0

contentctl/objects/notable_event.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ class NotableEvent(BaseModel):
1010

1111
# The search ID that found that generated this risk event
1212
orig_sid: str
13-
13+
1414
# Allowing fields that aren't explicitly defined to be passed since some of the risk event's
1515
# fields vary depending on the SPL which generated them
16-
model_config = ConfigDict(extra='allow')
16+
model_config = ConfigDict(
17+
extra='allow'
18+
)
1719

1820
def validate_against_detection(self, detection: Detection) -> None:
1921
raise NotImplementedError()

contentctl/objects/risk_analysis_action.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,30 @@ class RiskAnalysisAction(BaseModel):
2323

2424
@field_validator("message", mode="before")
2525
@classmethod
26-
def _validate_message(cls, message) -> str:
26+
def _validate_message(cls, v: Any) -> str:
2727
"""
28-
Validate splunk_path and derive if None
28+
Validate message and derive if None
2929
"""
30-
if message is None:
30+
if v is None:
3131
raise ValueError(
3232
"RiskAnalysisAction.message is a required field, cannot be None. Check the "
3333
"detection YAML definition to ensure a message is defined"
3434
)
3535

36-
if not isinstance(message, str):
36+
if not isinstance(v, str):
3737
raise ValueError(
3838
"RiskAnalysisAction.message must be a string. Check the detection YAML definition "
3939
"to ensure message is defined as a string"
4040
)
4141

42-
if len(message.strip()) < 1:
42+
if len(v.strip()) < 1:
4343
raise ValueError(
4444
"RiskAnalysisAction.message must be a meaningful string, with a length greater than"
4545
"or equal to 1 (once stripped of trailing/leading whitespace). Check the detection "
4646
"YAML definition to ensure message is defined as a meanigful string"
4747
)
4848

49-
return message
49+
return v
5050

5151
@classmethod
5252
def parse_from_dict(cls, dict_: dict[str, Any]) -> "RiskAnalysisAction":

0 commit comments

Comments
 (0)