Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions pkg/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -92,10 +93,13 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
server.Shutdown(ctx)
}()

fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)

if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}

fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
fmt.Println("Waiting for authentication in browser...")

select {
Expand All @@ -114,6 +118,57 @@ type callbackResult struct {
err error
}

type deviceCodeResponse struct {
DeviceAuthID string
UserCode string
Interval int
}

func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
var raw struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval json.RawMessage `json:"interval"`
}

if err := json.Unmarshal(body, &raw); err != nil {
return deviceCodeResponse{}, err
}

interval, err := parseFlexibleInt(raw.Interval)
if err != nil {
return deviceCodeResponse{}, err
}

return deviceCodeResponse{
DeviceAuthID: raw.DeviceAuthID,
UserCode: raw.UserCode,
Interval: interval,
}, nil
}

func parseFlexibleInt(raw json.RawMessage) (int, error) {
if len(raw) == 0 || string(raw) == "null" {
return 0, nil
}

var interval int
if err := json.Unmarshal(raw, &interval); err == nil {
return interval, nil
}

var intervalStr string
if err := json.Unmarshal(raw, &intervalStr); err == nil {
intervalStr = strings.TrimSpace(intervalStr)
if intervalStr == "" {
return 0, nil
}
return strconv.Atoi(intervalStr)
}

return 0, fmt.Errorf("invalid integer value: %s", string(raw))
}

func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
Expand All @@ -134,12 +189,8 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}

var deviceResp struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval int `json:"interval"`
}
if err := json.Unmarshal(body, &deviceResp); err != nil {
deviceResp, err := parseDeviceCodeResponse(body)
if err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}

Expand Down
40 changes: 40 additions & 0 deletions pkg/auth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,43 @@ func TestOpenAIOAuthConfig(t *testing.T) {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}

func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`)

resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}

if resp.DeviceAuthID != "abc" {
t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc")
}
if resp.UserCode != "DEF-1234" {
t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234")
}
if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}

func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`)

resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}

if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}

func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`)

if _, err := parseDeviceCodeResponse(body); err == nil {
t.Fatal("expected error for invalid interval")
}
}