implement basic ACL for debug http
diff --git a/README b/README
index 940db3f..2147adb 100644
--- a/README
+++ b/README
@@ -43,6 +43,7 @@
- `-listen_address` (default: `127.0.0.1:4200`): where to listen for gRPC requests
- `-debug_address` (default: `127.0.0.1:4201`): where to listen for debug HTTP requests
+ - `-debug_allow_all` (default: false): whether to allow all IP address (vs. localhost) to connect to debug endpoint
Since this library also includes [hspki](https://code.hackerspace.pl/q3k/hspki), you also get all the typical `-hspki_{...}` flags included.
diff --git a/mirko.go b/mirko.go
index 4a5f01f..90959ae 100644
--- a/mirko.go
+++ b/mirko.go
@@ -19,11 +19,13 @@
var (
flagListenAddress string
flagDebugAddress string
+ flagDebugAllowAll bool
)
func init() {
flag.StringVar(&flagListenAddress, "listen_address", "127.0.0.1:4200", "gRPC listen address")
flag.StringVar(&flagDebugAddress, "debug_address", "127.0.0.1:4201", "HTTP debug/status listen address")
+ flag.StringVar(&flagDebugAllowAll, "debug_allow_all", false, "HTTP debug/status available to everyone")
flag.Set("logtostderr", "true")
}
@@ -39,8 +41,28 @@
return &Mirko{}
}
+func authRequest(req *http.Request) (any, sensitive bool) {
+ host, _, err := net.SplitHostPort(req.RemoteAddr)
+ if err != nil {
+ host = req.RemoteAddr
+ }
+
+ if flagDebugAllowAll {
+ return true, true
+ }
+
+ switch host {
+ case "localhost", "127.0.0.1", "::1":
+ return true, true
+ default:
+ return false, false
+ }
+}
+
func (m *Mirko) Listen() error {
grpc.EnableTracing = true
+ trace.AuthRequest = authRequest
+
grpcLis, err := net.Listen("tcp", flagListenAddress)
if err != nil {
return fmt.Errorf("net.Listen: %v", err)
@@ -56,7 +78,14 @@
m.httpMux = http.NewServeMux()
// Canonical URLs
- m.httpMux.HandleFunc("/debug/status", statusz.StatusHandler)
+ m.httpMux.HandleFunc("/debug/status", func(w http.ResponseWriter, r *http.Request) {
+ any, sensitive := authRequest(r)
+ if !any {
+ http.Error(w, "not allowed", http.StatusUnauthorized)
+ return
+ }
+ statusz.StatusHandler(w, r)
+ })
m.httpMux.HandleFunc("/debug/requests", trace.Traces)
// -z legacy URLs