Skip to content

Commit f76c21c

Browse files
kishorjpjryan93owenthereal
committed
Allow UDP for AWS NLB
Co-authored-by: Patrick Ryan <[email protected]> Co-authored-by: Owen Ou <[email protected]>
1 parent f705d62 commit f76c21c

File tree

3 files changed

+95
-22
lines changed

3 files changed

+95
-22
lines changed

staging/src/k8s.io/legacy-cloud-providers/aws/aws.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3660,9 +3660,10 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS
36603660

36613661
sslPorts := getPortSets(annotations[ServiceAnnotationLoadBalancerSSLPorts])
36623662
for _, port := range apiService.Spec.Ports {
3663-
if port.Protocol != v1.ProtocolTCP {
3664-
return nil, fmt.Errorf("Only TCP LoadBalancer is supported for AWS ELB")
3663+
if err := checkProtocol(port, annotations); err != nil {
3664+
return nil, err
36653665
}
3666+
36663667
if port.NodePort == 0 {
36673668
klog.Errorf("Ignoring port without NodePort defined: %v", port)
36683669
continue
@@ -3682,7 +3683,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS
36823683
}
36833684

36843685
certificateARN := annotations[ServiceAnnotationLoadBalancerCertificate]
3685-
if certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) {
3686+
if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) {
36863687
portMapping.FrontendProtocol = elbv2.ProtocolEnumTls
36873688
portMapping.SSLCertificateARN = certificateARN
36883689
portMapping.SSLPolicy = annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy]
@@ -3693,12 +3694,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS
36933694
}
36943695

36953696
v2Mappings = append(v2Mappings, portMapping)
3697+
} else {
3698+
listener, err := buildListener(port, annotations, sslPorts)
3699+
if err != nil {
3700+
return nil, err
3701+
}
3702+
listeners = append(listeners, listener)
36963703
}
3697-
listener, err := buildListener(port, annotations, sslPorts)
3698-
if err != nil {
3699-
return nil, err
3700-
}
3701-
listeners = append(listeners, listener)
37023704
}
37033705

