Browse Source

Merge d274d2cf0e into 13f012fb81

pull/4143/merge
Martin Kjær Jørgensen 4 days ago committed by GitHub
parent
commit
244fafaf19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 32
      cmd/dex/serve.go
  2. 23
      cmd/dex/serve_test.go
  3. 5
      config.yaml.dist
  4. 29
      connector/ldap/ldap.go

32
cmd/dex/serve.go

@ -10,10 +10,12 @@ import (
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"syscall"
@ -89,6 +91,32 @@ func commandServe() *cobra.Command {
return cmd
}
// try detect the intended socket type from address string
func getSocketType(address string) string {
if h, p, serr := net.SplitHostPort(address); serr == nil {
// if port string is a number, assume tcp
if _, cerr := strconv.Atoi(p); cerr == nil {
return "tcp"
}
// otherwise results in unix socket path
return h
}
if u, perr := url.Parse(address); perr == nil {
if len(u.Scheme) > 0 {
// if scheme is recognized use that
return u.Scheme
} else {
// when parser gets a file path Scheme is
// empty. so default to unix socket.
return "unix"
}
}
// assume unix file path
return "unix"
}
func runServe(options serveOptions) error {
configFile := options.config
configData, err := os.ReadFile(configFile)
@ -454,7 +482,7 @@ func runServe(options serveOptions) error {
logger.Info("listening on", "server", name, "address", c.Telemetry.HTTP)
l, err := net.Listen("tcp", c.Telemetry.HTTP)
l, err := net.Listen(getSocketType(c.Telemetry.HTTP), c.Telemetry.HTTP)
if err != nil {
return fmt.Errorf("listening (%s) on %s: %v", name, c.Telemetry.HTTP, err)
}
@ -487,7 +515,7 @@ func runServe(options serveOptions) error {
logger.Info("listening on", "server", name, "address", c.Web.HTTP)
l, err := net.Listen("tcp", c.Web.HTTP)
l, err := net.Listen(getSocketType(c.Web.HTTP), c.Web.HTTP)
if err != nil {
return fmt.Errorf("listening (%s) on %s: %v", name, c.Web.HTTP, err)
}

23
cmd/dex/serve_test.go

@ -27,3 +27,26 @@ func TestNewLogger(t *testing.T) {
require.Equal(t, (*slog.Logger)(nil), logger)
})
}
func TestGetSocketType(t *testing.T) {
urls := [][]string{
{"tcp://127.0.0.1:8080", "tcp"},
{"unix:///tmp/my.sock", "unix"},
{"127.0.0.1:9000", "tcp"},
{"/var/run/app.sock", "unix"},
{"./socket.sock", "unix"},
{"relative.sock", "unix"},
{"example.com:80", "tcp"},
{"unix://./run/rel.sock", "unix"},
{"[::1]:443", "tcp"},
{"[::FFFF:129.144.52.38]:80", "tcp"},
{"a/b/c", "unix"},
{"/d/e/f", "unix"},
{"localhost:80", "tcp"},
}
for _, url := range urls {
t.Run(url[0], func(t *testing.T) {
require.Equal(t, url[1], getSocketType(url[0]))
})
}
}

5
config.yaml.dist

@ -50,6 +50,8 @@ storage:
# HTTP service configuration
web:
http: 127.0.0.1:5556
# Uncomment to use a UNIX socket endpoint.
# http: /run/dex/dex.sock
# Uncomment to enable HTTPS endpoint.
# https: 127.0.0.1:5554
@ -69,6 +71,9 @@ web:
# telemetry:
# http: 127.0.0.1:5558
# Uncomment to use UNIX socket for telemetry endpoint.
# http: /run/dex/dex-telemetry.sock
# logger:
# level: "debug"
# format: "text" # can also be "json"

29
connector/ldap/ldap.go

@ -255,7 +255,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) {
host string
err error
)
if host, _, err = net.SplitHostPort(c.Host); err != nil {
_, resErr := net.ResolveTCPAddr("tcp", c.Host)
if host, _, err = net.SplitHostPort(c.Host); err != nil && resErr == nil {
host = c.Host
if c.InsecureNoSSL {
c.Host += ":389"
@ -327,13 +329,27 @@ func (c *ldapConnector) do(_ context.Context, f func(c *ldap.Conn) error) error
err error
)
var dialAddr string
if c.InsecureNoSSL {
_, resErr := net.ResolveTCPAddr("tcp", c.Host)
if resErr == nil {
u := url.URL{Scheme: "ldap", Host: c.Host}
dialAddr = u.String()
} else {
// assume UNIX socket
dialAddr = fmt.Sprintf("ldapi:%s", c.Host)
}
} else {
u := url.URL{Scheme: "ldaps", Host: c.Host}
dialAddr = u.String()
}
switch {
case c.InsecureNoSSL:
u := url.URL{Scheme: "ldap", Host: c.Host}
conn, err = ldap.DialURL(u.String())
conn, err = ldap.DialURL(dialAddr)
case c.StartTLS:
u := url.URL{Scheme: "ldap", Host: c.Host}
conn, err = ldap.DialURL(u.String())
conn, err = ldap.DialURL(dialAddr)
if err != nil {
return fmt.Errorf("failed to connect: %v", err)
}
@ -341,8 +357,7 @@ func (c *ldapConnector) do(_ context.Context, f func(c *ldap.Conn) error) error
return fmt.Errorf("start TLS failed: %v", err)
}
default:
u := url.URL{Scheme: "ldaps", Host: c.Host}
conn, err = ldap.DialURL(u.String(), ldap.DialWithTLSConfig(c.tlsConfig))
conn, err = ldap.DialURL(dialAddr, ldap.DialWithTLSConfig(c.tlsConfig))
}
if err != nil {
return fmt.Errorf("failed to connect: %v", err)

Loading…
Cancel
Save