Files
open-im-server/internal/msggateway/n_ws_server.go
T

319 lines
8.9 KiB
Go
Raw Normal View History

2023-03-08 18:39:18 +08:00
package msggateway
2023-02-14 21:08:36 +08:00
import (
2023-06-14 10:15:58 +08:00
"context"
2023-02-14 21:08:36 +08:00
"errors"
2023-04-23 19:50:42 +08:00
"net/http"
2023-06-14 11:00:44 +08:00
"strconv"
2023-04-23 19:50:42 +08:00
"sync"
"sync/atomic"
"time"
2023-06-15 18:25:13 +08:00
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
2023-06-11 12:37:36 +08:00
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
2023-06-15 18:25:13 +08:00
"github.com/redis/go-redis/v9"
2023-06-11 12:37:36 +08:00
2023-03-23 15:14:50 +08:00
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
2023-03-16 10:46:06 +08:00
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
2023-02-16 16:32:31 +08:00
"github.com/go-playground/validator/v10"
2023-02-14 21:08:36 +08:00
)
2023-03-08 18:39:18 +08:00
type LongConnServer interface {
Run() error
wsHandler(w http.ResponseWriter, r *http.Request)
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error
2023-06-14 09:58:10 +08:00
SetCacheHandler(cache cache.MsgModel)
2023-05-25 20:57:19 +08:00
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
2023-03-08 18:39:18 +08:00
UnRegister(c *Client)
Compressor
Encoder
MessageHandler
}
2023-02-15 19:57:16 +08:00
var bufferPool = sync.Pool{
New: func() interface{} {
2023-02-22 21:06:55 +08:00
return make([]byte, 1024)
2023-02-15 19:57:16 +08:00
},
}
2023-02-16 16:32:31 +08:00
2023-02-14 21:08:36 +08:00
type WsServer struct {
2023-03-24 16:39:33 +08:00
port int
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
2023-06-14 09:58:10 +08:00
kickHandlerChan chan *kickHandler
2023-03-24 16:39:33 +08:00
clients *UserMap
clientPool sync.Pool
onlineUserNum int64
onlineUserConnNum int64
handshakeTimeout time.Duration
hubServer *Server
validate *validator.Validate
2023-06-14 09:58:10 +08:00
cache cache.MsgModel
2023-03-08 18:39:18 +08:00
Compressor
Encoder
MessageHandler
}
2023-06-14 09:58:10 +08:00
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
2023-03-08 18:39:18 +08:00
2023-05-25 20:57:19 +08:00
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
ws.MessageHandler = NewGrpcHandler(ws.validate, client)
2023-03-23 12:05:25 +08:00
}
2023-06-14 09:58:10 +08:00
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
ws.cache = cache
}
2023-03-23 12:05:25 +08:00
2023-03-08 18:39:18 +08:00
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
}
func (ws *WsServer) Validate(s interface{}) error {
2023-03-15 14:38:11 +08:00
return nil
2023-03-08 18:39:18 +08:00
}
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
return ws.clients.GetAll(userID)
}
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
return ws.clients.Get(userID, platform)
2023-02-14 21:08:36 +08:00
}
2023-03-08 18:39:18 +08:00
func NewWsServer(opts ...Option) (*WsServer, error) {
2023-02-14 21:08:36 +08:00
var config configs
for _, o := range opts {
o(&config)
}
if config.port < 1024 {
return nil, errors.New("port not allow to listen")
}
2023-03-08 18:39:18 +08:00
v := validator.New()
2023-02-14 21:08:36 +08:00
return &WsServer{
2023-02-16 16:32:31 +08:00
port: config.port,
wsMaxConnNum: config.maxConnNum,
handshakeTimeout: config.handshakeTimeout,
2023-02-14 21:08:36 +08:00
clientPool: sync.Pool{
New: func() interface{} {
return new(Client)
},
},
2023-06-14 09:58:10 +08:00
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v,
clients: newUserMap(),
Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
2023-02-14 21:08:36 +08:00
}, nil
}
func (ws *WsServer) Run() error {
2023-02-15 19:57:16 +08:00
var client *Client
go func() {
for {
select {
case client = <-ws.registerChan:
ws.registerClient(client)
2023-02-16 16:32:31 +08:00
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
2023-06-14 09:58:10 +08:00
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo)
2023-02-15 19:57:16 +08:00
}
}
}()
2023-05-04 15:06:23 +08:00
http.HandleFunc("/", ws.wsHandler)
// http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {})
2023-02-16 16:32:31 +08:00
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening
2023-02-15 19:57:16 +08:00
}
func (ws *WsServer) registerClient(client *Client) {
var (
2023-06-14 09:58:10 +08:00
userOK bool
clientOK bool
oldClients []*Client
2023-02-15 19:57:16 +08:00
)
2023-06-14 09:58:10 +08:00
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
2023-02-22 21:06:55 +08:00
if !userOK {
2023-06-14 15:07:53 +08:00
ws.clients.Set(client.UserID, client)
2023-06-07 11:00:07 +08:00
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
2023-02-16 16:32:31 +08:00
atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
2023-03-23 15:14:50 +08:00
2023-02-22 21:06:55 +08:00
} else {
2023-06-14 09:58:10 +08:00
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
ws.kickHandlerChan <- i
2023-06-07 11:00:07 +08:00
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
2023-06-14 15:07:53 +08:00
if clientOK {
ws.clients.Set(client.UserID, client)
//已经有同平台的连接存在
2023-06-14 09:58:10 +08:00
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
2023-06-02 17:05:42 +08:00
atomic.AddInt64(&ws.onlineUserConnNum, 1)
2023-02-22 21:06:55 +08:00
} else {
2023-06-14 15:07:53 +08:00
ws.clients.Set(client.UserID, client)
2023-02-16 16:32:31 +08:00
atomic.AddInt64(&ws.onlineUserConnNum, 1)
}
2023-02-15 19:57:16 +08:00
}
2023-03-23 15:14:50 +08:00
log.ZInfo(client.ctx, "user online", "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum)
2023-02-16 16:32:31 +08:00
}
2023-06-04 14:42:36 +08:00
func getRemoteAdders(client []*Client) string {
var ret string
for i, c := range client {
if i == 0 {
ret = c.ctx.GetRemoteAddr()
} else {
2023-06-07 15:11:35 +08:00
ret += "@" + c.ctx.GetRemoteAddr()
2023-06-04 14:42:36 +08:00
}
}
return ret
}
2023-02-16 16:32:31 +08:00
2023-06-14 09:58:10 +08:00
func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
return
}
fallthrough
case constant.AllLoginButSameTermKick:
if info.clientOK {
2023-06-14 11:53:12 +08:00
ws.clients.deleteClients(info.newClient.UserID, info.oldClients)
2023-06-14 09:58:10 +08:00
for _, c := range info.oldClients {
err := c.KickOnlineMessage()
if err != nil {
2023-06-14 11:53:12 +08:00
log.ZWarn(c.ctx, "KickOnlineMessage", err)
2023-06-14 09:58:10 +08:00
}
}
2023-06-14 11:53:12 +08:00
m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
if err != nil && err != redis.Nil {
log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
if m == nil {
log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m)
for k, _ := range m {
if k != info.newClient.ctx.GetToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID)
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m)
if err != nil {
log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
2023-06-14 09:58:10 +08:00
}
}
2023-02-16 16:32:31 +08:00
}
func (ws *WsServer) unregisterClient(client *Client) {
2023-03-08 18:39:18 +08:00
defer ws.clientPool.Put(client)
2023-06-07 11:00:07 +08:00
isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr())
2023-02-22 21:06:55 +08:00
if isDeleteUser {
2023-02-16 16:32:31 +08:00
atomic.AddInt64(&ws.onlineUserNum, -1)
2023-02-15 19:57:16 +08:00
}
2023-02-16 16:32:31 +08:00
atomic.AddInt64(&ws.onlineUserConnNum, -1)
2023-03-24 16:39:33 +08:00
log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum)
2023-02-15 19:57:16 +08:00
}
2023-02-16 16:32:31 +08:00
2023-02-14 21:08:36 +08:00
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
2023-06-14 10:15:58 +08:00
connContext := newContext(w, r)
2023-02-16 16:32:31 +08:00
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
2023-06-14 10:15:58 +08:00
httpError(connContext, errs.ErrConnOverMaxNumLimit)
2023-02-16 16:32:31 +08:00
return
}
var (
2023-06-14 11:00:44 +08:00
token string
userID string
platformIDStr string
exists bool
compression bool
2023-02-16 16:32:31 +08:00
)
2023-06-14 10:15:58 +08:00
token, exists = connContext.Query(Token)
2023-02-16 16:32:31 +08:00
if !exists {
2023-06-14 10:15:58 +08:00
httpError(connContext, errs.ErrConnArgsErr)
2023-02-16 16:32:31 +08:00
return
}
2023-06-14 10:15:58 +08:00
userID, exists = connContext.Query(WsUserID)
2023-02-16 16:32:31 +08:00
if !exists {
2023-06-14 10:15:58 +08:00
httpError(connContext, errs.ErrConnArgsErr)
2023-02-16 16:32:31 +08:00
return
}
2023-06-14 11:00:44 +08:00
platformIDStr, exists = connContext.Query(PlatformID)
if !exists {
2023-06-14 10:15:58 +08:00
httpError(connContext, errs.ErrConnArgsErr)
2023-02-16 16:32:31 +08:00
return
}
2023-06-14 11:00:44 +08:00
platformID, err := strconv.Atoi(platformIDStr)
2023-02-16 16:32:31 +08:00
if err != nil {
2023-06-14 11:00:44 +08:00
httpError(connContext, errs.ErrConnArgsErr)
return
}
if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil {
2023-06-14 10:15:58 +08:00
httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
2023-02-16 16:32:31 +08:00
return
}
2023-03-24 16:39:33 +08:00
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
2023-02-16 16:32:31 +08:00
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
2023-06-14 10:15:58 +08:00
httpError(connContext, err)
2023-02-16 16:32:31 +08:00
return
}
2023-06-14 10:15:58 +08:00
compressProtoc, exists := connContext.Query(Compression)
2023-02-16 16:32:31 +08:00
if exists {
2023-03-08 18:39:18 +08:00
if compressProtoc == GzipCompressionProtocol {
2023-02-16 16:32:31 +08:00
compression = true
}
}
2023-06-14 10:15:58 +08:00
compressProtoc, exists = connContext.GetHeader(Compression)
2023-02-16 16:32:31 +08:00
if exists {
2023-03-08 18:39:18 +08:00
if compressProtoc == GzipCompressionProtocol {
2023-02-16 16:32:31 +08:00
compression = true
2023-02-14 21:08:36 +08:00
}
}
2023-02-22 21:06:55 +08:00
client := ws.clientPool.Get().(*Client)
2023-06-14 10:15:58 +08:00
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
2023-02-16 16:32:31 +08:00
ws.registerChan <- client
go client.readMessage()
2023-02-14 21:08:36 +08:00
}