Skip to content

Commit 23a4fcc

Browse files
committed
Add go code for a proxy server
1 parent 84338ea commit 23a4fcc

File tree

3 files changed

+348
-0
lines changed

3 files changed

+348
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: Deploy Serverless function
2+
on: push
3+
4+
jobs:
5+
deploy:
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: actions/checkout@v4
9+
- uses: goodsmileduck/yandex-serverless-action@v2
10+
with:
11+
token: ${{ secrets.YC_IAM_TOKEN }}
12+
function_id: ${{ vars.YC_FUNCTION_ID }}
13+
runtime: 'golang123'
14+
memory: '128'
15+
execution_timeout: 15
16+
entrypoint: 'main.handleProxy'
17+
environment: PROXY_USER=${{ vars.PROXY_USER }},PROXY_PASS=${{ secrets.PROXY_PASS }}
18+
source: '.'
19+
exclude: '.github/,terraform/'

main.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package main
2+
3+
import (
4+
"io"
5+
"log"
6+
"log/slog"
7+
"net"
8+
"net/http"
9+
"os"
10+
"time"
11+
)
12+
13+
// Main handler
14+
func handleProxy(w http.ResponseWriter, r *http.Request) {
15+
// Auth first
16+
if !checkAuth(r) {
17+
w.Header().Set("Proxy-Authenticate", `Basic realm="Restricted"`)
18+
http.Error(w, "Proxy Authentication Required", http.StatusProxyAuthRequired)
19+
return
20+
}
21+
22+
if r.Method == http.MethodConnect {
23+
handleTunneling(w, r)
24+
} else {
25+
handleHTTP(w, r)
26+
}
27+
}
28+
29+
// Very simple auth check (Basic Auth)
30+
func checkAuth(r *http.Request) bool {
31+
// Look for Proxy-Authorization header
32+
auth := r.Header.Get("Proxy-Authorization")
33+
if auth == "" {
34+
return false
35+
}
36+
37+
// Reuse Go's parsing logic by making a fake request
38+
req := &http.Request{Header: http.Header{"Authorization": []string{auth}}}
39+
40+
username, password, ok := req.BasicAuth()
41+
if !ok {
42+
return false
43+
}
44+
45+
// Get expected credentials from the environment
46+
expectedUser := os.Getenv("PROXY_USER")
47+
expectedPass := os.Getenv("PROXY_PASS")
48+
49+
return username == expectedUser && password == expectedPass
50+
}
51+
52+
// Handle HTTPS tunneling (CONNECT)
53+
func handleTunneling(w http.ResponseWriter, r *http.Request) {
54+
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
55+
if err != nil {
56+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
57+
58+
return
59+
}
60+
61+
defer callAndLogError(destConn.Close)
62+
63+
// Write 200 Connection Established to the client
64+
w.WriteHeader(http.StatusOK)
65+
66+
// Hijack the connection to get the raw TCP stream
67+
hijacker, ok := w.(http.Hijacker)
68+
if !ok {
69+
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
70+
71+
return
72+
}
73+
74+
clientConn, _, err := hijacker.Hijack()
75+
if err != nil {
76+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
77+
78+
return
79+
}
80+
81+
defer callAndLogError(clientConn.Close)
82+
83+
// Bidirectional copy between client and destination
84+
go func() {
85+
_, err = io.Copy(destConn, clientConn)
86+
if err != nil {
87+
slog.Error(err.Error())
88+
}
89+
}()
90+
91+
_, err = io.Copy(clientConn, destConn)
92+
if err != nil {
93+
slog.Error(err.Error())
94+
}
95+
}
96+
97+
// HTTP proxy handler
98+
func handleHTTP(w http.ResponseWriter, r *http.Request) {
99+
// Create a new request based on the incoming one
100+
outReq, err := http.NewRequest(r.Method, r.RequestURI, r.Body)
101+
if err != nil {
102+
http.Error(w, "Error creating request", http.StatusInternalServerError)
103+
104+
return
105+
}
106+
107+
outReq.Header = r.Header.Clone()
108+
109+
// Use http.DefaultTransport to perform the request
110+
resp, err := http.DefaultTransport.RoundTrip(outReq)
111+
if err != nil {
112+
http.Error(w, "Error forwarding request: "+err.Error(), http.StatusBadGateway)
113+
114+
return
115+
}
116+
117+
defer callAndLogError(resp.Body.Close)
118+
119+
// Copy response headers
120+
for key, values := range resp.Header {
121+
for _, value := range values {
122+
w.Header().Add(key, value)
123+
}
124+
}
125+
126+
// Write status code
127+
w.WriteHeader(resp.StatusCode)
128+
129+
// Copy body
130+
_, err = io.Copy(w, resp.Body)
131+
if err != nil {
132+
slog.Error(err.Error())
133+
}
134+
}
135+
136+
func callAndLogError(f func() error) {
137+
if err := f(); err != nil {
138+
slog.Error(err.Error())
139+
}
140+
}
141+
142+
func main() {
143+
server := &http.Server{
144+
Addr: ":8080", // listen on port 8080
145+
Handler: http.HandlerFunc(handleProxy),
146+
}
147+
148+
// Check if port is available
149+
ln, err := net.Listen("tcp", server.Addr)
150+
if err != nil {
151+
log.Fatalf("Could not listen on %s: %v", server.Addr, err)
152+
}
153+
154+
log.Printf("Proxy server listening on %s", server.Addr)
155+
156+
if err = server.Serve(ln); err != nil {
157+
log.Fatalf("Server failed: %v", err)
158+
}
159+
}