37043706
if apiService.Spec.LoadBalancerIP != "" {
@@ -4739,6 +4741,18 @@ func (c *Cloud) nodeNameToProviderID(nodeName types.NodeName) (InstanceID, error
47394741
return KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID()
47404742
}
47414743

4744+
func checkProtocol(port v1.ServicePort, annotations map[string]string) error {
4745+
// nlb supports tcp, udp
4746+
if isNLB(annotations) && (port.Protocol == v1.ProtocolTCP || port.Protocol == v1.ProtocolUDP) {
4747+
return nil
4748+
}
4749+
// elb only supports tcp
4750+
if !isNLB(annotations) && port.Protocol == v1.ProtocolTCP {
4751+
return nil
4752+
}
4753+
return fmt.Errorf("Protocol %s not supported by LoadBalancer", port.Protocol)
4754+
}
4755+
47424756
func setNodeDisk(
47434757
nodeDiskMap map[types.NodeName]map[KubernetesVolumeID]bool,
47444758
volumeID KubernetesVolumeID,

staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa
185185
}
186186

187187
// actual maps FrontendPort to an elbv2.Listener
188-
actual := map[int64]*elbv2.Listener{}
188+
actual := map[int64]map[string]*elbv2.Listener{}
189189
for _, listener := range listenerDescriptions.Listeners {
190-
actual[*listener.Port] = listener
190+
if actual[*listener.Port] == nil {
191+
actual[*listener.Port] = map[string]*elbv2.Listener{}
192+
}
193+
actual[*listener.Port][*listener.Protocol] = listener
191194
}
192195

193196
actualTargetGroups, err := c.elbv2.DescribeTargetGroups(
@@ -207,10 +210,11 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa
207210
// Handle additions/modifications
208211
for _, mapping := range mappings {
209212
frontendPort := mapping.FrontendPort
213+
frontendProtocol := mapping.FrontendProtocol
210214
nodePort := mapping.TrafficPort
211215

212216
// modifications
213-
if listener, ok := actual[frontendPort]; ok {
217+
if listener, ok := actual[frontendPort][frontendProtocol]; ok {
214218
listenerNeedsModification := false
215219

216220
if aws.StringValue(listener.Protocol) != mapping.FrontendProtocol {
@@ -315,23 +319,27 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa
315319
dirty = true
316320
}
317321

318-
frontEndPorts := map[int64]bool{}
322+
frontEndPorts := map[int64]map[string]bool{}
319323
for i := range mappings {
320-
frontEndPorts[mappings[i].FrontendPort] = true
324+
if frontEndPorts[mappings[i].FrontendPort] == nil {
325+
frontEndPorts[mappings[i].FrontendPort] = map[string]bool{}
326+
}
327+
frontEndPorts[mappings[i].FrontendPort][mappings[i].FrontendProtocol] = true
321328
}
322329

323330
// handle deletions
324-
for port, listener := range actual {
325-
if _, ok := frontEndPorts[port]; !ok {
326-
err := c.deleteListenerV2(listener)
327-
if err != nil {
328-
return nil, err
331+
for port := range actual {
332+
for protocol := range actual[port] {
333+
if _, ok := frontEndPorts[port][protocol]; !ok {
334+
err := c.deleteListenerV2(actual[port][protocol])
335+
if err != nil {
336+
return nil, err
337+
}
338+
dirty = true
329339
}
330-
dirty = true
331340
}
332341
}
333342
}
334-
335343
if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil {
336344
return nil, err
337345
}
@@ -765,10 +773,14 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[
765773

766774
{
767775
clientPorts := sets.Int64{}
776+
clientProtocol := "tcp"
768777
healthCheckPorts := sets.Int64{}
769778
for _, port := range portMappings {
770779
clientPorts.Insert(port.TrafficPort)
771780
healthCheckPorts.Insert(port.HealthCheckPort)
781+
if port.TrafficProtocol == string(v1.ProtocolUDP) {
782+
clientProtocol = "udp"
783+
}
772784
}
773785
clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName)
774786
healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName)
@@ -782,14 +794,14 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[
782794
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, vpcCIDRs); err != nil {
783795
return err
784796
}
785-
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", clientPorts, clientCIDRs); err != nil {
797+
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil {
786798
return err
787799
}
788800
} else {
789801
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil {
790802
return err
791803
}
792-
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", nil, nil); err != nil {
804+
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil {
793805
return err
794806
}
795807
}

staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,53 @@ func TestDescribeLoadBalancerOnEnsure(t *testing.T) {
13711371
c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{})
13721372
}
13731373

1374+
func TestCheckProtocol(t *testing.T) {
1375+
tests := []struct {
1376+
name string
1377+
annotations map[string]string
1378+
port v1.ServicePort
1379+
wantErr error
1380+
}{
1381+
{
1382+
name: "TCP with ELB",
1383+
annotations: make(map[string]string),
1384+
port: v1.ServicePort{Protocol: v1.ProtocolTCP, Port: int32(8080)},
1385+
wantErr: nil,
1386+
},
1387+
{
1388+
name: "TCP with NLB",
1389+
annotations: map[string]string{ServiceAnnotationLoadBalancerType: "nlb"},
1390+
port: v1.ServicePort{Protocol: v1.ProtocolTCP, Port: int32(8080)},
1391+
wantErr: nil,
1392+
},
1393+
{
1394+
name: "UDP with ELB",
1395+
annotations: make(map[string]string),
1396+
port: v1.ServicePort{Protocol: v1.ProtocolUDP, Port: int32(8080)},
1397+
wantErr: fmt.Errorf("Protocol UDP not supported by load balancer"),
1398+
},
1399+
{
1400+
name: "UDP with NLB",
1401+
annotations: map[string]string{ServiceAnnotationLoadBalancerType: "nlb"},
1402+
port: v1.ServicePort{Protocol: v1.ProtocolUDP, Port: int32(8080)},
1403+
wantErr: nil,
1404+
},
1405+
}
1406+
for _, test := range tests {
1407+
tt := test
1408+
t.Run(tt.name, func(t *testing.T) {
1409+
t.Parallel()
1410+
err := checkProtocol(tt.port, tt.annotations)
1411+
if tt.wantErr != nil && err == nil {
1412+
t.Errorf("Expected error: want=%s got =%s", tt.wantErr, err)
1413+
}
1414+
if tt.wantErr == nil && err != nil {
1415+
t.Errorf("Unexpected error: want=%s got =%s", tt.wantErr, err)
1416+
}
1417+
})
1418+
}
1419+
}
1420+
13741421
func TestBuildListener(t *testing.T) {
13751422
tests := []struct {
13761423
name string

0 commit comments

Comments
 (0)