mirror of https://github.com/dexidp/dex.git
18 changed files with 1162 additions and 364 deletions
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,129 @@
|
||||
package pq |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"database/sql/driver" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
) |
||||
|
||||
// Implement the "QueryerContext" interface
|
||||
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { |
||||
list := make([]driver.Value, len(args)) |
||||
for i, nv := range args { |
||||
list[i] = nv.Value |
||||
} |
||||
finish := cn.watchCancel(ctx) |
||||
r, err := cn.query(query, list) |
||||
if err != nil { |
||||
if finish != nil { |
||||
finish() |
||||
} |
||||
return nil, err |
||||
} |
||||
r.finish = finish |
||||
return r, nil |
||||
} |
||||
|
||||
// Implement the "ExecerContext" interface
|
||||
func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { |
||||
list := make([]driver.Value, len(args)) |
||||
for i, nv := range args { |
||||
list[i] = nv.Value |
||||
} |
||||
|
||||
if finish := cn.watchCancel(ctx); finish != nil { |
||||
defer finish() |
||||
} |
||||
|
||||
return cn.Exec(query, list) |
||||
} |
||||
|
||||
// Implement the "ConnBeginTx" interface
|
||||
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { |
||||
var mode string |
||||
|
||||
switch sql.IsolationLevel(opts.Isolation) { |
||||
case sql.LevelDefault: |
||||
// Don't touch mode: use the server's default
|
||||
case sql.LevelReadUncommitted: |
||||
mode = " ISOLATION LEVEL READ UNCOMMITTED" |
||||
case sql.LevelReadCommitted: |
||||
mode = " ISOLATION LEVEL READ COMMITTED" |
||||
case sql.LevelRepeatableRead: |
||||
mode = " ISOLATION LEVEL REPEATABLE READ" |
||||
case sql.LevelSerializable: |
||||
mode = " ISOLATION LEVEL SERIALIZABLE" |
||||
default: |
||||
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) |
||||
} |
||||
|
||||
if opts.ReadOnly { |
||||
mode += " READ ONLY" |
||||
} else { |
||||
mode += " READ WRITE" |
||||
} |
||||
|
||||
tx, err := cn.begin(mode) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
cn.txnFinish = cn.watchCancel(ctx) |
||||
return tx, nil |
||||
} |
||||
|
||||
func (cn *conn) watchCancel(ctx context.Context) func() { |
||||
if done := ctx.Done(); done != nil { |
||||
finished := make(chan struct{}) |
||||
go func() { |
||||
select { |
||||
case <-done: |
||||
_ = cn.cancel() |
||||
finished <- struct{}{} |
||||
case <-finished: |
||||
} |
||||
}() |
||||
return func() { |
||||
select { |
||||
case <-finished: |
||||
case finished <- struct{}{}: |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cn *conn) cancel() error { |
||||
c, err := dial(cn.dialer, cn.opts) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer c.Close() |
||||
|
||||
{ |
||||
can := conn{ |
||||
c: c, |
||||
} |
||||
err = can.ssl(cn.opts) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
w := can.writeBuf(0) |
||||
w.int32(80877102) // cancel request code
|
||||
w.int32(cn.processID) |
||||
w.int32(cn.secretKey) |
||||
|
||||
if err := can.sendStartupPacket(w); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// Read until EOF to ensure that the server received the cancel.
|
||||
{ |
||||
_, err := io.Copy(ioutil.Discard, c) |
||||
return err |
||||
} |
||||
} |
||||
@ -0,0 +1,43 @@
|
||||
// +build go1.10
|
||||
|
||||
package pq |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql/driver" |
||||
) |
||||
|
||||
// Connector represents a fixed configuration for the pq driver with a given
|
||||
// name. Connector satisfies the database/sql/driver Connector interface and
|
||||
// can be used to create any number of DB Conn's via the database/sql OpenDB
|
||||
// function.
|
||||
//
|
||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
|
||||
// See https://golang.org/pkg/database/sql/#OpenDB.
|
||||
type connector struct { |
||||
name string |
||||
} |
||||
|
||||
// Connect returns a connection to the database using the fixed configuration
|
||||
// of this Connector. Context is not used.
|
||||
func (c *connector) Connect(_ context.Context) (driver.Conn, error) { |
||||
return (&Driver{}).Open(c.name) |
||||
} |
||||
|
||||
// Driver returnst the underlying driver of this Connector.
|
||||
func (c *connector) Driver() driver.Driver { |
||||
return &Driver{} |
||||
} |
||||
|
||||
var _ driver.Connector = &connector{} |
||||
|
||||
// NewConnector returns a connector for the pq driver in a fixed configuration
|
||||
// with the given name. The returned connector can be used to create any number
|
||||
// of equivalent Conn's. The returned connector is intended to be used with
|
||||
// database/sql.OpenDB.
|
||||
//
|
||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
|
||||
// See https://golang.org/pkg/database/sql/#OpenDB.
|
||||
func NewConnector(name string) (driver.Connector, error) { |
||||
return &connector{name: name}, nil |
||||
} |
||||
@ -0,0 +1,93 @@
|
||||
package pq |
||||
|
||||
import ( |
||||
"math" |
||||
"reflect" |
||||
"time" |
||||
|
||||
"github.com/lib/pq/oid" |
||||
) |
||||
|
||||
const headerSize = 4 |
||||
|
||||
type fieldDesc struct { |
||||
// The object ID of the data type.
|
||||
OID oid.Oid |
||||
// The data type size (see pg_type.typlen).
|
||||
// Note that negative values denote variable-width types.
|
||||
Len int |
||||
// The type modifier (see pg_attribute.atttypmod).
|
||||
// The meaning of the modifier is type-specific.
|
||||
Mod int |
||||
} |
||||
|
||||
func (fd fieldDesc) Type() reflect.Type { |
||||
switch fd.OID { |
||||
case oid.T_int8: |
||||
return reflect.TypeOf(int64(0)) |
||||
case oid.T_int4: |
||||
return reflect.TypeOf(int32(0)) |
||||
case oid.T_int2: |
||||
return reflect.TypeOf(int16(0)) |
||||
case oid.T_varchar, oid.T_text: |
||||
return reflect.TypeOf("") |
||||
case oid.T_bool: |
||||
return reflect.TypeOf(false) |
||||
case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: |
||||
return reflect.TypeOf(time.Time{}) |
||||
case oid.T_bytea: |
||||
return reflect.TypeOf([]byte(nil)) |
||||
default: |
||||
return reflect.TypeOf(new(interface{})).Elem() |
||||
} |
||||
} |
||||
|
||||
func (fd fieldDesc) Name() string { |
||||
return oid.TypeName[fd.OID] |
||||
} |
||||
|
||||
func (fd fieldDesc) Length() (length int64, ok bool) { |
||||
switch fd.OID { |
||||
case oid.T_text, oid.T_bytea: |
||||
return math.MaxInt64, true |
||||
case oid.T_varchar, oid.T_bpchar: |
||||
return int64(fd.Mod - headerSize), true |
||||
default: |
||||
return 0, false |
||||
} |
||||
} |
||||
|
||||
func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { |
||||
switch fd.OID { |
||||
case oid.T_numeric, oid.T__numeric: |
||||
mod := fd.Mod - headerSize |
||||
precision = int64((mod >> 16) & 0xffff) |
||||
scale = int64(mod & 0xffff) |
||||
return precision, scale, true |
||||
default: |
||||
return 0, 0, false |
||||
} |
||||
} |
||||
|
||||
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
||||
func (rs *rows) ColumnTypeScanType(index int) reflect.Type { |
||||
return rs.colTyps[index].Type() |
||||
} |
||||
|
||||
// ColumnTypeDatabaseTypeName return the database system type name.
|
||||
func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { |
||||
return rs.colTyps[index].Name() |
||||
} |
||||
|
||||
// ColumnTypeLength returns the length of the column type if the column is a
|
||||
// variable length type. If the column is not a variable length type ok
|
||||
// should return false.
|
||||
func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { |
||||
return rs.colTyps[index].Length() |
||||
} |
||||
|
||||
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
||||
// types. If not applicable, ok should be false.
|
||||
func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { |
||||
return rs.colTyps[index].PrecisionScale() |
||||
} |
||||
@ -0,0 +1,175 @@
|
||||
package pq |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"io/ioutil" |
||||
"net" |
||||
"os" |
||||
"os/user" |
||||
"path/filepath" |
||||
) |
||||
|
||||
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
|
||||
// related settings. The function is nil when no upgrade should take place.
|
||||
func ssl(o values) (func(net.Conn) (net.Conn, error), error) { |
||||
verifyCaOnly := false |
||||
tlsConf := tls.Config{} |
||||
switch mode := o["sslmode"]; mode { |
||||
// "require" is the default.
|
||||
case "", "require": |
||||
// We must skip TLS's own verification since it requires full
|
||||
// verification since Go 1.3.
|
||||
tlsConf.InsecureSkipVerify = true |
||||
|
||||
// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
|
||||
//
|
||||
// Note: For backwards compatibility with earlier versions of
|
||||
// PostgreSQL, if a root CA file exists, the behavior of
|
||||
// sslmode=require will be the same as that of verify-ca, meaning the
|
||||
// server certificate is validated against the CA. Relying on this
|
||||
// behavior is discouraged, and applications that need certificate
|
||||
// validation should always use verify-ca or verify-full.
|
||||
if sslrootcert, ok := o["sslrootcert"]; ok { |
||||
if _, err := os.Stat(sslrootcert); err == nil { |
||||
verifyCaOnly = true |
||||
} else { |
||||
delete(o, "sslrootcert") |
||||
} |
||||
} |
||||
case "verify-ca": |
||||
// We must skip TLS's own verification since it requires full
|
||||
// verification since Go 1.3.
|
||||
tlsConf.InsecureSkipVerify = true |
||||
verifyCaOnly = true |
||||
case "verify-full": |
||||
tlsConf.ServerName = o["host"] |
||||
case "disable": |
||||
return nil, nil |
||||
default: |
||||
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) |
||||
} |
||||
|
||||
err := sslClientCertificates(&tlsConf, o) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
err = sslCertificateAuthority(&tlsConf, o) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Accept renegotiation requests initiated by the backend.
|
||||
//
|
||||
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
|
||||
// the default configuration of older versions has it enabled. Redshift
|
||||
// also initiates renegotiations and cannot be reconfigured.
|
||||
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient |
||||
|
||||
return func(conn net.Conn) (net.Conn, error) { |
||||
client := tls.Client(conn, &tlsConf) |
||||
if verifyCaOnly { |
||||
err := sslVerifyCertificateAuthority(client, &tlsConf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
return client, nil |
||||
}, nil |
||||
} |
||||
|
||||
// sslClientCertificates adds the certificate specified in the "sslcert" and
|
||||
// "sslkey" settings, or if they aren't set, from the .postgresql directory
|
||||
// in the user's home directory. The configured files must exist and have
|
||||
// the correct permissions.
|
||||
func sslClientCertificates(tlsConf *tls.Config, o values) error { |
||||
// user.Current() might fail when cross-compiling. We have to ignore the
|
||||
// error and continue without home directory defaults, since we wouldn't
|
||||
// know from where to load them.
|
||||
user, _ := user.Current() |
||||
|
||||
// In libpq, the client certificate is only loaded if the setting is not blank.
|
||||
//
|
||||
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037
|
||||
sslcert := o["sslcert"] |
||||
if len(sslcert) == 0 && user != nil { |
||||
sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") |
||||
} |
||||
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045
|
||||
if len(sslcert) == 0 { |
||||
return nil |
||||
} |
||||
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054
|
||||
if _, err := os.Stat(sslcert); os.IsNotExist(err) { |
||||
return nil |
||||
} else if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// In libpq, the ssl key is only loaded if the setting is not blank.
|
||||
//
|
||||
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222
|
||||
sslkey := o["sslkey"] |
||||
if len(sslkey) == 0 && user != nil { |
||||
sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") |
||||
} |
||||
|
||||
if len(sslkey) > 0 { |
||||
if err := sslKeyPermissions(sslkey); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
tlsConf.Certificates = []tls.Certificate{cert} |
||||
return nil |
||||
} |
||||
|
||||
// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting.
|
||||
func sslCertificateAuthority(tlsConf *tls.Config, o values) error { |
||||
// In libpq, the root certificate is only loaded if the setting is not blank.
|
||||
//
|
||||
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951
|
||||
if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { |
||||
tlsConf.RootCAs = x509.NewCertPool() |
||||
|
||||
cert, err := ioutil.ReadFile(sslrootcert) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { |
||||
return fmterrorf("couldn't parse pem in sslrootcert") |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// sslVerifyCertificateAuthority carries out a TLS handshake to the server and
|
||||
// verifies the presented certificate against the CA, i.e. the one specified in
|
||||
// sslrootcert or the system CA if sslrootcert was not specified.
|
||||
func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { |
||||
err := client.Handshake() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
certs := client.ConnectionState().PeerCertificates |
||||
opts := x509.VerifyOptions{ |
||||
DNSName: client.ConnectionState().ServerName, |
||||
Intermediates: x509.NewCertPool(), |
||||
Roots: tlsConf.RootCAs, |
||||
} |
||||
for i, cert := range certs { |
||||
if i == 0 { |
||||
continue |
||||
} |
||||
opts.Intermediates.AddCert(cert) |
||||
} |
||||
_, err = certs[0].Verify(opts) |
||||
return err |
||||
} |
||||
@ -0,0 +1,20 @@
|
||||
// +build !windows
|
||||
|
||||
package pq |
||||
|
||||
import "os" |
||||
|
||||
// sslKeyPermissions checks the permissions on user-supplied ssl key files.
|
||||
// The key file should have very little access.
|
||||
//
|
||||
// libpq does not check key file permissions on Windows.
|
||||
func sslKeyPermissions(sslkey string) error { |
||||
info, err := os.Stat(sslkey) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if info.Mode().Perm()&0077 != 0 { |
||||
return ErrSSLKeyHasWorldPermissions |
||||
} |
||||
return nil |
||||
} |
||||
@ -0,0 +1,9 @@
|
||||
// +build windows
|
||||
|
||||
package pq |
||||
|
||||
// sslKeyPermissions checks the permissions on user-supplied ssl key files.
|
||||
// The key file should have very little access.
|
||||
//
|
||||
// libpq does not check key file permissions on Windows.
|
||||
func sslKeyPermissions(string) error { return nil } |
||||
@ -0,0 +1,23 @@
|
||||
package pq |
||||
|
||||
import ( |
||||
"encoding/hex" |
||||
"fmt" |
||||
) |
||||
|
||||
// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format.
|
||||
func decodeUUIDBinary(src []byte) ([]byte, error) { |
||||
if len(src) != 16 { |
||||
return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) |
||||
} |
||||
|
||||
dst := make([]byte, 36) |
||||
dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' |
||||
hex.Encode(dst[0:], src[0:4]) |
||||
hex.Encode(dst[9:], src[4:6]) |
||||
hex.Encode(dst[14:], src[6:8]) |
||||
hex.Encode(dst[19:], src[8:10]) |
||||
hex.Encode(dst[24:], src[10:16]) |
||||
|
||||
return dst, nil |
||||
} |
||||
Loading…
Reference in new issue