Skip to content

Commit 25601d9

Browse files
replaced deprecated Pydantic v1 validators
1 parent a453237 commit 25601d9

File tree

3 files changed

+52
-158
lines changed

3 files changed

+52
-158
lines changed

contentctl/objects/correlation_search.py

Lines changed: 28 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -218,117 +218,52 @@ class CorrelationSearch(BaseModel):
218218
logger: logging.Logger = Field(default_factory=get_logger)
219219

220220
# The search name (e.g. "ESCU - Windows Modify Registry EnableLinkedConnections - Rule")
221-
name: Optional[str] = None
221+
@computed_field
222+
@property
223+
def name(self) -> str:
224+
return f"ESCU - {self.detection.name} - Rule"
222225

223226
# The path to the saved search on the Splunk instance
224-
splunk_path: Optional[str] = None
227+
@computed_field
228+
@property
229+
def splunk_path(self) -> str:
230+
return f"/saved/searches/{self.name}"
225231

226232
# A model of the saved search as provided by splunklib
227-
saved_search: Optional[splunklib.SavedSearch] = None
233+
@computed_field
234+
@property
235+
def saved_search(self) -> splunklib.SavedSearch | None:
236+
return splunklib.SavedSearch(
237+
self.service,
238+
self.splunk_path,
239+
)
228240

229241
# The set of indexes to clear on cleanup
230242
indexes_to_purge: set[str] = set()
231243

232244
# The risk analysis adaptive response action (if defined)
233-
risk_analysis_action: Union[RiskAnalysisAction, None] = None
245+
@computed_field
246+
@property
247+
def risk_analysis_action(self) -> RiskAnalysisAction | None:
248+
if not self.saved_search.content:
249+
return None
250+
return CorrelationSearch._get_risk_analysis_action(self.saved_search.content)
234251

235252
# The notable adaptive response action (if defined)
236-
notable_action: Union[NotableAction, None] = None
253+
@computed_field
254+
@property
255+
def notable_action(self) -> NotableAction | None:
256+
if not self.saved_search.content:
257+
return None
258+
return CorrelationSearch._get_notable_action(self.saved_search.content)
237259

238260
# The list of risk events found
239261
_risk_events: Optional[list[RiskEvent]] = PrivateAttr(default=None)
240262

241263
# The list of notable events found
242264
_notable_events: Optional[list[NotableEvent]] = PrivateAttr(default=None)
265+
model_config = ConfigDict(arbitrary_types_allowed=True, extra='forbid')
243266

244-
class Config:
245-
# needed to allow fields w/ types like SavedSearch
246-
arbitrary_types_allowed = True
247-
# We want to have more ridgid typing
248-
extra = 'forbid'
249-
250-
@validator("name", always=True)
251-
@classmethod
252-
def _convert_detection_to_search_name(cls, v, values) -> str:
253-
"""
254-
Validate name and derive if None
255-
"""
256-
if "detection" not in values:
257-
raise ValueError("detection missing; name is dependent on detection")
258-
259-
expected_name = f"ESCU - {values['detection'].name} - Rule"
260-
if v is not None and v != expected_name:
261-
raise ValueError(
262-
"name must be derived from detection; leave as None and it will be derived automatically"
263-
)
264-
return expected_name
265-
266-
@validator("splunk_path", always=True)
267-
@classmethod
268-
def _derive_splunk_path(cls, v, values) -> str:
269-
"""
270-
Validate splunk_path and derive if None
271-
"""
272-
if "name" not in values:
273-
raise ValueError("name missing; splunk_path is dependent on name")
274-
275-
expected_path = f"saved/searches/{values['name']}"
276-
if v is not None and v != expected_path:
277-
raise ValueError(
278-
"splunk_path must be derived from name; leave as None and it will be derived automatically"
279-
)
280-
return f"saved/searches/{values['name']}"
281-
282-
@validator("saved_search", always=True)
283-
@classmethod
284-
def _instantiate_saved_search(cls, v, values) -> str:
285-
"""
286-
Ensure saved_search was initialized as None and derive
287-
"""
288-
if "splunk_path" not in values or "service" not in values:
289-
raise ValueError("splunk_path or service missing; saved_search is dependent on both")
290-
291-
if v is not None:
292-
raise ValueError(
293-
"saved_search must be derived from the service and splunk_path; leave as None and it will be derived "
294-
"automatically"
295-
)
296-
return splunklib.SavedSearch(
297-
values['service'],
298-
values['splunk_path'],
299-
)
300-
301-
@validator("risk_analysis_action", always=True)
302-
@classmethod
303-
def _init_risk_analysis_action(cls, v, values) -> Optional[RiskAnalysisAction]:
304-
"""
305-
Initialize risk_analysis_action
306-
"""
307-
if "saved_search" not in values:
308-
raise ValueError("saved_search missing; risk_analysis_action is dependent on saved_search")
309-
310-
if v is not None:
311-
raise ValueError(
312-
"risk_analysis_action must be derived from the saved_search; leave as None and it will be derived "
313-
"automatically"
314-
)
315-
return CorrelationSearch._get_risk_analysis_action(values['saved_search'].content)
316-
317-
@validator("notable_action", always=True)
318-
@classmethod
319-
def _init_notable_action(cls, v, values) -> Optional[NotableAction]:
320-
"""
321-
Initialize notable_action
322-
"""
323-
if "saved_search" not in values:
324-
raise ValueError("saved_search missing; notable_action is dependent on saved_search")
325-
326-
if v is not None:
327-
raise ValueError(
328-
"notable_action must be derived from the saved_search; leave as None and it will be derived "
329-
"automatically"
330-
)
331-
return CorrelationSearch._get_notable_action(values['saved_search'].content)
332267

