Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/picoclaw/cmd_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ func gatewayCmd() {

fmt.Println("\nShutting down...")
cancel()
if cp, ok := provider.(providers.SessionProvider); ok {
cp.Close()
}
healthServer.Stop(context.Background())
deviceService.Stop()
heartbeatService.Stop()
Expand Down
71 changes: 54 additions & 17 deletions pkg/providers/github_copilot_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,72 +4,109 @@ import (
"context"
"encoding/json"
"fmt"
"sync"

copilot "github.com/github/copilot-sdk/go"
)

type GitHubCopilotProvider struct {
uri string
connectMode string // `stdio` or `grpc``
connectMode string // "stdio" or "grpc"

client *copilot.Client
session *copilot.Session

mu sync.Mutex
}

func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
var session *copilot.Session
if connectMode == "" {
connectMode = "grpc"
}
switch connectMode {

switch connectMode {
case "stdio":
// todo
// TODO:
return nil, fmt.Errorf("stdio mode not implemented")
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
})
if err := client.Start(context.Background()); err != nil {
return nil, fmt.Errorf(
"Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details",
"can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details",
err,
)
}
defer client.Stop()
session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{

session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{
Model: model,
Hooks: &copilot.SessionHooks{},
})
if err != nil {

client.Stop()
return nil, fmt.Errorf("create session failed: %w", err)
}

return &GitHubCopilotProvider{
uri: uri,
connectMode: connectMode,
client: client,
session: session,
}, nil
default:
return nil, fmt.Errorf("unknown connect mode: %s", connectMode)
}
}

return &GitHubCopilotProvider{
uri: uri,
connectMode: connectMode,
session: session,
}, nil
func (p *GitHubCopilotProvider) Close() {
p.mu.Lock()
defer p.mu.Unlock()
Comment on lines +64 to +65
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it maybe be better if you lock only if the client is not nil?
Just to be sure to locking only if there's a real reason.

if p.client != nil {
p.client.Stop()
p.client = nil
p.session = nil
}
}

// Chat sends a chat request to GitHub Copilot
func (p *GitHubCopilotProvider) Chat(
ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
type tempMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
out := make([]tempMessage, 0, len(messages))

for _, msg := range messages {
out = append(out, tempMessage{
Role: msg.Role,
Content: msg.Content,
})
}

fullcontent, _ := json.Marshal(out)
fullcontent, err := json.Marshal(out)
if err != nil {
return nil, fmt.Errorf("marshal messages: %w", err)
}
p.mu.Lock()
defer p.mu.Unlock()

content, _ := p.session.Send(ctx, copilot.MessageOptions{
resp, err := p.session.SendAndWait(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
if err != nil {
return nil, err
}

var content string
if resp != nil && resp.Data.Content != nil {
content = *resp.Data.Content
}

return &LLMResponse{
FinishReason: "stop",
Expand Down
5 changes: 5 additions & 0 deletions pkg/providers/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ type LLMProvider interface {
GetDefaultModel() string
}

type SessionProvider interface {
LLMProvider
Close()
}

// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string

Expand Down
20 changes: 10 additions & 10 deletions pkg/tools/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import (
type mockRegistryTool struct {
name string
desc string
params map[string]interface{}
params map[string]any
result *ToolResult
}

func (m *mockRegistryTool) Name() string { return m.name }
func (m *mockRegistryTool) Description() string { return m.desc }
func (m *mockRegistryTool) Parameters() map[string]interface{} { return m.params }
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]interface{}) *ToolResult {
func (m *mockRegistryTool) Name() string { return m.name }
func (m *mockRegistryTool) Description() string { return m.desc }
func (m *mockRegistryTool) Parameters() map[string]any { return m.params }
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
return m.result
}

Expand Down Expand Up @@ -51,7 +51,7 @@ func newMockTool(name, desc string) *mockRegistryTool {
return &mockRegistryTool{
name: name,
desc: desc,
params: map[string]interface{}{"type": "object"},
params: map[string]any{"type": "object"},
result: SilentResult("ok"),
}
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestToolRegistry_Execute_Success(t *testing.T) {
r.Register(&mockRegistryTool{
name: "greet",
desc: "says hello",
params: map[string]interface{}{},
params: map[string]any{},
result: SilentResult("hello"),
})

Expand Down Expand Up @@ -203,7 +203,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
if defs[0]["type"] != "function" {
t.Errorf("expected type 'function', got %v", defs[0]["type"])
}
fn, ok := defs[0]["function"].(map[string]interface{})
fn, ok := defs[0]["function"].(map[string]any)
if !ok {
t.Fatal("expected 'function' key to be a map")
}
Expand All @@ -217,7 +217,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {

func TestToolRegistry_ToProviderDefs(t *testing.T) {
r := NewToolRegistry()
params := map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}
params := map[string]any{"type": "object", "properties": map[string]any{}}
r.Register(&mockRegistryTool{
name: "beta",
desc: "tool B",
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestToolToSchema(t *testing.T) {
if schema["type"] != "function" {
t.Errorf("expected type 'function', got %v", schema["type"])
}
fn, ok := schema["function"].(map[string]interface{})
fn, ok := schema["function"].(map[string]any)
if !ok {
t.Fatal("expected 'function' to be a map")
}
Expand Down