diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 9a3b6aa197..3620e3d22c 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -15,6 +15,17 @@ import ( "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" + dch "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" + slackch "github.com/sipeed/picoclaw/pkg/channels/slack" + tgramch "github.com/sipeed/picoclaw/pkg/channels/telegram" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" @@ -128,19 +139,19 @@ func gatewayCmd() { if transcriber != nil { if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { + if tc, ok := telegramChannel.(*tgramch.TelegramChannel); ok { tc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Telegram channel") } } if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { + if dc, ok := discordChannel.(*dch.DiscordChannel); ok { dc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Discord channel") } } if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { + if sc, ok := slackChannel.(*slackch.SlackChannel); ok { sc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Slack channel") } diff --git a/pkg/channels/base.go b/pkg/channels/base.go index cd6419ebb9..5d77c6c0d9 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -3,6 +3,7 @@ package channels import ( "context" "strings" + "sync/atomic" "github.com/sipeed/picoclaw/pkg/bus" ) @@ -19,7 +20,7 @@ type Channel interface { type BaseChannel struct { config any bus *bus.MessageBus - running bool + running atomic.Bool name string allowList []string } @@ -30,7 +31,6 @@ func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []st bus: bus, name: name, allowList: allowList, - running: false, } } @@ -39,7 +39,7 @@ func (c *BaseChannel) Name() string { } func (c *BaseChannel) IsRunning() bool { - return c.running + return c.running.Load() } func (c *BaseChannel) IsAllowed(senderID string) bool { @@ -98,6 +98,6 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st c.bus.PublishInbound(msg) } -func (c *BaseChannel) setRunning(running bool) { - c.running = running +func (c *BaseChannel) SetRunning(running bool) { + c.running.Store(running) } diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go similarity index 96% rename from pkg/channels/dingtalk.go rename to pkg/channels/dingtalk/dingtalk.go index 662fba3b72..afc0de47f8 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -1,7 +1,7 @@ // PicoClaw - Ultra-lightweight personal AI agent // DingTalk channel implementation using Stream Mode -package channels +package dingtalk import ( "context" @@ -12,6 +12,7 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -20,7 +21,7 @@ import ( // DingTalkChannel implements the Channel interface for DingTalk (钉钉) // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { - *BaseChannel + *channels.BaseChannel config config.DingTalkConfig clientID string clientSecret string @@ -37,7 +38,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } - base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) return &DingTalkChannel{ BaseChannel: base, @@ -70,7 +71,7 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start stream client: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } @@ -87,7 +88,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { c.streamClient.Close() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go new file mode 100644 index 0000000000..5f49bce8c1 --- /dev/null +++ b/pkg/channels/dingtalk/init.go @@ -0,0 +1,13 @@ +package dingtalk + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDingTalkChannel(cfg.Channels.DingTalk, b) + }) +} diff --git a/pkg/channels/discord.go b/pkg/channels/discord/discord.go similarity index 98% rename from pkg/channels/discord.go rename to pkg/channels/discord/discord.go index 20f3b267ca..b83ac28fd8 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord/discord.go @@ -1,4 +1,4 @@ -package channels +package discord import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -23,7 +24,7 @@ const ( ) type DiscordChannel struct { - *BaseChannel + *channels.BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber @@ -39,7 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } - base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) return &DiscordChannel{ BaseChannel: base, @@ -80,7 +81,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to open discord session: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, @@ -92,7 +93,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") - c.setRunning(false) + c.SetRunning(false) // Stop all typing goroutines before closing session c.typingMu.Lock() diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go new file mode 100644 index 0000000000..15a5398040 --- /dev/null +++ b/pkg/channels/discord/init.go @@ -0,0 +1,13 @@ +package discord + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDiscordChannel(cfg.Channels.Discord, b) + }) +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go new file mode 100644 index 0000000000..e8a0577411 --- /dev/null +++ b/pkg/channels/feishu/common.go @@ -0,0 +1,9 @@ +package feishu + +// stringValue safely dereferences a *string pointer. +func stringValue(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu/feishu_32.go similarity index 79% rename from pkg/channels/feishu_32.go rename to pkg/channels/feishu/feishu_32.go index 5109b81956..14711e49ea 100644 --- a/pkg/channels/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -1,25 +1,24 @@ //go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 -package channels +package feishu import ( "context" "errors" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) // FeishuChannel is a stub implementation for 32-bit architectures type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel } // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - return nil, errors.New( - "feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config", - ) + return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config") } // Start is a stub method to satisfy the Channel interface diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu/feishu_64.go similarity index 96% rename from pkg/channels/feishu_64.go rename to pkg/channels/feishu/feishu_64.go index 42e74980f8..aa4e141c43 100644 --- a/pkg/channels/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -1,6 +1,6 @@ //go:build amd64 || arm64 || riscv64 || mips64 || ppc64 -package channels +package feishu import ( "context" @@ -15,13 +15,14 @@ import ( larkws "github.com/larksuite/oapi-sdk-go/v3/ws" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel config config.FeishuConfig client *lark.Client wsClient *larkws.Client @@ -31,7 +32,7 @@ type FeishuChannel struct { } func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) return &FeishuChannel{ BaseChannel: base, @@ -60,7 +61,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error { wsClient := c.wsClient c.mu.Unlock() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("feishu", "Feishu channel started (websocket mode)") go func() { @@ -83,7 +84,7 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { c.wsClient = nil c.mu.Unlock() - c.setRunning(false) + c.SetRunning(false) logger.InfoC("feishu", "Feishu channel stopped") return nil } @@ -218,10 +219,3 @@ func extractFeishuMessageContent(message *larkim.EventMessage) string { return *message.Content } - -func stringValue(v *string) string { - if v == nil { - return "" - } - return *v -} diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go new file mode 100644 index 0000000000..7e5a62daec --- /dev/null +++ b/pkg/channels/feishu/init.go @@ -0,0 +1,13 @@ +package feishu + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewFeishuChannel(cfg.Channels.Feishu, b) + }) +} diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go new file mode 100644 index 0000000000..9265575ccf --- /dev/null +++ b/pkg/channels/line/init.go @@ -0,0 +1,13 @@ +package line + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewLINEChannel(cfg.Channels.LINE, b) + }) +} diff --git a/pkg/channels/line.go b/pkg/channels/line/line.go similarity index 98% rename from pkg/channels/line.go rename to pkg/channels/line/line.go index 44134996fe..4e1d0dfd34 100644 --- a/pkg/channels/line.go +++ b/pkg/channels/line/line.go @@ -1,4 +1,4 @@ -package channels +package line import ( "bytes" @@ -16,6 +16,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -41,7 +42,7 @@ type replyTokenEntry struct { // using the LINE Messaging API with HTTP webhook for receiving messages // and REST API for sending messages. type LINEChannel struct { - *BaseChannel + *channels.BaseChannel config config.LINEConfig httpServer *http.Server botUserID string // Bot's user ID @@ -59,7 +60,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return nil, fmt.Errorf("line channel_secret and channel_access_token are required") } - base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) return &LINEChannel{ BaseChannel: base, @@ -111,7 +112,7 @@ func (c *LINEChannel) Start(ctx context.Context) error { } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("line", "LINE channel started (Webhook Mode)") return nil } @@ -168,7 +169,7 @@ func (c *LINEChannel) Stop(ctx context.Context) error { } } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("line", "LINE channel stopped") return nil } diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go new file mode 100644 index 0000000000..5a269b22ba --- /dev/null +++ b/pkg/channels/maixcam/init.go @@ -0,0 +1,13 @@ +package maixcam + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMaixCamChannel(cfg.Channels.MaixCam, b) + }) +} diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam/maixcam.go similarity index 96% rename from pkg/channels/maixcam.go rename to pkg/channels/maixcam/maixcam.go index 34ce62b20f..a7bff55e05 100644 --- a/pkg/channels/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -1,4 +1,4 @@ -package channels +package maixcam import ( "context" @@ -8,12 +8,13 @@ import ( "sync" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" ) type MaixCamChannel struct { - *BaseChannel + *channels.BaseChannel config config.MaixCamConfig listener net.Listener clients map[net.Conn]bool @@ -28,7 +29,7 @@ type MaixCamMessage struct { } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { - base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) return &MaixCamChannel{ BaseChannel: base, @@ -47,7 +48,7 @@ func (c *MaixCamChannel) Start(ctx context.Context) error { } c.listener = listener - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ "host": c.config.Host, @@ -70,7 +71,7 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { default: conn, err := c.listener.Accept() if err != nil { - if c.running { + if c.IsRunning() { logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ "error": err.Error(), }) @@ -185,7 +186,7 @@ func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { func (c *MaixCamChannel) Stop(ctx context.Context) error { logger.InfoC("maixcam", "Stopping MaixCam channel") - c.setRunning(false) + c.SetRunning(false) if c.listener != nil { c.listener.Close() diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 75edaf49e3..7baef058c6 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -43,163 +43,81 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error return m, nil } +// initChannel is a helper that looks up a factory by name and creates the channel. +func (m *Manager) initChannel(name, displayName string) { + f, ok := getFactory(name) + if !ok { + logger.WarnCF("channels", "Factory not registered", map[string]any{ + "channel": displayName, + }) + return + } + logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{ + "channel": displayName, + }) + ch, err := f(m.config, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{ + "channel": displayName, + "error": err.Error(), + }) + } else { + m.channels[name] = ch + logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ + "channel": displayName, + }) + } +} + func (m *Manager) initChannels() error { logger.InfoC("channels", "Initializing channel manager") if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { - logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["telegram"] = telegram - logger.InfoC("channels", "Telegram channel enabled successfully") - } + m.initChannel("telegram", "Telegram") } if m.config.Channels.WhatsApp.Enabled && m.config.Channels.WhatsApp.BridgeURL != "" { - logger.DebugC("channels", "Attempting to initialize WhatsApp channel") - whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["whatsapp"] = whatsapp - logger.InfoC("channels", "WhatsApp channel enabled successfully") - } + m.initChannel("whatsapp", "WhatsApp") } if m.config.Channels.Feishu.Enabled { - logger.DebugC("channels", "Attempting to initialize Feishu channel") - feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["feishu"] = feishu - logger.InfoC("channels", "Feishu channel enabled successfully") - } + m.initChannel("feishu", "Feishu") } if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" { - logger.DebugC("channels", "Attempting to initialize Discord channel") - discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["discord"] = discord - logger.InfoC("channels", "Discord channel enabled successfully") - } + m.initChannel("discord", "Discord") } if m.config.Channels.MaixCam.Enabled { - logger.DebugC("channels", "Attempting to initialize MaixCam channel") - maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["maixcam"] = maixcam - logger.InfoC("channels", "MaixCam channel enabled successfully") - } + m.initChannel("maixcam", "MaixCam") } if m.config.Channels.QQ.Enabled { - logger.DebugC("channels", "Attempting to initialize QQ channel") - qq, err := NewQQChannel(m.config.Channels.QQ, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["qq"] = qq - logger.InfoC("channels", "QQ channel enabled successfully") - } + m.initChannel("qq", "QQ") } if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" { - logger.DebugC("channels", "Attempting to initialize DingTalk channel") - dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["dingtalk"] = dingtalk - logger.InfoC("channels", "DingTalk channel enabled successfully") - } + m.initChannel("dingtalk", "DingTalk") } if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { - logger.DebugC("channels", "Attempting to initialize Slack channel") - slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["slack"] = slackCh - logger.InfoC("channels", "Slack channel enabled successfully") - } + m.initChannel("slack", "Slack") } if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { - logger.DebugC("channels", "Attempting to initialize LINE channel") - line, err := NewLINEChannel(m.config.Channels.LINE, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["line"] = line - logger.InfoC("channels", "LINE channel enabled successfully") - } + m.initChannel("line", "LINE") } if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { - logger.DebugC("channels", "Attempting to initialize OneBot channel") - onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["onebot"] = onebot - logger.InfoC("channels", "OneBot channel enabled successfully") - } + m.initChannel("onebot", "OneBot") } if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { - logger.DebugC("channels", "Attempting to initialize WeCom channel") - wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom"] = wecom - logger.InfoC("channels", "WeCom channel enabled successfully") - } + m.initChannel("wecom", "WeCom") } if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { - logger.DebugC("channels", "Attempting to initialize WeCom App channel") - wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom_app"] = wecomApp - logger.InfoC("channels", "WeCom App channel enabled successfully") - } + m.initChannel("wecom_app", "WeCom App") } logger.InfoCF("channels", "Channel initialization completed", map[string]any{ diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go new file mode 100644 index 0000000000..84c06dfd67 --- /dev/null +++ b/pkg/channels/onebot/init.go @@ -0,0 +1,13 @@ +package onebot + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewOneBotChannel(cfg.Channels.OneBot, b) + }) +} diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot/onebot.go similarity index 99% rename from pkg/channels/onebot.go rename to pkg/channels/onebot/onebot.go index cee8ad9d33..3d2e64e2a3 100644 --- a/pkg/channels/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -1,4 +1,4 @@ -package channels +package onebot import ( "context" @@ -14,6 +14,7 @@ import ( "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -21,7 +22,7 @@ import ( ) type OneBotChannel struct { - *BaseChannel + *channels.BaseChannel config config.OneBotConfig conn *websocket.Conn ctx context.Context @@ -98,7 +99,7 @@ type oneBotMessageSegment struct { } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { - base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) const dedupSize = 1024 return &OneBotChannel{ @@ -159,7 +160,7 @@ func (c *OneBotChannel) Start(ctx context.Context) error { } } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("onebot", "OneBot channel started successfully") return nil @@ -346,7 +347,7 @@ func (c *OneBotChannel) reconnectLoop() { func (c *OneBotChannel) Stop(ctx context.Context) error { logger.InfoC("onebot", "Stopping OneBot channel") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go new file mode 100644 index 0000000000..15b9550896 --- /dev/null +++ b/pkg/channels/qq/init.go @@ -0,0 +1,13 @@ +package qq + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewQQChannel(cfg.Channels.QQ, b) + }) +} diff --git a/pkg/channels/qq.go b/pkg/channels/qq/qq.go similarity index 96% rename from pkg/channels/qq.go rename to pkg/channels/qq/qq.go index e66cac533e..2a95bbd060 100644 --- a/pkg/channels/qq.go +++ b/pkg/channels/qq/qq.go @@ -1,4 +1,4 @@ -package channels +package qq import ( "context" @@ -14,12 +14,13 @@ import ( "golang.org/x/oauth2" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" ) type QQChannel struct { - *BaseChannel + *channels.BaseChannel config config.QQConfig api openapi.OpenAPI tokenSource oauth2.TokenSource @@ -31,7 +32,7 @@ type QQChannel struct { } func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { - base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) return &QQChannel{ BaseChannel: base, @@ -90,11 +91,11 @@ func (c *QQChannel) Start(ctx context.Context) error { logger.ErrorCF("qq", "WebSocket session error", map[string]any{ "error": err.Error(), }) - c.setRunning(false) + c.SetRunning(false) } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("qq", "QQ bot started successfully") return nil @@ -102,7 +103,7 @@ func (c *QQChannel) Start(ctx context.Context) error { func (c *QQChannel) Stop(ctx context.Context) error { logger.InfoC("qq", "Stopping QQ bot") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() diff --git a/pkg/channels/registry.go b/pkg/channels/registry.go new file mode 100644 index 0000000000..36a05bf3eb --- /dev/null +++ b/pkg/channels/registry.go @@ -0,0 +1,32 @@ +package channels + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// ChannelFactory is a constructor function that creates a Channel from config and message bus. +// Each channel subpackage registers one or more factories via init(). +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +var ( + factoriesMu sync.RWMutex + factories = map[string]ChannelFactory{} +) + +// RegisterFactory registers a named channel factory. Called from subpackage init() functions. +func RegisterFactory(name string, f ChannelFactory) { + factoriesMu.Lock() + defer factoriesMu.Unlock() + factories[name] = f +} + +// getFactory looks up a channel factory by name. +func getFactory(name string) (ChannelFactory, bool) { + factoriesMu.RLock() + defer factoriesMu.RUnlock() + f, ok := factories[name] + return f, ok +} diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go new file mode 100644 index 0000000000..c131bb2916 --- /dev/null +++ b/pkg/channels/slack/init.go @@ -0,0 +1,13 @@ +package slack + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewSlackChannel(cfg.Channels.Slack, b) + }) +} diff --git a/pkg/channels/slack.go b/pkg/channels/slack/slack.go similarity index 98% rename from pkg/channels/slack.go rename to pkg/channels/slack/slack.go index f7359cd6d6..cafe53103d 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack/slack.go @@ -1,4 +1,4 @@ -package channels +package slack import ( "context" @@ -13,6 +13,7 @@ import ( "github.com/slack-go/slack/socketmode" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -20,7 +21,7 @@ import ( ) type SlackChannel struct { - *BaseChannel + *channels.BaseChannel config config.SlackConfig api *slack.Client socketClient *socketmode.Client @@ -49,7 +50,7 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack socketClient := socketmode.New(api) - base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) return &SlackChannel{ BaseChannel: base, @@ -92,7 +93,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("slack", "Slack channel started (Socket Mode)") return nil } @@ -104,7 +105,7 @@ func (c *SlackChannel) Stop(ctx context.Context) error { c.cancel() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("slack", "Slack channel stopped") return nil } diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack/slack_test.go similarity index 99% rename from pkg/channels/slack_test.go rename to pkg/channels/slack/slack_test.go index 3707c2703a..30e0d2d73a 100644 --- a/pkg/channels/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -1,4 +1,4 @@ -package channels +package slack import ( "testing" diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go new file mode 100644 index 0000000000..ac87bb805e --- /dev/null +++ b/pkg/channels/telegram/init.go @@ -0,0 +1,13 @@ +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram/telegram.go similarity index 98% rename from pkg/channels/telegram.go rename to pkg/channels/telegram/telegram.go index a0a1c8d0a8..7619440e21 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" @@ -17,6 +17,7 @@ import ( tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -24,7 +25,7 @@ import ( ) type TelegramChannel struct { - *BaseChannel + *channels.BaseChannel bot *telego.Bot commands TelegramCommander config *config.Config @@ -72,7 +73,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + base := channels.NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) return &TelegramChannel{ BaseChannel: base, @@ -125,7 +126,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ "username": c.bot.Username(), }) @@ -142,7 +143,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") - c.setRunning(false) + c.SetRunning(false) return nil } diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go similarity index 99% rename from pkg/channels/telegram_commands.go rename to pkg/channels/telegram/telegram_commands.go index a084b641b6..f17912260e 100644 --- a/pkg/channels/telegram_commands.go +++ b/pkg/channels/telegram/telegram_commands.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom/app.go similarity index 96% rename from pkg/channels/wecom_app.go rename to pkg/channels/wecom/app.go index 715c487079..f3557d60f9 100644 --- a/pkg/channels/wecom_app.go +++ b/pkg/channels/wecom/app.go @@ -1,8 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel implementation -// Supports receiving messages via webhook callback and sending messages proactively - -package channels +package wecom import ( "bytes" @@ -18,6 +14,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -29,7 +26,7 @@ const ( // WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) type WeComAppChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComAppConfig server *http.Server accessToken string @@ -123,7 +120,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") } - base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) return &WeComAppChannel{ BaseChannel: base, @@ -170,7 +167,7 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { Handler: mux, } - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ "address": addr, "path": webhookPath, @@ -202,7 +199,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { c.server.Shutdown(shutdownCtx) } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom_app", "WeCom App channel stopped") return nil } @@ -279,7 +276,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ "token": c.config.Token, "msg_signature": msgSignature, @@ -298,7 +295,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons "encoding_aes_key": c.config.EncodingAESKey, "corp_id": c.config.CorpID, }) - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -357,7 +354,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom_app", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -365,7 +362,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message with CorpID verification // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ "error": err.Error(), diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom/app_test.go similarity index 95% rename from pkg/channels/wecom_app_test.go rename to pkg/channels/wecom/app_test.go index abf15c52b0..5420949de7 100644 --- a/pkg/channels/wecom_app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel tests - -package channels +package wecom import ( "bytes" @@ -197,7 +194,7 @@ func TestWeComAppVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -207,7 +204,7 @@ func TestWeComAppVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -221,7 +218,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -243,7 +240,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -268,7 +265,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -286,7 +283,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -301,7 +298,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -319,7 +316,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { // Encrypt a very short message that results in ciphertext less than block size shortData := make([]byte, 8) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for short ciphertext, got nil") } @@ -361,7 +358,7 @@ func TestWeComAppPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { @@ -852,6 +849,28 @@ func TestWeComAppMessageStructures(t *testing.T) { } }) + t.Run("WeComImageMessage structure", func(t *testing.T) { + msg := WeComImageMessage{ + ToUser: "user123", + MsgType: "image", + AgentID: 1000002, + } + msg.Image.MediaID = "media_123456" + + if msg.Image.MediaID != "media_123456" { + t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") + } + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + }) + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { jsonData := `{ "errcode": 0, diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom/bot.go similarity index 73% rename from pkg/channels/wecom.go rename to pkg/channels/wecom/bot.go index f8daf89de2..17ee2107ff 100644 --- a/pkg/channels/wecom.go +++ b/pkg/channels/wecom/bot.go @@ -1,28 +1,19 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel implementation -// Uses webhook callback mode for receiving messages and webhook API for sending replies - -package channels +package wecom import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" "encoding/json" "encoding/xml" "fmt" "io" "net/http" - "sort" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" @@ -31,7 +22,7 @@ import ( // WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) // Uses webhook callback mode - simpler than WeCom App but only supports passive replies type WeComBotChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComConfig server *http.Server ctx context.Context @@ -96,7 +87,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We return nil, fmt.Errorf("wecom token and webhook_url are required") } - base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) return &WeComBotChannel{ BaseChannel: base, @@ -133,7 +124,7 @@ func (c *WeComBotChannel) Start(ctx context.Context) error { Handler: mux, } - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ "address": addr, "path": webhookPath, @@ -165,7 +156,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { c.server.Shutdown(shutdownCtx) } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom", "WeCom Bot channel stopped") return nil } @@ -219,7 +210,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnC("wecom", "Signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -228,7 +219,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons // Decrypt echostr // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -281,7 +272,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -290,7 +281,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ "error": err.Error(), @@ -477,129 +468,3 @@ func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } - -// WeCom common utilities for both WeCom Bot and WeCom App -// The following functions were moved from wecom_common.go - -// WeComVerifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// WeComDecryptMessage decrypts the encrypted message using AES -// This is a common function used by both WeCom Bot and WeCom App -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7UnpadWeCom(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - - return string(msg), nil -} - -// pkcs7UnpadWeCom removes PKCS7 padding with validation -// WeCom uses block size of 32 (not standard AES block size of 16) -const wecomBlockSize = 32 - -func pkcs7UnpadWeCom(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > wecomBlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := 0; i < padding; i++ { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom/bot_test.go similarity index 95% rename from pkg/channels/wecom_test.go rename to pkg/channels/wecom/bot_test.go index 8afa7e8c33..328b145c2d 100644 --- a/pkg/channels/wecom_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel tests - -package channels +package wecom import ( "bytes" @@ -177,7 +174,7 @@ func TestWeComBotVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -187,7 +184,7 @@ func TestWeComBotVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -202,7 +199,7 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -223,7 +220,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -247,7 +244,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -264,7 +261,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -278,7 +275,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -320,20 +317,20 @@ func TestWeComBotPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { - t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result) + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) } return } if err != nil { - t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err) + t.Errorf("pkcs7Unpad() unexpected error: %v", err) return } if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected) + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) } }) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go new file mode 100644 index 0000000000..3c1629577a --- /dev/null +++ b/pkg/channels/wecom/common.go @@ -0,0 +1,134 @@ +package wecom + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// blockSize is the PKCS7 block size used by WeCom (32) +const blockSize = 32 + +// verifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// decryptMessage decrypts the encrypted message using AES +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7Unpad(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7Unpad removes PKCS7 padding with validation +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > blockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go new file mode 100644 index 0000000000..3ef1ecdf37 --- /dev/null +++ b/pkg/channels/wecom/init.go @@ -0,0 +1,16 @@ +package wecom + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComBotChannel(cfg.Channels.WeCom, b) + }) + channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAppChannel(cfg.Channels.WeComApp, b) + }) +} diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go new file mode 100644 index 0000000000..d9c2669c32 --- /dev/null +++ b/pkg/channels/whatsapp/init.go @@ -0,0 +1,13 @@ +package whatsapp + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWhatsAppChannel(cfg.Channels.WhatsApp, b) + }) +} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go similarity index 95% rename from pkg/channels/whatsapp.go rename to pkg/channels/whatsapp/whatsapp.go index 958d850bb7..7e8f13ab6d 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -1,4 +1,4 @@ -package channels +package whatsapp import ( "context" @@ -11,12 +11,13 @@ import ( "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { - *BaseChannel + *channels.BaseChannel conn *websocket.Conn config config.WhatsAppConfig url string @@ -25,7 +26,7 @@ type WhatsAppChannel struct { } func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) return &WhatsAppChannel{ BaseChannel: base, @@ -51,7 +52,7 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { c.connected = true c.mu.Unlock() - c.setRunning(true) + c.SetRunning(true) log.Println("WhatsApp channel connected") go c.listen(ctx) @@ -73,7 +74,7 @@ func (c *WhatsAppChannel) Stop(ctx context.Context) error { } c.connected = false - c.setRunning(false) + c.SetRunning(false) return nil }