Skip to content

Commit d30d295

Browse files
Zuulopenstack-gerrit
authored andcommitted
Merge "Improve the SG RPC callback security_group_info_for_ports" into stable/2023.1
2 parents 1a5dbb7 + 29d5570 commit d30d295

File tree

4 files changed

+45
-13
lines changed

4 files changed

+45
-13
lines changed

neutron/api/rpc/handlers/securitygroups_rpc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ def _select_sg_ids_for_ports(self, context, ports):
434434
for sg_id in p['security_group_ids']))
435435
return [(sg_id, ) for sg_id in sg_ids]
436436

437-
def _is_security_group_stateful(self, context, sg_id):
438-
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
439-
return sg.stateful
437+
def _get_sgs_stateful_flag(self, context, sg_ids):
438+
sgs_stateful = {}
439+
for sg_id in sg_ids:
440+
sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id)
441+
sgs_stateful[sg_id] = sg.stateful
442+
443+
return sgs_stateful

neutron/db/securitygroups_rpc_base.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,10 @@ def security_group_info_for_ports(self, context, ports):
211211
# this set will be serialized into a list by rpc code
212212
remote_address_group_info[remote_ag_id][ethertype] = set()
213213
direction = rule_in_db['direction']
214-
stateful = self._is_security_group_stateful(context,
215-
security_group_id)
216214
rule_dict = {
217215
'direction': direction,
218216
'ethertype': ethertype,
219-
'stateful': stateful}
217+
}
220218

221219
for key in ('protocol', 'port_range_min', 'port_range_max',
222220
'remote_ip_prefix', 'remote_group_id',
@@ -234,6 +232,13 @@ def security_group_info_for_ports(self, context, ports):
234232
if rule_dict not in sg_info['security_groups'][security_group_id]:
235233
sg_info['security_groups'][security_group_id].append(
236234
rule_dict)
235+
236+
# Populate the security group "stateful" flag in the SGs list of rules.
237+
for sg_id, stateful in self._get_sgs_stateful_flag(
238+
context, sg_info['security_groups'].keys()).items():
239+
for rule in sg_info['security_groups'][sg_id]:
240+
rule['stateful'] = stateful
241+
237242
# Update the security groups info if they don't have any rules
238243
sg_ids = self._select_sg_ids_for_ports(context, ports)
239244
for (sg_id, ) in sg_ids:
@@ -427,13 +432,13 @@ def _select_sg_ids_for_ports(self, context, ports):
427432
"""
428433
raise NotImplementedError()
429434

430-
def _is_security_group_stateful(self, context, sg_id):
431-
"""Return whether the security group is stateful or not.
435+
def _get_sgs_stateful_flag(self, context, sg_id):
436+
"""Return the security groups stateful flag.
432437
433-
Return True if the security group associated with the given ID
434-
is stateful, else False.
438+
Returns a dictionary with the SG ID as key and the stateful flag:
439+
{sg_1: True, sg_2: False, ...}
435440
"""
436-
return True
441+
raise NotImplementedError()
437442

438443

439444
class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
@@ -530,5 +535,5 @@ def _select_ips_for_remote_address_group(self, context,
530535
return ips_by_group
531536

532537
@db_api.retry_if_session_inactive()
533-
def _is_security_group_stateful(self, context, sg_id):
534-
return sg_obj.SecurityGroup.get_sg_by_id(context, sg_id).stateful
538+
def _get_sgs_stateful_flag(self, context, sg_ids):
539+
return sg_obj.SecurityGroup.get_sgs_stateful_flag(context, sg_ids)

neutron/objects/securitygroup.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def get_bound_project_ids(cls, context, obj_id):
133133
security_group_ids=[obj_id])
134134
return {port.project_id for port in port_objs}
135135

136+
@classmethod
137+
@db_api.CONTEXT_READER
138+
def get_sgs_stateful_flag(cls, context, sg_ids):
139+
query = context.session.query(cls.db_model.id, cls.db_model.stateful)
140+
query = query.filter(cls.db_model.id.in_(sg_ids))
141+
return dict(query.all())
142+
136143

137144
@base.NeutronObjectRegistry.register
138145
class DefaultSecurityGroup(base.NeutronDbObject):

neutron/tests/unit/objects/test_securitygroup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ def test_get_objects_no_synth(self):
210210
self.assertEqual(len(sg_obj.rules), 0)
211211
self.assertIsNone(listed_objs[0].rules)
212212

213+
def test_get_sgs_stateful_flag(self):
214+
for obj in self.objs:
215+
obj.create()
216+
217+
sg_ids = tuple(sg.id for sg in self.objs)
218+
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
219+
self.context, sg_ids)
220+
for sg_id, stateful in sgs_stateful.items():
221+
for obj in (obj for obj in self.objs if obj.id == sg_id):
222+
self.assertEqual(obj.stateful, stateful)
223+
224+
sg_ids = sg_ids + ('random_id_not_present', )
225+
sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag(
226+
self.context, sg_ids)
227+
self.assertEqual(len(self.objs), len(sgs_stateful))
228+
213229

214230
class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):
215231

0 commit comments

Comments
 (0)