feat: Integrate Comprehensive E2E Testing for GoChat (#1906)

* feat: create e2e test readme

Signed-off-by: Xinwei Xiong (cubxxw) <3293172751nss@gmail.com>

* feat: fix markdown file

* feat: add openim make lint

* feat: add git chglog pull request

* feat: add git chglog pull request

* fix: fix openim api err code

* fix: fix openim api err code

* fix: fix openim api err code

* feat: Improve CICD

* feat: Combining GitHub and Google Workspace for Effective Project Management'

* feat: fix openim tools error code

* feat: fix openim tools error code

* feat: add openim error handle

* feat: add openim error handle

* feat: optimize tim white prom code return err

* feat: fix openim tools error code

* style: format openim server code style

* feat: add openim optimize commit code

* feat: add openim optimize commit code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: format openim code

* feat: Some of the notes were translated

* feat: Some of the notes were translated

* feat: update openim server code

* feat: optimize openim reset code

* feat: optimize openim reset code

---------

Signed-off-by: Xinwei Xiong (cubxxw) <3293172751nss@gmail.com>
This commit is contained in:
Xinwei Xiong
2024-03-04 12:12:14 +08:00
committed by GitHub
parent 1ef26b29a7
commit 853ac47e42
131 changed files with 1133 additions and 881 deletions
+1
View File
@@ -67,6 +67,7 @@ type LocationElem struct {
Longitude float64 `mapstructure:"longitude" validate:"required"`
Latitude float64 `mapstructure:"latitude" validate:"required"`
}
type CustomElem struct {
Data string `mapstructure:"data" validate:"required"`
Description string `mapstructure:"description"`
+2 -1
View File
@@ -44,7 +44,7 @@ func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) {
if opUserID == ownerUserID {
return nil
}
return errs.ErrNoPermission.Wrap(utils.GetSelfFuncName())
return errs.Wrap(errs.ErrNoPermission, "CheckAccessV3: no permission for user "+opUserID)
}
func IsAppManagerUid(ctx context.Context) bool {
@@ -61,6 +61,7 @@ func CheckAdmin(ctx context.Context) error {
}
return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx)))
}
func CheckIMAdmin(ctx context.Context) error {
if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) {
return nil
+26 -5
View File
@@ -15,6 +15,9 @@
package cmd
import (
"errors"
"fmt"
"github.com/OpenIMSDK/protocol/constant"
"github.com/spf13/cobra"
@@ -32,17 +35,35 @@ func NewApiCmd() *ApiCmd {
return ret
}
// AddApi configures the API command to run with specified ports for the API and Prometheus monitoring.
// It ensures error handling for port retrieval and only proceeds if both port numbers are successfully obtained.
func (a *ApiCmd) AddApi(f func(port int, promPort int) error) {
a.Command.RunE = func(cmd *cobra.Command, args []string) error {
return f(a.getPortFlag(cmd), a.getPrometheusPortFlag(cmd))
port, err := a.getPortFlag(cmd)
if err != nil {
return err
}
promPort, err := a.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return f(port, promPort)
}
}
func (a *ApiCmd) GetPortFromConfig(portType string) int {
func (a *ApiCmd) GetPortFromConfig(portType string) (int, error) {
if portType == constant.FlagPort {
return config2.Config.Api.OpenImApiPort[0]
if len(config2.Config.Api.OpenImApiPort) > 0 {
return config2.Config.Api.OpenImApiPort[0], nil
}
return 0, errors.New("API port configuration is empty or missing")
} else if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.ApiPrometheusPort[0]
if len(config2.Config.Prometheus.ApiPrometheusPort) > 0 {
return config2.Config.Prometheus.ApiPrometheusPort[0], nil
}
return 0, errors.New("Prometheus port configuration is empty or missing")
}
return 0
return 0, fmt.Errorf("unknown port type: %s", portType)
}
+4
View File
@@ -36,3 +36,7 @@ func (c *CronTaskCmd) Exec(f func() error) error {
c.addRunE(f)
return c.Execute()
}
func (c *CronTaskCmd) GetPortFromConfig(portType string) (int, error) {
return 0, nil
}
+44 -15
View File
@@ -15,13 +15,15 @@
package cmd
import (
"log"
"errors"
"github.com/spf13/cobra"
"github.com/OpenIMSDK/protocol/constant"
"github.com/openimsdk/open-im-server/v3/internal/msggateway"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config"
)
@@ -39,20 +41,32 @@ func (m *MsgGatewayCmd) AddWsPortFlag() {
m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port")
}
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int {
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) (int, error) {
port, err := cmd.Flags().GetInt(constant.FlagWsPort)
if err != nil {
log.Println("Error getting ws port flag:", err)
return 0, errs.Wrap(err, "error getting ws port flag")
}
if port == 0 {
port = m.PortFromConfig(constant.FlagWsPort)
port, _ = m.PortFromConfig(constant.FlagWsPort)
}
return port
return port, nil
}
func (m *MsgGatewayCmd) addRunE() {
m.Command.RunE = func(cmd *cobra.Command, args []string) error {
return msggateway.RunWsAndServer(m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd))
wsPort, err := m.getWsPortFlag(cmd)
if err != nil {
return errs.Wrap(err, "failed to get WS port flag")
}
port, err := m.getPortFlag(cmd)
if err != nil {
return err
}
prometheusPort, err := m.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return msggateway.RunWsAndServer(port, wsPort, prometheusPort)
}
}
@@ -61,18 +75,33 @@ func (m *MsgGatewayCmd) Exec() error {
return m.Execute()
}
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int {
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) (int, error) {
var port int
var exists bool
switch portType {
case constant.FlagWsPort:
return v3config.Config.LongConnSvr.OpenImWsPort[0]
if len(v3config.Config.LongConnSvr.OpenImWsPort) > 0 {
port = v3config.Config.LongConnSvr.OpenImWsPort[0]
exists = true
}
case constant.FlagPort:
return v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0]
if len(v3config.Config.LongConnSvr.OpenImMessageGatewayPort) > 0 {
port = v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0]
exists = true
}
case constant.FlagPrometheusPort:
return v3config.Config.Prometheus.MessageGatewayPrometheusPort[0]
default:
return 0
if len(v3config.Config.Prometheus.MessageGatewayPrometheusPort) > 0 {
port = v3config.Config.Prometheus.MessageGatewayPrometheusPort[0]
exists = true
}
}
if !exists {
return 0, errs.Wrap(errors.New("port type '%s' not found in configuration"), portType)
}
return port, nil
}
+1 -1
View File
@@ -44,7 +44,7 @@ func TestMsgGatewayCmd_GetPortFromConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.portType, func(t *testing.T) {
got := msgGatewayCmd.GetPortFromConfig(tt.portType)
got, _ := msgGatewayCmd.GetPortFromConfig(tt.portType)
assert.Equal(t, tt.want, got)
})
}
+13 -5
View File
@@ -37,7 +37,11 @@ func NewMsgTransferCmd() *MsgTransferCmd {
func (m *MsgTransferCmd) addRunE() {
m.Command.RunE = func(cmd *cobra.Command, args []string) error {
return msgtransfer.StartTransfer(m.getPrometheusPortFlag(cmd))
prometheusPort, err := m.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return msgtransfer.StartTransfer(prometheusPort)
}
}
@@ -46,14 +50,18 @@ func (m *MsgTransferCmd) Exec() error {
return m.Execute()
}
func (m *MsgTransferCmd) GetPortFromConfig(portType string) int {
func (m *MsgTransferCmd) GetPortFromConfig(portType string) (int, error) {
if portType == constant.FlagPort {
return 0
return 0, nil
} else if portType == constant.FlagPrometheusPort {
n := m.getTransferProgressFlagValue()
return config2.Config.Prometheus.MessageTransferPrometheusPort[n]
if n < len(config2.Config.Prometheus.MessageTransferPrometheusPort) {
return config2.Config.Prometheus.MessageTransferPrometheusPort[n], nil
}
return 0, fmt.Errorf("index out of range for MessageTransferPrometheusPort with index %d", n)
}
return 0
return 0, fmt.Errorf("unknown port type: %s", portType)
}
func (m *MsgTransferCmd) AddTransferProgressFlag() {
+2 -1
View File
@@ -18,6 +18,7 @@ import (
"github.com/spf13/cobra"
"github.com/openimsdk/open-im-server/v3/internal/tools"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
type MsgUtilsCmd struct {
@@ -137,7 +138,7 @@ func (s *SeqCmd) GetSeqCmd() *cobra.Command {
s.Command.Run = func(cmdLines *cobra.Command, args []string) {
_, err := tools.InitMsgTool()
if err != nil {
panic(err)
util.ExitWithError(err)
}
userID := s.getUserIDFlag(cmdLines)
superGroupID := s.getSuperGroupIDFlag(cmdLines)
+33 -19
View File
@@ -19,6 +19,8 @@ import (
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/errs"
"github.com/spf13/cobra"
"github.com/OpenIMSDK/protocol/constant"
@@ -28,8 +30,9 @@ import (
)
type RootCmdPt interface {
GetPortFromConfig(portType string) int
GetPortFromConfig(portType string) (int, error)
}
type RootCmd struct {
Command cobra.Command
Name string
@@ -77,7 +80,7 @@ func (rc *RootCmd) persistentPreRun(cmd *cobra.Command, opts ...func(*CmdOpts))
cmdOpts := rc.applyOptions(opts...)
if err := rc.initializeLogger(cmdOpts); err != nil {
return fmt.Errorf("failed to initialize from config: %w", err)
return errs.Wrap(err, "failed to initialize logger")
}
return nil
@@ -130,31 +133,41 @@ func (r *RootCmd) AddPortFlag() {
r.Command.Flags().IntP(constant.FlagPort, "p", 0, "server listen port")
}
func (r *RootCmd) getPortFlag(cmd *cobra.Command) int {
func (r *RootCmd) getPortFlag(cmd *cobra.Command) (int, error) {
port, err := cmd.Flags().GetInt(constant.FlagPort)
if err != nil {
fmt.Println("Error getting ws port flag:", err)
// Wrapping the error with additional context
return 0, errs.Wrap(err, "error getting port flag")
}
if port == 0 {
port = r.PortFromConfig(constant.FlagPort)
port, _ = r.PortFromConfig(constant.FlagPort)
// port, err := r.PortFromConfig(constant.FlagPort)
// if err != nil {
// // Optionally wrap the error if it's an internal error needing context
// return 0, errs.Wrap(err, "error getting port from config")
// }
}
return port
return port, nil
}
func (r *RootCmd) GetPortFlag() int {
return r.port
// // GetPortFlag returns the port flag.
func (r *RootCmd) GetPortFlag() (int, error) {
return r.port, nil
}
func (r *RootCmd) AddPrometheusPortFlag() {
r.Command.Flags().IntP(constant.FlagPrometheusPort, "", 0, "server prometheus listen port")
}
func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) int {
port, _ := cmd.Flags().GetInt(constant.FlagPrometheusPort)
if port == 0 {
port = r.PortFromConfig(constant.FlagPrometheusPort)
func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) (int, error) {
port, err := cmd.Flags().GetInt(constant.FlagPrometheusPort)
if err != nil || port == 0 {
port, err = r.PortFromConfig(constant.FlagPrometheusPort)
if err != nil {
return 0, err
}
}
return port
return port, nil
}
func (r *RootCmd) GetPrometheusPortFlag() int {
@@ -175,10 +188,11 @@ func (r *RootCmd) AddCommand(cmds ...*cobra.Command) {
r.Command.AddCommand(cmds...)
}
func (r *RootCmd) GetPortFromConfig(portType string) int {
return 0
}
func (r *RootCmd) PortFromConfig(portType string) int {
return r.cmdItf.GetPortFromConfig(portType)
func (r *RootCmd) PortFromConfig(portType string) (int, error) {
// Retrieve the port and cache it
port, err := r.cmdItf.GetPortFromConfig(portType)
if err != nil {
return 0, err
}
return port, nil
}
+67 -64
View File
@@ -16,11 +16,14 @@ package cmd
import (
"errors"
"fmt"
"github.com/OpenIMSDK/protocol/constant"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/OpenIMSDK/tools/errs"
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/discoveryregistry"
@@ -39,78 +42,78 @@ func NewRpcCmd(name string) *RpcCmd {
}
func (a *RpcCmd) Exec() error {
a.Command.Run = func(cmd *cobra.Command, args []string) {
a.port = a.getPortFlag(cmd)
a.prometheusPort = a.getPrometheusPortFlag(cmd)
a.Command.RunE = func(cmd *cobra.Command, args []string) error {
portFlag, err := a.getPortFlag(cmd)
if err != nil {
return err
}
a.port = portFlag
prometheusPort, err := a.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
a.prometheusPort = prometheusPort
return nil
}
return a.Execute()
}
func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error {
if a.GetPortFlag() == 0 {
return errors.New("port is required")
portFlag, err := a.GetPortFlag()
if err != nil {
return err
} else {
a.port = portFlag
}
return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), rpcFn)
return startrpc.Start(portFlag, name, a.GetPrometheusPortFlag(), rpcFn)
}
func (a *RpcCmd) GetPortFromConfig(portType string) int {
switch a.Name {
case RpcPushServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImPushPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.PushPrometheusPort[0]
}
case RpcAuthServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImAuthPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.AuthPrometheusPort[0]
}
case RpcConversationServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImConversationPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.ConversationPrometheusPort[0]
}
case RpcFriendServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImFriendPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.FriendPrometheusPort[0]
}
case RpcGroupServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImGroupPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.GroupPrometheusPort[0]
}
case RpcMsgServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImMessagePort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.MessagePrometheusPort[0]
}
case RpcThirdServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImThirdPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.ThirdPrometheusPort[0]
}
case RpcUserServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImUserPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.UserPrometheusPort[0]
func (a *RpcCmd) GetPortFromConfig(portType string) (int, error) {
portConfigMap := map[string]map[string]int{
RpcPushServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImPushPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.PushPrometheusPort[0],
},
RpcAuthServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImAuthPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.AuthPrometheusPort[0],
},
RpcConversationServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImConversationPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.ConversationPrometheusPort[0],
},
RpcFriendServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImFriendPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.FriendPrometheusPort[0],
},
RpcGroupServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImGroupPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.GroupPrometheusPort[0],
},
RpcMsgServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImMessagePort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.MessagePrometheusPort[0],
},
RpcThirdServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImThirdPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.ThirdPrometheusPort[0],
},
RpcUserServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImUserPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.UserPrometheusPort[0],
},
}
if portMap, ok := portConfigMap[a.Name]; ok {
if port, ok := portMap[portType]; ok {
return port, nil
} else {
return 0, errs.Wrap(errors.New("port type not found"), fmt.Sprintf("Failed to get port for %s", a.Name))
}
}
return 0
return 0, errs.Wrap(fmt.Errorf("server name '%s' not found", a.Name), "Failed to get port configuration")
}
+64 -33
View File
@@ -21,6 +21,7 @@ import (
"path/filepath"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"gopkg.in/yaml.v3"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
@@ -36,32 +37,38 @@ const (
DefaultFolderPath = "../config/"
)
// return absolude path join ../config/, this is k8s container config path.
func GetDefaultConfigPath() string {
// GetDefaultConfigPath returns the absolute path to the default configuration directory
// relative to the executable's location. It is intended for use in Kubernetes container configurations.
// Errors are returned to the caller to allow for flexible error handling.
func GetDefaultConfigPath() (string, error) {
executablePath, err := os.Executable()
if err != nil {
fmt.Println("GetDefaultConfigPath error:", err.Error())
return ""
return "", errs.Wrap(err, "failed to get executable path")
}
// Calculate the config path as a directory relative to the executable's location
configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/"))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err)
os.Exit(1)
return "", errs.Wrap(err, "failed to get output directory")
}
return configPath
return configPath, nil
}
// getProjectRoot returns the absolute path of the project root directory.
func GetProjectRoot() string {
executablePath, _ := os.Executable()
// GetProjectRoot returns the absolute path of the project root directory by navigating up from the directory
// containing the executable. It provides a detailed error if the path cannot be determined.
func GetProjectRoot() (string, error) {
executablePath, err := os.Executable()
if err != nil {
return "", errs.Wrap(err, "failed to retrieve executable path")
}
// Attempt to compute the project root by navigating up from the executable's directory
projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../.."))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err)
os.Exit(1)
return "", err
}
return projectRoot
return projectRoot, nil
}
func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
@@ -83,42 +90,66 @@ func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
return opts
}
// initConfig loads configuration from a specified path into the provided config structure.
// If the specified config file does not exist, it attempts to load from the project's default "config" directory.
// It logs informative messages regarding the configuration path being used.
func initConfig(config any, configName, configFolderPath string) error {
configFolderPath = filepath.Join(configFolderPath, configName)
_, err := os.Stat(configFolderPath)
configFilePath := filepath.Join(configFolderPath, configName)
_, err := os.Stat(configFilePath)
if err != nil {
if !os.IsNotExist(err) {
fmt.Println("stat config path error:", err.Error())
return fmt.Errorf("stat config path error: %w", err)
return errs.Wrap(err, fmt.Sprintf("failed to check existence of config file at path: %s", configFilePath))
}
configFolderPath = filepath.Join(GetProjectRoot(), "config", configName)
fmt.Println("flag's path,enviment's path,default path all is not exist,using project path:", configFolderPath)
var projectRoot string
projectRoot, err = GetProjectRoot()
if err != nil {
return err
}
configFilePath = filepath.Join(projectRoot, "config", configName)
fmt.Printf("Configuration file not found at specified path. Falling back to project path: %s\n", configFilePath)
}
data, err := os.ReadFile(configFolderPath)
if err != nil {
return fmt.Errorf("read file error: %w", err)
}
if err = yaml.Unmarshal(data, config); err != nil {
return fmt.Errorf("unmarshal yaml error: %w", err)
}
fmt.Println("The path of the configuration file to start the process:", configFolderPath)
data, err := os.ReadFile(configFilePath)
if err != nil {
// Wrap and return the error if reading the configuration file fails.
return errs.Wrap(err, fmt.Sprintf("failed to read configuration file at path: %s", configFilePath))
}
if err = yaml.Unmarshal(data, config); err != nil {
// Wrap and return the error if unmarshalling the YAML configuration fails.
return errs.Wrap(err, "failed to unmarshal YAML configuration")
}
fmt.Printf("Configuration file loaded successfully from path: %s\n", configFilePath)
return nil
}
// InitConfig initializes the application configuration by loading it from a specified folder path.
// If the folder path is not provided, it attempts to use the OPENIMCONFIG environment variable,
// and as a fallback, it uses the default configuration path. It loads both the main configuration
// and notification configuration, wrapping errors for better context.
func InitConfig(configFolderPath string) error {
// Use the provided config folder path, or fallback to environment variable or default path
if configFolderPath == "" {
envConfigPath := os.Getenv("OPENIMCONFIG")
if envConfigPath != "" {
configFolderPath = envConfigPath
} else {
configFolderPath = GetDefaultConfigPath()
configFolderPath = os.Getenv("OPENIMCONFIG")
if configFolderPath == "" {
var err error
configFolderPath, err = GetDefaultConfigPath()
if err != nil {
return err
}
}
}
// Initialize the main configuration
if err := initConfig(&Config, FileName, configFolderPath); err != nil {
return err
}
return initConfig(&Config.Notification, NotificationFileName, configFolderPath)
// Initialize the notification configuration
if err := initConfig(&Config.Notification, NotificationFileName, configFolderPath); err != nil {
return err
}
return nil
}
+2 -2
View File
@@ -31,7 +31,7 @@ func TestGetDefaultConfigPath(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetDefaultConfigPath(); got != tt.want {
if got, _ := GetDefaultConfigPath(); got != tt.want {
t.Errorf("GetDefaultConfigPath() = %v, want %v", got, tt.want)
}
})
@@ -47,7 +47,7 @@ func TestGetProjectRoot(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetProjectRoot(); got != tt.want {
if got, _ := GetProjectRoot(); got != tt.want {
t.Errorf("GetProjectRoot() = %v, want %v", got, tt.want)
}
})
+1 -5
View File
@@ -23,11 +23,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
)
func BlackDB2Pb(
ctx context.Context,
blackDBs []*relation.BlackModel,
f func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) (blackPbs []*sdk.BlackInfo, err error) {
func BlackDB2Pb(ctx context.Context, blackDBs []*relation.BlackModel, f func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error)) (blackPbs []*sdk.BlackInfo, err error) {
if len(blackDBs) == 0 {
return nil, nil
}
+2 -7
View File
@@ -53,11 +53,7 @@ func FriendDB2Pb(ctx context.Context, friendDB *relation.FriendModel,
}, nil
}
func FriendsDB2Pb(
ctx context.Context,
friendsDB []*relation.FriendModel,
getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) (friendsPb []*sdkws.FriendInfo, err error) {
func FriendsDB2Pb(ctx context.Context, friendsDB []*relation.FriendModel, getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error)) (friendsPb []*sdkws.FriendInfo, err error) {
if len(friendsDB) == 0 {
return nil, nil
}
@@ -89,8 +85,7 @@ func FriendsDB2Pb(
}
func FriendRequestDB2Pb(
ctx context.Context,
func FriendRequestDB2Pb(ctx context.Context,
friendRequests []*relation.FriendRequestModel,
getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) ([]*sdkws.FriendRequest, error) {
+1 -5
View File
@@ -46,11 +46,7 @@ type BlackCacheRedis struct {
blackDB relationtb.BlackModelInterface
}
func NewBlackCacheRedis(
rdb redis.UniversalClient,
blackDB relationtb.BlackModelInterface,
options rockscache.Options,
) BlackCache {
func NewBlackCacheRedis(rdb redis.UniversalClient, blackDB relationtb.BlackModelInterface, options rockscache.Options) BlackCache {
rcClient := rockscache.NewClient(rdb, options)
return &BlackCacheRedis{
+9 -7
View File
@@ -49,7 +49,7 @@ func NewRedis() (redis.UniversalClient, error) {
overrideConfigFromEnv()
if len(config.Config.Redis.Address) == 0 {
return nil, errs.Wrap(errors.New("redis address is empty"))
return nil, errs.Wrap(errors.New("redis address is empty"), "Redis configuration error")
}
specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound)
var rdb redis.UniversalClient
@@ -65,9 +65,9 @@ func NewRedis() (redis.UniversalClient, error) {
rdb = redis.NewClient(&redis.Options{
Addr: config.Config.Redis.Address[0],
Username: config.Config.Redis.Username,
Password: config.Config.Redis.Password,
DB: 0, // use default DB
PoolSize: 100, // connection pool size
Password: config.Config.Redis.Password, // no password set
DB: 0, // use default DB
PoolSize: 100, // connection pool size
MaxRetries: maxRetry,
})
}
@@ -77,9 +77,9 @@ func NewRedis() (redis.UniversalClient, error) {
defer cancel()
err = rdb.Ping(ctx).Err()
if err != nil {
uriFormat := "address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t"
errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.Password, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline)
return nil, errs.Wrap(err, errMsg)
uriFormat := "address:%v, username:%s, clusterMode:%t, enablePipeline:%t"
errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline)
return nil, errs.Wrap(err, "Redis connection failed: %s", errMsg)
}
redisClient = rdb
return rdb, err
@@ -98,9 +98,11 @@ func overrideConfigFromEnv() {
config.Config.Redis.Address = strings.Split(envAddr, ",")
}
}
if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" {
config.Config.Redis.Username = envUser
}
if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" {
config.Config.Redis.Password = envPass
}
+3 -11
View File
@@ -134,7 +134,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
}
bs, err := json.Marshal(t)
if err != nil {
return "", errs.Wrap(err)
return "", errs.Wrap(err, "marshal failed")
}
write = true
@@ -152,8 +152,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
err = json.Unmarshal([]byte(v), &t)
if err != nil {
log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire)
return t, errs.Wrap(err)
return t, errs.Wrap(err, "unmarshal failed")
}
return t, nil
@@ -197,14 +196,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
// return tArrays, nil
//}
func batchGetCache2[T any, K comparable](
ctx context.Context,
rcClient *rockscache.Client,
expire time.Duration,
keys []K,
keyFn func(key K) string,
fns func(ctx context.Context, key K) (T, error),
) ([]T, error) {
func batchGetCache2[T any, K comparable](ctx context.Context, rcClient *rockscache.Client, expire time.Duration, keys []K, keyFn func(key K) string, fns func(ctx context.Context, key K) (T, error)) ([]T, error) {
if len(keys) == 0 {
return nil, nil
}
+3 -4
View File
@@ -24,12 +24,11 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
"github.com/OpenIMSDK/tools/errs"
"github.com/gogo/protobuf/jsonpb"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/utils"
@@ -433,7 +432,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
for _, msg := range msgs {
s, err := msgprocessor.Pb2String(msg)
if err != nil {
return 0, errs.Wrap(err, "pb.marshal")
return 0, err
}
key := c.getMessageCacheKey(conversationID, msg.Seq)
@@ -442,7 +441,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
results, err := pipe.Exec(ctx)
if err != nil {
return 0, errs.Wrap(err, "pipe.set")
return 0, errs.Wrap(err)
}
for _, res := range results {
+4 -8
View File
@@ -66,11 +66,7 @@ type UserCacheRedis struct {
rcClient *rockscache.Client
}
func NewUserCacheRedis(
rdb redis.UniversalClient,
userDB relationtb.UserModelInterface,
options rockscache.Options,
) UserCache {
func NewUserCacheRedis(rdb redis.UniversalClient, userDB relationtb.UserModelInterface, options rockscache.Options) UserCache {
rcClient := rockscache.NewClient(rdb, options)
return &UserCacheRedis{
@@ -282,8 +278,8 @@ func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string,
var onlineStatus user.OnlineStatus
if !isNil {
err2 := json.Unmarshal([]byte(result), &onlineStatus)
if err != nil {
return errs.Wrap(err2)
if err2 != nil {
return errs.Wrap(err, "json.Unmarshal failed")
}
onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID))
} else {
@@ -293,7 +289,7 @@ func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string,
onlineStatus.UserID = userID
newjsonData, err := json.Marshal(&onlineStatus)
if err != nil {
return errs.Wrap(err)
return errs.Wrap(err, "json.Marshal failed")
}
_, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result()
if err != nil {
+6 -10
View File
@@ -30,9 +30,9 @@ import (
)
type AuthDatabase interface {
// 结果为空 不返回错误
// If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
// 创建token
// Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
}
@@ -47,16 +47,12 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire}
}
// 结果为空 不返回错误.
func (a *authDatabase) GetTokensWithoutError(
ctx context.Context,
userID string,
platformID int,
) (map[string]int, error) {
// If the result is empty.
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
}
// 创建token.
// Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil {
@@ -80,7 +76,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil {
return "", errs.Wrap(err)
return "", errs.Wrap(err, "token.SignedString")
}
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
}
+12 -12
View File
@@ -27,14 +27,14 @@ import (
)
type BlackDatabase interface {
// Create 增加黑名单
// Create add BlackList
Create(ctx context.Context, blacks []*relation.BlackModel) (err error)
// Delete 删除黑名单
// Delete delete BlackList
Delete(ctx context.Context, blacks []*relation.BlackModel) (err error)
// FindOwnerBlacks 获取黑名单列表
// FindOwnerBlacks get BlackList list
FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error)
FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error)
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true)
// CheckIn Check whether user2 is in the black list of user1 (inUser1Blacks==true) Check whether user1 is in the black list of user2 (inUser2Blacks==true)
CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error)
}
@@ -47,7 +47,7 @@ func NewBlackDatabase(black relation.BlackModelInterface, cache cache.BlackCache
return &blackDatabase{black, cache}
}
// Create 增加黑名单.
// Create Add Blacklist.
func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Create(ctx, blacks); err != nil {
return err
@@ -55,7 +55,7 @@ func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks)
}
// Delete 删除黑名单.
// Delete Delete Blacklist.
func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Delete(ctx, blacks); err != nil {
return err
@@ -63,6 +63,7 @@ func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks)
}
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relation.BlackModel) (err error) {
cache := b.cache.NewCache()
for _, black := range blacks {
@@ -71,16 +72,13 @@ func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relat
return cache.ExecDel(ctx)
}
// FindOwnerBlacks 获取黑名单列表.
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlacks(ctx, ownerUserID, pagination)
}
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true).
func (b *blackDatabase) CheckIn(
ctx context.Context,
userID1, userID2 string,
) (inUser1Blacks bool, inUser2Blacks bool, err error) {
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) {
userID1BlackIDs, err := b.cache.GetBlackIDs(ctx, userID1)
if err != nil {
return
@@ -93,10 +91,12 @@ func (b *blackDatabase) CheckIn(
return utils.IsContain(userID2, userID1BlackIDs), utils.IsContain(userID1, userID2BlackIDs), nil
}
// FindBlackIDs Get Blacklist List.
func (b *blackDatabase) FindBlackIDs(ctx context.Context, ownerUserID string) (blackIDs []string, err error) {
return b.cache.GetBlackIDs(ctx, ownerUserID)
}
// FindBlackInfos Get Blacklist List.
func (b *blackDatabase) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlackInfos(ctx, ownerUserID, userIDs)
}
+26 -18
View File
@@ -32,32 +32,40 @@ import (
)
type ConversationDatabase interface {
// UpdateUserConversationFiled 更新用户该会话的属性信息
UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation 创建一批新的会话
// UpdateUsersConversationField updates the properties of a conversation for specified users.
UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation creates a batch of new conversations.
CreateConversation(ctx context.Context, conversations []*relationtb.ConversationModel) error
// SyncPeerUserPrivateConversation 同步对端私聊会话内部保证事务操作
// SyncPeerUserPrivateConversationTx ensures transactional operation while syncing private conversations between peers.
SyncPeerUserPrivateConversationTx(ctx context.Context, conversation []*relationtb.ConversationModel) error
// FindConversations 根据会话ID获取某个用户的多个会话
// FindConversations retrieves multiple conversations of a user by conversation IDs.
FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// FindRecvMsgNotNotifyUserIDs 获取超级大群开启免打扰的用户ID
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
// GetUserAllConversation 获取一个用户在服务器上所有的会话
// GetUserAllConversation fetches all conversations of a user on the server.
GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.ConversationModel, error)
// SetUserConversations 设置用户多个会话属性,如果会话不存在则创建,否则更新,内部保证原子性
// SetUserConversations sets multiple conversation properties for a user, creates new conversations if they do not exist, or updates them otherwise. This operation is atomic.
SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error
// SetUsersConversationFiledTx 设置多个用户会话关于某个字段的更新操作,如果会话不存在则创建,否则更新,内部保证事务操作
SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) error
// SetUsersConversationFieldTx updates a specific field for multiple users' conversations, creating new conversations if they do not exist, or updates them otherwise. This operation is transactional.
SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) error
// CreateGroupChatConversation creates a group chat conversation for the specified group ID and user IDs.
CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error
// GetConversationIDs retrieves conversation IDs for a given user.
GetConversationIDs(ctx context.Context, userID string) ([]string, error)
// GetUserConversationIDsHash gets the hash of conversation IDs for a given user.
GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error)
// GetAllConversationIDs fetches all conversation IDs.
GetAllConversationIDs(ctx context.Context) ([]string, error)
// GetAllConversationIDsNumber returns the number of all conversation IDs.
GetAllConversationIDsNumber(ctx context.Context) (int64, error)
// PageConversationIDs paginates through conversation IDs based on the specified pagination settings.
PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
// GetConversationsByConversationID retrieves conversations by their IDs.
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// GetConversationIDsNeedDestruct fetches conversations that need to be destructed based on specific criteria.
GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.ConversationModel, error)
// GetConversationNotReceiveMessageUserIDs gets user IDs for users in a conversation who have not received messages.
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
}
func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase {
@@ -74,7 +82,7 @@ type conversationDatabase struct {
tx tx.CtxTx
}
func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) (err error) {
func (c *conversationDatabase) SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) (err error) {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
if conversation.GroupID != "" {
@@ -85,22 +93,22 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
return err
}
if len(haveUserIDs) > 0 {
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap)
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, fieldMap)
if err != nil {
return err
}
cache = cache.DelUsersConversation(conversation.ConversationID, haveUserIDs...)
if _, ok := filedMap["has_read_seq"]; ok {
if _, ok := fieldMap["has_read_seq"]; ok {
for _, userID := range haveUserIDs {
cache = cache.DelUserAllHasReadSeqs(userID, conversation.ConversationID)
}
}
if _, ok := filedMap["recv_msg_opt"]; ok {
if _, ok := fieldMap["recv_msg_opt"]; ok {
cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID)
}
}
NotUserIDs := utils.DifferenceString(haveUserIDs, userIDs)
log.ZDebug(ctx, "SetUsersConversationFiledTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
log.ZDebug(ctx, "SetUsersConversationFieldTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
var conversations []*relationtb.ConversationModel
now := time.Now()
for _, v := range NotUserIDs {
@@ -123,7 +131,7 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
})
}
func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
func (c *conversationDatabase) UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
_, err := c.conversationDB.UpdateByMap(ctx, userIDs, conversationID, args)
if err != nil {
return err
+46 -24
View File
@@ -16,6 +16,7 @@ package controller
import (
"context"
"fmt"
"time"
"github.com/OpenIMSDK/tools/pagination"
@@ -89,20 +90,30 @@ func NewFriendDatabase(friend relation.FriendModelInterface, friendRequest relat
return &friendDatabase{friend: friend, friendRequest: friendRequest, cache: cache, tx: tx}
}
// ok 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true).
// CheckIn verifies if user2 is in user1's friend list (inUser1Friends returns true) and
// if user1 is in user2's friend list (inUser2Friends returns true).
func (f *friendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) {
// Retrieve friend IDs of userID1 from the cache
userID1FriendIDs, err := f.cache.GetFriendIDs(ctx, userID1)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID1, err)
return
}
// Retrieve friend IDs of userID2 from the cache
userID2FriendIDs, err := f.cache.GetFriendIDs(ctx, userID2)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID2, err)
return
}
return utils.IsContain(userID2, userID1FriendIDs), utils.IsContain(userID1, userID2FriendIDs), nil
// Check if userID2 is in userID1's friend list and vice versa
inUser1Friends = utils.IsContain(userID2, userID1FriendIDs)
inUser2Friends = utils.IsContain(userID1, userID2FriendIDs)
return inUser1Friends, inUser2Friends, nil
}
// 增加或者更新好友申请 如果之前有记录则更新,没有记录则新增.
// AddFriendRequest adds or updates a friend request.
func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
_, err := f.friendRequest.Take(ctx, fromUserID, toUserID)
@@ -126,11 +137,11 @@ func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUse
})
}
// (1)先判断是否在好友表 (在不在都不返回错误) (2)对于不在好友列表的 插入即可.
// (1) First determine whether it is in the friends list (in or out does not return an error) (2) for not in the friends list can be inserted.
func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
cache := f.cache.NewCache()
// 先find 找出重复的 去掉重复的
// User find friends
fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
return err
@@ -170,26 +181,37 @@ func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string,
})
}
// 拒绝好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)修改申请记录 已拒绝.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) {
// RefuseFriendRequest rejects a friend request. It first checks for an existing, unprocessed request.
// If no such request exists, it returns an error. Otherwise, it marks the request as refused.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) error {
// Attempt to retrieve the friend request from the database.
fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil {
return err
return fmt.Errorf("failed to retrieve friend request from %s to %s: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
// Check if the friend request has already been handled.
if fr.HandleResult != 0 {
return errs.ErrArgs.Wrap("the friend request has been processed")
return fmt.Errorf("friend request from %s to %s has already been processed", friendRequest.FromUserID, friendRequest.ToUserID)
}
log.ZDebug(ctx, "refuse friend request", "friendRequest db", fr, "friendRequest arg", friendRequest)
// Log the action of refusing the friend request for debugging and auditing purposes.
log.ZDebug(ctx, "Refusing friend request", map[string]interface{}{
"DB_FriendRequest": fr,
"Arg_FriendRequest": friendRequest,
})
// Mark the friend request as refused and update the handle time.
friendRequest.HandleResult = constant.FriendResponseRefuse
friendRequest.HandleTime = time.Now()
err = f.friendRequest.Update(ctx, friendRequest)
if err != nil {
return err
if err := f.friendRequest.Update(ctx, friendRequest); err != nil {
return fmt.Errorf("failed to update friend request from %s to %s as refused: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
return nil
}
// AgreeFriendRequest 同意好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)检查是否好友(不返回错误) (3) 建立双向好友关系(存在的忽略).
// AgreeFriendRequest accepts a friend request. It first checks for an existing, unprocessed request.
func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
defer log.ZDebug(ctx, "return line")
@@ -227,10 +249,10 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
return err
}
existsMap := utils.SliceSet(utils.Slice(exists, func(friend *relation.FriendModel) [2]string {
return [...]string{friend.OwnerUserID, friend.FriendUserID} // 自己 - 好友
return [...]string{friend.OwnerUserID, friend.FriendUserID} // My - Friend
}))
var adds []*relation.FriendModel
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // 自己 - 好友
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // My - Friend
adds = append(
adds,
&relation.FriendModel{
@@ -241,7 +263,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
},
)
}
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // 好友 - 自己
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // My - Friend
adds = append(
adds,
&relation.FriendModel{
@@ -261,7 +283,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
})
}
// 删除好友 外部判断是否好友关系.
// Delete removes a friend relationship. It is assumed that the external caller has verified the friendship status.
func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) {
if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil {
return err
@@ -269,7 +291,7 @@ func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendU
return f.cache.DelFriendIDs(append(friendUserIDs, ownerUserID)...).ExecDel(ctx)
}
// 更新好友备注 零值也支持.
// UpdateRemark updates the remark for a friend. Zero value for remark is also supported.
func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) {
if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil {
return err
@@ -277,27 +299,27 @@ func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUs
return f.cache.DelFriend(ownerUserID, friendUserID).ExecDel(ctx)
}
// 获取ownerUserID的好友列表 无结果不返回错误.
// PageOwnerFriends retrieves the list of friends for the ownerUserID. It does not return an error if the result is empty.
func (f *friendDatabase) PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindOwnerFriends(ctx, ownerUserID, pagination)
}
// friendUserID在哪些人的好友列表中.
// PageInWhoseFriends identifies in whose friend lists the friendUserID appears.
func (f *friendDatabase) PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindInWhoseFriends(ctx, friendUserID, pagination)
}
// 获取我发出去的好友申请 无结果不返回错误.
// PageFriendRequestFromMe retrieves friend requests sent by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestFromMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindFromUserID(ctx, userID, pagination)
}
// 获取我收到的的好友申请 无结果不返回错误.
// PageFriendRequestToMe retrieves friend requests received by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestToMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindToUserID(ctx, userID, pagination)
}
// 获取某人指定好友的信息 如果有好友不存在,也返回错误.
// FindFriendsWithError retrieves specified friends' information for ownerUserID. Returns an error if any friend does not exist.
func (f *friendDatabase) FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*relation.FriendModel, err error) {
friends, err = f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
+42 -10
View File
@@ -31,47 +31,79 @@ import (
)
type GroupDatabase interface {
// Group
// CreateGroup creates new groups along with their members.
CreateGroup(ctx context.Context, groups []*relationtb.GroupModel, groupMembers []*relationtb.GroupMemberModel) error
// TakeGroup retrieves a single group by its ID.
TakeGroup(ctx context.Context, groupID string) (group *relationtb.GroupModel, err error)
// FindGroup retrieves multiple groups by their IDs.
FindGroup(ctx context.Context, groupIDs []string) (groups []*relationtb.GroupModel, err error)
// SearchGroup searches for groups based on a keyword and pagination settings, returns total count and groups.
SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*relationtb.GroupModel, error)
// UpdateGroup updates the properties of a group identified by its ID.
UpdateGroup(ctx context.Context, groupID string, data map[string]any) error
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error // 解散群,并删除群成员
// DismissGroup disbands a group and optionally removes its members based on the deleteMember flag.
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error
// TakeGroupMember retrieves a specific group member by group ID and user ID.
TakeGroupMember(ctx context.Context, groupID string, userID string) (groupMember *relationtb.GroupMemberModel, err error)
// TakeGroupOwner retrieves the owner of a group by group ID.
TakeGroupOwner(ctx context.Context, groupID string) (*relationtb.GroupMemberModel, error)
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error) // *
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error) // *
// FindGroupMembers retrieves members of a group filtered by user IDs.
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberUser retrieves groups that a user is a member of, filtered by group IDs.
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberRoleLevels retrieves group members filtered by their role levels within a group.
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberAll retrieves all members of a group.
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupsOwner retrieves the owners for multiple groups.
FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*relationtb.GroupMemberModel, error)
// FindGroupMemberUserID retrieves the user IDs of all members in a group.
FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error)
// FindGroupMemberNum retrieves the number of members in a group.
FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error)
// FindUserManagedGroupID retrieves group IDs managed by a user.
FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
// PageGroupRequest paginates through group requests for specified groups.
PageGroupRequest(ctx context.Context, groupIDs []string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// GetGroupRoleLevelMemberIDs retrieves user IDs of group members with a specific role level.
GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
// PageGetJoinGroup paginates through groups that a user has joined.
PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// PageGetGroupMember paginates through members of a group.
PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// SearchGroupMember searches for group members based on a keyword, group ID, and pagination settings.
SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*relationtb.GroupMemberModel, error)
// HandlerGroupRequest processes a group join request with a specified result.
HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *relationtb.GroupMemberModel) error
// DeleteGroupMember removes specified users from a group.
DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error
// MapGroupMemberUserID maps group IDs to their members' simplified user IDs.
MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*relationtb.GroupSimpleUserID, error)
// MapGroupMemberNum maps group IDs to their member count.
MapGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error)
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error // 转让群
// TransferGroupOwner transfers the ownership of a group to another user.
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error
// UpdateGroupMember updates properties of a group member.
UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error
// UpdateGroupMembers batch updates properties of group members.
UpdateGroupMembers(ctx context.Context, data []*relationtb.BatchUpdateGroupMember) error
// GroupRequest
// CreateGroupRequest creates new group join requests.
CreateGroupRequest(ctx context.Context, requests []*relationtb.GroupRequestModel) error
// TakeGroupRequest retrieves a specific group join request.
TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relationtb.GroupRequestModel, error)
// FindGroupRequests retrieves multiple group join requests.
FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*relationtb.GroupRequestModel, error)
// PageGroupRequestUser paginates through group join requests made by a user.
PageGroupRequestUser(ctx context.Context, userID string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// 获取群总数
// CountTotal counts the total number of groups as of a certain date.
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量
// CountRangeEverydayTotal counts the daily group creation total within a specified date range.
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
// DeleteGroupMemberHash deletes the hash entries for group members in specified groups.
DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error
}
+25 -26
View File
@@ -48,33 +48,32 @@ const (
updateKeyRevoke
)
// CommonMsgDatabase defines the interface for message database operations.
type CommonMsgDatabase interface {
// 批量插入消息
// BatchInsertChat2DB inserts a batch of messages into the database for a specific conversation.
BatchInsertChat2DB(ctx context.Context, conversationID string, msgs []*sdkws.MsgData, currentMaxSeq int64) error
// 撤回消息
// RevokeMsg revokes a message in a conversation.
RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *unrelationtb.RevokeModel) error
// mark as read
// MarkSingleChatMsgsAsRead marks messages as read for a single chat by sequence numbers.
MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 刪除redis中消息缓存
// DeleteMessagesFromCache deletes message caches from Redis by sequence numbers.
DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error
// DelUserDeleteMsgsList deletes user's message deletion list.
DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64)
// incrSeq然后批量插入缓存
// BatchInsertChat2Cache increments the sequence number and then batch inserts messages into the cache.
BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, err error)
// 通过seqList获取mongo中写扩散消息
// GetMsgBySeqsRange retrieves messages from MongoDB by a range of sequence numbers.
GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 通过seqList获取大群在 mongo里面的消息
// GetMsgBySeqs retrieves messages for large groups from MongoDB by sequence numbers.
GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 删除会话消息重置最小seq, remainTime为消息保留的时间单位秒,超时消息删除, 传0删除所有消息(此方法不删除redis cache)
// DeleteConversationMsgsAndSetMinSeq deletes conversation messages and resets the minimum sequence number. If `remainTime` is 0, all messages are deleted (this method does not delete Redis cache).
DeleteConversationMsgsAndSetMinSeq(ctx context.Context, conversationID string, remainTime int64) error
// 用户标记删除过期消息返回标记删除的seq列表
// UserMsgsDestruct marks messages for deletion based on destruct time and returns a list of sequence numbers for marked messages.
UserMsgsDestruct(ctx context.Context, userID string, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error)
// 用户根据seq删除消息
// DeleteUserMsgsBySeqs allows a user to delete messages based on sequence numbers.
DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 物理删除消息置空
// DeleteMsgsPhysicalBySeqs physically deletes messages by emptying them based on sequence numbers.
DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, seqs []int64) error
SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error)
@@ -200,7 +199,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
num := db.msg.GetSingleGocMsgNum()
// num = 100
for i, field := range fields { // 检查类型
for i, field := range fields { // Check the type of the field
var ok bool
switch key {
case updateKeyMsg:
@@ -218,7 +217,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
return errs.ErrInternalServer.Wrap("field type is invalid")
}
}
// 返回值为true表示数据库存在该文档,false表示数据库不存在该文档
// Returns true if the document exists in the database, false if the document does not exist in the database
updateMsgModel := func(seq int64, i int) (bool, error) {
var (
res *mongo.UpdateResult
@@ -240,21 +239,21 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
tryUpdate := true
for i := 0; i < len(fields); i++ {
seq := firstSeq + int64(i) // 当前seq
seq := firstSeq + int64(i) // Current sequence number
if tryUpdate {
matched, err := updateMsgModel(seq, i)
if err != nil {
return err
}
if matched {
continue // 匹配到了,继续下一个(不一定修改)
continue // The current data has been updated, skip the current data
}
}
doc := unrelationtb.MsgDocModel{
DocID: db.msg.GetDocID(conversationID, seq),
Msg: make([]*unrelationtb.MsgInfoModel, num),
}
var insert int // 插入的数量
var insert int // Inserted data number
for j := i; j < len(fields); j++ {
seq = firstSeq + int64(j)
if db.msg.GetDocID(conversationID, seq) != doc.DocID {
@@ -283,14 +282,14 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
}
if err := db.msgDocDatabase.Create(ctx, &doc); err != nil {
if mongo.IsDuplicateKeyError(err) {
i-- // 存在并发,重试当前数据
tryUpdate = true // 以修改模式
i-- // already inserted
tryUpdate = true // next block use update mode
continue
}
return err
}
tryUpdate = false // 当前以插入成功,下一块优先插入模式
i += insert - 1 // 跳过已插入的数据
tryUpdate = false // The current block is inserted successfully, and the next block is inserted preferentially
i += insert - 1 // Skip the inserted data
}
return nil
}
@@ -754,7 +753,7 @@ func (db *commonMsgDatabase) UserMsgsDestruct(ctx context.Context, userID string
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
}
}
// 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归
// If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
break
}
index++
@@ -809,7 +808,7 @@ func (d *delMsgRecursionStruct) getSetMinSeq() int64 {
// index 0....19(del) 20...69
// seq 70
// set minSeq 21
// recursion 删除list并且返回设置的最小seq.
// recursion deletes the list and returns the set minimum seq.
func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversationID string, index int64, delStruct *delMsgRecursionStruct, remainTime int64) (int64, error) {
// find from oldest list
msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1)
@@ -821,7 +820,7 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
}
}
// 获取报错,或者获取不到了,物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归
// If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
err = db.msgDocDatabase.DeleteDocs(ctx, delStruct.delDocIDs)
if err != nil {
return 0, err
+1 -7
View File
@@ -63,13 +63,7 @@ func NewThirdDatabase(cache cache.MsgModel, logdb relation.LogInterface) ThirdDa
return &thirdDatabase{cache: cache, logdb: logdb}
}
func (t *thirdDatabase) FcmUpdateToken(
ctx context.Context,
account string,
platformID int,
fcmToken string,
expireTime int64,
) error {
func (t *thirdDatabase) FcmUpdateToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) error {
return t.cache.SetFcmToken(ctx, account, platformID, fcmToken, expireTime)
}
+2 -2
View File
@@ -58,10 +58,10 @@ func NewAWS() (s3.Interface, error) {
credential := credentials.NewStaticCredentials(
conf.AccessKeyID, // accessKey
conf.AccessKeySecret, // secretKey
"") // sts的临时凭证
"") // stoken
sess, err := session.NewSession(&aws.Config{
Region: aws.String(conf.Region), // 桶所在的区域
Region: aws.String(conf.Region), // The area where the bucket is located
Credentials: credential,
})
+20 -6
View File
@@ -15,10 +15,24 @@
package cont
const (
hashPath = "openim/data/hash/"
tempPath = "openim/temp/"
DirectPath = "openim/direct"
UploadTypeMultipart = 1 // 分片上传
UploadTypePresigned = 2 // 预签名上传
partSeparator = ","
// hashPath defines the storage path for hash data within the 'openim' directory.
hashPath = "openim/data/hash/"
// tempPath specifies the directory for temporary files in the 'openim' structure.
tempPath = "openim/temp/"
// DirectPath indicates the directory for direct uploads or access within the 'openim' structure.
DirectPath = "openim/direct"
// UploadTypeMultipart represents the identifier for multipart uploads,
// allowing large files to be uploaded in chunks.
UploadTypeMultipart = 1
// UploadTypePresigned signifies the use of presigned URLs for uploads,
// facilitating secure, authorized file transfers without requiring direct access to the storage credentials.
UploadTypePresigned = 2
// partSeparator is used as a delimiter in multipart upload processes,
// separating individual file parts.
partSeparator = ","
)
+4 -4
View File
@@ -114,7 +114,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
return nil, err
}
if size <= partSize {
// 预签名上传
// Pre-signed upload
key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID()))
rawURL, err := c.impl.PresignedPutObject(ctx, key, expire)
if err != nil {
@@ -139,7 +139,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
},
}, nil
} else {
// 分片上传
// Fragment upload
upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash))
if err != nil {
return nil, err
@@ -206,7 +206,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
ETag: part,
}
}
// todo: 验证大小
// todo: Validation size
result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts)
if err != nil {
return nil, err
@@ -225,7 +225,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
if md5val := hex.EncodeToString(md5Sum[:]); md5val != upload.Hash {
return nil, errs.ErrArgs.Wrap(fmt.Sprintf("md5 mismatching %s != %s", md5val, upload.Hash))
}
// 防止在这个时候,并发操作,导致文件被覆盖
// Prevents concurrent operations at this time that cause files to be overwritten
copyInfo, err := c.impl.CopyObject(ctx, uploadInfo.Key, upload.Key+"."+c.UUID())
if err != nil {
return nil, err
+8 -3
View File
@@ -17,9 +17,14 @@ package cont
import "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
type InitiateUploadResult struct {
UploadID string `json:"uploadID"` // 上传ID
PartSize int64 `json:"partSize"` // 分片大小
Sign *s3.AuthSignResult `json:"sign"` // 分片信息
// UploadID uniquely identifies the upload session for tracking and management purposes.
UploadID string `json:"uploadID"`
// PartSize specifies the size of each part in a multipart upload. This is relevant for breaking down large uploads into manageable pieces.
PartSize int64 `json:"partSize"`
// Sign contains the authentication and signature information necessary for securely uploading each part. This could include signed URLs or tokens.
Sign *s3.AuthSignResult `json:"sign"`
}
type UploadResult struct {
+33 -61
View File
@@ -42,79 +42,51 @@ func ImageWidthHeight(img image.Image) (int, int) {
return bounds.X, bounds.Y
}
// resizeImage resizes an image to a specified maximum width and height, maintaining the aspect ratio.
// If both maxWidth and maxHeight are set to 0, the original image is returned.
// If both are non-zero, the image is scaled to fit within the constraints while maintaining aspect ratio.
// If only one of maxWidth or maxHeight is non-zero, the image is scaled accordingly.
func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
bounds := img.Bounds()
imgWidth := bounds.Max.X
imgHeight := bounds.Max.Y
imgWidth, imgHeight := bounds.Dx(), bounds.Dy()
// 计算缩放比例
scaleWidth := float64(maxWidth) / float64(imgWidth)
scaleHeight := float64(maxHeight) / float64(imgHeight)
// 如果都为0,则不缩放,返回原始图片
// Return original image if no resizing is needed.
if maxWidth == 0 && maxHeight == 0 {
return img
}
// 如果宽度和高度都大于0,则选择较小的缩放比例,以保持宽高比
var scale float64 = 1
if maxWidth > 0 && maxHeight > 0 {
scale := scaleWidth
if scaleHeight < scaleWidth {
scale = scaleHeight
}
// 计算缩略图尺寸
thumbnailWidth := int(float64(imgWidth) * scale)
thumbnailHeight := int(float64(imgHeight) * scale)
// 使用"image"库的Resample方法生成缩略图
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
srcX := int(float64(x) / scale)
srcY := int(float64(y) / scale)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
}
return thumbnail
scaleWidth := float64(maxWidth) / float64(imgWidth)
scaleHeight := float64(maxHeight) / float64(imgHeight)
// Choose the smaller scale to fit both constraints.
scale = min(scaleWidth, scaleHeight)
} else if maxWidth > 0 {
scale = float64(maxWidth) / float64(imgWidth)
} else if maxHeight > 0 {
scale = float64(maxHeight) / float64(imgHeight)
}
// 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图
if maxWidth > 0 {
thumbnailWidth := maxWidth
thumbnailHeight := int(float64(imgHeight) * scaleWidth)
newWidth := int(float64(imgWidth) * scale)
newHeight := int(float64(imgHeight) * scale)
// 使用"image"库的Resample方法生成缩略图
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
srcX := int(float64(x) / scaleWidth)
srcY := int(float64(y) / scaleWidth)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
// Resize the image by creating a new image and manually copying pixels.
thumbnail := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight))
for y := 0; y < newHeight; y++ {
for x := 0; x < newWidth; x++ {
srcX := int(float64(x) / scale)
srcY := int(float64(y) / scale)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
return thumbnail
}
if maxHeight > 0 {
thumbnailWidth := int(float64(imgWidth) * scaleHeight)
thumbnailHeight := maxHeight
// 使用"image"库的Resample方法生成缩略图
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
srcX := int(float64(x) / scaleHeight)
srcY := int(float64(y) / scaleHeight)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
}
return thumbnail
}
// 默认情况下,返回原始图片
return img
return thumbnail
}
// min returns the smaller of x or y.
func min(x, y float64) float64 {
if x < y {
return x
}
return y
}
+16 -15
View File
@@ -30,6 +30,7 @@ import (
"strings"
"time"
"github.com/OpenIMSDK/tools/errs"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
@@ -60,7 +61,7 @@ const successCode = http.StatusOK
func NewOSS() (s3.Interface, error) {
conf := config.Config.Object.Oss
if conf.BucketURL == "" {
return nil, errors.New("bucket url is empty")
return nil, errs.Wrap(errors.New("bucket url is empty"))
}
client, err := oss.New(conf.Endpoint, conf.AccessKeyID, conf.AccessKeySecret)
if err != nil {
@@ -68,7 +69,7 @@ func NewOSS() (s3.Interface, error) {
}
bucket, err := client.Bucket(conf.Bucket)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "ali-oss bucket error")
}
if conf.BucketURL[len(conf.BucketURL)-1] != '/' {
conf.BucketURL += "/"
@@ -138,10 +139,10 @@ func (o *OSS) CompleteMultipartUpload(ctx context.Context, uploadID string, name
func (o *OSS) PartSize(ctx context.Context, size int64) (int64, error) {
if size <= 0 {
return 0, errors.New("size must be greater than 0")
return 0, errs.Wrap(errors.New("size must be greater than 0"))
}
if size > maxPartSize*maxNumSize {
return 0, fmt.Errorf("OSS size must be less than the maximum allowed limit")
return 0, errs.Wrap(errors.New("size must be less than the maximum allowed limit"))
}
if size <= minPartSize*maxNumSize {
return minPartSize, nil
@@ -196,25 +197,25 @@ func (o *OSS) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, erro
}
res := &s3.ObjectInfo{Key: name}
if res.ETag = strings.ToLower(strings.ReplaceAll(header.Get("ETag"), `"`, ``)); res.ETag == "" {
return nil, errors.New("StatObject etag not found")
return nil, errs.Wrap(errors.New("StatObject etag not found"))
}
if contentLengthStr := header.Get("Content-Length"); contentLengthStr == "" {
return nil, errors.New("StatObject content-length not found")
} else {
res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("StatObject content-length parse error: %w", err)
return nil, errs.Wrap(err, "StatObject content-length parse error")
}
if res.Size < 0 {
return nil, errors.New("StatObject content-length must be greater than 0")
return nil, errs.Wrap(errors.New("StatObject content-length must be greater than 0"))
}
}
if lastModified := header.Get("Last-Modified"); lastModified == "" {
return nil, errors.New("StatObject last-modified not found")
return nil, errs.Wrap(errors.New("StatObject last-modified not found"))
} else {
res.LastModified, err = time.Parse(http.TimeFormat, lastModified)
if err != nil {
return nil, fmt.Errorf("StatObject last-modified parse error: %w", err)
return nil, errs.Wrap(err, "StatObject last-modified parse error")
}
}
return res, nil
@@ -227,7 +228,7 @@ func (o *OSS) DeleteObject(ctx context.Context, name string) error {
func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
result, err := o.bucket.CopyObject(src, dst)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "CopyObject error")
}
return &s3.CopyObjectInfo{
Key: dst,
@@ -261,7 +262,7 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
Bucket: o.bucket.BucketName,
}, oss.MaxUploads(100), oss.MaxParts(maxParts), oss.PartNumberMarker(partNumberMarker))
if err != nil {
return nil, err
return nil, errs.Wrap(err, "ListUploadedParts error")
}
res := &s3.ListUploadedPartsResult{
Key: result.Key,
@@ -286,7 +287,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
var opts []oss.Option
if opt != nil {
if opt.Image != nil {
// 文档地址: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji
// Docs Address: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji
var format string
switch opt.Image.Format {
case
@@ -329,7 +330,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
}
rawParams, err := oss.GetRawParams(opts)
if err != nil {
return "", err
return "", errs.Wrap(err, "AccessURL error")
}
params := getURLParams(*o.bucket.Client.Conn, rawParams)
return getURL(o.um, o.bucket.BucketName, name, params).String(), nil
@@ -351,12 +352,12 @@ func (o *OSS) FormData(ctx context.Context, name string, size int64, contentType
}
policyJson, err := json.Marshal(policy)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Marshal json error")
}
policyStr := base64.StdEncoding.EncodeToString(policyJson)
h := hmac.New(sha1.New, []byte(o.credentials.GetAccessKeySecret()))
if _, err := io.WriteString(h, policyStr); err != nil {
return nil, err
return nil, errs.Wrap(err, "WriteString error")
}
fd := &s3.FormData{
URL: o.bucketURL,
+2 -2
View File
@@ -46,8 +46,8 @@ type GroupModelInterface interface {
Find(ctx context.Context, groupIDs []string) (groups []*GroupModel, err error)
Take(ctx context.Context, groupID string) (group *GroupModel, err error)
Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*GroupModel, err error)
// 获取群总数
// Get Group total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量
// Get Group total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
}
+2 -2
View File
@@ -62,9 +62,9 @@ type UserModelInterface interface {
Exist(ctx context.Context, userID string) (exist bool, err error)
GetAllUserID(ctx context.Context, pagination pagination.Pagination) (count int64, userIDs []string, err error)
GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error)
// 获取用户总数
// Get user total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内用户增量
// Get user total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
//CRUD user command
AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error
+38 -86
View File
@@ -21,8 +21,6 @@ import (
"fmt"
"time"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/constant"
@@ -35,6 +33,7 @@ import (
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
)
@@ -122,13 +121,7 @@ func (m *MsgMongoDriver) UpdateMsgContent(ctx context.Context, docID string, ind
return nil
}
func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(
ctx context.Context,
docID string,
msg *sdkws.MsgData,
seqIndex int,
status int32,
) error {
func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(ctx context.Context, docID string, msg *sdkws.MsgData, seqIndex int, status int32) error {
msg.Status = status
bytes, err := proto.Marshal(msg)
if err != nil {
@@ -140,7 +133,7 @@ func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(
bson.M{"$set": bson.M{fmt.Sprintf("msgs.%d.msg", seqIndex): bytes}},
)
if err != nil {
return errs.Wrap(err)
return errs.Wrap(err, fmt.Sprintf("docID is %s, seqIndex is %d", docID, seqIndex))
}
return nil
}
@@ -166,7 +159,7 @@ func (m *MsgMongoDriver) GetMsgDocModelByIndex(
findOpts,
)
if err != nil {
return nil, errs.Wrap(err)
return nil, errs.Wrap(err, fmt.Sprintf("conversationID is %s", conversationID))
}
var msgs []table.MsgDocModel
err = cursor.All(ctx, &msgs)
@@ -222,7 +215,7 @@ func (m *MsgMongoDriver) DeleteMsgsInOneDocByIndex(ctx context.Context, docID st
}
_, err := m.MsgCollection.UpdateMany(ctx, bson.M{"doc_id": docID}, updates)
if err != nil {
return errs.Wrap(err)
return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
}
return nil
}
@@ -289,7 +282,7 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
defer cur.Close(ctx)
var msgDocModel []table.MsgDocModel
if err := cur.All(ctx, &msgDocModel); err != nil {
return nil, errs.Wrap(err)
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
if len(msgDocModel) == 0 {
return nil, errs.Wrap(mongo.ErrNoDocuments)
@@ -316,14 +309,14 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
}
data, err := json.Marshal(&revokeContent)
if err != nil {
return nil, err
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
elem := sdkws.NotificationElem{
Detail: string(data),
}
content, err := json.Marshal(&elem)
if err != nil {
return nil, err
return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
}
msg.Msg.ContentType = constant.MsgRevokeNotification
msg.Msg.Content = string(content)
@@ -336,17 +329,12 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
func (m *MsgMongoDriver) IsExistDocID(ctx context.Context, docID string) (bool, error) {
count, err := m.MsgCollection.CountDocuments(ctx, bson.M{"doc_id": docID})
if err != nil {
return false, errs.Wrap(err)
return false, errs.Wrap(err, fmt.Sprintf("docID is %s", docID))
}
return count > 0, nil
}
func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(
ctx context.Context,
userID string,
docID string,
indexes []int64,
) error {
func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(ctx context.Context, userID string, docID string, indexes []int64) error {
updates := []mongo.WriteModel{}
for _, index := range indexes {
filter := bson.M{
@@ -366,7 +354,7 @@ func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(
updates = append(updates, updateModel)
}
_, err := m.MsgCollection.BulkWrite(ctx, updates)
return err
return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
}
// RangeUserSendCount
@@ -662,7 +650,7 @@ func (m *MsgMongoDriver) RangeUserSendCount(
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳
"$toDate": "$$item.msg.send_time", // Millisecond timestamp
},
},
},
@@ -911,7 +899,7 @@ func (m *MsgMongoDriver) RangeGroupSendCount(
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳
"$toDate": "$$item.msg.send_time", // Millisecond timestamp
},
},
},
@@ -1076,6 +1064,7 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var pipe mongo.Pipeline
condition := bson.A{}
if req.SendTime != "" {
// Changed to keyed fields for bson.M to avoid govet errors
condition = append(condition, bson.M{"$eq": bson.A{bson.M{"$dateToString": bson.M{"format": "%Y-%m-%d", "date": bson.M{"$toDate": "$$item.msg.send_time"}}}, req.SendTime}})
}
if req.MsgType != 0 {
@@ -1092,62 +1081,26 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
}
or := bson.A{
bson.M{
"doc_id": bson.M{
"$regex": "^si_",
"$options": "i",
},
},
bson.M{"doc_id": bson.M{"$regex": "^si_", "$options": "i"}},
bson.M{"doc_id": bson.M{"$regex": "^g_", "$options": "i"}},
bson.M{"doc_id": bson.M{"$regex": "^sg_", "$options": "i"}},
}
or = append(or,
bson.M{
"doc_id": bson.M{
"$regex": "^g_",
"$options": "i",
},
},
bson.M{
"doc_id": bson.M{
"$regex": "^sg_",
"$options": "i",
},
},
)
// Use bson.D with keyed fields to specify the order explicitly
pipe = mongo.Pipeline{
{
{"$match", bson.D{
{
"$or", or,
},
{{"$match", bson.D{{Key: "$or", Value: or}}}},
{{"$project", bson.D{
{Key: "msgs", Value: bson.D{
{Key: "$filter", Value: bson.D{
{Key: "input", Value: "$msgs"},
{Key: "as", Value: "item"},
{Key: "cond", Value: bson.D{{Key: "$and", Value: condition}}},
}},
}},
},
{
{"$project", bson.D{
{
"msgs", bson.D{
{
"$filter", bson.D{
{"input", "$msgs"},
{"as", "item"},
{
"cond", bson.D{
{"$and", condition},
},
},
},
},
},
},
{"doc_id", 1},
}},
},
{
{"$unwind", bson.M{"path": "$msgs"}},
},
{
{"$sort", bson.M{"msgs.msg.send_time": -1}},
},
{Key: "doc_id", Value: 1},
}}},
{{"$unwind", bson.M{"path": "$msgs"}}},
{{"$sort", bson.M{"msgs.msg.send_time": -1}}},
}
cursor, err := m.MsgCollection.Aggregate(ctx, pipe)
if err != nil {
@@ -1160,12 +1113,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var msgsDocs []docModel
err = cursor.All(ctx, &msgsDocs)
if err != nil {
return 0, nil, err
return 0, nil, errs.Wrap(err, "cursor.All msgsDocs")
}
log.ZDebug(ctx, "query mongoDB", "result", msgsDocs)
msgs := make([]*table.MsgInfoModel, 0)
for index := range msgsDocs {
msgInfo := msgsDocs[index].Msg
for _, doc := range msgsDocs {
msgInfo := doc.Msg
if msgInfo == nil || msgInfo.Msg == nil {
continue
}
@@ -1185,14 +1138,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
}
data, err := json.Marshal(&revokeContent)
if err != nil {
return 0, nil, err
}
elem := sdkws.NotificationElem{
Detail: string(data),
return 0, nil, errs.Wrap(err, "json.Marshal revokeContent")
}
elem := sdkws.NotificationElem{Detail: string(data)}
content, err := json.Marshal(&elem)
if err != nil {
return 0, nil, err
return 0, nil, errs.Wrap(err, "json.Marshal elem")
}
msgInfo.Msg.ContentType = constant.MsgRevokeNotification
msgInfo.Msg.Content = string(content)
@@ -1203,7 +1154,8 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
n := int32(len(msgs))
if start >= n {
return n, []*table.MsgInfoModel{}, nil
} else if start+req.Pagination.ShowNumber < n {
}
if start+req.Pagination.ShowNumber < n {
msgs = msgs[start : start+req.Pagination.ShowNumber]
} else {
msgs = msgs[start:]
@@ -105,7 +105,7 @@ func (cd *ConnDirect) GetConns(ctx context.Context,
}
if len(connections) == 0 {
return nil, fmt.Errorf("no connections found for service: %s", serviceName)
return nil, errs.Wrap(errors.New("no connections found for service"), "serviceName", serviceName)
}
return connections, nil
}
@@ -155,10 +155,11 @@ func (cd *ConnDirect) dialService(ctx context.Context, address string, opts ...g
conn, err := grpc.DialContext(ctx, cd.resolverDirect.Scheme()+":///"+address, options...)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "address", address)
}
return conn, nil
}
func (cd *ConnDirect) dialServiceWithoutResolver(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
options := append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.DialContext(ctx, address, options...)
@@ -24,6 +24,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister/zookeeper"
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
)
// NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type.
@@ -41,6 +42,6 @@ func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistr
case "direct":
return direct.NewConnDirect()
default:
return nil, errors.New("envType not correct")
return nil, errs.Wrap(errors.New("envType not correct"))
}
}
@@ -39,7 +39,7 @@ func TestNewDiscoveryRegister(t *testing.T) {
expectedResult bool
}{
{"zookeeper", false, true},
{"k8s", false, true}, // 假设 k8s 配置也已正确设置
{"k8s", false, true}, // Assume that the k8s configuration is also set up correctly
{"direct", false, true},
{"invalid", true, false},
}
+8 -6
View File
@@ -66,12 +66,12 @@ func Post(ctx context.Context, url string, header map[string]string, data any, t
jsonStr, err := json.Marshal(data)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Post: JSON marshal failed")
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonStr))
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Post: NewRequestWithContext failed")
}
if operationID, _ := ctx.Value(constant.OperationID).(string); operationID != "" {
@@ -84,13 +84,13 @@ func Post(ctx context.Context, url string, header map[string]string, data any, t
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Post: client.Do failed")
}
defer resp.Body.Close()
result, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "Post: ReadAll failed")
}
return result, nil
@@ -102,7 +102,10 @@ func PostReturn(ctx context.Context, url string, header map[string]string, input
return err
}
err = json.Unmarshal(b, output)
return err
if err != nil {
return errs.Wrap(err, "PostReturn: JSON unmarshal failed")
}
return nil
}
func callBackPostReturn(ctx context.Context, url, command string, input interface{}, output callbackstruct.CallbackResp, callbackConfig config.CallBackConfig) error {
@@ -127,7 +130,6 @@ func callBackPostReturn(ctx context.Context, url, command string, input interfac
}
if err := output.Parse(); err != nil {
log.ZWarn(ctx, "callback parse failed", err, "url", url, "input", input, "response", string(b))
return err
}
log.ZInfo(ctx, "callback success", "url", url, "input", input, "response", string(b))
return nil
+20 -16
View File
@@ -17,9 +17,8 @@ package kafka
import (
"sync"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/IBM/sarama"
"github.com/OpenIMSDK/tools/errs"
)
type Consumer struct {
@@ -30,28 +29,33 @@ type Consumer struct {
Consumer sarama.Consumer
}
func NewKafkaConsumer(addr []string, topic string) *Consumer {
p := Consumer{}
p.Topic = topic
p.addr = addr
consumerConfig := sarama.NewConfig()
if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" {
consumerConfig.Net.SASL.Enable = true
consumerConfig.Net.SASL.User = config.Config.Kafka.Username
consumerConfig.Net.SASL.Password = config.Config.Kafka.Password
func NewKafkaConsumer(addr []string, topic string, kafkaConfig *sarama.Config) (*Consumer, error) {
p := Consumer{
Topic: topic,
addr: addr,
}
SetupTLSConfig(consumerConfig)
consumer, err := sarama.NewConsumer(p.addr, consumerConfig)
if kafkaConfig.Net.SASL.User != "" && kafkaConfig.Net.SASL.Password != "" {
kafkaConfig.Net.SASL.Enable = true
}
err := SetupTLSConfig(kafkaConfig)
if err != nil {
panic(err.Error())
return nil, err
}
consumer, err := sarama.NewConsumer(p.addr, kafkaConfig)
if err != nil {
return nil, errs.Wrap(err, "NewKafkaConsumer: creating consumer failed")
}
p.Consumer = consumer
partitionList, err := consumer.Partitions(p.Topic)
if err != nil {
panic(err.Error())
return nil, errs.Wrap(err, "NewKafkaConsumer: getting partitions failed")
}
p.PartitionList = partitionList
return &p
return &p, nil
}
+2 -3
View File
@@ -90,11 +90,10 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
for i := 0; i <= maxRetry; i++ {
p.producer, err = sarama.NewSyncProducer(p.addr, p.config)
if err == nil {
return &p, nil
return &p, errs.Wrap(err)
}
time.Sleep(1 * time.Second) // Wait before retrying
}
// Panic if unable to create producer after retries
if err != nil {
return nil, errs.Wrap(errors.New("failed to create Kafka producer: " + err.Error()))
@@ -179,7 +178,7 @@ func (p *Producer) SendMessage(ctx context.Context, key string, msg proto.Messag
// Attach context metadata as headers
header, err := GetMQHeaderWithContext(ctx)
if err != nil {
return 0, 0, errs.Wrap(err)
return 0, 0, err
}
kMsg.Headers = header
+7 -2
View File
@@ -26,16 +26,21 @@ import (
)
// SetupTLSConfig set up the TLS config from config file.
func SetupTLSConfig(cfg *sarama.Config) {
func SetupTLSConfig(cfg *sarama.Config) error {
if config.Config.Kafka.TLS != nil {
cfg.Net.TLS.Enable = true
cfg.Net.TLS.Config = tls.NewTLSConfig(
tlsConfig, err := tls.NewTLSConfig(
config.Config.Kafka.TLS.ClientCrt,
config.Config.Kafka.TLS.ClientKey,
config.Config.Kafka.TLS.CACrt,
[]byte(config.Config.Kafka.TLS.ClientKeyPwd),
)
if err != nil {
return err
}
cfg.Net.TLS.Config = tlsConfig
}
return nil
}
// getEnvOrConfig returns the value of the environment variable if it exists,
-1
View File
@@ -24,7 +24,6 @@ import (
)
func NewGrpcPromObj(cusMetrics []prometheus.Collector) (*prometheus.Registry, *gp.ServerMetrics, error) {
////////////////////////////////////////////////////////
reg := prometheus.NewRegistry()
grpcMetrics := gp.NewServerMetrics()
grpcMetrics.EnableHandlingTimeHistogram()
+1 -1
View File
@@ -70,7 +70,7 @@ func Start(
defer listener.Close()
client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
if err != nil {
return errs.Wrap(err)
return err
}
defer client.Close()
Executable → Regular
+21 -15
View File
@@ -21,6 +21,8 @@ import (
"errors"
"os"
"github.com/OpenIMSDK/tools/errs"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
)
@@ -49,37 +51,41 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) {
}
// NewTLSConfig setup the TLS config from general config file.
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) *tls.Config {
tlsConfig := tls.Config{}
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) (*tls.Config, error) {
var tlsConfig tls.Config
if clientCertFile != "" && clientKeyFile != "" {
certPEMBlock, err := os.ReadFile(clientCertFile)
if err != nil {
panic(err)
return nil, errs.Wrap(err, "NewTLSConfig: failed to read client cert file")
}
keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd)
if err != nil {
panic(err)
return nil, err
}
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
panic(err)
return nil, errs.Wrap(err, "NewTLSConfig: failed to create X509 key pair")
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
caCert, err := os.ReadFile(caCertFile)
if err != nil {
panic(err)
if caCertFile != "" {
caCert, err := os.ReadFile(caCertFile)
if err != nil {
return nil, errs.Wrap(err, "NewTLSConfig: failed to read CA cert file")
}
caCertPool := x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("NewTLSConfig: not a valid CA cert")
}
tlsConfig.RootCAs = caCertPool
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert)
if !ok {
panic(errors.New("not a valid CA cert"))
}
tlsConfig.RootCAs = caCertPool
tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify
return &tlsConfig
return &tlsConfig, nil
}
+2 -1
View File
@@ -20,6 +20,7 @@ import (
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"google.golang.org/protobuf/proto"
)
@@ -188,7 +189,7 @@ func (s MsgBySeq) Swap(i, j int) {
func Pb2String(pb proto.Message) (string, error) {
s, err := proto.Marshal(pb)
if err != nil {
return "", err
return "", errs.Wrap(err)
}
return string(s), nil
}
+2 -1
View File
@@ -23,12 +23,13 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImAuthName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := auth.NewAuthClient(conn)
return &Auth{discov: discov, conn: conn, Client: client}
+4 -6
View File
@@ -24,6 +24,8 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
)
@@ -36,7 +38,7 @@ type Conversation struct {
func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversation {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImConversationName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := pbconversation.NewConversationClient(conn)
return &Conversation{discov: discov, conn: conn, Client: client}
@@ -114,11 +116,7 @@ func (c *ConversationRpcClient) GetConversationsByConversationID(ctx context.Con
return resp.Conversations, nil
}
func (c *ConversationRpcClient) GetConversations(
ctx context.Context,
ownerUserID string,
conversationIDs []string,
) ([]*pbconversation.Conversation, error) {
func (c *ConversationRpcClient) GetConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*pbconversation.Conversation, error) {
if len(conversationIDs) == 0 {
return nil, nil
}
+3 -2
View File
@@ -24,6 +24,7 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
type Friend struct {
@@ -35,7 +36,7 @@ type Friend struct {
func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImFriendName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := friend.NewFriendClient(conn)
return &Friend{discov: discov, conn: conn, Client: client}
@@ -62,7 +63,7 @@ func (f *FriendRpcClient) GetFriendsInfo(
return
}
// possibleFriendUserID是否在userID的好友中.
// possibleFriendUserID Is PossibleFriendUserId's friends.
func (f *FriendRpcClient) IsFriend(ctx context.Context, possibleFriendUserID, userID string) (bool, error) {
resp, err := f.Client.IsFriend(ctx, &friend.IsFriendReq{UserID1: userID, UserID2: possibleFriendUserID})
if err != nil {
+2 -1
View File
@@ -28,6 +28,7 @@ import (
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
type Group struct {
@@ -39,7 +40,7 @@ type Group struct {
func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImGroupName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := group.NewGroupClient(conn)
return &Group{discov: discov, conn: conn, Client: client}
+19 -3
View File
@@ -147,14 +147,24 @@ func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageR
return MessageRpcClient(*NewMessage(discov))
}
// SendMsg sends a message through the gRPC client and returns the response.
// It wraps any encountered error for better error handling and context understanding.
func (m *MessageRpcClient) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
resp, err := m.Client.SendMsg(ctx, req)
return resp, err
if err != nil {
return nil, err
}
return resp, nil
}
// GetMaxSeq retrieves the maximum sequence number from the gRPC client.
// Errors during the gRPC call are wrapped to provide additional context.
func (m *MessageRpcClient) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) {
resp, err := m.Client.GetMaxSeq(ctx, req)
return resp, err
if err != nil {
return nil, err
}
return resp, nil
}
func (m *MessageRpcClient) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
@@ -181,9 +191,15 @@ func (m *MessageRpcClient) GetMsgByConversationIDs(ctx context.Context, docIDs [
return resp.MsgDatas, err
}
// PullMessageBySeqList retrieves messages by their sequence numbers using the gRPC client.
// It directly forwards the request to the gRPC client and returns the response along with any error encountered.
func (m *MessageRpcClient) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) {
resp, err := m.Client.PullMessageBySeqs(ctx, req)
return resp, err
if err != nil {
// Wrap the error to provide more context if the gRPC call fails.
return nil, err
}
return resp, nil
}
func (m *MessageRpcClient) GetConversationMaxSeq(ctx context.Context, conversationID string) (int64, error) {
+1 -3
View File
@@ -31,7 +31,7 @@ func NewConversationNotificationSender(msgRpcClient *rpcclient.MessageRpcClient)
return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))}
}
// SetPrivate调用.
// SetPrivate invote.
func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx context.Context, sendID, recvID string,
isPrivateChat bool, conversationID string,
) error {
@@ -45,7 +45,6 @@ func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx
return c.Notification(ctx, sendID, recvID, constant.ConversationPrivateChatNotification, tips)
}
// 会话改变.
func (c *ConversationNotificationSender) ConversationChangeNotification(ctx context.Context, userID string, conversationIDs []string) error {
tips := &sdkws.ConversationUpdateTips{
UserID: userID,
@@ -55,7 +54,6 @@ func (c *ConversationNotificationSender) ConversationChangeNotification(ctx cont
return c.Notification(ctx, userID, userID, constant.ConversationChangeNotification, tips)
}
// 会话未读数同步.
func (c *ConversationNotificationSender) ConversationUnreadChangeNotification(
ctx context.Context,
userID, conversationID string,
+1 -1
View File
@@ -31,7 +31,7 @@ import (
type FriendNotificationSender struct {
*rpcclient.NotificationSender
// 找不到报错
// Target not found err
getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error)
// db controller
db controller.FriendDatabase
+3 -5
View File
@@ -23,6 +23,7 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
type Push struct {
@@ -34,7 +35,7 @@ type Push struct {
func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
return &Push{
discov: discov,
@@ -49,9 +50,6 @@ func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) PushRpcClie
return PushRpcClient(*NewPush(discov))
}
func (p *PushRpcClient) DelUserPushToken(
ctx context.Context,
req *push.DelUserPushTokenReq,
) (*push.DelUserPushTokenResp, error) {
func (p *PushRpcClient) DelUserPushToken(ctx context.Context, req *push.DelUserPushTokenReq) (*push.DelUserPushTokenResp, error) {
return p.Client.DelUserPushToken(ctx, req)
}
+24 -15
View File
@@ -24,8 +24,10 @@ import (
"github.com/OpenIMSDK/protocol/third"
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
)
type Third struct {
@@ -38,35 +40,42 @@ type Third struct {
func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImThirdName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := third.NewThirdClient(conn)
minioClient, err := minioInit()
if err != nil {
panic(err)
util.ExitWithError(err)
}
return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient}
}
func minioInit() (*minio.Client, error) {
minioClient := &minio.Client{}
initUrl := config.Config.Object.Minio.Endpoint
minioUrl, err := url.Parse(initUrl)
// Retrieve MinIO configuration details
endpoint := config.Config.Object.Minio.Endpoint
accessKeyID := config.Config.Object.Minio.AccessKeyID
secretAccessKey := config.Config.Object.Minio.SecretAccessKey
// Parse the MinIO URL to determine if the connection should be secure
minioURL, err := url.Parse(endpoint)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "minioInit: failed to parse MinIO endpoint URL")
}
// Determine the security of the connection based on the scheme
secure := minioURL.Scheme == "https"
// Setup MinIO client options
opts := &minio.Options{
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, ""),
// Region: config.Config.Credential.Minio.Location,
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: secure,
}
if minioUrl.Scheme == "http" {
opts.Secure = false
} else if minioUrl.Scheme == "https" {
opts.Secure = true
}
minioClient, err = minio.New(minioUrl.Host, opts)
// Initialize MinIO client
minioClient, err := minio.New(minioURL.Host, opts)
if err != nil {
return nil, err
return nil, errs.Wrap(err, "minioInit: failed to create MinIO client")
}
return minioClient, nil
}
+2 -1
View File
@@ -19,6 +19,7 @@ import (
"strings"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
"google.golang.org/grpc"
@@ -42,7 +43,7 @@ type User struct {
func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImUserName)
if err != nil {
panic(err)
util.ExitWithError(err)
}
client := user.NewUserClient(conn)
return &User{Discov: discov, Client: client, conn: conn}
+5 -3
View File
@@ -18,6 +18,8 @@ import (
"fmt"
"os"
"path/filepath"
"github.com/OpenIMSDK/tools/errs"
)
// OutDir creates the absolute path name from path and checks path exists.
@@ -25,16 +27,16 @@ import (
func OutDir(path string) (string, error) {
outDir, err := filepath.Abs(path)
if err != nil {
return "", err
return "", errs.Wrap(err, "output directory %s does not exist", path)
}
stat, err := os.Stat(outDir)
if err != nil {
return "", err
return "", errs.Wrap(err, "output directory %s does not exist", outDir)
}
if !stat.IsDir() {
return "", fmt.Errorf("output directory %s is not a directory", outDir)
return "", errs.Wrap(err, "output directory %s is not a directory", outDir)
}
outDir += "/"
return outDir, nil