Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.

Commit 005b8ac

Browse files
committed
initial working agent
1 parent e7258e3 commit 005b8ac

File tree

6 files changed

+201
-22
lines changed

6 files changed

+201
-22
lines changed

cmd/root.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
109109
}
110110
}
111111

112-
// Execute adds all child commands to the root command and sets flags appropriately.
113-
// This is called by main.main(). It only needs to happen once to the rootCmd.
114112
func Execute() {
115113
err := rootCmd.Execute()
116114
if err != nil {
@@ -131,13 +129,14 @@ func loadConfig() {
131129

132130
// LLM
133131
viper.SetDefault("models.big", string(models.DefaultBigModel))
134-
viper.SetDefault("models.little", string(models.DefaultLittleModel))
132+
viper.SetDefault("models.small", string(models.DefaultLittleModel))
135133
viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
136134
viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
135+
viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
137136
viper.SetDefault("providers.common.max_tokens", 4000)
138137

139138
viper.SetDefault("agents.default", "coder")
140-
//
139+
141140
viper.ReadInConfig()
142141

143142
workdir, err := os.Getwd()

internal/llm/agent/title.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
6+
"github.com/cloudwego/eino/schema"
7+
"github.com/kujtimiihoxha/termai/internal/llm/models"
8+
"github.com/spf13/viper"
9+
)
10+
11+
func GenerateTitle(ctx context.Context, content string) (string, error) {
12+
model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
13+
if err != nil {
14+
return "", err
15+
}
16+
out, err := model.Generate(
17+
ctx,
18+
[]*schema.Message{
19+
schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
20+
- ensure it is not more than 80 characters long
21+
- the title should be a summary of the user's message
22+
- do not use quotes or colons
23+
- the entire text you return will be used as the title`),
24+
schema.UserMessage(content),
25+
},
26+
)
27+
if err != nil {
28+
return "", err
29+
}
30+
return out.Content, nil
31+
}

internal/llm/llm.go

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/cloudwego/eino/schema"
1212
"github.com/google/uuid"
1313
"github.com/kujtimiihoxha/termai/internal/llm/agent"
14+
"github.com/kujtimiihoxha/termai/internal/llm/models"
1415
"github.com/kujtimiihoxha/termai/internal/logging"
1516
"github.com/kujtimiihoxha/termai/internal/message"
1617
"github.com/kujtimiihoxha/termai/internal/pubsub"
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
8889
}
8990

9091
log.Printf("Request: %s", content)
91-
agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
92+
currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
9293
if err != nil {
9394
s.Publish(AgentErrorEvent, AgentEvent{
9495
ID: id,
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
110111
for _, m := range history {
111112
messages = append(messages, &m.MessageData)
112113
}
114+
113115
builder := callbacks.NewHandlerBuilder()
114116
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
115117
i, ok := input.(*eModel.CallbackInput)
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
140142
return ctx
141143
})
142144

143-
out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
145+
out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
144146
if err != nil {
145147
s.Publish(AgentErrorEvent, AgentEvent{
146148
ID: id,
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
153155
return
154156
}
155157
usage := out.ResponseMeta.Usage
158+
s.messages.Create(sessionID, *out)
156159
if usage != nil {
157160
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
158161
session, err := s.sessions.Get(sessionID)
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
170173
session.PromptTokens += int64(usage.PromptTokens)
171174
session.CompletionTokens += int64(usage.CompletionTokens)
172175
// TODO: calculate cost
176+
model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
177+
session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
178+
float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
179+
var newTitle string
180+
if len(history) == 1 {
181+
// first message generate the title
182+
newTitle, err = agent.GenerateTitle(s.ctx, content)
183+
if err != nil {
184+
s.Publish(AgentErrorEvent, AgentEvent{
185+
ID: id,
186+
Type: AgentMessageTypeError,
187+
AgentID: RootAgent,
188+
MessageID: "",
189+
SessionID: sessionID,
190+
Content: err.Error(),
191+
})
192+
return
193+
}
194+
}
195+
if newTitle != "" {
196+
session.Title = newTitle
197+
}
198+
173199
_, err = s.sessions.Save(session)
174200
if err != nil {
175201
s.Publish(AgentErrorEvent, AgentEvent{
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
183209
return
184210
}
185211
}
186-
s.messages.Create(sessionID, *out)
187212
}
188213

189214
func (s *service) SendRequest(sessionID string, content string) {

internal/llm/models/models.go

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package models
33
import (
44
"context"
55
"errors"
6+
"log"
67

78
"github.com/cloudwego/eino-ext/components/model/claude"
89
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -16,10 +17,12 @@ type (
1617
)
1718

1819
type Model struct {
19-
ID ModelID `json:"id"`
20-
Name string `json:"name"`
21-
Provider ModelProvider `json:"provider"`
22-
APIModel string `json:"api_model"` // Actual value used when calling the API
20+
ID ModelID `json:"id"`
21+
Name string `json:"name"`
22+
Provider ModelProvider `json:"provider"`
23+
APIModel string `json:"api_model"`
24+
CostPer1MIn float64 `json:"cost_per_1m_in"`
25+
CostPer1MOut float64 `json:"cost_per_1m_out"`
2326
}
2427

2528
const (
@@ -52,6 +55,9 @@ const (
5255
// Meta
5356
Llama3 ModelID = "llama-3"
5457
Llama270B ModelID = "llama-2-70b"
58+
// GROQ
59+
GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
60+
GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
5561
)
5662

5763
const (
@@ -61,6 +67,7 @@ const (
6167
ProviderXAI ModelProvider = "xai"
6268
ProviderDeepSeek ModelProvider = "deepseek"
6369
ProviderMeta ModelProvider = "meta"
70+
ProviderGroq ModelProvider = "groq"
6471
)
6572

6673
var SupportedModels = map[ModelID]Model{
@@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
7279
APIModel: "gpt-4o",
7380
},
7481
GPT4oMini: {
75-
ID: GPT4oMini,
76-
Name: "GPT-4o Mini",
77-
Provider: ProviderOpenAI,
78-
APIModel: "gpt-4o-mini",
82+
ID: GPT4oMini,
83+
Name: "GPT-4o Mini",
84+
Provider: ProviderOpenAI,
85+
APIModel: "gpt-4o-mini",
86+
CostPer1MIn: 0.150,
87+
CostPer1MOut: 0.600,
7988
},
8089
GPT45: {
8190
ID: GPT45,
@@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
172181
Provider: ProviderMeta,
173182
APIModel: "llama-2-70b",
174183
},
184+
185+
// GROQ
186+
GroqLlama3SpecDec: {
187+
ID: GroqLlama3SpecDec,
188+
Name: "GROQ LLaMA 3 SpecDec",
189+
Provider: ProviderGroq,
190+
APIModel: "llama-3.3-70b-specdec",
191+
},
192+
GroqQwen32BCoder: {
193+
ID: GroqQwen32BCoder,
194+
Name: "GROQ Qwen 2.5 Coder 32B",
195+
Provider: ProviderGroq,
196+
APIModel: "qwen-2.5-coder-32b",
197+
},
175198
}
176199

177200
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
178201
provider := SupportedModels[model].Provider
202+
log.Printf("Provider: %s", provider)
179203
maxTokens := viper.GetInt("providers.common.max_tokens")
180204
switch provider {
181205
case ProviderOpenAI:
@@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
191215
MaxTokens: maxTokens,
192216
})
193217

218+
case ProviderGroq:
219+
return openai.NewChatModel(ctx, &openai.ChatModelConfig{
220+
BaseURL: "https://api.groq.com/openai/v1",
221+
APIKey: viper.GetString("providers.groq.key"),
222+
Model: string(SupportedModels[model].APIModel),
223+
MaxTokens: &maxTokens,
224+
})
225+
194226
}
195227
return nil, errors.New("unsupported provider")
196228
}

internal/tui/components/repl/messages.go

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
package repl
22

33
import (
4+
"github.com/charmbracelet/bubbles/key"
5+
"github.com/charmbracelet/bubbles/viewport"
46
tea "github.com/charmbracelet/bubbletea"
57
"github.com/charmbracelet/lipgloss"
68
"github.com/kujtimiihoxha/termai/internal/app"
79
"github.com/kujtimiihoxha/termai/internal/message"
810
"github.com/kujtimiihoxha/termai/internal/pubsub"
911
"github.com/kujtimiihoxha/termai/internal/session"
12+
"github.com/kujtimiihoxha/termai/internal/tui/layout"
1013
)
1114

15+
type MessagesCmp interface {
16+
tea.Model
17+
layout.Focusable
18+
layout.Bordered
19+
layout.Sizeable
20+
layout.Bindings
21+
}
22+
1223
type messagesCmp struct {
1324
app *app.App
1425
messages []message.Message
1526
session session.Session
16-
}
17-
18-
func (m *messagesCmp) Init() tea.Cmd {
19-
return nil
27+
viewport viewport.Model
28+
width int
29+
height int
30+
focused bool
2031
}
2132

2233
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
2536
if msg.Type == pubsub.CreatedEvent {
2637
m.messages = append(m.messages, msg.Payload)
2738
}
39+
case pubsub.Event[session.Session]:
40+
if msg.Type == pubsub.UpdatedEvent {
41+
if m.session.ID == msg.Payload.ID {
42+
m.session = msg.Payload
43+
}
44+
}
2845
case SelectedSessionMsg:
2946
m.session, _ = m.app.Sessions.Get(msg.SessionID)
3047
m.messages, _ = m.app.Messages.List(m.session.ID)
@@ -40,7 +57,55 @@ func (i *messagesCmp) View() string {
4057
return lipgloss.JoinVertical(lipgloss.Top, stringMessages...)
4158
}
4259

43-
func NewMessagesCmp(app *app.App) tea.Model {
60+
// BindingKeys implements MessagesCmp.
61+
func (m *messagesCmp) BindingKeys() []key.Binding {
62+
return []key.Binding{}
63+
}
64+
65+
// Blur implements MessagesCmp.
66+
func (m *messagesCmp) Blur() tea.Cmd {
67+
m.focused = false
68+
return nil
69+
}
70+
71+
// BorderText implements MessagesCmp.
72+
func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
73+
title := m.session.Title
74+
if len(title) > 20 {
75+
title = title[:20] + "..."
76+
}
77+
return map[layout.BorderPosition]string{
78+
layout.TopLeftBorder: title,
79+
}
80+
}
81+
82+
// Focus implements MessagesCmp.
83+
func (m *messagesCmp) Focus() tea.Cmd {
84+
m.focused = true
85+
return nil
86+
}
87+
88+
// GetSize implements MessagesCmp.
89+
func (m *messagesCmp) GetSize() (int, int) {
90+
return m.width, m.height
91+
}
92+
93+
// IsFocused implements MessagesCmp.
94+
func (m *messagesCmp) IsFocused() bool {
95+
return m.focused
96+
}
97+
98+
// SetSize implements MessagesCmp.
99+
func (m *messagesCmp) SetSize(width int, height int) {
100+
m.width = width
101+
m.height = height
102+
}
103+
104+
func (m *messagesCmp) Init() tea.Cmd {
105+
return nil
106+
}
107+
108+
func NewMessagesCmp(app *app.App) MessagesCmp {
44109
return &messagesCmp{
45110
app: app,
46111
messages: []message.Message{},

0 commit comments

Comments
 (0)