Merge remote-tracking branch 'origin/main' into v3.8-js-sdk-only

# Conflicts:
#	go.mod
#	go.sum
#	pkg/common/storage/controller/auth.go
This commit is contained in:
withchao
2024-10-25 17:31:26 +08:00
36 changed files with 967 additions and 373 deletions
+22 -4
View File
@@ -361,11 +361,29 @@ type AfterConfig struct {
}
type Share struct {
Secret string `mapstructure:"secret"`
RpcRegisterName RpcRegisterName `mapstructure:"rpcRegisterName"`
IMAdminUserID []string `mapstructure:"imAdminUserID"`
MultiLoginPolicy int `mapstructure:"multiLoginPolicy"`
Secret string `mapstructure:"secret"`
RpcRegisterName RpcRegisterName `mapstructure:"rpcRegisterName"`
IMAdminUserID []string `mapstructure:"imAdminUserID"`
MultiLogin MultiLogin `mapstructure:"multiLogin"`
}
type MultiLogin struct {
Policy int `mapstructure:"policy"`
MaxNumOneEnd int `mapstructure:"maxNumOneEnd"`
CustomizeLoginNum struct {
IOS int `mapstructure:"ios"`
Android int `mapstructure:"android"`
Windows int `mapstructure:"windows"`
OSX int `mapstructure:"osx"`
Web int `mapstructure:"web"`
MiniWeb int `mapstructure:"miniWeb"`
Linux int `mapstructure:"linux"`
APad int `mapstructure:"aPad"`
IPad int `mapstructure:"iPad"`
Admin int `mapstructure:"admin"`
} `mapstructure:"customizeLoginNum"`
}
type RpcRegisterName struct {
User string `mapstructure:"user"`
Friend string `mapstructure:"friend"`
+11
View File
@@ -0,0 +1,11 @@
package cache
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
)
type ApplicationCache interface {
LatestVersion(ctx context.Context, platform string) (*model.Application, error)
DeleteCache(ctx context.Context, platforms []string) error
}
+9
View File
@@ -0,0 +1,9 @@
package cachekey
const (
ApplicationLatestVersion = "APPLICATION_LATEST_VERSION:"
)
func GetApplicationLatestVersionKey(platform string) string {
return ApplicationLatestVersion + platform
}
+18 -1
View File
@@ -1,6 +1,9 @@
package cachekey
import "github.com/openimsdk/protocol/constant"
import (
"github.com/openimsdk/protocol/constant"
"strings"
)
const (
UidPidToken = "UID_PID_TOKEN_STATUS:"
@@ -9,3 +12,17 @@ const (
func GetTokenKey(userID string, platformID int) string {
return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
}
func GetAllPlatformTokenKey(userID string) []string {
res := make([]string, len(constant.PlatformID2Name))
for k := range constant.PlatformID2Name {
res[k-1] = GetTokenKey(userID, k)
}
return res
}
func GetPlatformIDByTokenKey(key string) int {
splitKey := strings.Split(key, ":")
platform := splitKey[len(splitKey)-1]
return constant.PlatformNameToID(platform)
}
+43
View File
@@ -0,0 +1,43 @@
package redis
import (
"context"
"github.com/dtm-labs/rockscache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
"time"
)
func NewApplicationRedisCache(db database.Application, rdb redis.UniversalClient) *ApplicationRedisCache {
return &ApplicationRedisCache{
db: db,
rcClient: rockscache.NewClient(rdb, *GetRocksCacheOptions()),
deleter: NewBatchDeleterRedis(rdb, GetRocksCacheOptions(), nil),
expireTime: time.Hour * 24 * 7,
}
}
type ApplicationRedisCache struct {
db database.Application
rcClient *rockscache.Client
deleter *BatchDeleterRedis
expireTime time.Duration
}
func (a *ApplicationRedisCache) LatestVersion(ctx context.Context, platform string) (*model.Application, error) {
return getCache(ctx, a.rcClient, cachekey.GetApplicationLatestVersionKey(platform), a.expireTime, func(ctx context.Context) (*model.Application, error) {
return a.db.LatestVersion(ctx, platform)
})
}
func (a *ApplicationRedisCache) DeleteCache(ctx context.Context, platforms []string) error {
if len(platforms) == 0 {
return nil
}
return a.deleter.ExecDelWithKeys(ctx, datautil.Slice(platforms, func(platform string) string {
return cachekey.GetApplicationLatestVersionKey(platform)
}))
}
+50 -14
View File
@@ -1,17 +1,3 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package redis
import (
@@ -21,6 +7,7 @@ import (
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
"strconv"
"sync"
"time"
)
@@ -67,6 +54,43 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
return mm, nil
}
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
var (
res = make(map[int]map[string]int)
resLock = sync.Mutex{}
)
keys := cachekey.GetAllPlatformTokenKey(userID)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
mapRes := make([]*redis.MapStringStringCmd, len(keys))
for i, key := range keys {
mapRes[i] = pipe.HGetAll(ctx, key)
}
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
for i, m := range mapRes {
mm := make(map[string]int)
for k, v := range m.Val() {
state, err := strconv.Atoi(v)
if err != nil {
return errs.WrapMsg(err, "redis token value is not int", "value", v, "userID", userID)
}
mm[k] = state
}
resLock.Lock()
res[cachekey.GetPlatformIDByTokenKey(keys[i])] = mm
resLock.Unlock()
}
return nil
}); err != nil {
return nil, err
}
return res, nil
}
func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
mm := make(map[string]any)
for k, v := range m {
@@ -75,6 +99,18 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla
return errs.Wrap(c.rdb.HSet(ctx, cachekey.GetTokenKey(userID, platformID), mm).Err())
}
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]int) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HSet(ctx, k, v)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
return errs.Wrap(c.rdb.HDel(ctx, cachekey.GetTokenKey(userID, platformID), fields...).Err())
}
+2
View File
@@ -9,6 +9,8 @@ type TokenModel interface {
// SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]int) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
}
@@ -0,0 +1,69 @@
package controller
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type ApplicationDatabase interface {
LatestVersion(ctx context.Context, platform string) (*model.Application, error)
AddVersion(ctx context.Context, val *model.Application) error
UpdateVersion(ctx context.Context, id primitive.ObjectID, update map[string]any) error
DeleteVersion(ctx context.Context, id []primitive.ObjectID) error
PageVersion(ctx context.Context, platforms []string, page pagination.Pagination) (int64, []*model.Application, error)
}
func NewApplicationDatabase(db database.Application, cache cache.ApplicationCache) ApplicationDatabase {
return &applicationDatabase{db: db, cache: cache}
}
type applicationDatabase struct {
db database.Application
cache cache.ApplicationCache
}
func (a *applicationDatabase) LatestVersion(ctx context.Context, platform string) (*model.Application, error) {
return a.cache.LatestVersion(ctx, platform)
}
func (a *applicationDatabase) AddVersion(ctx context.Context, val *model.Application) error {
if err := a.db.AddVersion(ctx, val); err != nil {
return err
}
return a.cache.DeleteCache(ctx, []string{val.Platform})
}
func (a *applicationDatabase) UpdateVersion(ctx context.Context, id primitive.ObjectID, update map[string]any) error {
platforms, err := a.db.FindPlatform(ctx, []primitive.ObjectID{id})
if err != nil {
return err
}
if err := a.db.UpdateVersion(ctx, id, update); err != nil {
return err
}
if p, ok := update["platform"]; ok {
if val, ok := p.(string); ok {
platforms = append(platforms, val)
}
}
return a.cache.DeleteCache(ctx, platforms)
}
func (a *applicationDatabase) DeleteVersion(ctx context.Context, id []primitive.ObjectID) error {
platforms, err := a.db.FindPlatform(ctx, id)
if err != nil {
return err
}
if err := a.db.DeleteVersion(ctx, id); err != nil {
return err
}
return a.cache.DeleteCache(ctx, platforms)
}
func (a *applicationDatabase) PageVersion(ctx context.Context, platforms []string, page pagination.Pagination) (int64, []*model.Application, error) {
return a.db.PageVersion(ctx, platforms, page)
}
+193 -60
View File
@@ -1,26 +1,12 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"github.com/openimsdk/tools/log"
"github.com/golang-jwt/jwt/v4"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/tokenverify"
@@ -32,18 +18,41 @@ type AuthDatabase interface {
// Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
}
type authDatabase struct {
cache cache.TokenModel
accessSecret string
accessExpire int64
multiLoginPolicy int
type multiLoginConfig struct {
Policy int
MaxNumOneEnd int
CustomizeLoginNum map[int]int
}
func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, policy int) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLoginPolicy: policy}
type authDatabase struct {
cache cache.TokenModel
accessSecret string
accessExpire int64
multiLogin multiLoginConfig
}
func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{
Policy: multiLogin.Policy,
MaxNumOneEnd: multiLogin.MaxNumOneEnd,
CustomizeLoginNum: map[int]int{
constant.IOSPlatformID: multiLogin.CustomizeLoginNum.IOS,
constant.AndroidPlatformID: multiLogin.CustomizeLoginNum.Android,
constant.WindowsPlatformID: multiLogin.CustomizeLoginNum.Windows,
constant.OSXPlatformID: multiLogin.CustomizeLoginNum.OSX,
constant.WebPlatformID: multiLogin.CustomizeLoginNum.Web,
constant.MiniWebPlatformID: multiLogin.CustomizeLoginNum.MiniWeb,
constant.LinuxPlatformID: multiLogin.CustomizeLoginNum.Linux,
constant.AndroidPadPlatformID: multiLogin.CustomizeLoginNum.APad,
constant.IPadPlatformID: multiLogin.CustomizeLoginNum.IPad,
constant.AdminPlatformID: multiLogin.CustomizeLoginNum.Admin,
},
}}
}
// If the result is empty.
@@ -55,22 +64,38 @@ func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, p
return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
}
func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error {
setMap := make(map[string]map[string]int)
for _, token := range tokens {
claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(a.accessSecret))
key := cachekey.GetTokenKey(claims.UserID, claims.PlatformID)
if err != nil {
continue
} else {
if v, ok := setMap[key]; ok {
v[token] = constant.KickedToken
} else {
setMap[key] = map[string]int{
token: constant.KickedToken,
}
}
}
}
if err := a.cache.BatchSetTokenMapByUidPid(ctx, setMap); err != nil {
return err
}
return nil
}
// Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
// todo: get all platform token
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID)
if err != nil {
return "", err
}
var deleteTokenKey []string
var kickedTokenKey []string
for k, v := range tokens {
t, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken {
deleteTokenKey = append(deleteTokenKey, k)
} else if a.checkKickToken(ctx, platformID, t) {
kickedTokenKey = append(kickedTokenKey, k)
}
deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID)
if err != nil {
return "", err
}
if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
@@ -78,16 +103,6 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return "", err
}
}
const adminTokenMaxNum = 30
if platformID == constant.AdminPlatformID {
if len(kickedTokenKey) > adminTokenMaxNum {
kickedTokenKey = kickedTokenKey[:len(kickedTokenKey)-adminTokenMaxNum]
} else {
kickedTokenKey = nil
}
}
if len(kickedTokenKey) != 0 {
for _, k := range kickedTokenKey {
err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken)
@@ -111,22 +126,140 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return tokenString, nil
}
func (a *authDatabase) checkKickToken(ctx context.Context, platformID int, token *tokenverify.Claims) bool {
switch a.multiLoginPolicy {
case constant.DefalutNotKick:
return false
case constant.PCAndOther:
if constant.PlatformIDToClass(platformID) == constant.TerminalPC ||
constant.PlatformIDToClass(token.PlatformID) == constant.TerminalPC {
return false
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) {
// todo: Move the logic for handling old data to another location.
var (
loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0
deleteToken = make([]string, 0)
kickToken = make([]string, 0)
adminToken = make([]string, 0)
unkickTerminal = ""
)
for plfID, tks := range tokens {
for k, v := range tks {
_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken {
deleteToken = append(deleteToken, k)
} else {
if plfID != constant.AdminPlatformID {
loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
} else {
adminToken = append(adminToken, k)
}
}
}
return true
case constant.AllLoginButSameTermKick:
if platformID == token.PlatformID {
return true
}
return false
default:
return false
}
switch a.multiLogin.Policy {
case constant.DefalutNotKick:
for plt, ts := range loginTokenMap {
l := len(ts)
if platformID == plt {
l++
}
limit := a.multiLogin.MaxNumOneEnd
if l > limit {
kickToken = append(kickToken, ts[:l-limit]...)
}
}
case constant.AllLoginButSameTermKick:
for plt, ts := range loginTokenMap {
kickToken = append(kickToken, ts[:len(ts)-1]...)
if plt == platformID {
kickToken = append(kickToken, ts[len(ts)-1])
}
}
case constant.SingleTerminalLogin:
for _, ts := range loginTokenMap {
kickToken = append(kickToken, ts...)
}
case constant.WebAndOther:
unkickTerminal = constant.WebPlatformStr
fallthrough
case constant.PCAndOther:
if unkickTerminal == "" {
unkickTerminal = constant.TerminalPC
}
if constant.PlatformIDToClass(platformID) != unkickTerminal {
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
kickToken = append(kickToken, ts...)
}
}
} else {
var (
preKick []string
isReserve = true
)
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
// Keep a token from another end
if isReserve {
isReserve = false
kickToken = append(kickToken, ts[:len(ts)-1]...)
preKick = append(preKick, ts[len(ts)-1])
continue
} else {
// Prioritize keeping Android
if plt == constant.AndroidPlatformID {
kickToken = append(kickToken, preKick...)
kickToken = append(kickToken, ts[:len(ts)-1]...)
} else {
kickToken = append(kickToken, ts...)
}
}
}
}
}
case constant.PcMobileAndWeb:
var (
reserved = make(map[string]bool)
)
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
kickToken = append(kickToken, ts...)
} else {
if !reserved[constant.PlatformIDToClass(plt)] {
reserved[constant.PlatformIDToClass(plt)] = true
kickToken = append(kickToken, ts[:len(ts)-1]...)
continue
} else {
kickToken = append(kickToken, ts...)
}
}
}
case constant.Customize:
if a.multiLogin.CustomizeLoginNum[platformID] <= 0 {
return nil, nil, errs.New("Do not allow login on this end").Wrap()
}
for plt, ts := range loginTokenMap {
l := len(ts)
if platformID == plt {
l++
}
// a.multiLogin.CustomizeLoginNum[platformID] must > 0
limit := min(a.multiLogin.CustomizeLoginNum[plt], a.multiLogin.MaxNumOneEnd)
if l > limit {
kickToken = append(kickToken, ts[:l-limit]...)
}
}
default:
return nil, nil, errs.New("unknown multiLogin policy").Wrap()
}
var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
if a.multiLogin.Policy == constant.Customize {
adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID]
}
l := len(adminToken)
if platformID == constant.AdminPlatformID {
l++
}
if l > adminTokenMaxNum {
kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
}
return deleteToken, kickToken, nil
}
@@ -16,9 +16,10 @@ package controller
import (
"context"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
@@ -194,7 +195,7 @@ func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Con
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.CloneConversationCache()
for _, conversation := range conversations {
cache = cache.DelConversationVersionUserIDs(conversation.OwnerUserID)
cache = cache.DelConversationVersionUserIDs(conversation.OwnerUserID, conversation.UserID)
for _, v := range [][2]string{{conversation.OwnerUserID, conversation.UserID}, {conversation.UserID, conversation.OwnerUserID}} {
ownerUserID := v[0]
userID := v[1]
@@ -2,6 +2,7 @@ package controller
import (
"context"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
@@ -83,6 +84,9 @@ func (db *msgTransferDatabase) BatchInsertChat2DB(ctx context.Context, conversat
IOSBadgeCount: msg.OfflinePushInfo.IOSBadgeCount,
}
}
if msg.Status == constant.MsgStatusSending {
msg.Status = constant.MsgStatusSendSuccess
}
msgs[i] = &model.MsgDataModel{
SendID: msg.SendID,
RecvID: msg.RecvID,
@@ -0,0 +1,17 @@
package database
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type Application interface {
LatestVersion(ctx context.Context, platform string) (*model.Application, error)
AddVersion(ctx context.Context, val *model.Application) error
UpdateVersion(ctx context.Context, id primitive.ObjectID, update map[string]any) error
DeleteVersion(ctx context.Context, id []primitive.ObjectID) error
PageVersion(ctx context.Context, platforms []string, page pagination.Pagination) (int64, []*model.Application, error)
FindPlatform(ctx context.Context, id []primitive.ObjectID) ([]string, error)
}
@@ -0,0 +1,82 @@
package mgo
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewApplicationMgo(db *mongo.Database) (*ApplicationMgo, error) {
coll := db.Collection("application")
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "platform", Value: 1},
{Key: "version", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "latest", Value: -1},
},
},
})
if err != nil {
return nil, err
}
return &ApplicationMgo{coll: coll}, nil
}
type ApplicationMgo struct {
coll *mongo.Collection
}
func (a *ApplicationMgo) sort() any {
return bson.D{{"latest", -1}, {"_id", -1}}
}
func (a *ApplicationMgo) LatestVersion(ctx context.Context, platform string) (*model.Application, error) {
return mongoutil.FindOne[*model.Application](ctx, a.coll, bson.M{"platform": platform}, options.FindOne().SetSort(a.sort()))
}
func (a *ApplicationMgo) AddVersion(ctx context.Context, val *model.Application) error {
if val.ID.IsZero() {
val.ID = primitive.NewObjectID()
}
return mongoutil.InsertMany(ctx, a.coll, []*model.Application{val})
}
func (a *ApplicationMgo) UpdateVersion(ctx context.Context, id primitive.ObjectID, update map[string]any) error {
if len(update) == 0 {
return nil
}
return mongoutil.UpdateOne(ctx, a.coll, bson.M{"_id": id}, bson.M{"$set": update}, true)
}
func (a *ApplicationMgo) DeleteVersion(ctx context.Context, id []primitive.ObjectID) error {
if len(id) == 0 {
return nil
}
return mongoutil.DeleteMany(ctx, a.coll, bson.M{"_id": bson.M{"$in": id}})
}
func (a *ApplicationMgo) PageVersion(ctx context.Context, platforms []string, page pagination.Pagination) (int64, []*model.Application, error) {
filter := bson.M{}
if len(platforms) > 0 {
filter["platform"] = bson.M{"$in": platforms}
}
return mongoutil.FindPage[*model.Application](ctx, a.coll, filter, page, options.Find().SetSort(a.sort()))
}
func (a *ApplicationMgo) FindPlatform(ctx context.Context, id []primitive.ObjectID) ([]string, error) {
if len(id) == 0 {
return nil, nil
}
return mongoutil.Find[string](ctx, a.coll, bson.M{"_id": bson.M{"$in": id}}, options.Find().SetProjection(bson.M{"_id": 0, "platform": 1}))
}
+17
View File
@@ -0,0 +1,17 @@
package model
import (
"go.mongodb.org/mongo-driver/bson/primitive"
"time"
)
type Application struct {
ID primitive.ObjectID `bson:"_id"`
Platform string `bson:"platform"`
Version string `bson:"version"`
Url string `bson:"url"`
Text string `bson:"text"`
Force bool `bson:"force"`
Latest bool `bson:"latest"`
CreateTime time.Time `bson:"create_time"`
}