diff --git a/messaging/Contexts.go b/messaging/Contexts.go index 6e25b70..2e369ba 100644 --- a/messaging/Contexts.go +++ b/messaging/Contexts.go @@ -6,43 +6,51 @@ import ( ) type Contexts struct { - locker sync.RWMutex - connections map[string]*neffos.NSConn - callbacksOfAdd []func(string, *neffos.NSConn) - callbacksOfRemove []func(string) - authenticator func(map[string]string) (string, error) + locker sync.RWMutex + connectionOfContexts map[string]*neffos.NSConn + connectionId2ContextID map[string]string + callbacksOfAdd []func(string, *neffos.NSConn) + callbacksOfRemove []func(string) + authenticator func(map[string]string) (string, error) } func newContexts() *Contexts { return &Contexts{ - connections: make(map[string]*neffos.NSConn, 10), - callbacksOfAdd: make([]func(string, *neffos.NSConn), 0, 10), - callbacksOfRemove: make([]func(string), 0, 10), + connectionOfContexts: make(map[string]*neffos.NSConn, 10), + connectionId2ContextID: make(map[string]string, 10), + callbacksOfAdd: make([]func(string, *neffos.NSConn), 0, 10), + callbacksOfRemove: make([]func(string), 0, 10), } } -func (contexts *Contexts) add(id string, connection *neffos.NSConn) { +func (contexts *Contexts) add(contextID string, connection *neffos.NSConn) { locker.Lock() defer locker.Unlock() - contexts.connections[id] = connection + contexts.connectionOfContexts[contextID] = connection + contexts.connectionId2ContextID[connection.Conn.ID()] = contextID for _, callback := range contexts.callbacksOfAdd { - callback(id, connection) + callback(contextID, connection) } } -func (contexts *Contexts) remove(id string) { +func (contexts *Contexts) remove(connectionId string) { locker.Lock() defer locker.Unlock() - delete(contexts.connections, id) + contextID, has := contexts.connectionId2ContextID[connectionId] + if !has { + return + } + delete(contexts.connectionId2ContextID, connectionId) + delete(contexts.connectionOfContexts, contextID) for _, callback := range contexts.callbacksOfRemove { - callback(id) + callback(contextID) } } -func (contexts *Contexts) Get(id string) *neffos.NSConn { +func (contexts *Contexts) Get(contextID string) *neffos.NSConn { locker.RLock() defer locker.RUnlock() - connection, ok := contexts.connections[id] + connection, ok := contexts.connectionOfContexts[contextID] if ok { return connection } else { diff --git a/messaging/messaging.go b/messaging/messaging.go index c8bb80e..6de479d 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -55,16 +55,18 @@ func Register(namespaceName string, eventMapping neffos.Events) *Contexts { } delete(headers, "token") fmt.Println(token) - id := connection.Conn.ID() // Todo 需要根据 token 解析出 user id + contextID := connection.Conn.ID() // Todo 需要根据 token 解析出 user id if contexts.authenticator != nil { - var err error - id, err = contexts.authenticator(headers) + newContextID, err := contexts.authenticator(headers) if err != nil { connection.Emit("onLogin", wrapError(err.Error())) return nil } + if newContextID != "" { + contextID = newContextID + } } - contexts.add(id, connection) + contexts.add(contextID, connection) connection.Emit("onLogin", wrapError("")) return nil }