Serge Bazanski | ebe6075 | 2021-09-16 11:28:00 +0200 | [diff] [blame] | 1 | package main |
| 2 | |
| 3 | import ( |
| 4 | "fmt" |
| 5 | "io" |
| 6 | "net/http" |
| 7 | "net/http/httptest" |
| 8 | "net/url" |
| 9 | "testing" |
| 10 | ) |
| 11 | |
| 12 | func TestForward(t *testing.T) { |
| 13 | // Test backend which proudly proclaims the value of the X-Forwarded-For header it received. |
| 14 | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 15 | fmt.Fprintf(w, "hello %s %s\n", r.Host, r.Header.Get("X-Forwarded-For")) |
| 16 | })) |
| 17 | defer backendServer.Close() |
| 18 | rpURL, err := url.Parse(backendServer.URL) |
| 19 | if err != nil { |
| 20 | t.Fatalf("parsing test backend URL failed: %v", err) |
| 21 | } |
| 22 | |
| 23 | // Configure and run proxy. |
| 24 | flagUpstream = rpURL.Host |
| 25 | flagUpstreamHost = "example.com" |
| 26 | flagDownstreamHost = "matrix.example.com" |
| 27 | proxy := httptest.NewServer(newProxy()) |
| 28 | defer proxy.Close() |
| 29 | |
| 30 | // Run through a few tests. |
| 31 | for i, te := range []struct { |
| 32 | headers map[string]string |
| 33 | host string |
| 34 | want string |
| 35 | }{ |
| 36 | { |
| 37 | // 0: expected to succeed |
| 38 | headers: map[string]string{ |
| 39 | "Hscloud-Nic-Source-IP": "1.2.3.4", |
| 40 | "Hscloud-Nic-Source-Port": "1337", |
| 41 | }, |
| 42 | host: "matrix.example.com", |
| 43 | want: "hello example.com 1.2.3.4:1337, 127.0.0.1\n", |
| 44 | }, |
| 45 | { |
| 46 | // 1: expected to succeed |
| 47 | host: "matrix.example.com", |
| 48 | want: "hello example.com 127.0.0.1\n", |
| 49 | }, |
| 50 | { |
| 51 | // 2: expected to succeed |
| 52 | host: "matrix.example.com:443", |
| 53 | want: "hello example.com 127.0.0.1\n", |
| 54 | }, |
| 55 | { |
| 56 | // 3: expected to fail |
| 57 | host: "example.com", |
| 58 | want: "invalid host\n", |
| 59 | }, |
| 60 | } { |
| 61 | req, _ := http.NewRequest("GET", proxy.URL, nil) |
| 62 | req.Host = te.host |
| 63 | for k, v := range te.headers { |
| 64 | req.Header.Set(k, v) |
| 65 | } |
| 66 | |
| 67 | resp, err := http.DefaultClient.Do(req) |
| 68 | if err != nil { |
| 69 | t.Fatalf("Get failed: %v", err) |
| 70 | } |
| 71 | |
| 72 | b, err := io.ReadAll(resp.Body) |
| 73 | if err != nil { |
| 74 | t.Fatalf("Read failed: %v", err) |
| 75 | } |
| 76 | resp.Body.Close() |
| 77 | |
| 78 | if want, got := te.want, string(b); want != got { |
| 79 | t.Errorf("%d: wrong response from upstream, wanted %q, got %q", i, want, got) |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | } |