Skip to content

Commit 4771f7b

Browse files
committed
pkg/infrastructure/azure: add and remove bootstrap ssh security rule
https://issues.redhat.com/browse/CORS-3302
1 parent 41a112d commit 4771f7b

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

pkg/infrastructure/azure/azure.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type Provider struct {
4545
StorageAccount *armstorage.Account
4646
StorageClientFactory *armstorage.ClientFactory
4747
StorageAccountKeys []armstorage.AccountKey
48+
NetworkClientFactory *armnetwork.ClientFactory
4849
Tags map[string]*string
4950
lbBackendAddressPool *armnetwork.BackendAddressPool
5051
}
@@ -53,6 +54,7 @@ var _ clusterapi.PreProvider = (*Provider)(nil)
5354
var _ clusterapi.InfraReadyProvider = (*Provider)(nil)
5455
var _ clusterapi.PostProvider = (*Provider)(nil)
5556
var _ clusterapi.IgnitionProvider = (*Provider)(nil)
57+
var _ clusterapi.PostDestroyer = (*Provider)(nil)
5658

5759
// Name returns the name of the provider.
5860
func (p *Provider) Name() string {
@@ -486,6 +488,7 @@ func (p *Provider) InfraReady(ctx context.Context, in clusterapi.InfraReadyInput
486488
p.StorageAccount = storageAccount
487489
p.StorageClientFactory = storageClientFactory
488490
p.StorageAccountKeys = storageAccountKeys
491+
p.NetworkClientFactory = networkClientFactory
489492
p.lbBackendAddressPool = lbBap
490493

491494
if err := createDNSEntries(ctx, in, extLBFQDN, resourceGroupName); err != nil {
@@ -544,6 +547,33 @@ func (p *Provider) PostProvision(ctx context.Context, in clusterapi.PostProvisio
544547
if err = associateVMToBackendPool(ctx, *vmInput); err != nil {
545548
return fmt.Errorf("failed to associate control plane VMs with external load balancer: %w", err)
546549
}
550+
551+
if err = addSecurityGroupRule(ctx, &securityGroupInput{
552+
resourceGroupName: p.ResourceGroupName,
553+
securityGroupName: fmt.Sprintf("%s-nsg", in.InfraID),
554+
securityRuleName: "ssh_in",
555+
securityRulePort: "22",
556+
securityGroupsClient: p.NetworkClientFactory.NewSecurityGroupsClient(),
557+
}); err != nil {
558+
return fmt.Errorf("failed to add security rule: %w", err)
559+
}
560+
}
561+
562+
return nil
563+
}
564+
565+
// PostDestroy removes SSH access from the network security rules after the
566+
// bootstrap machine is destroyed.
567+
func (p *Provider) PostDestroy(ctx context.Context, in clusterapi.PostDestroyerInput) error {
568+
err := deleteSecurityGroupRule(ctx, &securityGroupInput{
569+
resourceGroupName: p.ResourceGroupName,
570+
securityGroupName: fmt.Sprintf("%s-nsg", in.Metadata.InfraID),
571+
securityRuleName: "ssh_in",
572+
securityRulePort: "22",
573+
securityGroupsClient: p.NetworkClientFactory.NewSecurityGroupsClient(),
574+
})
575+
if err != nil {
576+
return fmt.Errorf("failed to delete security rule: %w", err)
547577
}
548578

549579
return nil

pkg/infrastructure/azure/network.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
99
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4"
1010
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2"
11+
"k8s.io/utils/ptr"
1112
)
1213

1314
type lbInput struct {
@@ -37,6 +38,14 @@ type vmInput struct {
3738
nicClient *armnetwork.InterfacesClient
3839
}
3940

41+
type securityGroupInput struct {
42+
resourceGroupName string
43+
securityGroupName string
44+
securityRuleName string
45+
securityRulePort string
46+
securityGroupsClient *armnetwork.SecurityGroupsClient
47+
}
48+
4049
func createPublicIP(ctx context.Context, in *pipInput) (*armnetwork.PublicIPAddress, error) {
4150
pollerResp, err := in.pipClient.BeginCreateOrUpdate(
4251
ctx,
@@ -266,3 +275,65 @@ func associateVMToBackendPool(ctx context.Context, in vmInput) error {
266275
}
267276
return nil
268277
}
278+
279+
func addSecurityGroupRule(ctx context.Context, in *securityGroupInput) error {
280+
securityGroupResp, err := in.securityGroupsClient.Get(ctx, in.resourceGroupName, in.securityGroupName, nil)
281+
if err != nil {
282+
return fmt.Errorf("failed to get security group: %w", err)
283+
}
284+
securityGroup := securityGroupResp.SecurityGroup
285+
286+
priority := int32(100)
287+
for _, securityRule := range securityGroup.Properties.SecurityRules {
288+
if *securityRule.Properties.Priority >= priority {
289+
priority = *securityRule.Properties.Priority + 1
290+
}
291+
}
292+
// Assume inbound tcp connections from any port to destination port for now
293+
securityGroup.Properties.SecurityRules = append(securityGroup.Properties.SecurityRules,
294+
&armnetwork.SecurityRule{
295+
Name: ptr.To(in.securityRuleName),
296+
Properties: &armnetwork.SecurityRulePropertiesFormat{
297+
Access: ptr.To(armnetwork.SecurityRuleAccessAllow),
298+
Direction: ptr.To(armnetwork.SecurityRuleDirectionInbound),
299+
Protocol: ptr.To(armnetwork.SecurityRuleProtocolTCP),
300+
DestinationAddressPrefix: ptr.To("*"),
301+
DestinationPortRange: ptr.To(in.securityRulePort),
302+
Priority: ptr.To[int32](priority),
303+
SourceAddressPrefix: ptr.To("*"),
304+
SourcePortRange: ptr.To("*"),
305+
},
306+
},
307+
)
308+
309+
_, err = in.securityGroupsClient.BeginCreateOrUpdate(ctx, in.resourceGroupName, in.securityGroupName, securityGroup, nil)
310+
if err != nil {
311+
return fmt.Errorf("failed to add security rule: %w", err)
312+
}
313+
314+
return nil
315+
}
316+
317+
func deleteSecurityGroupRule(ctx context.Context, in *securityGroupInput) error {
318+
securityGroupResp, err := in.securityGroupsClient.Get(ctx, in.resourceGroupName, in.securityGroupName, nil)
319+
if err != nil {
320+
return fmt.Errorf("failed to get security group: %w", err)
321+
}
322+
securityGroup := securityGroupResp.SecurityGroup
323+
324+
var securityRules []*armnetwork.SecurityRule
325+
for _, securityRule := range securityGroup.Properties.SecurityRules {
326+
if *securityRule.Name == in.securityRuleName {
327+
continue
328+
}
329+
securityRules = append(securityRules, securityRule)
330+
}
331+
securityGroup.Properties.SecurityRules = securityRules
332+
333+
_, err = in.securityGroupsClient.BeginCreateOrUpdate(ctx, in.resourceGroupName, in.securityGroupName, securityGroup, nil)
334+
if err != nil {
335+
return fmt.Errorf("failed to update security group: %w", err)
336+
}
337+
338+
return nil
339+
}

0 commit comments

Comments
 (0)