Compare commits

..

1 Commits

Author SHA1 Message Date
github-actions[bot] 94ef75a19e Update version to v3.8.3-patch.11 2025-09-30 10:16:49 +00:00
21 changed files with 819 additions and 933 deletions
+12 -10
View File
@@ -92,13 +92,12 @@ jobs:
contents: write contents: write
env: env:
SDK_DIR: openim-sdk-core SDK_DIR: openim-sdk-core
NOTIFICATION_CONFIG_PATH: config/notification.yml CONFIG_PATH: config/notification.yml
SHARE_CONFIG_PATH: config/share.yml # pull-requests: write
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ ubuntu-latest ]
go_version: ["1.22.x"] go_version: [ "1.22.x" ]
steps: steps:
- name: Checkout Server repository - name: Checkout Server repository
@@ -107,8 +106,7 @@ jobs:
- name: Checkout SDK repository - name: Checkout SDK repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
repository: "openimsdk/openim-sdk-core" repository: 'openimsdk/openim-sdk-core'
ref: "main"
path: ${{ env.SDK_DIR }} path: ${{ env.SDK_DIR }}
- name: Set up Go ${{ matrix.go_version }} - name: Set up Go ${{ matrix.go_version }}
@@ -121,11 +119,15 @@ jobs:
go install github.com/magefile/mage@latest go install github.com/magefile/mage@latest
go mod download go mod download
- name: Install yq
run: |
sudo wget https://github.com/mikefarah/yq/releases/download/v4.34.1/yq_linux_amd64 -O /usr/bin/yq
sudo chmod +x /usr/bin/yq
- name: Modify Server Configuration - name: Modify Server Configuration
run: | run: |
yq e '.groupCreated.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }} yq e '.groupCreated.unreadCount = true' -i ${{ env.CONFIG_PATH }}
yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }} yq e '.friendApplicationApproved.unreadCount = true' -i ${{ env.CONFIG_PATH }}
yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }}
- name: Start Server Services - name: Start Server Services
run: | run: |
+1 -2
View File
@@ -1,8 +1,7 @@
address: [ localhost:16379 ] address: [ localhost:16379 ]
username: username:
password: openIM123 password: openIM123
# redisMode can be "cluster", "sentinel", or "standalone" clusterMode: false
redisMode: "standalone"
db: 0 db: 0
maxRetry: 10 maxRetry: 10
poolSize: 100 poolSize: 100
+2 -2
View File
@@ -12,8 +12,8 @@ require (
github.com/gorilla/websocket v1.5.1 github.com/gorilla/websocket v1.5.1
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/openimsdk/protocol v0.0.73-alpha.18 github.com/openimsdk/protocol v0.0.73-alpha.12
github.com/openimsdk/tools v0.0.50-alpha.106 github.com/openimsdk/tools v0.0.50-alpha.103
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
+2 -4
View File
@@ -349,10 +349,8 @@ github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5b
github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI=
github.com/openimsdk/protocol v0.0.73-alpha.12 h1:2NYawXeHChYUeSme6QJ9pOLh+Empce2WmwEtbP4JvKk= github.com/openimsdk/protocol v0.0.73-alpha.12 h1:2NYawXeHChYUeSme6QJ9pOLh+Empce2WmwEtbP4JvKk=
github.com/openimsdk/protocol v0.0.73-alpha.12/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= github.com/openimsdk/protocol v0.0.73-alpha.12/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw=
github.com/openimsdk/protocol v0.0.73-alpha.18 h1:LXmDFx3KnMd2mN0/S3Q2U33Ft/DHvplSsINO0/bto/c= github.com/openimsdk/tools v0.0.50-alpha.103 h1:jYvI86cWiVu8a8iw1panw+pwIiStuUHF76h3fxA6ESI=
github.com/openimsdk/protocol v0.0.73-alpha.18/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= github.com/openimsdk/tools v0.0.50-alpha.103/go.mod h1:qCExFBqXpQBMzZck3XGIFwivBayAn2KNqB3WAd++IJw=
github.com/openimsdk/tools v0.0.50-alpha.106 h1:gaqU08IbRxOdL16ZEQNyYLhCTf7K1HNkrik8KZLA+BM=
github.com/openimsdk/tools v0.0.50-alpha.106/go.mod h1:x9i/e+WJFW4tocy6RNJQ9NofQiP3KJ1Y576/06TqOG4=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
+151 -10
View File
@@ -16,6 +16,7 @@ package msggateway
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -30,6 +31,7 @@ import (
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log" "github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/stringutil"
) )
var ( var (
@@ -62,7 +64,7 @@ type PingPongHandler func(string) error
type Client struct { type Client struct {
w *sync.Mutex w *sync.Mutex
conn ClientConn conn LongConn
PlatformID int `json:"platformID"` PlatformID int `json:"platformID"`
IsCompress bool `json:"isCompress"` IsCompress bool `json:"isCompress"`
UserID string `json:"userID"` UserID string `json:"userID"`
@@ -81,10 +83,10 @@ type Client struct {
} }
// ResetClient updates the client's state with new connection and context information. // ResetClient updates the client's state with new connection and context information.
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) { func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
c.w = new(sync.Mutex) c.w = new(sync.Mutex)
c.conn = conn c.conn = conn
c.PlatformID = ctx.GetPlatformID() c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
c.IsCompress = ctx.GetCompression() c.IsCompress = ctx.GetCompression()
c.IsBackground = ctx.GetBackground() c.IsBackground = ctx.GetBackground()
c.UserID = ctx.GetUserID() c.UserID = ctx.GetUserID()
@@ -108,6 +110,22 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServ
c.subUserIDs = make(map[string]struct{}) c.subUserIDs = make(map[string]struct{})
} }
func (c *Client) pingHandler(appData string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
return c.writePongMsg(appData)
}
func (c *Client) pongHandler(_ string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
return nil
}
// readMessage continuously reads messages from the connection. // readMessage continuously reads messages from the connection.
func (c *Client) readMessage() { func (c *Client) readMessage() {
defer func() { defer func() {
@@ -118,25 +136,52 @@ func (c *Client) readMessage() {
c.close() c.close()
}() }()
c.conn.SetReadLimit(maxMessageSize)
_ = c.conn.SetReadDeadline(pongWait)
c.conn.SetPongHandler(c.pongHandler)
c.conn.SetPingHandler(c.pingHandler)
c.activeHeartbeat(c.hbCtx)
for { for {
log.ZDebug(c.ctx, "readMessage") log.ZDebug(c.ctx, "readMessage")
message, returnErr := c.conn.ReadMessage() messageType, message, returnErr := c.conn.ReadMessage()
if returnErr != nil { if returnErr != nil {
log.ZWarn(c.ctx, "readMessage", returnErr) log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
c.closedErr = returnErr c.closedErr = returnErr
return return
} }
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
if c.closed.Load() { if c.closed.Load() {
// The scenario where the connection has just been closed, but the coroutine has not exited // The scenario where the connection has just been closed, but the coroutine has not exited
c.closedErr = ErrConnClosed c.closedErr = ErrConnClosed
return return
} }
parseDataErr := c.handleMessage(message) switch messageType {
if parseDataErr != nil { case MessageBinary:
c.closedErr = parseDataErr _ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handleMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case MessageText:
_ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handlerTextMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case PingMessage:
err := c.writePongMsg("")
log.ZError(c.ctx, "writePongMsg", err)
case CloseMessage:
c.closedErr = ErrClientClosed
return return
default:
} }
} }
} }
@@ -311,13 +356,109 @@ func (c *Client) writeBinaryMsg(resp Resp) error {
c.w.Lock() c.w.Lock()
defer c.w.Unlock() defer c.w.Unlock()
err = c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
if c.IsCompress { if c.IsCompress {
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
if compressErr != nil { if compressErr != nil {
return compressErr return compressErr
} }
return c.conn.WriteMessage(resultBuf) return c.conn.WriteMessage(MessageBinary, resultBuf)
} }
return c.conn.WriteMessage(encodedBuf) return c.conn.WriteMessage(MessageBinary, encodedBuf)
}
// Actively initiate Heartbeat when platform in Web.
func (c *Client) activeHeartbeat(ctx context.Context) {
if c.PlatformID == constant.WebPlatformID {
go func() {
defer func() {
if r := recover(); r != nil {
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
}
}()
log.ZDebug(ctx, "server initiative send heartbeat start.")
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writePingMsg(); err != nil {
log.ZWarn(c.ctx, "send Ping Message error.", err)
return
}
case <-c.hbCtx.Done():
return
}
}
}()
}
}
func (c *Client) writePingMsg() error {
if c.closed.Load() {
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
return c.conn.WriteMessage(PingMessage, nil)
}
func (c *Client) writePongMsg(appData string) error {
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
if c.closed.Load() {
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
return errs.Wrap(err)
}
err = c.conn.WriteMessage(PongMessage, []byte(appData))
if err != nil {
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
}
return errs.Wrap(err)
}
func (c *Client) handlerTextMessage(b []byte) error {
var msg TextMessage
if err := json.Unmarshal(b, &msg); err != nil {
return err
}
switch msg.Type {
case TextPong:
return nil
case TextPing:
msg.Type = TextPong
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
c.w.Lock()
defer c.w.Unlock()
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
return err
}
return c.conn.WriteMessage(MessageText, msgData)
default:
return fmt.Errorf("not support message type %s", msg.Type)
}
} }
-212
View File
@@ -1,212 +0,0 @@
package msggateway
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/openimsdk/tools/log"
)
var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection")
type ClientConn interface {
ReadMessage() ([]byte, error)
WriteMessage(message []byte) error
Close() error
}
type websocketMessage struct {
MessageType int
Data []byte
}
func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn {
c := &websocketClientConn{
readTimeout: readTimeout,
conn: conn,
writer: make(chan *websocketMessage, 256),
done: make(chan struct{}),
}
if readLimit > 0 {
c.conn.SetReadLimit(readLimit)
}
c.conn.SetPingHandler(c.pingHandler)
c.conn.SetPongHandler(c.pongHandler)
go c.loopSend()
if pingInterval > 0 {
go c.doPing(pingInterval)
}
return c
}
type websocketClientConn struct {
readTimeout time.Duration
conn *websocket.Conn
writer chan *websocketMessage
done chan struct{}
err atomic.Pointer[error]
}
func (c *websocketClientConn) ReadMessage() ([]byte, error) {
buf, err := c.readMessage()
if err != nil {
return nil, c.closeBy(fmt.Errorf("read message %w", err))
}
return buf, nil
}
func (c *websocketClientConn) WriteMessage(message []byte) error {
return c.writeMessage(websocket.BinaryMessage, message)
}
func (c *websocketClientConn) Close() error {
_ = c.closeBy(fmt.Errorf("websocket connection closed"))
return nil
}
func (c *websocketClientConn) closeBy(err error) error {
if !c.err.CompareAndSwap(nil, &err) {
return *c.err.Load()
}
close(c.done)
log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(),
"chan length", len(c.writer))
_ = c.conn.Close()
return err
}
func (c *websocketClientConn) writeMessage(messageType int, data []byte) error {
if errPtr := c.err.Load(); errPtr != nil {
return *errPtr
}
select {
case c.writer <- &websocketMessage{MessageType: messageType, Data: data}:
return nil
default:
return c.closeBy(ErrWriteFull)
}
}
func (c *websocketClientConn) loopSend() {
var err error
for {
select {
case <-c.done:
return
case msg := <-c.writer:
switch msg.MessageType {
case websocket.TextMessage, websocket.BinaryMessage:
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
default:
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
}
if err != nil {
_ = c.closeBy(err)
return
}
}
}
}
func (c *websocketClientConn) setReadDeadline() error {
deadline := time.Now().Add(c.readTimeout)
return c.conn.SetReadDeadline(deadline)
}
func (c *websocketClientConn) readMessage() ([]byte, error) {
for {
if err := c.setReadDeadline(); err != nil {
return nil, err
}
messageType, buf, err := c.conn.ReadMessage()
if err != nil {
return nil, err
}
switch messageType {
case websocket.BinaryMessage:
return buf, nil
case websocket.TextMessage:
if err := c.onReadTextMessage(buf); err != nil {
return nil, err
}
case websocket.PingMessage:
if err := c.pingHandler(string(buf)); err != nil {
return nil, err
}
case websocket.PongMessage:
if err := c.pongHandler(string(buf)); err != nil {
return nil, err
}
case websocket.CloseMessage:
if len(buf) == 0 {
return nil, errors.New("websocket connection closed by peer")
}
return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf))
default:
return nil, fmt.Errorf("unknown websocket message type %d", messageType)
}
}
}
func (c *websocketClientConn) onReadTextMessage(buf []byte) error {
var msg struct {
Type string `json:"type"`
Body json.RawMessage `json:"body"`
}
if err := json.Unmarshal(buf, &msg); err != nil {
return err
}
switch msg.Type {
case TextPong:
return nil
case TextPing:
msg.Type = TextPong
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
return c.writeMessage(websocket.TextMessage, msgData)
default:
return fmt.Errorf("not support text message type %s", msg.Type)
}
}
func (c *websocketClientConn) pingHandler(appData string) error {
log.ZWarn(context.Background(), "ping handler recv ping", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
if err := c.setReadDeadline(); err != nil {
return err
}
err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1))
if err != nil {
log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
}
log.ZWarn(context.Background(), "ping handler write pong success", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
return nil
}
func (c *websocketClientConn) pongHandler(string) error {
return nil
}
func (c *websocketClientConn) doPing(d time.Duration) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
_ = c.closeBy(fmt.Errorf("send ping %w", err))
return
}
}
}
}
+82 -133
View File
@@ -15,31 +15,18 @@
package msggateway package msggateway
import ( import (
"encoding/base64" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"encoding/json"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"time" "time"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/utils/encrypt" "github.com/openimsdk/tools/utils/encrypt"
"github.com/openimsdk/tools/utils/stringutil"
"github.com/openimsdk/tools/utils/timeutil" "github.com/openimsdk/tools/utils/timeutil"
) )
type UserConnContextInfo struct {
Token string `json:"token"`
UserID string `json:"userID"`
PlatformID int `json:"platformID"`
OperationID string `json:"operationID"`
Compression string `json:"compression"`
SDKType string `json:"sdkType"`
SendResponse bool `json:"sendResponse"`
Background bool `json:"background"`
}
type UserConnContext struct { type UserConnContext struct {
RespWriter http.ResponseWriter RespWriter http.ResponseWriter
Req *http.Request Req *http.Request
@@ -47,7 +34,6 @@ type UserConnContext struct {
Method string Method string
RemoteAddr string RemoteAddr string
ConnID string ConnID string
info *UserConnContextInfo
} }
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) { func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
@@ -71,7 +57,7 @@ func (c *UserConnContext) Value(key any) any {
case constant.ConnID: case constant.ConnID:
return c.GetConnID() return c.GetConnID()
case constant.OpUserPlatform: case constant.OpUserPlatform:
return c.GetPlatformID() return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
case constant.RemoteAddr: case constant.RemoteAddr:
return c.RemoteAddr return c.RemoteAddr
default: default:
@@ -96,91 +82,30 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
func newTempContext() *UserConnContext { func newTempContext() *UserConnContext {
return &UserConnContext{ return &UserConnContext{
Req: &http.Request{URL: &url.URL{}}, Req: &http.Request{URL: &url.URL{}},
info: &UserConnContextInfo{},
} }
} }
func (c *UserConnContext) ParseEssentialArgs() error {
query := c.Req.URL.Query()
if data := query.Get("v"); data != "" {
return c.parseByJson(data)
} else {
return c.parseByQuery(query, c.Req.Header)
}
}
func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error {
info := UserConnContextInfo{
Token: query.Get(Token),
UserID: query.Get(WsUserID),
OperationID: query.Get(OperationID),
Compression: query.Get(Compression),
SDKType: query.Get(SDKType),
}
platformID, err := strconv.Atoi(query.Get(PlatformID))
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
info.PlatformID = platformID
if val := query.Get(SendResponse); val != "" {
ok, err := strconv.ParseBool(val)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool")
}
info.SendResponse = ok
}
if info.Compression == "" {
info.Compression = header.Get(Compression)
}
background, err := strconv.ParseBool(query.Get(BackgroundStatus))
if err != nil {
return err
}
info.Background = background
return c.checkInfo(&info)
}
func (c *UserConnContext) parseByJson(data string) error {
reqInfo, err := base64.RawURLEncoding.DecodeString(data)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("data is not base64")
}
var info UserConnContextInfo
if err := json.Unmarshal(reqInfo, &info); err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error())
}
return c.checkInfo(&info)
}
func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error {
if info.OperationID == "" {
return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty")
}
if info.Token == "" {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
if info.UserID == "" {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid")
}
switch info.SDKType {
case "":
info.SDKType = GoSDK
case GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid")
}
c.info = info
return nil
}
func (c *UserConnContext) GetRemoteAddr() string { func (c *UserConnContext) GetRemoteAddr() string {
return c.RemoteAddr return c.RemoteAddr
} }
func (c *UserConnContext) Query(key string) (string, bool) {
var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) GetHeader(key string) (string, bool) {
var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) { func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value) c.RespWriter.Header().Set(key, value)
} }
@@ -194,69 +119,93 @@ func (c *UserConnContext) GetConnID() string {
} }
func (c *UserConnContext) GetUserID() string { func (c *UserConnContext) GetUserID() string {
if c == nil || c.info == nil { return c.Req.URL.Query().Get(WsUserID)
return ""
}
return c.info.UserID
} }
func (c *UserConnContext) GetPlatformID() int { func (c *UserConnContext) GetPlatformID() string {
if c == nil || c.info == nil { return c.Req.URL.Query().Get(PlatformID)
return 0
}
return c.info.PlatformID
} }
func (c *UserConnContext) GetOperationID() string { func (c *UserConnContext) GetOperationID() string {
if c == nil || c.info == nil { return c.Req.URL.Query().Get(OperationID)
return ""
}
return c.info.OperationID
} }
func (c *UserConnContext) SetOperationID(operationID string) { func (c *UserConnContext) SetOperationID(operationID string) {
if c.info == nil { values := c.Req.URL.Query()
c.info = &UserConnContextInfo{} values.Set(OperationID, operationID)
} c.Req.URL.RawQuery = values.Encode()
c.info.OperationID = operationID
} }
func (c *UserConnContext) GetToken() string { func (c *UserConnContext) GetToken() string {
if c == nil || c.info == nil { return c.Req.URL.Query().Get(Token)
return ""
}
return c.info.Token
} }
func (c *UserConnContext) GetCompression() bool { func (c *UserConnContext) GetCompression() bool {
return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol compression, exists := c.Query(Compression)
if exists && compression == GzipCompressionProtocol {
return true
} else {
compression, exists := c.GetHeader(Compression)
if exists && compression == GzipCompressionProtocol {
return true
}
}
return false
} }
func (c *UserConnContext) GetSDKType() string { func (c *UserConnContext) GetSDKType() string {
if c == nil || c.info == nil { sdkType := c.Req.URL.Query().Get(SDKType)
return GoSDK if sdkType == "" {
} sdkType = GoSDK
switch c.info.SDKType {
case "", GoSDK:
return GoSDK
case JsSDK:
return JsSDK
default:
return ""
} }
return sdkType
} }
func (c *UserConnContext) ShouldSendResp() bool { func (c *UserConnContext) ShouldSendResp() bool {
return c != nil && c.info != nil && c.info.SendResponse errResp, exists := c.Query(SendResponse)
if exists {
b, err := strconv.ParseBool(errResp)
if err != nil {
return false
} else {
return b
}
}
return false
} }
func (c *UserConnContext) SetToken(token string) { func (c *UserConnContext) SetToken(token string) {
if c.info == nil { c.Req.URL.RawQuery = Token + "=" + token
c.info = &UserConnContextInfo{}
}
c.info.Token = token
} }
func (c *UserConnContext) GetBackground() bool { func (c *UserConnContext) GetBackground() bool {
return c != nil && c.info != nil && c.info.Background b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil {
return false
}
return b
}
func (c *UserConnContext) ParseEssentialArgs() error {
_, exists := c.Query(Token)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
_, exists = c.Query(WsUserID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr, exists := c.Query(PlatformID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
_, err := strconv.Atoi(platformIDStr)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
switch sdkType, _ := c.Query(SDKType); sdkType {
case "", GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
}
return nil
} }
+179
View File
@@ -0,0 +1,179 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"encoding/json"
"net/http"
"time"
"github.com/openimsdk/tools/apiresp"
"github.com/gorilla/websocket"
"github.com/openimsdk/tools/errs"
)
type LongConn interface {
// Close this connection
Close() error
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
WriteMessage(messageType int, message []byte) error
// ReadMessage Read message from connection.
ReadMessage() (int, []byte, error)
// SetReadDeadline sets the read deadline on the underlying network connection,
// after a read has timed out, will return an error.
SetReadDeadline(timeout time.Duration) error
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
SetWriteDeadline(timeout time.Duration) error
// Dial Try to dial a connection,url must set auth args,header can control compress data
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
// IsNil Whether the connection of the current long connection is nil
IsNil() bool
// SetConnNil Set the connection of the current long connection to nil
SetConnNil()
// SetReadLimit sets the maximum size for a message read from the peer.bytes
SetReadLimit(limit int64)
SetPongHandler(handler PingPongHandler)
SetPingHandler(handler PingPongHandler)
// GenerateLongConn Check the connection of the current and when it was sent are the same
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
}
type GWebSocket struct {
protocolType int
conn *websocket.Conn
handshakeTimeout time.Duration
writeBufferSize int
}
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
}
func (d *GWebSocket) Close() error {
return d.conn.Close()
}
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout,
CheckOrigin: func(r *http.Request) bool { return true },
}
if d.writeBufferSize > 0 { // default is 4kb.
upgrader.WriteBufferSize = d.writeBufferSize
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
}
d.conn = conn
return nil
}
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
// d.setSendConn(d.conn)
return d.conn.WriteMessage(messageType, message)
}
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
// d.sendConn = sendConn
//}
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
return d.conn.ReadMessage()
}
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
return d.conn.SetReadDeadline(time.Now().Add(timeout))
}
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
if timeout <= 0 {
return errs.New("timeout must be greater than 0")
}
// TODO SetWriteDeadline Future add error handling
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
}
return nil
}
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
if err != nil {
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
}
d.conn = conn
return httpResp, nil
}
func (d *GWebSocket) IsNil() bool {
return d.conn == nil
//
// if d.conn != nil {
// return false
// }
// return true
}
func (d *GWebSocket) SetConnNil() {
d.conn = nil
}
func (d *GWebSocket) SetReadLimit(limit int64) {
d.conn.SetReadLimit(limit)
}
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
d.conn.SetPongHandler(handler)
}
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
d.conn.SetPingHandler(handler)
}
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
if err := d.GenerateLongConn(w, r); err != nil {
return err
}
data, err := json.Marshal(apiresp.ParseError(err))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
_ = d.Close()
return nil
}
func (d *GWebSocket) RespondWithSuccess() error {
data, err := json.Marshal(apiresp.ParseError(nil))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
return nil
}
+37 -56
View File
@@ -2,16 +2,13 @@ package msggateway
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gorilla/websocket"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli" "github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/tools/apiresp"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
@@ -29,8 +26,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{})
type LongConnServer interface { type LongConnServer interface {
Run(done chan error) error Run(done chan error) error
wsHandler(w http.ResponseWriter, r *http.Request) wsHandler(w http.ResponseWriter, r *http.Request)
@@ -47,7 +42,6 @@ type LongConnServer interface {
} }
type WsServer struct { type WsServer struct {
websocket *websocket.Upgrader
msgGatewayConfig *Config msgGatewayConfig *Config
port int port int
wsMaxConnNum int64 wsMaxConnNum int64
@@ -137,13 +131,9 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
o(&config) o(&config)
} }
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID) //userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
upgrader := &websocket.Upgrader{
HandshakeTimeout: config.handshakeTimeout,
CheckOrigin: func(r *http.Request) bool { return true },
}
v := validator.New() v := validator.New()
return &WsServer{ return &WsServer{
websocket: upgrader,
msgGatewayConfig: msgGatewayConfig, msgGatewayConfig: msgGatewayConfig,
port: config.port, port: config.port,
wsMaxConnNum: config.maxConnNum, wsMaxConnNum: config.maxConnNum,
@@ -449,39 +439,16 @@ func (ws *WsServer) unregisterClient(client *Client) {
// validateRespWithRequest checks if the response matches the expected userID and platformID. // validateRespWithRequest checks if the response matches the expected userID and platformID.
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error { func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
userID := ctx.GetUserID() userID := ctx.GetUserID()
platformID := ctx.GetPlatformID() platformID := stringutil.StringToInt32(ctx.GetPlatformID())
if resp.UserID != userID { if resp.UserID != userID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID)) return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
} }
if int(resp.PlatformID) != platformID { if resp.PlatformID != platformID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID)) return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
} }
return nil return nil
} }
func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) {
if !ctx.ShouldSendResp() {
httpError(ctx, err)
return
}
// the browser cannot get the response of upgrade failure
data, err := json.Marshal(apiresp.ParseError(err))
if err != nil {
log.ZError(ctx, "json marshal failed", err)
return
}
conn, upgradeErr := ws.websocket.Upgrade(w, r, nil)
if upgradeErr != nil {
log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data))
return
}
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data))
return
}
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
// Create a new connection context // Create a new connection context
connContext := newContext(w, r) connContext := newContext(w, r)
@@ -489,7 +456,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
// Check if the current number of online user connections exceeds the maximum limit // Check if the current number of online user connections exceeds the maximum limit
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
// If it exceeds the maximum connection number, return an error via HTTP and stop processing // If it exceeds the maximum connection number, return an error via HTTP and stop processing
ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
return return
} }
@@ -497,14 +464,26 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
err := connContext.ParseEssentialArgs() err := connContext.ParseEssentialArgs()
if err != nil { if err != nil {
// If there's an error during parsing, return an error via HTTP and stop processing // If there's an error during parsing, return an error via HTTP and stop processing
ws.handlerError(connContext, w, r, err)
httpError(connContext, err)
return return
} }
// Call the authentication client to parse the Token obtained from the context // Call the authentication client to parse the Token obtained from the context
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
if err != nil { if err != nil {
ws.handlerError(connContext, w, r, err) // If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
shouldSendError := connContext.ShouldSendResp()
if shouldSendError {
// Create a WebSocket connection object and attempt to send the error message via WebSocket
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
// If the error message is successfully sent via WebSocket, stop processing
return
}
}
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
httpError(connContext, err)
return return
} }
@@ -512,30 +491,32 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
err = ws.validateRespWithRequest(connContext, resp) err = ws.validateRespWithRequest(connContext, resp)
if err != nil { if err != nil {
// If validation fails, return an error via HTTP and stop processing // If validation fails, return an error via HTTP and stop processing
ws.handlerError(connContext, w, r, err) httpError(connContext, err)
return return
} }
conn, err := ws.websocket.Upgrade(w, r, nil)
if err != nil {
log.ZWarn(connContext, "websocket upgrade failed", err)
return
}
if connContext.ShouldSendResp() {
if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil {
log.ZWarn(connContext, "WriteMessage first response", err)
return
}
}
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
var pingInterval time.Duration log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
if connContext.GetPlatformID() == constant.WebPlatformID { // Create a WebSocket long connection object
pingInterval = pingPeriod wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
//If the creation of the long connection fails, the error is handled internally during the handshake process.
log.ZWarn(connContext, "long connection fails", err)
return
} else {
// Check if a normal response should be sent via WebSocket
shouldSendSuccessResp := connContext.ShouldSendResp()
if shouldSendSuccessResp {
// Attempt to send a success message through WebSocket
if err := wsLongConn.RespondWithSuccess(); err != nil {
// If the success message is successfully sent, end further processing
return
}
}
} }
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection // Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws) client.ResetClient(connContext, wsLongConn, ws)
// Register the client with the server and start message processing // Register the client with the server and start message processing
ws.registerChan <- client ws.registerChan <- client
+1 -1
View File
@@ -23,7 +23,7 @@ func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, re
conversationIDs = nil conversationIDs = nil
} }
return &conversation.GetFullOwnerConversationIDsResp{ return &conversation.GetFullOwnerConversationIDsResp{
Version: uint64(vl.Version), Version: idHash,
VersionID: vl.ID.Hex(), VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash, Equal: req.IdHash == idHash,
ConversationIDs: conversationIDs, ConversationIDs: conversationIDs,
File diff suppressed because it is too large Load Diff
+14 -16
View File
@@ -476,15 +476,14 @@ func (g *NotificationSender) GroupApplicationAcceptedNotification(ctx context.Co
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil { if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
return return
} }
uid := g.uuid() tips := &sdkws.GroupApplicationAcceptedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: g.uuid(),
Request: request,
}
for _, userID := range append(userIDs, req.FromUserID) { for _, userID := range append(userIDs, req.FromUserID) {
tips := &sdkws.GroupApplicationAcceptedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: uid,
Request: request,
}
if userID == req.FromUserID { if userID == req.FromUserID {
tips.ReceiverAs = applicantReceiver tips.ReceiverAs = applicantReceiver
} else { } else {
@@ -521,15 +520,14 @@ func (g *NotificationSender) GroupApplicationRejectedNotification(ctx context.Co
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil { if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
return return
} }
uid := g.uuid() tips := &sdkws.GroupApplicationRejectedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: g.uuid(),
Request: request,
}
for _, userID := range append(userIDs, req.FromUserID) { for _, userID := range append(userIDs, req.FromUserID) {
tips := &sdkws.GroupApplicationRejectedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: uid,
Request: request,
}
if userID == req.FromUserID { if userID == req.FromUserID {
tips.ReceiverAs = applicantReceiver tips.ReceiverAs = applicantReceiver
} else { } else {
+13 -22
View File
@@ -2,7 +2,6 @@ package group
import ( import (
"context" "context"
"errors"
"github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
@@ -13,24 +12,23 @@ import (
pbgroup "github.com/openimsdk/protocol/group" pbgroup "github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/datautil"
) )
const versionSyncLimit = 500 const versionSyncLimit = 500
func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID) userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) { if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !datautil.Contain(mcontext.GetOpUserID(ctx), userIDs...) { if !datautil.Contain(mcontext.GetOpUserID(ctx), userIDs...) {
return nil, errs.ErrNoPermission.WrapMsg("op user not in group") return nil, errs.ErrNoPermission.WrapMsg("op user not in group")
} }
} }
vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -39,7 +37,7 @@ func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou
userIDs = nil userIDs = nil
} }
return &pbgroup.GetFullGroupMemberUserIDsResp{ return &pbgroup.GetFullGroupMemberUserIDsResp{
Version: uint64(vl.Version), Version: idHash,
VersionID: vl.ID.Hex(), VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash, Equal: req.IdHash == idHash,
UserIDs: userIDs, UserIDs: userIDs,
@@ -63,7 +61,7 @@ func (s *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetF
groupIDs = nil groupIDs = nil
} }
return &pbgroup.GetFullJoinGroupIDsResp{ return &pbgroup.GetFullJoinGroupIDsResp{
Version: uint64(vl.Version), Version: idHash,
VersionID: vl.ID.Hex(), VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash, Equal: req.IdHash == idHash,
GroupIDs: groupIDs, GroupIDs: groupIDs,
@@ -148,8 +146,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou
return resp, nil return resp, nil
} }
func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil {
return nil, err return nil, err
} }
opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{ opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{
@@ -157,9 +155,9 @@ func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
VersionKey: req.UserID, VersionKey: req.UserID,
VersionID: req.VersionID, VersionID: req.VersionID,
VersionNumber: req.Version, VersionNumber: req.Version,
Version: s.db.FindJoinIncrVersion, Version: g.db.FindJoinIncrVersion,
CacheMaxVersion: s.db.FindMaxJoinGroupVersionCache, CacheMaxVersion: g.db.FindMaxJoinGroupVersionCache,
Find: s.getGroupsInfo, Find: g.getGroupsInfo,
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp { Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp {
return &pbgroup.GetIncrementalJoinGroupResp{ return &pbgroup.GetIncrementalJoinGroupResp{
VersionID: version.ID.Hex(), VersionID: version.ID.Hex(),
@@ -174,29 +172,22 @@ func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.
return opt.Build() return opt.Build()
} }
func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) { func (g *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) {
var num int var num int
resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp)
for _, memberReq := range req.ReqList { for _, memberReq := range req.ReqList {
if _, ok := resp[memberReq.GroupID]; ok { if _, ok := resp[memberReq.GroupID]; ok {
continue continue
} }
memberResp, err := s.GetIncrementalGroupMember(ctx, memberReq) memberResp, err := g.GetIncrementalGroupMember(ctx, memberReq)
if err != nil { if err != nil {
if errors.Is(err, servererrs.ErrDismissedAlready) {
log.ZWarn(ctx, "Failed to get incremental group member", err, "groupID", memberReq.GroupID, "request", memberReq)
continue
}
return nil, err return nil, err
} }
resp[memberReq.GroupID] = memberResp resp[memberReq.GroupID] = memberResp
num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete) num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete)
if num >= versionSyncLimit { if num >= versionSyncLimit {
break break
} }
} }
return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil
} }
+1 -1
View File
@@ -56,7 +56,7 @@ func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.G
userIDs = nil userIDs = nil
} }
return &relation.GetFullFriendUserIDsResp{ return &relation.GetFullFriendUserIDsResp{
Version: uint64(vl.Version), Version: idHash,
VersionID: vl.ID.Hex(), VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash, Equal: req.IdHash == idHash,
UserIDs: userIDs, UserIDs: userIDs,
+14 -14
View File
@@ -354,13 +354,13 @@ type User struct {
} }
type Redis struct { type Redis struct {
Address []string `mapstructure:"address"` Address []string `mapstructure:"address"`
Username string `mapstructure:"username"` Username string `mapstructure:"username"`
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
RedisMode string `mapstructure:"redisMode"` ClusterMode bool `mapstructure:"clusterMode"`
DB int `mapstructure:"db"` DB int `mapstructure:"db"`
MaxRetry int `mapstructure:"maxRetry"` MaxRetry int `mapstructure:"maxRetry"`
PoolSize int `mapstructure:"poolSize"` PoolSize int `mapstructure:"poolSize"`
} }
type BeforeConfig struct { type BeforeConfig struct {
@@ -511,13 +511,13 @@ func (m *Mongo) Build() *mongoutil.Config {
func (r *Redis) Build() *redisutil.Config { func (r *Redis) Build() *redisutil.Config {
return &redisutil.Config{ return &redisutil.Config{
RedisMode: r.RedisMode, ClusterMode: r.ClusterMode,
Address: r.Address, Address: r.Address,
Username: r.Username, Username: r.Username,
Password: r.Password, Password: r.Password,
DB: r.DB, DB: r.DB,
MaxRetry: r.MaxRetry, MaxRetry: r.MaxRetry,
PoolSize: r.PoolSize, PoolSize: r.PoolSize,
} }
} }
+7 -14
View File
@@ -49,7 +49,7 @@ func New[V any](opts ...Option) Cache[V] {
if opt.expirationEvict { if opt.expirationEvict {
return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict) return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
} else { } else {
return lru.NewLazyLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict) return lru.NewLayLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
} }
} }
if opt.localSlotNum == 1 { if opt.localSlotNum == 1 {
@@ -72,18 +72,11 @@ type cache[V any] struct {
func (c *cache[V]) onEvict(key string, value V) { func (c *cache[V]) onEvict(key string, value V) {
if c.link != nil { if c.link != nil {
// Do not delete other keys while the underlying LRU still holds its lock; lks := c.link.Del(key)
// defer linked deletions to avoid re-entering the same slot and deadlocking. for k := range lks {
if lks := c.link.Del(key); len(lks) > 0 { if key != k { // prevent deadlock
go c.delLinked(key, lks) c.local.Del(k)
} }
}
}
func (c *cache[V]) delLinked(src string, keys map[string]struct{}) {
for k := range keys {
if src != k {
c.local.Del(k)
} }
} }
} }
@@ -110,7 +103,7 @@ func (c *cache[V]) Get(ctx context.Context, key string, fetch func(ctx context.C
func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) { func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) {
if c.local != nil { if c.local != nil {
return c.local.Get(key, func() (V, error) { return c.local.Get(key, func() (V, error) {
if len(link) > 0 && c.link != nil { if len(link) > 0 {
c.link.Link(key, link...) c.link.Link(key, link...)
} }
return fetch(ctx) return fetch(ctx)
-67
View File
@@ -22,8 +22,6 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
) )
func TestName(t *testing.T) { func TestName(t *testing.T) {
@@ -93,68 +91,3 @@ func TestName(t *testing.T) {
t.Log("del", del.Load()) t.Log("del", del.Load())
// 137.35s // 137.35s
} }
// Test deadlock scenario when eviction callback deletes a linked key that hashes to the same slot.
func TestCacheEvictDeadlock(t *testing.T) {
ctx := context.Background()
c := New[string](WithLocalSlotNum(1), WithLocalSlotSize(1), WithLazy())
if _, err := c.GetLink(ctx, "k1", func(ctx context.Context) (string, error) {
return "v1", nil
}, "k2"); err != nil {
t.Fatalf("seed cache failed: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
_, _ = c.GetLink(ctx, "k2", func(ctx context.Context) (string, error) {
return "v2", nil
}, "k1")
}()
select {
case <-done:
// expected to finish quickly; current implementation deadlocks here.
case <-time.After(time.Second):
t.Fatal("GetLink deadlocked during eviction of linked key")
}
}
func TestExpirationLRUGetBatch(t *testing.T) {
l := lru.NewExpirationLRU[string, string](2, time.Minute, time.Second*5, EmptyTarget{}, nil)
keys := []string{"a", "b"}
values, err := l.GetBatch(keys, func(keys []string) (map[string]string, error) {
res := make(map[string]string)
for _, k := range keys {
res[k] = k + "_v"
}
return res, nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(values) != len(keys) {
t.Fatalf("expected %d values, got %d", len(keys), len(values))
}
for _, k := range keys {
if v, ok := values[k]; !ok || v != k+"_v" {
t.Fatalf("unexpected value for %s: %q, ok=%v", k, v, ok)
}
}
// second batch should hit cache
values, err = l.GetBatch(keys, func(keys []string) (map[string]string, error) {
t.Fatalf("should not fetch on cache hit")
return nil, nil
})
if err != nil {
t.Fatalf("unexpected error on cache hit: %v", err)
}
for _, k := range keys {
if v, ok := values[k]; !ok || v != k+"_v" {
t.Fatalf("unexpected cached value for %s: %q, ok=%v", k, v, ok)
}
}
}
+2 -47
View File
@@ -52,53 +52,8 @@ type ExpirationLRU[K comparable, V any] struct {
} }
func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) { func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
var ( //TODO implement me
err error panic("implement me")
results = make(map[K]V)
misses = make([]K, 0, len(keys))
)
for _, key := range keys {
x.lock.Lock()
v, ok := x.core.Get(key)
x.lock.Unlock()
if ok {
x.target.IncrGetHit()
v.lock.RLock()
results[key] = v.value
if v.err != nil && err == nil {
err = v.err
}
v.lock.RUnlock()
continue
}
misses = append(misses, key)
}
if len(misses) == 0 {
return results, err
}
fetchValues, fetchErr := fetch(misses)
if fetchErr != nil && err == nil {
err = fetchErr
}
for key, val := range fetchValues {
results[key] = val
if fetchErr != nil {
x.target.IncrGetFailed()
continue
}
x.target.IncrGetSuccess()
item := &expirationLruItem[V]{value: val}
x.lock.Lock()
x.core.Add(key, item)
x.lock.Unlock()
}
// any keys not returned from fetch remain absent (no cache write)
return results, err
} }
func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
+25 -31
View File
@@ -21,25 +21,25 @@ import (
"github.com/hashicorp/golang-lru/v2/simplelru" "github.com/hashicorp/golang-lru/v2/simplelru"
) )
type lazyLruItem[V any] struct { type layLruItem[V any] struct {
lock sync.Mutex lock sync.Mutex
expires int64 expires int64
err error err error
value V value V
} }
func NewLazyLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LazyLRU[K, V] { func NewLayLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LayLRU[K, V] {
var cb simplelru.EvictCallback[K, *lazyLruItem[V]] var cb simplelru.EvictCallback[K, *layLruItem[V]]
if onEvict != nil { if onEvict != nil {
cb = func(key K, value *lazyLruItem[V]) { cb = func(key K, value *layLruItem[V]) {
onEvict(key, value.value) onEvict(key, value.value)
} }
} }
core, err := simplelru.NewLRU[K, *lazyLruItem[V]](size, cb) core, err := simplelru.NewLRU[K, *layLruItem[V]](size, cb)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return &LazyLRU[K, V]{ return &LayLRU[K, V]{
core: core, core: core,
successTTL: successTTL, successTTL: successTTL,
failedTTL: failedTTL, failedTTL: failedTTL,
@@ -47,15 +47,15 @@ func NewLazyLRU[K comparable, V any](size int, successTTL, failedTTL time.Durati
} }
} }
type LazyLRU[K comparable, V any] struct { type LayLRU[K comparable, V any] struct {
lock sync.Mutex lock sync.Mutex
core *simplelru.LRU[K, *lazyLruItem[V]] core *simplelru.LRU[K, *layLruItem[V]]
successTTL time.Duration successTTL time.Duration
failedTTL time.Duration failedTTL time.Duration
target Target target Target
} }
func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { func (x *LayLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
x.lock.Lock() x.lock.Lock()
v, ok := x.core.Get(key) v, ok := x.core.Get(key)
if ok { if ok {
@@ -68,7 +68,7 @@ func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
return value, err return value, err
} }
} else { } else {
v = &lazyLruItem[V]{} v = &layLruItem[V]{}
x.core.Add(key, v) x.core.Add(key, v)
v.lock.Lock() v.lock.Lock()
x.lock.Unlock() x.lock.Unlock()
@@ -88,15 +88,15 @@ func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
return v.value, v.err return v.value, v.err
} }
func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) { func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
var ( var (
err error err error
once sync.Once once sync.Once
) )
res := make(map[K]V) res := make(map[K]V)
queries := make([]K, 0, len(keys)) queries := make([]K, 0)
setVs := make(map[K]*layLruItem[V])
for _, key := range keys { for _, key := range keys {
x.lock.Lock() x.lock.Lock()
v, ok := x.core.Get(key) v, ok := x.core.Get(key)
@@ -118,20 +118,14 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
} }
queries = append(queries, key) queries = append(queries, key)
} }
values, err1 := fetch(queries)
if len(queries) == 0 { if err1 != nil {
return res, err
}
values, fetchErr := fetch(queries)
if fetchErr != nil {
once.Do(func() { once.Do(func() {
err = fetchErr err = err1
}) })
} }
for key, val := range values { for key, val := range values {
v := &lazyLruItem[V]{} v := &layLruItem[V]{}
v.value = val v.value = val
if err == nil { if err == nil {
@@ -141,7 +135,7 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
v.expires = time.Now().Add(x.failedTTL).UnixMilli() v.expires = time.Now().Add(x.failedTTL).UnixMilli()
x.target.IncrGetFailed() x.target.IncrGetFailed()
} }
setVs[key] = v
x.lock.Lock() x.lock.Lock()
x.core.Add(key, v) x.core.Add(key, v)
x.lock.Unlock() x.lock.Unlock()
@@ -151,29 +145,29 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
return res, err return res, err
} }
//func (x *LazyLRU[K, V]) Has(key K) bool { //func (x *LayLRU[K, V]) Has(key K) bool {
// x.lock.Lock() // x.lock.Lock()
// defer x.lock.Unlock() // defer x.lock.Unlock()
// return x.core.Contains(key) // return x.core.Contains(key)
//} //}
func (x *LazyLRU[K, V]) Set(key K, value V) { func (x *LayLRU[K, V]) Set(key K, value V) {
x.lock.Lock() x.lock.Lock()
defer x.lock.Unlock() defer x.lock.Unlock()
x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
} }
func (x *LazyLRU[K, V]) SetHas(key K, value V) bool { func (x *LayLRU[K, V]) SetHas(key K, value V) bool {
x.lock.Lock() x.lock.Lock()
defer x.lock.Unlock() defer x.lock.Unlock()
if x.core.Contains(key) { if x.core.Contains(key) {
x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
return true return true
} }
return false return false
} }
func (x *LazyLRU[K, V]) Del(key K) bool { func (x *LayLRU[K, V]) Del(key K) bool {
x.lock.Lock() x.lock.Lock()
ok := x.core.Remove(key) ok := x.core.Remove(key)
x.lock.Unlock() x.lock.Unlock()
@@ -185,6 +179,6 @@ func (x *LazyLRU[K, V]) Del(key K) bool {
return ok return ok
} }
func (x *LazyLRU[K, V]) Stop() { func (x *LayLRU[K, V]) Stop() {
} }
+4 -5
View File
@@ -3,16 +3,15 @@ package rpccache
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/user"
"math/rand" "math/rand"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/user"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/localcache" "github.com/openimsdk/open-im-server/v3/pkg/localcache"
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru" "github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
@@ -47,7 +46,7 @@ func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis.
case false: case false:
log.ZDebug(ctx, "fullUserCache is false") log.ZDebug(ctx, "fullUserCache is false")
x.lruCache = lru.NewSlotLRU(1024, localcache.LRUStringHash, func() lru.LRU[string, []int32] { x.lruCache = lru.NewSlotLRU(1024, localcache.LRUStringHash, func() lru.LRU[string, []int32] {
return lru.NewLazyLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {}) return lru.NewLayLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {})
}) })
x.CurrentPhase.Store(DoSubscribeOver) x.CurrentPhase.Store(DoSubscribeOver)
x.Cond.Broadcast() x.Cond.Broadcast()
+1 -1
View File
@@ -1 +1 @@
v3.8.3-patch.15 v3.8.3-patch.11