diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index cb675869..efa33860 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -10,10 +10,12 @@ import ( "net" "net/http" "net/http/pprof" + "net/url" "os" "os/signal" "path/filepath" "runtime" + "strconv" "strings" "sync/atomic" "syscall" @@ -89,6 +91,32 @@ func commandServe() *cobra.Command { return cmd } +// try detect the intended socket type from address string +func getSocketType(address string) string { + if h, p, serr := net.SplitHostPort(address); serr == nil { + // if port string is a number, assume tcp + if _, cerr := strconv.Atoi(p); cerr == nil { + return "tcp" + } + // otherwise results in unix socket path + return h + } + + if u, perr := url.Parse(address); perr == nil { + if len(u.Scheme) > 0 { + // if scheme is recognized use that + return u.Scheme + } else { + // when parser gets a file path Scheme is + // empty. so default to unix socket. + return "unix" + } + } + + // assume unix file path + return "unix" +} + func runServe(options serveOptions) error { configFile := options.config configData, err := os.ReadFile(configFile) @@ -454,7 +482,7 @@ func runServe(options serveOptions) error { logger.Info("listening on", "server", name, "address", c.Telemetry.HTTP) - l, err := net.Listen("tcp", c.Telemetry.HTTP) + l, err := net.Listen(getSocketType(c.Telemetry.HTTP), c.Telemetry.HTTP) if err != nil { return fmt.Errorf("listening (%s) on %s: %v", name, c.Telemetry.HTTP, err) } @@ -487,7 +515,7 @@ func runServe(options serveOptions) error { logger.Info("listening on", "server", name, "address", c.Web.HTTP) - l, err := net.Listen("tcp", c.Web.HTTP) + l, err := net.Listen(getSocketType(c.Web.HTTP), c.Web.HTTP) if err != nil { return fmt.Errorf("listening (%s) on %s: %v", name, c.Web.HTTP, err) } diff --git a/cmd/dex/serve_test.go b/cmd/dex/serve_test.go index 9e214480..aa9890d8 100644 --- a/cmd/dex/serve_test.go +++ b/cmd/dex/serve_test.go @@ -27,3 +27,26 @@ func TestNewLogger(t *testing.T) { require.Equal(t, (*slog.Logger)(nil), logger) }) } + +func TestGetSocketType(t *testing.T) { + urls := [][]string{ + {"tcp://127.0.0.1:8080", "tcp"}, + {"unix:///tmp/my.sock", "unix"}, + {"127.0.0.1:9000", "tcp"}, + {"/var/run/app.sock", "unix"}, + {"./socket.sock", "unix"}, + {"relative.sock", "unix"}, + {"example.com:80", "tcp"}, + {"unix://./run/rel.sock", "unix"}, + {"[::1]:443", "tcp"}, + {"[::FFFF:129.144.52.38]:80", "tcp"}, + {"a/b/c", "unix"}, + {"/d/e/f", "unix"}, + {"localhost:80", "tcp"}, + } + for _, url := range urls { + t.Run(url[0], func(t *testing.T) { + require.Equal(t, url[1], getSocketType(url[0])) + }) + } +} diff --git a/config.yaml.dist b/config.yaml.dist index 3f888e08..4e5dcc58 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -50,6 +50,8 @@ storage: # HTTP service configuration web: http: 127.0.0.1:5556 + # Uncomment to use a UNIX socket endpoint. + # http: /run/dex/dex.sock # Uncomment to enable HTTPS endpoint. # https: 127.0.0.1:5554 @@ -69,6 +71,9 @@ web: # telemetry: # http: 127.0.0.1:5558 +# Uncomment to use UNIX socket for telemetry endpoint. +# http: /run/dex/dex-telemetry.sock + # logger: # level: "debug" # format: "text" # can also be "json" diff --git a/connector/ldap/ldap.go b/connector/ldap/ldap.go index 4cb7180e..b9587a79 100644 --- a/connector/ldap/ldap.go +++ b/connector/ldap/ldap.go @@ -255,7 +255,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) { host string err error ) - if host, _, err = net.SplitHostPort(c.Host); err != nil { + _, resErr := net.ResolveTCPAddr("tcp", c.Host) + + if host, _, err = net.SplitHostPort(c.Host); err != nil && resErr == nil { host = c.Host if c.InsecureNoSSL { c.Host += ":389" @@ -327,13 +329,27 @@ func (c *ldapConnector) do(_ context.Context, f func(c *ldap.Conn) error) error err error ) + var dialAddr string + + if c.InsecureNoSSL { + _, resErr := net.ResolveTCPAddr("tcp", c.Host) + if resErr == nil { + u := url.URL{Scheme: "ldap", Host: c.Host} + dialAddr = u.String() + } else { + // assume UNIX socket + dialAddr = fmt.Sprintf("ldapi:%s", c.Host) + } + } else { + u := url.URL{Scheme: "ldaps", Host: c.Host} + dialAddr = u.String() + } + switch { case c.InsecureNoSSL: - u := url.URL{Scheme: "ldap", Host: c.Host} - conn, err = ldap.DialURL(u.String()) + conn, err = ldap.DialURL(dialAddr) case c.StartTLS: - u := url.URL{Scheme: "ldap", Host: c.Host} - conn, err = ldap.DialURL(u.String()) + conn, err = ldap.DialURL(dialAddr) if err != nil { return fmt.Errorf("failed to connect: %v", err) } @@ -341,8 +357,7 @@ func (c *ldapConnector) do(_ context.Context, f func(c *ldap.Conn) error) error return fmt.Errorf("start TLS failed: %v", err) } default: - u := url.URL{Scheme: "ldaps", Host: c.Host} - conn, err = ldap.DialURL(u.String(), ldap.DialWithTLSConfig(c.tlsConfig)) + conn, err = ldap.DialURL(dialAddr, ldap.DialWithTLSConfig(c.tlsConfig)) } if err != nil { return fmt.Errorf("failed to connect: %v", err)