diff --git a/config/config.example.json b/config/config.example.json index d885ef94b0..e98df3dcde 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -58,6 +58,7 @@ "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", + "proxy": "", "allow_from": [], "group_trigger": { "mention_only": false diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index cd6a2560f8..1de910c834 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,12 +3,15 @@ package discord import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } + if err := applyDiscordProxy(session, cfg.Proxy); err != nil { + return nil, err + } base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000), channels.WithGroupTrigger(cfg.GroupTrigger), @@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func() func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", + ProxyURL: c.config.Proxy, }) } +func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { + var proxyFunc func(*http.Request) (*url.URL, error) + if proxyAddr != "" { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err) + } + proxyFunc = http.ProxyURL(proxyURL) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + proxyFunc = http.ProxyFromEnvironment + } + + if proxyFunc == nil { + return nil + } + + transport := &http.Transport{Proxy: proxyFunc} + session.Client = &http.Client{ + Timeout: sendTimeout, + Transport: transport, + } + + if session.Dialer != nil { + dialerCopy := *session.Dialer + dialerCopy.Proxy = proxyFunc + session.Dialer = &dialerCopy + } else { + session.Dialer = &websocket.Dialer{Proxy: proxyFunc} + } + + return nil +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go new file mode 100644 index 0000000000..0cd5328f40 --- /dev/null +++ b/pkg/channels/discord/discord_test.go @@ -0,0 +1,91 @@ +package discord + +import ( + "net/http" + "net/url" + "testing" + + "github.com/bwmarrin/discordgo" +) + +func TestApplyDiscordProxy_CustomProxy(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + restProxy := session.Client.Transport.(*http.Transport).Proxy + restProxyURL, err := restProxy(req) + if err != nil { + t.Fatalf("rest proxy func error: %v", err) + } + if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("REST proxy = %q, want %q", got, want) + } + + wsProxyURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("WS proxy = %q, want %q", got, want) + } +} + +func TestApplyDiscordProxy_FromEnvironment(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, ""); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + gotURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + + wantURL, err := url.Parse("http://127.0.0.1:8888") + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + if gotURL.String() != wantURL.String() { + t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String()) + } +} + +func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "://bad-proxy"); err == nil { + t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 9f4769de4a..a0027bdddd 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -258,6 +258,7 @@ type FeishuConfig struct { type DiscordConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` diff --git a/pkg/utils/media.go b/pkg/utils/media.go index a34889fb81..3e1c5d88e1 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -3,6 +3,7 @@ package utils import ( "io" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -52,11 +53,12 @@ type DownloadOptions struct { Timeout time.Duration ExtraHeaders map[string]string LoggerPrefix string + ProxyURL string } // DownloadFile downloads a file from URL to a local temp directory. // Returns the local file path or empty string on error. -func DownloadFile(url, filename string, opts DownloadOptions) string { +func DownloadFile(urlStr, filename string, opts DownloadOptions) string { // Set defaults if opts.Timeout == 0 { opts.Timeout = 60 * time.Second @@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("GET", urlStr, nil) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{ "error": err.Error(), @@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } client := &http.Client{Timeout: opts.Timeout} + if opts.ProxyURL != "" { + proxyURL, parseErr := url.Parse(opts.ProxyURL) + if parseErr != nil { + logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{ + "error": parseErr.Error(), + "proxy": opts.ProxyURL, + }) + return "" + } + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } resp, err := client.Do(req) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{ "error": err.Error(), - "url": url, + "url": urlStr, }) return "" } @@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { if resp.StatusCode != http.StatusOK { logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{ "status": resp.StatusCode, - "url": url, + "url": urlStr, }) return "" }