mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
4.1 KiB
181 lines
4.1 KiB
package ent |
|
|
|
import ( |
|
"io" |
|
"log/slog" |
|
"os" |
|
"strconv" |
|
"testing" |
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
"github.com/dexidp/dex/storage" |
|
"github.com/dexidp/dex/storage/conformance" |
|
) |
|
|
|
const ( |
|
PostgresEntHostEnv = "DEX_POSTGRES_ENT_HOST" |
|
PostgresEntPortEnv = "DEX_POSTGRES_ENT_PORT" |
|
PostgresEntDatabaseEnv = "DEX_POSTGRES_ENT_DATABASE" |
|
PostgresEntUserEnv = "DEX_POSTGRES_ENT_USER" |
|
PostgresEntPasswordEnv = "DEX_POSTGRES_ENT_PASSWORD" |
|
) |
|
|
|
func postgresTestConfig(host string, port uint64) *Postgres { |
|
return &Postgres{ |
|
NetworkDB: NetworkDB{ |
|
Database: getenv(PostgresEntDatabaseEnv, "postgres"), |
|
User: getenv(PostgresEntUserEnv, "postgres"), |
|
Password: getenv(PostgresEntPasswordEnv, "postgres"), |
|
Host: host, |
|
Port: uint16(port), |
|
}, |
|
SSL: SSL{ |
|
Mode: pgSSLDisable, // Postgres container doesn't support SSL. |
|
}, |
|
} |
|
} |
|
|
|
func newPostgresStorage(host string, port uint64) storage.Storage { |
|
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) |
|
|
|
cfg := postgresTestConfig(host, port) |
|
s, err := cfg.Open(logger) |
|
if err != nil { |
|
panic(err) |
|
} |
|
return s |
|
} |
|
|
|
func TestPostgres(t *testing.T) { |
|
host := os.Getenv(PostgresEntHostEnv) |
|
if host == "" { |
|
t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv) |
|
} |
|
|
|
port := uint64(5432) |
|
if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" { |
|
var err error |
|
|
|
port, err = strconv.ParseUint(rawPort, 10, 32) |
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) |
|
} |
|
|
|
newStorage := func() storage.Storage { |
|
return newPostgresStorage(host, port) |
|
} |
|
conformance.RunTests(t, newStorage) |
|
conformance.RunTransactionTests(t, newStorage) |
|
} |
|
|
|
func TestPostgresDSN(t *testing.T) { |
|
tests := []struct { |
|
name string |
|
cfg *Postgres |
|
desiredDSN string |
|
}{ |
|
{ |
|
name: "Host port", |
|
cfg: &Postgres{ |
|
NetworkDB: NetworkDB{ |
|
Host: "localhost", |
|
Port: uint16(5432), |
|
}, |
|
}, |
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", |
|
}, |
|
{ |
|
name: "Host with port", |
|
cfg: &Postgres{ |
|
NetworkDB: NetworkDB{ |
|
Host: "localhost:5432", |
|
}, |
|
}, |
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'", |
|
}, |
|
{ |
|
name: "Host ipv6 with port", |
|
cfg: &Postgres{ |
|
NetworkDB: NetworkDB{ |
|
Host: "[a:b:c:d]:5432", |
|
}, |
|
}, |
|
desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'", |
|
}, |
|
{ |
|
name: "Credentials and timeout", |
|
cfg: &Postgres{ |
|
NetworkDB: NetworkDB{ |
|
Database: "test", |
|
User: "test", |
|
Password: "test", |
|
ConnectionTimeout: 5, |
|
}, |
|
}, |
|
desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'", |
|
}, |
|
{ |
|
name: "SSL", |
|
cfg: &Postgres{ |
|
SSL: SSL{ |
|
Mode: pgSSLRequire, |
|
CAFile: "/ca.crt", |
|
KeyFile: "/cert.crt", |
|
CertFile: "/cert.key", |
|
}, |
|
}, |
|
desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'", |
|
}, |
|
} |
|
|
|
for _, tt := range tests { |
|
t.Run(tt.name, func(t *testing.T) { |
|
require.Equal(t, tt.desiredDSN, tt.cfg.dsn()) |
|
}) |
|
} |
|
} |
|
|
|
func TestPostgresDriver(t *testing.T) { |
|
host := os.Getenv(PostgresEntHostEnv) |
|
if host == "" { |
|
t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv) |
|
} |
|
|
|
port := uint64(5432) |
|
if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" { |
|
var err error |
|
|
|
port, err = strconv.ParseUint(rawPort, 10, 32) |
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) |
|
} |
|
|
|
tests := []struct { |
|
name string |
|
cfg func() *Postgres |
|
desiredConns int |
|
}{ |
|
{ |
|
name: "Defaults", |
|
cfg: func() *Postgres { return postgresTestConfig(host, port) }, |
|
desiredConns: 5, |
|
}, |
|
{ |
|
name: "Tune", |
|
cfg: func() *Postgres { |
|
cfg := postgresTestConfig(host, port) |
|
cfg.MaxOpenConns = 101 |
|
return cfg |
|
}, |
|
desiredConns: 101, |
|
}, |
|
} |
|
|
|
for _, tt := range tests { |
|
t.Run(tt.name, func(t *testing.T) { |
|
drv, err := tt.cfg().driver() |
|
require.NoError(t, err) |
|
|
|
require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections) |
|
}) |
|
} |
|
}
|
|
|