mirror of
https://github.com/openimsdk/open-im-server.git
synced 2026-04-28 14:29:19 +08:00
feat: implement server-initiated heartbeat in msgGateway module (#2404)
* feat: implement send ping msg when platform is web in gateway. * add context life cycle control. * feat: implement heartbeat logic in msggateway. * update heartbeat logic. * update to correct method name and comment. * update initiate heartbeat logic. * rename ws_server * update writePingMsg logic * update log level to warn.
This commit is contained in:
@@ -0,0 +1,445 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msggateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
|
||||
pbAuth "github.com/openimsdk/protocol/auth"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
|
||||
"github.com/openimsdk/protocol/constant"
|
||||
"github.com/openimsdk/protocol/msggateway"
|
||||
"github.com/openimsdk/tools/discovery"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type LongConnServer interface {
|
||||
Run(done chan error) error
|
||||
wsHandler(w http.ResponseWriter, r *http.Request)
|
||||
GetUserAllCons(userID string) ([]*Client, bool)
|
||||
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
|
||||
Validate(s any) error
|
||||
SetDiscoveryRegistry(client discovery.SvcDiscoveryRegistry, config *Config)
|
||||
KickUserConn(client *Client) error
|
||||
UnRegister(c *Client)
|
||||
SetKickHandlerInfo(i *kickHandler)
|
||||
Compressor
|
||||
Encoder
|
||||
MessageHandler
|
||||
}
|
||||
|
||||
type WsServer struct {
|
||||
msgGatewayConfig *Config
|
||||
port int
|
||||
wsMaxConnNum int64
|
||||
registerChan chan *Client
|
||||
unregisterChan chan *Client
|
||||
kickHandlerChan chan *kickHandler
|
||||
clients *UserMap
|
||||
clientPool sync.Pool
|
||||
onlineUserNum atomic.Int64
|
||||
onlineUserConnNum atomic.Int64
|
||||
handshakeTimeout time.Duration
|
||||
writeBufferSize int
|
||||
validate *validator.Validate
|
||||
userClient *rpcclient.UserRpcClient
|
||||
authClient *rpcclient.Auth
|
||||
disCov discovery.SvcDiscoveryRegistry
|
||||
Compressor
|
||||
Encoder
|
||||
MessageHandler
|
||||
webhookClient *webhook.Client
|
||||
}
|
||||
|
||||
type kickHandler struct {
|
||||
clientOK bool
|
||||
oldClients []*Client
|
||||
newClient *Client
|
||||
}
|
||||
|
||||
func (ws *WsServer) SetDiscoveryRegistry(disCov discovery.SvcDiscoveryRegistry, config *Config) {
|
||||
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, &config.Share.RpcRegisterName)
|
||||
u := rpcclient.NewUserRpcClient(disCov, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
|
||||
ws.authClient = rpcclient.NewAuth(disCov, config.Share.RpcRegisterName.Auth)
|
||||
ws.userClient = &u
|
||||
ws.disCov = disCov
|
||||
}
|
||||
|
||||
func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) {
|
||||
err := ws.userClient.SetUserStatus(ctx, client.UserID, status, client.PlatformID)
|
||||
if err != nil {
|
||||
log.ZWarn(ctx, "SetUserStatus err", err)
|
||||
}
|
||||
switch status {
|
||||
case constant.Online:
|
||||
ws.webhookAfterUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOnline, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID())
|
||||
case constant.Offline:
|
||||
ws.webhookAfterUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOffline, client.UserID, client.PlatformID, client.ctx.GetConnID())
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WsServer) UnRegister(c *Client) {
|
||||
ws.unregisterChan <- c
|
||||
}
|
||||
|
||||
func (ws *WsServer) Validate(_ any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
|
||||
return ws.clients.GetAll(userID)
|
||||
}
|
||||
|
||||
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
|
||||
return ws.clients.Get(userID, platform)
|
||||
}
|
||||
|
||||
func NewWsServer(msgGatewayConfig *Config, opts ...Option) (*WsServer, error) {
|
||||
var config configs
|
||||
for _, o := range opts {
|
||||
o(&config)
|
||||
}
|
||||
v := validator.New()
|
||||
return &WsServer{
|
||||
msgGatewayConfig: msgGatewayConfig,
|
||||
port: config.port,
|
||||
wsMaxConnNum: config.maxConnNum,
|
||||
writeBufferSize: config.writeBufferSize,
|
||||
handshakeTimeout: config.handshakeTimeout,
|
||||
clientPool: sync.Pool{
|
||||
New: func() any {
|
||||
return new(Client)
|
||||
},
|
||||
},
|
||||
registerChan: make(chan *Client, 1000),
|
||||
unregisterChan: make(chan *Client, 1000),
|
||||
kickHandlerChan: make(chan *kickHandler, 1000),
|
||||
validate: v,
|
||||
clients: newUserMap(),
|
||||
Compressor: NewGzipCompressor(),
|
||||
Encoder: NewGobEncoder(),
|
||||
webhookClient: webhook.NewWebhookClient(msgGatewayConfig.WebhooksConfig.URL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ws *WsServer) Run(done chan error) error {
|
||||
var (
|
||||
client *Client
|
||||
netErr error
|
||||
shutdownDone = make(chan struct{}, 1)
|
||||
)
|
||||
|
||||
server := http.Server{Addr: ":" + stringutil.IntToString(ws.port), Handler: nil}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-shutdownDone:
|
||||
return
|
||||
case client = <-ws.registerChan:
|
||||
ws.registerClient(client)
|
||||
case client = <-ws.unregisterChan:
|
||||
ws.unregisterClient(client)
|
||||
case onlineInfo := <-ws.kickHandlerChan:
|
||||
ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient)
|
||||
}
|
||||
}
|
||||
}()
|
||||
netDone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
http.HandleFunc("/", ws.wsHandler)
|
||||
err := server.ListenAndServe()
|
||||
defer close(netDone)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
netErr = errs.WrapMsg(err, "ws start err", server.Addr)
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
sErr := server.Shutdown(ctx)
|
||||
if sErr != nil {
|
||||
return errs.WrapMsg(sErr, "shutdown err")
|
||||
}
|
||||
close(shutdownDone)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-netDone:
|
||||
}
|
||||
return netErr
|
||||
|
||||
}
|
||||
|
||||
var concurrentRequest = 3
|
||||
|
||||
func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error {
|
||||
conns, err := ws.disCov.GetConns(ctx, ws.msgGatewayConfig.Share.RpcRegisterName.MessageGateway)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wg := errgroup.Group{}
|
||||
wg.SetLimit(concurrentRequest)
|
||||
|
||||
// Online push user online message to other node
|
||||
for _, v := range conns {
|
||||
v := v
|
||||
log.ZDebug(ctx, " sendUserOnlineInfoToOtherNode conn ", "target", v.Target())
|
||||
if v.Target() == ws.disCov.GetSelfConnTarget() {
|
||||
log.ZDebug(ctx, "Filter out this node", "node", v.Target())
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Go(func() error {
|
||||
msgClient := msggateway.NewMsgGatewayClient(v)
|
||||
_, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{
|
||||
UserID: client.UserID,
|
||||
PlatformID: int32(client.PlatformID), Token: client.token,
|
||||
})
|
||||
if err != nil {
|
||||
log.ZWarn(ctx, "MultiTerminalLoginCheck err", err, "node", v.Target())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WsServer) SetKickHandlerInfo(i *kickHandler) {
|
||||
ws.kickHandlerChan <- i
|
||||
}
|
||||
|
||||
func (ws *WsServer) registerClient(client *Client) {
|
||||
var (
|
||||
userOK bool
|
||||
clientOK bool
|
||||
oldClients []*Client
|
||||
)
|
||||
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
||||
if !userOK {
|
||||
ws.clients.Set(client.UserID, client)
|
||||
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
|
||||
prommetrics.OnlineUserGauge.Add(1)
|
||||
ws.onlineUserNum.Add(1)
|
||||
ws.onlineUserConnNum.Add(1)
|
||||
} else {
|
||||
ws.multiTerminalLoginChecker(clientOK, oldClients, client)
|
||||
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
|
||||
if clientOK {
|
||||
ws.clients.Set(client.UserID, client)
|
||||
// There is already a connection to the platform
|
||||
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID",
|
||||
client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
|
||||
ws.onlineUserConnNum.Add(1)
|
||||
} else {
|
||||
ws.clients.Set(client.UserID, client)
|
||||
ws.onlineUserConnNum.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
log.ZDebug(client.ctx, "ws.msgGatewayConfig.Discovery.Enable", ws.msgGatewayConfig.Discovery.Enable)
|
||||
|
||||
if ws.msgGatewayConfig.Discovery.Enable != "k8s" {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = ws.sendUserOnlineInfoToOtherNode(client.ctx, client)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ws.SetUserOnlineStatus(client.ctx, client, constant.Online)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
log.ZInfo(
|
||||
client.ctx,
|
||||
"user online",
|
||||
"online user Num",
|
||||
ws.onlineUserNum.Load(),
|
||||
"online user conn Num",
|
||||
ws.onlineUserConnNum.Load(),
|
||||
)
|
||||
}
|
||||
|
||||
func getRemoteAdders(client []*Client) string {
|
||||
var ret string
|
||||
for i, c := range client {
|
||||
if i == 0 {
|
||||
ret = c.ctx.GetRemoteAddr()
|
||||
} else {
|
||||
ret += "@" + c.ctx.GetRemoteAddr()
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ws *WsServer) KickUserConn(client *Client) error {
|
||||
ws.clients.deleteClients(client.UserID, []*Client{client})
|
||||
return client.KickOnlineMessage()
|
||||
}
|
||||
|
||||
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
|
||||
switch ws.msgGatewayConfig.MsgGateway.MultiLoginPolicy {
|
||||
case constant.DefalutNotKick:
|
||||
case constant.PCAndOther:
|
||||
if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
|
||||
return
|
||||
}
|
||||
fallthrough
|
||||
case constant.AllLoginButSameTermKick:
|
||||
if !clientOK {
|
||||
return
|
||||
}
|
||||
ws.clients.deleteClients(newClient.UserID, oldClients)
|
||||
for _, c := range oldClients {
|
||||
err := c.KickOnlineMessage()
|
||||
if err != nil {
|
||||
log.ZWarn(c.ctx, "KickOnlineMessage", err)
|
||||
}
|
||||
}
|
||||
ctx := mcontext.WithMustInfoCtx(
|
||||
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
|
||||
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
|
||||
)
|
||||
if _, err := ws.authClient.InvalidateToken(ctx, newClient.token, newClient.UserID, newClient.PlatformID); err != nil {
|
||||
log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID,
|
||||
"platformID", newClient.PlatformID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WsServer) unregisterClient(client *Client) {
|
||||
defer ws.clientPool.Put(client)
|
||||
isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr())
|
||||
if isDeleteUser {
|
||||
ws.onlineUserNum.Add(-1)
|
||||
prommetrics.OnlineUserGauge.Dec()
|
||||
}
|
||||
ws.onlineUserConnNum.Add(-1)
|
||||
ws.SetUserOnlineStatus(client.ctx, client, constant.Offline)
|
||||
log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num",
|
||||
ws.onlineUserNum.Load(), "online user conn Num",
|
||||
ws.onlineUserConnNum.Load(),
|
||||
)
|
||||
}
|
||||
|
||||
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
||||
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
||||
userID := ctx.GetUserID()
|
||||
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
|
||||
if resp.UserID != userID {
|
||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
||||
}
|
||||
if resp.PlatformID != platformID {
|
||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a new connection context
|
||||
connContext := newContext(w, r)
|
||||
|
||||
// Check if the current number of online user connections exceeds the maximum limit
|
||||
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
||||
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
||||
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse essential arguments (e.g., user ID, Token)
|
||||
err := connContext.ParseEssentialArgs()
|
||||
if err != nil {
|
||||
// If there's an error during parsing, return an error via HTTP and stop processing
|
||||
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the authentication client to parse the Token obtained from the context
|
||||
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
||||
if err != nil {
|
||||
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
|
||||
shouldSendError := connContext.ShouldSendResp()
|
||||
if shouldSendError {
|
||||
// Create a WebSocket connection object and attempt to send the error message via WebSocket
|
||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
|
||||
// If the error message is successfully sent via WebSocket, stop processing
|
||||
return
|
||||
}
|
||||
}
|
||||
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the authentication response matches the request (e.g., user ID and platform ID)
|
||||
err = ws.validateRespWithRequest(connContext, resp)
|
||||
if err != nil {
|
||||
// If validation fails, return an error via HTTP and stop processing
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a WebSocket long connection object
|
||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
||||
//If the creation of the long connection fails, the error is handled internally during the handshake process.
|
||||
log.ZWarn(connContext, "long connection fails", err)
|
||||
return
|
||||
} else {
|
||||
// Check if a normal response should be sent via WebSocket
|
||||
shouldSendSuccessResp := connContext.ShouldSendResp()
|
||||
if shouldSendSuccessResp {
|
||||
// Attempt to send a success message through WebSocket
|
||||
if err := wsLongConn.RespondWithSuccess(); err != nil {
|
||||
// If the success message is successfully sent, end further processing
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
|
||||
client := ws.clientPool.Get().(*Client)
|
||||
client.ResetClient(connContext, wsLongConn, ws)
|
||||
|
||||
// Register the client with the server and start message processing
|
||||
ws.registerChan <- client
|
||||
go client.readMessage()
|
||||
}
|
||||
Reference in New Issue
Block a user