@ -11,24 +11,34 @@ import (
"log"
"net"
"sync"
"sync/atomic"
"time"
"gopkg.in/asn1-ber.v1"
)
const (
MessageQuit = 0
MessageRequest = 1
// MessageQuit causes the processMessages loop to exit
MessageQuit = 0
// MessageRequest sends a request to the server
MessageRequest = 1
// MessageResponse receives a response from the server
MessageResponse = 2
MessageFinish = 3
MessageTimeout = 4
// MessageFinish indicates the client considers a particular message ID to be finished
MessageFinish = 3
// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
MessageTimeout = 4
)
// PacketResponse contains the packet or error encountered reading a response
type PacketResponse struct {
// Packet is the packet read from the server
Packet * ber . Packet
Error error
// Error is an error encountered while reading
Error error
}
// ReadPacket returns the packet or an error
func ( pr * PacketResponse ) ReadPacket ( ) ( * ber . Packet , error ) {
if ( pr == nil ) || ( pr . Packet == nil && pr . Error == nil ) {
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: could not retrieve response" ) )
@ -36,11 +46,31 @@ func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
return pr . Packet , pr . Error
}
type messageContext struct {
id int64
// close(done) should only be called from finishMessage()
done chan struct { }
// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
responses chan * PacketResponse
}
// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func ( msgCtx * messageContext ) sendResponse ( packet * PacketResponse ) {
select {
case msgCtx . responses <- packet :
// Successfully sent packet to message handler.
case <- msgCtx . done :
// The request handler is done and will not receive more
// packets.
}
}
type messagePacket struct {
Op int
MessageID int64
Packet * ber . Packet
Channel chan * PacketResponse
Context * messageContext
}
type sendMessageFlags uint
@ -53,19 +83,18 @@ const (
type Conn struct {
conn net . Conn
isTLS bool
isClosing bool
closing uint32
closeErr atomicValue
isStartingTLS bool
Debug debugging
chanConfirm chan bool
chanResults map [ int64 ] chan * PacketResponse
chanConfirm chan struct { }
messageContexts map [ int64 ] * messageContext
chanMessage chan * messagePacket
chanMessageID chan int64
wgSender sync . WaitGroup
wgClose sync . WaitGroup
once sync . Once
outstandingRequests uint
messageMutex sync . Mutex
requestTimeout time . Duration
requestTimeout int64
}
var _ Client = & Conn { }
@ -111,28 +140,39 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
// NewConn returns a new Conn using conn for network I/O.
func NewConn ( conn net . Conn , isTLS bool ) * Conn {
return & Conn {
conn : conn ,
chanConfirm : make ( chan bool ) ,
chanMessageID : make ( chan int64 ) ,
chanMessage : make ( chan * messagePacket , 10 ) ,
chanResults : map [ int64 ] chan * PacketResponse { } ,
requestTimeout : 0 ,
isTLS : isTLS ,
conn : conn ,
chanConfirm : make ( chan struct { } ) ,
chanMessageID : make ( chan int64 ) ,
chanMessage : make ( chan * messagePacket , 10 ) ,
messageContexts : map [ int64 ] * messageContext { } ,
requestTimeout : 0 ,
isTLS : isTLS ,
}
}
// Start initializes goroutines to read responses and process messages
func ( l * Conn ) Start ( ) {
go l . reader ( )
go l . processMessages ( )
l . wgClose . Add ( 1 )
}
// isClosing returns whether or not we're currently closing.
func ( l * Conn ) isClosing ( ) bool {
return atomic . LoadUint32 ( & l . closing ) == 1
}
// setClosing sets the closing value to true
func ( l * Conn ) setClosing ( ) bool {
return atomic . CompareAndSwapUint32 ( & l . closing , 0 , 1 )
}
// Close closes the connection.
func ( l * Conn ) Close ( ) {
l . once . Do ( func ( ) {
l . isClosing = true
l . wgSender . Wait ( )
l . messageMutex . Lock ( )
defer l . messageMutex . Unlock ( )
if l . setClosing ( ) {
l . Debug . Printf ( "Sending quit message and waiting for confirmation" )
l . chanMessage <- & messagePacket { Op : MessageQuit }
<- l . chanConfirm
@ -140,62 +180,56 @@ func (l *Conn) Close() {
l . Debug . Printf ( "Closing network connection" )
if err := l . conn . Close ( ) ; err != nil {
log . Print ( err )
log . Println ( err )
}
l . wgClose . Done ( )
} )
}
l . wgClose . Wait ( )
}
// Sets the time after a request is sent that a MessageTimeout triggers
// SetTimeout set s the time after a request is sent that a MessageTimeout triggers
func ( l * Conn ) SetTimeout ( timeout time . Duration ) {
if timeout > 0 {
l . requestTimeout = timeout
atomic . StoreInt64 ( & l . requestTimeout , int64 ( timeout ) )
}
}
// Returns the next available messageID
func ( l * Conn ) nextMessageID ( ) int64 {
if l . chanMessageID != nil {
if messageID , ok := <- l . chanMessageID ; ok {
return messageID
}
if messageID , ok := <- l . chanMessageID ; ok {
return messageID
}
return 0
}
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func ( l * Conn ) StartTLS ( config * tls . Config ) error {
messageID := l . nextMessageID ( )
if l . isTLS {
return NewError ( ErrorNetwork , errors . New ( "ldap: already encrypted" ) )
}
packet := ber . Encode ( ber . ClassUniversal , ber . TypeConstructed , ber . TagSequence , nil , "LDAP Request" )
packet . AppendChild ( ber . NewInteger ( ber . ClassUniversal , ber . TypePrimitive , ber . TagInteger , messageID , "MessageID" ) )
packet . AppendChild ( ber . NewInteger ( ber . ClassUniversal , ber . TypePrimitive , ber . TagInteger , l . nextMessageID ( ) , "MessageID" ) )
request := ber . Encode ( ber . ClassApplication , ber . TypeConstructed , ApplicationExtendedRequest , nil , "Start TLS" )
request . AppendChild ( ber . NewString ( ber . ClassContext , ber . TypePrimitive , 0 , "1.3.6.1.4.1.1466.20037" , "TLS Extended Command" ) )
packet . AppendChild ( request )
l . Debug . PrintPacket ( packet )
channel , err := l . sendMessageWithFlags ( packet , startTLS )
msgCtx , err := l . sendMessageWithFlags ( packet , startTLS )
if err != nil {
return err
}
if channel == nil {
return NewError ( ErrorNetwork , errors . New ( "ldap: could not send message" ) )
}
defer l . finishMessage ( msgCtx )
l . Debug . Printf ( "%d: waiting for response" , msgCtx . id )
l . Debug . Printf ( "%d: waiting for response" , messageID )
defer l . finishMessage ( messageID )
packetResponse , ok := <- channel
packetResponse , ok := <- msgCtx . responses
if ! ok {
return NewError ( ErrorNetwork , errors . New ( "ldap: channel closed" ) )
return NewError ( ErrorNetwork , errors . New ( "ldap: response channel closed" ) )
}
packet , err = packetResponse . ReadPacket ( )
l . Debug . Printf ( "%d: got response %p" , messageID , packet )
l . Debug . Printf ( "%d: got response %p" , msgCtx . id , packet )
if err != nil {
return err
}
@ -226,45 +260,51 @@ func (l *Conn) StartTLS(config *tls.Config) error {
return nil
}
func ( l * Conn ) sendMessage ( packet * ber . Packet ) ( chan * PacketResponse , error ) {
func ( l * Conn ) sendMessage ( packet * ber . Packet ) ( * messageContext , error ) {
return l . sendMessageWithFlags ( packet , 0 )
}
func ( l * Conn ) sendMessageWithFlags ( packet * ber . Packet , flags sendMessageFlags ) ( chan * PacketResponse , error ) {
if l . isClosing {
func ( l * Conn ) sendMessageWithFlags ( packet * ber . Packet , flags sendMessageFlags ) ( * messageContext , error ) {
if l . isClosing ( ) {
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: connection closed" ) )
}
l . messageMutex . Lock ( )
l . Debug . Printf ( "flags&startTLS = %d" , flags & startTLS )
if l . isStartingTLS {
l . messageMutex . Unlock ( )
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: connection is in startls phase. " ) )
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: connection is in startls phase" ) )
}
if flags & startTLS != 0 {
if l . outstandingRequests != 0 {
l . messageMutex . Unlock ( )
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: cannot StartTLS with outstanding requests" ) )
} else {
l . isStartingTLS = true
}
l . isStartingTLS = true
}
l . outstandingRequests ++
l . messageMutex . Unlock ( )
out := make ( chan * PacketResponse )
responses := make ( chan * PacketResponse )
messageID := packet . Children [ 0 ] . Value . ( int64 )
message := & messagePacket {
Op : MessageRequest ,
MessageID : packet . Children [ 0 ] . Value . ( int64 ) ,
MessageID : messageID ,
Packet : packet ,
Channel : out ,
Context : & messageContext {
id : messageID ,
done : make ( chan struct { } ) ,
responses : responses ,
} ,
}
l . sendProcessMessage ( message )
return out , nil
return message . Contex t, nil
}
func ( l * Conn ) finishMessage ( messageID int64 ) {
if l . isClosing {
func ( l * Conn ) finishMessage ( msgCtx * messageContext ) {
close ( msgCtx . done )
if l . isClosing ( ) {
return
}
@ -277,18 +317,18 @@ func (l *Conn) finishMessage(messageID int64) {
message := & messagePacket {
Op : MessageFinish ,
MessageID : messageID ,
MessageID : msgCtx . id ,
}
l . sendProcessMessage ( message )
}
func ( l * Conn ) sendProcessMessage ( message * messagePacket ) bool {
if l . isClosing {
l . messageMutex . Lock ( )
defer l . messageMutex . Unlock ( )
if l . isClosing ( ) {
return false
}
l . wgSender . Add ( 1 )
l . chanMessage <- message
l . wgSender . Done ( )
return true
}
@ -297,13 +337,17 @@ func (l *Conn) processMessages() {
if err := recover ( ) ; err != nil {
log . Printf ( "ldap: recovered panic in processMessages: %v" , err )
}
for messageID , channel := range l . chanResults {
for messageID , msgCtx := range l . messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l . isClosing ( ) && l . closeErr . Load ( ) != nil {
msgCtx . sendResponse ( & PacketResponse { Error : l . closeErr . Load ( ) . ( error ) } )
}
l . Debug . Printf ( "Closing channel for MessageID %d" , messageID )
close ( channel )
delete ( l . chanResults , messageID )
close ( msgCtx . responses )
delete ( l . messageContex ts, messageID )
}
close ( l . chanMessageID )
l . chanConfirm <- true
close ( l . chanConfirm )
} ( )
@ -312,11 +356,7 @@ func (l *Conn) processMessages() {
select {
case l . chanMessageID <- messageID :
messageID ++
case message , ok := <- l . chanMessage :
if ! ok {
l . Debug . Printf ( "Shutting down - message channel is closed" )
return
}
case message := <- l . chanMessage :
switch message . Op {
case MessageQuit :
l . Debug . Printf ( "Shutting down - quit message received" )
@ -324,24 +364,30 @@ func (l *Conn) processMessages() {
case MessageRequest :
// Add to message list and write to network
l . Debug . Printf ( "Sending message %d" , message . MessageID )
l . chanResults [ message . MessageID ] = message . Channel
buf := message . Packet . Bytes ( )
_ , err := l . conn . Write ( buf )
if err != nil {
l . Debug . Printf ( "Error Sending Message: %s" , err . Error ( ) )
message . Context . sendResponse ( & PacketResponse { Error : fmt . Errorf ( "unable to send request: %s" , err ) } )
close ( message . Context . responses )
break
}
// Only add to messageContexts if we were able to
// successfully write the message.
l . messageContexts [ message . MessageID ] = message . Context
// Add timeout if defined
if l . requestTimeout > 0 {
requestTimeout := time . Duration ( atomic . LoadInt64 ( & l . requestTimeout ) )
if requestTimeout > 0 {
go func ( ) {
defer func ( ) {
if err := recover ( ) ; err != nil {
log . Printf ( "ldap: recovered panic in RequestTimeout: %v" , err )
}
} ( )
time . Sleep ( l . requestTimeout )
time . Sleep ( requestTimeout )
timeoutMessage := & messagePacket {
Op : MessageTimeout ,
MessageID : message . MessageID ,
@ -351,26 +397,26 @@ func (l *Conn) processMessages() {
}
case MessageResponse :
l . Debug . Printf ( "Receiving message %d" , message . MessageID )
if chanResult , ok := l . chanResul ts[ message . MessageID ] ; ok {
chanResult <- & PacketResponse { message . Packet , nil }
if msgCtx , ok := l . messageContex ts[ message . MessageID ] ; ok {
msgCtx . sendResponse ( & PacketResponse { message . Packet , nil } )
} else {
log . Printf ( "Received unexpected message %d, %v" , message . MessageID , l . isClosing )
log . Printf ( "Received unexpected message %d, %v" , message . MessageID , l . isClosing ( ) )
ber . PrintPacket ( message . Packet )
}
case MessageTimeout :
// Handle the timeout by closing the channel
// All reads will return immediately
if chanResult , ok := l . chanResults [ message . MessageID ] ; ok {
chanResult <- & PacketResponse { message . Packet , errors . New ( "ldap: connection timed out" ) }
if msgCtx , ok := l . messageContexts [ message . MessageID ] ; ok {
l . Debug . Printf ( "Receiving message timeout for %d" , message . MessageID )
delete ( l . chanResults , message . MessageID )
close ( chanResult )
msgCtx . sendResponse ( & PacketResponse { message . Packet , errors . New ( "ldap: connection timed out" ) } )
delete ( l . messageContexts , message . MessageID )
close ( msgCtx . responses )
}
case MessageFinish :
l . Debug . Printf ( "Finished message %d" , message . MessageID )
if chanResult , ok := l . chanResul ts[ message . MessageID ] ; ok {
close ( chanResult )
delete ( l . chanResults , message . MessageID )
if msgCtx , ok := l . messageContex ts[ message . MessageID ] ; ok {
delete ( l . messageContexts , message . MessageID )
close ( msgCtx . responses )
}
}
}
@ -396,7 +442,8 @@ func (l *Conn) reader() {
packet , err := ber . ReadPacket ( l . conn )
if err != nil {
// A read error is expected here if we are closing the connection...
if ! l . isClosing {
if ! l . isClosing ( ) {
l . closeErr . Store ( fmt . Errorf ( "unable to read LDAP response packet: %s" , err ) )
l . Debug . Printf ( "reader error: %s" , err . Error ( ) )
}
return
@ -419,6 +466,5 @@ func (l *Conn) reader() {
if ! l . sendProcessMessage ( message ) {
return
}
}
}