333268
@property
334269
def earliest_time(self) -> str:

contentctl/objects/risk_analysis_action.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22
import json
33

4-
from pydantic import BaseModel, validator
4+
from pydantic import BaseModel, field_validator
55

66
from contentctl.objects.risk_object import RiskObject
77
from contentctl.objects.threat_object import ThreatObject
@@ -21,32 +21,32 @@ class RiskAnalysisAction(BaseModel):
2121
risk_objects: list[RiskObject]
2222
message: str
2323

24-
@validator("message", always=True, pre=True)
24+
@field_validator("message", mode="before")
2525
@classmethod
26-
def _validate_message(cls, v, values) -> str:
26+
def _validate_message(cls, message) -> str:
2727
"""
2828
Validate splunk_path and derive if None
2929
"""
30-
if v is None:
30+
if message is None:
3131
raise ValueError(
3232
"RiskAnalysisAction.message is a required field, cannot be None. Check the "
3333
"detection YAML definition to ensure a message is defined"
3434
)
3535

36-
if not isinstance(v, str):
36+
if not isinstance(message, str):
3737
raise ValueError(
3838
"RiskAnalysisAction.message must be a string. Check the detection YAML definition "
3939
"to ensure message is defined as a string"
4040
)
4141

42-
if len(v.strip()) < 1:
42+
if len(message.strip()) < 1:
4343
raise ValueError(
4444
"RiskAnalysisAction.message must be a meaningful string, with a length greater than"
4545
"or equal to 1 (once stripped of trailing/leading whitespace). Check the detection "
4646
"YAML definition to ensure message is defined as a meanigful string"
4747
)
4848

49-
return v
49+
return message
5050

5151
@classmethod
5252
def parse_from_dict(cls, dict_: dict[str, Any]) -> "RiskAnalysisAction":

contentctl/objects/ssa_detection_tags.py

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
2-
import re
32
from typing import List
4-
from pydantic import BaseModel, validator, ValidationError, model_validator, Field
3+
from pydantic import BaseModel, computed_field, constr, field_validator, model_validator, Field
54

