Merge branch 'main' into localcache

# Conflicts:
#	go.mod
#	go.sum
#	internal/msggateway/hub_server.go
#	internal/push/consumer_init.go
#	internal/push/offlinepush/fcm/push.go
#	internal/push/offlinepush/getui/push.go
#	internal/push/offlinepush/jpush/push.go
#	internal/push/push_handler.go
#	internal/push/push_rpc_server.go
#	internal/push/push_to_client.go
#	internal/rpc/friend/friend.go
#	internal/rpc/msg/server.go
#	internal/rpc/msg/verify.go
#	pkg/common/config/config.go
#	pkg/common/db/cache/conversation.go
#	pkg/common/db/cache/meta_cache.go
#	pkg/common/db/cache/msg.go
#	pkg/common/db/localcache/conversation.go
#	pkg/common/db/localcache/group.go
#	pkg/rpcclient/conversation.go
#	pkg/rpcclient/group.go
This commit is contained in:
withchao
2024-03-07 11:59:53 +08:00
326 changed files with 7738 additions and 3878 deletions
+2 -7
View File
@@ -22,9 +22,8 @@ import (
"time"
"github.com/dtm-labs/rockscache"
"github.com/redis/go-redis/v9"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
const (
@@ -49,11 +48,7 @@ type BlackCacheRedis struct {
blackDB relationtb.BlackModelInterface
}
func NewBlackCacheRedis(
rdb redis.UniversalClient,
blackDB relationtb.BlackModelInterface,
options rockscache.Options,
) BlackCache {
func NewBlackCacheRedis(rdb redis.UniversalClient, blackDB relationtb.BlackModelInterface, options rockscache.Options) BlackCache {
rcClient := rockscache.NewClient(rdb, options)
mc := NewMetaCacheRedis(rcClient)
b := config.Config.LocalCache.Friend
+26 -28
View File
@@ -24,12 +24,10 @@ import (
"strings"
"time"
"github.com/dtm-labs/rockscache"
"github.com/redis/go-redis/v9"
"github.com/OpenIMSDK/tools/utils"
"github.com/dtm-labs/rockscache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
const (
@@ -227,16 +225,16 @@ func (c *ConversationRedisCache) DelConversations(ownerUserID string, conversati
return cache
}
func (c *ConversationRedisCache) getConversationIndex(convsation *relationtb.ConversationModel, keys []string) (int, error) {
key := c.getConversationKey(convsation.OwnerUserID, convsation.ConversationID)
for _i, _key := range keys {
if _key == key {
return _i, nil
}
}
// func (c *ConversationRedisCache) getConversationIndex(convsation *relationtb.ConversationModel, keys []string) (int, error) {
// key := c.getConversationKey(convsation.OwnerUserID, convsation.ConversationID)
// for _i, _key := range keys {
// if _key == key {
// return _i, nil
// }
// }
return 0, errors.New("not found key:" + key + " in keys")
}
// return 0, errors.New("not found key:" + key + " in keys")
// }
func (c *ConversationRedisCache) GetConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error) {
//var keys []string
@@ -340,7 +338,7 @@ func (c *ConversationRedisCache) DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupI
return cache
}
func (c *ConversationRedisCache) getUserAllHasReadSeqsIndex(conversationID string, conversationIDs []string) (int, error) {
/* func (c *ConversationRedisCache) getUserAllHasReadSeqsIndex(conversationID string, conversationIDs []string) (int, error) {
for _i, _conversationID := range conversationIDs {
if _conversationID == conversationID {
return _i, nil
@@ -348,21 +346,21 @@ func (c *ConversationRedisCache) getUserAllHasReadSeqsIndex(conversationID strin
}
return 0, errors.New("not found key:" + conversationID + " in keys")
}
} */
//func (c *ConversationRedisCache) GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error) {
// conversationIDs, err := c.GetUserConversationIDs(ctx, ownerUserID)
// if err != nil {
// return nil, err
// }
// var keys []string
// for _, conversarionID := range conversationIDs {
// keys = append(keys, c.getConversationHasReadSeqKey(ownerUserID, conversarionID))
// }
// return batchGetCacheMap(ctx, c.rcClient, keys, conversationIDs, c.expireTime, c.getUserAllHasReadSeqsIndex, func(ctx context.Context) (map[string]int64, error) {
// return c.conversationDB.GetUserAllHasReadSeqs(ctx, ownerUserID)
// })
//}
/* func (c *ConversationRedisCache) GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error) {
conversationIDs, err := c.GetUserConversationIDs(ctx, ownerUserID)
if err != nil {
return nil, err
}
var keys []string
for _, conversarionID := range conversationIDs {
keys = append(keys, c.getConversationHasReadSeqKey(ownerUserID, conversarionID))
}
return batchGetCacheMap(ctx, c.rcClient, keys, conversationIDs, c.expireTime, c.getUserAllHasReadSeqsIndex, func(ctx context.Context) (map[string]int64, error) {
return c.conversationDB.GetUserAllHasReadSeqs(ctx, ownerUserID)
})
} */
func (c *ConversationRedisCache) DelUserAllHasReadSeqs(ownerUserID string, conversationIDs ...string) ConversationCache {
cache := c.NewCache()
+2 -4
View File
@@ -21,12 +21,10 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"time"
"github.com/dtm-labs/rockscache"
"github.com/redis/go-redis/v9"
"github.com/OpenIMSDK/tools/utils"
"github.com/dtm-labs/rockscache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
const (
+2 -6
View File
@@ -23,15 +23,11 @@ import (
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/dtm-labs/rockscache"
"github.com/redis/go-redis/v9"
"github.com/OpenIMSDK/tools/utils"
"github.com/dtm-labs/rockscache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
const (
+22 -21
View File
@@ -22,12 +22,10 @@ import (
"strings"
"time"
"github.com/redis/go-redis/v9"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mw/specialerror"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/redis/go-redis/v9"
)
var (
@@ -40,32 +38,32 @@ const (
)
// NewRedis Initialize redis connection.
func NewRedis() (redis.UniversalClient, error) {
func NewRedis(config *config.GlobalConfig) (redis.UniversalClient, error) {
if redisClient != nil {
return redisClient, nil
}
// Read configuration from environment variables
overrideConfigFromEnv()
overrideConfigFromEnv(config)
if len(config.Config.Redis.Address) == 0 {
return nil, errors.New("redis address is empty")
if len(config.Redis.Address) == 0 {
return nil, errs.Wrap(errors.New("redis address is empty"))
}
specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound)
var rdb redis.UniversalClient
if len(config.Config.Redis.Address) > 1 || config.Config.Redis.ClusterMode {
if len(config.Redis.Address) > 1 || config.Redis.ClusterMode {
rdb = redis.NewClusterClient(&redis.ClusterOptions{
Addrs: config.Config.Redis.Address,
Username: config.Config.Redis.Username,
Password: config.Config.Redis.Password, // no password set
Addrs: config.Redis.Address,
Username: config.Redis.Username,
Password: config.Redis.Password, // no password set
PoolSize: 50,
MaxRetries: maxRetry,
})
} else {
rdb = redis.NewClient(&redis.Options{
Addr: config.Config.Redis.Address[0],
Username: config.Config.Redis.Username,
Password: config.Config.Redis.Password,
Addr: config.Redis.Address[0],
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: 0, // use default DB
PoolSize: 100, // connection pool size
MaxRetries: maxRetry,
@@ -77,30 +75,33 @@ func NewRedis() (redis.UniversalClient, error) {
defer cancel()
err = rdb.Ping(ctx).Err()
if err != nil {
return nil, fmt.Errorf("redis ping %w", err)
errMsg := fmt.Sprintf("address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t", config.Redis.Address, config.Redis.Username,
config.Redis.Password, config.Redis.ClusterMode, config.Redis.EnablePipeline)
return nil, errs.Wrap(err, errMsg)
}
redisClient = rdb
return rdb, err
}
// overrideConfigFromEnv overrides configuration fields with environment variables if present.
func overrideConfigFromEnv() {
func overrideConfigFromEnv(config *config.GlobalConfig) {
if envAddr := os.Getenv("REDIS_ADDRESS"); envAddr != "" {
if envPort := os.Getenv("REDIS_PORT"); envPort != "" {
addresses := strings.Split(envAddr, ",")
for i, addr := range addresses {
addresses[i] = addr + ":" + envPort
}
config.Config.Redis.Address = addresses
config.Redis.Address = addresses
} else {
config.Config.Redis.Address = strings.Split(envAddr, ",")
config.Redis.Address = strings.Split(envAddr, ",")
}
}
if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" {
config.Config.Redis.Username = envUser
config.Redis.Username = envUser
}
if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" {
config.Config.Redis.Password = envPass
config.Redis.Password = envPass
}
}
+8 -10
View File
@@ -19,15 +19,14 @@ import (
"encoding/json"
"errors"
"github.com/redis/go-redis/v9"
"fmt"
"time"
"github.com/OpenIMSDK/tools/mw/specialerror"
"github.com/dtm-labs/rockscache"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/mw/specialerror"
"github.com/OpenIMSDK/tools/utils"
"github.com/dtm-labs/rockscache"
)
const (
@@ -151,14 +150,14 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
}
bs, err := json.Marshal(t)
if err != nil {
return "", utils.Wrap(err, "")
return "", errs.Wrap(err, "marshal failed")
}
write = true
return string(bs), nil
})
if err != nil {
return t, err
return t, errs.Wrap(err)
}
if write {
return t, nil
@@ -168,9 +167,8 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
}
err = json.Unmarshal([]byte(v), &t)
if err != nil {
log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire)
return t, utils.Wrap(err, "")
errInfo := fmt.Sprintf("cache json.Unmarshal failed, key:%s, value:%s, expire:%s", key, v, expire)
return t, errs.Wrap(err, errInfo)
}
return t, nil
@@ -234,7 +232,7 @@ func batchGetCache2[T any, K comparable](
if errs.ErrRecordNotFound.Is(specialerror.ErrCode(errs.Unwrap(err))) {
continue
}
return nil, err
return nil, errs.Wrap(err)
}
res = append(res, val)
}
+64 -57
View File
@@ -21,22 +21,16 @@ import (
"strconv"
"time"
"golang.org/x/sync/errgroup"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
"github.com/OpenIMSDK/tools/errs"
"github.com/gogo/protobuf/jsonpb"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/utils"
"github.com/gogo/protobuf/jsonpb"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/errgroup"
)
const (
@@ -128,14 +122,14 @@ type MsgModel interface {
UnLockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error
}
func NewMsgCacheModel(client redis.UniversalClient) MsgModel {
rcClient := rockscache.NewClient(client, rockscache.NewDefaultOptions())
return &msgCache{metaCache: NewMetaCacheRedis(rcClient), rdb: client}
func NewMsgCacheModel(client redis.UniversalClient, config *config.GlobalConfig) MsgModel {
return &msgCache{rdb: client, config: config}
}
type msgCache struct {
metaCache
rdb redis.UniversalClient
rdb redis.UniversalClient
config *config.GlobalConfig
}
func (c *msgCache) getMaxSeqKey(conversationID string) string {
@@ -155,11 +149,15 @@ func (c *msgCache) getConversationUserMinSeqKey(conversationID, userID string) s
}
func (c *msgCache) setSeq(ctx context.Context, conversationID string, seq int64, getkey func(conversationID string) string) error {
return utils.Wrap1(c.rdb.Set(ctx, getkey(conversationID), seq, 0).Err())
return errs.Wrap(c.rdb.Set(ctx, getkey(conversationID), seq, 0).Err())
}
func (c *msgCache) getSeq(ctx context.Context, conversationID string, getkey func(conversationID string) string) (int64, error) {
return utils.Wrap2(c.rdb.Get(ctx, getkey(conversationID)).Int64())
val, err := c.rdb.Get(ctx, getkey(conversationID)).Int64()
if err != nil {
return 0, errs.Wrap(err)
}
return val, nil
}
func (c *msgCache) getSeqs(ctx context.Context, items []string, getkey func(s string) string) (m map[string]int64, err error) {
@@ -216,7 +214,11 @@ func (c *msgCache) GetMinSeq(ctx context.Context, conversationID string) (int64,
}
func (c *msgCache) GetConversationUserMinSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return utils.Wrap2(c.rdb.Get(ctx, c.getConversationUserMinSeqKey(conversationID, userID)).Int64())
val, err := c.rdb.Get(ctx, c.getConversationUserMinSeqKey(conversationID, userID)).Int64()
if err != nil {
return 0, errs.Wrap(err)
}
return val, nil
}
func (c *msgCache) GetConversationUserMinSeqs(ctx context.Context, conversationID string, userIDs []string) (m map[string]int64, err error) {
@@ -226,7 +228,7 @@ func (c *msgCache) GetConversationUserMinSeqs(ctx context.Context, conversationI
}
func (c *msgCache) SetConversationUserMinSeq(ctx context.Context, conversationID string, userID string, minSeq int64) error {
return utils.Wrap1(c.rdb.Set(ctx, c.getConversationUserMinSeqKey(conversationID, userID), minSeq, 0).Err())
return errs.Wrap(c.rdb.Set(ctx, c.getConversationUserMinSeqKey(conversationID, userID), minSeq, 0).Err())
}
func (c *msgCache) SetConversationUserMinSeqs(ctx context.Context, conversationID string, seqs map[string]int64) (err error) {
@@ -242,7 +244,7 @@ func (c *msgCache) SetUserConversationsMinSeqs(ctx context.Context, userID strin
}
func (c *msgCache) SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error {
return utils.Wrap1(c.rdb.Set(ctx, c.getHasReadSeqKey(conversationID, userID), hasReadSeq, 0).Err())
return errs.Wrap(c.rdb.Set(ctx, c.getHasReadSeqKey(conversationID, userID), hasReadSeq, 0).Err())
}
func (c *msgCache) SetHasReadSeqs(ctx context.Context, conversationID string, hasReadSeqs map[string]int64) error {
@@ -264,12 +266,15 @@ func (c *msgCache) GetHasReadSeqs(ctx context.Context, userID string, conversati
}
func (c *msgCache) GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error) {
return utils.Wrap2(c.rdb.Get(ctx, c.getHasReadSeqKey(conversationID, userID)).Int64())
val, err := c.rdb.Get(ctx, c.getHasReadSeqKey(conversationID, userID)).Int64()
if err != nil {
return 0, err
}
return val, nil
}
func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err())
}
@@ -312,7 +317,7 @@ func (c *msgCache) allMessageCacheKey(conversationID string) string {
}
func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) {
if config.Config.Redis.EnablePipeline {
if c.config.Redis.EnablePipeline {
return c.PipeGetMessagesBySeq(ctx, conversationID, seqs)
}
@@ -413,7 +418,7 @@ func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID
}
func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) {
if config.Config.Redis.EnablePipeline {
if c.config.Redis.EnablePipeline {
return c.PipeSetMessageToCache(ctx, conversationID, msgs)
}
return c.ParallelSetMessageToCache(ctx, conversationID, msgs)
@@ -424,16 +429,16 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
for _, msg := range msgs {
s, err := msgprocessor.Pb2String(msg)
if err != nil {
return 0, errs.Wrap(err, "pb.marshal")
return 0, err
}
key := c.getMessageCacheKey(conversationID, msg.Seq)
_ = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second)
_ = pipe.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second)
}
results, err := pipe.Exec(ctx)
if err != nil {
return 0, errs.Wrap(err, "pipe.set")
return 0, errs.Wrap(err)
}
for _, res := range results {
@@ -458,7 +463,7 @@ func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID
}
key := c.getMessageCacheKey(conversationID, msg.Seq)
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
return errs.Wrap(err)
}
return nil
@@ -467,7 +472,7 @@ func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID
err := wg.Wait()
if err != nil {
return 0, err
return 0, errs.Wrap(err, "wg.Wait failed")
}
return len(msgs), nil
@@ -493,10 +498,10 @@ func (c *msgCache) UserDeleteMsgs(ctx context.Context, conversationID string, se
if err != nil {
return errs.Wrap(err)
}
if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
return errs.Wrap(err)
}
if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
return errs.Wrap(err)
}
}
@@ -601,7 +606,7 @@ func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID str
}
func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error {
if config.Config.Redis.EnablePipeline {
if c.config.Redis.EnablePipeline {
return c.PipeDeleteMessages(ctx, conversationID, seqs)
}
@@ -683,7 +688,7 @@ func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []in
if err != nil {
return errs.Wrap(err)
}
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
return errs.Wrap(err)
}
}
@@ -696,7 +701,11 @@ func (c *msgCache) SetGetuiToken(ctx context.Context, token string, expireTime i
}
func (c *msgCache) GetGetuiToken(ctx context.Context) (string, error) {
return utils.Wrap2(c.rdb.Get(ctx, getuiToken).Result())
val, err := c.rdb.Get(ctx, getuiToken).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}
func (c *msgCache) SetGetuiTaskID(ctx context.Context, taskID string, expireTime int64) error {
@@ -704,7 +713,11 @@ func (c *msgCache) SetGetuiTaskID(ctx context.Context, taskID string, expireTime
}
func (c *msgCache) GetGetuiTaskID(ctx context.Context) (string, error) {
return utils.Wrap2(c.rdb.Get(ctx, getuiTaskID).Result())
val, err := c.rdb.Get(ctx, getuiTaskID).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}
func (c *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32) error {
@@ -722,7 +735,11 @@ func (c *msgCache) SetFcmToken(ctx context.Context, account string, platformID i
}
func (c *msgCache) GetFcmToken(ctx context.Context, account string, platformID int) (string, error) {
return utils.Wrap2(c.rdb.Get(ctx, FCM_TOKEN+account+":"+strconv.Itoa(platformID)).Result())
val, err := c.rdb.Get(ctx, FCM_TOKEN+account+":"+strconv.Itoa(platformID)).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}
func (c *msgCache) DelFcmToken(ctx context.Context, account string, platformID int) error {
@@ -740,7 +757,8 @@ func (c *msgCache) SetUserBadgeUnreadCountSum(ctx context.Context, userID string
}
func (c *msgCache) GetUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) {
return utils.Wrap2(c.rdb.Get(ctx, userBadgeUnreadCountSum+userID).Int())
val, err := c.rdb.Get(ctx, userBadgeUnreadCountSum+userID).Int()
return val, errs.Wrap(err)
}
func (c *msgCache) LockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error {
@@ -773,42 +791,31 @@ func (c *msgCache) getMessageReactionExPrefix(clientMsgID string, sessionType in
func (c *msgCache) JudgeMessageReactionExist(ctx context.Context, clientMsgID string, sessionType int32) (bool, error) {
n, err := c.rdb.Exists(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType)).Result()
if err != nil {
return false, utils.Wrap(err, "")
return false, errs.Wrap(err)
}
return n > 0, nil
}
func (c *msgCache) SetMessageTypeKeyValue(
ctx context.Context,
clientMsgID string,
sessionType int32,
typeKey, value string,
) error {
func (c *msgCache) SetMessageTypeKeyValue(ctx context.Context, clientMsgID string, sessionType int32, typeKey, value string) error {
return errs.Wrap(c.rdb.HSet(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), typeKey, value).Err())
}
func (c *msgCache) SetMessageReactionExpire(ctx context.Context, clientMsgID string, sessionType int32, expiration time.Duration) (bool, error) {
return utils.Wrap2(c.rdb.Expire(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), expiration).Result())
val, err := c.rdb.Expire(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), expiration).Result()
return val, errs.Wrap(err)
}
func (c *msgCache) GetMessageTypeKeyValue(ctx context.Context, clientMsgID string, sessionType int32, typeKey string) (string, error) {
return utils.Wrap2(c.rdb.HGet(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), typeKey).Result())
val, err := c.rdb.HGet(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), typeKey).Result()
return val, errs.Wrap(err)
}
func (c *msgCache) GetOneMessageAllReactionList(
ctx context.Context,
clientMsgID string,
sessionType int32,
) (map[string]string, error) {
return utils.Wrap2(c.rdb.HGetAll(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType)).Result())
func (c *msgCache) GetOneMessageAllReactionList(ctx context.Context, clientMsgID string, sessionType int32) (map[string]string, error) {
val, err := c.rdb.HGetAll(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType)).Result()
return val, errs.Wrap(err)
}
func (c *msgCache) DeleteOneMessageKey(
ctx context.Context,
clientMsgID string,
sessionType int32,
subKey string,
) error {
func (c *msgCache) DeleteOneMessageKey(ctx context.Context, clientMsgID string, sessionType int32, subKey string) error {
return errs.Wrap(c.rdb.HDel(ctx, c.getMessageReactionExPrefix(clientMsgID, sessionType), subKey).Err())
}
+1 -2
View File
@@ -20,10 +20,9 @@ import (
"time"
"github.com/dtm-labs/rockscache"
"github.com/redis/go-redis/v9"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
type ObjectCache interface {
+11 -15
View File
@@ -24,16 +24,12 @@ import (
"strconv"
"time"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/user"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/dtm-labs/rockscache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
@@ -206,13 +202,13 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu
Status: constant.Online,
PlatformIDs: []int32{platformID},
}
jsonData, err2 := json.Marshal(&onlineStatus)
if err2 != nil {
return errs.Wrap(err2)
jsonData, err := json.Marshal(&onlineStatus)
if err != nil {
return errs.Wrap(err)
}
_, err2 = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result()
if err2 != nil {
return errs.Wrap(err2)
_, err = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result()
if err != nil {
return errs.Wrap(err)
}
u.rdb.Expire(ctx, key, userOlineStatusExpireTime)
@@ -286,9 +282,9 @@ func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string
func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error {
var onlineStatus user.OnlineStatus
if !isNil {
err2 := json.Unmarshal([]byte(result), &onlineStatus)
err := json.Unmarshal([]byte(result), &onlineStatus)
if err != nil {
return errs.Wrap(err2)
return errs.Wrap(err)
}
onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID))
} else {
@@ -298,7 +294,7 @@ func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string,
onlineStatus.UserID = userID
newjsonData, err := json.Marshal(&onlineStatus)
if err != nil {
return errs.Wrap(err)
return errs.Wrap(err, "json.Marshal failed")
}
_, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result()
if err != nil {
+17 -22
View File
@@ -17,45 +17,39 @@ package controller
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/golang-jwt/jwt/v4"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/tokenverify"
"github.com/OpenIMSDK/tools/utils"
"github.com/golang-jwt/jwt/v4"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
)
type AuthDatabase interface {
// 结果为空 不返回错误
// If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
// 创建token
// Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
}
type authDatabase struct {
cache cache.MsgModel
cache cache.MsgModel
accessSecret string
accessExpire int64
config *config.GlobalConfig
}
func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire}
func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64, config *config.GlobalConfig) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, config: config}
}
// 结果为空 不返回错误.
func (a *authDatabase) GetTokensWithoutError(
ctx context.Context,
userID string,
platformID int,
) (map[string]int, error) {
// If the result is empty.
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
}
// 创建token.
// Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil {
@@ -63,22 +57,23 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
}
var deleteTokenKey []string
for k, v := range tokens {
_, err = tokenverify.GetClaimFromToken(k, authverify.Secret())
_, err = tokenverify.GetClaimFromToken(k, authverify.Secret(a.config.Secret))
if err != nil || v != constant.NormalToken {
deleteTokenKey = append(deleteTokenKey, k)
}
}
if len(deleteTokenKey) != 0 {
err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
if err != nil {
return "", err
}
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil {
return "", utils.Wrap(err, "")
return "", errs.Wrap(err, "token.SignedString")
}
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
}
+13 -15
View File
@@ -17,24 +17,22 @@ package controller
import (
"context"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
type BlackDatabase interface {
// Create 增加黑名单
// Create add BlackList
Create(ctx context.Context, blacks []*relation.BlackModel) (err error)
// Delete 删除黑名单
// Delete delete BlackList
Delete(ctx context.Context, blacks []*relation.BlackModel) (err error)
// FindOwnerBlacks 获取黑名单列表
// FindOwnerBlacks get BlackList list
FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error)
FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error)
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true)
// CheckIn Check whether user2 is in the black list of user1 (inUser1Blacks==true) Check whether user1 is in the black list of user2 (inUser2Blacks==true)
CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error)
}
@@ -47,7 +45,7 @@ func NewBlackDatabase(black relation.BlackModelInterface, cache cache.BlackCache
return &blackDatabase{black, cache}
}
// Create 增加黑名单.
// Create Add Blacklist.
func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Create(ctx, blacks); err != nil {
return err
@@ -55,7 +53,7 @@ func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks)
}
// Delete 删除黑名单.
// Delete Delete Blacklist.
func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Delete(ctx, blacks); err != nil {
return err
@@ -63,6 +61,7 @@ func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks)
}
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relation.BlackModel) (err error) {
cache := b.cache.NewCache()
for _, black := range blacks {
@@ -71,16 +70,13 @@ func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relat
return cache.ExecDel(ctx)
}
// FindOwnerBlacks 获取黑名单列表.
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlacks(ctx, ownerUserID, pagination)
}
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true).
func (b *blackDatabase) CheckIn(
ctx context.Context,
userID1, userID2 string,
) (inUser1Blacks bool, inUser2Blacks bool, err error) {
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) {
userID1BlackIDs, err := b.cache.GetBlackIDs(ctx, userID1)
if err != nil {
return
@@ -93,10 +89,12 @@ func (b *blackDatabase) CheckIn(
return utils.IsContain(userID2, userID1BlackIDs), utils.IsContain(userID1, userID2BlackIDs), nil
}
// FindBlackIDs Get Blacklist List.
func (b *blackDatabase) FindBlackIDs(ctx context.Context, ownerUserID string) (blackIDs []string, err error) {
return b.cache.GetBlackIDs(ctx, ownerUserID)
}
// FindBlackInfos Get Blacklist List.
func (b *blackDatabase) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlackInfos(ctx, ownerUserID, userIDs)
}
+30 -24
View File
@@ -18,46 +18,52 @@ import (
"context"
"time"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/tx"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
)
type ConversationDatabase interface {
// UpdateUserConversationFiled 更新用户该会话的属性信息
UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation 创建一批新的会话
// UpdateUsersConversationField updates the properties of a conversation for specified users.
UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation creates a batch of new conversations.
CreateConversation(ctx context.Context, conversations []*relationtb.ConversationModel) error
// SyncPeerUserPrivateConversation 同步对端私聊会话内部保证事务操作
// SyncPeerUserPrivateConversationTx ensures transactional operation while syncing private conversations between peers.
SyncPeerUserPrivateConversationTx(ctx context.Context, conversation []*relationtb.ConversationModel) error
// FindConversations 根据会话ID获取某个用户的多个会话
// FindConversations retrieves multiple conversations of a user by conversation IDs.
FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// FindRecvMsgNotNotifyUserIDs 获取超级大群开启免打扰的用户ID
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
// GetUserAllConversation 获取一个用户在服务器上所有的会话
// GetUserAllConversation fetches all conversations of a user on the server.
GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.ConversationModel, error)
// SetUserConversations 设置用户多个会话属性,如果会话不存在则创建,否则更新,内部保证原子性
// SetUserConversations sets multiple conversation properties for a user, creates new conversations if they do not exist, or updates them otherwise. This operation is atomic.
SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error
// SetUsersConversationFiledTx 设置多个用户会话关于某个字段的更新操作,如果会话不存在则创建,否则更新,内部保证事务操作
SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) error
// SetUsersConversationFieldTx updates a specific field for multiple users' conversations, creating new conversations if they do not exist, or updates them otherwise. This operation is
// transactional.
SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) error
// CreateGroupChatConversation creates a group chat conversation for the specified group ID and user IDs.
CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error
// GetConversationIDs retrieves conversation IDs for a given user.
GetConversationIDs(ctx context.Context, userID string) ([]string, error)
// GetUserConversationIDsHash gets the hash of conversation IDs for a given user.
GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error)
// GetAllConversationIDs fetches all conversation IDs.
GetAllConversationIDs(ctx context.Context) ([]string, error)
// GetAllConversationIDsNumber returns the number of all conversation IDs.
GetAllConversationIDsNumber(ctx context.Context) (int64, error)
// PageConversationIDs paginates through conversation IDs based on the specified pagination settings.
PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
// GetConversationsByConversationID retrieves conversations by their IDs.
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// GetConversationIDsNeedDestruct fetches conversations that need to be destructed based on specific criteria.
GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.ConversationModel, error)
// GetConversationNotReceiveMessageUserIDs gets user IDs for users in a conversation who have not received messages.
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
}
func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase {
@@ -74,7 +80,7 @@ type conversationDatabase struct {
tx tx.CtxTx
}
func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) (err error) {
func (c *conversationDatabase) SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) (err error) {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
if conversation.GroupID != "" {
@@ -85,27 +91,27 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
return err
}
if len(haveUserIDs) > 0 {
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap)
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, fieldMap)
if err != nil {
return err
}
cache = cache.DelUsersConversation(conversation.ConversationID, haveUserIDs...)
if _, ok := filedMap["has_read_seq"]; ok {
if _, ok := fieldMap["has_read_seq"]; ok {
for _, userID := range haveUserIDs {
cache = cache.DelUserAllHasReadSeqs(userID, conversation.ConversationID)
}
}
if _, ok := filedMap["recv_msg_opt"]; ok {
if _, ok := fieldMap["recv_msg_opt"]; ok {
cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID)
}
}
NotUserIDs := utils.DifferenceString(haveUserIDs, userIDs)
log.ZDebug(ctx, "SetUsersConversationFiledTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
log.ZDebug(ctx, "SetUsersConversationFieldTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
var conversations []*relationtb.ConversationModel
now := time.Now()
for _, v := range NotUserIDs {
temp := new(relationtb.ConversationModel)
if err := utils.CopyStructFields(temp, conversation); err != nil {
if err = utils.CopyStructFields(temp, conversation); err != nil {
return err
}
temp.OwnerUserID = v
@@ -123,7 +129,7 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
})
}
func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
func (c *conversationDatabase) UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
_, err := c.conversationDB.UpdateByMap(ctx, userIDs, conversationID, args)
if err != nil {
return err
+47 -27
View File
@@ -16,17 +16,16 @@ package controller
import (
"context"
"fmt"
"time"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/mcontext"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/tx"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
@@ -89,20 +88,30 @@ func NewFriendDatabase(friend relation.FriendModelInterface, friendRequest relat
return &friendDatabase{friend: friend, friendRequest: friendRequest, cache: cache, tx: tx}
}
// ok 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true).
// CheckIn verifies if user2 is in user1's friend list (inUser1Friends returns true) and
// if user1 is in user2's friend list (inUser2Friends returns true).
func (f *friendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) {
// Retrieve friend IDs of userID1 from the cache
userID1FriendIDs, err := f.cache.GetFriendIDs(ctx, userID1)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID1, err)
return
}
// Retrieve friend IDs of userID2 from the cache
userID2FriendIDs, err := f.cache.GetFriendIDs(ctx, userID2)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID2, err)
return
}
return utils.IsContain(userID2, userID1FriendIDs), utils.IsContain(userID1, userID2FriendIDs), nil
// Check if userID2 is in userID1's friend list and vice versa
inUser1Friends = utils.IsContain(userID2, userID1FriendIDs)
inUser2Friends = utils.IsContain(userID1, userID2FriendIDs)
return inUser1Friends, inUser2Friends, nil
}
// 增加或者更新好友申请 如果之前有记录则更新,没有记录则新增.
// AddFriendRequest adds or updates a friend request.
func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
_, err := f.friendRequest.Take(ctx, fromUserID, toUserID)
@@ -126,11 +135,11 @@ func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUse
})
}
// (1)先判断是否在好友表 (在不在都不返回错误) (2)对于不在好友列表的 插入即可.
// (1) First determine whether it is in the friends list (in or out does not return an error) (2) for not in the friends list can be inserted.
func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
cache := f.cache.NewCache()
// 先find 找出重复的 去掉重复的
// User find friends
fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
return err
@@ -170,26 +179,37 @@ func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string,
})
}
// 拒绝好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)修改申请记录 已拒绝.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) {
// RefuseFriendRequest rejects a friend request. It first checks for an existing, unprocessed request.
// If no such request exists, it returns an error. Otherwise, it marks the request as refused.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) error {
// Attempt to retrieve the friend request from the database.
fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil {
return err
return fmt.Errorf("failed to retrieve friend request from %s to %s: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
// Check if the friend request has already been handled.
if fr.HandleResult != 0 {
return errs.ErrArgs.Wrap("the friend request has been processed")
return fmt.Errorf("friend request from %s to %s has already been processed", friendRequest.FromUserID, friendRequest.ToUserID)
}
log.ZDebug(ctx, "refuse friend request", "friendRequest db", fr, "friendRequest arg", friendRequest)
// Log the action of refusing the friend request for debugging and auditing purposes.
log.ZDebug(ctx, "Refusing friend request", map[string]interface{}{
"DB_FriendRequest": fr,
"Arg_FriendRequest": friendRequest,
})
// Mark the friend request as refused and update the handle time.
friendRequest.HandleResult = constant.FriendResponseRefuse
friendRequest.HandleTime = time.Now()
err = f.friendRequest.Update(ctx, friendRequest)
if err != nil {
return err
if err := f.friendRequest.Update(ctx, friendRequest); err != nil {
return fmt.Errorf("failed to update friend request from %s to %s as refused: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
return nil
}
// AgreeFriendRequest 同意好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)检查是否好友(不返回错误) (3) 建立双向好友关系(存在的忽略).
// AgreeFriendRequest accepts a friend request. It first checks for an existing, unprocessed request.
func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
defer log.ZDebug(ctx, "return line")
@@ -227,10 +247,10 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
return err
}
existsMap := utils.SliceSet(utils.Slice(exists, func(friend *relation.FriendModel) [2]string {
return [...]string{friend.OwnerUserID, friend.FriendUserID} // 自己 - 好友
return [...]string{friend.OwnerUserID, friend.FriendUserID} // My - Friend
}))
var adds []*relation.FriendModel
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // 自己 - 好友
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // My - Friend
adds = append(
adds,
&relation.FriendModel{
@@ -241,7 +261,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
},
)
}
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // 好友 - 自己
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // My - Friend
adds = append(
adds,
&relation.FriendModel{
@@ -261,7 +281,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
})
}
// 删除好友 外部判断是否好友关系.
// Delete removes a friend relationship. It is assumed that the external caller has verified the friendship status.
func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) {
if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil {
return err
@@ -269,7 +289,7 @@ func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendU
return f.cache.DelFriendIDs(append(friendUserIDs, ownerUserID)...).ExecDel(ctx)
}
// 更新好友备注 零值也支持.
// UpdateRemark updates the remark for a friend. Zero value for remark is also supported.
func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) {
if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil {
return err
@@ -277,27 +297,27 @@ func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUs
return f.cache.DelFriend(ownerUserID, friendUserID).ExecDel(ctx)
}
// 获取ownerUserID的好友列表 无结果不返回错误.
// PageOwnerFriends retrieves the list of friends for the ownerUserID. It does not return an error if the result is empty.
func (f *friendDatabase) PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindOwnerFriends(ctx, ownerUserID, pagination)
}
// friendUserID在哪些人的好友列表中.
// PageInWhoseFriends identifies in whose friend lists the friendUserID appears.
func (f *friendDatabase) PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindInWhoseFriends(ctx, friendUserID, pagination)
}
// 获取我发出去的好友申请 无结果不返回错误.
// PageFriendRequestFromMe retrieves friend requests sent by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestFromMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindFromUserID(ctx, userID, pagination)
}
// 获取我收到的的好友申请 无结果不返回错误.
// PageFriendRequestToMe retrieves friend requests received by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestToMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindToUserID(ctx, userID, pagination)
}
// 获取某人指定好友的信息 如果有好友不存在,也返回错误.
// FindFriendsWithError retrieves specified friends' information for ownerUserID. Returns an error if any friend does not exist.
func (f *friendDatabase) FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*relation.FriendModel, err error) {
friends, err = f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
+45 -15
View File
@@ -18,60 +18,90 @@ import (
"context"
"time"
"github.com/OpenIMSDK/tools/pagination"
"github.com/dtm-labs/rockscache"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/tx"
"github.com/OpenIMSDK/tools/utils"
"github.com/redis/go-redis/v9"
"github.com/dtm-labs/rockscache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
type GroupDatabase interface {
// Group
// CreateGroup creates new groups along with their members.
CreateGroup(ctx context.Context, groups []*relationtb.GroupModel, groupMembers []*relationtb.GroupMemberModel) error
// TakeGroup retrieves a single group by its ID.
TakeGroup(ctx context.Context, groupID string) (group *relationtb.GroupModel, err error)
// FindGroup retrieves multiple groups by their IDs.
FindGroup(ctx context.Context, groupIDs []string) (groups []*relationtb.GroupModel, err error)
// SearchGroup searches for groups based on a keyword and pagination settings, returns total count and groups.
SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*relationtb.GroupModel, error)
// UpdateGroup updates the properties of a group identified by its ID.
UpdateGroup(ctx context.Context, groupID string, data map[string]any) error
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error // 解散群,并删除群成员
// DismissGroup disbands a group and optionally removes its members based on the deleteMember flag.
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error
// TakeGroupMember retrieves a specific group member by group ID and user ID.
TakeGroupMember(ctx context.Context, groupID string, userID string) (groupMember *relationtb.GroupMemberModel, err error)
// TakeGroupOwner retrieves the owner of a group by group ID.
TakeGroupOwner(ctx context.Context, groupID string) (*relationtb.GroupMemberModel, error)
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
// FindGroupMembers retrieves members of a group filtered by user IDs.
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberUser retrieves groups that a user is a member of, filtered by group IDs.
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberRoleLevels retrieves group members filtered by their role levels within a group.
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberAll retrieves all members of a group.
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupsOwner retrieves the owners for multiple groups.
FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*relationtb.GroupMemberModel, error)
// FindGroupMemberUserID retrieves the user IDs of all members in a group.
FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error)
// FindGroupMemberNum retrieves the number of members in a group.
FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error)
// FindUserManagedGroupID retrieves group IDs managed by a user.
FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
// PageGroupRequest paginates through group requests for specified groups.
PageGroupRequest(ctx context.Context, groupIDs []string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// GetGroupRoleLevelMemberIDs retrieves user IDs of group members with a specific role level.
GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
// PageGetJoinGroup paginates through groups that a user has joined.
PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// PageGetGroupMember paginates through members of a group.
PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// SearchGroupMember searches for group members based on a keyword, group ID, and pagination settings.
SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*relationtb.GroupMemberModel, error)
// HandlerGroupRequest processes a group join request with a specified result.
HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *relationtb.GroupMemberModel) error
// DeleteGroupMember removes specified users from a group.
DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error
// MapGroupMemberUserID maps group IDs to their members' simplified user IDs.
MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*relationtb.GroupSimpleUserID, error)
// MapGroupMemberNum maps group IDs to their member count.
MapGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error)
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error // 转让群
// TransferGroupOwner transfers the ownership of a group to another user.
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error
// UpdateGroupMember updates properties of a group member.
UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error
// UpdateGroupMembers batch updates properties of group members.
UpdateGroupMembers(ctx context.Context, data []*relationtb.BatchUpdateGroupMember) error
// GroupRequest
// CreateGroupRequest creates new group join requests.
CreateGroupRequest(ctx context.Context, requests []*relationtb.GroupRequestModel) error
// TakeGroupRequest retrieves a specific group join request.
TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relationtb.GroupRequestModel, error)
// FindGroupRequests retrieves multiple group join requests.
FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*relationtb.GroupRequestModel, error)
// PageGroupRequestUser paginates through group join requests made by a user.
PageGroupRequestUser(ctx context.Context, userID string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// 获取群总数
// CountTotal counts the total number of groups as of a certain date.
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量
// CountRangeEverydayTotal counts the daily group creation total within a specified date range.
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
// DeleteGroupMemberHash deletes the hash entries for group members in specified groups.
DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error
}
+73 -51
View File
@@ -21,26 +21,20 @@ import (
"time"
"github.com/OpenIMSDK/protocol/constant"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/redis/go-redis/v9"
pbmsg "github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"go.mongodb.org/mongo-driver/mongo"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/convert"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation"
"github.com/openimsdk/open-im-server/v3/pkg/common/kafka"
pbmsg "github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
)
const (
@@ -48,33 +42,33 @@ const (
updateKeyRevoke
)
// CommonMsgDatabase defines the interface for message database operations.
type CommonMsgDatabase interface {
// 批量插入消息
// BatchInsertChat2DB inserts a batch of messages into the database for a specific conversation.
BatchInsertChat2DB(ctx context.Context, conversationID string, msgs []*sdkws.MsgData, currentMaxSeq int64) error
// 撤回消息
// RevokeMsg revokes a message in a conversation.
RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *unrelationtb.RevokeModel) error
// mark as read
// MarkSingleChatMsgsAsRead marks messages as read for a single chat by sequence numbers.
MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 刪除redis中消息缓存
// DeleteMessagesFromCache deletes message caches from Redis by sequence numbers.
DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error
// DelUserDeleteMsgsList deletes user's message deletion list.
DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64)
// incrSeq然后批量插入缓存
// BatchInsertChat2Cache increments the sequence number and then batch inserts messages into the cache.
BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, err error)
// 通过seqList获取mongo中写扩散消息
// GetMsgBySeqsRange retrieves messages from MongoDB by a range of sequence numbers.
GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 通过seqList获取大群在 mongo里面的消息
// GetMsgBySeqs retrieves messages for large groups from MongoDB by sequence numbers.
GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 删除会话消息重置最小seq, remainTime为消息保留的时间单位秒,超时消息删除, 传0删除所有消息(此方法不删除redis cache)
// DeleteConversationMsgsAndSetMinSeq deletes conversation messages and resets the minimum sequence number. If `remainTime` is 0, all messages are deleted (this method does not delete Redis
// cache).
DeleteConversationMsgsAndSetMinSeq(ctx context.Context, conversationID string, remainTime int64) error
// 用户标记删除过期消息返回标记删除的seq列表
// UserMsgsDestruct marks messages for deletion based on destruct time and returns a list of sequence numbers for marked messages.
UserMsgsDestruct(ctx context.Context, userID string, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error)
// 用户根据seq删除消息
// DeleteUserMsgsBySeqs allows a user to delete messages based on sequence numbers.
DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 物理删除消息置空
// DeleteMsgsPhysicalBySeqs physically deletes messages by emptying them based on sequence numbers.
DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, seqs []int64) error
SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error)
@@ -126,21 +120,49 @@ type CommonMsgDatabase interface {
ConvertMsgsDocLen(ctx context.Context, conversationIDs []string)
}
func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel) CommonMsgDatabase {
func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel, config *config.GlobalConfig) (CommonMsgDatabase, error) {
producerConfig := &kafka.ProducerConfig{
ProducerAck: config.Kafka.ProducerAck,
CompressType: config.Kafka.CompressType,
Username: config.Kafka.Username,
Password: config.Kafka.Password,
}
var tlsConfig *kafka.TLSConfig
if config.Kafka.TLS != nil {
tlsConfig = &kafka.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: false,
}
}
producerToRedis, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.LatestMsgToRedis.Topic, producerConfig, tlsConfig)
if err != nil {
return nil, err
}
producerToMongo, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToMongo.Topic, producerConfig, tlsConfig)
if err != nil {
return nil, err
}
producerToPush, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToPush.Topic, producerConfig, tlsConfig)
if err != nil {
return nil, err
}
return &commonMsgDatabase{
msgDocDatabase: msgDocModel,
cache: cacheModel,
producer: kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.LatestMsgToRedis.Topic),
producerToMongo: kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToMongo.Topic),
producerToPush: kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToPush.Topic),
}
producer: producerToRedis,
producerToMongo: producerToMongo,
producerToPush: producerToPush,
}, nil
}
func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database) CommonMsgDatabase {
cacheModel := cache.NewMsgCacheModel(rdb)
func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database, config *config.GlobalConfig) (CommonMsgDatabase, error) {
cacheModel := cache.NewMsgCacheModel(rdb, config)
msgDocModel := unrelation.NewMsgMongoDriver(database)
CommonMsgDatabase := NewCommonMsgDatabase(msgDocModel, cacheModel)
return CommonMsgDatabase
return NewCommonMsgDatabase(msgDocModel, cacheModel, config)
}
type commonMsgDatabase struct {
@@ -189,7 +211,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
num := db.msg.GetSingleGocMsgNum()
// num = 100
for i, field := range fields { // 检查类型
for i, field := range fields { // Check the type of the field
var ok bool
switch key {
case updateKeyMsg:
@@ -207,7 +229,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
return errs.ErrInternalServer.Wrap("field type is invalid")
}
}
// 返回值为true表示数据库存在该文档,false表示数据库不存在该文档
// Returns true if the document exists in the database, false if the document does not exist in the database
updateMsgModel := func(seq int64, i int) (bool, error) {
var (
res *mongo.UpdateResult
@@ -229,21 +251,21 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
tryUpdate := true
for i := 0; i < len(fields); i++ {
seq := firstSeq + int64(i) // 当前seq
seq := firstSeq + int64(i) // Current sequence number
if tryUpdate {
matched, err := updateMsgModel(seq, i)
if err != nil {
return err
}
if matched {
continue // 匹配到了,继续下一个(不一定修改)
continue // The current data has been updated, skip the current data
}
}
doc := unrelationtb.MsgDocModel{
DocID: db.msg.GetDocID(conversationID, seq),
Msg: make([]*unrelationtb.MsgInfoModel, num),
}
var insert int // 插入的数量
var insert int // Inserted data number
for j := i; j < len(fields); j++ {
seq = firstSeq + int64(j)
if db.msg.GetDocID(conversationID, seq) != doc.DocID {
@@ -272,14 +294,14 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
if err := db.msgDocDatabase.Create(ctx, &doc); err != nil {
if mongo.IsDuplicateKeyError(err) {
i-- // 存在并发,重试当前数据
tryUpdate = true // 以修改模式
i-- // already inserted
tryUpdate = true // next block use update mode
continue
}
return err
}
tryUpdate = false // 当前以插入成功,下一块优先插入模式
i += insert - 1 // 跳过已插入的数据
tryUpdate = false // The current block is inserted successfully, and the next block is inserted preferentially
i += insert - 1 // Skip the inserted data
}
return nil
}
@@ -392,12 +414,12 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa
log.ZError(ctx, "db.cache.SetMaxSeq error", err, "conversationID", conversationID)
prommetrics.SeqSetFailedCounter.Inc()
}
err2 := db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap)
err = db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap)
if err != nil {
log.ZError(ctx, "SetHasReadSeqs error", err2, "userSeqMap", userSeqMap, "conversationID", conversationID)
log.ZError(ctx, "SetHasReadSeqs error", err, "userSeqMap", userSeqMap, "conversationID", conversationID)
prommetrics.SeqSetFailedCounter.Inc()
}
return lastMaxSeq, isNew, utils.Wrap(err, "")
return lastMaxSeq, isNew, errs.Wrap(err)
}
func (db *commonMsgDatabase) getMsgBySeqs(ctx context.Context, userID, conversationID string, seqs []int64) (totalMsgs []*sdkws.MsgData, err error) {
@@ -743,7 +765,7 @@ func (db *commonMsgDatabase) UserMsgsDestruct(ctx context.Context, userID string
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
}
}
// 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归
// If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
break
}
index++
@@ -798,7 +820,7 @@ func (d *delMsgRecursionStruct) getSetMinSeq() int64 {
// index 0....19(del) 20...69
// seq 70
// set minSeq 21
// recursion 删除list并且返回设置的最小seq.
// recursion deletes the list and returns the set minimum seq.
func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversationID string, index int64, delStruct *delMsgRecursionStruct, remainTime int64) (int64, error) {
// find from oldest list
msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1)
@@ -810,7 +832,7 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
}
}
// 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归
// If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
err = db.msgDocDatabase.DeleteDocs(ctx, delStruct.delDocIDs)
if err != nil {
return 0, err
@@ -835,7 +857,7 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio
}
}
if len(delMsgIndexs) > 0 {
if err := db.msgDocDatabase.DeleteMsgsInOneDocByIndex(ctx, msgDocModel.DocID, delMsgIndexs); err != nil {
if err = db.msgDocDatabase.DeleteMsgsInOneDocByIndex(ctx, msgDocModel.DocID, delMsgIndexs); err != nil {
log.ZError(ctx, "deleteMsgRecursion DeleteMsgsInOneDocByIndex failed", err, "conversationID", conversationID, "index", index)
}
delStruct.minSeq = int64(msgDocModel.Msg[delMsgIndexs[len(delMsgIndexs)-1]].Msg.Seq)
+27 -25
View File
@@ -33,27 +33,28 @@ import (
)
func Test_BatchInsertChat2DB(t *testing.T) {
config.Config.Mongo.Address = []string{"192.168.44.128:37017"}
// config.Config.Mongo.Timeout = 60
config.Config.Mongo.Database = "openIM"
// config.Config.Mongo.Source = "admin"
config.Config.Mongo.Username = "root"
config.Config.Mongo.Password = "openIM123"
config.Config.Mongo.MaxPoolSize = 100
config.Config.RetainChatRecords = 3650
config.Config.ChatRecordsClearTime = "0 2 * * 3"
conf := config.NewGlobalConfig()
conf.Mongo.Address = []string{"192.168.44.128:37017"}
// conf.Mongo.Timeout = 60
conf.Mongo.Database = "openIM"
// conf.Mongo.Source = "admin"
conf.Mongo.Username = "root"
conf.Mongo.Password = "openIM123"
conf.Mongo.MaxPoolSize = 100
conf.RetainChatRecords = 3650
conf.ChatRecordsClearTime = "0 2 * * 3"
mongo, err := unrelation.NewMongo()
mongo, err := unrelation.NewMongo(conf)
if err != nil {
t.Fatal(err)
}
err = mongo.GetDatabase().Client().Ping(context.Background(), nil)
err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
if err != nil {
panic(err)
}
db := &commonMsgDatabase{
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()),
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
}
//ctx := context.Background()
@@ -70,7 +71,7 @@ func Test_BatchInsertChat2DB(t *testing.T) {
//}
_ = db.BatchInsertChat2DB
c := mongo.GetDatabase().Collection("msg")
c := mongo.GetDatabase(conf.Mongo.Database).Collection("msg")
ch := make(chan int)
rand.Seed(time.Now().UnixNano())
@@ -144,26 +145,27 @@ func Test_BatchInsertChat2DB(t *testing.T) {
}
func GetDB() *commonMsgDatabase {
config.Config.Mongo.Address = []string{"203.56.175.233:37017"}
// config.Config.Mongo.Timeout = 60
config.Config.Mongo.Database = "openim_v3"
// config.Config.Mongo.Source = "admin"
config.Config.Mongo.Username = "root"
config.Config.Mongo.Password = "openIM123"
config.Config.Mongo.MaxPoolSize = 100
config.Config.RetainChatRecords = 3650
config.Config.ChatRecordsClearTime = "0 2 * * 3"
conf := config.NewGlobalConfig()
conf.Mongo.Address = []string{"203.56.175.233:37017"}
// conf.Mongo.Timeout = 60
conf.Mongo.Database = "openim_v3"
// conf.Mongo.Source = "admin"
conf.Mongo.Username = "root"
conf.Mongo.Password = "openIM123"
conf.Mongo.MaxPoolSize = 100
conf.RetainChatRecords = 3650
conf.ChatRecordsClearTime = "0 2 * * 3"
mongo, err := unrelation.NewMongo()
mongo, err := unrelation.NewMongo(conf)
if err != nil {
panic(err)
}
err = mongo.GetDatabase().Client().Ping(context.Background(), nil)
err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
if err != nil {
panic(err)
}
return &commonMsgDatabase{
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()),
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
}
}
+1 -2
View File
@@ -19,12 +19,11 @@ import (
"path/filepath"
"time"
"github.com/redis/go-redis/v9"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3/cont"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/redis/go-redis/v9"
)
type S3Database interface {
+1 -8
View File
@@ -19,7 +19,6 @@ import (
"time"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
@@ -63,13 +62,7 @@ func NewThirdDatabase(cache cache.MsgModel, logdb relation.LogInterface) ThirdDa
return &thirdDatabase{cache: cache, logdb: logdb}
}
func (t *thirdDatabase) FcmUpdateToken(
ctx context.Context,
account string,
platformID int,
fcmToken string,
expireTime int64,
) error {
func (t *thirdDatabase) FcmUpdateToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) error {
return t.cache.SetFcmToken(ctx, account, platformID, fcmToken, expireTime)
}
+4 -9
View File
@@ -18,19 +18,14 @@ import (
"context"
"time"
"github.com/OpenIMSDK/protocol/user"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/pagination"
"github.com/OpenIMSDK/tools/tx"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/OpenIMSDK/protocol/user"
unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
type UserDatabase interface {
View File
+1 -2
View File
@@ -19,11 +19,10 @@ import (
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewBlackMongo(db *mongo.Database) (relation.BlackModelInterface, error) {
+3 -3
View File
@@ -19,13 +19,13 @@ import (
"time"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewConversationMongo(db *mongo.Database) (*ConversationMgo, error) {
@@ -38,7 +38,7 @@ func NewConversationMongo(db *mongo.Database) (*ConversationMgo, error) {
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &ConversationMgo{coll: coll}, nil
}
+2 -4
View File
@@ -19,12 +19,10 @@ import (
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/mongo/options"
)
// FriendMgo implements FriendModelInterface using MongoDB as the storage backend.
+2 -4
View File
@@ -19,12 +19,10 @@ import (
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewFriendRequestMongo(db *mongo.Database) (relation.FriendRequestModelInterface, error) {
+3 -3
View File
@@ -18,13 +18,13 @@ import (
"context"
"time"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewGroupMongo(db *mongo.Database) (relation.GroupModelInterface, error) {
@@ -36,7 +36,7 @@ func NewGroupMongo(db *mongo.Database) (relation.GroupModelInterface, error) {
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &GroupMgo{coll: coll}, nil
}
+3 -3
View File
@@ -18,13 +18,13 @@ import (
"context"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewGroupMember(db *mongo.Database) (relation.GroupMemberModelInterface, error) {
@@ -37,7 +37,7 @@ func NewGroupMember(db *mongo.Database) (relation.GroupMemberModelInterface, err
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &GroupMemberMgo{coll: coll}, nil
}
+3 -3
View File
@@ -17,13 +17,13 @@ package mgo
import (
"context"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewGroupRequestMgo(db *mongo.Database) (relation.GroupRequestModelInterface, error) {
@@ -36,7 +36,7 @@ func NewGroupRequestMgo(db *mongo.Database) (relation.GroupRequestModelInterface
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &GroupRequestMgo{coll: coll}, nil
}
+1 -2
View File
@@ -20,11 +20,10 @@ import (
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewLogMongo(db *mongo.Database) (relation.LogInterface, error) {
+3 -3
View File
@@ -16,13 +16,13 @@ package mgo
import (
"context"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewS3Mongo(db *mongo.Database) (relation.ObjectInfoModelInterface, error) {
@@ -34,7 +34,7 @@ func NewS3Mongo(db *mongo.Database) (relation.ObjectInfoModelInterface, error) {
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &S3Mongo{coll: coll}, nil
}
+10 -12
View File
@@ -20,15 +20,13 @@ import (
"github.com/OpenIMSDK/protocol/user"
"github.com/OpenIMSDK/tools/errs"
"go.mongodb.org/mongo-driver/bson/primitive"
"github.com/OpenIMSDK/tools/mgoutil"
"github.com/OpenIMSDK/tools/pagination"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func NewUserMongo(db *mongo.Database) (relation.UserModelInterface, error) {
@@ -40,7 +38,7 @@ func NewUserMongo(db *mongo.Database) (relation.UserModelInterface, error) {
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return &UserMgo{coll: coll}, nil
}
@@ -159,7 +157,7 @@ func (u *UserMgo) AddUserCommand(ctx context.Context, userID string, Type int32,
}
_, err := collection.InsertOne(ctx, doc)
return err
return errs.Wrap(err)
}
func (u *UserMgo) DeleteUserCommand(ctx context.Context, userID string, Type int32, UUID string) error {
@@ -172,7 +170,7 @@ func (u *UserMgo) DeleteUserCommand(ctx context.Context, userID string, Type int
// No records found to update
return errs.Wrap(errs.ErrRecordNotFound)
}
return err
return errs.Wrap(err)
}
func (u *UserMgo) UpdateUserCommand(ctx context.Context, userID string, Type int32, UUID string, val map[string]any) error {
if len(val) == 0 {
@@ -186,7 +184,7 @@ func (u *UserMgo) UpdateUserCommand(ctx context.Context, userID string, Type int
result, err := collection.UpdateOne(ctx, filter, update)
if err != nil {
return err
return errs.Wrap(err)
}
if result.MatchedCount == 0 {
@@ -235,7 +233,7 @@ func (u *UserMgo) GetUserCommand(ctx context.Context, userID string, Type int32)
}
if err := cursor.Err(); err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return commands, nil
@@ -246,7 +244,7 @@ func (u *UserMgo) GetAllUserCommand(ctx context.Context, userID string) ([]*user
cursor, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
return nil, errs.Wrap(err)
}
defer cursor.Close(ctx)
@@ -263,7 +261,7 @@ func (u *UserMgo) GetAllUserCommand(ctx context.Context, userID string) ([]*user
}
if err := cursor.Decode(&document); err != nil {
return nil, err
return nil, errs.Wrap(err)
}
commandInfo := &user.AllCommandInfoResp{
@@ -278,7 +276,7 @@ func (u *UserMgo) GetAllUserCommand(ctx context.Context, userID string) ([]*user
}
if err := cursor.Err(); err != nil {
return nil, err
return nil, errs.Wrap(err)
}
return commands, nil
}
+20 -6
View File
@@ -15,10 +15,24 @@
package cont
const (
hashPath = "openim/data/hash/"
tempPath = "openim/temp/"
DirectPath = "openim/direct"
UploadTypeMultipart = 1 // 分片上传
UploadTypePresigned = 2 // 预签名上传
partSeparator = ","
// hashPath defines the storage path for hash data within the 'openim' directory.
hashPath = "openim/data/hash/"
// tempPath specifies the directory for temporary files in the 'openim' structure.
tempPath = "openim/temp/"
// DirectPath indicates the directory for direct uploads or access within the 'openim' structure.
DirectPath = "openim/direct"
// UploadTypeMultipart represents the identifier for multipart uploads,
// allowing large files to be uploaded in chunks.
UploadTypeMultipart = 1
// UploadTypePresigned signifies the use of presigned URLs for uploads,
// facilitating secure, authorized file transfers without requiring direct access to the storage credentials.
UploadTypePresigned = 2
// partSeparator is used as a delimiter in multipart upload processes,
// separating individual file parts.
partSeparator = ","
)
+7 -10
View File
@@ -24,13 +24,10 @@ import (
"strings"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/google/uuid"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/google/uuid"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
)
@@ -106,7 +103,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
partNumber++
}
if maxParts > 0 && partNumber > 0 && partNumber < maxParts {
return nil, errors.New(fmt.Sprintf("too many parts: %d", partNumber))
return nil, fmt.Errorf("too many parts: %d", partNumber)
}
if info, err := c.StatObject(ctx, c.HashPath(hash)); err == nil {
return nil, &HashAlreadyExistsError{Object: info}
@@ -114,7 +111,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
return nil, err
}
if size <= partSize {
// 预签名上传
// Pre-signed upload
key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID()))
rawURL, err := c.impl.PresignedPutObject(ctx, key, expire)
if err != nil {
@@ -139,7 +136,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
},
}, nil
} else {
// 分片上传
// Fragment upload
upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash))
if err != nil {
return nil, err
@@ -206,7 +203,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
ETag: part,
}
}
// todo: 验证大小
// todo: Validation size
result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts)
if err != nil {
return nil, err
@@ -225,7 +222,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
if md5val := hex.EncodeToString(md5Sum[:]); md5val != upload.Hash {
return nil, errs.ErrArgs.Wrap(fmt.Sprintf("md5 mismatching %s != %s", md5val, upload.Hash))
}
// 防止在这个时候,并发操作,导致文件被覆盖
// Prevents concurrent operations at this time that cause files to be overwritten
copyInfo, err := c.impl.CopyObject(ctx, uploadInfo.Key, upload.Key+"."+c.UUID())
if err != nil {
return nil, err
+8 -3
View File
@@ -17,9 +17,14 @@ package cont
import "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
type InitiateUploadResult struct {
UploadID string `json:"uploadID"` // 上传ID
PartSize int64 `json:"partSize"` // 分片大小
Sign *s3.AuthSignResult `json:"sign"` // 分片信息
// UploadID uniquely identifies the upload session for tracking and management purposes.
UploadID string `json:"uploadID"`
// PartSize specifies the size of each part in a multipart upload. This is relevant for breaking down large uploads into manageable pieces.
PartSize int64 `json:"partSize"`
// Sign contains the authentication and signature information necessary for securely uploading each part. This could include signed URLs or tokens.
Sign *s3.AuthSignResult `json:"sign"`
}
type UploadResult struct {
+14 -11
View File
@@ -29,10 +29,9 @@ import (
"strings"
"time"
"github.com/tencentyun/cos-go-sdk-v5"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/errs"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
"github.com/tencentyun/cos-go-sdk-v5"
)
const (
@@ -51,13 +50,15 @@ const (
const successCode = http.StatusOK
const (
videoSnapshotImagePng = "png"
videoSnapshotImageJpg = "jpg"
)
type Config struct {
BucketURL string
SecretID string
SecretKey string
SessionToken string
PublicRead bool
}
func NewCos() (s3.Interface, error) {
conf := config.Config.Object.Cos
func NewCos(conf Config) (s3.Interface, error) {
u, err := url.Parse(conf.BucketURL)
if err != nil {
panic(err)
@@ -70,6 +71,7 @@ func NewCos() (s3.Interface, error) {
},
})
return &Cos{
publicRead: conf.PublicRead,
copyURL: u.Host + "/",
client: client,
credential: client.GetCredential(),
@@ -77,6 +79,7 @@ func NewCos() (s3.Interface, error) {
}
type Cos struct {
publicRead bool
copyURL string
client *cos.Client
credential *cos.Credential
@@ -227,7 +230,7 @@ func (c *Cos) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO
}
func (c *Cos) IsNotFound(err error) bool {
switch e := err.(type) {
switch e := errs.Unwrap(err).(type) {
case *cos.ErrorResponse:
return e.Response.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
default:
@@ -328,7 +331,7 @@ func (c *Cos) AccessURL(ctx context.Context, name string, expire time.Duration,
}
func (c *Cos) getPresignedURL(ctx context.Context, name string, expire time.Duration, opt *cos.PresignedURLOptions) (*url.URL, error) {
if !config.Config.Object.Cos.PublicRead {
if !c.publicRead {
return c.client.Object.GetPresignedURL(ctx, http.MethodGet, name, c.credential.SecretID, c.credential.SecretKey, expire, opt)
}
return c.client.Object.GetObjectURL(name), nil
+9 -9
View File
@@ -47,27 +47,27 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
imgWidth := bounds.Max.X
imgHeight := bounds.Max.Y
// 计算缩放比例
// Calculating scaling
scaleWidth := float64(maxWidth) / float64(imgWidth)
scaleHeight := float64(maxHeight) / float64(imgHeight)
// 如果都为0,则不缩放,返回原始图片
// If both are 0, then no scaling is done and the original image is returned
if maxWidth == 0 && maxHeight == 0 {
return img
}
// 如果宽度和高度都大于0,则选择较小的缩放比例,以保持宽高比
// If both width and height are greater than 0, select a smaller zoom ratio to maintain the aspect ratio
if maxWidth > 0 && maxHeight > 0 {
scale := scaleWidth
if scaleHeight < scaleWidth {
scale = scaleHeight
}
// 计算缩略图尺寸
// Calculate Thumbnail Size
thumbnailWidth := int(float64(imgWidth) * scale)
thumbnailHeight := int(float64(imgHeight) * scale)
// 使用"image"库的Resample方法生成缩略图
// Thumbnails are generated using the Resample method of the "image" library.
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
@@ -80,12 +80,12 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
return thumbnail
}
// 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图
// If only width or height is specified, thumbnails are generated based on the maximum not to exceed rule
if maxWidth > 0 {
thumbnailWidth := maxWidth
thumbnailHeight := int(float64(imgHeight) * scaleWidth)
// 使用"image"库的Resample方法生成缩略图
// Thumbnails are generated using the Resample method of the "image" library.
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
@@ -102,7 +102,7 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
thumbnailWidth := int(float64(imgWidth) * scaleHeight)
thumbnailHeight := maxHeight
// 使用"image"库的Resample方法生成缩略图
// Thumbnails are generated using the Resample method of the "image" library.
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
@@ -115,6 +115,6 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
return thumbnail
}
// 默认情况下,返回原始图片
// By default, the original image is returned
return img
}
+36 -30
View File
@@ -29,14 +29,12 @@ import (
"time"
"unsafe"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/minio/minio-go/v7/pkg/signer"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
)
@@ -45,7 +43,7 @@ const (
)
const (
minPartSize int64 = 1024 * 1024 * 5 // 1MB
minPartSize int64 = 1024 * 1024 * 5 // 5MB
maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB
maxNumSize int64 = 10000
)
@@ -59,13 +57,23 @@ const (
const successCode = http.StatusOK
func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
u, err := url.Parse(config.Config.Object.Minio.Endpoint)
type Config struct {
Bucket string
Endpoint string
AccessKeyID string
SecretAccessKey string
SessionToken string
SignEndpoint string
PublicRead bool
}
func NewMinio(cache cache.MinioCache, conf Config) (s3.Interface, error) {
u, err := url.Parse(conf.Endpoint)
if err != nil {
return nil, err
}
opts := &minio.Options{
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken),
Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
Secure: u.Scheme == "https",
}
client, err := minio.New(u.Host, opts)
@@ -73,26 +81,27 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
return nil, err
}
m := &Minio{
bucket: config.Config.Object.Minio.Bucket,
conf: conf,
bucket: conf.Bucket,
core: &minio.Core{Client: client},
lock: &sync.Mutex{},
init: false,
cache: cache,
}
if config.Config.Object.Minio.SignEndpoint == "" || config.Config.Object.Minio.SignEndpoint == config.Config.Object.Minio.Endpoint {
if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint {
m.opts = opts
m.sign = m.core.Client
m.prefix = u.Path
u.Path = ""
config.Config.Object.Minio.Endpoint = u.String()
m.signEndpoint = config.Config.Object.Minio.Endpoint
conf.Endpoint = u.String()
m.signEndpoint = conf.Endpoint
} else {
su, err := url.Parse(config.Config.Object.Minio.SignEndpoint)
su, err := url.Parse(conf.SignEndpoint)
if err != nil {
return nil, err
}
m.opts = &minio.Options{
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken),
Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
Secure: su.Scheme == "https",
}
m.sign, err = minio.New(su.Host, m.opts)
@@ -101,8 +110,8 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
}
m.prefix = su.Path
su.Path = ""
config.Config.Object.Minio.SignEndpoint = su.String()
m.signEndpoint = config.Config.Object.Minio.SignEndpoint
conf.SignEndpoint = su.String()
m.signEndpoint = conf.SignEndpoint
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -113,6 +122,7 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
}
type Minio struct {
conf Config
bucket string
signEndpoint string
location string
@@ -134,31 +144,30 @@ func (m *Minio) initMinio(ctx context.Context) error {
if m.init {
return nil
}
conf := config.Config.Object.Minio
exists, err := m.core.Client.BucketExists(ctx, conf.Bucket)
exists, err := m.core.Client.BucketExists(ctx, m.conf.Bucket)
if err != nil {
return fmt.Errorf("check bucket exists error: %w", err)
}
if !exists {
if err := m.core.Client.MakeBucket(ctx, conf.Bucket, minio.MakeBucketOptions{}); err != nil {
if err = m.core.Client.MakeBucket(ctx, m.conf.Bucket, minio.MakeBucketOptions{}); err != nil {
return fmt.Errorf("make bucket error: %w", err)
}
}
if conf.PublicRead {
if m.conf.PublicRead {
policy := fmt.Sprintf(
`{"Version": "2012-10-17","Statement": [{"Action": ["s3:GetObject","s3:PutObject"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::%s/*"],"Sid": ""}]}`,
conf.Bucket,
m.conf.Bucket,
)
if err := m.core.Client.SetBucketPolicy(ctx, conf.Bucket, policy); err != nil {
if err = m.core.Client.SetBucketPolicy(ctx, m.conf.Bucket, policy); err != nil {
return err
}
}
m.location, err = m.core.Client.GetBucketLocation(ctx, conf.Bucket)
m.location, err = m.core.Client.GetBucketLocation(ctx, m.conf.Bucket)
if err != nil {
return err
}
func() {
if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint {
if m.conf.SignEndpoint == "" || m.conf.SignEndpoint == m.conf.Endpoint {
return
}
defer func() {
@@ -178,7 +187,7 @@ func (m *Minio) initMinio(ctx context.Context) error {
blc := reflect.ValueOf(m.sign).Elem().FieldByName("bucketLocCache")
vblc := reflect.New(reflect.PtrTo(blc.Type()))
*(*unsafe.Pointer)(vblc.UnsafePointer()) = unsafe.Pointer(blc.UnsafeAddr())
vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(conf.Bucket, m.location)
vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(m.conf.Bucket, m.location)
}()
m.init = true
return nil
@@ -343,10 +352,7 @@ func (m *Minio) CopyObject(ctx context.Context, src string, dst string) (*s3.Cop
}
func (m *Minio) IsNotFound(err error) bool {
if err == nil {
return false
}
switch e := err.(type) {
switch e := errs.Unwrap(err).(type) {
case minio.ErrorResponse:
return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
case *minio.ErrorResponse:
@@ -399,7 +405,7 @@ func (m *Minio) PresignedGetObject(ctx context.Context, name string, expire time
rawURL *url.URL
err error
)
if config.Config.Object.Minio.PublicRead {
if m.conf.PublicRead {
rawURL, err = makeTargetURL(m.sign, m.bucket, name, m.location, false, query)
} else {
rawURL, err = m.sign.PresignedGetObject(ctx, m.bucket, name, expire, query)
+3 -3
View File
@@ -31,7 +31,6 @@ import (
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/minio/minio-go/v7"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
)
@@ -82,7 +81,8 @@ func (m *Minio) getImageThumbnailURL(ctx context.Context, name string, expire ti
}
key, err := m.cache.GetThumbnailKey(ctx, name, opt.Format, opt.Width, opt.Height, func(ctx context.Context) (string, error) {
if img == nil {
reader, err := m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{})
var reader *minio.Object
reader, err = m.core.Client.GetObject(ctx, m.bucket, name, minio.GetObjectOptions{})
if err != nil {
return "", err
}
@@ -103,7 +103,7 @@ func (m *Minio) getImageThumbnailURL(ctx context.Context, name string, expire ti
err = gif.Encode(buf, thumbnail, nil)
}
cacheKey := filepath.Join(imageThumbnailPath, info.Etag, fmt.Sprintf("image_w%d_h%d.%s", opt.Width, opt.Height, opt.Format))
if _, err := m.core.Client.PutObject(ctx, m.bucket, cacheKey, buf, int64(buf.Len()), minio.PutObjectOptions{}); err != nil {
if _, err = m.core.Client.PutObject(ctx, m.bucket, cacheKey, buf, int64(buf.Len()), minio.PutObjectOptions{}); err != nil {
return "", err
}
return cacheKey, nil
+31 -27
View File
@@ -30,9 +30,8 @@ import (
"strings"
"time"
"github.com/OpenIMSDK/tools/errs"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
)
@@ -52,15 +51,19 @@ const (
const successCode = http.StatusOK
const (
videoSnapshotImagePng = "png"
videoSnapshotImageJpg = "jpg"
)
type Config struct {
Endpoint string
Bucket string
BucketURL string
AccessKeyID string
AccessKeySecret string
SessionToken string
PublicRead bool
}
func NewOSS() (s3.Interface, error) {
conf := config.Config.Object.Oss
func NewOSS(conf Config) (s3.Interface, error) {
if conf.BucketURL == "" {
return nil, errors.New("bucket url is empty")
return nil, errs.Wrap(errors.New("bucket url is empty"))
}
client, err := oss.New(conf.Endpoint, conf.AccessKeyID, conf.AccessKeySecret)
if err != nil {
@@ -68,7 +71,7 @@ func NewOSS() (s3.Interface, error) {
}
bucket, err := client.Bucket(conf.Bucket)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "ali-oss bucket error")
}
if conf.BucketURL[len(conf.BucketURL)-1] != '/' {
conf.BucketURL += "/"
@@ -78,6 +81,7 @@ func NewOSS() (s3.Interface, error) {
bucket: bucket,
credentials: client.Config.GetCredentials(),
um: *(*urlMaker)(reflect.ValueOf(bucket.Client.Conn).Elem().FieldByName("url").UnsafePointer()),
publicRead: conf.PublicRead,
}, nil
}
@@ -86,6 +90,7 @@ type OSS struct {
bucket *oss.Bucket
credentials oss.Credentials
um urlMaker
publicRead bool
}
func (o *OSS) Engine() string {
@@ -138,10 +143,10 @@ func (o *OSS) CompleteMultipartUpload(ctx context.Context, uploadID string, name
func (o *OSS) PartSize(ctx context.Context, size int64) (int64, error) {
if size <= 0 {
return 0, errors.New("size must be greater than 0")
return 0, errs.Wrap(errors.New("size must be greater than 0"))
}
if size > maxPartSize*maxNumSize {
return 0, fmt.Errorf("OSS size must be less than the maximum allowed limit")
return 0, errs.Wrap(errors.New("size must be less than the maximum allowed limit"))
}
if size <= minPartSize*maxNumSize {
return minPartSize, nil
@@ -196,25 +201,25 @@ func (o *OSS) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, erro
}
res := &s3.ObjectInfo{Key: name}
if res.ETag = strings.ToLower(strings.ReplaceAll(header.Get("ETag"), `"`, ``)); res.ETag == "" {
return nil, errors.New("StatObject etag not found")
return nil, errs.Wrap(errors.New("StatObject etag not found"))
}
if contentLengthStr := header.Get("Content-Length"); contentLengthStr == "" {
return nil, errors.New("StatObject content-length not found")
} else {
res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("StatObject content-length parse error: %w", err)
return nil, errs.Wrap(err, "StatObject content-length parse error")
}
if res.Size < 0 {
return nil, errors.New("StatObject content-length must be greater than 0")
return nil, errs.Wrap(errors.New("StatObject content-length must be greater than 0"))
}
}
if lastModified := header.Get("Last-Modified"); lastModified == "" {
return nil, errors.New("StatObject last-modified not found")
return nil, errs.Wrap(errors.New("StatObject last-modified not found"))
} else {
res.LastModified, err = time.Parse(http.TimeFormat, lastModified)
if err != nil {
return nil, fmt.Errorf("StatObject last-modified parse error: %w", err)
return nil, errs.Wrap(err, "StatObject last-modified parse error")
}
}
return res, nil
@@ -227,7 +232,7 @@ func (o *OSS) DeleteObject(ctx context.Context, name string) error {
func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
result, err := o.bucket.CopyObject(src, dst)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "CopyObject error")
}
return &s3.CopyObjectInfo{
Key: dst,
@@ -236,7 +241,7 @@ func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO
}
func (o *OSS) IsNotFound(err error) bool {
switch e := err.(type) {
switch e := errs.Unwrap(err).(type) {
case oss.ServiceError:
return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
case *oss.ServiceError:
@@ -261,7 +266,7 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
Bucket: o.bucket.BucketName,
}, oss.MaxUploads(100), oss.MaxParts(maxParts), oss.PartNumberMarker(partNumberMarker))
if err != nil {
return nil, err
return nil, errs.Wrap(err, "ListUploadedParts error")
}
res := &s3.ListUploadedPartsResult{
Key: result.Key,
@@ -282,11 +287,10 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
}
func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
publicRead := config.Config.Object.Oss.PublicRead
var opts []oss.Option
if opt != nil {
if opt.Image != nil {
// 文档地址: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji
// Docs Address: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji
var format string
switch opt.Image.Format {
case
@@ -310,7 +314,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
process += ",format," + format
opts = append(opts, oss.Process(process))
}
if !publicRead {
if !o.publicRead {
if opt.ContentType != "" {
opts = append(opts, oss.ResponseContentType(opt.ContentType))
}
@@ -324,12 +328,12 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
} else if expire < time.Second {
expire = time.Second
}
if !publicRead {
if !o.publicRead {
return o.bucket.SignURL(name, http.MethodGet, int64(expire/time.Second), opts...)
}
rawParams, err := oss.GetRawParams(opts)
if err != nil {
return "", err
return "", errs.Wrap(err, "AccessURL error")
}
params := getURLParams(*o.bucket.Client.Conn, rawParams)
return getURL(o.um, o.bucket.BucketName, name, params).String(), nil
@@ -351,12 +355,12 @@ func (o *OSS) FormData(ctx context.Context, name string, size int64, contentType
}
policyJson, err := json.Marshal(policy)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Marshal json error")
}
policyStr := base64.StdEncoding.EncodeToString(policyJson)
h := hmac.New(sha1.New, []byte(o.credentials.GetAccessKeySecret()))
if _, err := io.WriteString(h, policyStr); err != nil {
return nil, err
return nil, errs.Wrap(err, "WriteString error")
}
fd := &s3.FormData{
URL: o.bucketURL,
+2 -2
View File
@@ -46,8 +46,8 @@ type GroupModelInterface interface {
Find(ctx context.Context, groupIDs []string) (groups []*GroupModel, err error)
Take(ctx context.Context, groupID string) (group *GroupModel, err error)
Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*GroupModel, err error)
// 获取群总数
// Get Group total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量
// Get Group total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
}
+2 -3
View File
@@ -19,7 +19,6 @@ import (
"time"
"github.com/OpenIMSDK/protocol/user"
"github.com/OpenIMSDK/tools/pagination"
)
@@ -62,9 +61,9 @@ type UserModelInterface interface {
Exist(ctx context.Context, userID string) (exist bool, err error)
GetAllUserID(ctx context.Context, pagination pagination.Pagination) (count int64, userIDs []string, err error)
GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error)
// 获取用户总数
// Get user total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内用户增量
// Get user total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
//CRUD user command
AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error
+1 -3
View File
@@ -20,10 +20,8 @@ import (
"time"
"github.com/OpenIMSDK/protocol/msg"
"go.mongodb.org/mongo-driver/mongo"
"github.com/OpenIMSDK/protocol/sdkws"
"go.mongodb.org/mongo-driver/mongo"
)
const (
+28 -30
View File
@@ -21,16 +21,13 @@ import (
"strings"
"time"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mw/specialerror"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/mw/specialerror"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
const (
@@ -39,13 +36,14 @@ const (
)
type Mongo struct {
db *mongo.Client
db *mongo.Client
config *config.GlobalConfig
}
// NewMongo Initialize MongoDB connection.
func NewMongo() (*Mongo, error) {
func NewMongo(config *config.GlobalConfig) (*Mongo, error) {
specialerror.AddReplace(mongo.ErrNoDocuments, errs.ErrRecordNotFound)
uri := buildMongoURI()
uri := buildMongoURI(config)
var mongoClient *mongo.Client
var err error
@@ -56,26 +54,27 @@ func NewMongo() (*Mongo, error) {
defer cancel()
mongoClient, err = mongo.Connect(ctx, options.Client().ApplyURI(uri))
if err == nil {
return &Mongo{db: mongoClient}, nil
if err = mongoClient.Ping(ctx, nil); err != nil {
return nil, errs.Wrap(err, uri)
}
return &Mongo{db: mongoClient, config: config}, nil
}
if shouldRetry(err) {
fmt.Printf("Failed to connect to MongoDB, retrying: %s\n", err)
time.Sleep(time.Second) // exponential backoff could be implemented here
continue
}
return nil, err
}
return nil, err
return nil, errs.Wrap(err, uri)
}
func buildMongoURI() string {
func buildMongoURI(config *config.GlobalConfig) string {
uri := os.Getenv("MONGO_URI")
if uri != "" {
return uri
}
if config.Config.Mongo.Uri != "" {
return config.Config.Mongo.Uri
if config.Mongo.Uri != "" {
return config.Mongo.Uri
}
username := os.Getenv("MONGO_OPENIM_USERNAME")
@@ -86,29 +85,28 @@ func buildMongoURI() string {
maxPoolSize := os.Getenv("MONGO_MAX_POOL_SIZE")
if username == "" {
username = config.Config.Mongo.Username
username = config.Mongo.Username
}
if password == "" {
password = config.Config.Mongo.Password
password = config.Mongo.Password
}
if address == "" {
address = strings.Join(config.Config.Mongo.Address, ",")
address = strings.Join(config.Mongo.Address, ",")
} else if port != "" {
address = fmt.Sprintf("%s:%s", address, port)
}
if database == "" {
database = config.Config.Mongo.Database
database = config.Mongo.Database
}
if maxPoolSize == "" {
maxPoolSize = fmt.Sprint(config.Config.Mongo.MaxPoolSize)
maxPoolSize = fmt.Sprint(config.Mongo.MaxPoolSize)
}
uriFormat := "mongodb://%s/%s?maxPoolSize=%s"
if username != "" && password != "" {
uriFormat = "mongodb://%s:%s@%s/%s?maxPoolSize=%s"
return fmt.Sprintf(uriFormat, username, password, address, database, maxPoolSize)
return fmt.Sprintf("mongodb://%s:%s@%s/%s?maxPoolSize=%s", username, password, address, database, maxPoolSize)
}
return fmt.Sprintf(uriFormat, address, database, maxPoolSize)
return fmt.Sprintf("mongodb://%s/%s?maxPoolSize=%s", address, database, maxPoolSize)
}
func shouldRetry(err error) bool {
@@ -124,8 +122,8 @@ func (m *Mongo) GetClient() *mongo.Client {
}
// GetDatabase returns the specific database from MongoDB.
func (m *Mongo) GetDatabase() *mongo.Database {
return m.db.Database(config.Config.Mongo.Database)
func (m *Mongo) GetDatabase(database string) *mongo.Database {
return m.db.Database(database)
}
// CreateMsgIndex creates an index for messages in MongoDB.
@@ -135,7 +133,7 @@ func (m *Mongo) CreateMsgIndex() error {
// createMongoIndex creates an index in a MongoDB collection.
func (m *Mongo) createMongoIndex(collection string, isUnique bool, keys ...string) error {
db := m.GetDatabase().Collection(collection)
db := m.GetDatabase(m.config.Mongo.Database).Collection(collection)
opts := options.CreateIndexes().SetMaxTime(10 * time.Second)
indexView := db.Indexes()
@@ -150,7 +148,7 @@ func (m *Mongo) createMongoIndex(collection string, isUnique bool, keys ...strin
_, err := indexView.CreateOne(context.Background(), index, opts)
if err != nil {
return utils.Wrap(err, "CreateIndex")
return errs.Wrap(err, "CreateIndex")
}
return nil
}
+76 -134
View File
@@ -21,23 +21,17 @@ import (
"fmt"
"time"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"google.golang.org/protobuf/proto"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/utils"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
var ErrMsgListNotExist = errors.New("user not have msg in mongoDB")
@@ -79,7 +73,7 @@ func (m *MsgMongoDriver) UpdateMsg(
update := bson.M{"$set": bson.M{field: value}}
res, err := m.MsgCollection.UpdateOne(ctx, filter, update)
if err != nil {
return nil, utils.Wrap(err, "")
return nil, errs.Wrap(err)
}
return res, nil
}
@@ -106,7 +100,7 @@ func (m *MsgMongoDriver) PushUnique(
}
res, err := m.MsgCollection.UpdateOne(ctx, filter, update)
if err != nil {
return nil, utils.Wrap(err, "")
return nil, errs.Wrap(err)
}
return res, nil
}
@@ -118,22 +112,16 @@ func (m *MsgMongoDriver) UpdateMsgContent(ctx context.Context, docID string, ind
bson.M{"$set": bson.M{fmt.Sprintf("msgs.%d.msg", index): msg}},
)
if err != nil {
return utils.Wrap(err, "")
return errs.Wrap(err)
}
return nil
}
func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(
ctx context.Context,
docID string,
msg *sdkws.MsgData,
seqIndex int,
status int32,
) error {
func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(ctx context.Context, docID string, msg *sdkws.MsgData, seqIndex int, status int32) error {
msg.Status = status
bytes, err := proto.Marshal(msg)
if err != nil {
return utils.Wrap(err, "")
return errs.Wrap(err)
}
_, err = m.MsgCollection.UpdateOne(
ctx,
@@ -141,7 +129,7 @@ func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(
bson.M{"$set": bson.M{fmt.Sprintf("msgs.%d.msg", seqIndex): bytes}},
)
if err != nil {
return utils.Wrap(err, "")
return errs.Wrap(err, fmt.Sprintf("docID is %s, seqIndex is %d", docID, seqIndex))
}
return nil
}
@@ -167,12 +155,12 @@ func (m *MsgMongoDriver) GetMsgDocModelByIndex(
findOpts,
)
if err != nil {
return nil, utils.Wrap(err, "")
return nil, errs.Wrap(err, fmt.Sprintf("conversationID is %s", conversationID))
}
var msgs []table.MsgDocModel
err = cursor.All(ctx, &msgs)
if err != nil {
return nil, utils.Wrap(err, fmt.Sprintf("cursor is %s", cursor.Current.String()))
return nil, errs.Wrap(err, fmt.Sprintf("cursor is %s", cursor.Current.String()))
}
if len(msgs) > 0 {
return &msgs[0], nil
@@ -223,7 +211,7 @@ func (m *MsgMongoDriver) DeleteMsgsInOneDocByIndex(ctx context.Context, docID st
}
_, err := m.MsgCollection.UpdateMany(ctx, bson.M{"doc_id": docID}, updates)
if err != nil {
return utils.Wrap(err, "")
return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
}
return nil
}
@@ -247,47 +235,42 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
indexs = append(indexs, m.model.GetMsgIndex(seq))
}
pipeline := mongo.Pipeline{
{
{"$match", bson.D{
{"doc_id", docID},
}},
},
{
{"$project", bson.D{
{"_id", 0},
{"doc_id", 1},
{"msgs", bson.D{
{"$map", bson.D{
{"input", indexs},
{"as", "index"},
{"in", bson.D{
{"$let", bson.D{
{"vars", bson.D{
{"currentMsg", bson.D{
{"$arrayElemAt", []string{"$msgs", "$$index"}},
}},
bson.D{{Key: "$match", Value: bson.D{
{Key: "doc_id", Value: docID},
}}},
bson.D{{Key: "$project", Value: bson.D{
{Key: "_id", Value: 0},
{Key: "doc_id", Value: 1},
{Key: "msgs", Value: bson.D{
{Key: "$map", Value: bson.D{
{Key: "input", Value: indexs},
{Key: "as", Value: "index"},
{Key: "in", Value: bson.D{
{Key: "$let", Value: bson.D{
{Key: "vars", Value: bson.D{
{Key: "currentMsg", Value: bson.D{
{Key: "$arrayElemAt", Value: bson.A{"$msgs", "$$index"}},
}},
{"in", bson.D{
{"$cond", bson.D{
{"if", bson.D{
{"$in", []string{userID, "$$currentMsg.del_list"}},
}},
{"then", nil},
{"else", "$$currentMsg"},
}},
{Key: "in", Value: bson.D{
{Key: "$cond", Value: bson.D{
{Key: "if", Value: bson.D{
{Key: "$in", Value: bson.A{userID, "$$currentMsg.del_list"}},
}},
{Key: "then", Value: nil},
{Key: "else", Value: "$$currentMsg"},
}},
}},
}},
}},
}},
}},
},
{
{"$project", bson.D{
{"msgs.del_list", 0},
}},
},
}}},
bson.D{{Key: "$project", Value: bson.D{
{Key: "msgs.del_list", Value: 0},
}}},
}
cur, err := m.MsgCollection.Aggregate(ctx, pipeline)
if err != nil {
return nil, errs.Wrap(err)
@@ -295,7 +278,7 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
defer cur.Close(ctx)
var msgDocModel []table.MsgDocModel
if err := cur.All(ctx, &msgDocModel); err != nil {
return nil, errs.Wrap(err)
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
if len(msgDocModel) == 0 {
return nil, errs.Wrap(mongo.ErrNoDocuments)
@@ -322,14 +305,14 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
}
data, err := json.Marshal(&revokeContent)
if err != nil {
return nil, err
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
elem := sdkws.NotificationElem{
Detail: string(data),
}
content, err := json.Marshal(&elem)
if err != nil {
return nil, err
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
msg.Msg.ContentType = constant.MsgRevokeNotification
msg.Msg.Content = string(content)
@@ -342,17 +325,12 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
func (m *MsgMongoDriver) IsExistDocID(ctx context.Context, docID string) (bool, error) {
count, err := m.MsgCollection.CountDocuments(ctx, bson.M{"doc_id": docID})
if err != nil {
return false, errs.Wrap(err)
return false, errs.Wrap(err, fmt.Sprintf("docID is %s", docID))
}
return count > 0, nil
}
func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(
ctx context.Context,
userID string,
docID string,
indexes []int64,
) error {
func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(ctx context.Context, userID string, docID string, indexes []int64) error {
updates := []mongo.WriteModel{}
for _, index := range indexes {
filter := bson.M{
@@ -372,7 +350,7 @@ func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(
updates = append(updates, updateModel)
}
_, err := m.MsgCollection.BulkWrite(ctx, updates)
return err
return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
}
// RangeUserSendCount
@@ -668,7 +646,7 @@ func (m *MsgMongoDriver) RangeUserSendCount(
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳
"$toDate": "$$item.msg.send_time", // Millisecond timestamp
},
},
},
@@ -801,7 +779,7 @@ func (m *MsgMongoDriver) RangeUserSendCount(
}
defer cur.Close(ctx)
var result []Result
if err := cur.All(ctx, &result); err != nil {
if err = cur.All(ctx, &result); err != nil {
return 0, 0, nil, nil, errs.Wrap(err)
}
if len(result) == 0 {
@@ -917,7 +895,7 @@ func (m *MsgMongoDriver) RangeGroupSendCount(
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳
"$toDate": "$$item.msg.send_time", // Millisecond timestamp
},
},
},
@@ -1050,7 +1028,7 @@ func (m *MsgMongoDriver) RangeGroupSendCount(
}
defer cur.Close(ctx)
var result []Result
if err := cur.All(ctx, &result); err != nil {
if err = cur.All(ctx, &result); err != nil {
return 0, 0, nil, nil, errs.Wrap(err)
}
if len(result) == 0 {
@@ -1082,6 +1060,7 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var pipe mongo.Pipeline
condition := bson.A{}
if req.SendTime != "" {
// Changed to keyed fields for bson.M to avoid govet errors
condition = append(condition, bson.M{"$eq": bson.A{bson.M{"$dateToString": bson.M{"format": "%Y-%m-%d", "date": bson.M{"$toDate": "$$item.msg.send_time"}}}, req.SendTime}})
}
if req.MsgType != 0 {
@@ -1098,62 +1077,26 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
}
or := bson.A{
bson.M{
"doc_id": bson.M{
"$regex": "^si_",
"$options": "i",
},
},
bson.M{"doc_id": bson.M{"$regex": "^si_", "$options": "i"}},
bson.M{"doc_id": bson.M{"$regex": "^g_", "$options": "i"}},
bson.M{"doc_id": bson.M{"$regex": "^sg_", "$options": "i"}},
}
or = append(or,
bson.M{
"doc_id": bson.M{
"$regex": "^g_",
"$options": "i",
},
},
bson.M{
"doc_id": bson.M{
"$regex": "^sg_",
"$options": "i",
},
},
)
// Use bson.D with keyed fields to specify the order explicitly
pipe = mongo.Pipeline{
{
{"$match", bson.D{
{
"$or", or,
},
{{"$match", bson.D{{Key: "$or", Value: or}}}},
{{"$project", bson.D{
{Key: "msgs", Value: bson.D{
{Key: "$filter", Value: bson.D{
{Key: "input", Value: "$msgs"},
{Key: "as", Value: "item"},
{Key: "cond", Value: bson.D{{Key: "$and", Value: condition}}},
}},
}},
},
{
{"$project", bson.D{
{
"msgs", bson.D{
{
"$filter", bson.D{
{"input", "$msgs"},
{"as", "item"},
{
"cond", bson.D{
{"$and", condition},
},
},
},
},
},
},
{"doc_id", 1},
}},
},
{
{"$unwind", bson.M{"path": "$msgs"}},
},
{
{"$sort", bson.M{"msgs.msg.send_time": -1}},
},
{Key: "doc_id", Value: 1},
}}},
{{"$unwind", bson.M{"path": "$msgs"}}},
{{"$sort", bson.M{"msgs.msg.send_time": -1}}},
}
cursor, err := m.MsgCollection.Aggregate(ctx, pipe)
if err != nil {
@@ -1166,12 +1109,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var msgsDocs []docModel
err = cursor.All(ctx, &msgsDocs)
if err != nil {
return 0, nil, err
return 0, nil, errs.Wrap(err, "cursor.All msgsDocs")
}
log.ZDebug(ctx, "query mongoDB", "result", msgsDocs)
msgs := make([]*table.MsgInfoModel, 0)
for index := range msgsDocs {
msgInfo := msgsDocs[index].Msg
for _, doc := range msgsDocs {
msgInfo := doc.Msg
if msgInfo == nil || msgInfo.Msg == nil {
continue
}
@@ -1191,14 +1134,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
}
data, err := json.Marshal(&revokeContent)
if err != nil {
return 0, nil, err
}
elem := sdkws.NotificationElem{
Detail: string(data),
return 0, nil, errs.Wrap(err, "json.Marshal revokeContent")
}
elem := sdkws.NotificationElem{Detail: string(data)}
content, err := json.Marshal(&elem)
if err != nil {
return 0, nil, err
return 0, nil, errs.Wrap(err, "json.Marshal elem")
}
msgInfo.Msg.ContentType = constant.MsgRevokeNotification
msgInfo.Msg.Content = string(content)
@@ -1209,7 +1150,8 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
n := int32(len(msgs))
if start >= n {
return n, []*table.MsgInfoModel{}, nil
} else if start+req.Pagination.ShowNumber < n {
}
if start+req.Pagination.ShowNumber < n {
msgs = msgs[start : start+req.Pagination.ShowNumber]
} else {
msgs = msgs[start:]
+1 -2
View File
@@ -19,10 +19,9 @@ import (
"fmt"
"github.com/OpenIMSDK/tools/log"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
func (m *MsgMongoDriver) ConvertMsgsDocLen(ctx context.Context, conversationIDs []string) {
+3 -5
View File
@@ -18,12 +18,10 @@ import (
"context"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
// prefixes and suffixes.
@@ -65,7 +63,7 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string
}
// iterate over aggregated results
for cursor.Next(ctx) {
err := cursor.Decode(&cnt)
err = cursor.Decode(&cnt)
if err != nil {
return errs.Wrap(err)
}
@@ -119,7 +117,7 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string
opts,
)
if err != nil {
return utils.Wrap(err, "transaction failed")
return errs.Wrap(err, "transaction failed")
}
}
return nil