Merge branch 'errcode' of github.com:OpenIMSDK/Open-IM-Server into errcode

This commit is contained in:
wangchuxiao
2023-06-14 12:20:52 +08:00
12 changed files with 216 additions and 88 deletions
+5 -2
View File
@@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error
return c.writeBinaryMsg(resp)
}
func (c *Client) KickOnlineMessage(ctx context.Context) error {
return nil
func (c *Client) KickOnlineMessage() error {
resp := Resp{
ReqIdentifier: WSKickOnlineMsg,
}
return c.writeBinaryMsg(resp)
}
func (c *Client) writeBinaryMsg(resp Resp) error {
+3
View File
@@ -91,6 +91,9 @@ func (c *UserConnContext) GetPlatformID() string {
func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID)
}
func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token)
}
func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil {
+8 -1
View File
@@ -2,6 +2,7 @@ package msggateway
import (
"context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@@ -17,7 +18,13 @@ import (
)
func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := cache.NewRedis()
if err != nil {
return err
}
msgModel := cache.NewMsgCacheModel(rdb)
s.LongConnServer.SetDiscoveryRegistry(client)
s.LongConnServer.SetCacheHandler(msgModel)
msggateway.RegisterMsgGatewayServer(server, s)
return nil
}
@@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf
for _, v := range req.KickUserIDList {
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
for _, client := range clients {
err := client.KickOnlineMessage(ctx)
err := client.KickOnlineMessage()
if err != nil {
return nil, err
}
+123 -38
View File
@@ -1,13 +1,19 @@
package msggateway
import (
"context"
"errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
redis "github.com/go-redis/redis/v8"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
@@ -22,7 +28,7 @@ type LongConnServer interface {
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error
//SetMessageHandler(msgRpcClient *rpcclient.MsgClient)
SetCacheHandler(cache cache.MsgModel)
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
UnRegister(c *Client)
Compressor
@@ -41,6 +47,7 @@ type WsServer struct {
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients *UserMap
clientPool sync.Pool
onlineUserNum int64
@@ -48,14 +55,23 @@ type WsServer struct {
handshakeTimeout time.Duration
hubServer *Server
validate *validator.Validate
cache cache.MsgModel
Compressor
Encoder
MessageHandler
}
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
ws.MessageHandler = NewGrpcHandler(ws.validate, client)
}
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
ws.cache = cache
}
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
@@ -92,12 +108,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
return new(Client)
},
},
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
validate: v,
clients: newUserMap(),
Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v,
clients: newUserMap(),
Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
}, nil
}
func (ws *WsServer) Run() error {
@@ -109,6 +126,8 @@ func (ws *WsServer) Run() error {
ws.registerClient(client)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo)
}
}
}()
@@ -119,26 +138,29 @@ func (ws *WsServer) Run() error {
func (ws *WsServer) registerClient(client *Client) {
var (
userOK bool
clientOK bool
cli []*Client
userOK bool
clientOK bool
oldClients []*Client
)
cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
ws.clients.Set(client.UserID, client)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
if !userOK {
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else {
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
ws.kickHandlerChan <- i
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK { //已经有同平台的连接存在
ws.clients.Set(client.UserID, client)
ws.multiTerminalLoginChecker(cli)
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli))
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else {
ws.clients.Set(client.UserID, client)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
}
}
@@ -156,7 +178,47 @@ func getRemoteAdders(client []*Client) string {
return ret
}
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) {
func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
return
}
fallthrough
case constant.AllLoginButSameTermKick:
if info.clientOK {
ws.clients.deleteClients(info.newClient.UserID, info.oldClients)
for _, c := range info.oldClients {
err := c.KickOnlineMessage()
if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage", err)
}
}
m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
if err != nil && err != redis.Nil {
log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
if m == nil {
log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m)
for k, _ := range m {
if k != info.newClient.ctx.GetToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID)
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m)
if err != nil {
log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
return
}
}
}
}
func (ws *WsServer) unregisterClient(client *Client) {
@@ -170,60 +232,83 @@ func (ws *WsServer) unregisterClient(client *Client) {
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r)
defer log.ZInfo(context.Background(), "wsHandler", "remote addr", "url", r.URL.String())
connContext := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit)
httpError(connContext, errs.ErrConnOverMaxNumLimit)
return
}
var (
token string
userID string
platformID string
exists bool
compression bool
token string
userID string
platformIDStr string
exists bool
compression bool
)
token, exists = context.Query(Token)
token, exists = connContext.Query(Token)
if !exists {
httpError(context, errs.ErrConnArgsErr)
httpError(connContext, errs.ErrConnArgsErr)
return
}
userID, exists = context.Query(WsUserID)
userID, exists = connContext.Query(WsUserID)
if !exists {
httpError(context, errs.ErrConnArgsErr)
httpError(connContext, errs.ErrConnArgsErr)
return
}
platformID, exists = context.Query(PlatformID)
if !exists || utils.StringToInt(platformID) == 0 {
httpError(context, errs.ErrConnArgsErr)
platformIDStr, exists = connContext.Query(PlatformID)
if !exists {
httpError(connContext, errs.ErrConnArgsErr)
return
}
// log.ZDebug(context2.Background(), "conn", "platformID", platformID)
err := tokenverify.WsVerifyToken(token, userID, platformID)
platformID, err := strconv.Atoi(platformIDStr)
if err != nil {
httpError(context, err)
httpError(connContext, errs.ErrConnArgsErr)
return
}
if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil {
httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return
}
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(context, err)
httpError(connContext, err)
return
}
compressProtoc, exists := context.Query(Compression)
compressProtoc, exists := connContext.Query(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
compressProtoc, exists = context.GetHeader(Compression)
compressProtoc, exists = connContext.GetHeader(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws)
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
ws.registerChan <- client
go client.readMessage()
}
+24
View File
@@ -3,6 +3,7 @@ package msggateway
import (
"context"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
"sync"
)
@@ -71,6 +72,29 @@ func (u *UserMap) delete(key string, connRemoteAddr string) (isDeleteUser bool)
}
return existed
}
func (u *UserMap) deleteClients(key string, clients []*Client) (isDeleteUser bool) {
m := utils.SliceToMapAny(clients, func(c *Client) (string, struct{}) {
return c.ctx.GetRemoteAddr(), struct{}{}
})
allClients, existed := u.m.Load(key)
if existed {
oldClients := allClients.([]*Client)
var a []*Client
for _, client := range oldClients {
if _, ok := m[client.ctx.GetRemoteAddr()]; !ok {
a = append(a, client)
}
}
if len(a) == 0 {
u.m.Delete(key)
return true
} else {
u.m.Store(key, a)
return false
}
}
return existed
}
func (u *UserMap) DeleteAll(key string) {
u.m.Delete(key)
}
+4 -4
View File
@@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (*
if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil {
return nil, err
}
token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID)))
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
if err != nil {
return nil, err
}
@@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil {
return nil, utils.Wrap(err, "")
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform)
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil {
return nil, err
}
@@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq)
if err != nil {
return nil, err
}
resp.UserID = claims.UID
resp.Platform = claims.Platform
resp.UserID = claims.UserID
resp.Platform = constant.PlatformIDToName(claims.PlatformID)
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
return resp, nil
}