diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index b3b09c0b..2489f0be 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -10,10 +10,12 @@ import ( "net" "net/http" "net/http/pprof" + "net/url" "os" "os/signal" "path/filepath" "runtime" + "strconv" "strings" "sync/atomic" "syscall" @@ -90,13 +92,28 @@ func commandServe() *cobra.Command { // try detect the intended socket type from address string func getSocketType(address string) string { - _, resErr := net.ResolveTCPAddr("tcp", address) - if resErr == nil { - return "tcp" - } else { - // assume UNIX socket - return "unix" + if h, p, serr := net.SplitHostPort(address); serr == nil { + // if port string is a number, assume tcp + if _, cerr := strconv.Atoi(p); cerr == nil { + return "tcp" + } + // otherwise results in unix socket path + return h + } + + if u, perr := url.Parse(address); perr == nil { + if len(u.Scheme) > 0 { + // if scheme is recognized use that + return u.Scheme + } else { + // when parser gets a file path Scheme is + // empty. so default to unix socket. + return "unix" + } } + + // assume unix file path + return "unix" } func runServe(options serveOptions) error { diff --git a/cmd/dex/serve_test.go b/cmd/dex/serve_test.go index 9e214480..aa9890d8 100644 --- a/cmd/dex/serve_test.go +++ b/cmd/dex/serve_test.go @@ -27,3 +27,26 @@ func TestNewLogger(t *testing.T) { require.Equal(t, (*slog.Logger)(nil), logger) }) } + +func TestGetSocketType(t *testing.T) { + urls := [][]string{ + {"tcp://127.0.0.1:8080", "tcp"}, + {"unix:///tmp/my.sock", "unix"}, + {"127.0.0.1:9000", "tcp"}, + {"/var/run/app.sock", "unix"}, + {"./socket.sock", "unix"}, + {"relative.sock", "unix"}, + {"example.com:80", "tcp"}, + {"unix://./run/rel.sock", "unix"}, + {"[::1]:443", "tcp"}, + {"[::FFFF:129.144.52.38]:80", "tcp"}, + {"a/b/c", "unix"}, + {"/d/e/f", "unix"}, + {"localhost:80", "tcp"}, + } + for _, url := range urls { + t.Run(url[0], func(t *testing.T) { + require.Equal(t, url[1], getSocketType(url[0])) + }) + } +}