mirror of
https://github.com/openimsdk/open-im-server.git
synced 2026-05-02 16:15:59 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4f865f83c1 |
@@ -92,13 +92,12 @@ jobs:
|
|||||||
contents: write
|
contents: write
|
||||||
env:
|
env:
|
||||||
SDK_DIR: openim-sdk-core
|
SDK_DIR: openim-sdk-core
|
||||||
NOTIFICATION_CONFIG_PATH: config/notification.yml
|
CONFIG_PATH: config/notification.yml
|
||||||
SHARE_CONFIG_PATH: config/share.yml
|
# pull-requests: write
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ ubuntu-latest ]
|
||||||
go_version: ["1.22.x"]
|
go_version: [ "1.22.x" ]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Server repository
|
- name: Checkout Server repository
|
||||||
@@ -107,8 +106,7 @@ jobs:
|
|||||||
- name: Checkout SDK repository
|
- name: Checkout SDK repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
repository: "openimsdk/openim-sdk-core"
|
repository: 'openimsdk/openim-sdk-core'
|
||||||
ref: "main"
|
|
||||||
path: ${{ env.SDK_DIR }}
|
path: ${{ env.SDK_DIR }}
|
||||||
|
|
||||||
- name: Set up Go ${{ matrix.go_version }}
|
- name: Set up Go ${{ matrix.go_version }}
|
||||||
@@ -121,11 +119,15 @@ jobs:
|
|||||||
go install github.com/magefile/mage@latest
|
go install github.com/magefile/mage@latest
|
||||||
go mod download
|
go mod download
|
||||||
|
|
||||||
|
- name: Install yq
|
||||||
|
run: |
|
||||||
|
sudo wget https://github.com/mikefarah/yq/releases/download/v4.34.1/yq_linux_amd64 -O /usr/bin/yq
|
||||||
|
sudo chmod +x /usr/bin/yq
|
||||||
|
|
||||||
- name: Modify Server Configuration
|
- name: Modify Server Configuration
|
||||||
run: |
|
run: |
|
||||||
yq e '.groupCreated.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }}
|
yq e '.groupCreated.unreadCount = true' -i ${{ env.CONFIG_PATH }}
|
||||||
yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }}
|
yq e '.friendApplicationApproved.unreadCount = true' -i ${{ env.CONFIG_PATH }}
|
||||||
yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }}
|
|
||||||
|
|
||||||
- name: Start Server Services
|
- name: Start Server Services
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -63,7 +64,7 @@ type PingPongHandler func(string) error
|
|||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
w *sync.Mutex
|
w *sync.Mutex
|
||||||
conn ClientConn
|
conn LongConn
|
||||||
PlatformID int `json:"platformID"`
|
PlatformID int `json:"platformID"`
|
||||||
IsCompress bool `json:"isCompress"`
|
IsCompress bool `json:"isCompress"`
|
||||||
UserID string `json:"userID"`
|
UserID string `json:"userID"`
|
||||||
@@ -82,7 +83,7 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetClient updates the client's state with new connection and context information.
|
// ResetClient updates the client's state with new connection and context information.
|
||||||
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) {
|
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
|
||||||
c.w = new(sync.Mutex)
|
c.w = new(sync.Mutex)
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
||||||
@@ -109,6 +110,22 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServ
|
|||||||
c.subUserIDs = make(map[string]struct{})
|
c.subUserIDs = make(map[string]struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) pingHandler(appData string) error {
|
||||||
|
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
|
||||||
|
return c.writePongMsg(appData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) pongHandler(_ string) error {
|
||||||
|
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// readMessage continuously reads messages from the connection.
|
// readMessage continuously reads messages from the connection.
|
||||||
func (c *Client) readMessage() {
|
func (c *Client) readMessage() {
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -119,25 +136,52 @@ func (c *Client) readMessage() {
|
|||||||
c.close()
|
c.close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
c.conn.SetReadLimit(maxMessageSize)
|
||||||
|
_ = c.conn.SetReadDeadline(pongWait)
|
||||||
|
c.conn.SetPongHandler(c.pongHandler)
|
||||||
|
c.conn.SetPingHandler(c.pingHandler)
|
||||||
|
c.activeHeartbeat(c.hbCtx)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
log.ZDebug(c.ctx, "readMessage")
|
log.ZDebug(c.ctx, "readMessage")
|
||||||
message, returnErr := c.conn.ReadMessage()
|
messageType, message, returnErr := c.conn.ReadMessage()
|
||||||
if returnErr != nil {
|
if returnErr != nil {
|
||||||
log.ZWarn(c.ctx, "readMessage", returnErr)
|
log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
|
||||||
c.closedErr = returnErr
|
c.closedErr = returnErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
|
||||||
if c.closed.Load() {
|
if c.closed.Load() {
|
||||||
// The scenario where the connection has just been closed, but the coroutine has not exited
|
// The scenario where the connection has just been closed, but the coroutine has not exited
|
||||||
c.closedErr = ErrConnClosed
|
c.closedErr = ErrConnClosed
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
parseDataErr := c.handleMessage(message)
|
switch messageType {
|
||||||
if parseDataErr != nil {
|
case MessageBinary:
|
||||||
c.closedErr = parseDataErr
|
_ = c.conn.SetReadDeadline(pongWait)
|
||||||
|
parseDataErr := c.handleMessage(message)
|
||||||
|
if parseDataErr != nil {
|
||||||
|
c.closedErr = parseDataErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case MessageText:
|
||||||
|
_ = c.conn.SetReadDeadline(pongWait)
|
||||||
|
parseDataErr := c.handlerTextMessage(message)
|
||||||
|
if parseDataErr != nil {
|
||||||
|
c.closedErr = parseDataErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case PingMessage:
|
||||||
|
err := c.writePongMsg("")
|
||||||
|
log.ZError(c.ctx, "writePongMsg", err)
|
||||||
|
|
||||||
|
case CloseMessage:
|
||||||
|
c.closedErr = ErrClientClosed
|
||||||
return
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -312,13 +356,109 @@ func (c *Client) writeBinaryMsg(resp Resp) error {
|
|||||||
c.w.Lock()
|
c.w.Lock()
|
||||||
defer c.w.Unlock()
|
defer c.w.Unlock()
|
||||||
|
|
||||||
|
err = c.conn.SetWriteDeadline(writeWait)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if c.IsCompress {
|
if c.IsCompress {
|
||||||
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
|
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
|
||||||
if compressErr != nil {
|
if compressErr != nil {
|
||||||
return compressErr
|
return compressErr
|
||||||
}
|
}
|
||||||
return c.conn.WriteMessage(resultBuf)
|
return c.conn.WriteMessage(MessageBinary, resultBuf)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.conn.WriteMessage(encodedBuf)
|
return c.conn.WriteMessage(MessageBinary, encodedBuf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actively initiate Heartbeat when platform in Web.
|
||||||
|
func (c *Client) activeHeartbeat(ctx context.Context) {
|
||||||
|
if c.PlatformID == constant.WebPlatformID {
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
log.ZDebug(ctx, "server initiative send heartbeat start.")
|
||||||
|
ticker := time.NewTicker(pingPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := c.writePingMsg(); err != nil {
|
||||||
|
log.ZWarn(c.ctx, "send Ping Message error.", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.hbCtx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (c *Client) writePingMsg() error {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.w.Lock()
|
||||||
|
defer c.w.Unlock()
|
||||||
|
|
||||||
|
err := c.conn.SetWriteDeadline(writeWait)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.WriteMessage(PingMessage, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) writePongMsg(appData string) error {
|
||||||
|
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
|
||||||
|
if c.closed.Load() {
|
||||||
|
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.w.Lock()
|
||||||
|
defer c.w.Unlock()
|
||||||
|
|
||||||
|
err := c.conn.SetWriteDeadline(writeWait)
|
||||||
|
if err != nil {
|
||||||
|
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
|
||||||
|
return errs.Wrap(err)
|
||||||
|
}
|
||||||
|
err = c.conn.WriteMessage(PongMessage, []byte(appData))
|
||||||
|
if err != nil {
|
||||||
|
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs.Wrap(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handlerTextMessage(b []byte) error {
|
||||||
|
var msg TextMessage
|
||||||
|
if err := json.Unmarshal(b, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case TextPong:
|
||||||
|
return nil
|
||||||
|
case TextPing:
|
||||||
|
msg.Type = TextPong
|
||||||
|
msgData, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.w.Lock()
|
||||||
|
defer c.w.Unlock()
|
||||||
|
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.conn.WriteMessage(MessageText, msgData)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("not support message type %s", msg.Type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
package msggateway
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
|
|
||||||
"github.com/openimsdk/tools/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection")
|
|
||||||
|
|
||||||
type ClientConn interface {
|
|
||||||
ReadMessage() ([]byte, error)
|
|
||||||
WriteMessage(message []byte) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type websocketMessage struct {
|
|
||||||
MessageType int
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn {
|
|
||||||
c := &websocketClientConn{
|
|
||||||
readTimeout: readTimeout,
|
|
||||||
conn: conn,
|
|
||||||
writer: make(chan *websocketMessage, 256),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
if readLimit > 0 {
|
|
||||||
c.conn.SetReadLimit(readLimit)
|
|
||||||
}
|
|
||||||
c.conn.SetPingHandler(c.pingHandler)
|
|
||||||
c.conn.SetPongHandler(c.pongHandler)
|
|
||||||
|
|
||||||
go c.loopSend()
|
|
||||||
if pingInterval > 0 {
|
|
||||||
go c.doPing(pingInterval)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
type websocketClientConn struct {
|
|
||||||
readTimeout time.Duration
|
|
||||||
conn *websocket.Conn
|
|
||||||
writer chan *websocketMessage
|
|
||||||
done chan struct{}
|
|
||||||
err atomic.Pointer[error]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) ReadMessage() ([]byte, error) {
|
|
||||||
buf, err := c.readMessage()
|
|
||||||
if err != nil {
|
|
||||||
return nil, c.closeBy(fmt.Errorf("read message %w", err))
|
|
||||||
}
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) WriteMessage(message []byte) error {
|
|
||||||
return c.writeMessage(websocket.BinaryMessage, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) Close() error {
|
|
||||||
_ = c.closeBy(fmt.Errorf("websocket connection closed"))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) closeBy(err error) error {
|
|
||||||
if !c.err.CompareAndSwap(nil, &err) {
|
|
||||||
return *c.err.Load()
|
|
||||||
}
|
|
||||||
close(c.done)
|
|
||||||
log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(),
|
|
||||||
"chan length", len(c.writer))
|
|
||||||
_ = c.conn.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) writeMessage(messageType int, data []byte) error {
|
|
||||||
if errPtr := c.err.Load(); errPtr != nil {
|
|
||||||
return *errPtr
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case c.writer <- &websocketMessage{MessageType: messageType, Data: data}:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return c.closeBy(ErrWriteFull)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) loopSend() {
|
|
||||||
var err error
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
case msg := <-c.writer:
|
|
||||||
switch msg.MessageType {
|
|
||||||
case websocket.TextMessage, websocket.BinaryMessage:
|
|
||||||
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
|
|
||||||
default:
|
|
||||||
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
_ = c.closeBy(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) setReadDeadline() error {
|
|
||||||
deadline := time.Now().Add(c.readTimeout)
|
|
||||||
return c.conn.SetReadDeadline(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) readMessage() ([]byte, error) {
|
|
||||||
for {
|
|
||||||
if err := c.setReadDeadline(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
messageType, buf, err := c.conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch messageType {
|
|
||||||
case websocket.BinaryMessage:
|
|
||||||
return buf, nil
|
|
||||||
case websocket.TextMessage:
|
|
||||||
if err := c.onReadTextMessage(buf); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case websocket.PingMessage:
|
|
||||||
if err := c.pingHandler(string(buf)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case websocket.PongMessage:
|
|
||||||
if err := c.pongHandler(string(buf)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case websocket.CloseMessage:
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil, errors.New("websocket connection closed by peer")
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf))
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown websocket message type %d", messageType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) onReadTextMessage(buf []byte) error {
|
|
||||||
var msg struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Body json.RawMessage `json:"body"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(buf, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch msg.Type {
|
|
||||||
case TextPong:
|
|
||||||
return nil
|
|
||||||
case TextPing:
|
|
||||||
msg.Type = TextPong
|
|
||||||
msgData, err := json.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.writeMessage(websocket.TextMessage, msgData)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("not support text message type %s", msg.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) pingHandler(appData string) error {
|
|
||||||
log.ZWarn(context.Background(), "ping handler recv ping", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
|
||||||
if err := c.setReadDeadline(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1))
|
|
||||||
if err != nil {
|
|
||||||
log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
|
||||||
}
|
|
||||||
log.ZWarn(context.Background(), "ping handler write pong success", nil, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) pongHandler(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *websocketClientConn) doPing(d time.Duration) {
|
|
||||||
ticker := time.NewTicker(d)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.done:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
|
|
||||||
_ = c.closeBy(fmt.Errorf("send ping %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
// Copyright © 2023 OpenIM. All rights reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package msggateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/openimsdk/tools/apiresp"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/openimsdk/tools/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LongConn interface {
|
||||||
|
// Close this connection
|
||||||
|
Close() error
|
||||||
|
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
|
||||||
|
WriteMessage(messageType int, message []byte) error
|
||||||
|
// ReadMessage Read message from connection.
|
||||||
|
ReadMessage() (int, []byte, error)
|
||||||
|
// SetReadDeadline sets the read deadline on the underlying network connection,
|
||||||
|
// after a read has timed out, will return an error.
|
||||||
|
SetReadDeadline(timeout time.Duration) error
|
||||||
|
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
|
||||||
|
SetWriteDeadline(timeout time.Duration) error
|
||||||
|
// Dial Try to dial a connection,url must set auth args,header can control compress data
|
||||||
|
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
|
||||||
|
// IsNil Whether the connection of the current long connection is nil
|
||||||
|
IsNil() bool
|
||||||
|
// SetConnNil Set the connection of the current long connection to nil
|
||||||
|
SetConnNil()
|
||||||
|
// SetReadLimit sets the maximum size for a message read from the peer.bytes
|
||||||
|
SetReadLimit(limit int64)
|
||||||
|
SetPongHandler(handler PingPongHandler)
|
||||||
|
SetPingHandler(handler PingPongHandler)
|
||||||
|
// GenerateLongConn Check the connection of the current and when it was sent are the same
|
||||||
|
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
|
||||||
|
}
|
||||||
|
type GWebSocket struct {
|
||||||
|
protocolType int
|
||||||
|
conn *websocket.Conn
|
||||||
|
handshakeTimeout time.Duration
|
||||||
|
writeBufferSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
|
||||||
|
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) Close() error {
|
||||||
|
return d.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
upgrader := &websocket.Upgrader{
|
||||||
|
HandshakeTimeout: d.handshakeTimeout,
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
if d.writeBufferSize > 0 { // default is 4kb.
|
||||||
|
upgrader.WriteBufferSize = d.writeBufferSize
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
|
||||||
|
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
|
||||||
|
}
|
||||||
|
d.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
|
||||||
|
// d.setSendConn(d.conn)
|
||||||
|
return d.conn.WriteMessage(messageType, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
|
||||||
|
// d.sendConn = sendConn
|
||||||
|
//}
|
||||||
|
|
||||||
|
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
|
||||||
|
return d.conn.ReadMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
|
||||||
|
return d.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return errs.New("timeout must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO SetWriteDeadline Future add error handling
|
||||||
|
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
|
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
|
||||||
|
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
|
||||||
|
if err != nil {
|
||||||
|
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
|
||||||
|
}
|
||||||
|
d.conn = conn
|
||||||
|
return httpResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) IsNil() bool {
|
||||||
|
return d.conn == nil
|
||||||
|
//
|
||||||
|
// if d.conn != nil {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetConnNil() {
|
||||||
|
d.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetReadLimit(limit int64) {
|
||||||
|
d.conn.SetReadLimit(limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
|
||||||
|
d.conn.SetPongHandler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
|
||||||
|
d.conn.SetPingHandler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if err := d.GenerateLongConn(w, r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(apiresp.ParseError(err))
|
||||||
|
if err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "json marshal failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.WriteMessage(MessageText, data); err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "WriteMessage failed")
|
||||||
|
}
|
||||||
|
_ = d.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *GWebSocket) RespondWithSuccess() error {
|
||||||
|
data, err := json.Marshal(apiresp.ParseError(nil))
|
||||||
|
if err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "json marshal failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.WriteMessage(MessageText, data); err != nil {
|
||||||
|
_ = d.Close()
|
||||||
|
return errs.WrapMsg(err, "WriteMessage failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,17 +2,13 @@ package msggateway
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
||||||
"github.com/openimsdk/tools/apiresp"
|
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
||||||
@@ -46,7 +42,6 @@ type LongConnServer interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WsServer struct {
|
type WsServer struct {
|
||||||
websocket *websocket.Upgrader
|
|
||||||
msgGatewayConfig *Config
|
msgGatewayConfig *Config
|
||||||
port int
|
port int
|
||||||
wsMaxConnNum int64
|
wsMaxConnNum int64
|
||||||
@@ -136,13 +131,9 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
|
|||||||
o(&config)
|
o(&config)
|
||||||
}
|
}
|
||||||
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
|
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID)
|
||||||
upgrader := &websocket.Upgrader{
|
|
||||||
HandshakeTimeout: config.handshakeTimeout,
|
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
}
|
|
||||||
v := validator.New()
|
v := validator.New()
|
||||||
return &WsServer{
|
return &WsServer{
|
||||||
websocket: upgrader,
|
|
||||||
msgGatewayConfig: msgGatewayConfig,
|
msgGatewayConfig: msgGatewayConfig,
|
||||||
port: config.port,
|
port: config.port,
|
||||||
wsMaxConnNum: config.maxConnNum,
|
wsMaxConnNum: config.maxConnNum,
|
||||||
@@ -458,29 +449,6 @@ func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.P
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
if !ctx.ShouldSendResp() {
|
|
||||||
httpError(ctx, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// the browser cannot get the response of upgrade failure
|
|
||||||
data, err := json.Marshal(apiresp.ParseError(err))
|
|
||||||
if err != nil {
|
|
||||||
log.ZError(ctx, "json marshal failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conn, upgradeErr := ws.websocket.Upgrade(w, r, nil)
|
|
||||||
if upgradeErr != nil {
|
|
||||||
log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
|
||||||
log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// Create a new connection context
|
// Create a new connection context
|
||||||
connContext := newContext(w, r)
|
connContext := newContext(w, r)
|
||||||
@@ -488,7 +456,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Check if the current number of online user connections exceeds the maximum limit
|
// Check if the current number of online user connections exceeds the maximum limit
|
||||||
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
||||||
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
||||||
ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,14 +464,26 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
err := connContext.ParseEssentialArgs()
|
err := connContext.ParseEssentialArgs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If there's an error during parsing, return an error via HTTP and stop processing
|
// If there's an error during parsing, return an error via HTTP and stop processing
|
||||||
ws.handlerError(connContext, w, r, err)
|
|
||||||
|
httpError(connContext, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the authentication client to parse the Token obtained from the context
|
// Call the authentication client to parse the Token obtained from the context
|
||||||
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ws.handlerError(connContext, w, r, err)
|
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
|
||||||
|
shouldSendError := connContext.ShouldSendResp()
|
||||||
|
if shouldSendError {
|
||||||
|
// Create a WebSocket connection object and attempt to send the error message via WebSocket
|
||||||
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||||
|
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
|
||||||
|
// If the error message is successfully sent via WebSocket, stop processing
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
|
||||||
|
httpError(connContext, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,24 +491,32 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
err = ws.validateRespWithRequest(connContext, resp)
|
err = ws.validateRespWithRequest(connContext, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If validation fails, return an error via HTTP and stop processing
|
// If validation fails, return an error via HTTP and stop processing
|
||||||
ws.handlerError(connContext, w, r, err)
|
httpError(connContext, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn, err := ws.websocket.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.ZWarn(connContext, "websocket upgrade failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
|
||||||
|
|
||||||
var pingInterval time.Duration
|
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
||||||
if connContext.GetPlatformID() == strconv.Itoa(constant.WebPlatformID) {
|
// Create a WebSocket long connection object
|
||||||
pingInterval = pingPeriod
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||||
|
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
||||||
|
//If the creation of the long connection fails, the error is handled internally during the handshake process.
|
||||||
|
log.ZWarn(connContext, "long connection fails", err)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Check if a normal response should be sent via WebSocket
|
||||||
|
shouldSendSuccessResp := connContext.ShouldSendResp()
|
||||||
|
if shouldSendSuccessResp {
|
||||||
|
// Attempt to send a success message through WebSocket
|
||||||
|
if err := wsLongConn.RespondWithSuccess(); err != nil {
|
||||||
|
// If the success message is successfully sent, end further processing
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
|
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
|
||||||
client := ws.clientPool.Get().(*Client)
|
client := ws.clientPool.Get().(*Client)
|
||||||
client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws)
|
client.ResetClient(connContext, wsLongConn, ws)
|
||||||
|
|
||||||
// Register the client with the server and start message processing
|
// Register the client with the server and start message processing
|
||||||
ws.registerChan <- client
|
ws.registerChan <- client
|
||||||
|
|||||||
@@ -476,15 +476,14 @@ func (g *NotificationSender) GroupApplicationAcceptedNotification(ctx context.Co
|
|||||||
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
|
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
uid := g.uuid()
|
tips := &sdkws.GroupApplicationAcceptedTips{
|
||||||
|
Group: group,
|
||||||
|
OpUser: opUser,
|
||||||
|
HandleMsg: req.HandledMsg,
|
||||||
|
Uuid: g.uuid(),
|
||||||
|
Request: request,
|
||||||
|
}
|
||||||
for _, userID := range append(userIDs, req.FromUserID) {
|
for _, userID := range append(userIDs, req.FromUserID) {
|
||||||
tips := &sdkws.GroupApplicationAcceptedTips{
|
|
||||||
Group: group,
|
|
||||||
OpUser: opUser,
|
|
||||||
HandleMsg: req.HandledMsg,
|
|
||||||
Uuid: uid,
|
|
||||||
Request: request,
|
|
||||||
}
|
|
||||||
if userID == req.FromUserID {
|
if userID == req.FromUserID {
|
||||||
tips.ReceiverAs = applicantReceiver
|
tips.ReceiverAs = applicantReceiver
|
||||||
} else {
|
} else {
|
||||||
@@ -521,15 +520,14 @@ func (g *NotificationSender) GroupApplicationRejectedNotification(ctx context.Co
|
|||||||
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
|
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
uid := g.uuid()
|
tips := &sdkws.GroupApplicationRejectedTips{
|
||||||
|
Group: group,
|
||||||
|
OpUser: opUser,
|
||||||
|
HandleMsg: req.HandledMsg,
|
||||||
|
Uuid: g.uuid(),
|
||||||
|
Request: request,
|
||||||
|
}
|
||||||
for _, userID := range append(userIDs, req.FromUserID) {
|
for _, userID := range append(userIDs, req.FromUserID) {
|
||||||
tips := &sdkws.GroupApplicationRejectedTips{
|
|
||||||
Group: group,
|
|
||||||
OpUser: opUser,
|
|
||||||
HandleMsg: req.HandledMsg,
|
|
||||||
Uuid: uid,
|
|
||||||
Request: request,
|
|
||||||
}
|
|
||||||
if userID == req.FromUserID {
|
if userID == req.FromUserID {
|
||||||
tips.ReceiverAs = applicantReceiver
|
tips.ReceiverAs = applicantReceiver
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+7
-14
@@ -49,7 +49,7 @@ func New[V any](opts ...Option) Cache[V] {
|
|||||||
if opt.expirationEvict {
|
if opt.expirationEvict {
|
||||||
return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||||
} else {
|
} else {
|
||||||
return lru.NewLazyLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
return lru.NewLayLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if opt.localSlotNum == 1 {
|
if opt.localSlotNum == 1 {
|
||||||
@@ -72,18 +72,11 @@ type cache[V any] struct {
|
|||||||
|
|
||||||
func (c *cache[V]) onEvict(key string, value V) {
|
func (c *cache[V]) onEvict(key string, value V) {
|
||||||
if c.link != nil {
|
if c.link != nil {
|
||||||
// Do not delete other keys while the underlying LRU still holds its lock;
|
lks := c.link.Del(key)
|
||||||
// defer linked deletions to avoid re-entering the same slot and deadlocking.
|
for k := range lks {
|
||||||
if lks := c.link.Del(key); len(lks) > 0 {
|
if key != k { // prevent deadlock
|
||||||
go c.delLinked(key, lks)
|
c.local.Del(k)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cache[V]) delLinked(src string, keys map[string]struct{}) {
|
|
||||||
for k := range keys {
|
|
||||||
if src != k {
|
|
||||||
c.local.Del(k)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,7 +103,7 @@ func (c *cache[V]) Get(ctx context.Context, key string, fetch func(ctx context.C
|
|||||||
func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) {
|
func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) {
|
||||||
if c.local != nil {
|
if c.local != nil {
|
||||||
return c.local.Get(key, func() (V, error) {
|
return c.local.Get(key, func() (V, error) {
|
||||||
if len(link) > 0 && c.link != nil {
|
if len(link) > 0 {
|
||||||
c.link.Link(key, link...)
|
c.link.Link(key, link...)
|
||||||
}
|
}
|
||||||
return fetch(ctx)
|
return fetch(ctx)
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestName(t *testing.T) {
|
func TestName(t *testing.T) {
|
||||||
@@ -93,68 +91,3 @@ func TestName(t *testing.T) {
|
|||||||
t.Log("del", del.Load())
|
t.Log("del", del.Load())
|
||||||
// 137.35s
|
// 137.35s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test deadlock scenario when eviction callback deletes a linked key that hashes to the same slot.
|
|
||||||
func TestCacheEvictDeadlock(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
c := New[string](WithLocalSlotNum(1), WithLocalSlotSize(1), WithLazy())
|
|
||||||
|
|
||||||
if _, err := c.GetLink(ctx, "k1", func(ctx context.Context) (string, error) {
|
|
||||||
return "v1", nil
|
|
||||||
}, "k2"); err != nil {
|
|
||||||
t.Fatalf("seed cache failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
_, _ = c.GetLink(ctx, "k2", func(ctx context.Context) (string, error) {
|
|
||||||
return "v2", nil
|
|
||||||
}, "k1")
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
// expected to finish quickly; current implementation deadlocks here.
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("GetLink deadlocked during eviction of linked key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExpirationLRUGetBatch(t *testing.T) {
|
|
||||||
l := lru.NewExpirationLRU[string, string](2, time.Minute, time.Second*5, EmptyTarget{}, nil)
|
|
||||||
|
|
||||||
keys := []string{"a", "b"}
|
|
||||||
values, err := l.GetBatch(keys, func(keys []string) (map[string]string, error) {
|
|
||||||
res := make(map[string]string)
|
|
||||||
for _, k := range keys {
|
|
||||||
res[k] = k + "_v"
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if len(values) != len(keys) {
|
|
||||||
t.Fatalf("expected %d values, got %d", len(keys), len(values))
|
|
||||||
}
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := values[k]; !ok || v != k+"_v" {
|
|
||||||
t.Fatalf("unexpected value for %s: %q, ok=%v", k, v, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// second batch should hit cache
|
|
||||||
values, err = l.GetBatch(keys, func(keys []string) (map[string]string, error) {
|
|
||||||
t.Fatalf("should not fetch on cache hit")
|
|
||||||
return nil, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error on cache hit: %v", err)
|
|
||||||
}
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := values[k]; !ok || v != k+"_v" {
|
|
||||||
t.Fatalf("unexpected cached value for %s: %q, ok=%v", k, v, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -52,53 +52,8 @@ type ExpirationLRU[K comparable, V any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
||||||
var (
|
//TODO implement me
|
||||||
err error
|
panic("implement me")
|
||||||
results = make(map[K]V)
|
|
||||||
misses = make([]K, 0, len(keys))
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
x.lock.Lock()
|
|
||||||
v, ok := x.core.Get(key)
|
|
||||||
x.lock.Unlock()
|
|
||||||
if ok {
|
|
||||||
x.target.IncrGetHit()
|
|
||||||
v.lock.RLock()
|
|
||||||
results[key] = v.value
|
|
||||||
if v.err != nil && err == nil {
|
|
||||||
err = v.err
|
|
||||||
}
|
|
||||||
v.lock.RUnlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
misses = append(misses, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(misses) == 0 {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fetchValues, fetchErr := fetch(misses)
|
|
||||||
if fetchErr != nil && err == nil {
|
|
||||||
err = fetchErr
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, val := range fetchValues {
|
|
||||||
results[key] = val
|
|
||||||
if fetchErr != nil {
|
|
||||||
x.target.IncrGetFailed()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
x.target.IncrGetSuccess()
|
|
||||||
item := &expirationLruItem[V]{value: val}
|
|
||||||
x.lock.Lock()
|
|
||||||
x.core.Add(key, item)
|
|
||||||
x.lock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// any keys not returned from fetch remain absent (no cache write)
|
|
||||||
return results, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
||||||
|
|||||||
@@ -21,25 +21,25 @@ import (
|
|||||||
"github.com/hashicorp/golang-lru/v2/simplelru"
|
"github.com/hashicorp/golang-lru/v2/simplelru"
|
||||||
)
|
)
|
||||||
|
|
||||||
type lazyLruItem[V any] struct {
|
type layLruItem[V any] struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
expires int64
|
expires int64
|
||||||
err error
|
err error
|
||||||
value V
|
value V
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLazyLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LazyLRU[K, V] {
|
func NewLayLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LayLRU[K, V] {
|
||||||
var cb simplelru.EvictCallback[K, *lazyLruItem[V]]
|
var cb simplelru.EvictCallback[K, *layLruItem[V]]
|
||||||
if onEvict != nil {
|
if onEvict != nil {
|
||||||
cb = func(key K, value *lazyLruItem[V]) {
|
cb = func(key K, value *layLruItem[V]) {
|
||||||
onEvict(key, value.value)
|
onEvict(key, value.value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
core, err := simplelru.NewLRU[K, *lazyLruItem[V]](size, cb)
|
core, err := simplelru.NewLRU[K, *layLruItem[V]](size, cb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return &LazyLRU[K, V]{
|
return &LayLRU[K, V]{
|
||||||
core: core,
|
core: core,
|
||||||
successTTL: successTTL,
|
successTTL: successTTL,
|
||||||
failedTTL: failedTTL,
|
failedTTL: failedTTL,
|
||||||
@@ -47,15 +47,15 @@ func NewLazyLRU[K comparable, V any](size int, successTTL, failedTTL time.Durati
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type LazyLRU[K comparable, V any] struct {
|
type LayLRU[K comparable, V any] struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
core *simplelru.LRU[K, *lazyLruItem[V]]
|
core *simplelru.LRU[K, *layLruItem[V]]
|
||||||
successTTL time.Duration
|
successTTL time.Duration
|
||||||
failedTTL time.Duration
|
failedTTL time.Duration
|
||||||
target Target
|
target Target
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
func (x *LayLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
v, ok := x.core.Get(key)
|
v, ok := x.core.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
@@ -68,7 +68,7 @@ func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
|||||||
return value, err
|
return value, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
v = &lazyLruItem[V]{}
|
v = &layLruItem[V]{}
|
||||||
x.core.Add(key, v)
|
x.core.Add(key, v)
|
||||||
v.lock.Lock()
|
v.lock.Lock()
|
||||||
x.lock.Unlock()
|
x.lock.Unlock()
|
||||||
@@ -88,15 +88,15 @@ func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
|||||||
return v.value, v.err
|
return v.value, v.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
once sync.Once
|
once sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
res := make(map[K]V)
|
res := make(map[K]V)
|
||||||
queries := make([]K, 0, len(keys))
|
queries := make([]K, 0)
|
||||||
|
setVs := make(map[K]*layLruItem[V])
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
v, ok := x.core.Get(key)
|
v, ok := x.core.Get(key)
|
||||||
@@ -118,20 +118,14 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
|
|||||||
}
|
}
|
||||||
queries = append(queries, key)
|
queries = append(queries, key)
|
||||||
}
|
}
|
||||||
|
values, err1 := fetch(queries)
|
||||||
if len(queries) == 0 {
|
if err1 != nil {
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
values, fetchErr := fetch(queries)
|
|
||||||
if fetchErr != nil {
|
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
err = fetchErr
|
err = err1
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, val := range values {
|
for key, val := range values {
|
||||||
v := &lazyLruItem[V]{}
|
v := &layLruItem[V]{}
|
||||||
v.value = val
|
v.value = val
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -141,7 +135,7 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
|
|||||||
v.expires = time.Now().Add(x.failedTTL).UnixMilli()
|
v.expires = time.Now().Add(x.failedTTL).UnixMilli()
|
||||||
x.target.IncrGetFailed()
|
x.target.IncrGetFailed()
|
||||||
}
|
}
|
||||||
|
setVs[key] = v
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
x.core.Add(key, v)
|
x.core.Add(key, v)
|
||||||
x.lock.Unlock()
|
x.lock.Unlock()
|
||||||
@@ -151,29 +145,29 @@ func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
|
|||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//func (x *LazyLRU[K, V]) Has(key K) bool {
|
//func (x *LayLRU[K, V]) Has(key K) bool {
|
||||||
// x.lock.Lock()
|
// x.lock.Lock()
|
||||||
// defer x.lock.Unlock()
|
// defer x.lock.Unlock()
|
||||||
// return x.core.Contains(key)
|
// return x.core.Contains(key)
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) Set(key K, value V) {
|
func (x *LayLRU[K, V]) Set(key K, value V) {
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
defer x.lock.Unlock()
|
defer x.lock.Unlock()
|
||||||
x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
|
x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) SetHas(key K, value V) bool {
|
func (x *LayLRU[K, V]) SetHas(key K, value V) bool {
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
defer x.lock.Unlock()
|
defer x.lock.Unlock()
|
||||||
if x.core.Contains(key) {
|
if x.core.Contains(key) {
|
||||||
x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
|
x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) Del(key K) bool {
|
func (x *LayLRU[K, V]) Del(key K) bool {
|
||||||
x.lock.Lock()
|
x.lock.Lock()
|
||||||
ok := x.core.Remove(key)
|
ok := x.core.Remove(key)
|
||||||
x.lock.Unlock()
|
x.lock.Unlock()
|
||||||
@@ -185,6 +179,6 @@ func (x *LazyLRU[K, V]) Del(key K) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LazyLRU[K, V]) Stop() {
|
func (x *LayLRU[K, V]) Stop() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,15 @@ package rpccache
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
||||||
|
"github.com/openimsdk/protocol/constant"
|
||||||
|
"github.com/openimsdk/protocol/user"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
|
||||||
"github.com/openimsdk/protocol/constant"
|
|
||||||
"github.com/openimsdk/protocol/user"
|
|
||||||
|
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
|
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/localcache"
|
"github.com/openimsdk/open-im-server/v3/pkg/localcache"
|
||||||
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
|
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
|
||||||
@@ -47,7 +46,7 @@ func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis.
|
|||||||
case false:
|
case false:
|
||||||
log.ZDebug(ctx, "fullUserCache is false")
|
log.ZDebug(ctx, "fullUserCache is false")
|
||||||
x.lruCache = lru.NewSlotLRU(1024, localcache.LRUStringHash, func() lru.LRU[string, []int32] {
|
x.lruCache = lru.NewSlotLRU(1024, localcache.LRUStringHash, func() lru.LRU[string, []int32] {
|
||||||
return lru.NewLazyLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {})
|
return lru.NewLayLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {})
|
||||||
})
|
})
|
||||||
x.CurrentPhase.Store(DoSubscribeOver)
|
x.CurrentPhase.Store(DoSubscribeOver)
|
||||||
x.Cond.Broadcast()
|
x.Cond.Broadcast()
|
||||||
|
|||||||
+1
-1
@@ -1 +1 @@
|
|||||||
v3.8.3-patch.13
|
v3.8.3-patch.12
|
||||||
Reference in New Issue
Block a user