mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
3.2 KiB
158 lines
3.2 KiB
package ent |
|
|
|
import ( |
|
"context" |
|
"crypto/sha256" |
|
"crypto/tls" |
|
"crypto/x509" |
|
"database/sql" |
|
"fmt" |
|
"log/slog" |
|
"net" |
|
"os" |
|
"strconv" |
|
"time" |
|
|
|
entSQL "entgo.io/ent/dialect/sql" |
|
"github.com/go-sql-driver/mysql" // Register mysql driver. |
|
|
|
"github.com/dexidp/dex/storage" |
|
"github.com/dexidp/dex/storage/ent/client" |
|
"github.com/dexidp/dex/storage/ent/db" |
|
) |
|
|
|
const ( |
|
// MySQL SSL modes |
|
mysqlSSLTrue = "true" |
|
mysqlSSLFalse = "false" |
|
mysqlSSLSkipVerify = "skip-verify" |
|
mysqlSSLCustom = "custom" |
|
) |
|
|
|
// MySQL options for creating an SQL db. |
|
type MySQL struct { |
|
NetworkDB |
|
|
|
SSL SSL `json:"ssl"` |
|
|
|
params map[string]string |
|
} |
|
|
|
// Open always returns a new in sqlite3 storage. |
|
func (m *MySQL) Open(logger *slog.Logger) (storage.Storage, error) { |
|
logger.Debug("experimental ent-based storage driver is enabled") |
|
drv, err := m.driver() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
databaseClient := client.NewDatabase( |
|
client.WithClient(db.NewClient(db.Driver(drv))), |
|
client.WithHasher(sha256.New), |
|
// Set tx isolation leve for each transaction as dex does for postgres |
|
client.WithTxIsolationLevel(sql.LevelSerializable), |
|
) |
|
|
|
if err := databaseClient.Schema().Create(context.TODO()); err != nil { |
|
return nil, err |
|
} |
|
|
|
return databaseClient, nil |
|
} |
|
|
|
func (m *MySQL) driver() (*entSQL.Driver, error) { |
|
var tlsConfig string |
|
|
|
switch { |
|
case m.SSL.CAFile != "" || m.SSL.CertFile != "" || m.SSL.KeyFile != "": |
|
if err := m.makeTLSConfig(); err != nil { |
|
return nil, fmt.Errorf("failed to make TLS config: %v", err) |
|
} |
|
tlsConfig = mysqlSSLCustom |
|
case m.SSL.Mode == "": |
|
tlsConfig = mysqlSSLTrue |
|
default: |
|
tlsConfig = m.SSL.Mode |
|
} |
|
|
|
drv, err := entSQL.Open("mysql", m.dsn(tlsConfig)) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if m.MaxIdleConns == 0 { |
|
/* Override default behavior to fix https://github.com/dexidp/dex/issues/1608 */ |
|
drv.DB().SetMaxIdleConns(0) |
|
} else { |
|
drv.DB().SetMaxIdleConns(m.MaxIdleConns) |
|
} |
|
|
|
return drv, nil |
|
} |
|
|
|
func (m *MySQL) dsn(tlsConfig string) string { |
|
cfg := mysql.Config{ |
|
User: m.User, |
|
Passwd: m.Password, |
|
DBName: m.Database, |
|
AllowNativePasswords: true, |
|
|
|
Timeout: time.Second * time.Duration(m.ConnectionTimeout), |
|
|
|
TLSConfig: tlsConfig, |
|
|
|
ParseTime: true, |
|
Params: make(map[string]string), |
|
} |
|
|
|
if m.Host != "" { |
|
if m.Host[0] != '/' { |
|
cfg.Net = "tcp" |
|
cfg.Addr = m.Host |
|
|
|
if m.Port != 0 { |
|
cfg.Addr = net.JoinHostPort(m.Host, strconv.Itoa(int(m.Port))) |
|
} |
|
} else { |
|
cfg.Net = "unix" |
|
cfg.Addr = m.Host |
|
} |
|
} |
|
|
|
for k, v := range m.params { |
|
cfg.Params[k] = v |
|
} |
|
|
|
return cfg.FormatDSN() |
|
} |
|
|
|
func (m *MySQL) makeTLSConfig() error { |
|
cfg := &tls.Config{} |
|
|
|
if m.SSL.CAFile != "" { |
|
rootCertPool := x509.NewCertPool() |
|
|
|
pem, err := os.ReadFile(m.SSL.CAFile) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { |
|
return fmt.Errorf("failed to append PEM") |
|
} |
|
cfg.RootCAs = rootCertPool |
|
} |
|
|
|
if m.SSL.CertFile != "" && m.SSL.KeyFile != "" { |
|
clientCert := make([]tls.Certificate, 0, 1) |
|
certs, err := tls.LoadX509KeyPair(m.SSL.CertFile, m.SSL.KeyFile) |
|
if err != nil { |
|
return err |
|
} |
|
clientCert = append(clientCert, certs) |
|
cfg.Certificates = clientCert |
|
} |
|
|
|
mysql.RegisterTLSConfig(mysqlSSLCustom, cfg) |
|
return nil |
|
}
|
|
|