1
1
from __future__ import annotations
2
- import re
3
2
from typing import List
4
- from pydantic import BaseModel , validator , ValidationError , model_validator , Field
3
+ from pydantic import BaseModel , computed_field , constr , field_validator , model_validator , Field
5
4
6
5
from contentctl .objects .mitre_attack_enrichment import MitreAttackEnrichment
7
6
from contentctl .objects .constants import *
@@ -13,17 +12,19 @@ class SSADetectionTags(BaseModel):
13
12
analytic_story : list
14
13
asset_type : str
15
14
automated_detection_testing : str = None
16
- cis20 : list = None
17
- confidence : int
18
- impact : int
15
+ cis20 : list [ constr ( pattern = r"^CIS (\d|1\d|20)$" )] = None #DO NOT match leading zeroes and ensure no extra characters before or after the string
16
+ confidence : int = Field (..., ge = 1 , le = 100 )
17
+ impact : int = Field (..., ge = 1 , le = 100 )
19
18
kill_chain_phases : list = None
20
19
message : str
21
- mitre_attack_id : list = None
20
+ mitre_attack_id : list [ constr ( pattern = r"^T[0-9]{4}$" )] = None
22
21
nist : list = None
23
22
observable : list
24
23
product : List [SecurityContentProductName ] = Field (...,min_length = 1 )
25
24
required_fields : list
26
- risk_score : int
25
+ @computed_field
26
+ def risk_score (self ) -> int :
27
+ return round ((self .confidence * self .impact )/ 100 )
27
28
security_domain : str
28
29
risk_severity : str = None
29
30
cve : list = None
@@ -51,16 +52,9 @@ class SSADetectionTags(BaseModel):
51
52
annotations : dict = None
52
53
53
54
54
- @validator ('cis20' )
55
- def tags_cis20 (cls , v , values ):
56
- pattern = r'^CIS ([\d|1\d|20)$' #DO NOT match leading zeroes and ensure no extra characters before or after the string
57
- for value in v :
58
- if not re .match (pattern , value ):
59
- raise ValueError (f"CIS control '{ value } ' is not a valid Control ('CIS 1' -> 'CIS 20'): { values ['name' ]} " )
60
- return v
61
55
62
- @validator ('nist' )
63
- def tags_nist (cls , v , values ):
56
+ @field_validator ('nist' , mode = 'before ' )
57
+ def tags_nist (cls , nist ):
64
58
# Sourced Courtest of NIST: https://www.nist.gov/system/files/documents/cyberframework/cybersecurity-framework-021214.pdf (Page 19)
65
59
IDENTIFY = [f'ID.{ category } ' for category in ["AM" , "BE" , "GV" , "RA" , "RM" ] ]
66
60
PROTECT = [f'PR.{ category } ' for category in ["AC" , "AT" , "DS" , "IP" , "MA" , "PT" ]]
@@ -70,53 +64,18 @@ def tags_nist(cls, v, values):
70
64
ALL_NIST_CATEGORIES = IDENTIFY + PROTECT + DETECT + RESPOND + RECOVER
71
65
72
66
73
- for value in v :
74
- if not value in ALL_NIST_CATEGORIES :
67
+ for value in nist :
68
+ if value not in ALL_NIST_CATEGORIES :
75
69
raise ValueError (f"NIST Category '{ value } ' is not a valid category" )
76
- return v
70
+ return nist
77
71
78
- @validator ('confidence' )
79
- def tags_confidence (cls , v , values ):
80
- v = int (v )
81
- if not (v > 0 and v <= 100 ):
82
- raise ValueError ('confidence score is out of range 1-100.' )
83
- else :
84
- return v
85
-
86
-
87
- @validator ('impact' )
88
- def tags_impact (cls , v , values ):
89
- if not (v > 0 and v <= 100 ):
90
- raise ValueError ('impact score is out of range 1-100.' )
91
- else :
92
- return v
93
-
94
- @validator ('kill_chain_phases' )
95
- def tags_kill_chain_phases (cls , v , values ):
72
+ @field_validator ('kill_chain_phases' )
73
+ def tags_kill_chain_phases (cls , kill_chain_phases ):
96
74
valid_kill_chain_phases = SES_KILL_CHAIN_MAPPINGS .keys ()
97
- for value in v :
75
+ for value in kill_chain_phases :
98
76
if value not in valid_kill_chain_phases :
99
77
raise ValueError ('kill chain phase not valid. Valid options are ' + str (valid_kill_chain_phases ))
100
- return v
101
-
102
- @validator ('mitre_attack_id' )
103
- def tags_mitre_attack_id (cls , v , values ):
104
- pattern = 'T[0-9]{4}'
105
- for value in v :
106
- if not re .match (pattern , value ):
107
- raise ValueError ('Mitre Attack ID are not following the pattern Txxxx:' )
108
- return v
109
-
110
-
111
-
112
- @validator ('risk_score' )
113
- def tags_calculate_risk_score (cls , v , values ):
114
- calculated_risk_score = round (values ['impact' ] * values ['confidence' ] / 100 )
115
- if calculated_risk_score != int (v ):
116
- raise ValueError (f"Risk Score must be calculated as round(confidence * impact / 100)"
117
- f"\n Expected risk_score={ calculated_risk_score } , found risk_score={ int (v )} : { values ['name' ]} " )
118
- return v
119
-
78
+ return kill_chain_phases
120
79
121
80
@model_validator (mode = "after" )
122
81
def tags_observable (self ):
0 commit comments