65
from contentctl.objects.mitre_attack_enrichment import MitreAttackEnrichment
76
from contentctl.objects.constants import *
@@ -13,17 +12,19 @@ class SSADetectionTags(BaseModel):
1312
analytic_story: list
1413
asset_type: str
1514
automated_detection_testing: str = None
16-
cis20: list = None
17-
confidence: int
18-
impact: int
15+
cis20: list[constr(pattern=r"^CIS (\d|1\d|20)$")] = None #DO NOT match leading zeroes and ensure no extra characters before or after the string
16+
confidence: int = Field(..., ge=1, le=100)
17+
impact: int = Field(..., ge=1, le=100)
1918
kill_chain_phases: list = None
2019
message: str
21-
mitre_attack_id: list = None
20+
mitre_attack_id: list[constr(pattern=r"^T[0-9]{4}$")] = None
2221
nist: list = None
2322
observable: list
2423
product: List[SecurityContentProductName] = Field(...,min_length=1)
2524
required_fields: list
26-
risk_score: int
25+
@computed_field
26+
def risk_score(self) -> int:
27+
return round((self.confidence * self.impact)/100)
2728
security_domain: str
2829
risk_severity: str = None
2930
cve: list = None
@@ -51,16 +52,9 @@ class SSADetectionTags(BaseModel):
5152
annotations: dict = None
5253

5354

54-
@validator('cis20')
55-
def tags_cis20(cls, v, values):
56-
pattern = r'^CIS ([\d|1\d|20)$' #DO NOT match leading zeroes and ensure no extra characters before or after the string
57-
for value in v:
58-
if not re.match(pattern, value):
59-
raise ValueError(f"CIS control '{value}' is not a valid Control ('CIS 1' -> 'CIS 20'): {values['name']}")
60-
return v
6155

62-
@validator('nist')
63-
def tags_nist(cls, v, values):
56+
@field_validator('nist', mode='before')
57+
def tags_nist(cls, nist):
6458
# Sourced Courtest of NIST: https://www.nist.gov/system/files/documents/cyberframework/cybersecurity-framework-021214.pdf (Page 19)
6559
IDENTIFY = [f'ID.{category}' for category in ["AM", "BE", "GV", "RA", "RM"] ]
6660
PROTECT = [f'PR.{category}' for category in ["AC", "AT", "DS", "IP", "MA", "PT"]]
@@ -70,53 +64,18 @@ def tags_nist(cls, v, values):
7064
ALL_NIST_CATEGORIES = IDENTIFY + PROTECT + DETECT + RESPOND + RECOVER
7165

7266

73-
for value in v:
74-
if not value in ALL_NIST_CATEGORIES:
67+
for value in nist:
68+
if value not in ALL_NIST_CATEGORIES:
7569
raise ValueError(f"NIST Category '{value}' is not a valid category")
76-
return v
70+
return nist
7771

78-
@validator('confidence')
79-
def tags_confidence(cls, v, values):
80-
v = int(v)
81-
if not (v > 0 and v <= 100):
82-
raise ValueError('confidence score is out of range 1-100.' )
83-
else:
84-
return v
85-
86-
87-
@validator('impact')
88-
def tags_impact(cls, v, values):
89-
if not (v > 0 and v <= 100):
90-
raise ValueError('impact score is out of range 1-100.')
91-
else:
92-
return v
93-
94-
@validator('kill_chain_phases')
95-
def tags_kill_chain_phases(cls, v, values):
72+
@field_validator('kill_chain_phases')
73+
def tags_kill_chain_phases(cls, kill_chain_phases):
9674
valid_kill_chain_phases = SES_KILL_CHAIN_MAPPINGS.keys()
97-
for value in v:
75+
for value in kill_chain_phases:
9876
if value not in valid_kill_chain_phases:
9977
raise ValueError('kill chain phase not valid. Valid options are ' + str(valid_kill_chain_phases))
100-
return v
101-
102-
@validator('mitre_attack_id')
103-
def tags_mitre_attack_id(cls, v, values):
104-
pattern = 'T[0-9]{4}'
105-
for value in v:
106-
if not re.match(pattern, value):
107-
raise ValueError('Mitre Attack ID are not following the pattern Txxxx:' )
108-
return v
109-
110-
111-
112-
@validator('risk_score')
113-
def tags_calculate_risk_score(cls, v, values):
114-
calculated_risk_score = round(values['impact'] * values['confidence'] / 100)
115-
if calculated_risk_score != int(v):
116-
raise ValueError(f"Risk Score must be calculated as round(confidence * impact / 100)"
117-
f"\n Expected risk_score={calculated_risk_score}, found risk_score={int(v)}: {values['name']}")
118-
return v
119-
78+
return kill_chain_phases
12079

12180
@model_validator(mode="after")
12281
def tags_observable(self):

0 commit comments

Comments
 (0)