main_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"encoding/base64"
6+
"errors"
7+
"io"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
)
13+
14+
func TestHandleRequestAndRedirect(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
requestMethod string
18+
requestBody string
19+
requestHeaders map[string]string
20+
mockResponse *http.Response
21+
mockError error
22+
expectedCode int
23+
expectedBody string
24+
dropCreds bool
25+
}{
26+
{
27+
name: "successful http request",
28+
requestMethod: http.MethodGet,
29+
mockResponse: &http.Response{
30+
StatusCode: http.StatusOK,
31+
Body: io.NopCloser(bytes.NewBufferString("success")),
32+
},
33+
expectedCode: http.StatusOK,
34+
expectedBody: "success",
35+
},
36+
{
37+
name: "bad gateway on forward error",
38+
requestMethod: http.MethodGet,
39+
mockError: errors.New("dial tcp connection failed"),
40+
expectedCode: http.StatusBadGateway,
41+
expectedBody: "Error forwarding request: dial tcp connection failed\n",
42+
},
43+
{
44+
name: "request with wrong credentials",
45+
requestMethod: http.MethodGet,
46+
expectedCode: http.StatusProxyAuthRequired,
47+
expectedBody: "Proxy Authentication Required\n",
48+
dropCreds: true,
49+
},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
originalTransport := http.DefaultTransport
55+
defer func() { http.DefaultTransport = originalTransport }()
56+
57+
http.DefaultTransport = &mockTransport{
58+
response: tt.mockResponse,
59+
err: tt.mockError,
60+
}
61+
62+
req := httptest.NewRequest(tt.requestMethod, "/test", bytes.NewBufferString(tt.requestBody))
63+
for k, v := range tt.requestHeaders {
64+
req.Header.Add(k, v)
65+
}
66+
67+
if !tt.dropCreds {
68+
req.Header.Add(
69+
"Proxy-Authorization",
70+
"Basic "+base64.StdEncoding.EncodeToString([]byte("test-user:valid-pass")),
71+
)
72+
}
73+
74+
rr := httptest.NewRecorder()
75+
76+
t.Setenv("PROXY_USER", "test-user")
77+
t.Setenv("PROXY_PASS", "valid-pass")
78+
79+
handleProxy(rr, req)
80+
81+
res := rr.Result()
82+
83+
t.Cleanup(callAndLogCleanup(t, res.Body.Close))
84+
85+
if res.StatusCode != tt.expectedCode {
86+
t.Errorf("expected status %d, got %d", tt.expectedCode, res.StatusCode)
87+
}
88+
89+
body, _ := io.ReadAll(res.Body)
90+
if string(body) != tt.expectedBody {
91+
t.Errorf("expected body %q, got %q", tt.expectedBody, string(body))
92+
}
93+
94+
if tt.mockResponse != nil {
95+
for key, values := range tt.mockResponse.Header {
96+
for _, value := range values {
97+
if rr.Header().Get(key) != value {
98+
t.Errorf("expected header %q with value %q, got %q", key, value, rr.Header().Get(key))
99+
}
100+
}
101+
}
102+
}
103+
})
104+
}
105+
}
106+
107+
type mockTransport struct {
108+
response *http.Response
109+
err error
110+
}
111+
112+
func (m *mockTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
113+
return m.response, m.err
114+
}
115+
116+
func callAndLogCleanup(t *testing.T, f func() error) func() {
117+
t.Helper()
118+
119+
return func() {
120+
if err := f(); err != nil {
121+
t.Log(err)
122+
}
123+
}
124+
}
125+
126+
// Start a simple TCP echo server
127+
func startEchoServer(t *testing.T) (addr string, closeFn func()) {
128+
ln, err := net.Listen("tcp", "127.0.0.1:0") // random free port
129+
if err != nil {
130+
t.Fatal(err)
131+
}
132+
133+
go func() {
134+
for {
135+
conn, err := ln.Accept()
136+
if err != nil {
137+
return
138+
}
139+
140+
go func(c net.Conn) {
141+
defer callAndLogError(c.Close)
142+
143+
_, err = io.Copy(c, c) // echo back
144+
if err != nil {
145+
t.Error(err)
146+
}
147+
}(conn)
148+
}
149+
}()
150+
151+
return ln.Addr().String(), func() { callAndLogError(ln.Close) }
152+
}
153+
154+
func TestHandleTunneling(t *testing.T) {
155+
// Start upstream echo server
156+
targetAddr, closeEcho := startEchoServer(t)
157+
t.Cleanup(closeEcho)
158+
159+
// Fake CONNECT request to proxy
160+
req := httptest.NewRequest(http.MethodConnect, targetAddr, nil)
161+
w := httptest.NewRecorder()
162+
163+
// Run tunneling handler
164+
handleTunneling(w, req)
165+
166+
// The proxy should reply with 200 (connection established)
167+
if w.Code != http.StatusOK {
168+
t.Fatalf("expected 200, got %d", w.Code)
169+
}
170+
}

0 commit comments

Comments
 (0)