From b5aa3b5ef41a41521b49bc3686ad3fc50794745a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:02:12 +0100 Subject: [PATCH 01/23] Add bridgeadapter RemoteMessage/RemoteEdit and refactor Introduce a bridgeadapter package that provides generic RemoteMessage, RemoteEdit, NewMessageID, and user-login ID helpers to centralize duplicated remote-event and identifier logic. Replace per-bridge RemoteMessage/RemoteEdit implementations with the shared types (and aliases) and update call sites to use exported struct fields and a LogKey for consistent logging. Refactor identifier creation into MakeUserLoginID/NextUserLoginID. Changes touch bridge code (codex, opencode), connector, and new pkg/bridgeadapter files to remove duplication and standardize message/edit construction and IDs. --- bridges/codex/client.go | 22 ++-- bridges/codex/identifiers.go | 25 +--- bridges/codex/portal_send.go | 11 +- bridges/codex/remote_events.go | 140 +------------------- bridges/codex/stream_transport.go | 11 +- bridges/opencode/host.go | 22 ++-- bridges/opencode/identifiers.go | 26 +--- bridges/opencode/portal_send.go | 11 +- bridges/opencode/remote_events.go | 137 +------------------ bridges/opencode/stream_canonical.go | 11 +- pkg/bridgeadapter/identifier_helpers.go | 28 ++++ pkg/bridgeadapter/remote_events.go | 147 +++++++++++++++++++++ pkg/connector/portal_send.go | 28 ++-- pkg/connector/remote_events.go | 166 ++---------------------- pkg/connector/response_finalization.go | 26 ++-- pkg/connector/stream_transport.go | 14 +- 16 files changed, 293 insertions(+), 532 deletions(-) create mode 100644 pkg/bridgeadapter/remote_events.go diff --git a/bridges/codex/client.go b/bridges/codex/client.go index ec4d9d79..9a632ad1 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -2010,11 +2010,12 @@ func (cc *CodexClient) sendFinalAssistantTurn(ctx context.Context, portal *bridg sender := cc.senderForPortal() cc.UserLogin.QueueRemoteEvent(&CodexRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: state.networkMessageID, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: state.networkMessageID, + Timestamp: time.Now(), + LogKey: "codex_edit_target", + PreBuilt: &bridgev2.ConvertedEdit{ ModifiedParts: []*bridgev2.ConvertedEditPart{{ Type: event.EventMessage, Content: &event.MessageEventContent{ @@ -2059,11 +2060,12 @@ func (cc *CodexClient) sendContinuationMessage(ctx context.Context, portal *brid } sender := cc.senderForPortal() cc.UserLogin.QueueRemoteEvent(&CodexRemoteMessage{ - portal: portal.PortalKey, - id: newMessageID(), - sender: sender, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedMessage{ + Portal: portal.PortalKey, + ID: newMessageID(), + Sender: sender, + Timestamp: time.Now(), + LogKey: "codex_msg_id", + PreBuilt: &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ ID: networkid.PartID("0"), Type: event.EventMessage, diff --git a/bridges/codex/identifiers.go b/bridges/codex/identifiers.go index b2594b98..df22ef07 100644 --- a/bridges/codex/identifiers.go +++ b/bridges/codex/identifiers.go @@ -1,39 +1,22 @@ package codex import ( - "fmt" - "net/url" "strings" "github.com/rs/xid" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func makeCodexUserLoginID(mxid id.UserID, ordinal int) networkid.UserLoginID { - escaped := url.PathEscape(string(mxid)) - base := networkid.UserLoginID(fmt.Sprintf("codex:%s", escaped)) - if ordinal <= 1 { - return base - } - return networkid.UserLoginID(fmt.Sprintf("%s:%d", base, ordinal)) + return bridgeadapter.MakeUserLoginID("codex", mxid, ordinal) } func nextCodexUserLoginID(user *bridgev2.User) networkid.UserLoginID { - used := map[string]struct{}{} - for _, existing := range user.GetUserLogins() { - if existing == nil { - continue - } - used[string(existing.ID)] = struct{}{} - } - for ordinal := 1; ; ordinal++ { - loginID := makeCodexUserLoginID(user.MXID, ordinal) - if _, ok := used[string(loginID)]; !ok { - return loginID - } - } + return bridgeadapter.NextUserLoginID(user, "codex") } func generateShortID() string { diff --git a/bridges/codex/portal_send.go b/bridges/codex/portal_send.go index 6e6517a9..3d01fb29 100644 --- a/bridges/codex/portal_send.go +++ b/bridges/codex/portal_send.go @@ -31,11 +31,12 @@ func (cc *CodexClient) sendViaPortal( } sender := cc.senderForPortal() evt := &CodexRemoteMessage{ - portal: portal.PortalKey, - id: msgID, - sender: sender, - timestamp: time.Now(), - preBuilt: converted, + Portal: portal.PortalKey, + ID: msgID, + Sender: sender, + Timestamp: time.Now(), + LogKey: "codex_msg_id", + PreBuilt: converted, } result := cc.UserLogin.QueueRemoteEvent(evt) if !result.Success { diff --git a/bridges/codex/remote_events.go b/bridges/codex/remote_events.go index f17d2297..5e1376ed 100644 --- a/bridges/codex/remote_events.go +++ b/bridges/codex/remote_events.go @@ -1,146 +1,18 @@ package codex import ( - "context" - "fmt" - "time" - - "github.com/google/uuid" - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "github.com/beeper/ai-bridge/pkg/shared/streamtransport" -) - -// ----------------------------------------------------------------------- -// CodexRemoteMessage — covers plain text, tool call events -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteMessage = (*CodexRemoteMessage)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*CodexRemoteMessage)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*CodexRemoteMessage)(nil) + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) -// CodexRemoteMessage is a RemoteMessage for Codex-generated content routed through bridgev2. -type CodexRemoteMessage struct { - portal networkid.PortalKey - id networkid.MessageID - sender bridgev2.EventSender - timestamp time.Time - - // Pre-built event content. - preBuilt *bridgev2.ConvertedMessage -} - -func (m *CodexRemoteMessage) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventMessage -} - -func (m *CodexRemoteMessage) GetPortalKey() networkid.PortalKey { - return m.portal -} - -func (m *CodexRemoteMessage) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("codex_msg_id", string(m.id)) -} - -func (m *CodexRemoteMessage) GetSender() bridgev2.EventSender { - return m.sender -} - -func (m *CodexRemoteMessage) GetID() networkid.MessageID { - return m.id -} - -func (m *CodexRemoteMessage) GetTimestamp() time.Time { - if m.timestamp.IsZero() { - return time.Now() - } - return m.timestamp -} - -func (m *CodexRemoteMessage) GetStreamOrder() int64 { - return m.GetTimestamp().UnixMilli() -} - -func (m *CodexRemoteMessage) ConvertMessage(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return m.preBuilt, nil -} - -// ----------------------------------------------------------------------- -// CodexRemoteEdit — for final streaming edits (m.replace) -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteEdit = (*CodexRemoteEdit)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*CodexRemoteEdit)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*CodexRemoteEdit)(nil) -) - -// CodexRemoteEdit is a RemoteEdit for the final streaming response edit. -type CodexRemoteEdit struct { - portal networkid.PortalKey - sender bridgev2.EventSender - targetMessage networkid.MessageID - timestamp time.Time - - // Pre-built edit content. - preBuilt *bridgev2.ConvertedEdit -} - -func (e *CodexRemoteEdit) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventEdit -} - -func (e *CodexRemoteEdit) GetPortalKey() networkid.PortalKey { - return e.portal -} - -func (e *CodexRemoteEdit) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("codex_edit_target", string(e.targetMessage)) -} - -func (e *CodexRemoteEdit) GetSender() bridgev2.EventSender { - return e.sender -} - -func (e *CodexRemoteEdit) GetTargetMessage() networkid.MessageID { - return e.targetMessage -} - -func (e *CodexRemoteEdit) GetTimestamp() time.Time { - if e.timestamp.IsZero() { - return time.Now() - } - return e.timestamp -} - -func (e *CodexRemoteEdit) GetStreamOrder() int64 { - return e.GetTimestamp().UnixMilli() -} - -func (e *CodexRemoteEdit) ConvertEdit(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { - // Bind existing DB parts to modified parts when Part was left nil at build time. - if e.preBuilt != nil && len(existing) > 0 { - for i, part := range e.preBuilt.ModifiedParts { - if part.Part == nil && i < len(existing) { - part.Part = existing[i] - } - } - } - streamtransport.EnsureDontRenderEdited(e.preBuilt) - return e.preBuilt, nil -} +// CodexRemoteMessage is a type alias for the shared RemoteMessage. +type CodexRemoteMessage = bridgeadapter.RemoteMessage -// ----------------------------------------------------------------------- -// Helpers -// ----------------------------------------------------------------------- +// CodexRemoteEdit is a type alias for the shared RemoteEdit. +type CodexRemoteEdit = bridgeadapter.RemoteEdit // newMessageID generates a unique message ID for Codex remote events. func newMessageID() networkid.MessageID { - return networkid.MessageID(fmt.Sprintf("codex:%s", uuid.NewString())) + return bridgeadapter.NewMessageID("codex") } diff --git a/bridges/codex/stream_transport.go b/bridges/codex/stream_transport.go index 43803e88..a9ba417d 100644 --- a/bridges/codex/stream_transport.go +++ b/bridges/codex/stream_transport.go @@ -28,11 +28,12 @@ func (cc *CodexClient) sendDebouncedStreamEdit(ctx context.Context, portal *brid } sender := cc.senderForPortal() cc.UserLogin.QueueRemoteEvent(&CodexRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: state.networkMessageID, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: state.networkMessageID, + Timestamp: time.Now(), + LogKey: "codex_edit_target", + PreBuilt: &bridgev2.ConvertedEdit{ ModifiedParts: []*bridgev2.ConvertedEditPart{{ Type: event.EventMessage, Content: &event.MessageEventContent{ diff --git a/bridges/opencode/host.go b/bridges/opencode/host.go index 7b2cfa5d..d5cb7f46 100644 --- a/bridges/opencode/host.go +++ b/bridges/opencode/host.go @@ -151,11 +151,12 @@ func (oc *OpenCodeClient) EmitOpenCodeStreamEvent(ctx context.Context, portal *b }}, } result := oc.UserLogin.QueueRemoteEvent(&OpenCodeRemoteMessage{ - portal: portal.PortalKey, - id: msgID, - sender: sender, - timestamp: time.Now(), - preBuilt: converted, + Portal: portal.PortalKey, + ID: msgID, + Sender: sender, + Timestamp: time.Now(), + LogKey: "opencode_msg_id", + PreBuilt: converted, }) if result.Success && result.EventID != "" { oc.streamMu.Lock() @@ -242,11 +243,12 @@ func (oc *OpenCodeClient) EmitOpenCodeStreamEvent(ctx context.Context, portal *b } sender := oc.SenderForOpenCode(instanceID, false) oc.UserLogin.QueueRemoteEvent(&OpenCodeRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: netMsgID, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: netMsgID, + Timestamp: time.Now(), + LogKey: "opencode_edit_target", + PreBuilt: &bridgev2.ConvertedEdit{ ModifiedParts: []*bridgev2.ConvertedEditPart{{ Type: event.EventMessage, Content: &event.MessageEventContent{ diff --git a/bridges/opencode/identifiers.go b/bridges/opencode/identifiers.go index 58c83540..9fd29857 100644 --- a/bridges/opencode/identifiers.go +++ b/bridges/opencode/identifiers.go @@ -1,35 +1,17 @@ package opencode import ( - "fmt" - "net/url" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func makeOpenCodeUserLoginID(mxid id.UserID, ordinal int) networkid.UserLoginID { - escaped := url.PathEscape(string(mxid)) - base := networkid.UserLoginID(fmt.Sprintf("opencode:%s", escaped)) - if ordinal <= 1 { - return base - } - return networkid.UserLoginID(fmt.Sprintf("%s:%d", base, ordinal)) + return bridgeadapter.MakeUserLoginID("opencode", mxid, ordinal) } func nextOpenCodeUserLoginID(user *bridgev2.User) networkid.UserLoginID { - used := map[string]struct{}{} - for _, existing := range user.GetUserLogins() { - if existing == nil { - continue - } - used[string(existing.ID)] = struct{}{} - } - for ordinal := 1; ; ordinal++ { - loginID := makeOpenCodeUserLoginID(user.MXID, ordinal) - if _, ok := used[string(loginID)]; !ok { - return loginID - } - } + return bridgeadapter.NextUserLoginID(user, "opencode") } diff --git a/bridges/opencode/portal_send.go b/bridges/opencode/portal_send.go index 33b530dd..1e6ea463 100644 --- a/bridges/opencode/portal_send.go +++ b/bridges/opencode/portal_send.go @@ -23,11 +23,12 @@ func (oc *OpenCodeClient) sendViaPortal( sender := oc.SenderForOpenCode(instanceID, false) msgID := newOpenCodeMessageID() evt := &OpenCodeRemoteMessage{ - portal: portal.PortalKey, - id: msgID, - sender: sender, - timestamp: time.Now(), - preBuilt: converted, + Portal: portal.PortalKey, + ID: msgID, + Sender: sender, + Timestamp: time.Now(), + LogKey: "opencode_msg_id", + PreBuilt: converted, } result := oc.UserLogin.QueueRemoteEvent(evt) if !result.Success { diff --git a/bridges/opencode/remote_events.go b/bridges/opencode/remote_events.go index 08941abd..be42b4c9 100644 --- a/bridges/opencode/remote_events.go +++ b/bridges/opencode/remote_events.go @@ -1,142 +1,17 @@ package opencode import ( - "context" - "fmt" - "time" - - "github.com/google/uuid" - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "github.com/beeper/ai-bridge/pkg/shared/streamtransport" -) - -// ----------------------------------------------------------------------- -// OpenCodeRemoteMessage — for sending messages via QueueRemoteEvent -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteMessage = (*OpenCodeRemoteMessage)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*OpenCodeRemoteMessage)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*OpenCodeRemoteMessage)(nil) + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) -// OpenCodeRemoteMessage is a RemoteMessage for OpenCode-generated content routed through bridgev2. -type OpenCodeRemoteMessage struct { - portal networkid.PortalKey - id networkid.MessageID - sender bridgev2.EventSender - timestamp time.Time - - preBuilt *bridgev2.ConvertedMessage -} - -func (m *OpenCodeRemoteMessage) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventMessage -} - -func (m *OpenCodeRemoteMessage) GetPortalKey() networkid.PortalKey { - return m.portal -} - -func (m *OpenCodeRemoteMessage) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("opencode_msg_id", string(m.id)) -} - -func (m *OpenCodeRemoteMessage) GetSender() bridgev2.EventSender { - return m.sender -} - -func (m *OpenCodeRemoteMessage) GetID() networkid.MessageID { - return m.id -} - -func (m *OpenCodeRemoteMessage) GetTimestamp() time.Time { - if m.timestamp.IsZero() { - return time.Now() - } - return m.timestamp -} - -func (m *OpenCodeRemoteMessage) GetStreamOrder() int64 { - return m.GetTimestamp().UnixMilli() -} - -func (m *OpenCodeRemoteMessage) ConvertMessage(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return m.preBuilt, nil -} - -// ----------------------------------------------------------------------- -// OpenCodeRemoteEdit — for debounced streaming edits -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteEdit = (*OpenCodeRemoteEdit)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*OpenCodeRemoteEdit)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*OpenCodeRemoteEdit)(nil) -) - -// OpenCodeRemoteEdit is a RemoteEdit for OpenCode streaming response edits. -type OpenCodeRemoteEdit struct { - portal networkid.PortalKey - sender bridgev2.EventSender - targetMessage networkid.MessageID - timestamp time.Time - - preBuilt *bridgev2.ConvertedEdit -} - -func (e *OpenCodeRemoteEdit) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventEdit -} - -func (e *OpenCodeRemoteEdit) GetPortalKey() networkid.PortalKey { - return e.portal -} - -func (e *OpenCodeRemoteEdit) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("opencode_edit_target", string(e.targetMessage)) -} - -func (e *OpenCodeRemoteEdit) GetSender() bridgev2.EventSender { - return e.sender -} - -func (e *OpenCodeRemoteEdit) GetTargetMessage() networkid.MessageID { - return e.targetMessage -} - -func (e *OpenCodeRemoteEdit) GetTimestamp() time.Time { - if e.timestamp.IsZero() { - return time.Now() - } - return e.timestamp -} - -func (e *OpenCodeRemoteEdit) GetStreamOrder() int64 { - return e.GetTimestamp().UnixMilli() -} - -func (e *OpenCodeRemoteEdit) ConvertEdit(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { - if e.preBuilt != nil && len(existing) > 0 { - for i, part := range e.preBuilt.ModifiedParts { - if part.Part == nil && i < len(existing) { - part.Part = existing[i] - } - } - } - streamtransport.EnsureDontRenderEdited(e.preBuilt) - return e.preBuilt, nil -} +// OpenCodeRemoteMessage is a type alias for the shared RemoteMessage. +type OpenCodeRemoteMessage = bridgeadapter.RemoteMessage -// ----------------------------------------------------------------------- -// Helpers -// ----------------------------------------------------------------------- +// OpenCodeRemoteEdit is a type alias for the shared RemoteEdit. +type OpenCodeRemoteEdit = bridgeadapter.RemoteEdit func newOpenCodeMessageID() networkid.MessageID { - return networkid.MessageID(fmt.Sprintf("opencode:%s", uuid.NewString())) + return bridgeadapter.NewMessageID("opencode") } diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index 5901484d..ff32cf7b 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -217,11 +217,12 @@ func (oc *OpenCodeClient) queueFinalStreamEdit(ctx context.Context, portal *brid } sender := oc.SenderForOpenCode(instanceID, false) oc.UserLogin.QueueRemoteEvent(&OpenCodeRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: state.networkMessageID, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: state.networkMessageID, + Timestamp: time.Now(), + LogKey: "opencode_edit_target", + PreBuilt: &bridgev2.ConvertedEdit{ ModifiedParts: []*bridgev2.ConvertedEditPart{{ Type: event.EventMessage, Content: &event.MessageEventContent{ diff --git a/pkg/bridgeadapter/identifier_helpers.go b/pkg/bridgeadapter/identifier_helpers.go index cabd0dbb..c393892f 100644 --- a/pkg/bridgeadapter/identifier_helpers.go +++ b/pkg/bridgeadapter/identifier_helpers.go @@ -2,6 +2,7 @@ package bridgeadapter import ( "fmt" + "net/url" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" @@ -16,6 +17,33 @@ func HumanUserID(prefix string, loginID networkid.UserLoginID) networkid.UserID return networkid.UserID(prefix + ":" + string(loginID)) } +// MakeUserLoginID creates a login ID in the format "prefix:escaped-mxid[:ordinal]". +func MakeUserLoginID(prefix string, mxid id.UserID, ordinal int) networkid.UserLoginID { + escaped := url.PathEscape(string(mxid)) + base := networkid.UserLoginID(fmt.Sprintf("%s:%s", prefix, escaped)) + if ordinal <= 1 { + return base + } + return networkid.UserLoginID(fmt.Sprintf("%s:%d", base, ordinal)) +} + +// NextUserLoginID finds the next available ordinal for a login ID with the given prefix. +func NextUserLoginID(user *bridgev2.User, prefix string) networkid.UserLoginID { + used := map[string]struct{}{} + for _, existing := range user.GetUserLogins() { + if existing == nil { + continue + } + used[string(existing.ID)] = struct{}{} + } + for ordinal := 1; ; ordinal++ { + loginID := MakeUserLoginID(prefix, user.MXID, ordinal) + if _, ok := used[string(loginID)]; !ok { + return loginID + } + } +} + func SingleLoginFlow(enabled bool, flow bridgev2.LoginFlow) []bridgev2.LoginFlow { if !enabled { return nil diff --git a/pkg/bridgeadapter/remote_events.go b/pkg/bridgeadapter/remote_events.go new file mode 100644 index 00000000..e1ff1c28 --- /dev/null +++ b/pkg/bridgeadapter/remote_events.go @@ -0,0 +1,147 @@ +package bridgeadapter + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + + "github.com/beeper/ai-bridge/pkg/shared/streamtransport" +) + +// ----------------------------------------------------------------------- +// RemoteMessage — generic pre-built message for QueueRemoteEvent +// ----------------------------------------------------------------------- + +var ( + _ bridgev2.RemoteMessage = (*RemoteMessage)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*RemoteMessage)(nil) + _ bridgev2.RemoteEventWithStreamOrder = (*RemoteMessage)(nil) +) + +// RemoteMessage is a bridge-agnostic RemoteMessage implementation backed by pre-built content. +type RemoteMessage struct { + Portal networkid.PortalKey + ID networkid.MessageID + Sender bridgev2.EventSender + Timestamp time.Time + PreBuilt *bridgev2.ConvertedMessage + + // LogKey is the zerolog field name used in AddLogContext (e.g. "ai_msg_id", "codex_msg_id"). + LogKey string +} + +func (m *RemoteMessage) GetType() bridgev2.RemoteEventType { + return bridgev2.RemoteEventMessage +} + +func (m *RemoteMessage) GetPortalKey() networkid.PortalKey { + return m.Portal +} + +func (m *RemoteMessage) AddLogContext(c zerolog.Context) zerolog.Context { + return c.Str(m.LogKey, string(m.ID)) +} + +func (m *RemoteMessage) GetSender() bridgev2.EventSender { + return m.Sender +} + +func (m *RemoteMessage) GetID() networkid.MessageID { + return m.ID +} + +func (m *RemoteMessage) GetTimestamp() time.Time { + if m.Timestamp.IsZero() { + return time.Now() + } + return m.Timestamp +} + +func (m *RemoteMessage) GetStreamOrder() int64 { + return m.GetTimestamp().UnixMilli() +} + +func (m *RemoteMessage) ConvertMessage(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return m.PreBuilt, nil +} + +// ----------------------------------------------------------------------- +// RemoteEdit — generic pre-built edit for QueueRemoteEvent +// ----------------------------------------------------------------------- + +var ( + _ bridgev2.RemoteEdit = (*RemoteEdit)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*RemoteEdit)(nil) + _ bridgev2.RemoteEventWithStreamOrder = (*RemoteEdit)(nil) +) + +// RemoteEdit is a bridge-agnostic RemoteEdit implementation backed by pre-built content. +type RemoteEdit struct { + Portal networkid.PortalKey + Sender bridgev2.EventSender + TargetMessage networkid.MessageID + Timestamp time.Time + PreBuilt *bridgev2.ConvertedEdit + + // LogKey is the zerolog field name used in AddLogContext (e.g. "ai_edit_target", "codex_edit_target"). + LogKey string +} + +func (e *RemoteEdit) GetType() bridgev2.RemoteEventType { + return bridgev2.RemoteEventEdit +} + +func (e *RemoteEdit) GetPortalKey() networkid.PortalKey { + return e.Portal +} + +func (e *RemoteEdit) AddLogContext(c zerolog.Context) zerolog.Context { + return c.Str(e.LogKey, string(e.TargetMessage)) +} + +func (e *RemoteEdit) GetSender() bridgev2.EventSender { + return e.Sender +} + +func (e *RemoteEdit) GetTargetMessage() networkid.MessageID { + return e.TargetMessage +} + +func (e *RemoteEdit) GetTimestamp() time.Time { + if e.Timestamp.IsZero() { + return time.Now() + } + return e.Timestamp +} + +func (e *RemoteEdit) GetStreamOrder() int64 { + return e.GetTimestamp().UnixMilli() +} + +func (e *RemoteEdit) ConvertEdit(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { + if e.PreBuilt != nil && len(existing) > 0 { + for i, part := range e.PreBuilt.ModifiedParts { + if part.Part == nil && i < len(existing) { + part.Part = existing[i] + } + } + } + streamtransport.EnsureDontRenderEdited(e.PreBuilt) + return e.PreBuilt, nil +} + +// ----------------------------------------------------------------------- +// NewMessageID — generates a unique message ID with the given prefix +// ----------------------------------------------------------------------- + +// NewMessageID generates a unique message ID in the format "prefix:uuid". +func NewMessageID(prefix string) networkid.MessageID { + return networkid.MessageID(fmt.Sprintf("%s:%s", prefix, uuid.NewString())) +} diff --git a/pkg/connector/portal_send.go b/pkg/connector/portal_send.go index 57e01689..5829d40e 100644 --- a/pkg/connector/portal_send.go +++ b/pkg/connector/portal_send.go @@ -9,6 +9,8 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func ensureConvertedMessageParts(converted *bridgev2.ConvertedMessage) { @@ -45,12 +47,13 @@ func (oc *AIClient) sendViaPortal( } ensureConvertedMessageParts(converted) sender := oc.senderForPortal(ctx, portal) - evt := &AIRemoteMessage{ - portal: portal.PortalKey, - id: msgID, - sender: sender, - timestamp: time.Now(), - preBuilt: converted, + evt := &bridgeadapter.RemoteMessage{ + Portal: portal.PortalKey, + ID: msgID, + Sender: sender, + Timestamp: time.Now(), + LogKey: "ai_msg_id", + PreBuilt: converted, } result := oc.UserLogin.QueueRemoteEvent(evt) if !result.Success { @@ -73,12 +76,13 @@ func (oc *AIClient) sendEditViaPortal( return fmt.Errorf("invalid portal") } sender := oc.senderForPortal(ctx, portal) - evt := &AIRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: targetMsgID, - timestamp: time.Now(), - preBuilt: converted, + evt := &bridgeadapter.RemoteEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: targetMsgID, + Timestamp: time.Now(), + LogKey: "ai_edit_target", + PreBuilt: converted, } result := oc.UserLogin.QueueRemoteEvent(evt) if !result.Success { diff --git a/pkg/connector/remote_events.go b/pkg/connector/remote_events.go index d40fc36b..cfbf7ba8 100644 --- a/pkg/connector/remote_events.go +++ b/pkg/connector/remote_events.go @@ -1,11 +1,8 @@ package connector import ( - "context" - "fmt" "time" - "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/variationselector" @@ -14,149 +11,10 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/connector/msgconv" - "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) -// ----------------------------------------------------------------------- -// AIRemoteMessage — covers plain text, tool call events, tool result events, media -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteMessage = (*AIRemoteMessage)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*AIRemoteMessage)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*AIRemoteMessage)(nil) - _ bridgev2.RemoteMessageWithTransactionID = (*AIRemoteMessage)(nil) -) - -// AIMessageVariant identifies what kind of Matrix event this message produces. -type AIMessageVariant int - -const ( - AIMessageText AIMessageVariant = iota // Plain assistant text - AIMessageToolCall // Tool call timeline event - AIMessageToolResult // Tool result timeline event -) - -// AIRemoteMessage is a RemoteMessage for AI-generated content routed through bridgev2. -type AIRemoteMessage struct { - portal networkid.PortalKey - id networkid.MessageID - sender bridgev2.EventSender - timestamp time.Time - txnID networkid.TransactionID - variant AIMessageVariant - - // Pre-built event content — the conversion is done before queuing because - // AI messages are constructed with full knowledge of what to send (no - // lazy resolution needed). This is the same pattern as simplevent.PreConvertedMessage. - preBuilt *bridgev2.ConvertedMessage -} - -func (m *AIRemoteMessage) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventMessage -} - -func (m *AIRemoteMessage) GetPortalKey() networkid.PortalKey { - return m.portal -} - -func (m *AIRemoteMessage) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("ai_msg_id", string(m.id)).Int("variant", int(m.variant)) -} - -func (m *AIRemoteMessage) GetSender() bridgev2.EventSender { - return m.sender -} - -func (m *AIRemoteMessage) GetID() networkid.MessageID { - return m.id -} - -func (m *AIRemoteMessage) GetTimestamp() time.Time { - if m.timestamp.IsZero() { - return time.Now() - } - return m.timestamp -} - -func (m *AIRemoteMessage) GetStreamOrder() int64 { - return m.GetTimestamp().UnixMilli() -} - -func (m *AIRemoteMessage) GetTransactionID() networkid.TransactionID { - return m.txnID -} - -func (m *AIRemoteMessage) ConvertMessage(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return m.preBuilt, nil -} - -// ----------------------------------------------------------------------- -// AIRemoteEdit — for final streaming edits (m.replace) -// ----------------------------------------------------------------------- - -var ( - _ bridgev2.RemoteEdit = (*AIRemoteEdit)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*AIRemoteEdit)(nil) - _ bridgev2.RemoteEventWithStreamOrder = (*AIRemoteEdit)(nil) -) - -// AIRemoteEdit is a RemoteEdit for the final streaming response edit. -type AIRemoteEdit struct { - portal networkid.PortalKey - sender bridgev2.EventSender - targetMessage networkid.MessageID - timestamp time.Time - - // Pre-built edit content. - preBuilt *bridgev2.ConvertedEdit -} - -func (e *AIRemoteEdit) GetType() bridgev2.RemoteEventType { - return bridgev2.RemoteEventEdit -} - -func (e *AIRemoteEdit) GetPortalKey() networkid.PortalKey { - return e.portal -} - -func (e *AIRemoteEdit) AddLogContext(c zerolog.Context) zerolog.Context { - return c.Str("ai_edit_target", string(e.targetMessage)) -} - -func (e *AIRemoteEdit) GetSender() bridgev2.EventSender { - return e.sender -} - -func (e *AIRemoteEdit) GetTargetMessage() networkid.MessageID { - return e.targetMessage -} - -func (e *AIRemoteEdit) GetTimestamp() time.Time { - if e.timestamp.IsZero() { - return time.Now() - } - return e.timestamp -} - -func (e *AIRemoteEdit) GetStreamOrder() int64 { - return e.GetTimestamp().UnixMilli() -} - -func (e *AIRemoteEdit) ConvertEdit(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { - // Bind existing DB parts to modified parts when Part was left nil at build time. - if e.preBuilt != nil && len(existing) > 0 { - for i, part := range e.preBuilt.ModifiedParts { - if part.Part == nil && i < len(existing) { - part.Part = existing[i] - } - } - } - streamtransport.EnsureDontRenderEdited(e.preBuilt) - return e.preBuilt, nil -} - // ----------------------------------------------------------------------- // AIRemoteReaction — for AI-sent reactions // ----------------------------------------------------------------------- @@ -290,10 +148,10 @@ func (r *AIRemoteMessageRemove) GetTargetMessage() networkid.MessageID { } // ----------------------------------------------------------------------- -// Constructor helpers — build pre-converted messages for common patterns +// Constructor helpers // ----------------------------------------------------------------------- -// NewAITextMessage creates an AIRemoteMessage for a plain text assistant message. +// NewAITextMessage creates a RemoteMessage for a plain text assistant message. func NewAITextMessage( portal *bridgev2.Portal, login *bridgev2.UserLogin, @@ -301,7 +159,7 @@ func NewAITextMessage( meta *PortalMetadata, agentID string, modelID string, -) *AIRemoteMessage { +) *bridgeadapter.RemoteMessage { rendered := msgconv.BuildPlainMessageContent(msgconv.PlainMessageContentParams{ Text: text, }) @@ -309,13 +167,13 @@ func NewAITextMessage( if agentID != "" { senderID = agentUserID(agentID) } - return &AIRemoteMessage{ - portal: portal.PortalKey, - id: newMessageID(), - sender: bridgev2.EventSender{Sender: senderID, SenderLogin: login.ID}, - timestamp: time.Now(), - variant: AIMessageText, - preBuilt: &bridgev2.ConvertedMessage{ + return &bridgeadapter.RemoteMessage{ + Portal: portal.PortalKey, + ID: bridgeadapter.NewMessageID("ai"), + Sender: bridgev2.EventSender{Sender: senderID, SenderLogin: login.ID}, + Timestamp: time.Now(), + LogKey: "ai_msg_id", + PreBuilt: &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ ID: networkid.PartID("0"), Type: event.EventMessage, @@ -328,5 +186,5 @@ func NewAITextMessage( // newMessageID generates a unique message ID for AI remote events. func newMessageID() networkid.MessageID { - return networkid.MessageID(fmt.Sprintf("ai:%s", uuid.NewString())) + return bridgeadapter.NewMessageID("ai") } diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index f236fd34..f9c9a963 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -14,6 +14,7 @@ import ( "maunium.net/go/mautrix/id" "github.com/beeper/ai-bridge/pkg/agents" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/connector/msgconv" airuntime "github.com/beeper/ai-bridge/pkg/runtime" "github.com/beeper/ai-bridge/pkg/shared/citations" @@ -78,13 +79,13 @@ func (oc *AIClient) sendContinuationMessage(ctx context.Context, portal *bridgev if agentID != "" { senderID = agentUserID(agentID) } - msg := &AIRemoteMessage{ - portal: portal.PortalKey, - id: newMessageID(), - sender: bridgev2.EventSender{Sender: senderID, SenderLogin: oc.UserLogin.ID}, - timestamp: time.Now(), - variant: AIMessageText, - preBuilt: &bridgev2.ConvertedMessage{ + msg := &bridgeadapter.RemoteMessage{ + Portal: portal.PortalKey, + ID: newMessageID(), + Sender: bridgev2.EventSender{Sender: senderID, SenderLogin: oc.UserLogin.ID}, + Timestamp: time.Now(), + LogKey: "ai_msg_id", + PreBuilt: &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ ID: networkid.PartID("0"), Type: event.EventMessage, @@ -731,11 +732,12 @@ func (oc *AIClient) sendFinalAssistantTurnContent(ctx context.Context, portal *b TopLevelExtra: topLevelExtra, }}, } - oc.UserLogin.QueueRemoteEvent(&AIRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: state.networkMessageID, - preBuilt: editContent, + oc.UserLogin.QueueRemoteEvent(&bridgeadapter.RemoteEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: state.networkMessageID, + LogKey: "ai_edit_target", + PreBuilt: editContent, }) oc.recordAgentActivity(ctx, portal, meta) oc.loggerForContext(ctx).Debug(). diff --git a/pkg/connector/stream_transport.go b/pkg/connector/stream_transport.go index 17a052c5..df9dad76 100644 --- a/pkg/connector/stream_transport.go +++ b/pkg/connector/stream_transport.go @@ -7,6 +7,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) @@ -25,12 +26,13 @@ func (oc *AIClient) sendDebouncedStreamEdit(ctx context.Context, portal *bridgev return nil } sender := oc.senderForPortal(ctx, portal) - oc.UserLogin.QueueRemoteEvent(&AIRemoteEdit{ - portal: portal.PortalKey, - sender: sender, - targetMessage: state.networkMessageID, - timestamp: time.Now(), - preBuilt: &bridgev2.ConvertedEdit{ + oc.UserLogin.QueueRemoteEvent(&bridgeadapter.RemoteEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: state.networkMessageID, + Timestamp: time.Now(), + LogKey: "ai_edit_target", + PreBuilt: &bridgev2.ConvertedEdit{ ModifiedParts: []*bridgev2.ConvertedEditPart{{ Type: event.EventMessage, Content: &event.MessageEventContent{ From 252b825626f5ff1dfafa59cd1eefc769755f6ac0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:24:43 +0100 Subject: [PATCH 02/23] Use bridgeadapter and consolidate utilities Replace ad-hoc ID and stream helpers with shared utilities and refactor related code. Key changes: - Use bridgeadapter.NewMessageID / MakeUserLoginID / NextUserLoginID across Codex, OpenCode and AI connector; remove duplicated ID helpers. - Switch stream edit building and emission to streamtransport helpers (BuildConvertedEdit, EmitStreamEvent) and consolidate streaming setup. - Replace nowMillis() with time.Now().UnixMilli() and remove the deleted time_util file. - Consolidate OpenCode session mutations via runSessionMutation and simplify part event creation paths. - Refactor token backfill logic in OpenCode to use a generic backfillTokenValue helper. - Centralize MIME->message type logic to shared media package. - Add several small helper functions to pkg/connector: loginDBContext, modelMemberUserInfo/modelJoinMember, topDesktopChatLabels, dispatchMediaUnderstandingFallback, heartbeatSessionPortalCandidate, applyHeartbeatVisibility, populateAudioMessageContent, joinChatText, appendUserIDOption. - Introduce new connector-only tool helper (newConnectorOnlyTool) and replace placeholder Beeper tools with connector-only variants. - Minor import adjustments and other small cleanups across many files to use the new shared helpers. These changes reduce duplication, centralize ID/message/stream handling, and simplify future maintenance. --- bridges/codex/client.go | 23 +-- bridges/codex/connector.go | 2 +- bridges/codex/identifiers.go | 13 -- bridges/codex/login.go | 2 +- bridges/codex/portal_send.go | 4 +- bridges/codex/remote_events.go | 7 - bridges/codex/stream_transport.go | 49 +++--- bridges/codex/streaming_support.go | 3 +- bridges/codex/time_util.go | 7 - bridges/opencode/host.go | 3 +- bridges/opencode/identifiers.go | 17 -- bridges/opencode/login.go | 3 +- .../opencodebridge/backfill_canonical.go | 34 ++-- bridges/opencode/opencodebridge/mime.go | 16 +- .../opencodebridge/opencode_manager.go | 30 ++-- .../opencode/opencodebridge/opencode_parts.go | 32 ++-- bridges/opencode/portal_send.go | 4 +- bridges/opencode/remote_events.go | 6 - bridges/opencode/stream_canonical.go | 20 +-- pkg/agents/tools/beeper_docs.go | 30 +--- pkg/agents/tools/beeper_send_feedback.go | 30 +--- pkg/agents/tools/connector_only.go | 27 +++ pkg/connector/bridge_db.go | 11 ++ pkg/connector/chat.go | 71 ++++---- pkg/connector/desktop_api_sessions.go | 27 ++- pkg/connector/handlematrix.go | 96 ++++++++--- pkg/connector/heartbeat_execute.go | 47 +++-- pkg/connector/heartbeat_visibility.go | 40 ++--- pkg/connector/media_send.go | 33 ++-- pkg/connector/messages.go | 26 +-- pkg/connector/portal_send.go | 2 +- pkg/connector/provider_openai.go | 31 ++-- pkg/connector/remote_events.go | 4 - pkg/connector/response_finalization.go | 54 +----- pkg/connector/scheduler_db.go | 160 +++++++----------- pkg/connector/session_store.go | 11 +- pkg/connector/stream_events.go | 24 ++- pkg/connector/stream_transport.go | 25 +-- pkg/connector/subagent_spawn.go | 52 +++--- pkg/connector/system_events_db.go | 11 +- pkg/connector/tools.go | 27 +-- pkg/runtime/chat_content.go | 70 ++++---- pkg/shared/citations/citations.go | 105 +++++++----- pkg/shared/media/message_type.go | 21 +++ pkg/shared/streamtransport/debounced_edit.go | 21 +++ pkg/shared/streamtransport/session.go | 32 ++++ pkg/shared/streamui/emitter.go | 21 +-- pkg/shared/streamui/recorder.go | 38 ++--- 48 files changed, 654 insertions(+), 768 deletions(-) delete mode 100644 bridges/codex/time_util.go delete mode 100644 bridges/opencode/identifiers.go create mode 100644 pkg/agents/tools/connector_only.go create mode 100644 pkg/shared/media/message_type.go diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 9a632ad1..3df621d8 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -1896,7 +1896,7 @@ func (cc *CodexClient) sendInitialStreamMessage(ctx context.Context, portal *bri "m.mentions": map[string]any{}, } - msgID := newMessageID() + msgID := bridgeadapter.NewMessageID("codex") converted := &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ ID: networkid.PartID("0"), @@ -2015,19 +2015,12 @@ func (cc *CodexClient) sendFinalAssistantTurn(ctx context.Context, portal *bridg TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "codex_edit_target", - PreBuilt: &bridgev2.ConvertedEdit{ - ModifiedParts: []*bridgev2.ConvertedEditPart{{ - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgText, - Body: rendered.Body, - Format: rendered.Format, - FormattedBody: rendered.FormattedBody, - }, - Extra: map[string]any{"m.mentions": map[string]any{}}, - TopLevelExtra: topLevelExtra, - }}, - }, + PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: rendered.Body, + Format: rendered.Format, + FormattedBody: rendered.FormattedBody, + }, topLevelExtra), }) cc.loggerForContext(ctx).Debug(). Str("initial_event_id", state.initialEventID.String()). @@ -2061,7 +2054,7 @@ func (cc *CodexClient) sendContinuationMessage(ctx context.Context, portal *brid sender := cc.senderForPortal() cc.UserLogin.QueueRemoteEvent(&CodexRemoteMessage{ Portal: portal.PortalKey, - ID: newMessageID(), + ID: bridgeadapter.NewMessageID("codex"), Sender: sender, Timestamp: time.Now(), LogKey: "codex_msg_id", diff --git a/bridges/codex/connector.go b/bridges/codex/connector.go index 8a851390..772869ce 100644 --- a/bridges/codex/connector.go +++ b/bridges/codex/connector.go @@ -183,7 +183,7 @@ func (cc *CodexConnector) autoProvisionExistingCodex(ctx context.Context) { } // Use a deterministic instance ID so restarts won't create duplicates. - loginID := makeCodexUserLoginID(mxid, 1) + loginID := bridgeadapter.MakeUserLoginID("codex", mxid, 1) // If this login already exists in the DB (e.g. from a previous run), skip creation. existing, err := cc.br.GetExistingUserLoginByID(ctx, loginID) diff --git a/bridges/codex/identifiers.go b/bridges/codex/identifiers.go index df22ef07..32645203 100644 --- a/bridges/codex/identifiers.go +++ b/bridges/codex/identifiers.go @@ -4,21 +4,8 @@ import ( "strings" "github.com/rs/xid" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" - - "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) -func makeCodexUserLoginID(mxid id.UserID, ordinal int) networkid.UserLoginID { - return bridgeadapter.MakeUserLoginID("codex", mxid, ordinal) -} - -func nextCodexUserLoginID(user *bridgev2.User) networkid.UserLoginID { - return bridgeadapter.NextUserLoginID(user, "codex") -} - func generateShortID() string { return xid.New().String() } diff --git a/bridges/codex/login.go b/bridges/codex/login.go index db724f4a..067a55df 100644 --- a/bridges/codex/login.go +++ b/bridges/codex/login.go @@ -609,7 +609,7 @@ func (cl *CodexLogin) finishLogin(ctx context.Context) (*bridgev2.LoginStep, err persistCtx := cl.backgroundProcessContext() log := cl.logger(persistCtx) - loginID := nextCodexUserLoginID(cl.User) + loginID := bridgeadapter.NextUserLoginID(cl.User, "codex") remoteName := "Codex" dupCount := 0 for _, existing := range cl.User.GetUserLogins() { diff --git a/bridges/codex/portal_send.go b/bridges/codex/portal_send.go index 3d01fb29..0381e916 100644 --- a/bridges/codex/portal_send.go +++ b/bridges/codex/portal_send.go @@ -8,6 +8,8 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) // sendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. @@ -27,7 +29,7 @@ func (cc *CodexClient) sendViaPortal( return "", msgID, fmt.Errorf("bridge unavailable") } if msgID == "" { - msgID = newMessageID() + msgID = bridgeadapter.NewMessageID("codex") } sender := cc.senderForPortal() evt := &CodexRemoteMessage{ diff --git a/bridges/codex/remote_events.go b/bridges/codex/remote_events.go index 5e1376ed..3c6ec8f0 100644 --- a/bridges/codex/remote_events.go +++ b/bridges/codex/remote_events.go @@ -1,8 +1,6 @@ package codex import ( - "maunium.net/go/mautrix/bridgev2/networkid" - "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) @@ -11,8 +9,3 @@ type CodexRemoteMessage = bridgeadapter.RemoteMessage // CodexRemoteEdit is a type alias for the shared RemoteEdit. type CodexRemoteEdit = bridgeadapter.RemoteEdit - -// newMessageID generates a unique message ID for Codex remote events. -func newMessageID() networkid.MessageID { - return bridgeadapter.NewMessageID("codex") -} diff --git a/bridges/codex/stream_transport.go b/bridges/codex/stream_transport.go index a9ba417d..0efbe694 100644 --- a/bridges/codex/stream_transport.go +++ b/bridges/codex/stream_transport.go @@ -2,7 +2,6 @@ package codex import ( "context" - "strings" "time" "maunium.net/go/mautrix/bridgev2" @@ -33,22 +32,15 @@ func (cc *CodexClient) sendDebouncedStreamEdit(ctx context.Context, portal *brid TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "codex_edit_target", - PreBuilt: &bridgev2.ConvertedEdit{ - ModifiedParts: []*bridgev2.ConvertedEditPart{{ - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgText, - Body: content.Body, - Format: content.Format, - FormattedBody: content.FormattedBody, - }, - Extra: map[string]any{"m.mentions": map[string]any{}}, - TopLevelExtra: map[string]any{ - "com.beeper.dont_render_edited": true, - "m.mentions": map[string]any{}, - }, - }}, - }, + PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: content.Body, + Format: content.Format, + FormattedBody: content.FormattedBody, + }, map[string]any{ + "com.beeper.dont_render_edited": true, + "m.mentions": map[string]any{}, + }), }) return nil } @@ -101,19 +93,16 @@ func (cc *CodexClient) ensureStreamSession(ctx context.Context, portal *bridgev2 } func (cc *CodexClient) emitStreamEvent(ctx context.Context, portal *bridgev2.Portal, state *streamingState, part map[string]any) { - if portal == nil || portal.MXID == "" || state == nil || state.suppressSend { - return - } - if !state.loggedStreamStart { - state.loggedStreamStart = true - cc.loggerForContext(ctx).Info(). - Stringer("room_id", portal.MXID). - Str("turn_id", strings.TrimSpace(state.turnID)). - Msg("Streaming events") - } - session := cc.ensureStreamSession(ctx, portal, state) - if session == nil { + if state == nil { return } - session.EmitPart(ctx, part) + streamtransport.EmitStreamEvent(ctx, portal, streamtransport.StreamEventState{ + TurnID: state.turnID, + SuppressSend: state.suppressSend, + LoggedStart: &state.loggedStreamStart, + EnsureSession: func() *streamtransport.StreamSession { + return cc.ensureStreamSession(ctx, portal, state) + }, + Logger: cc.loggerForContext(ctx), + }, part) } diff --git a/bridges/codex/streaming_support.go b/bridges/codex/streaming_support.go index 41a52665..772ab67b 100644 --- a/bridges/codex/streaming_support.go +++ b/bridges/codex/streaming_support.go @@ -3,6 +3,7 @@ package codex import ( "context" "strings" + "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" @@ -68,7 +69,7 @@ func newStreamingState(_ context.Context, _ *PortalMetadata, sourceEventID id.Ev ui.InitMaps() return &streamingState{ turnID: turnID, - startedAtMs: nowMillis(), + startedAtMs: time.Now().UnixMilli(), firstToken: true, initialEventID: sourceEventID, ui: ui, diff --git a/bridges/codex/time_util.go b/bridges/codex/time_util.go deleted file mode 100644 index 6789f3aa..00000000 --- a/bridges/codex/time_util.go +++ /dev/null @@ -1,7 +0,0 @@ -package codex - -import "time" - -func nowMillis() int64 { - return time.Now().UnixMilli() -} diff --git a/bridges/opencode/host.go b/bridges/opencode/host.go index d5cb7f46..63b813af 100644 --- a/bridges/opencode/host.go +++ b/bridges/opencode/host.go @@ -16,6 +16,7 @@ import ( "maunium.net/go/mautrix/id" "github.com/beeper/ai-bridge/bridges/opencode/opencodebridge" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/connector/msgconv" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" @@ -119,7 +120,7 @@ func (oc *OpenCodeClient) EmitOpenCodeStreamEvent(ctx context.Context, portal *b instanceID = pmeta.InstanceID } sender := oc.SenderForOpenCode(instanceID, false) - msgID := newOpenCodeMessageID() + msgID := bridgeadapter.NewMessageID("opencode") uiMessage := msgconv.BuildUIMessage(msgconv.UIMessageParams{ TurnID: turnID, Role: "assistant", diff --git a/bridges/opencode/identifiers.go b/bridges/opencode/identifiers.go deleted file mode 100644 index 9fd29857..00000000 --- a/bridges/opencode/identifiers.go +++ /dev/null @@ -1,17 +0,0 @@ -package opencode - -import ( - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" - - "github.com/beeper/ai-bridge/pkg/bridgeadapter" -) - -func makeOpenCodeUserLoginID(mxid id.UserID, ordinal int) networkid.UserLoginID { - return bridgeadapter.MakeUserLoginID("opencode", mxid, ordinal) -} - -func nextOpenCodeUserLoginID(user *bridgev2.User) networkid.UserLoginID { - return bridgeadapter.NextUserLoginID(user, "opencode") -} diff --git a/bridges/opencode/login.go b/bridges/opencode/login.go index 7f01891e..dd59eee8 100644 --- a/bridges/opencode/login.go +++ b/bridges/opencode/login.go @@ -13,6 +13,7 @@ import ( openCodeAPI "github.com/beeper/ai-bridge/bridges/opencode/opencode" "github.com/beeper/ai-bridge/bridges/opencode/opencodebridge" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) var ( @@ -133,7 +134,7 @@ func (ol *OpenCodeLogin) SubmitUserInput(ctx context.Context, input map[string]s return openCodeCompleteStep(existing), nil } - loginID := nextOpenCodeUserLoginID(ol.User) + loginID := bridgeadapter.NextUserLoginID(ol.User, "opencode") login, err := ol.User.NewLogin(ctx, &database.UserLogin{ ID: loginID, diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index 60023793..7832f0aa 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -385,36 +385,30 @@ func backfillCost(msg opencode.MessageWithParts) float64 { } func backfillPromptTokens(msg opencode.MessageWithParts) int64 { - if msg.Info.Tokens != nil { - return int64(msg.Info.Tokens.Input) - } - for _, part := range msg.Parts { - if part.Type == "step-finish" && part.Tokens != nil { - return int64(part.Tokens.Input) - } - } - return 0 + return backfillTokenValue(msg, func(tokens opencode.TokenUsage) int64 { + return int64(tokens.Input) + }) } func backfillCompletionTokens(msg opencode.MessageWithParts) int64 { - if msg.Info.Tokens != nil { - return int64(msg.Info.Tokens.Output) - } - for _, part := range msg.Parts { - if part.Type == "step-finish" && part.Tokens != nil { - return int64(part.Tokens.Output) - } - } - return 0 + return backfillTokenValue(msg, func(tokens opencode.TokenUsage) int64 { + return int64(tokens.Output) + }) } func backfillReasoningTokens(msg opencode.MessageWithParts) int64 { + return backfillTokenValue(msg, func(tokens opencode.TokenUsage) int64 { + return int64(tokens.Reasoning) + }) +} + +func backfillTokenValue(msg opencode.MessageWithParts, pick func(opencode.TokenUsage) int64) int64 { if msg.Info.Tokens != nil { - return int64(msg.Info.Tokens.Reasoning) + return pick(*msg.Info.Tokens) } for _, part := range msg.Parts { if part.Type == "step-finish" && part.Tokens != nil { - return int64(part.Tokens.Reasoning) + return pick(*part.Tokens) } } return 0 diff --git a/bridges/opencode/opencodebridge/mime.go b/bridges/opencode/opencodebridge/mime.go index ab5615fa..9526de18 100644 --- a/bridges/opencode/opencodebridge/mime.go +++ b/bridges/opencode/opencodebridge/mime.go @@ -1,21 +1,11 @@ package opencodebridge import ( - "strings" - "maunium.net/go/mautrix/event" + + "github.com/beeper/ai-bridge/pkg/shared/media" ) func messageTypeForMIME(mimeType string) event.MessageType { - mimeType = strings.ToLower(strings.TrimSpace(mimeType)) - switch { - case strings.HasPrefix(mimeType, "image/"): - return event.MsgImage - case strings.HasPrefix(mimeType, "audio/"): - return event.MsgAudio - case strings.HasPrefix(mimeType, "video/"): - return event.MsgVideo - default: - return event.MsgFile - } + return media.MessageTypeForMIME(mimeType) } diff --git a/bridges/opencode/opencodebridge/opencode_manager.go b/bridges/opencode/opencodebridge/opencode_manager.go index a1ae686d..cdd54d45 100644 --- a/bridges/opencode/opencodebridge/opencode_manager.go +++ b/bridges/opencode/opencodebridge/opencode_manager.go @@ -382,31 +382,33 @@ func (m *OpenCodeManager) AbortSession(ctx context.Context, instanceID, sessionI } func (m *OpenCodeManager) CreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { - inst, err := m.requireConnectedInstance(instanceID) - if err != nil { - return nil, err - } - session, err := inst.client.CreateSession(ctx, title, directory) - if err != nil { - if opencode.IsAuthError(err) { - m.setConnected(inst, false) - } - return nil, fmt.Errorf("create session: %w", err) - } - return session, nil + return m.runSessionMutation(ctx, instanceID, "create session", func(inst *openCodeInstance) (*opencode.Session, error) { + return inst.client.CreateSession(ctx, title, directory) + }) } func (m *OpenCodeManager) UpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { + return m.runSessionMutation(ctx, instanceID, "update session title", func(inst *openCodeInstance) (*opencode.Session, error) { + return inst.client.UpdateSessionTitle(ctx, sessionID, title) + }) +} + +func (m *OpenCodeManager) runSessionMutation( + ctx context.Context, + instanceID string, + action string, + run func(*openCodeInstance) (*opencode.Session, error), +) (*opencode.Session, error) { inst, err := m.requireConnectedInstance(instanceID) if err != nil { return nil, err } - session, err := inst.client.UpdateSessionTitle(ctx, sessionID, title) + session, err := run(inst) if err != nil { if opencode.IsAuthError(err) { m.setConnected(inst, false) } - return nil, fmt.Errorf("update session title: %w", err) + return nil, fmt.Errorf("%s: %w", action, err) } return session, nil } diff --git a/bridges/opencode/opencodebridge/opencode_parts.go b/bridges/opencode/opencodebridge/opencode_parts.go index 5cd91328..598d84e2 100644 --- a/bridges/opencode/opencodebridge/opencode_parts.go +++ b/bridges/opencode/opencodebridge/opencode_parts.go @@ -22,35 +22,31 @@ type openCodePartEvent struct { } func (b *Bridge) emitOpenCodePart(ctx context.Context, portal *bridgev2.Portal, instanceID string, part opencode.Part, fromMe bool) { - if portal == nil || part.ID == "" { - return - } - remote := &simplevent.Message[openCodePartEvent]{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventMessage, - PortalKey: portal.PortalKey, - Sender: b.opencodeSender(instanceID, fromMe), - }, - ID: opencodePartMessageID(part.ID), - Data: openCodePartEvent{InstanceID: instanceID, Part: part}, - ConvertMessageFunc: b.convertOpenCodePartMessage, - } - b.queueRemoteEvent(remote) + b.emitOpenCodePartEvent(portal, instanceID, part, fromMe, bridgev2.RemoteEventMessage) } func (b *Bridge) emitOpenCodePartEdit(ctx context.Context, portal *bridgev2.Portal, instanceID string, part opencode.Part, fromMe bool) { + b.emitOpenCodePartEvent(portal, instanceID, part, fromMe, bridgev2.RemoteEventEdit) +} + +func (b *Bridge) emitOpenCodePartEvent(portal *bridgev2.Portal, instanceID string, part opencode.Part, fromMe bool, eventType bridgev2.RemoteEventType) { if portal == nil || part.ID == "" { return } remote := &simplevent.Message[openCodePartEvent]{ EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventEdit, + Type: eventType, PortalKey: portal.PortalKey, Sender: b.opencodeSender(instanceID, fromMe), }, - TargetMessage: opencodePartMessageID(part.ID), - Data: openCodePartEvent{InstanceID: instanceID, Part: part}, - ConvertEditFunc: b.convertOpenCodePartEdit, + Data: openCodePartEvent{InstanceID: instanceID, Part: part}, + } + if eventType == bridgev2.RemoteEventMessage { + remote.ID = opencodePartMessageID(part.ID) + remote.ConvertMessageFunc = b.convertOpenCodePartMessage + } else { + remote.TargetMessage = opencodePartMessageID(part.ID) + remote.ConvertEditFunc = b.convertOpenCodePartEdit } b.queueRemoteEvent(remote) } diff --git a/bridges/opencode/portal_send.go b/bridges/opencode/portal_send.go index 1e6ea463..d3d3423c 100644 --- a/bridges/opencode/portal_send.go +++ b/bridges/opencode/portal_send.go @@ -8,6 +8,8 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) // sendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. @@ -21,7 +23,7 @@ func (oc *OpenCodeClient) sendViaPortal( return fmt.Errorf("invalid portal") } sender := oc.SenderForOpenCode(instanceID, false) - msgID := newOpenCodeMessageID() + msgID := bridgeadapter.NewMessageID("opencode") evt := &OpenCodeRemoteMessage{ Portal: portal.PortalKey, ID: msgID, diff --git a/bridges/opencode/remote_events.go b/bridges/opencode/remote_events.go index be42b4c9..31339b23 100644 --- a/bridges/opencode/remote_events.go +++ b/bridges/opencode/remote_events.go @@ -1,8 +1,6 @@ package opencode import ( - "maunium.net/go/mautrix/bridgev2/networkid" - "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) @@ -11,7 +9,3 @@ type OpenCodeRemoteMessage = bridgeadapter.RemoteMessage // OpenCodeRemoteEdit is a type alias for the shared RemoteEdit. type OpenCodeRemoteEdit = bridgeadapter.RemoteEdit - -func newOpenCodeMessageID() networkid.MessageID { - return bridgeadapter.NewMessageID("opencode") -} diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index ff32cf7b..498a7945 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -15,6 +15,7 @@ import ( "github.com/beeper/ai-bridge/pkg/connector/msgconv" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/maputil" + "github.com/beeper/ai-bridge/pkg/shared/streamtransport" "github.com/beeper/ai-bridge/pkg/shared/streamui" "github.com/beeper/ai-bridge/pkg/shared/stringutil" ) @@ -222,19 +223,12 @@ func (oc *OpenCodeClient) queueFinalStreamEdit(ctx context.Context, portal *brid TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "opencode_edit_target", - PreBuilt: &bridgev2.ConvertedEdit{ - ModifiedParts: []*bridgev2.ConvertedEditPart{{ - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgText, - Body: rendered.Body, - Format: rendered.Format, - FormattedBody: rendered.FormattedBody, - }, - Extra: map[string]any{"m.mentions": map[string]any{}}, - TopLevelExtra: topLevelExtra, - }}, - }, + PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: rendered.Body, + Format: rendered.Format, + FormattedBody: rendered.FormattedBody, + }, topLevelExtra), }) } diff --git a/pkg/agents/tools/beeper_docs.go b/pkg/agents/tools/beeper_docs.go index 50573477..2051ae4a 100644 --- a/pkg/agents/tools/beeper_docs.go +++ b/pkg/agents/tools/beeper_docs.go @@ -1,27 +1,11 @@ package tools -import ( - "context" - - "github.com/modelcontextprotocol/go-sdk/mcp" - - "github.com/beeper/ai-bridge/pkg/shared/toolspec" -) +import "github.com/beeper/ai-bridge/pkg/shared/toolspec" // BeeperDocsTool is the Beeper help documentation search tool. -var BeeperDocsTool = &Tool{ - Tool: mcp.Tool{ - Name: toolspec.BeeperDocsName, - Description: toolspec.BeeperDocsDescription, - Annotations: &mcp.ToolAnnotations{Title: "Beeper Docs"}, - InputSchema: toolspec.BeeperDocsSchema(), - }, - Type: ToolTypeBuiltin, - Group: GroupWeb, - Execute: executeBeeperDocsPlaceholder, -} - -// executeBeeperDocsPlaceholder is a no-op; real execution happens in the connector. -func executeBeeperDocsPlaceholder(_ context.Context, _ map[string]any) (*Result, error) { - return ErrorResult("beeper_docs", "beeper_docs is only available through the connector"), nil -} +var BeeperDocsTool = newConnectorOnlyTool( + toolspec.BeeperDocsName, + toolspec.BeeperDocsDescription, + "Beeper Docs", + toolspec.BeeperDocsSchema(), +) diff --git a/pkg/agents/tools/beeper_send_feedback.go b/pkg/agents/tools/beeper_send_feedback.go index a8ee7138..3b6a4d0d 100644 --- a/pkg/agents/tools/beeper_send_feedback.go +++ b/pkg/agents/tools/beeper_send_feedback.go @@ -1,27 +1,11 @@ package tools -import ( - "context" - - "github.com/modelcontextprotocol/go-sdk/mcp" - - "github.com/beeper/ai-bridge/pkg/shared/toolspec" -) +import "github.com/beeper/ai-bridge/pkg/shared/toolspec" // BeeperSendFeedbackTool is the Beeper feedback submission tool. -var BeeperSendFeedbackTool = &Tool{ - Tool: mcp.Tool{ - Name: toolspec.BeeperSendFeedbackName, - Description: toolspec.BeeperSendFeedbackDescription, - Annotations: &mcp.ToolAnnotations{Title: "Beeper Send Feedback"}, - InputSchema: toolspec.BeeperSendFeedbackSchema(), - }, - Type: ToolTypeBuiltin, - Group: GroupWeb, - Execute: executeBeeperSendFeedbackPlaceholder, -} - -// executeBeeperSendFeedbackPlaceholder is a no-op; real execution happens in the connector. -func executeBeeperSendFeedbackPlaceholder(_ context.Context, _ map[string]any) (*Result, error) { - return ErrorResult("beeper_send_feedback", "beeper_send_feedback is only available through the connector"), nil -} +var BeeperSendFeedbackTool = newConnectorOnlyTool( + toolspec.BeeperSendFeedbackName, + toolspec.BeeperSendFeedbackDescription, + "Beeper Send Feedback", + toolspec.BeeperSendFeedbackSchema(), +) diff --git a/pkg/agents/tools/connector_only.go b/pkg/agents/tools/connector_only.go new file mode 100644 index 00000000..3694832d --- /dev/null +++ b/pkg/agents/tools/connector_only.go @@ -0,0 +1,27 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func newConnectorOnlyTool(name, description, title string, schema map[string]any) *Tool { + return &Tool{ + Tool: mcp.Tool{ + Name: name, + Description: description, + Annotations: &mcp.ToolAnnotations{Title: title}, + InputSchema: schema, + }, + Type: ToolTypeBuiltin, + Group: GroupWeb, + Execute: connectorOnlyPlaceholder(name), + } +} + +func connectorOnlyPlaceholder(toolName string) func(context.Context, map[string]any) (*Result, error) { + return func(_ context.Context, _ map[string]any) (*Result, error) { + return ErrorResult(toolName, toolName+" is only available through the connector"), nil + } +} diff --git a/pkg/connector/bridge_db.go b/pkg/connector/bridge_db.go index 15638332..71e118c6 100644 --- a/pkg/connector/bridge_db.go +++ b/pkg/connector/bridge_db.go @@ -53,3 +53,14 @@ func bridgeDBFromLogin(login *bridgev2.UserLogin) *dbutil.Database { } return nil } + +func loginDBContext(client *AIClient) (*dbutil.Database, string, string) { + if client == nil || client.UserLogin == nil || client.UserLogin.Bridge == nil { + return nil, "", "" + } + db := client.bridgeDB() + if db == nil || client.UserLogin.Bridge.DB == nil { + return nil, "", "" + } + return db, string(client.UserLogin.Bridge.DB.BridgeID), string(client.UserLogin.ID) +} diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 51e30415..38761984 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -541,17 +541,36 @@ func (oc *AIClient) resolveModelIdentifier(ctx context.Context, modelID string, info := oc.findModelInfo(modelID) return &bridgev2.ResolveIdentifierResponse{ - UserID: userID, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr(modelContactName(modelID, info)), - IsBot: ptr.Ptr(false), - Identifiers: modelContactIdentifiers(modelID, info), - }, - Ghost: ghost, - Chat: chatResp, + UserID: userID, + UserInfo: modelMemberUserInfo(modelID, info), + Ghost: ghost, + Chat: chatResp, }, nil } +func modelMemberUserInfo(modelID string, info *ModelInfo) *bridgev2.UserInfo { + return &bridgev2.UserInfo{ + Name: ptr.Ptr(modelContactName(modelID, info)), + IsBot: ptr.Ptr(false), + Identifiers: modelContactIdentifiers(modelID, info), + } +} + +func modelJoinMember(loginID networkid.UserLoginID, modelID, modelName string, info *ModelInfo) bridgev2.ChatMember { + return bridgev2.ChatMember{ + EventSender: bridgev2.EventSender{ + Sender: modelUserID(modelID), + SenderLogin: loginID, + }, + Membership: event.MembershipJoin, + UserInfo: modelMemberUserInfo(modelID, info), + MemberEventExtra: map[string]any{ + "displayname": modelName, + "com.beeper.ai.model_id": modelID, + }, + } +} + // createAgentChat creates a new chat room for an agent func (oc *AIClient) createAgentChat(ctx context.Context, agent *agents.AgentDefinition) (*bridgev2.CreateChatResponse, error) { return oc.createAgentChatWithModel(ctx, agent, "", false) @@ -1197,24 +1216,7 @@ func (oc *AIClient) composeChatInfo(title, modelID string) *bridgev2.ChatInfo { }, Membership: event.MembershipJoin, }, - modelUserID(modelID): { - EventSender: bridgev2.EventSender{ - Sender: modelUserID(modelID), - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr(modelName), - IsBot: ptr.Ptr(false), - Identifiers: modelContactIdentifiers(modelID, modelInfo), - }, - // Set displayname directly in membership event content - // This works because MemberEventContent.Displayname has omitempty - MemberEventExtra: map[string]any{ - "displayname": modelName, - "com.beeper.ai.model_id": modelID, - }, - }, + modelUserID(modelID): modelJoinMember(oc.UserLogin.ID, modelID, modelName, modelInfo), } return &bridgev2.ChatInfo{ Name: ptr.Ptr(title), @@ -1410,22 +1412,7 @@ func (oc *AIClient) handleModelSwitch(ctx context.Context, portal *bridgev2.Port Membership: event.MembershipLeave, PrevMembership: event.MembershipJoin, }, - modelUserID(newModel): { - EventSender: bridgev2.EventSender{ - Sender: modelUserID(newModel), - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr(newModelName), - IsBot: ptr.Ptr(false), - Identifiers: modelContactIdentifiers(newModel, newInfo), - }, - MemberEventExtra: map[string]any{ - "displayname": newModelName, - "com.beeper.ai.model_id": newModel, - }, - }, + modelUserID(newModel): modelJoinMember(oc.UserLogin.ID, newModel, newModelName, newInfo), }, } diff --git a/pkg/connector/desktop_api_sessions.go b/pkg/connector/desktop_api_sessions.go index 2db0bb0b..d2621e98 100644 --- a/pkg/connector/desktop_api_sessions.go +++ b/pkg/connector/desktop_api_sessions.go @@ -564,23 +564,11 @@ func (oc *AIClient) resolveDesktopSessionByLabelWithOptions(ctx context.Context, return exactMatches[0].ID, key, nil } if len(exactMatches) > 1 { - titles := make([]string, 0, len(exactMatches)) - for i, chat := range exactMatches { - if i >= 5 { - break - } - titles = append(titles, describeDesktopChatForLabel(chat, accounts[strings.TrimSpace(chat.AccountID)])) - } + titles := topDesktopChatLabels(exactMatches, accounts) return "", "", fmt.Errorf("%w: label '%s' matched multiple chats (%s)", errDesktopLabelAmbiguous, trimmed, strings.Join(titles, ", ")) } if len(partialMatches) > 0 { - suggestions := make([]string, 0, len(partialMatches)) - for i, chat := range partialMatches { - if i >= 5 { - break - } - suggestions = append(suggestions, describeDesktopChatForLabel(chat, accounts[strings.TrimSpace(chat.AccountID)])) - } + suggestions := topDesktopChatLabels(partialMatches, accounts) return "", "", fmt.Errorf("%w: no exact session found for label '%s'. Top matches: %s. Use the sessionKey from sessions_list for deterministic targeting", errDesktopLabelNotFound, trimmed, strings.Join(suggestions, ", ")) } acctID := strings.TrimSpace(opts.AccountID) @@ -598,6 +586,17 @@ func (oc *AIClient) resolveDesktopSessionByLabelWithOptions(ctx context.Context, return "", "", fmt.Errorf("%w: no session found for label '%s'. Use the sessionKey from sessions_list", errDesktopLabelNotFound, trimmed) } +func topDesktopChatLabels(chats []beeperdesktopapi.Chat, accounts map[string]beeperdesktopapi.Account) []string { + labels := make([]string, 0, min(len(chats), 5)) + for i, chat := range chats { + if i >= 5 { + break + } + labels = append(labels, describeDesktopChatForLabel(chat, accounts[strings.TrimSpace(chat.AccountID)])) + } + return labels +} + func (oc *AIClient) resolveDesktopSessionByLabel(ctx context.Context, instance, label string) (string, string, error) { return oc.resolveDesktopSessionByLabelWithOptions(ctx, instance, label, desktopLabelResolveOptions{}) } diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index 4e590e3b..16554033 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -785,38 +785,48 @@ func (oc *AIClient) handleMediaMessage( // If model lacks vision but agent supports image understanding, analyze image first. if msgType == event.MsgImage { visionModel, visionFallback := oc.resolveVisionModelForImage(ctx, meta) - if visionFallback && visionModel != "" { - analysisPrompt := buildImageUnderstandingPrompt(caption, hasUserCaption) - description, err := oc.analyzeImageWithModel(ctx, visionModel, string(mediaURL), mimeType, encryptedFile, analysisPrompt) - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Image understanding failed") - return nil, messageSendStatusError(err, "Couldn't analyze the image. Try again, or switch to a vision-capable model with !ai model.", "") - } - - combined := buildImageUnderstandingMessage(caption, hasUserCaption, description) - if combined == "" { - return nil, messageSendStatusError(errors.New("image understanding produced empty result"), "Couldn't analyze the image. Try again, or switch to a vision-capable model with !ai model.", "") - } - return dispatchTextOnly(combined) + if resp, err := oc.dispatchMediaUnderstandingFallback( + ctx, + visionModel, + visionFallback, + string(mediaURL), + mimeType, + encryptedFile, + caption, + hasUserCaption, + buildImageUnderstandingPrompt, + oc.analyzeImageWithModel, + buildImageUnderstandingMessage, + "Image understanding failed", + "image understanding produced empty result", + "Couldn't analyze the image. Try again, or switch to a vision-capable model with !ai model.", + dispatchTextOnly, + ); resp != nil || err != nil { + return resp, err } } // If model lacks audio but agent supports audio understanding, analyze audio first. if msgType == event.MsgAudio { audioModel, audioFallback := oc.resolveAudioModelForInput(ctx, meta) - if audioFallback && audioModel != "" { - analysisPrompt := buildAudioUnderstandingPrompt(caption, hasUserCaption) - transcript, err := oc.analyzeAudioWithModel(ctx, audioModel, string(mediaURL), mimeType, encryptedFile, analysisPrompt) - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Audio understanding failed") - return nil, messageSendStatusError(err, "Couldn't analyze the audio. Try again, or switch to an audio-capable model with !ai model.", "") - } - - combined := buildAudioUnderstandingMessage(caption, hasUserCaption, transcript) - if combined == "" { - return nil, messageSendStatusError(errors.New("audio understanding produced empty result"), "Couldn't analyze the audio. Try again, or switch to an audio-capable model with !ai model.", "") - } - return dispatchTextOnly(combined) + if resp, err := oc.dispatchMediaUnderstandingFallback( + ctx, + audioModel, + audioFallback, + string(mediaURL), + mimeType, + encryptedFile, + caption, + hasUserCaption, + buildAudioUnderstandingPrompt, + oc.analyzeAudioWithModel, + buildAudioUnderstandingMessage, + "Audio understanding failed", + "audio understanding produced empty result", + "Couldn't analyze the audio. Try again, or switch to an audio-capable model with !ai model.", + dispatchTextOnly, + ); resp != nil || err != nil { + return resp, err } } @@ -886,6 +896,40 @@ func (oc *AIClient) handleMediaMessage( }, nil } +func (oc *AIClient) dispatchMediaUnderstandingFallback( + ctx context.Context, + model string, + fallback bool, + mediaURL string, + mimeType string, + encryptedFile *event.EncryptedFileInfo, + caption string, + hasUserCaption bool, + buildPrompt func(string, bool) string, + analyze func(context.Context, string, string, string, *event.EncryptedFileInfo, string) (string, error), + buildMessage func(string, bool, string) string, + failureLog string, + emptyResult string, + userError string, + dispatchTextOnly func(string) (*bridgev2.MatrixMessageResponse, error), +) (*bridgev2.MatrixMessageResponse, error) { + if !fallback || model == "" { + return nil, nil + } + analysisPrompt := buildPrompt(caption, hasUserCaption) + description, err := analyze(ctx, model, mediaURL, mimeType, encryptedFile, analysisPrompt) + if err != nil { + oc.loggerForContext(ctx).Warn().Err(err).Msg(failureLog) + return nil, messageSendStatusError(err, userError, "") + } + + combined := buildMessage(caption, hasUserCaption, description) + if combined == "" { + return nil, messageSendStatusError(errors.New(emptyResult), userError, "") + } + return dispatchTextOnly(combined) +} + func (oc *AIClient) handleTextFileMessage( ctx context.Context, msg *bridgev2.MatrixMessage, diff --git a/pkg/connector/heartbeat_execute.go b/pkg/connector/heartbeat_execute.go index d60d3a97..14866838 100644 --- a/pkg/connector/heartbeat_execute.go +++ b/pkg/connector/heartbeat_execute.go @@ -316,19 +316,9 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } if session == "" || strings.EqualFold(session, "main") || strings.EqualFold(session, "global") || (mainKey != "" && strings.EqualFold(session, mainKey)) { hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) - if hbSession.Entry != nil { - lastChannel := strings.TrimSpace(hbSession.Entry.LastChannel) - lastTo := strings.TrimSpace(hbSession.Entry.LastTo) - if lastTo != "" && strings.HasPrefix(lastTo, "!") && (lastChannel == "" || strings.EqualFold(lastChannel, "matrix")) { - if portal := oc.portalByRoomID(context.Background(), id.RoomID(lastTo)); portal != nil { - if meta := portalMeta(portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { - goto mainFallback - } - return portal, portal.MXID.String(), nil - } - } + if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { + return portal, portal.MXID.String(), nil } - mainFallback: if portal := oc.lastActivePortal(agentID); portal != nil { return portal, portal.MXID.String(), nil } @@ -343,19 +333,9 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } } hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) - if hbSession.Entry != nil { - lastChannel := strings.TrimSpace(hbSession.Entry.LastChannel) - lastTo := strings.TrimSpace(hbSession.Entry.LastTo) - if lastTo != "" && strings.HasPrefix(lastTo, "!") && (lastChannel == "" || strings.EqualFold(lastChannel, "matrix")) { - if portal := oc.portalByRoomID(context.Background(), id.RoomID(lastTo)); portal != nil { - if meta := portalMeta(portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { - goto finalFallback - } - return portal, portal.MXID.String(), nil - } - } + if portal := oc.heartbeatSessionPortalCandidate(agentID, hbSession); portal != nil { + return portal, portal.MXID.String(), nil } -finalFallback: if portal := oc.lastActivePortal(agentID); portal != nil { return portal, portal.MXID.String(), nil } @@ -365,6 +345,25 @@ finalFallback: return nil, "", errors.New("no session") } +func (oc *AIClient) heartbeatSessionPortalCandidate(agentID string, session heartbeatSessionResolution) *bridgev2.Portal { + if session.Entry == nil { + return nil + } + lastChannel := strings.TrimSpace(session.Entry.LastChannel) + lastTo := strings.TrimSpace(session.Entry.LastTo) + if lastTo == "" || !strings.HasPrefix(lastTo, "!") || (lastChannel != "" && !strings.EqualFold(lastChannel, "matrix")) { + return nil + } + portal := oc.portalByRoomID(context.Background(), id.RoomID(lastTo)) + if portal == nil { + return nil + } + if meta := portalMeta(portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { + return nil + } + return portal +} + func (oc *AIClient) shouldRunHeartbeatForFile(agentID string, reason string) bool { db := oc.bridgeDB() if db == nil || oc.UserLogin == nil || oc.UserLogin.Bridge == nil || oc.UserLogin.Bridge.DB == nil { diff --git a/pkg/connector/heartbeat_visibility.go b/pkg/connector/heartbeat_visibility.go index 1c97715d..9e5274b7 100644 --- a/pkg/connector/heartbeat_visibility.go +++ b/pkg/connector/heartbeat_visibility.go @@ -29,29 +29,25 @@ func resolveHeartbeatVisibility(cfg *Config, channel string) ResolvedHeartbeatVi UseIndicator: defaultHeartbeatVisibility.UseIndicator, } - if defaults != nil && defaults.Heartbeat != nil { - if defaults.Heartbeat.ShowOk != nil { - result.ShowOk = *defaults.Heartbeat.ShowOk - } - if defaults.Heartbeat.ShowAlerts != nil { - result.ShowAlerts = *defaults.Heartbeat.ShowAlerts - } - if defaults.Heartbeat.UseIndicator != nil { - result.UseIndicator = *defaults.Heartbeat.UseIndicator - } - } - - if perChannel != nil && perChannel.Heartbeat != nil { - if perChannel.Heartbeat.ShowOk != nil { - result.ShowOk = *perChannel.Heartbeat.ShowOk - } - if perChannel.Heartbeat.ShowAlerts != nil { - result.ShowAlerts = *perChannel.Heartbeat.ShowAlerts - } - if perChannel.Heartbeat.UseIndicator != nil { - result.UseIndicator = *perChannel.Heartbeat.UseIndicator - } + applyHeartbeatVisibility(&result, defaults.Heartbeat) + if perChannel != nil { + applyHeartbeatVisibility(&result, perChannel.Heartbeat) } return result } + +func applyHeartbeatVisibility(dst *ResolvedHeartbeatVisibility, cfg *ChannelHeartbeatVisibilityConfig) { + if dst == nil || cfg == nil { + return + } + if cfg.ShowOk != nil { + dst.ShowOk = *cfg.ShowOk + } + if cfg.ShowAlerts != nil { + dst.ShowAlerts = *cfg.ShowAlerts + } + if cfg.UseIndicator != nil { + dst.UseIndicator = *cfg.UseIndicator + } +} diff --git a/pkg/connector/media_send.go b/pkg/connector/media_send.go index 9dc4517e..25ecc5db 100644 --- a/pkg/connector/media_send.go +++ b/pkg/connector/media_send.go @@ -71,20 +71,7 @@ func (oc *AIClient) sendGeneratedMedia( } } - if msgType == event.MsgAudio { - if durationMs, waveform := analyzeAudio(data, mimeType); durationMs > 0 || len(waveform) > 0 { - if durationMs > 0 { - info["duration"] = durationMs - } - rawContent["org.matrix.msc1767.audio"] = map[string]any{ - "duration": durationMs, - "waveform": waveform, - } - } - if asVoice { - rawContent["org.matrix.msc3245.voice"] = map[string]any{} - } - } + populateAudioMessageContent(rawContent, info, data, mimeType, asVoice, msgType) if turnID != "" && metadataKey != "" { rawContent[metadataKey] = map[string]any{ @@ -114,3 +101,21 @@ func extensionForMIME(mimeType, defaultExt string, overrides map[string]string) } return defaultExt } + +func populateAudioMessageContent(rawContent map[string]any, info map[string]any, data []byte, mimeType string, asVoice bool, msgType event.MessageType) { + if msgType != event.MsgAudio { + return + } + if durationMs, waveform := analyzeAudio(data, mimeType); durationMs > 0 || len(waveform) > 0 { + if durationMs > 0 { + info["duration"] = durationMs + } + rawContent["org.matrix.msc1767.audio"] = map[string]any{ + "duration": durationMs, + "waveform": waveform, + } + } + if asVoice { + rawContent["org.matrix.msc3245.voice"] = map[string]any{} + } +} diff --git a/pkg/connector/messages.go b/pkg/connector/messages.go index f3bb9c53..c71ca0f6 100644 --- a/pkg/connector/messages.go +++ b/pkg/connector/messages.go @@ -429,26 +429,28 @@ func extractChatSystemText(content openai.ChatCompletionSystemMessageParamConten if content.OfString.Value != "" { return content.OfString.Value } - var parts []string - for _, part := range content.OfArrayOfContentParts { - if strings.TrimSpace(part.Text) != "" { - parts = append(parts, part.Text) - } - } - return strings.Join(parts, "\n") + return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { + return part.Text + }) } func extractChatDeveloperText(content openai.ChatCompletionDeveloperMessageParamContentUnion) string { if content.OfString.Value != "" { return content.OfString.Value } - var parts []string - for _, part := range content.OfArrayOfContentParts { - if strings.TrimSpace(part.Text) != "" { - parts = append(parts, part.Text) + return joinChatText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { + return part.Text + }) +} + +func joinChatText[T any](parts []T, extract func(T) string) string { + var values []string + for _, part := range parts { + if text := strings.TrimSpace(extract(part)); text != "" { + values = append(values, text) } } - return strings.Join(parts, "\n") + return strings.Join(values, "\n") } func inferPromptMimeTypeFromDataURL(value string) string { diff --git a/pkg/connector/portal_send.go b/pkg/connector/portal_send.go index 5829d40e..f10de140 100644 --- a/pkg/connector/portal_send.go +++ b/pkg/connector/portal_send.go @@ -43,7 +43,7 @@ func (oc *AIClient) sendViaPortal( return "", "", fmt.Errorf("invalid portal") } if msgID == "" { - msgID = newMessageID() + msgID = bridgeadapter.NewMessageID("ai") } ensureConvertedMessageParts(converted) sender := oc.senderForPortal(ctx, portal) diff --git a/pkg/connector/provider_openai.go b/pkg/connector/provider_openai.go index 3a5f824d..ea622ca0 100644 --- a/pkg/connector/provider_openai.go +++ b/pkg/connector/provider_openai.go @@ -59,14 +59,7 @@ func NewOpenAIProviderWithUserID(apiKey, baseURL, userID string, log zerolog.Log opts = append(opts, option.WithBaseURL(baseURL)) } - if userID != "" { - opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - q := req.URL.Query() - q.Set("user_id", userID) - req.URL.RawQuery = q.Encode() - return next(req) - })) - } + opts = appendUserIDOption(opts, userID) opts = append(opts, option.WithMiddleware(makeRequestTraceMiddleware(log))) client := openai.NewClient(opts...) @@ -82,6 +75,18 @@ func newOutboundRequestID() string { return "abr_" + random.String(12) } +func appendUserIDOption(opts []option.RequestOption, userID string) []option.RequestOption { + if userID == "" { + return opts + } + return append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + q := req.URL.Query() + q.Set("user_id", userID) + req.URL.RawQuery = q.Encode() + return next(req) + })) +} + func makeRequestTraceMiddleware(log zerolog.Logger) option.Middleware { traceLog := log.With().Str("component", "openai_http").Logger() return func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { @@ -164,15 +169,7 @@ func NewOpenAIProviderWithPDFPlugin(apiKey, baseURL, userID, pdfEngine string, h opts = append(opts, option.WithBaseURL(baseURL)) } - // Add user_id query parameter if provided - if userID != "" { - opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - q := req.URL.Query() - q.Set("user_id", userID) - req.URL.RawQuery = q.Encode() - return next(req) - })) - } + opts = appendUserIDOption(opts, userID) opts = httputil.AppendHeaderOptions(opts, headers) diff --git a/pkg/connector/remote_events.go b/pkg/connector/remote_events.go index cfbf7ba8..ff57e808 100644 --- a/pkg/connector/remote_events.go +++ b/pkg/connector/remote_events.go @@ -184,7 +184,3 @@ func NewAITextMessage( } } -// newMessageID generates a unique message ID for AI remote events. -func newMessageID() networkid.MessageID { - return bridgeadapter.NewMessageID("ai") -} diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index f9c9a963..b32486d2 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -3,7 +3,6 @@ package connector import ( "context" "encoding/json" - "fmt" "strings" "time" @@ -81,7 +80,7 @@ func (oc *AIClient) sendContinuationMessage(ctx context.Context, portal *bridgev } msg := &bridgeadapter.RemoteMessage{ Portal: portal.PortalKey, - ID: newMessageID(), + ID: bridgeadapter.NewMessageID("ai"), Sender: bridgev2.EventSender{Sender: senderID, SenderLogin: oc.UserLogin.ID}, Timestamp: time.Now(), LogKey: "ai_msg_id", @@ -139,7 +138,7 @@ func (oc *AIClient) sendInitialStreamMessage(ctx context.Context, portal *bridge eventRaw["m.relates_to"] = relatesTo } - msgID := newMessageID() + msgID := bridgeadapter.NewMessageID("ai") converted := &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ ID: networkid.PartID("0"), @@ -539,28 +538,7 @@ func buildSourceParts(cits []citations.SourceCitation, documents []citations.Sou seen := make(map[string]struct{}, len(cits)+len(documents)+len(previews)) appendURL := func(url, title string, providerMetadata map[string]any) { - url = strings.TrimSpace(url) - if url == "" { - return - } - seenKey := "url:" + url - if _, ok := seen[seenKey]; ok { - return - } - seen[seenKey] = struct{}{} - - part := map[string]any{ - "type": "source-url", - "sourceId": fmt.Sprintf("source-%d", len(parts)+1), - "url": url, - } - if title = strings.TrimSpace(title); title != "" { - part["title"] = title - } - if len(providerMetadata) > 0 { - part["providerMetadata"] = providerMetadata - } - parts = append(parts, part) + citations.AppendSourceURLPart(&parts, seen, url, title, providerMetadata) } for _, citation := range cits { @@ -586,31 +564,7 @@ func buildSourceParts(cits []citations.SourceCitation, documents []citations.Sou } for _, doc := range documents { - key := strings.TrimSpace(doc.ID) - if key == "" { - key = strings.TrimSpace(doc.Filename) - } - if key == "" { - key = strings.TrimSpace(doc.Title) - } - if key == "" { - continue - } - seenKey := "doc:" + key - if _, ok := seen[seenKey]; ok { - continue - } - seen[seenKey] = struct{}{} - part := map[string]any{ - "type": "source-document", - "sourceId": fmt.Sprintf("source-%d", len(parts)+1), - "mediaType": doc.MediaType, - "title": doc.Title, - } - if filename := strings.TrimSpace(doc.Filename); filename != "" { - part["filename"] = filename - } - parts = append(parts, part) + citations.AppendSourceDocumentPart(&parts, seen, doc) } for _, preview := range previews { diff --git a/pkg/connector/scheduler_db.go b/pkg/connector/scheduler_db.go index 28e81a11..97a18405 100644 --- a/pkg/connector/scheduler_db.go +++ b/pkg/connector/scheduler_db.go @@ -3,6 +3,7 @@ package connector import ( "context" "database/sql" + "fmt" "strings" "go.mau.fi/util/dbutil" @@ -383,95 +384,19 @@ func flattenHeartbeatActiveHours(cfg *HeartbeatActiveHoursConfig) (string, strin } func loadCronRunKeys(ctx context.Context, scope *schedulerDBScope, jobID string) ([]string, error) { - rows, err := scope.db.Query(ctx, ` - SELECT run_key - FROM ai_cron_job_run_keys - WHERE bridge_id=$1 AND login_id=$2 AND job_id=$3 - ORDER BY run_index - `, scope.bridgeID, scope.loginID, jobID) - if err != nil { - return nil, err - } - defer rows.Close() - - var keys []string - for rows.Next() { - var key string - if err := rows.Scan(&key); err != nil { - return nil, err - } - keys = append(keys, key) - } - return keys, rows.Err() + return loadIndexedRunKeys(ctx, scope, "ai_cron_job_run_keys", "job_id", jobID) } func replaceCronRunKeys(ctx context.Context, scope *schedulerDBScope, jobID string, keys []string) error { - if _, err := scope.db.Exec(ctx, ` - DELETE FROM ai_cron_job_run_keys - WHERE bridge_id=$1 AND login_id=$2 AND job_id=$3 - `, scope.bridgeID, scope.loginID, jobID); err != nil { - return err - } - for idx, key := range keys { - key = strings.TrimSpace(key) - if key == "" { - continue - } - if _, err := scope.db.Exec(ctx, ` - INSERT INTO ai_cron_job_run_keys ( - bridge_id, login_id, job_id, run_index, run_key - ) VALUES ($1, $2, $3, $4, $5) - `, scope.bridgeID, scope.loginID, jobID, idx, key); err != nil { - return err - } - } - return nil + return replaceIndexedRunKeys(ctx, scope, "ai_cron_job_run_keys", "job_id", jobID, keys) } func loadHeartbeatRunKeys(ctx context.Context, scope *schedulerDBScope, agentID string) ([]string, error) { - rows, err := scope.db.Query(ctx, ` - SELECT run_key - FROM ai_managed_heartbeat_run_keys - WHERE bridge_id=$1 AND login_id=$2 AND agent_id=$3 - ORDER BY run_index - `, scope.bridgeID, scope.loginID, agentID) - if err != nil { - return nil, err - } - defer rows.Close() - - var keys []string - for rows.Next() { - var key string - if err := rows.Scan(&key); err != nil { - return nil, err - } - keys = append(keys, key) - } - return keys, rows.Err() + return loadIndexedRunKeys(ctx, scope, "ai_managed_heartbeat_run_keys", "agent_id", agentID) } func replaceHeartbeatRunKeys(ctx context.Context, scope *schedulerDBScope, agentID string, keys []string) error { - if _, err := scope.db.Exec(ctx, ` - DELETE FROM ai_managed_heartbeat_run_keys - WHERE bridge_id=$1 AND login_id=$2 AND agent_id=$3 - `, scope.bridgeID, scope.loginID, agentID); err != nil { - return err - } - for idx, key := range keys { - key = strings.TrimSpace(key) - if key == "" { - continue - } - if _, err := scope.db.Exec(ctx, ` - INSERT INTO ai_managed_heartbeat_run_keys ( - bridge_id, login_id, agent_id, run_index, run_key - ) VALUES ($1, $2, $3, $4, $5) - `, scope.bridgeID, scope.loginID, agentID, idx, key); err != nil { - return err - } - } - return nil + return replaceIndexedRunKeys(ctx, scope, "ai_managed_heartbeat_run_keys", "agent_id", agentID, keys) } func nullableInt64Pointer(value sql.NullInt64) *int64 { @@ -527,47 +452,86 @@ func nullableBoolValue(value *bool) any { } func deleteMissingCronRows(ctx context.Context, scope *schedulerDBScope, keep map[string]struct{}) error { - rows, err := scope.db.Query(ctx, `SELECT job_id FROM ai_cron_jobs WHERE bridge_id=$1 AND login_id=$2`, scope.bridgeID, scope.loginID) + return deleteMissingScopedRows(ctx, scope, keep, "ai_cron_jobs", "job_id", "ai_cron_job_run_keys") +} + +func deleteMissingHeartbeatRows(ctx context.Context, scope *schedulerDBScope, keep map[string]struct{}) error { + return deleteMissingScopedRows(ctx, scope, keep, "ai_managed_heartbeats", "agent_id", "ai_managed_heartbeat_run_keys") +} + +func loadIndexedRunKeys(ctx context.Context, scope *schedulerDBScope, table, idColumn, idValue string) ([]string, error) { + rows, err := scope.db.Query(ctx, fmt.Sprintf(` + SELECT run_key + FROM %s + WHERE bridge_id=$1 AND login_id=$2 AND %s=$3 + ORDER BY run_index + `, table, idColumn), scope.bridgeID, scope.loginID, idValue) if err != nil { - return err + return nil, err } defer rows.Close() + + var keys []string for rows.Next() { - var jobID string - if err := rows.Scan(&jobID); err != nil { - return err + var key string + if err := rows.Scan(&key); err != nil { + return nil, err } - if _, ok := keep[strings.TrimSpace(jobID)]; ok { + keys = append(keys, key) + } + return keys, rows.Err() +} + +func replaceIndexedRunKeys(ctx context.Context, scope *schedulerDBScope, table, idColumn, idValue string, keys []string) error { + if _, err := scope.db.Exec(ctx, fmt.Sprintf(` + DELETE FROM %s + WHERE bridge_id=$1 AND login_id=$2 AND %s=$3 + `, table, idColumn), scope.bridgeID, scope.loginID, idValue); err != nil { + return err + } + for idx, key := range keys { + key = strings.TrimSpace(key) + if key == "" { continue } - if _, err := scope.db.Exec(ctx, `DELETE FROM ai_cron_jobs WHERE bridge_id=$1 AND login_id=$2 AND job_id=$3`, scope.bridgeID, scope.loginID, jobID); err != nil { - return err - } - if _, err := scope.db.Exec(ctx, `DELETE FROM ai_cron_job_run_keys WHERE bridge_id=$1 AND login_id=$2 AND job_id=$3`, scope.bridgeID, scope.loginID, jobID); err != nil { + if _, err := scope.db.Exec(ctx, fmt.Sprintf(` + INSERT INTO %s ( + bridge_id, login_id, %s, run_index, run_key + ) VALUES ($1, $2, $3, $4, $5) + `, table, idColumn), scope.bridgeID, scope.loginID, idValue, idx, key); err != nil { return err } } - return rows.Err() + return nil } -func deleteMissingHeartbeatRows(ctx context.Context, scope *schedulerDBScope, keep map[string]struct{}) error { - rows, err := scope.db.Query(ctx, `SELECT agent_id FROM ai_managed_heartbeats WHERE bridge_id=$1 AND login_id=$2`, scope.bridgeID, scope.loginID) +func deleteMissingScopedRows(ctx context.Context, scope *schedulerDBScope, keep map[string]struct{}, entityTable, idColumn, runKeyTable string) error { + rows, err := scope.db.Query(ctx, fmt.Sprintf( + `SELECT %s FROM %s WHERE bridge_id=$1 AND login_id=$2`, + idColumn, entityTable, + ), scope.bridgeID, scope.loginID) if err != nil { return err } defer rows.Close() for rows.Next() { - var agentID string - if err := rows.Scan(&agentID); err != nil { + var idValue string + if err := rows.Scan(&idValue); err != nil { return err } - if _, ok := keep[strings.TrimSpace(agentID)]; ok { + if _, ok := keep[strings.TrimSpace(idValue)]; ok { continue } - if _, err := scope.db.Exec(ctx, `DELETE FROM ai_managed_heartbeats WHERE bridge_id=$1 AND login_id=$2 AND agent_id=$3`, scope.bridgeID, scope.loginID, agentID); err != nil { + if _, err := scope.db.Exec(ctx, fmt.Sprintf( + `DELETE FROM %s WHERE bridge_id=$1 AND login_id=$2 AND %s=$3`, + entityTable, idColumn, + ), scope.bridgeID, scope.loginID, idValue); err != nil { return err } - if _, err := scope.db.Exec(ctx, `DELETE FROM ai_managed_heartbeat_run_keys WHERE bridge_id=$1 AND login_id=$2 AND agent_id=$3`, scope.bridgeID, scope.loginID, agentID); err != nil { + if _, err := scope.db.Exec(ctx, fmt.Sprintf( + `DELETE FROM %s WHERE bridge_id=$1 AND login_id=$2 AND %s=$3`, + runKeyTable, idColumn, + ), scope.bridgeID, scope.loginID, idValue); err != nil { return err } } diff --git a/pkg/connector/session_store.go b/pkg/connector/session_store.go index ae871987..805e2337 100644 --- a/pkg/connector/session_store.go +++ b/pkg/connector/session_store.go @@ -68,17 +68,14 @@ func sessionStoreLock(ref sessionStoreRef, sessionKey string) *sync.Mutex { } func (oc *AIClient) sessionDBScope() *sessionDBScope { - if oc == nil || oc.UserLogin == nil || oc.UserLogin.Bridge == nil { - return nil - } - db := oc.bridgeDB() - if db == nil || oc.UserLogin.Bridge.DB == nil { + db, bridgeID, loginID := loginDBContext(oc) + if db == nil { return nil } return &sessionDBScope{ db: db, - bridgeID: string(oc.UserLogin.Bridge.DB.BridgeID), - loginID: string(oc.UserLogin.ID), + bridgeID: bridgeID, + loginID: loginID, } } diff --git a/pkg/connector/stream_events.go b/pkg/connector/stream_events.go index 47e74556..0f5445f6 100644 --- a/pkg/connector/stream_events.go +++ b/pkg/connector/stream_events.go @@ -2,7 +2,6 @@ package connector import ( "context" - "strings" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" @@ -59,19 +58,16 @@ func (oc *AIClient) emitStreamEvent( state *streamingState, part map[string]any, ) { - if portal == nil || portal.MXID == "" || state == nil || state.suppressSend { + if state == nil { return } - if !state.loggedStreamStart { - state.loggedStreamStart = true - oc.loggerForContext(ctx).Info(). - Stringer("room_id", portal.MXID). - Str("turn_id", strings.TrimSpace(state.turnID)). - Msg("Streaming events") - } - session := oc.ensureStreamSession(ctx, portal, state) - if session == nil { - return - } - session.EmitPart(ctx, part) + streamtransport.EmitStreamEvent(ctx, portal, streamtransport.StreamEventState{ + TurnID: state.turnID, + SuppressSend: state.suppressSend, + LoggedStart: &state.loggedStreamStart, + EnsureSession: func() *streamtransport.StreamSession { + return oc.ensureStreamSession(ctx, portal, state) + }, + Logger: oc.loggerForContext(ctx), + }, part) } diff --git a/pkg/connector/stream_transport.go b/pkg/connector/stream_transport.go index df9dad76..731a042c 100644 --- a/pkg/connector/stream_transport.go +++ b/pkg/connector/stream_transport.go @@ -32,22 +32,15 @@ func (oc *AIClient) sendDebouncedStreamEdit(ctx context.Context, portal *bridgev TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "ai_edit_target", - PreBuilt: &bridgev2.ConvertedEdit{ - ModifiedParts: []*bridgev2.ConvertedEditPart{{ - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgText, - Body: content.Body, - Format: content.Format, - FormattedBody: content.FormattedBody, - }, - Extra: map[string]any{"m.mentions": map[string]any{}}, - TopLevelExtra: map[string]any{ - "com.beeper.dont_render_edited": true, - "m.mentions": map[string]any{}, - }, - }}, - }, + PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: content.Body, + Format: content.Format, + FormattedBody: content.FormattedBody, + }, map[string]any{ + "com.beeper.dont_render_edited": true, + "m.mentions": map[string]any{}, + }), }) return nil } diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index d43681eb..91c97efd 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -55,33 +55,41 @@ func (oc *AIClient) resolveSubagentAllowlist(ctx context.Context, requesterAgent } func resolveSubagentModel(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - if trimmed := strings.TrimSpace(override); trimmed != "" { - return trimmed - } - if agent != nil && agent.Subagents != nil { - if trimmed := strings.TrimSpace(agent.Subagents.Model); trimmed != "" { - return trimmed - } - } - if defaults != nil { - if trimmed := strings.TrimSpace(defaults.Model); trimmed != "" { - return trimmed - } - } - return "" + return firstNonEmptyTrimmed( + override, + subagentStringValue(agent, func(cfg *agents.SubagentConfig) string { return cfg.Model }), + subagentStringValue(defaults, func(cfg *agents.SubagentConfig) string { return cfg.Model }), + ) } func resolveSubagentThinking(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - if trimmed := strings.TrimSpace(override); trimmed != "" { - return trimmed - } - if agent != nil && agent.Subagents != nil { - if trimmed := strings.TrimSpace(agent.Subagents.Thinking); trimmed != "" { - return trimmed + return firstNonEmptyTrimmed( + override, + subagentStringValue(agent, func(cfg *agents.SubagentConfig) string { return cfg.Thinking }), + subagentStringValue(defaults, func(cfg *agents.SubagentConfig) string { return cfg.Thinking }), + ) +} + +func subagentStringValue(source any, extract func(*agents.SubagentConfig) string) string { + switch cfg := source.(type) { + case *agents.AgentDefinition: + if cfg == nil { + return "" + } + return subagentStringValue(cfg.Subagents, extract) + case *agents.SubagentConfig: + if cfg == nil { + return "" } + return extract(cfg) + default: + return "" } - if defaults != nil { - if trimmed := strings.TrimSpace(defaults.Thinking); trimmed != "" { +} + +func firstNonEmptyTrimmed(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { return trimmed } } diff --git a/pkg/connector/system_events_db.go b/pkg/connector/system_events_db.go index ef90f2b9..af3bd1e1 100644 --- a/pkg/connector/system_events_db.go +++ b/pkg/connector/system_events_db.go @@ -21,17 +21,14 @@ type systemEventsDBScope struct { } func systemEventsScope(client *AIClient) *systemEventsDBScope { - if client == nil || client.UserLogin == nil || client.UserLogin.Bridge == nil { - return nil - } - db := client.bridgeDB() - if db == nil || client.UserLogin.Bridge.DB == nil { + db, bridgeID, loginID := loginDBContext(client) + if db == nil { return nil } return &systemEventsDBScope{ db: db, - bridgeID: string(client.UserLogin.Bridge.DB.BridgeID), - loginID: string(client.UserLogin.ID), + bridgeID: bridgeID, + loginID: loginID, } } diff --git a/pkg/connector/tools.go b/pkg/connector/tools.go index 63c77a3a..fc0faa0d 100644 --- a/pkg/connector/tools.go +++ b/pkg/connector/tools.go @@ -169,17 +169,7 @@ func firstNonEmptyString(values ...any) string { } func messageTypeForMIME(mimeType string) event.MessageType { - mimeType = strings.ToLower(strings.TrimSpace(mimeType)) - switch { - case strings.HasPrefix(mimeType, "image/"): - return event.MsgImage - case strings.HasPrefix(mimeType, "audio/"): - return event.MsgAudio - case strings.HasPrefix(mimeType, "video/"): - return event.MsgVideo - default: - return event.MsgFile - } + return media.MessageTypeForMIME(mimeType) } func resolveMessageMedia(ctx context.Context, btc *BridgeToolContext, bufferInput, mediaInput string) ([]byte, string, error) { @@ -564,20 +554,7 @@ func executeMessageSend(ctx context.Context, args map[string]any, btc *BridgeToo } } - if msgType == event.MsgAudio { - if durationMs, waveform := analyzeAudio(data, mimeType); durationMs > 0 || len(waveform) > 0 { - if durationMs > 0 { - info["duration"] = durationMs - } - rawContent["org.matrix.msc1767.audio"] = map[string]any{ - "duration": durationMs, - "waveform": waveform, - } - } - if asVoice { - rawContent["org.matrix.msc3245.voice"] = map[string]any{} - } - } + populateAudioMessageContent(rawContent, info, data, mimeType, asVoice, msgType) converted := &bridgev2.ConvertedMessage{ Parts: []*bridgev2.ConvertedMessagePart{{ diff --git a/pkg/runtime/chat_content.go b/pkg/runtime/chat_content.go index 57623af8..14e836fb 100644 --- a/pkg/runtime/chat_content.go +++ b/pkg/runtime/chat_content.go @@ -37,74 +37,62 @@ func ExtractSystemContent(content openai.ChatCompletionSystemMessageParamContent if content.OfString.Value != "" { return content.OfString.Value } - if len(content.OfArrayOfContentParts) > 0 { - var sb strings.Builder - for _, part := range content.OfArrayOfContentParts { - sb.WriteString(part.Text) - } - return sb.String() - } - return "" + return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { + return part.Text + }) } func ExtractUserContent(content openai.ChatCompletionUserMessageParamContentUnion) string { if content.OfString.Value != "" { return content.OfString.Value } - if len(content.OfArrayOfContentParts) > 0 { - var sb strings.Builder - for _, part := range content.OfArrayOfContentParts { - if part.OfText != nil { - sb.WriteString(part.OfText.Text) - } + return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartUnionParam) string { + if part.OfText == nil { + return "" } - return sb.String() - } - return "" + return part.OfText.Text + }) } func ExtractAssistantContent(content openai.ChatCompletionAssistantMessageParamContentUnion) string { if content.OfString.Value != "" { return content.OfString.Value } - if len(content.OfArrayOfContentParts) > 0 { - var sb strings.Builder - for _, part := range content.OfArrayOfContentParts { - if part.OfText != nil { - sb.WriteString(part.OfText.Text) - } + return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) string { + if part.OfText == nil { + return "" } - return sb.String() - } - return "" + return part.OfText.Text + }) } func ExtractDeveloperContent(content openai.ChatCompletionDeveloperMessageParamContentUnion) string { if content.OfString.Value != "" { return content.OfString.Value } - if len(content.OfArrayOfContentParts) > 0 { - var sb strings.Builder - for _, part := range content.OfArrayOfContentParts { - sb.WriteString(part.Text) - } - return sb.String() - } - return "" + return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { + return part.Text + }) } func ExtractToolContent(content openai.ChatCompletionToolMessageParamContentUnion) string { if content.OfString.Value != "" { return content.OfString.Value } - if len(content.OfArrayOfContentParts) > 0 { - var sb strings.Builder - for _, part := range content.OfArrayOfContentParts { - sb.WriteString(part.Text) - } - return sb.String() + return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartTextParam) string { + return part.Text + }) +} + +func joinContentText[T any](parts []T, extract func(T) string) string { + if len(parts) == 0 { + return "" + } + var sb strings.Builder + for _, part := range parts { + sb.WriteString(extract(part)) } - return "" + return sb.String() } // EstimateMessageChars approximates character usage for one prompt message. diff --git a/pkg/shared/citations/citations.go b/pkg/shared/citations/citations.go index 39d29e61..fe775af5 100644 --- a/pkg/shared/citations/citations.go +++ b/pkg/shared/citations/citations.go @@ -134,58 +134,71 @@ func BuildSourceParts(citations []SourceCitation, documents []SourceDocument) [] parts := make([]map[string]any, 0, len(citations)+len(documents)) seen := make(map[string]struct{}, len(citations)+len(documents)) for _, c := range citations { - url := strings.TrimSpace(c.URL) - if url == "" { - continue - } - seenKey := "url:" + url - if _, ok := seen[seenKey]; ok { - continue - } - seen[seenKey] = struct{}{} - p := map[string]any{ - "type": "source-url", - "sourceId": fmt.Sprintf("source-%d", len(parts)+1), - "url": url, - } - if title := strings.TrimSpace(c.Title); title != "" { - p["title"] = title - } - if meta := ProviderMetadata(c); len(meta) > 0 { - p["providerMetadata"] = meta - } - parts = append(parts, p) + AppendSourceURLPart(&parts, seen, c.URL, c.Title, ProviderMetadata(c)) } for _, d := range documents { - key := strings.TrimSpace(d.ID) - if key == "" { - key = strings.TrimSpace(d.Filename) - } - if key == "" { - key = strings.TrimSpace(d.Title) - } - if key == "" { - continue - } - seenKey := "doc:" + key - if _, ok := seen[seenKey]; ok { - continue - } - seen[seenKey] = struct{}{} - p := map[string]any{ - "type": "source-document", - "sourceId": fmt.Sprintf("source-%d", len(parts)+1), - "mediaType": d.MediaType, - "title": d.Title, - } - if fn := strings.TrimSpace(d.Filename); fn != "" { - p["filename"] = fn - } - parts = append(parts, p) + AppendSourceDocumentPart(&parts, seen, d) } return parts } +// AppendSourceURLPart appends a deduplicated source-url part to parts. +func AppendSourceURLPart(parts *[]map[string]any, seen map[string]struct{}, url, title string, providerMetadata map[string]any) { + url = strings.TrimSpace(url) + if url == "" { + return + } + seenKey := "url:" + url + if _, ok := seen[seenKey]; ok { + return + } + seen[seenKey] = struct{}{} + part := map[string]any{ + "type": "source-url", + "sourceId": fmt.Sprintf("source-%d", len(*parts)+1), + "url": url, + } + if title = strings.TrimSpace(title); title != "" { + part["title"] = title + } + if len(providerMetadata) > 0 { + part["providerMetadata"] = providerMetadata + } + *parts = append(*parts, part) +} + +// AppendSourceDocumentPart appends a deduplicated source-document part to parts. +func AppendSourceDocumentPart(parts *[]map[string]any, seen map[string]struct{}, doc SourceDocument) { + key := sourceDocumentKey(doc) + if key == "" { + return + } + seenKey := "doc:" + key + if _, ok := seen[seenKey]; ok { + return + } + seen[seenKey] = struct{}{} + part := map[string]any{ + "type": "source-document", + "sourceId": fmt.Sprintf("source-%d", len(*parts)+1), + "mediaType": doc.MediaType, + "title": doc.Title, + } + if filename := strings.TrimSpace(doc.Filename); filename != "" { + part["filename"] = filename + } + *parts = append(*parts, part) +} + +func sourceDocumentKey(doc SourceDocument) string { + for _, candidate := range []string{doc.ID, doc.Filename, doc.Title} { + if key := strings.TrimSpace(candidate); key != "" { + return key + } + } + return "" +} + // GeneratedFilesToParts converts generated files into stream-event parts. func GeneratedFilesToParts(files []GeneratedFilePart) []map[string]any { if len(files) == 0 { diff --git a/pkg/shared/media/message_type.go b/pkg/shared/media/message_type.go new file mode 100644 index 00000000..f0188dc6 --- /dev/null +++ b/pkg/shared/media/message_type.go @@ -0,0 +1,21 @@ +package media + +import ( + "strings" + + "maunium.net/go/mautrix/event" +) + +func MessageTypeForMIME(mimeType string) event.MessageType { + mimeType = strings.ToLower(strings.TrimSpace(mimeType)) + switch { + case strings.HasPrefix(mimeType, "image/"): + return event.MsgImage + case strings.HasPrefix(mimeType, "audio/"): + return event.MsgAudio + case strings.HasPrefix(mimeType, "video/"): + return event.MsgVideo + default: + return event.MsgFile + } +} diff --git a/pkg/shared/streamtransport/debounced_edit.go b/pkg/shared/streamtransport/debounced_edit.go index ae352be2..3fc328c5 100644 --- a/pkg/shared/streamtransport/debounced_edit.go +++ b/pkg/shared/streamtransport/debounced_edit.go @@ -3,6 +3,7 @@ package streamtransport import ( "strings" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" ) @@ -49,3 +50,23 @@ func BuildDebouncedEditContent(p DebouncedEditParams) *DebouncedEditContent { Format: rendered.Format, } } + +// BuildConvertedEdit wraps rendered message content into a standard Matrix edit. +func BuildConvertedEdit(content *event.MessageEventContent, topLevelExtra map[string]any) *bridgev2.ConvertedEdit { + if content == nil { + return nil + } + return &bridgev2.ConvertedEdit{ + ModifiedParts: []*bridgev2.ConvertedEditPart{{ + Type: event.EventMessage, + Content: &event.MessageEventContent{ + MsgType: content.MsgType, + Body: content.Body, + Format: content.Format, + FormattedBody: content.FormattedBody, + }, + Extra: map[string]any{"m.mentions": map[string]any{}}, + TopLevelExtra: topLevelExtra, + }}, + } +} diff --git a/pkg/shared/streamtransport/session.go b/pkg/shared/streamtransport/session.go index a8ad0d09..25a4fe86 100644 --- a/pkg/shared/streamtransport/session.go +++ b/pkg/shared/streamtransport/session.go @@ -15,6 +15,14 @@ import ( "github.com/beeper/ai-bridge/pkg/matrixevents" ) +type StreamEventState struct { + TurnID string + SuppressSend bool + LoggedStart *bool + EnsureSession func() *StreamSession + Logger *zerolog.Logger +} + const ( // Fixed debounce interval for fallback post+edit streaming. debounceInterval = 200 * time.Millisecond @@ -93,6 +101,30 @@ func NewStreamSession(params StreamSessionParams) *StreamSession { return s } +// EmitStreamEvent logs the stream start once and emits a part through a session. +func EmitStreamEvent(ctx context.Context, portal *bridgev2.Portal, state StreamEventState, part map[string]any) { + if portal == nil || portal.MXID == "" || state.SuppressSend { + return + } + if state.LoggedStart != nil && !*state.LoggedStart { + *state.LoggedStart = true + if state.Logger != nil { + state.Logger.Info(). + Stringer("room_id", portal.MXID). + Str("turn_id", strings.TrimSpace(state.TurnID)). + Msg("Streaming events") + } + } + if state.EnsureSession == nil { + return + } + session := state.EnsureSession() + if session == nil { + return + } + session.EmitPart(ctx, part) +} + func (s *StreamSession) IsClosed() bool { return s == nil || s.closed.Load() } diff --git a/pkg/shared/streamui/emitter.go b/pkg/shared/streamui/emitter.go index 1fe28b26..7b0ca99b 100644 --- a/pkg/shared/streamui/emitter.go +++ b/pkg/shared/streamui/emitter.go @@ -139,25 +139,22 @@ func (e *Emitter) EmitUIStepFinish(ctx context.Context, portal *bridgev2.Portal) // EnsureUIText sends "text-start" the first time it's called for a turn. func (e *Emitter) EnsureUIText(ctx context.Context, portal *bridgev2.Portal) { - if e.State.UITextID != "" { - return - } - e.State.UITextID = fmt.Sprintf("text-%s", e.State.TurnID) - e.Emit(ctx, portal, map[string]any{ - "type": "text-start", - "id": e.State.UITextID, - }) + e.ensureUIPartStarted(ctx, portal, &e.State.UITextID, "text") } // EnsureUIReasoning sends "reasoning-start" the first time it's called for a turn. func (e *Emitter) EnsureUIReasoning(ctx context.Context, portal *bridgev2.Portal) { - if e.State.UIReasoningID != "" { + e.ensureUIPartStarted(ctx, portal, &e.State.UIReasoningID, "reasoning") +} + +func (e *Emitter) ensureUIPartStarted(ctx context.Context, portal *bridgev2.Portal, idRef *string, partType string) { + if idRef == nil || *idRef != "" { return } - e.State.UIReasoningID = fmt.Sprintf("reasoning-%s", e.State.TurnID) + *idRef = fmt.Sprintf("%s-%s", partType, e.State.TurnID) e.Emit(ctx, portal, map[string]any{ - "type": "reasoning-start", - "id": e.State.UIReasoningID, + "type": partType + "-start", + "id": *idRef, }) } diff --git a/pkg/shared/streamui/recorder.go b/pkg/shared/streamui/recorder.go index a48c7c6c..420c6588 100644 --- a/pkg/shared/streamui/recorder.go +++ b/pkg/shared/streamui/recorder.go @@ -35,15 +35,7 @@ func ApplyChunk(state *UIState, chunk map[string]any) { if partID == "" { return } - part := map[string]any{ - "type": "text", - "text": "", - "state": "streaming", - } - if providerMetadata := jsonutil.DeepCloneMap(jsonutil.ToMap(chunk["providerMetadata"])); len(providerMetadata) > 0 { - part["providerMetadata"] = providerMetadata - } - state.UITextPartIndexByID[partID] = appendPart(state, part) + state.UITextPartIndexByID[partID] = appendPart(state, newStreamingTextPart("text", jsonutil.DeepCloneMap(jsonutil.ToMap(chunk["providerMetadata"])))) case "text-delta": partID := strings.TrimSpace(stringValue(chunk["id"])) if partID == "" { @@ -65,15 +57,7 @@ func ApplyChunk(state *UIState, chunk map[string]any) { if partID == "" { return } - part := map[string]any{ - "type": "reasoning", - "text": "", - "state": "streaming", - } - if providerMetadata := jsonutil.DeepCloneMap(jsonutil.ToMap(chunk["providerMetadata"])); len(providerMetadata) > 0 { - part["providerMetadata"] = providerMetadata - } - state.UIReasoningPartIndexByID[partID] = appendPart(state, part) + state.UIReasoningPartIndexByID[partID] = appendPart(state, newStreamingTextPart("reasoning", jsonutil.DeepCloneMap(jsonutil.ToMap(chunk["providerMetadata"])))) case "reasoning-delta": partID := strings.TrimSpace(stringValue(chunk["id"])) if partID == "" { @@ -283,14 +267,7 @@ func ensureTextPart(state *UIState, partID string, providerMetadata map[string]a if idx, ok := state.UITextPartIndexByID[partID]; ok { return getPartAt(state, idx) } - part := map[string]any{ - "type": "text", - "text": "", - "state": "streaming", - } - if len(providerMetadata) > 0 { - part["providerMetadata"] = providerMetadata - } + part := newStreamingTextPart("text", providerMetadata) state.UITextPartIndexByID[partID] = appendPart(state, part) return part } @@ -299,15 +276,20 @@ func ensureReasoningPart(state *UIState, partID string, providerMetadata map[str if idx, ok := state.UIReasoningPartIndexByID[partID]; ok { return getPartAt(state, idx) } + part := newStreamingTextPart("reasoning", providerMetadata) + state.UIReasoningPartIndexByID[partID] = appendPart(state, part) + return part +} + +func newStreamingTextPart(partType string, providerMetadata map[string]any) map[string]any { part := map[string]any{ - "type": "reasoning", + "type": partType, "text": "", "state": "streaming", } if len(providerMetadata) > 0 { part["providerMetadata"] = providerMetadata } - state.UIReasoningPartIndexByID[partID] = appendPart(state, part) return part } From 246729eb7cdb5c920da70011d38a86f5ad5a1dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:28:41 +0100 Subject: [PATCH 03/23] Refactor: extract helpers and consolidate logic Introduce several helper functions and generic utilities to remove duplicated code and centralize behavior across connector and bridge code. Key changes: - bridges/codex: extracted emitTrimmedProviderToolTextOutput to unify trimmed tool output emission and state updates. - bridges/opencode: moved UI metadata construction into opencodeUIMessageMetadata and simplified merging logic. - pkg/connector/chat: extracted ensureExistingChatPortalReady to consolidate portal readiness, MXID creation and logging. - pkg/connector/error_logging: added logProviderFailure helper to standardize provider error logging for Responses and Chat Completions. - pkg/connector/pending_queue & room_runs: added pendingQueueItemsConflict to centralize duplicate detection logic. - pkg/connector/subagent_conversion: introduced a generic convertSubagentConfig to reduce repetitive conversion code between subagent config types. - pkg/connector/tool_configured: created effectiveToolConfig generic helper to unify loading, token application, and defaulting for tool configs. Overall intent: reduce duplication, improve maintainability, and make logging/config handling more consistent. --- bridges/codex/client.go | 28 +++++++++---- bridges/opencode/stream_canonical.go | 31 ++++++-------- pkg/connector/chat.go | 55 +++++++++++-------------- pkg/connector/error_logging.go | 26 ++++++++---- pkg/connector/pending_queue.go | 22 ++++++---- pkg/connector/response_retry.go | 30 +++++--------- pkg/connector/room_runs.go | 9 +--- pkg/connector/streaming_output_items.go | 50 +++++++++------------- pkg/connector/subagent_conversion.go | 53 ++++++++++++++++-------- pkg/connector/tool_configured.go | 54 +++++++++++++++--------- pkg/runtime/chat_content.go | 24 ++++++----- 11 files changed, 204 insertions(+), 178 deletions(-) diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 3df621d8..fda23bb2 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -1168,12 +1168,9 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 Text string `json:"text"` } _ = json.Unmarshal(raw, &it) - text := strings.TrimSpace(it.Text) - if text == "" { + if !cc.emitTrimmedProviderToolTextOutput(ctx, portal, state, itemID, "plan", "text", it.Text) { return } - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, text, true, false) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "plan", map[string]any{"text": text})) case "enteredReviewMode": var it map[string]any _ = json.Unmarshal(raw, &it) @@ -1184,12 +1181,9 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 Review string `json:"review"` } _ = json.Unmarshal(raw, &it) - text := strings.TrimSpace(it.Review) - if text == "" { + if !cc.emitTrimmedProviderToolTextOutput(ctx, portal, state, itemID, "review", "review", it.Review) { return } - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, text, true, false) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "review", map[string]any{"review": text})) case "contextCompaction": var it map[string]any _ = json.Unmarshal(raw, &it) @@ -1199,6 +1193,24 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 } } +func (cc *CodexClient) emitTrimmedProviderToolTextOutput( + ctx context.Context, + portal *bridgev2.Portal, + state *streamingState, + itemID string, + toolName string, + field string, + value string, +) bool { + text := strings.TrimSpace(value) + if text == "" { + return false + } + cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, text, true, false) + state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, toolName, map[string]any{field: text})) + return true +} + func (cc *CodexClient) ensureRPC(ctx context.Context) error { cc.rpcMu.Lock() defer cc.rpcMu.Unlock() diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index 498a7945..a58fa893 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -82,27 +82,21 @@ func (oc *OpenCodeClient) currentCanonicalUIMessage(state *openCodeStreamState) return nil } uiMessage := streamui.SnapshotCanonicalUIMessage(&state.ui) + metadata := opencodeUIMessageMetadata(state) if len(uiMessage) == 0 { return msgconv.BuildUIMessage(msgconv.UIMessageParams{ - TurnID: state.turnID, - Role: "assistant", - Metadata: msgconv.BuildUIMessageMetadata(msgconv.UIMessageMetadataParams{ - TurnID: state.turnID, - AgentID: state.agentID, - Model: state.modelID, - FinishReason: state.finishReason, - PromptTokens: state.promptTokens, - CompletionTokens: state.completionTokens, - ReasoningTokens: state.reasoningTokens, - TotalTokens: state.totalTokens, - StartedAtMs: state.startedAtMs, - CompletedAtMs: state.completedAtMs, - IncludeUsage: true, - }), + TurnID: state.turnID, + Role: "assistant", + Metadata: metadata, }) } - metadata, _ := uiMessage["metadata"].(map[string]any) - uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(metadata, msgconv.BuildUIMessageMetadata(msgconv.UIMessageMetadataParams{ + existingMetadata, _ := uiMessage["metadata"].(map[string]any) + uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(existingMetadata, metadata) + return uiMessage +} + +func opencodeUIMessageMetadata(state *openCodeStreamState) map[string]any { + return msgconv.BuildUIMessageMetadata(msgconv.UIMessageMetadataParams{ TurnID: state.turnID, AgentID: state.agentID, Model: state.modelID, @@ -114,8 +108,7 @@ func (oc *OpenCodeClient) currentCanonicalUIMessage(state *openCodeStreamState) StartedAtMs: state.startedAtMs, CompletedAtMs: state.completedAtMs, IncludeUsage: true, - })) - return uiMessage + }) } func (oc *OpenCodeClient) buildStreamDBMetadata(state *openCodeStreamState) *MessageMetadata { diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 38761984..be3991b7 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -1888,22 +1888,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { if err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to load default chat portal by deterministic key") } else if portal != nil { - loginMeta.DefaultChatPortalID = string(portal.PortalKey.ID) - if err := oc.UserLogin.Save(ctx); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to persist default chat portal ID") - } - if portal.MXID != "" { - oc.loggerForContext(ctx).Debug().Stringer("portal", portal.PortalKey).Msg("Existing default chat already has MXID") - return nil - } - info := oc.chatInfoFromPortal(ctx, portal) - oc.loggerForContext(ctx).Info().Stringer("portal", portal.PortalKey).Msg("Default chat missing MXID; creating Matrix room") - err := portal.CreateMatrixRoom(ctx, oc.UserLogin, info) - if err != nil { - oc.loggerForContext(ctx).Err(err).Msg("Failed to create Matrix room for default chat") - } - oc.sendWelcomeMessage(ctx, portal) - return err + return oc.ensureExistingChatPortalReady(ctx, loginMeta, portal, "Existing default chat already has MXID", "Default chat missing MXID; creating Matrix room", "Failed to create Matrix room for default chat") } } @@ -1928,22 +1913,7 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { } if defaultPortal != nil { - loginMeta.DefaultChatPortalID = string(defaultPortal.PortalKey.ID) - if err := oc.UserLogin.Save(ctx); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to persist default chat portal ID") - } - if defaultPortal.MXID != "" { - oc.loggerForContext(ctx).Debug().Stringer("portal", defaultPortal.PortalKey).Msg("Existing chat already has MXID") - return nil - } - info := oc.chatInfoFromPortal(ctx, defaultPortal) - oc.loggerForContext(ctx).Info().Stringer("portal", defaultPortal.PortalKey).Msg("Existing portal missing MXID; creating Matrix room") - err := defaultPortal.CreateMatrixRoom(ctx, oc.UserLogin, info) - if err != nil { - oc.loggerForContext(ctx).Err(err).Msg("Failed to create Matrix room for existing portal") - } - oc.sendWelcomeMessage(ctx, defaultPortal) - return err + return oc.ensureExistingChatPortalReady(ctx, loginMeta, defaultPortal, "Existing chat already has MXID", "Existing portal missing MXID; creating Matrix room", "Failed to create Matrix room for existing portal") } // Create default chat with Beep agent @@ -2025,6 +1995,27 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { return nil } +func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta *UserLoginMetadata, portal *bridgev2.Portal, readyMsg string, createMsg string, errMsg string) error { + if loginMeta != nil { + loginMeta.DefaultChatPortalID = string(portal.PortalKey.ID) + if err := oc.UserLogin.Save(ctx); err != nil { + oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to persist default chat portal ID") + } + } + if portal.MXID != "" { + oc.loggerForContext(ctx).Debug().Stringer("portal", portal.PortalKey).Msg(readyMsg) + return nil + } + info := oc.chatInfoFromPortal(ctx, portal) + oc.loggerForContext(ctx).Info().Stringer("portal", portal.PortalKey).Msg(createMsg) + err := portal.CreateMatrixRoom(ctx, oc.UserLogin, info) + if err != nil { + oc.loggerForContext(ctx).Err(err).Msg(errMsg) + } + oc.sendWelcomeMessage(ctx, portal) + return err +} + func (oc *AIClient) listAllChatPortals(ctx context.Context) ([]*bridgev2.Portal, error) { // Query all portals and filter by receiver (our login ID) // This works because all our portals have Receiver set to our UserLogin.ID diff --git a/pkg/connector/error_logging.go b/pkg/connector/error_logging.go index 47a0a66f..ee0e6e8d 100644 --- a/pkg/connector/error_logging.go +++ b/pkg/connector/error_logging.go @@ -10,19 +10,31 @@ import ( ) func logResponsesFailure(log zerolog.Logger, err error, params responses.ResponseNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { - event := log.Error().Err(err).Str("stage", stage) - addRequestSummary(event, meta, messages) - addResponsesParamsSummary(event, params) - addOpenAIErrorFields(event, err) - event.Msg("Responses API failure") + logProviderFailure(log, err, meta, messages, stage, "Responses API failure", func(event *zerolog.Event) { + addResponsesParamsSummary(event, params) + }) } func logChatCompletionsFailure(log zerolog.Logger, err error, params openai.ChatCompletionNewParams, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion, stage string) { + logProviderFailure(log, err, meta, messages, stage, "Chat Completions failure", func(event *zerolog.Event) { + addChatParamsSummary(event, params) + }) +} + +func logProviderFailure( + log zerolog.Logger, + err error, + meta *PortalMetadata, + messages []openai.ChatCompletionMessageParamUnion, + stage string, + msg string, + addSummary func(*zerolog.Event), +) { event := log.Error().Err(err).Str("stage", stage) addRequestSummary(event, meta, messages) - addChatParamsSummary(event, params) + addSummary(event) addOpenAIErrorFields(event, err) - event.Msg("Chat Completions failure") + event.Msg(msg) } func addRequestSummary(event *zerolog.Event, meta *PortalMetadata, messages []openai.ChatCompletionMessageParamUnion) { diff --git a/pkg/connector/pending_queue.go b/pkg/connector/pending_queue.go index ae87a60c..81a0bc26 100644 --- a/pkg/connector/pending_queue.go +++ b/pkg/connector/pending_queue.go @@ -81,13 +81,8 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, } for _, existing := range queue.items { - if !item.allowDuplicate { - if item.messageID != "" && existing.messageID == item.messageID { - return false - } - if item.messageID == "" && existing.messageID == "" && item.pending.MessageBody != "" && existing.pending.MessageBody == item.pending.MessageBody { - return false - } + if pendingQueueItemsConflict(item, existing) { + return false } } @@ -129,6 +124,19 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, return true } +func pendingQueueItemsConflict(item pendingQueueItem, existing pendingQueueItem) bool { + if item.allowDuplicate { + return false + } + if item.messageID != "" && existing.messageID == item.messageID { + return true + } + return item.messageID == "" && + existing.messageID == "" && + item.pending.MessageBody != "" && + existing.pending.MessageBody == item.pending.MessageBody +} + func (oc *AIClient) popQueueItems(roomID id.RoomID, count int) []pendingQueueItem { oc.pendingQueuesMu.Lock() defer oc.pendingQueuesMu.Unlock() diff --git a/pkg/connector/response_retry.go b/pkg/connector/response_retry.go index a9d24123..be347445 100644 --- a/pkg/connector/response_retry.go +++ b/pkg/connector/response_retry.go @@ -117,28 +117,10 @@ func (oc *AIClient) responseWithRetry( Summary: summary, WillRetry: true, }) - oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ - Client: oc, - Portal: portal, - Meta: meta, - Phase: integrationruntime.CompactionLifecycleEnd, - Attempt: attempt + 1, - ContextWindowTokens: contextWindow, - RequestedTokens: cle.RequestedTokens, - PromptTokens: tokensAfter, - MessagesBefore: len(currentPrompt), - MessagesAfter: len(compacted), - TokensBefore: tokensBefore, - TokensAfter: tokensAfter, - DroppedCount: decision.DroppedCount, - Reason: decision.Reason, - WillRetry: true, - }) - oc.emitCompactionLifecycle(ctx, integrationruntime.CompactionLifecycleEvent{ + oc.emitCompactionLifecyclePhases(ctx, integrationruntime.CompactionLifecycleEvent{ Client: oc, Portal: portal, Meta: meta, - Phase: integrationruntime.CompactionLifecycleRefresh, Attempt: attempt + 1, ContextWindowTokens: contextWindow, RequestedTokens: cle.RequestedTokens, @@ -150,7 +132,7 @@ func (oc *AIClient) responseWithRetry( DroppedCount: decision.DroppedCount, Reason: decision.Reason, WillRetry: true, - }) + }, integrationruntime.CompactionLifecycleEnd, integrationruntime.CompactionLifecycleRefresh) oc.loggerForContext(ctx).Info(). Int("messages_before", len(currentPrompt)). @@ -240,6 +222,14 @@ func (oc *AIClient) responseWithRetry( return false, terminal } +func (oc *AIClient) emitCompactionLifecyclePhases(ctx context.Context, base integrationruntime.CompactionLifecycleEvent, phases ...integrationruntime.CompactionLifecyclePhase) { + for _, phase := range phases { + event := base + event.Phase = phase + oc.emitCompactionLifecycle(ctx, event) + } +} + func (oc *AIClient) runCompactionPreflightFlushHook( ctx context.Context, portal *bridgev2.Portal, diff --git a/pkg/connector/room_runs.go b/pkg/connector/room_runs.go index 72ed19d7..fa2bef54 100644 --- a/pkg/connector/room_runs.go +++ b/pkg/connector/room_runs.go @@ -112,13 +112,8 @@ func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) b return false } for _, existing := range run.steerQueue { - if !item.allowDuplicate { - if item.messageID != "" && existing.messageID == item.messageID { - return false - } - if item.messageID == "" && existing.messageID == "" && item.pending.MessageBody != "" && existing.pending.MessageBody == item.pending.MessageBody { - return false - } + if pendingQueueItemsConflict(item, existing) { + return false } } run.steerQueue = append(run.steerQueue, item) diff --git a/pkg/connector/streaming_output_items.go b/pkg/connector/streaming_output_items.go index 1bba6a22..816fd031 100644 --- a/pkg/connector/streaming_output_items.go +++ b/pkg/connector/streaming_output_items.go @@ -114,38 +114,11 @@ func deriveToolDescriptorForOutputItem(item responses.ResponseOutputItemUnion, s desc.input = map[string]any{} desc.ok = true case "local_shell_call": - desc.callID = strings.TrimSpace(item.CallID) - if desc.callID == "" { - desc.callID = item.ID - } - desc.toolName = "local_shell" - desc.toolType = ToolTypeProvider - desc.providerExecuted = true - desc.dynamic = true - desc.input = responseOutputItemToMap(item) - desc.ok = true + desc = providerDynamicResponseToolDescriptor(item, "local_shell") case "shell_call": - desc.callID = strings.TrimSpace(item.CallID) - if desc.callID == "" { - desc.callID = item.ID - } - desc.toolName = "shell" - desc.toolType = ToolTypeProvider - desc.providerExecuted = true - desc.dynamic = true - desc.input = responseOutputItemToMap(item) - desc.ok = true + desc = providerDynamicResponseToolDescriptor(item, "shell") case "apply_patch_call": - desc.callID = strings.TrimSpace(item.CallID) - if desc.callID == "" { - desc.callID = item.ID - } - desc.toolName = "apply_patch" - desc.toolType = ToolTypeProvider - desc.providerExecuted = true - desc.dynamic = true - desc.input = responseOutputItemToMap(item) - desc.ok = true + desc = providerDynamicResponseToolDescriptor(item, "apply_patch") case "custom_tool_call": desc.callID = strings.TrimSpace(item.CallID) if desc.callID == "" { @@ -196,6 +169,23 @@ func deriveToolDescriptorForOutputItem(item responses.ResponseOutputItemUnion, s return desc } +func providerDynamicResponseToolDescriptor(item responses.ResponseOutputItemUnion, toolName string) responseToolDescriptor { + callID := strings.TrimSpace(item.CallID) + if callID == "" { + callID = item.ID + } + return responseToolDescriptor{ + itemID: item.ID, + callID: callID, + toolName: toolName, + toolType: ToolTypeProvider, + input: responseOutputItemToMap(item), + providerExecuted: true, + dynamic: true, + ok: true, + } +} + func outputItemLooksDenied(item responses.ResponseOutputItemUnion) bool { errorText := strings.ToLower(strings.TrimSpace(item.Error)) if strings.Contains(errorText, "denied") || strings.Contains(errorText, "rejected") { diff --git a/pkg/connector/subagent_conversion.go b/pkg/connector/subagent_conversion.go index fb390fa6..035e8803 100644 --- a/pkg/connector/subagent_conversion.go +++ b/pkg/connector/subagent_conversion.go @@ -8,29 +8,46 @@ import ( ) func subagentsToTools(cfg *agents.SubagentConfig) *tools.SubagentConfig { - if cfg == nil { - return nil - } - out := &tools.SubagentConfig{ - Model: cfg.Model, - Thinking: cfg.Thinking, - } - if len(cfg.AllowAgents) > 0 { - out.AllowAgents = slices.Clone(cfg.AllowAgents) - } - return out + return convertSubagentConfig(cfg, func(model, thinking string, allowAgents []string) *tools.SubagentConfig { + return &tools.SubagentConfig{ + Model: model, + Thinking: thinking, + AllowAgents: allowAgents, + } + }) } func subagentsFromTools(cfg *tools.SubagentConfig) *agents.SubagentConfig { + return convertSubagentConfig(cfg, func(model, thinking string, allowAgents []string) *agents.SubagentConfig { + return &agents.SubagentConfig{ + Model: model, + Thinking: thinking, + AllowAgents: allowAgents, + } + }) +} + +type subagentConfigLike interface { + *agents.SubagentConfig | *tools.SubagentConfig +} + +func convertSubagentConfig[T subagentConfigLike, R any](cfg T, build func(string, string, []string) *R) *R { if cfg == nil { return nil } - out := &agents.SubagentConfig{ - Model: cfg.Model, - Thinking: cfg.Thinking, - } - if len(cfg.AllowAgents) > 0 { - out.AllowAgents = slices.Clone(cfg.AllowAgents) + allowAgents := []string(nil) + switch typed := any(cfg).(type) { + case *agents.SubagentConfig: + if len(typed.AllowAgents) > 0 { + allowAgents = slices.Clone(typed.AllowAgents) + } + return build(typed.Model, typed.Thinking, allowAgents) + case *tools.SubagentConfig: + if len(typed.AllowAgents) > 0 { + allowAgents = slices.Clone(typed.AllowAgents) + } + return build(typed.Model, typed.Thinking, allowAgents) + default: + return nil } - return out } diff --git a/pkg/connector/tool_configured.go b/pkg/connector/tool_configured.go index cf694149..c4c0cd5c 100644 --- a/pkg/connector/tool_configured.go +++ b/pkg/connector/tool_configured.go @@ -14,37 +14,51 @@ import ( // prerequisites like API keys and service initialization. func (oc *AIClient) effectiveSearchConfig(_ context.Context) *search.Config { - var cfg *search.Config - var meta *UserLoginMetadata - var connector *OpenAIConnector - if oc != nil { - connector = oc.connector - if connector != nil { - cfg = mapSearchConfig(connector.Config.Tools.Search) - } - if oc.UserLogin != nil { - meta = loginMetadata(oc.UserLogin) - } - } - cfg = applyLoginTokensToSearchConfig(cfg, meta, connector) - return search.ApplyEnvDefaults(cfg).WithDefaults() + return effectiveToolConfig( + oc, + func(connector *OpenAIConnector) *search.Config { + if connector == nil { + return nil + } + return mapSearchConfig(connector.Config.Tools.Search) + }, + applyLoginTokensToSearchConfig, + func(cfg *search.Config) *search.Config { return search.ApplyEnvDefaults(cfg).WithDefaults() }, + ) } func (oc *AIClient) effectiveFetchConfig(_ context.Context) *fetch.Config { - var cfg *fetch.Config + return effectiveToolConfig( + oc, + func(connector *OpenAIConnector) *fetch.Config { + if connector == nil { + return nil + } + return mapFetchConfig(connector.Config.Tools.Fetch) + }, + applyLoginTokensToFetchConfig, + func(cfg *fetch.Config) *fetch.Config { return fetch.ApplyEnvDefaults(cfg).WithDefaults() }, + ) +} + +func effectiveToolConfig[T any]( + oc *AIClient, + load func(*OpenAIConnector) *T, + applyTokens func(*T, *UserLoginMetadata, *OpenAIConnector) *T, + withDefaults func(*T) *T, +) *T { + var cfg *T var meta *UserLoginMetadata var connector *OpenAIConnector if oc != nil { connector = oc.connector - if connector != nil { - cfg = mapFetchConfig(connector.Config.Tools.Fetch) - } + cfg = load(connector) if oc.UserLogin != nil { meta = loginMetadata(oc.UserLogin) } } - cfg = applyLoginTokensToFetchConfig(cfg, meta, connector) - return fetch.ApplyEnvDefaults(cfg).WithDefaults() + cfg = applyTokens(cfg, meta, connector) + return withDefaults(cfg) } func (oc *AIClient) isWebSearchConfigured(ctx context.Context) (bool, string) { diff --git a/pkg/runtime/chat_content.go b/pkg/runtime/chat_content.go index 14e836fb..6e457710 100644 --- a/pkg/runtime/chat_content.go +++ b/pkg/runtime/chat_content.go @@ -46,11 +46,8 @@ func ExtractUserContent(content openai.ChatCompletionUserMessageParamContentUnio if content.OfString.Value != "" { return content.OfString.Value } - return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartUnionParam) string { - if part.OfText == nil { - return "" - } - return part.OfText.Text + return joinOptionalContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionContentPartUnionParam) *openai.ChatCompletionContentPartTextParam { + return part.OfText }) } @@ -58,11 +55,8 @@ func ExtractAssistantContent(content openai.ChatCompletionAssistantMessageParamC if content.OfString.Value != "" { return content.OfString.Value } - return joinContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) string { - if part.OfText == nil { - return "" - } - return part.OfText.Text + return joinOptionalContentText(content.OfArrayOfContentParts, func(part openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) *openai.ChatCompletionContentPartTextParam { + return part.OfText }) } @@ -95,6 +89,16 @@ func joinContentText[T any](parts []T, extract func(T) string) string { return sb.String() } +func joinOptionalContentText[T any](parts []T, extract func(T) *openai.ChatCompletionContentPartTextParam) string { + return joinContentText(parts, func(part T) string { + textPart := extract(part) + if textPart == nil { + return "" + } + return textPart.Text + }) +} + // EstimateMessageChars approximates character usage for one prompt message. func EstimateMessageChars(msg openai.ChatCompletionMessageParamUnion) int { switch { From 49a01a53d8946d99d1512d7dd19178633ae91af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:30:27 +0100 Subject: [PATCH 04/23] Refactor stream/session helpers and subagent logic Introduce EmitStreamEventWithSession wrapper in streamtransport and update Codex and connector callers to use the simpler call signature. Add runCreateSession and runUpdateSessionTitle helpers in OpenCodeManager to delegate to runSessionMutation, removing duplicated anonymous functions. Consolidate subagent config resolution by adding resolveSubagentConfigValue and updating subagentStringValue to accept a field name (model/thinking). These changes reduce duplication and simplify call sites. --- bridges/codex/stream_transport.go | 19 ++++++------ .../opencodebridge/opencode_manager.go | 20 +++++++++---- pkg/connector/stream_events.go | 19 ++++++------ pkg/connector/subagent_spawn.go | 29 ++++++++++--------- pkg/shared/streamtransport/session.go | 21 ++++++++++++++ 5 files changed, 71 insertions(+), 37 deletions(-) diff --git a/bridges/codex/stream_transport.go b/bridges/codex/stream_transport.go index 0efbe694..577c2424 100644 --- a/bridges/codex/stream_transport.go +++ b/bridges/codex/stream_transport.go @@ -96,13 +96,14 @@ func (cc *CodexClient) emitStreamEvent(ctx context.Context, portal *bridgev2.Por if state == nil { return } - streamtransport.EmitStreamEvent(ctx, portal, streamtransport.StreamEventState{ - TurnID: state.turnID, - SuppressSend: state.suppressSend, - LoggedStart: &state.loggedStreamStart, - EnsureSession: func() *streamtransport.StreamSession { - return cc.ensureStreamSession(ctx, portal, state) - }, - Logger: cc.loggerForContext(ctx), - }, part) + streamtransport.EmitStreamEventWithSession( + ctx, + portal, + state.turnID, + state.suppressSend, + &state.loggedStreamStart, + cc.loggerForContext(ctx), + func() *streamtransport.StreamSession { return cc.ensureStreamSession(ctx, portal, state) }, + part, + ) } diff --git a/bridges/opencode/opencodebridge/opencode_manager.go b/bridges/opencode/opencodebridge/opencode_manager.go index cdd54d45..141546fb 100644 --- a/bridges/opencode/opencodebridge/opencode_manager.go +++ b/bridges/opencode/opencodebridge/opencode_manager.go @@ -382,15 +382,11 @@ func (m *OpenCodeManager) AbortSession(ctx context.Context, instanceID, sessionI } func (m *OpenCodeManager) CreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { - return m.runSessionMutation(ctx, instanceID, "create session", func(inst *openCodeInstance) (*opencode.Session, error) { - return inst.client.CreateSession(ctx, title, directory) - }) + return m.runCreateSession(ctx, instanceID, title, directory) } func (m *OpenCodeManager) UpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { - return m.runSessionMutation(ctx, instanceID, "update session title", func(inst *openCodeInstance) (*opencode.Session, error) { - return inst.client.UpdateSessionTitle(ctx, sessionID, title) - }) + return m.runUpdateSessionTitle(ctx, instanceID, sessionID, title) } func (m *OpenCodeManager) runSessionMutation( @@ -413,6 +409,18 @@ func (m *OpenCodeManager) runSessionMutation( return session, nil } +func (m *OpenCodeManager) runCreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { + return m.runSessionMutation(ctx, instanceID, "create session", func(inst *openCodeInstance) (*opencode.Session, error) { + return inst.client.CreateSession(ctx, title, directory) + }) +} + +func (m *OpenCodeManager) runUpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { + return m.runSessionMutation(ctx, instanceID, "update session title", func(inst *openCodeInstance) (*opencode.Session, error) { + return inst.client.UpdateSessionTitle(ctx, sessionID, title) + }) +} + func (m *OpenCodeManager) syncSessions(ctx context.Context, inst *openCodeInstance, sessions []opencode.Session) (int, error) { count := 0 for _, session := range sessions { diff --git a/pkg/connector/stream_events.go b/pkg/connector/stream_events.go index 0f5445f6..4b4075f1 100644 --- a/pkg/connector/stream_events.go +++ b/pkg/connector/stream_events.go @@ -61,13 +61,14 @@ func (oc *AIClient) emitStreamEvent( if state == nil { return } - streamtransport.EmitStreamEvent(ctx, portal, streamtransport.StreamEventState{ - TurnID: state.turnID, - SuppressSend: state.suppressSend, - LoggedStart: &state.loggedStreamStart, - EnsureSession: func() *streamtransport.StreamSession { - return oc.ensureStreamSession(ctx, portal, state) - }, - Logger: oc.loggerForContext(ctx), - }, part) + streamtransport.EmitStreamEventWithSession( + ctx, + portal, + state.turnID, + state.suppressSend, + &state.loggedStreamStart, + oc.loggerForContext(ctx), + func() *streamtransport.StreamSession { return oc.ensureStreamSession(ctx, portal, state) }, + part, + ) } diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index 91c97efd..b3aaf04e 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -55,33 +55,36 @@ func (oc *AIClient) resolveSubagentAllowlist(ctx context.Context, requesterAgent } func resolveSubagentModel(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - return firstNonEmptyTrimmed( - override, - subagentStringValue(agent, func(cfg *agents.SubagentConfig) string { return cfg.Model }), - subagentStringValue(defaults, func(cfg *agents.SubagentConfig) string { return cfg.Model }), - ) + return resolveSubagentConfigValue(override, agent, defaults, "model") } func resolveSubagentThinking(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - return firstNonEmptyTrimmed( - override, - subagentStringValue(agent, func(cfg *agents.SubagentConfig) string { return cfg.Thinking }), - subagentStringValue(defaults, func(cfg *agents.SubagentConfig) string { return cfg.Thinking }), - ) + return resolveSubagentConfigValue(override, agent, defaults, "thinking") } -func subagentStringValue(source any, extract func(*agents.SubagentConfig) string) string { +func resolveSubagentConfigValue(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig, field string) string { + return firstNonEmptyTrimmed(override, subagentStringValue(agent, field), subagentStringValue(defaults, field)) +} + +func subagentStringValue(source any, field string) string { switch cfg := source.(type) { case *agents.AgentDefinition: if cfg == nil { return "" } - return subagentStringValue(cfg.Subagents, extract) + return subagentStringValue(cfg.Subagents, field) case *agents.SubagentConfig: if cfg == nil { return "" } - return extract(cfg) + switch field { + case "model": + return cfg.Model + case "thinking": + return cfg.Thinking + default: + return "" + } default: return "" } diff --git a/pkg/shared/streamtransport/session.go b/pkg/shared/streamtransport/session.go index 25a4fe86..8401e67d 100644 --- a/pkg/shared/streamtransport/session.go +++ b/pkg/shared/streamtransport/session.go @@ -125,6 +125,27 @@ func EmitStreamEvent(ctx context.Context, portal *bridgev2.Portal, state StreamE session.EmitPart(ctx, part) } +// EmitStreamEventWithSession is a convenience wrapper for callers that only need +// to provide the common stream state fields. +func EmitStreamEventWithSession( + ctx context.Context, + portal *bridgev2.Portal, + turnID string, + suppressSend bool, + loggedStart *bool, + logger *zerolog.Logger, + ensureSession func() *StreamSession, + part map[string]any, +) { + EmitStreamEvent(ctx, portal, StreamEventState{ + TurnID: turnID, + SuppressSend: suppressSend, + LoggedStart: loggedStart, + EnsureSession: ensureSession, + Logger: logger, + }, part) +} + func (s *StreamSession) IsClosed() bool { return s == nil || s.closed.Load() } From e324fee059550d2da95caf3863cb9cdc15a890f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:43:55 +0100 Subject: [PATCH 05/23] Add bridgeadapter and toolspec helpers Introduce reusable helpers to reduce duplication and centralize common logic. - Add pkg/bridgeadapter/helpers.go with BuildMetaTypes and BuildChatInfoWithFallback and update Codex/OpenCode/OpenAI connectors and clients to use them for DB meta types and ChatInfo construction. - Add toolspec helpers (ObjectSchema, StringProperty, NumberProperty, BooleanProperty) in pkg/shared/toolspec/toolspec.go and refactor existing schema functions (Read, Write, TTS, etc.) to use the new helpers. - Update boss tools to use toolspec.ObjectSchema for input schemas. - Refactor streaming logic in pkg/connector: extract recordCompletedToolCall and responseFunctionToolDescriptor to remove repeated code in streaming_chat_completions, streaming_function_calls, and streaming_output_items. - Remove small unused resolveSubagentModel/resolveSubagentThinking wrappers and use resolveSubagentConfigValue directly in subagent spawn logic. Overall this commit cleans up repeated patterns, improves maintainability, and centralizes schema/chat/meta construction logic. --- bridges/codex/client.go | 13 +-- bridges/codex/connector.go | 12 +-- bridges/opencode/connector.go | 12 +-- pkg/agents/tools/boss.go | 70 +++----------- pkg/bridgeadapter/helpers.go | 31 ++++++ pkg/connector/client.go | 14 +-- pkg/connector/connector.go | 12 +-- pkg/connector/streaming_chat_completions.go | 17 +--- pkg/connector/streaming_function_calls.go | 15 ++- pkg/connector/streaming_output_items.go | 40 ++++---- pkg/connector/subagent_spawn.go | 12 +-- pkg/shared/toolspec/toolspec.go | 102 +++++++++----------- 12 files changed, 150 insertions(+), 200 deletions(-) create mode 100644 pkg/bridgeadapter/helpers.go diff --git a/bridges/codex/client.go b/bridges/codex/client.go index fda23bb2..c8fca67a 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -333,18 +333,7 @@ func (cc *CodexClient) IsThisUser(ctx context.Context, userID networkid.UserID) func (cc *CodexClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { meta := portalMeta(portal) - title := meta.Title - if title == "" { - if portal.Name != "" { - title = portal.Name - } else { - title = "Codex" - } - } - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(title), - Topic: ptr.NonZero(portal.Topic), - }, nil + return bridgeadapter.BuildChatInfoWithFallback(meta.Title, portal.Name, "Codex", portal.Topic), nil } func (cc *CodexClient) GetUserInfo(_ context.Context, _ *bridgev2.Ghost) (*bridgev2.UserInfo, error) { diff --git a/bridges/codex/connector.go b/bridges/codex/connector.go index 772869ce..19ac3896 100644 --- a/bridges/codex/connector.go +++ b/bridges/codex/connector.go @@ -291,12 +291,12 @@ func (cc *CodexConnector) GetConfig() (example string, data any, upgrader config } func (cc *CodexConnector) GetDBMetaTypes() database.MetaTypes { - return database.MetaTypes{ - Portal: func() any { return &PortalMetadata{} }, - Message: func() any { return &MessageMetadata{} }, - UserLogin: func() any { return &UserLoginMetadata{} }, - Ghost: func() any { return &GhostMetadata{} }, - } + return bridgeadapter.BuildMetaTypes( + func() any { return &PortalMetadata{} }, + func() any { return &MessageMetadata{} }, + func() any { return &UserLoginMetadata{} }, + func() any { return &GhostMetadata{} }, + ) } func (cc *CodexConnector) LoadUserLogin(_ context.Context, login *bridgev2.UserLogin) error { diff --git a/bridges/opencode/connector.go b/bridges/opencode/connector.go index b10ad54b..36724dd1 100644 --- a/bridges/opencode/connector.go +++ b/bridges/opencode/connector.go @@ -84,12 +84,12 @@ func (oc *OpenCodeConnector) GetConfig() (example string, data any, upgrader con } func (oc *OpenCodeConnector) GetDBMetaTypes() database.MetaTypes { - return database.MetaTypes{ - Portal: func() any { return &PortalMetadata{} }, - Message: func() any { return &MessageMetadata{} }, - UserLogin: func() any { return &UserLoginMetadata{} }, - Ghost: func() any { return &GhostMetadata{} }, - } + return bridgeadapter.BuildMetaTypes( + func() any { return &PortalMetadata{} }, + func() any { return &MessageMetadata{} }, + func() any { return &UserLoginMetadata{} }, + func() any { return &GhostMetadata{} }, + ) } func (oc *OpenCodeConnector) LoadUserLogin(_ context.Context, login *bridgev2.UserLogin) error { diff --git a/pkg/agents/tools/boss.go b/pkg/agents/tools/boss.go index dbfc4338..e63c1655 100644 --- a/pkg/agents/tools/boss.go +++ b/pkg/agents/tools/boss.go @@ -10,6 +10,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/beeper/ai-bridge/pkg/agents/toolpolicy" + "github.com/beeper/ai-bridge/pkg/shared/toolspec" ) // Boss tools for agent management. @@ -157,20 +158,10 @@ var ForkAgentTool = &Tool{ Name: "fork_agent", Description: "Create a copy of an existing agent as a new custom agent", Annotations: &mcp.ToolAnnotations{Title: "Fork Agent"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "source_id": map[string]any{ - "type": "string", - "description": "ID of the agent to copy", - }, - "new_name": map[string]any{ - "type": "string", - "description": "Name for the new agent (defaults to '[Original Name] (Fork)')", - }, - }, - "required": []string{"source_id"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "source_id": toolspec.StringProperty("ID of the agent to copy"), + "new_name": toolspec.StringProperty("Name for the new agent (defaults to '[Original Name] (Fork)')"), + }, "source_id"), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -257,20 +248,10 @@ var RunInternalCommandTool = &Tool{ Name: "run_internal_command", Description: "Run an internal !ai command in a target room", Annotations: &mcp.ToolAnnotations{Title: "Run Internal Command"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "command": map[string]any{ - "type": "string", - "description": "The !ai command to run (with or without prefix)", - }, - "room_id": map[string]any{ - "type": "string", - "description": "Optional target room ID (defaults to the current room)", - }, - }, - "required": []string{"command"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "command": toolspec.StringProperty("The !ai command to run (with or without prefix)"), + "room_id": toolspec.StringProperty("Optional target room ID (defaults to the current room)"), + }, "command"), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -376,32 +357,13 @@ var SessionsSendTool = &Tool{ Name: "sessions_send", Description: "Send a message into another session. Prefer the sessionKey from sessions_list; label is fallback only.", Annotations: &mcp.ToolAnnotations{Title: "Send to Session"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "sessionKey": map[string]any{ - "type": "string", - "description": "Session identifier from sessions_list (preferred canonical target)", - }, - "label": map[string]any{ - "type": "string", - "description": "Session label fallback (can be ambiguous; sessionKey is preferred)", - }, - "agentId": map[string]any{ - "type": "string", - "description": "Agent id filter for label lookups", - }, - "message": map[string]any{ - "type": "string", - "description": "The message to send", - }, - "timeoutSeconds": map[string]any{ - "type": "number", - "description": "Optional timeout for the remote session", - }, - }, - "required": []string{"message"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "sessionKey": toolspec.StringProperty("Session identifier from sessions_list (preferred canonical target)"), + "label": toolspec.StringProperty("Session label fallback (can be ambiguous; sessionKey is preferred)"), + "agentId": toolspec.StringProperty("Agent id filter for label lookups"), + "message": toolspec.StringProperty("The message to send"), + "timeoutSeconds": toolspec.NumberProperty("Optional timeout for the remote session"), + }, "message"), }, Type: ToolTypeBuiltin, Group: GroupSessions, diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go new file mode 100644 index 00000000..20d0d52e --- /dev/null +++ b/pkg/bridgeadapter/helpers.go @@ -0,0 +1,31 @@ +package bridgeadapter + +import ( + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" +) + +func BuildMetaTypes(portal, message, userLogin, ghost func() any) database.MetaTypes { + return database.MetaTypes{ + Portal: portal, + Message: message, + UserLogin: userLogin, + Ghost: ghost, + } +} + +func BuildChatInfoWithFallback(metaTitle, portalName, fallbackTitle, portalTopic string) *bridgev2.ChatInfo { + title := metaTitle + if title == "" { + if portalName != "" { + title = portalName + } else { + title = fallbackTitle + } + } + return &bridgev2.ChatInfo{ + Name: ptr.Ptr(title), + Topic: ptr.NonZero(portalTopic), + } +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index f7c44662..f26846e9 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -1078,19 +1078,7 @@ func (oc *AIClient) IsThisUser(ctx context.Context, userID networkid.UserID) boo func (oc *AIClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { meta := portalMeta(portal) - title := meta.Title - if title == "" { - if portal.Name != "" { - title = portal.Name - } else { - title = "AI Chat" - } - } - // Use actual portal.Topic, not SystemPrompt (they are separate concepts) - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(title), - Topic: ptr.NonZero(portal.Topic), - }, nil + return bridgeadapter.BuildChatInfoWithFallback(meta.Title, portal.Name, "AI Chat", portal.Topic), nil } func (oc *AIClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b86e4561..5b3c0f46 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -392,12 +392,12 @@ func (oc *OpenAIConnector) GetConfig() (example string, data any, upgrader confi } func (oc *OpenAIConnector) GetDBMetaTypes() database.MetaTypes { - return database.MetaTypes{ - Portal: func() any { return &PortalMetadata{} }, - Message: func() any { return &MessageMetadata{} }, - UserLogin: func() any { return &UserLoginMetadata{} }, - Ghost: func() any { return &GhostMetadata{} }, - } + return bridgeadapter.BuildMetaTypes( + func() any { return &PortalMetadata{} }, + func() any { return &MessageMetadata{} }, + func() any { return &UserLoginMetadata{} }, + func() any { return &GhostMetadata{} }, + ) } func (oc *OpenAIConnector) LoadUserLogin(ctx context.Context, login *bridgev2.UserLogin) error { diff --git a/pkg/connector/streaming_chat_completions.go b/pkg/connector/streaming_chat_completions.go index 0f249544..e2b1f534 100644 --- a/pkg/connector/streaming_chat_completions.go +++ b/pkg/connector/streaming_chat_completions.go @@ -331,22 +331,7 @@ func (oc *AIClient) streamChatCompletions( } oc.uiEmitter(state).EmitUIToolInputAvailable(ctx, portal, tool.callID, toolName, inputMap, false) - // Track tool call in metadata - completedAt := time.Now().UnixMilli() - resultEventID := oc.sendToolResultEvent(ctx, portal, state, tool, result, resultStatus) - state.toolCalls = append(state.toolCalls, ToolCallMetadata{ - CallID: tool.callID, - ToolName: toolName, - ToolType: string(tool.toolType), - Input: parseToolInputPayload(argsJSON), - Output: map[string]any{"result": result}, - Status: string(ToolStatusCompleted), - ResultStatus: string(resultStatus), - StartedAtMs: tool.startedAtMs, - CompletedAtMs: completedAt, - CallEventID: string(tool.eventID), - ResultEventID: string(resultEventID), - }) + recordCompletedToolCall(ctx, oc, portal, state, tool, toolName, argsJSON, result, resultStatus) if resultStatus == ResultStatusSuccess { collectToolOutputCitations(state, toolName, result) diff --git a/pkg/connector/streaming_function_calls.go b/pkg/connector/streaming_function_calls.go index f17900f6..cc1d01e8 100644 --- a/pkg/connector/streaming_function_calls.go +++ b/pkg/connector/streaming_function_calls.go @@ -254,7 +254,20 @@ func (oc *AIClient) handleFunctionCallArgumentsDone( oc.uiEmitter(state).EmitUIToolOutputError(ctx, portal, tool.callID, result, tool.toolType == ToolTypeProvider) } - // Track tool call in metadata. + recordCompletedToolCall(ctx, oc, portal, state, tool, toolName, argsJSON, result, resultStatus) +} + +func recordCompletedToolCall( + ctx context.Context, + oc *AIClient, + portal *bridgev2.Portal, + state *streamingState, + tool *activeToolCall, + toolName string, + argsJSON string, + result string, + resultStatus ResultStatus, +) { completedAt := time.Now().UnixMilli() resultEventID := oc.sendToolResultEvent(ctx, portal, state, tool, result, resultStatus) state.toolCalls = append(state.toolCalls, ToolCallMetadata{ diff --git a/pkg/connector/streaming_output_items.go b/pkg/connector/streaming_output_items.go index 816fd031..0c456810 100644 --- a/pkg/connector/streaming_output_items.go +++ b/pkg/connector/streaming_output_items.go @@ -70,16 +70,7 @@ func deriveToolDescriptorForOutputItem(item responses.ResponseOutputItemUnion, s } switch item.Type { case "function_call": - desc.callID = strings.TrimSpace(item.CallID) - if desc.callID == "" { - desc.callID = item.ID - } - desc.toolName = strings.TrimSpace(item.Name) - desc.toolType = ToolTypeFunction - desc.providerExecuted = false - desc.dynamic = false - desc.input = parseJSONOrRaw(item.Arguments) - desc.ok = desc.toolName != "" + desc = responseFunctionToolDescriptor(item, false, parseJSONOrRaw(item.Arguments)) case "web_search_call": desc.toolName = ToolNameWebSearch desc.toolType = ToolTypeProvider @@ -120,16 +111,7 @@ func deriveToolDescriptorForOutputItem(item responses.ResponseOutputItemUnion, s case "apply_patch_call": desc = providerDynamicResponseToolDescriptor(item, "apply_patch") case "custom_tool_call": - desc.callID = strings.TrimSpace(item.CallID) - if desc.callID == "" { - desc.callID = item.ID - } - desc.toolName = strings.TrimSpace(item.Name) - desc.toolType = ToolTypeFunction - desc.providerExecuted = false - desc.dynamic = true - desc.input = parseJSONOrRaw(item.Input) - desc.ok = desc.toolName != "" + desc = responseFunctionToolDescriptor(item, true, parseJSONOrRaw(item.Input)) case "mcp_call": desc.toolName = "mcp." + strings.TrimSpace(item.Name) desc.toolType = ToolTypeMCP @@ -169,6 +151,24 @@ func deriveToolDescriptorForOutputItem(item responses.ResponseOutputItemUnion, s return desc } +func responseFunctionToolDescriptor(item responses.ResponseOutputItemUnion, dynamic bool, input any) responseToolDescriptor { + callID := strings.TrimSpace(item.CallID) + if callID == "" { + callID = item.ID + } + toolName := strings.TrimSpace(item.Name) + return responseToolDescriptor{ + itemID: item.ID, + callID: callID, + toolName: toolName, + toolType: ToolTypeFunction, + input: input, + providerExecuted: false, + dynamic: dynamic, + ok: toolName != "", + } +} + func providerDynamicResponseToolDescriptor(item responses.ResponseOutputItemUnion, toolName string) responseToolDescriptor { callID := strings.TrimSpace(item.CallID) if callID == "" { diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index b3aaf04e..4d4d4eb2 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -54,14 +54,6 @@ func (oc *AIClient) resolveSubagentAllowlist(ctx context.Context, requesterAgent return allowAny, allowSet } -func resolveSubagentModel(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - return resolveSubagentConfigValue(override, agent, defaults, "model") -} - -func resolveSubagentThinking(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { - return resolveSubagentConfigValue(override, agent, defaults, "thinking") -} - func resolveSubagentConfigValue(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig, field string) string { return firstNonEmptyTrimmed(override, subagentStringValue(agent, field), subagentStringValue(defaults, field)) } @@ -264,7 +256,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P if oc.connector != nil && oc.connector.Config.Agents != nil && oc.connector.Config.Agents.Defaults != nil { defaultSubagents = oc.connector.Config.Agents.Defaults.Subagents } - thinkingCandidate := resolveSubagentThinking(thinkingOverride, targetAgent, defaultSubagents) + thinkingCandidate := resolveSubagentConfigValue(thinkingOverride, targetAgent, defaultSubagents, "thinking") thinkingLevel, ok := normalizeThinkingLevel(thinkingCandidate) if !ok { return tools.JSONResult(map[string]any{ @@ -274,7 +266,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } reasoningEffort := mapThinkingToReasoningEffort(thinkingLevel) - modelCandidate := resolveSubagentModel(modelOverride, targetAgent, defaultSubagents) + modelCandidate := resolveSubagentConfigValue(modelOverride, targetAgent, defaultSubagents, "model") resolvedModel := "" modelWarning := "" diff --git a/pkg/shared/toolspec/toolspec.go b/pkg/shared/toolspec/toolspec.go index 5a65200e..361765cc 100644 --- a/pkg/shared/toolspec/toolspec.go +++ b/pkg/shared/toolspec/toolspec.go @@ -133,42 +133,19 @@ func WebFetchSchema() map[string]any { // ReadSchema returns the JSON schema for the read tool. func ReadSchema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Path to the file to read (relative or absolute)", - }, - "offset": map[string]any{ - "type": "number", - "description": "Line number to start reading from (1-indexed)", - }, - "limit": map[string]any{ - "type": "number", - "description": "Maximum number of lines to read", - }, - }, - "required": []string{"path"}, - } + return ObjectSchema(map[string]any{ + "path": StringProperty("Path to the file to read (relative or absolute)"), + "offset": NumberProperty("Line number to start reading from (1-indexed)"), + "limit": NumberProperty("Maximum number of lines to read"), + }, "path") } // WriteSchema returns the JSON schema for the write tool. func WriteSchema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Path to the file to write (relative or absolute)", - }, - "content": map[string]any{ - "type": "string", - "description": "Content to write to the file", - }, - }, - "required": []string{"path", "content"}, - } + return ObjectSchema(map[string]any{ + "path": StringProperty("Path to the file to write (relative or absolute)"), + "content": StringProperty("Content to write to the file"), + }, "path", "content") } // EditSchema returns the JSON schema for the edit tool. @@ -566,31 +543,44 @@ func ImageGenerateSchema() map[string]any { // TTSSchema returns the JSON schema for the tts tool. func TTSSchema() map[string]any { + return ObjectSchema(map[string]any{ + "async": BooleanProperty("Optional: if true, start TTS in the background and send the audio to the chat when ready (tool returns immediately)."), + "text": StringProperty("Text to convert to speech."), + "voice": StringProperty("Optional: preferred voice (OpenAI voices: alloy, ash, coral, echo, fable, onyx, nova, sage, shimmer)."), + "model": StringProperty("Optional: TTS model (e.g. tts-1-hd, tts-1)."), + "channel": StringProperty("Optional channel id to pick output format (e.g. telegram)."), + }, "text") +} + +func ObjectSchema(properties map[string]any, required ...string) map[string]any { + schema := map[string]any{ + "type": "object", + "properties": properties, + } + if len(required) > 0 { + schema["required"] = required + } + return schema +} + +func StringProperty(description string) map[string]any { return map[string]any{ - "type": "object", - "properties": map[string]any{ - "async": map[string]any{ - "type": "boolean", - "description": "Optional: if true, start TTS in the background and send the audio to the chat when ready (tool returns immediately).", - }, - "text": map[string]any{ - "type": "string", - "description": "Text to convert to speech.", - }, - "voice": map[string]any{ - "type": "string", - "description": "Optional: preferred voice (OpenAI voices: alloy, ash, coral, echo, fable, onyx, nova, sage, shimmer).", - }, - "model": map[string]any{ - "type": "string", - "description": "Optional: TTS model (e.g. tts-1-hd, tts-1).", - }, - "channel": map[string]any{ - "type": "string", - "description": "Optional channel id to pick output format (e.g. telegram).", - }, - }, - "required": []string{"text"}, + "type": "string", + "description": description, + } +} + +func NumberProperty(description string) map[string]any { + return map[string]any{ + "type": "number", + "description": description, + } +} + +func BooleanProperty(description string) map[string]any { + return map[string]any{ + "type": "boolean", + "description": description, } } From fa299bac9d10cdbce1cf27ddf162752461bf5e2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:53:36 +0100 Subject: [PATCH 06/23] Centralize canonical parsers and web search Extract common parsing logic and DRY up duplicated code across bridges and connector packages. Added pkg/shared/citations/web_search.go to centralize web-search result parsing and updated callers to use citations.ExtractWebSearchCitations. Moved canonical UI extraction helpers into bridges/opencode/opencodebridge/canonical_extract.go (reasoning text, tool calls, generated files) and introduced BuildDataPartMap to unify data part mapping/streaming. Added bridgeadapter helpers: ParseApprovalDecisionEvent and BuildSystemNotice and replaced local copies with these centralized functions. Introduced cancelAndStopTimer on openCodeInstance and used it in manager methods to simplify instance teardown. Misc. import and call-site adjustments to use the new shared helpers. --- bridges/codex/citations_collect.go | 40 +---- bridges/codex/client.go | 26 +-- bridges/opencode/client.go | 13 +- .../opencodebridge/backfill_canonical.go | 150 +----------------- .../opencodebridge/canonical_extract.go | 94 +++++++++++ .../opencode_canonical_stream.go | 18 ++- .../opencodebridge/opencode_instance_state.go | 15 ++ .../opencodebridge/opencode_manager.go | 32 +--- bridges/opencode/portal_send.go | 15 +- bridges/opencode/stream_canonical.go | 92 +---------- pkg/bridgeadapter/approval_decision.go | 18 ++- pkg/bridgeadapter/helpers.go | 17 ++ pkg/connector/chat.go | 14 +- pkg/connector/source_citations.go | 55 +------ pkg/shared/citations/web_search.go | 61 +++++++ 15 files changed, 239 insertions(+), 421 deletions(-) create mode 100644 bridges/opencode/opencodebridge/canonical_extract.go create mode 100644 pkg/shared/citations/web_search.go diff --git a/bridges/codex/citations_collect.go b/bridges/codex/citations_collect.go index 6e4d6e4c..2b03d452 100644 --- a/bridges/codex/citations_collect.go +++ b/bridges/codex/citations_collect.go @@ -1,7 +1,6 @@ package codex import ( - "encoding/json" neturl "net/url" "path/filepath" "strings" @@ -193,42 +192,5 @@ func extractWebSearchCitationsFromToolOutput(toolName, output string) []citation if normalizeToolAlias(strings.TrimSpace(toolName)) != "websearch" { return nil } - output = strings.TrimSpace(output) - if output == "" || !strings.HasPrefix(output, "{") { - return nil - } - var payload map[string]any - if err := json.Unmarshal([]byte(output), &payload); err != nil { - return nil - } - rawResults, ok := payload["results"].([]any) - if !ok || len(rawResults) == 0 { - return nil - } - result := make([]citations.SourceCitation, 0, len(rawResults)) - for _, item := range rawResults { - m, ok := item.(map[string]any) - if !ok { - continue - } - url, _ := m["url"].(string) - url = strings.TrimSpace(url) - if url == "" { - continue - } - parsedURL, err := neturl.Parse(url) - if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { - continue - } - title, _ := m["title"].(string) - description, _ := m["description"].(string) - siteName, _ := m["siteName"].(string) - result = append(result, citations.SourceCitation{ - URL: url, - Title: strings.TrimSpace(title), - Description: strings.TrimSpace(description), - SiteName: strings.TrimSpace(siteName), - }) - } - return result + return citations.ExtractWebSearchCitations(output) } diff --git a/bridges/codex/client.go b/bridges/codex/client.go index c8fca67a..0eff4cd0 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -1679,21 +1679,10 @@ func (cc *CodexClient) sendSystemNotice(ctx context.Context, portal *bridgev2.Po if portal == nil || portal.MXID == "" || cc.UserLogin == nil || cc.UserLogin.Bridge == nil { return } - converted := &bridgev2.ConvertedMessage{ - Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: strings.TrimSpace(message), - Mentions: &event.Mentions{}, - }, - }}, - } bg := cc.backgroundContext(ctx) sendCtx, cancel := context.WithTimeout(bg, 10*time.Second) defer cancel() - cc.sendViaPortal(sendCtx, portal, converted, "") + cc.sendViaPortal(sendCtx, portal, bridgeadapter.BuildSystemNotice(strings.TrimSpace(message)), "") } func (cc *CodexClient) sendApprovalRequestFallbackEvent( @@ -2269,7 +2258,7 @@ func (cc *CodexClient) handleApprovalRequest( } func (cc *CodexClient) tryApprovalDecisionEvent(ctx context.Context, msg *bridgev2.MatrixMessage) (bool, *bridgev2.MatrixMessageResponse) { - raw, ok := parseCodexApprovalDecision(msg.Event) + raw, ok := bridgeadapter.ParseApprovalDecisionEvent(msg.Event) if !ok { return false, nil } @@ -2295,17 +2284,6 @@ func (cc *CodexClient) tryApprovalDecisionEvent(ctx context.Context, msg *bridge return true, &bridgev2.MatrixMessageResponse{Pending: false} } -func parseCodexApprovalDecision(evt *event.Event) (map[string]any, bool) { - if evt == nil || evt.Content.Raw == nil { - return nil, false - } - raw, ok := evt.Content.Raw["com.beeper.ai.approval_decision"].(map[string]any) - if !ok { - return nil, false - } - return raw, true -} - func (cc *CodexClient) handleCommandApprovalRequest(ctx context.Context, req codexrpc.Request) (any, *codexrpc.RPCError) { return cc.handleApprovalRequest(ctx, req, "commandExecution", func(raw json.RawMessage) map[string]any { var p struct { diff --git a/bridges/opencode/client.go b/bridges/opencode/client.go index 86d45efe..5d141501 100644 --- a/bridges/opencode/client.go +++ b/bridges/opencode/client.go @@ -147,7 +147,7 @@ func (oc *OpenCodeClient) tryApprovalDecisionEvent(ctx context.Context, msg *bri if oc == nil || oc.bridge == nil || msg == nil || msg.Event == nil || msg.Portal == nil { return false, nil } - raw, ok := parseOpenCodeApprovalDecision(msg.Event) + raw, ok := bridgeadapter.ParseApprovalDecisionEvent(msg.Event) if !ok { return false, nil } @@ -169,17 +169,6 @@ func (oc *OpenCodeClient) tryApprovalDecisionEvent(ctx context.Context, msg *bri return true, &bridgev2.MatrixMessageResponse{Pending: false} } -func parseOpenCodeApprovalDecision(evt *event.Event) (map[string]any, bool) { - if evt == nil || evt.Content.Raw == nil { - return nil, false - } - raw, ok := evt.Content.Raw["com.beeper.ai.approval_decision"].(map[string]any) - if !ok { - return nil, false - } - return raw, true -} - func (oc *OpenCodeClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { if oc.bridge == nil { return nil diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index 7832f0aa..33459851 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -1,14 +1,11 @@ package opencodebridge import ( - "encoding/json" - "slices" "strings" "maunium.net/go/mautrix/event" "github.com/beeper/ai-bridge/bridges/opencode/opencode" - "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/streamui" "github.com/beeper/ai-bridge/pkg/shared/stringutil" @@ -85,9 +82,9 @@ func buildCanonicalAssistantBackfill(msg opencode.MessageWithParts, agentID stri CanonicalUIMessage: uiMessage, StartedAtMs: int64(msg.Info.Time.Created), CompletedAtMs: int64(msg.Info.Time.Completed), - ThinkingContent: canonicalReasoningTextBridge(uiMessage), - ToolCalls: canonicalToolCallsBridge(uiMessage), - GeneratedFiles: canonicalGeneratedFilesBridge(uiMessage), + ThinkingContent: CanonicalReasoningText(uiMessage), + ToolCalls: CanonicalToolCalls(uiMessage), + GeneratedFiles: CanonicalGeneratedFiles(uiMessage), }, } } @@ -238,139 +235,13 @@ func canonicalDataPart(part opencode.Part) map[string]any { if strings.TrimSpace(part.ID) == "" { return nil } - data := map[string]any{ - "type": "data-opencode-" + strings.TrimSpace(part.Type), - "id": strings.TrimSpace(part.ID), - } - switch part.Type { - case "step-finish": - if reason := strings.TrimSpace(part.Reason); reason != "" { - data["reason"] = reason - } - if part.Cost != 0 { - data["cost"] = part.Cost - } - case "patch": - if hash := strings.TrimSpace(part.Hash); hash != "" { - data["hash"] = hash - } - if len(part.Files) > 0 { - data["files"] = slices.Clone(part.Files) - } - case "snapshot": - if snapshot := strings.TrimSpace(part.Snapshot); snapshot != "" { - data["snapshot"] = snapshot - } - case "agent": - if name := strings.TrimSpace(part.Name); name != "" { - data["name"] = name - } - case "subtask": - if desc := strings.TrimSpace(part.Description); desc != "" { - data["description"] = desc - } - if prompt := strings.TrimSpace(part.Prompt); prompt != "" { - data["prompt"] = prompt - } - if agent := strings.TrimSpace(part.Agent); agent != "" { - data["agent"] = agent - } - case "retry": - if part.Attempt != 0 { - data["attempt"] = part.Attempt - } - if len(part.Error) > 0 { - data["error"] = string(part.Error) - } - case "compaction": - data["auto"] = part.Auto - default: + data := BuildDataPartMap(part) + if data == nil { return nil } return data } -func canonicalReasoningTextBridge(uiMessage map[string]any) string { - parts, _ := uiMessage["parts"].([]any) - var sb strings.Builder - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok || strings.TrimSpace(stringValueBridge(part["type"])) != "reasoning" { - continue - } - text := strings.TrimSpace(stringValueBridge(part["text"])) - if text == "" { - continue - } - if sb.Len() > 0 { - sb.WriteString("\n") - } - sb.WriteString(text) - } - return sb.String() -} - -func canonicalToolCallsBridge(uiMessage map[string]any) []bridgeadapter.ToolCallMetadata { - parts, _ := uiMessage["parts"].([]any) - calls := make([]bridgeadapter.ToolCallMetadata, 0, len(parts)) - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok || strings.TrimSpace(stringValueBridge(part["type"])) != "dynamic-tool" { - continue - } - call := bridgeadapter.ToolCallMetadata{ - CallID: strings.TrimSpace(stringValueBridge(part["toolCallId"])), - ToolName: strings.TrimSpace(stringValueBridge(part["toolName"])), - ToolType: "opencode", - Status: strings.TrimSpace(stringValueBridge(part["state"])), - } - if input, ok := part["input"].(map[string]any); ok { - call.Input = input - } - if output, ok := part["output"].(map[string]any); ok { - call.Output = output - } else if output := strings.TrimSpace(stringValueBridge(part["output"])); output != "" { - call.Output = map[string]any{"text": output} - } - switch call.Status { - case "output-available": - call.ResultStatus = "completed" - case "output-error": - call.ResultStatus = "error" - call.ErrorMessage = strings.TrimSpace(stringValueBridge(part["errorText"])) - case "output-denied": - call.ResultStatus = "denied" - case "approval-requested": - call.ResultStatus = "pending_approval" - default: - call.ResultStatus = call.Status - } - if call.CallID != "" { - calls = append(calls, call) - } - } - return calls -} - -func canonicalGeneratedFilesBridge(uiMessage map[string]any) []bridgeadapter.GeneratedFileRef { - parts, _ := uiMessage["parts"].([]any) - files := make([]bridgeadapter.GeneratedFileRef, 0, len(parts)) - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok || strings.TrimSpace(stringValueBridge(part["type"])) != "file" { - continue - } - url := strings.TrimSpace(stringValueBridge(part["url"])) - if url == "" { - continue - } - files = append(files, bridgeadapter.GeneratedFileRef{ - URL: url, - MimeType: stringutil.FirstNonEmpty(strings.TrimSpace(stringValueBridge(part["mediaType"])), "application/octet-stream"), - }) - } - return files -} func backfillCost(msg opencode.MessageWithParts) float64 { if msg.Info.Cost != 0 { @@ -418,17 +289,6 @@ func backfillTotalTokens(msg opencode.MessageWithParts) int64 { return backfillPromptTokens(msg) + backfillCompletionTokens(msg) + backfillReasoningTokens(msg) } -func stringValueBridge(raw any) string { - switch value := raw.(type) { - case string: - return value - case json.Number: - return value.String() - default: - return "" - } -} - func buildCanonicalBackfillPart(snapshot canonicalBackfillSnapshot) *event.MessageEventContent { return &event.MessageEventContent{ MsgType: event.MsgText, diff --git a/bridges/opencode/opencodebridge/canonical_extract.go b/bridges/opencode/opencodebridge/canonical_extract.go new file mode 100644 index 00000000..e297ab7d --- /dev/null +++ b/bridges/opencode/opencodebridge/canonical_extract.go @@ -0,0 +1,94 @@ +package opencodebridge + +import ( + "strings" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" + "github.com/beeper/ai-bridge/pkg/shared/maputil" + "github.com/beeper/ai-bridge/pkg/shared/stringutil" +) + +// CanonicalReasoningText extracts and joins all reasoning-type text from a canonical UI message. +func CanonicalReasoningText(uiMessage map[string]any) string { + parts, _ := uiMessage["parts"].([]any) + var sb strings.Builder + for _, raw := range parts { + part, ok := raw.(map[string]any) + if !ok || maputil.StringArg(part, "type") != "reasoning" { + continue + } + text := maputil.StringArg(part, "text") + if text == "" { + continue + } + if sb.Len() > 0 { + sb.WriteString("\n") + } + sb.WriteString(text) + } + return sb.String() +} + +// CanonicalGeneratedFiles extracts file references from a canonical UI message. +func CanonicalGeneratedFiles(uiMessage map[string]any) []bridgeadapter.GeneratedFileRef { + parts, _ := uiMessage["parts"].([]any) + var refs []bridgeadapter.GeneratedFileRef + for _, raw := range parts { + part, ok := raw.(map[string]any) + if !ok || maputil.StringArg(part, "type") != "file" { + continue + } + url := maputil.StringArg(part, "url") + if url == "" { + continue + } + refs = append(refs, bridgeadapter.GeneratedFileRef{ + URL: url, + MimeType: stringutil.FirstNonEmpty(maputil.StringArg(part, "mediaType"), "application/octet-stream"), + }) + } + return refs +} + +// CanonicalToolCalls extracts tool call metadata from a canonical UI message. +func CanonicalToolCalls(uiMessage map[string]any) []bridgeadapter.ToolCallMetadata { + parts, _ := uiMessage["parts"].([]any) + var calls []bridgeadapter.ToolCallMetadata + for _, raw := range parts { + part, ok := raw.(map[string]any) + if !ok || maputil.StringArg(part, "type") != "dynamic-tool" { + continue + } + call := bridgeadapter.ToolCallMetadata{ + CallID: maputil.StringArg(part, "toolCallId"), + ToolName: maputil.StringArg(part, "toolName"), + ToolType: "opencode", + Status: maputil.StringArg(part, "state"), + } + if input, ok := part["input"].(map[string]any); ok { + call.Input = input + } + if output, ok := part["output"].(map[string]any); ok { + call.Output = output + } else if text := maputil.StringArg(part, "output"); text != "" { + call.Output = map[string]any{"text": text} + } + switch call.Status { + case "output-available": + call.ResultStatus = "completed" + case "output-denied": + call.ResultStatus = "denied" + case "output-error": + call.ResultStatus = "error" + call.ErrorMessage = maputil.StringArg(part, "errorText") + case "approval-requested": + call.ResultStatus = "pending_approval" + default: + call.ResultStatus = call.Status + } + if call.CallID != "" { + calls = append(calls, call) + } + } + return calls +} diff --git a/bridges/opencode/opencodebridge/opencode_canonical_stream.go b/bridges/opencode/opencodebridge/opencode_canonical_stream.go index f8e43669..ea123daa 100644 --- a/bridges/opencode/opencodebridge/opencode_canonical_stream.go +++ b/bridges/opencode/opencodebridge/opencode_canonical_stream.go @@ -101,6 +101,18 @@ func (m *OpenCodeManager) emitDataPartStream(ctx context.Context, inst *openCode if state := inst.partState(part.SessionID, part.ID); state != nil && state.dataStreamSent { return } + data := BuildDataPartMap(part) + if data == nil { + return + } + turnID := partTurnID(part) + m.bridge.emitOpenCodeStreamEvent(ctx, portal, turnID, m.bridge.portalAgentID(portal), data) + inst.markPartDataStreamSent(part.SessionID, part.ID) +} + +// BuildDataPartMap builds a map representation of an opencode data part for streaming or backfill. +// Returns nil for unknown part types. +func BuildDataPartMap(part opencode.Part) map[string]any { data := map[string]any{ "type": "data-opencode-" + strings.TrimSpace(part.Type), "id": part.ID, @@ -147,8 +159,8 @@ func (m *OpenCodeManager) emitDataPartStream(ctx context.Context, inst *openCode } case "compaction": data["auto"] = part.Auto + default: + return nil } - turnID := partTurnID(part) - m.bridge.emitOpenCodeStreamEvent(ctx, portal, turnID, m.bridge.portalAgentID(portal), data) - inst.markPartDataStreamSent(part.SessionID, part.ID) + return data } diff --git a/bridges/opencode/opencodebridge/opencode_instance_state.go b/bridges/opencode/opencodebridge/opencode_instance_state.go index bd666020..52b59caf 100644 --- a/bridges/opencode/opencodebridge/opencode_instance_state.go +++ b/bridges/opencode/opencodebridge/opencode_instance_state.go @@ -74,6 +74,21 @@ type openCodeInstance struct { sendQueue map[string]*openCodeSessionQueue } +// cancelAndStopTimer cancels the instance's event loop and stops its disconnect timer. +func (inst *openCodeInstance) cancelAndStopTimer() { + if inst.cancel != nil { + inst.cancel() + } + inst.cancel = nil + inst.connected = false + inst.disconnectMu.Lock() + if inst.disconnectTimer != nil { + inst.disconnectTimer.Stop() + inst.disconnectTimer = nil + } + inst.disconnectMu.Unlock() +} + // ---------- seen-message helpers ---------- func (inst *openCodeInstance) isSeen(sessionID, messageID string) bool { diff --git a/bridges/opencode/opencodebridge/opencode_manager.go b/bridges/opencode/opencodebridge/opencode_manager.go index 141546fb..adef84df 100644 --- a/bridges/opencode/opencodebridge/opencode_manager.go +++ b/bridges/opencode/opencodebridge/opencode_manager.go @@ -89,17 +89,7 @@ func (m *OpenCodeManager) DisconnectAll() { if inst == nil { continue } - if inst.cancel != nil { - inst.cancel() - } - inst.cancel = nil - inst.connected = false - inst.disconnectMu.Lock() - if inst.disconnectTimer != nil { - inst.disconnectTimer.Stop() - inst.disconnectTimer = nil - } - inst.disconnectMu.Unlock() + inst.cancelAndStopTimer() } m.instances = make(map[string]*openCodeInstance) } @@ -173,15 +163,7 @@ func (m *OpenCodeManager) Connect(ctx context.Context, baseURL, password, userna m.mu.Lock() if existing := m.instances[instanceID]; existing != nil { - if existing.cancel != nil { - existing.cancel() - } - existing.disconnectMu.Lock() - if existing.disconnectTimer != nil { - existing.disconnectTimer.Stop() - existing.disconnectTimer = nil - } - existing.disconnectMu.Unlock() + existing.cancelAndStopTimer() } m.instances[instanceID] = inst m.mu.Unlock() @@ -232,15 +214,7 @@ func (m *OpenCodeManager) RemoveInstance(ctx context.Context, instanceID string) m.mu.Lock() if inst := m.instances[id]; inst != nil { hadInstance = true - if inst.cancel != nil { - inst.cancel() - } - inst.disconnectMu.Lock() - if inst.disconnectTimer != nil { - inst.disconnectTimer.Stop() - inst.disconnectTimer = nil - } - inst.disconnectMu.Unlock() + inst.cancelAndStopTimer() delete(m.instances, id) } m.mu.Unlock() diff --git a/bridges/opencode/portal_send.go b/bridges/opencode/portal_send.go index d3d3423c..e4a26b01 100644 --- a/bridges/opencode/portal_send.go +++ b/bridges/opencode/portal_send.go @@ -6,8 +6,6 @@ import ( "time" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) @@ -49,18 +47,7 @@ func (oc *OpenCodeClient) sendSystemNoticeViaPortal(ctx context.Context, portal if pmeta != nil { instanceID = pmeta.InstanceID } - converted := &bridgev2.ConvertedMessage{ - Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: msg, - Mentions: &event.Mentions{}, - }, - }}, - } - if err := oc.sendViaPortal(ctx, portal, instanceID, converted); err != nil { + if err := oc.sendViaPortal(ctx, portal, instanceID, bridgeadapter.BuildSystemNotice(msg)); err != nil { oc.Log().Warn().Err(err).Msg("Failed to send system notice") } } diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index a58fa893..dd9eaa75 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -11,7 +11,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" - "github.com/beeper/ai-bridge/pkg/bridgeadapter" + "github.com/beeper/ai-bridge/bridges/opencode/opencodebridge" "github.com/beeper/ai-bridge/pkg/connector/msgconv" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/maputil" @@ -116,7 +116,7 @@ func (oc *OpenCodeClient) buildStreamDBMetadata(state *openCodeStreamState) *Mes return nil } uiMessage := oc.currentCanonicalUIMessage(state) - thinking := canonicalReasoningText(uiMessage) + thinking := opencodebridge.CanonicalReasoningText(uiMessage) return &MessageMetadata{ Role: stringutil.FirstNonEmpty(state.role, "assistant"), Body: stringutil.FirstNonEmpty(state.visible.String(), state.accumulated.String()), @@ -141,8 +141,8 @@ func (oc *OpenCodeClient) buildStreamDBMetadata(state *openCodeStreamState) *Mes StartedAtMs: state.startedAtMs, CompletedAtMs: state.completedAtMs, ThinkingContent: thinking, - ToolCalls: canonicalToolCalls(uiMessage), - GeneratedFiles: canonicalGeneratedFiles(uiMessage), + ToolCalls: opencodebridge.CanonicalToolCalls(uiMessage), + GeneratedFiles: opencodebridge.CanonicalGeneratedFiles(uiMessage), } } @@ -225,87 +225,3 @@ func (oc *OpenCodeClient) queueFinalStreamEdit(ctx context.Context, portal *brid }) } -func canonicalReasoningText(uiMessage map[string]any) string { - parts, _ := uiMessage["parts"].([]any) - var sb strings.Builder - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok { - continue - } - if strings.TrimSpace(maputil.StringArg(part, "type")) != "reasoning" { - continue - } - text := maputil.StringArg(part, "text") - if text == "" { - continue - } - if sb.Len() > 0 { - sb.WriteString("\n") - } - sb.WriteString(text) - } - return sb.String() -} - -func canonicalGeneratedFiles(uiMessage map[string]any) []bridgeadapter.GeneratedFileRef { - parts, _ := uiMessage["parts"].([]any) - var refs []bridgeadapter.GeneratedFileRef - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok || strings.TrimSpace(maputil.StringArg(part, "type")) != "file" { - continue - } - url := maputil.StringArg(part, "url") - if url == "" { - continue - } - refs = append(refs, bridgeadapter.GeneratedFileRef{ - URL: url, - MimeType: stringutil.FirstNonEmpty(maputil.StringArg(part, "mediaType"), "application/octet-stream"), - }) - } - return refs -} - -func canonicalToolCalls(uiMessage map[string]any) []bridgeadapter.ToolCallMetadata { - parts, _ := uiMessage["parts"].([]any) - var calls []bridgeadapter.ToolCallMetadata - for _, raw := range parts { - part, ok := raw.(map[string]any) - if !ok || strings.TrimSpace(maputil.StringArg(part, "type")) != "dynamic-tool" { - continue - } - call := bridgeadapter.ToolCallMetadata{ - CallID: maputil.StringArg(part, "toolCallId"), - ToolName: maputil.StringArg(part, "toolName"), - ToolType: "opencode", - Status: maputil.StringArg(part, "state"), - } - if input, ok := part["input"].(map[string]any); ok { - call.Input = input - } - if output, ok := part["output"].(map[string]any); ok { - call.Output = output - } else if text := maputil.StringArg(part, "output"); text != "" { - call.Output = map[string]any{"text": text} - } - switch call.Status { - case "output-available": - call.ResultStatus = "completed" - case "output-denied": - call.ResultStatus = "denied" - case "output-error": - call.ResultStatus = "error" - call.ErrorMessage = maputil.StringArg(part, "errorText") - case "approval-requested": - call.ResultStatus = "pending_approval" - default: - call.ResultStatus = call.Status - } - if call.CallID != "" { - calls = append(calls, call) - } - } - return calls -} diff --git a/pkg/bridgeadapter/approval_decision.go b/pkg/bridgeadapter/approval_decision.go index a9e312b9..6225d0fb 100644 --- a/pkg/bridgeadapter/approval_decision.go +++ b/pkg/bridgeadapter/approval_decision.go @@ -1,6 +1,10 @@ package bridgeadapter -import "strings" +import ( + "strings" + + "maunium.net/go/mautrix/event" +) type ApprovalDecisionPayload struct { ApprovalID string @@ -9,6 +13,18 @@ type ApprovalDecisionPayload struct { Reason string } +// ParseApprovalDecisionEvent extracts the approval decision payload from a Matrix event's raw content. +func ParseApprovalDecisionEvent(evt *event.Event) (map[string]any, bool) { + if evt == nil || evt.Content.Raw == nil { + return nil, false + } + raw, ok := evt.Content.Raw["com.beeper.ai.approval_decision"].(map[string]any) + if !ok { + return nil, false + } + return raw, true +} + func ParseApprovalDecision(raw map[string]any) (ApprovalDecisionPayload, bool) { if raw == nil { return ApprovalDecisionPayload{}, false diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 20d0d52e..01f2f51d 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -4,6 +4,8 @@ import ( "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) func BuildMetaTypes(portal, message, userLogin, ghost func() any) database.MetaTypes { @@ -15,6 +17,21 @@ func BuildMetaTypes(portal, message, userLogin, ghost func() any) database.MetaT } } +// BuildSystemNotice creates a ConvertedMessage containing a single MsgNotice part. +func BuildSystemNotice(body string) *bridgev2.ConvertedMessage { + return &bridgev2.ConvertedMessage{ + Parts: []*bridgev2.ConvertedMessagePart{{ + ID: networkid.PartID("0"), + Type: event.EventMessage, + Content: &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: body, + Mentions: &event.Mentions{}, + }, + }}, + } +} + func BuildChatInfoWithFallback(metaTitle, portalName, fallbackTitle, portalTopic string) *bridgev2.ChatInfo { title := metaTitle if title == "" { diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index be3991b7..eaccaa0d 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -13,6 +13,7 @@ import ( "github.com/beeper/ai-bridge/pkg/agents" "github.com/beeper/ai-bridge/pkg/agents/tools" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/shared/stringutil" "github.com/beeper/ai-bridge/pkg/shared/toolspec" @@ -1751,18 +1752,7 @@ func (oc *AIClient) sendSystemNotice(ctx context.Context, portal *bridgev2.Porta if portal == nil || portal.MXID == "" { return } - converted := &bridgev2.ConvertedMessage{ - Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: message, - Mentions: &event.Mentions{}, - }, - }}, - } - if _, _, err := oc.sendViaPortal(ctx, portal, converted, ""); err != nil { + if _, _, err := oc.sendViaPortal(ctx, portal, bridgeadapter.BuildSystemNotice(message), ""); err != nil { oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to send system notice") } } diff --git a/pkg/connector/source_citations.go b/pkg/connector/source_citations.go index a02752b6..ff563264 100644 --- a/pkg/connector/source_citations.go +++ b/pkg/connector/source_citations.go @@ -1,7 +1,6 @@ package connector import ( - "encoding/json" "mime" "net/url" "path/filepath" @@ -77,57 +76,5 @@ func extractWebSearchCitationsFromToolOutput(toolName, output string) []citation if normalizeToolAlias(strings.TrimSpace(toolName)) != ToolNameWebSearch { return nil } - output = strings.TrimSpace(output) - if output == "" || !strings.HasPrefix(output, "{") { - return nil - } - - var payload map[string]any - if err := json.Unmarshal([]byte(output), &payload); err != nil { - return nil - } - - rawResults, ok := payload["results"].([]any) - if !ok || len(rawResults) == 0 { - return nil - } - - result := make([]citations.SourceCitation, 0, len(rawResults)) - for _, rawResult := range rawResults { - entry, ok := rawResult.(map[string]any) - if !ok { - continue - } - urlStr := maputil.StringArg(entry, "url") - if urlStr == "" { - continue - } - parsed, err := url.Parse(urlStr) - if err != nil { - continue - } - switch parsed.Scheme { - case "http", "https": - default: - continue - } - title := maputil.StringArg(entry, "title") - description := maputil.StringArg(entry, "description") - published := maputil.StringArg(entry, "published") - siteName := maputil.StringArg(entry, "siteName") - author := maputil.StringArg(entry, "author") - image := maputil.StringArg(entry, "image") - favicon := maputil.StringArg(entry, "favicon") - result = append(result, citations.SourceCitation{ - URL: urlStr, - Title: title, - Description: description, - Published: published, - SiteName: siteName, - Author: author, - Image: image, - Favicon: favicon, - }) - } - return result + return citations.ExtractWebSearchCitations(output) } diff --git a/pkg/shared/citations/web_search.go b/pkg/shared/citations/web_search.go new file mode 100644 index 00000000..bbcea98a --- /dev/null +++ b/pkg/shared/citations/web_search.go @@ -0,0 +1,61 @@ +package citations + +import ( + "encoding/json" + "net/url" + "strings" + + "github.com/beeper/ai-bridge/pkg/shared/maputil" +) + +// ExtractWebSearchCitations parses a JSON tool output containing web search results +// and returns the extracted source citations. The output is expected to be a JSON object +// with a "results" array of objects containing url, title, description, etc. +func ExtractWebSearchCitations(output string) []SourceCitation { + output = strings.TrimSpace(output) + if output == "" || !strings.HasPrefix(output, "{") { + return nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(output), &payload); err != nil { + return nil + } + + rawResults, ok := payload["results"].([]any) + if !ok || len(rawResults) == 0 { + return nil + } + + result := make([]SourceCitation, 0, len(rawResults)) + for _, rawResult := range rawResults { + entry, ok := rawResult.(map[string]any) + if !ok { + continue + } + urlStr := maputil.StringArg(entry, "url") + if urlStr == "" { + continue + } + parsed, err := url.Parse(urlStr) + if err != nil { + continue + } + switch parsed.Scheme { + case "http", "https": + default: + continue + } + result = append(result, SourceCitation{ + URL: urlStr, + Title: maputil.StringArg(entry, "title"), + Description: maputil.StringArg(entry, "description"), + Published: maputil.StringArg(entry, "published"), + SiteName: maputil.StringArg(entry, "siteName"), + Author: maputil.StringArg(entry, "author"), + Image: maputil.StringArg(entry, "image"), + Favicon: maputil.StringArg(entry, "favicon"), + }) + } + return result +} From 4f6c1a9602162cade30ba6c9f075691275a04ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:58:05 +0100 Subject: [PATCH 07/23] Centralize DM chat info & debounced edits Introduce reusable helpers in pkg/bridgeadapter: BuildDMChatInfo and SendDebouncedStreamEdit (with param structs) to centralize DM ChatInfo construction and debounced edit queuing. Replace duplicated chat/members and debounced-edit logic in bridges (bridges/codex, bridges/opencode) and connector (chat, stream_transport, response_finalization, streaming_persistence) to use the new helpers. Move stream UI message construction into buildStreamUIMessage and simplify imports/removed duplicated code to reduce repetition and prepare for shared reuse. --- bridges/codex/client.go | 47 ++------ bridges/codex/stream_transport.go | 40 ++----- .../opencodebridge/opencode_portal.go | 56 ++------- pkg/bridgeadapter/helpers.go | 106 ++++++++++++++++++ pkg/connector/chat.go | 40 ++----- pkg/connector/response_finalization.go | 21 +--- pkg/connector/stream_transport.go | 40 ++----- pkg/connector/streaming_persistence.go | 31 +---- pkg/connector/streaming_ui_helpers.go | 26 +++++ 9 files changed, 187 insertions(+), 220 deletions(-) diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 0eff4cd0..14d8d22b 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -1482,44 +1482,15 @@ func (cc *CodexClient) composeCodexChatInfo(title string) *bridgev2.ChatInfo { if title == "" { title = "Codex" } - members := bridgev2.ChatMemberMap{ - humanUserID(cc.UserLogin.ID): { - EventSender: bridgev2.EventSender{ - IsFromMe: true, - SenderLogin: cc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - }, - codexGhostID: { - EventSender: bridgev2.EventSender{ - Sender: codexGhostID, - SenderLogin: cc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr("Codex"), - IsBot: ptr.Ptr(true), - }, - MemberEventExtra: map[string]any{ - "displayname": "Codex", - }, - }, - } - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(title), - Type: ptr.Ptr(database.RoomTypeDM), - Members: &bridgev2.ChatMemberList{ - IsFull: true, - OtherUserID: codexGhostID, - MemberMap: members, - PowerLevels: &bridgev2.PowerLevelOverrides{ - Events: map[event.Type]int{ - matrixevents.RoomCapabilitiesEventType: 100, - matrixevents.RoomSettingsEventType: 0, - }, - }, - }, - } + return bridgeadapter.BuildDMChatInfo(bridgeadapter.DMChatInfoParams{ + Title: title, + HumanUserID: humanUserID(cc.UserLogin.ID), + LoginID: cc.UserLogin.ID, + BotUserID: codexGhostID, + BotDisplayName: "Codex", + CapabilitiesEvent: matrixevents.RoomCapabilitiesEventType, + SettingsEvent: matrixevents.RoomSettingsEventType, + }) } func (cc *CodexClient) buildSandboxPolicy(cwd string) map[string]any { diff --git a/bridges/codex/stream_transport.go b/bridges/codex/stream_transport.go index 577c2424..32292ab2 100644 --- a/bridges/codex/stream_transport.go +++ b/bridges/codex/stream_transport.go @@ -2,12 +2,11 @@ package codex import ( "context" - "time" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) @@ -15,34 +14,17 @@ func (cc *CodexClient) sendDebouncedStreamEdit(ctx context.Context, portal *brid if cc == nil || state == nil || portal == nil { return nil } - content := streamtransport.BuildDebouncedEditContent(streamtransport.DebouncedEditParams{ - PortalMXID: portal.MXID.String(), - Force: force, - SuppressSend: state.suppressSend, - VisibleBody: state.visibleAccumulated.String(), - FallbackBody: state.accumulated.String(), + return bridgeadapter.SendDebouncedStreamEdit(bridgeadapter.SendDebouncedStreamEditParams{ + Login: cc.UserLogin, + Portal: portal, + Sender: cc.senderForPortal(), + NetworkMessageID: state.networkMessageID, + SuppressSend: state.suppressSend, + VisibleBody: state.visibleAccumulated.String(), + FallbackBody: state.accumulated.String(), + LogKey: "codex_edit_target", + Force: force, }) - if content == nil || state.networkMessageID == "" { - return nil - } - sender := cc.senderForPortal() - cc.UserLogin.QueueRemoteEvent(&CodexRemoteEdit{ - Portal: portal.PortalKey, - Sender: sender, - TargetMessage: state.networkMessageID, - Timestamp: time.Now(), - LogKey: "codex_edit_target", - PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ - MsgType: event.MsgText, - Body: content.Body, - Format: content.Format, - FormattedBody: content.FormattedBody, - }, map[string]any{ - "com.beeper.dont_render_edited": true, - "m.mentions": map[string]any{}, - }), - }) - return nil } func (cc *CodexClient) ensureStreamSession(ctx context.Context, portal *bridgev2.Portal, state *streamingState) *streamtransport.StreamSession { diff --git a/bridges/opencode/opencodebridge/opencode_portal.go b/bridges/opencode/opencodebridge/opencode_portal.go index 9bfaf95c..92937162 100644 --- a/bridges/opencode/opencodebridge/opencode_portal.go +++ b/bridges/opencode/opencodebridge/opencode_portal.go @@ -6,12 +6,11 @@ import ( "strings" "github.com/google/uuid" - "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/event" "github.com/beeper/ai-bridge/bridges/opencode/opencode" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func (b *Bridge) ensureOpenCodeSessionPortal(ctx context.Context, inst *openCodeInstance, session opencode.Session) error { @@ -133,49 +132,16 @@ func (b *Bridge) composeOpenCodeChatInfo(title, instanceID string) *bridgev2.Cha if login == nil { return nil } - displayName := b.opencodeDisplayName(instanceID) - ownUserID := b.host.HumanUserID(login.ID) - members := bridgev2.ChatMemberMap{ - ownUserID: { - EventSender: bridgev2.EventSender{ - IsFromMe: true, - SenderLogin: login.ID, - }, - Membership: event.MembershipJoin, - }, - OpenCodeUserID(instanceID): { - EventSender: bridgev2.EventSender{ - Sender: OpenCodeUserID(instanceID), - SenderLogin: login.ID, - }, - Membership: event.MembershipJoin, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr(displayName), - IsBot: ptr.Ptr(true), - }, - MemberEventExtra: map[string]any{ - "displayname": displayName, - }, - }, - } - - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(title), - Type: ptr.Ptr(database.RoomTypeDM), - Topic: nil, - Members: &bridgev2.ChatMemberList{ - IsFull: true, - OtherUserID: OpenCodeUserID(instanceID), - MemberMap: members, - PowerLevels: &bridgev2.PowerLevelOverrides{ - Events: map[event.Type]int{ - b.host.RoomCapabilitiesEventType(): 100, - b.host.RoomSettingsEventType(): 0, - }, - }, - }, - CanBackfill: true, - } + return bridgeadapter.BuildDMChatInfo(bridgeadapter.DMChatInfoParams{ + Title: title, + HumanUserID: b.host.HumanUserID(login.ID), + LoginID: login.ID, + BotUserID: OpenCodeUserID(instanceID), + BotDisplayName: b.opencodeDisplayName(instanceID), + CanBackfill: true, + CapabilitiesEvent: b.host.RoomCapabilitiesEventType(), + SettingsEvent: b.host.RoomSettingsEventType(), + }) } func (b *Bridge) createOpenCodeSessionChat(ctx context.Context, instanceID, title string, pendingTitle bool) (*bridgev2.CreateChatResponse, error) { diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 01f2f51d..755f483e 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -1,11 +1,15 @@ package bridgeadapter import ( + "time" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + + "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) func BuildMetaTypes(portal, message, userLogin, ghost func() any) database.MetaTypes { @@ -32,6 +36,108 @@ func BuildSystemNotice(body string) *bridgev2.ConvertedMessage { } } +// SendDebouncedStreamEditParams holds the parameters for SendDebouncedStreamEdit. +type SendDebouncedStreamEditParams struct { + Login *bridgev2.UserLogin + Portal *bridgev2.Portal + Sender bridgev2.EventSender + NetworkMessageID networkid.MessageID + SuppressSend bool + VisibleBody string + FallbackBody string + LogKey string + Force bool +} + +// SendDebouncedStreamEdit builds and queues a debounced stream edit via the bridge pipeline. +func SendDebouncedStreamEdit(p SendDebouncedStreamEditParams) error { + if p.Login == nil || p.Portal == nil { + return nil + } + content := streamtransport.BuildDebouncedEditContent(streamtransport.DebouncedEditParams{ + PortalMXID: p.Portal.MXID.String(), + Force: p.Force, + SuppressSend: p.SuppressSend, + VisibleBody: p.VisibleBody, + FallbackBody: p.FallbackBody, + }) + if content == nil || p.NetworkMessageID == "" { + return nil + } + p.Login.QueueRemoteEvent(&RemoteEdit{ + Portal: p.Portal.PortalKey, + Sender: p.Sender, + TargetMessage: p.NetworkMessageID, + Timestamp: time.Now(), + LogKey: p.LogKey, + PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: content.Body, + Format: content.Format, + FormattedBody: content.FormattedBody, + }, map[string]any{ + "com.beeper.dont_render_edited": true, + "m.mentions": map[string]any{}, + }), + }) + return nil +} + +// DMChatInfoParams holds the parameters for BuildDMChatInfo. +type DMChatInfoParams struct { + Title string + HumanUserID networkid.UserID + LoginID networkid.UserLoginID + BotUserID networkid.UserID + BotDisplayName string + CanBackfill bool + CapabilitiesEvent event.Type + SettingsEvent event.Type +} + +// BuildDMChatInfo creates a ChatInfo for a DM room between a human user and a bot ghost. +func BuildDMChatInfo(p DMChatInfoParams) *bridgev2.ChatInfo { + members := bridgev2.ChatMemberMap{ + p.HumanUserID: { + EventSender: bridgev2.EventSender{ + IsFromMe: true, + SenderLogin: p.LoginID, + }, + Membership: event.MembershipJoin, + }, + p.BotUserID: { + EventSender: bridgev2.EventSender{ + Sender: p.BotUserID, + SenderLogin: p.LoginID, + }, + Membership: event.MembershipJoin, + UserInfo: &bridgev2.UserInfo{ + Name: ptr.Ptr(p.BotDisplayName), + IsBot: ptr.Ptr(true), + }, + MemberEventExtra: map[string]any{ + "displayname": p.BotDisplayName, + }, + }, + } + return &bridgev2.ChatInfo{ + Name: ptr.Ptr(p.Title), + Type: ptr.Ptr(database.RoomTypeDM), + CanBackfill: p.CanBackfill, + Members: &bridgev2.ChatMemberList{ + IsFull: true, + OtherUserID: p.BotUserID, + MemberMap: members, + PowerLevels: &bridgev2.PowerLevelOverrides{ + Events: map[event.Type]int{ + p.CapabilitiesEvent: 100, + p.SettingsEvent: 0, + }, + }, + }, + } +} + func BuildChatInfoWithFallback(metaTitle, portalName, fallbackTitle, portalTopic string) *bridgev2.ChatInfo { title := metaTitle if title == "" { diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index eaccaa0d..11aa37fe 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -1209,34 +1209,18 @@ func (oc *AIClient) composeChatInfo(title, modelID string) *bridgev2.ChatInfo { if title == "" { title = modelName } - members := bridgev2.ChatMemberMap{ - humanUserID(oc.UserLogin.ID): { - EventSender: bridgev2.EventSender{ - IsFromMe: true, - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - }, - modelUserID(modelID): modelJoinMember(oc.UserLogin.ID, modelID, modelName, modelInfo), - } - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(title), - Topic: nil, // Topic managed via Matrix events, not system prompt - Type: ptr.Ptr(database.RoomTypeDM), - Members: &bridgev2.ChatMemberList{ - IsFull: true, - OtherUserID: modelUserID(modelID), - MemberMap: members, - // Set power levels so only bridge bot can modify room_capabilities (100) - // while any user can modify room_settings (0) - PowerLevels: &bridgev2.PowerLevelOverrides{ - Events: map[event.Type]int{ - RoomCapabilitiesEventType: 100, // Only bridge bot - RoomSettingsEventType: 0, // Any user - }, - }, - }, - } + chatInfo := bridgeadapter.BuildDMChatInfo(bridgeadapter.DMChatInfoParams{ + Title: title, + HumanUserID: humanUserID(oc.UserLogin.ID), + LoginID: oc.UserLogin.ID, + BotUserID: modelUserID(modelID), + BotDisplayName: modelName, + CapabilitiesEvent: RoomCapabilitiesEventType, + SettingsEvent: RoomSettingsEventType, + }) + // Override bot member with model-specific UserInfo and extra fields. + chatInfo.Members.MemberMap[modelUserID(modelID)] = modelJoinMember(oc.UserLogin.ID, modelID, modelName, modelInfo) + return chatInfo } func (oc *AIClient) applyAgentChatInfo(chatInfo *bridgev2.ChatInfo, agentID, agentName, modelID string) { diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index b32486d2..a4e0906d 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -18,7 +18,6 @@ import ( airuntime "github.com/beeper/ai-bridge/pkg/runtime" "github.com/beeper/ai-bridge/pkg/shared/citations" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" - "github.com/beeper/ai-bridge/pkg/shared/streamui" ) const maxSafeEditPayloadBytes = 54 * 1024 @@ -600,25 +599,7 @@ func buildSourceParts(cits []citations.SourceCitation, documents []citations.Sou } func (oc *AIClient) buildFinalEditUIMessage(state *streamingState, meta *PortalMetadata, linkPreviews []*event.BeeperLinkPreview) map[string]any { - if state == nil { - return nil - } - if uiMessage := streamui.SnapshotCanonicalUIMessage(&state.ui); len(uiMessage) > 0 { - metadata, _ := uiMessage["metadata"].(map[string]any) - uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(metadata, oc.buildUIMessageMetadata(state, meta, true)) - return msgconv.AppendUIMessageArtifacts( - uiMessage, - buildSourceParts(state.sourceCitations, state.sourceDocuments, linkPreviews), - citations.GeneratedFilesToParts(state.generatedFiles), - ) - } - return msgconv.BuildUIMessage(msgconv.UIMessageParams{ - TurnID: state.turnID, - Role: "assistant", - Metadata: oc.buildUIMessageMetadata(state, meta, true), - SourceURLs: buildSourceParts(state.sourceCitations, state.sourceDocuments, linkPreviews), - FileParts: citations.GeneratedFilesToParts(state.generatedFiles), - }) + return oc.buildStreamUIMessage(state, meta, linkPreviews) } // sendFinalAssistantTurnContent is a helper for simple mode that sends content without directive processing. diff --git a/pkg/connector/stream_transport.go b/pkg/connector/stream_transport.go index 731a042c..2908d2cc 100644 --- a/pkg/connector/stream_transport.go +++ b/pkg/connector/stream_transport.go @@ -2,45 +2,25 @@ package connector import ( "context" - "time" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/event" "github.com/beeper/ai-bridge/pkg/bridgeadapter" - "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) func (oc *AIClient) sendDebouncedStreamEdit(ctx context.Context, portal *bridgev2.Portal, state *streamingState, force bool) error { if oc == nil || state == nil || portal == nil { return nil } - content := streamtransport.BuildDebouncedEditContent(streamtransport.DebouncedEditParams{ - PortalMXID: portal.MXID.String(), - Force: force, - SuppressSend: state.suppressSend, - VisibleBody: state.visibleAccumulated.String(), - FallbackBody: state.accumulated.String(), + return bridgeadapter.SendDebouncedStreamEdit(bridgeadapter.SendDebouncedStreamEditParams{ + Login: oc.UserLogin, + Portal: portal, + Sender: oc.senderForPortal(ctx, portal), + NetworkMessageID: state.networkMessageID, + SuppressSend: state.suppressSend, + VisibleBody: state.visibleAccumulated.String(), + FallbackBody: state.accumulated.String(), + LogKey: "ai_edit_target", + Force: force, }) - if content == nil || state.networkMessageID == "" { - return nil - } - sender := oc.senderForPortal(ctx, portal) - oc.UserLogin.QueueRemoteEvent(&bridgeadapter.RemoteEdit{ - Portal: portal.PortalKey, - Sender: sender, - TargetMessage: state.networkMessageID, - Timestamp: time.Now(), - LogKey: "ai_edit_target", - PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ - MsgType: event.MsgText, - Body: content.Body, - Format: content.Format, - FormattedBody: content.FormattedBody, - }, map[string]any{ - "com.beeper.dont_render_edited": true, - "m.mentions": map[string]any{}, - }), - }) - return nil } diff --git a/pkg/connector/streaming_persistence.go b/pkg/connector/streaming_persistence.go index a5b27b4b..d520cdcc 100644 --- a/pkg/connector/streaming_persistence.go +++ b/pkg/connector/streaming_persistence.go @@ -11,9 +11,6 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "github.com/beeper/ai-bridge/pkg/bridgeadapter" - "github.com/beeper/ai-bridge/pkg/connector/msgconv" - "github.com/beeper/ai-bridge/pkg/shared/citations" - "github.com/beeper/ai-bridge/pkg/shared/streamui" ) // saveAssistantMessage saves the completed assistant message to the database. @@ -152,31 +149,5 @@ func thinkingTokenCount(model string, content string) int { } func (oc *AIClient) buildCanonicalUIMessage(state *streamingState, meta *PortalMetadata) map[string]any { - if state == nil { - return nil - } - if uiMessage := streamui.SnapshotCanonicalUIMessage(&state.ui); len(uiMessage) > 0 { - metadata, _ := uiMessage["metadata"].(map[string]any) - uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(metadata, msgconv.BuildUIMessageMetadata(msgconv.UIMessageMetadataParams{ - TurnID: state.turnID, - AgentID: state.agentID, - Model: oc.effectiveModel(meta), - FinishReason: state.finishReason, - PromptTokens: state.promptTokens, - CompletionTokens: state.completionTokens, - ReasoningTokens: state.reasoningTokens, - StartedAtMs: state.startedAtMs, - FirstTokenAtMs: state.firstTokenAtMs, - CompletedAtMs: state.completedAtMs, - IncludeUsage: true, - })) - return msgconv.AppendUIMessageArtifacts(uiMessage, buildSourceParts(state.sourceCitations, state.sourceDocuments, nil), citations.GeneratedFilesToParts(state.generatedFiles)) - } - return msgconv.BuildUIMessage(msgconv.UIMessageParams{ - TurnID: state.turnID, - Role: "assistant", - Metadata: oc.buildUIMessageMetadata(state, meta, true), - SourceURLs: buildSourceParts(state.sourceCitations, state.sourceDocuments, nil), - FileParts: citations.GeneratedFilesToParts(state.generatedFiles), - }) + return oc.buildStreamUIMessage(state, meta, nil) } diff --git a/pkg/connector/streaming_ui_helpers.go b/pkg/connector/streaming_ui_helpers.go index a6fb0ca5..4d4b67c7 100644 --- a/pkg/connector/streaming_ui_helpers.go +++ b/pkg/connector/streaming_ui_helpers.go @@ -5,7 +5,11 @@ import ( "unicode" "unicode/utf8" + "maunium.net/go/mautrix/event" + "github.com/beeper/ai-bridge/pkg/connector/msgconv" + "github.com/beeper/ai-bridge/pkg/shared/citations" + "github.com/beeper/ai-bridge/pkg/shared/streamui" ) func (oc *AIClient) buildUIMessageMetadata(state *streamingState, meta *PortalMetadata, includeUsage bool) map[string]any { @@ -25,6 +29,28 @@ func (oc *AIClient) buildUIMessageMetadata(state *streamingState, meta *PortalMe }) } +// buildStreamUIMessage constructs the canonical UI message for streaming edits and persistence. +// linkPreviews may be nil for intermediate saves. +func (oc *AIClient) buildStreamUIMessage(state *streamingState, meta *PortalMetadata, linkPreviews []*event.BeeperLinkPreview) map[string]any { + if state == nil { + return nil + } + sourceParts := buildSourceParts(state.sourceCitations, state.sourceDocuments, linkPreviews) + fileParts := citations.GeneratedFilesToParts(state.generatedFiles) + if uiMessage := streamui.SnapshotCanonicalUIMessage(&state.ui); len(uiMessage) > 0 { + metadata, _ := uiMessage["metadata"].(map[string]any) + uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(metadata, oc.buildUIMessageMetadata(state, meta, true)) + return msgconv.AppendUIMessageArtifacts(uiMessage, sourceParts, fileParts) + } + return msgconv.BuildUIMessage(msgconv.UIMessageParams{ + TurnID: state.turnID, + Role: "assistant", + Metadata: oc.buildUIMessageMetadata(state, meta, true), + SourceURLs: sourceParts, + FileParts: fileParts, + }) +} + func mapFinishReason(reason string) string { return msgconv.MapFinishReason(reason) } From 8d688ad91eb7581f20ba16641aad3108fc035a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 18:59:18 +0100 Subject: [PATCH 08/23] Remove stray blank lines Trim unnecessary blank lines in backfill_canonical.go, stream_canonical.go, and remote_events.go. Pure formatting cleanup with no functional changes. --- bridges/opencode/opencodebridge/backfill_canonical.go | 1 - bridges/opencode/stream_canonical.go | 1 - pkg/connector/remote_events.go | 1 - 3 files changed, 3 deletions(-) diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index 33459851..cf0766ad 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -242,7 +242,6 @@ func canonicalDataPart(part opencode.Part) map[string]any { return data } - func backfillCost(msg opencode.MessageWithParts) float64 { if msg.Info.Cost != 0 { return msg.Info.Cost diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index dd9eaa75..87d2d0e7 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -224,4 +224,3 @@ func (oc *OpenCodeClient) queueFinalStreamEdit(ctx context.Context, portal *brid }, topLevelExtra), }) } - diff --git a/pkg/connector/remote_events.go b/pkg/connector/remote_events.go index ff57e808..95ab7898 100644 --- a/pkg/connector/remote_events.go +++ b/pkg/connector/remote_events.go @@ -183,4 +183,3 @@ func NewAITextMessage( }, } } - From fba1eb2316d2f5cf88a242328f87ec13fb9b90dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:08:56 +0100 Subject: [PATCH 09/23] Add bridgeadapter helpers and assorted refactors Add several bridgeadapter utilities and refactor callsites to centralize common logic and modernize helpers. Key changes: - Add bridgeadapter helpers: MatrixMessageID, NewEventID, SendMatrixMessageStatus and LoadOrCreateTypedClient to centralize ID creation, status sending and typed client caching. - Introduce providerJSONToolOutput helper in Codex client to consolidate JSON tool-output handling and side effects. - Add streamtransport.BuildRenderedConvertedEdit and new RenderedMarkdownContent to unify building converted edits. - Replace many manual slice copies with slices.Clone and use exslices.DeduplicateUnsorted for dedupe logic to avoid accidental in-place mutations. - Use strings.Cut / CutPrefix patterns and range-based loops to simplify string/loop logic. - Minor error formatting fixes (use err instead of err.Error) and other small cleanups across connectors, integrators, and runtime code. - Remove some dead/duplicated functions and consolidate builtin tool caching into a safe once/cached initializer. These changes reduce duplication, improve type safety for cached clients, and standardize ID/status handling across the codebase. --- bridges/codex/client.go | 97 +++++++++++--------- bridges/codex/connector.go | 18 +--- bridges/opencode/connector.go | 14 +-- bridges/opencode/stream_canonical.go | 4 +- cmd/bridgectl/main.go | 2 +- pkg/agents/tools/results.go | 2 +- pkg/bridgeadapter/client_cache.go | 47 ++++++++++ pkg/bridgeadapter/helpers.go | 3 +- pkg/bridgeadapter/identifier_helpers.go | 6 ++ pkg/bridgeadapter/status_helpers.go | 13 +++ pkg/connector/agentstore.go | 5 +- pkg/connector/bootstrap_context.go | 2 +- pkg/connector/chat.go | 3 +- pkg/connector/client.go | 2 +- pkg/connector/commands.go | 2 +- pkg/connector/compaction_summarization.go | 3 +- pkg/connector/errors.go | 26 ------ pkg/connector/handleai.go | 12 +-- pkg/connector/handlematrix.go | 12 +-- pkg/connector/internal_dispatch.go | 8 +- pkg/connector/messages.go | 29 +----- pkg/connector/response_retry.go | 3 +- pkg/connector/room_runs.go | 2 +- pkg/connector/streaming_function_calls.go | 11 +-- pkg/connector/subagent_spawn.go | 7 +- pkg/connector/tools.go | 29 +++--- pkg/integrations/memory/flush_tool_loop.go | 3 +- pkg/integrations/memory/overflow_exec.go | 3 +- pkg/integrations/memory/prompt_exec.go | 3 +- pkg/runtime/compaction_overflow.go | 7 +- pkg/shared/streamtransport/converted_edit.go | 21 +++++ pkg/shared/streamtransport/session.go | 2 +- pkg/shared/stringutil/dedupe.go | 21 ++--- pkg/textfs/apply_patch.go | 13 +-- pkg/textfs/truncate.go | 2 +- 35 files changed, 229 insertions(+), 208 deletions(-) create mode 100644 pkg/shared/streamtransport/converted_edit.go diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 14d8d22b..02a68b5f 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -487,7 +487,7 @@ func (cc *CodexClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma // Save user message immediately; we return Pending=true. userMsg := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(msg.Event.ID))), + ID: bridgeadapter.MatrixMessageID(msg.Event.ID), MXID: msg.Event.ID, Room: portal.PortalKey, SenderID: humanUserID(cc.UserLogin.ID), @@ -1125,33 +1125,15 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 } state.toolCalls = append(state.toolCalls, tc) case "collabToolCall": - var it map[string]any - _ = json.Unmarshal(raw, &it) - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) - newDocs, newFiles := collectToolOutputArtifacts(state, it) - emitNewArtifacts(ctx, portal, cc.uiEmitter(state), newDocs, newFiles) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "collabToolCall", it)) + cc.emitProviderJSONToolOutput(ctx, portal, state, itemID, "collabToolCall", raw, providerJSONToolOutputOptions{collectArtifacts: true}) case "webSearch": - var it map[string]any - _ = json.Unmarshal(raw, &it) - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "webSearch", it)) - // Extract web search citations and emit source-url stream events. - if outputJSON, err := json.Marshal(it); err == nil { - collectToolOutputCitations(state, "webSearch", string(outputJSON)) - for _, citation := range state.sourceCitations { - cc.uiEmitter(state).EmitUISourceURL(ctx, portal, citation) - } - } - newDocs, newFiles := collectToolOutputArtifacts(state, it) - emitNewArtifacts(ctx, portal, cc.uiEmitter(state), newDocs, newFiles) + cc.emitProviderJSONToolOutput(ctx, portal, state, itemID, "webSearch", raw, providerJSONToolOutputOptions{ + collectArtifacts: true, + collectCitations: true, + appendBeforeSideEffects: true, + }) case "imageView": - var it map[string]any - _ = json.Unmarshal(raw, &it) - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) - newDocs, newFiles := collectToolOutputArtifacts(state, it) - emitNewArtifacts(ctx, portal, cc.uiEmitter(state), newDocs, newFiles) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "imageView", it)) + cc.emitProviderJSONToolOutput(ctx, portal, state, itemID, "imageView", raw, providerJSONToolOutputOptions{collectArtifacts: true}) case "plan": var it struct { Text string `json:"text"` @@ -1161,10 +1143,7 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 return } case "enteredReviewMode": - var it map[string]any - _ = json.Unmarshal(raw, &it) - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "review", it)) + cc.emitProviderJSONToolOutput(ctx, portal, state, itemID, "review", raw, providerJSONToolOutputOptions{}) case "exitedReviewMode": var it struct { Review string `json:"review"` @@ -1174,14 +1153,52 @@ func (cc *CodexClient) handleItemCompleted(ctx context.Context, portal *bridgev2 return } case "contextCompaction": - var it map[string]any - _ = json.Unmarshal(raw, &it) - cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) - state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, "contextCompaction", it)) + cc.emitProviderJSONToolOutput(ctx, portal, state, itemID, "contextCompaction", raw, providerJSONToolOutputOptions{}) cc.sendSystemNoticeOnce(ctx, portal, state, "compaction:completed:"+itemID, "Codex finished compacting context.") } } +type providerJSONToolOutputOptions struct { + collectArtifacts bool + collectCitations bool + appendBeforeSideEffects bool +} + +func (cc *CodexClient) emitProviderJSONToolOutput( + ctx context.Context, + portal *bridgev2.Portal, + state *streamingState, + itemID string, + toolName string, + raw []byte, + opts providerJSONToolOutputOptions, +) { + var it map[string]any + _ = json.Unmarshal(raw, &it) + cc.uiEmitter(state).EmitUIToolOutputAvailable(ctx, portal, itemID, it, true, false) + appendToolCall := func() { + state.toolCalls = append(state.toolCalls, newProviderToolCall(itemID, toolName, it)) + } + if opts.appendBeforeSideEffects { + appendToolCall() + } + if opts.collectCitations { + if outputJSON, err := json.Marshal(it); err == nil { + collectToolOutputCitations(state, toolName, string(outputJSON)) + for _, citation := range state.sourceCitations { + cc.uiEmitter(state).EmitUISourceURL(ctx, portal, citation) + } + } + } + if opts.collectArtifacts { + newDocs, newFiles := collectToolOutputArtifacts(state, it) + emitNewArtifacts(ctx, portal, cc.uiEmitter(state), newDocs, newFiles) + } + if !opts.appendBeforeSideEffects { + appendToolCall() + } +} + func (cc *CodexClient) emitTrimmedProviderToolTextOutput( ctx context.Context, portal *bridgev2.Portal, @@ -1730,23 +1747,20 @@ func (cc *CodexClient) sendApprovalRequestFallbackEvent( } func (cc *CodexClient) sendPendingStatus(ctx context.Context, portal *bridgev2.Portal, evt *event.Event, message string) { - if portal == nil || portal.Bridge == nil || evt == nil { - return - } st := bridgev2.MessageStatus{ Status: event.MessageStatusPending, Message: message, IsCertain: true, } - portal.Bridge.Matrix.SendMessageStatus(ctx, &st, bridgev2.StatusEventInfoFromEvent(evt)) + bridgeadapter.SendMatrixMessageStatus(ctx, portal, evt, st) } func (cc *CodexClient) markMessageSendSuccess(ctx context.Context, portal *bridgev2.Portal, evt *event.Event, state *streamingState) { - if portal == nil || portal.Bridge == nil || evt == nil || state == nil { + if state == nil { return } st := bridgev2.MessageStatus{Status: event.MessageStatusSuccess, IsCertain: true} - portal.Bridge.Matrix.SendMessageStatus(ctx, &st, bridgev2.StatusEventInfoFromEvent(evt)) + bridgeadapter.SendMatrixMessageStatus(ctx, portal, evt, st) } func (cc *CodexClient) acquireRoomIfQueueEmpty(roomID id.RoomID) bool { @@ -1976,8 +1990,7 @@ func (cc *CodexClient) sendFinalAssistantTurn(ctx context.Context, portal *bridg TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "codex_edit_target", - PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ - MsgType: event.MsgText, + PreBuilt: streamtransport.BuildRenderedConvertedEdit(streamtransport.RenderedMarkdownContent{ Body: rendered.Body, Format: rendered.Format, FormattedBody: rendered.FormattedBody, diff --git a/bridges/codex/connector.go b/bridges/codex/connector.go index 19ac3896..9b6efcc4 100644 --- a/bridges/codex/connector.go +++ b/bridges/codex/connector.go @@ -314,20 +314,14 @@ func (cc *CodexConnector) loadCodexUserLogin(login *bridgev2.UserLogin) error { return nil } - client, err := bridgeadapter.LoadOrCreateClient( + client, err := bridgeadapter.LoadOrCreateTypedClient( &cc.clientsMu, cc.clients, - login.ID, - func(existingAPI bridgev2.NetworkAPI) bool { - existing, ok := existingAPI.(*CodexClient) - if !ok || existing == nil { - return false - } + login, + func(existing *CodexClient, login *bridgev2.UserLogin) { existing.UserLogin = login - login.Client = existing - return true }, - func() (bridgev2.NetworkAPI, error) { + func() (*CodexClient, error) { return newCodexClient(login, cc) }, ) @@ -336,9 +330,7 @@ func (cc *CodexConnector) loadCodexUserLogin(login *bridgev2.UserLogin) error { return nil } login.Client = client - if codexClient, ok := client.(*CodexClient); ok { - codexClient.scheduleBootstrap() - } + client.scheduleBootstrap() return nil } diff --git a/bridges/opencode/connector.go b/bridges/opencode/connector.go index 36724dd1..20b7bce7 100644 --- a/bridges/opencode/connector.go +++ b/bridges/opencode/connector.go @@ -99,20 +99,14 @@ func (oc *OpenCodeConnector) LoadUserLogin(_ context.Context, login *bridgev2.Us return nil } - client, err := bridgeadapter.LoadOrCreateClient( + client, err := bridgeadapter.LoadOrCreateTypedClient( &oc.clientsMu, oc.clients, - login.ID, - func(existingAPI bridgev2.NetworkAPI) bool { - existing, ok := existingAPI.(*OpenCodeClient) - if !ok || existing == nil { - return false - } + login, + func(existing *OpenCodeClient, login *bridgev2.UserLogin) { existing.UserLogin = login - login.Client = existing - return true }, - func() (bridgev2.NetworkAPI, error) { + func() (*OpenCodeClient, error) { return newOpenCodeClient(login, oc) }, ) diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index 87d2d0e7..598cead0 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -8,7 +8,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "github.com/beeper/ai-bridge/bridges/opencode/opencodebridge" @@ -216,8 +215,7 @@ func (oc *OpenCodeClient) queueFinalStreamEdit(ctx context.Context, portal *brid TargetMessage: state.networkMessageID, Timestamp: time.Now(), LogKey: "opencode_edit_target", - PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ - MsgType: event.MsgText, + PreBuilt: streamtransport.BuildRenderedConvertedEdit(streamtransport.RenderedMarkdownContent{ Body: rendered.Body, Format: rendered.Format, FormattedBody: rendered.FormattedBody, diff --git a/cmd/bridgectl/main.go b/cmd/bridgectl/main.go index 09384443..26660ebf 100644 --- a/cmd/bridgectl/main.go +++ b/cmd/bridgectl/main.go @@ -1071,7 +1071,7 @@ func setPath(root map[string]any, parts []string, value any) { return } cur := root - for i := 0; i < len(parts)-1; i++ { + for i := range len(parts) - 1 { key := parts[i] next, ok := cur[key] if !ok { diff --git a/pkg/agents/tools/results.go b/pkg/agents/tools/results.go index f4ace030..5430aaed 100644 --- a/pkg/agents/tools/results.go +++ b/pkg/agents/tools/results.go @@ -32,7 +32,7 @@ func ErrorResult(toolName, message string) *Result { func mustJSON(v any) string { data, err := json.Marshal(v) if err != nil { - return fmt.Sprintf(`{"error":"failed to marshal: %s"}`, err.Error()) + return fmt.Sprintf(`{"error":"failed to marshal: %s"}`, err) } return string(data) } diff --git a/pkg/bridgeadapter/client_cache.go b/pkg/bridgeadapter/client_cache.go index 5a9e6ef4..2fd760f6 100644 --- a/pkg/bridgeadapter/client_cache.go +++ b/pkg/bridgeadapter/client_cache.go @@ -2,6 +2,7 @@ package bridgeadapter import ( "context" + "fmt" "maps" "sync" @@ -49,6 +50,52 @@ func LoadOrCreateClient( return client, nil } +// LoadOrCreateTypedClient wraps LoadOrCreateClient with typed reuse/create callbacks. +func LoadOrCreateTypedClient[T bridgev2.NetworkAPI]( + mu *sync.Mutex, + clients map[networkid.UserLoginID]bridgev2.NetworkAPI, + login *bridgev2.UserLogin, + reuse func(T, *bridgev2.UserLogin), + create func() (T, error), +) (T, error) { + var zero T + if login == nil { + return zero, fmt.Errorf("login is nil") + } + client, err := LoadOrCreateClient( + mu, + clients, + login.ID, + func(existingAPI bridgev2.NetworkAPI) bool { + existing, ok := existingAPI.(T) + if !ok { + return false + } + if reuse != nil { + reuse(existing, login) + } + login.Client = existing + return true + }, + func() (bridgev2.NetworkAPI, error) { + client, err := create() + if err != nil { + return nil, err + } + login.Client = client + return client, nil + }, + ) + if err != nil { + return zero, err + } + typed, ok := client.(T) + if !ok { + return zero, fmt.Errorf("unexpected client type %T", client) + } + return typed, nil +} + // RemoveClientFromCache removes a client from the cache by login ID. func RemoveClientFromCache( mu *sync.Mutex, diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 755f483e..747fda54 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -70,8 +70,7 @@ func SendDebouncedStreamEdit(p SendDebouncedStreamEditParams) error { TargetMessage: p.NetworkMessageID, Timestamp: time.Now(), LogKey: p.LogKey, - PreBuilt: streamtransport.BuildConvertedEdit(&event.MessageEventContent{ - MsgType: event.MsgText, + PreBuilt: streamtransport.BuildRenderedConvertedEdit(streamtransport.RenderedMarkdownContent{ Body: content.Body, Format: content.Format, FormattedBody: content.FormattedBody, diff --git a/pkg/bridgeadapter/identifier_helpers.go b/pkg/bridgeadapter/identifier_helpers.go index c393892f..7b142808 100644 --- a/pkg/bridgeadapter/identifier_helpers.go +++ b/pkg/bridgeadapter/identifier_helpers.go @@ -4,6 +4,7 @@ import ( "fmt" "net/url" + "github.com/google/uuid" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" @@ -13,6 +14,11 @@ func MatrixMessageID(eventID id.EventID) networkid.MessageID { return networkid.MessageID("mx:" + string(eventID)) } +// NewEventID generates a unique Matrix-style event ID with the given prefix. +func NewEventID(prefix string) id.EventID { + return id.EventID(fmt.Sprintf("$%s-%s", prefix, uuid.NewString())) +} + func HumanUserID(prefix string, loginID networkid.UserLoginID) networkid.UserID { return networkid.UserID(prefix + ":" + string(loginID)) } diff --git a/pkg/bridgeadapter/status_helpers.go b/pkg/bridgeadapter/status_helpers.go index 3f0e8d8a..5812ba4f 100644 --- a/pkg/bridgeadapter/status_helpers.go +++ b/pkg/bridgeadapter/status_helpers.go @@ -1,6 +1,7 @@ package bridgeadapter import ( + "context" "errors" "maunium.net/go/mautrix/bridgev2" @@ -47,3 +48,15 @@ func MessageSendStatusError( } return st } + +func SendMatrixMessageStatus( + ctx context.Context, + portal *bridgev2.Portal, + evt *event.Event, + status bridgev2.MessageStatus, +) { + if portal == nil || portal.Bridge == nil || evt == nil { + return + } + portal.Bridge.Matrix.SendMessageStatus(ctx, &status, bridgev2.StatusEventInfoFromEvent(evt)) +} diff --git a/pkg/connector/agentstore.go b/pkg/connector/agentstore.go index f3d5438b..20fd8611 100644 --- a/pkg/connector/agentstore.go +++ b/pkg/connector/agentstore.go @@ -14,10 +14,9 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "github.com/google/uuid" - "github.com/beeper/ai-bridge/pkg/agents" "github.com/beeper/ai-bridge/pkg/agents/tools" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) // AgentStoreAdapter implements agents.AgentStore with UserLogin metadata as source of truth. @@ -383,7 +382,7 @@ func (b *BossStoreAdapter) RunInternalCommand(ctx context.Context, roomID string runCtx := b.store.client.backgroundContext(ctx) logCopy := b.store.client.log.With().Str("mx_command", cmdName).Logger() captureBot := newCaptureMatrixAPI(b.store.client.UserLogin.Bridge.Bot) - eventID := id.EventID(fmt.Sprintf("$internal-%s", uuid.NewString())) + eventID := bridgeadapter.NewEventID("internal") ce := &commands.Event{ Bot: captureBot, Bridge: b.store.client.UserLogin.Bridge, diff --git a/pkg/connector/bootstrap_context.go b/pkg/connector/bootstrap_context.go index baafdf14..72f13d0b 100644 --- a/pkg/connector/bootstrap_context.go +++ b/pkg/connector/bootstrap_context.go @@ -76,7 +76,7 @@ func userMdHasValues(content string) bool { v := strings.TrimSpace(value) // Strip common markdown emphasis markers. Do this a couple times because removing "**" // can expose leading whitespace before another marker like "*...*". - for i := 0; i < 2; i++ { + for range 2 { v = strings.TrimSpace(strings.Trim(v, "*_")) } if strings.HasPrefix(v, "(") && strings.HasSuffix(v, ")") { diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 11aa37fe..03e80aa9 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/ptr" @@ -1116,7 +1115,7 @@ func (oc *AIClient) copyMessagesToChat( // Create remote message for bridging remoteMsg := &OpenAIRemoteMessage{ PortalKey: destPortal.PortalKey, - ID: networkid.MessageID(fmt.Sprintf("fork:%s", uuid.NewString())), + ID: bridgeadapter.NewMessageID("fork"), Sender: sender, Content: srcMeta.Body, Timestamp: srcMsg.Timestamp, diff --git a/pkg/connector/client.go b/pkg/connector/client.go index f26846e9..b7acb3c7 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -2570,7 +2570,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { // Create user message for database userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(last.Event.ID))), + ID: bridgeadapter.MatrixMessageID(last.Event.ID), MXID: last.Event.ID, Room: last.Portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index 07346cbe..db44f646 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -86,7 +86,7 @@ func isValidAgentID(agentID string) bool { if agentID == "" { return false } - for i := 0; i < len(agentID); i++ { + for i := range len(agentID) { ch := agentID[i] if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '-' { return false diff --git a/pkg/connector/compaction_summarization.go b/pkg/connector/compaction_summarization.go index 6bc63bcc..30421fc1 100644 --- a/pkg/connector/compaction_summarization.go +++ b/pkg/connector/compaction_summarization.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "slices" "strings" "time" @@ -623,7 +624,7 @@ func (oc *AIClient) applyCompactionModelSummaryAndRefresh( dropped := selectDroppedCompactionMessages(originalPrompt, compactedPrompt, decision.DroppedCount) if len(dropped) > 0 { model := resolveCompactionSummaryModel(oc.effectiveModel(meta), oc.pruningSummarizationModel()) - allMessages := append([]openai.ChatCompletionMessageParamUnion{}, dropped...) + allMessages := slices.Clone(dropped) allMessages = append(allMessages, compactedPrompt...) adaptive := computeCompactionAdaptiveChunkRatio(allMessages, model, contextWindowTokens) maxChunkTokens := int(math.Floor(float64(contextWindowTokens)*adaptive)) - compactionSummarizationOverhead diff --git a/pkg/connector/errors.go b/pkg/connector/errors.go index 86878251..9e0ca0ed 100644 --- a/pkg/connector/errors.go +++ b/pkg/connector/errors.go @@ -276,29 +276,3 @@ func IsToolSchemaError(err error) bool { } return false } - -// IsToolUniquenessError checks if the error indicates duplicate tool names. -func IsToolUniquenessError(err error) bool { - var apiErr *openai.Error - if errors.As(err, &apiErr) { - if strings.Contains(apiErr.Message, "tools: Tool names must be unique") { - return true - } - raw := apiErr.RawJSON() - if raw != "" && strings.Contains(raw, "tools: Tool names must be unique") { - return true - } - } - return false -} - -// IsNoResponseChunksError checks if the Responses streaming returned no chunks. -func IsNoResponseChunksError(err error) bool { - for err != nil { - if strings.Contains(err.Error(), "No response chunks received") { - return true - } - err = errors.Unwrap(err) - } - return false -} diff --git a/pkg/connector/handleai.go b/pkg/connector/handleai.go index eea2b539..04c6f8d0 100644 --- a/pkg/connector/handleai.go +++ b/pkg/connector/handleai.go @@ -12,6 +12,8 @@ import ( "github.com/openai/openai-go/v3/shared" "github.com/rs/zerolog" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/bridgev2/status" @@ -153,26 +155,20 @@ func (oc *AIClient) setModelTyping(ctx context.Context, portal *bridgev2.Portal, } func (oc *AIClient) sendPendingStatus(ctx context.Context, portal *bridgev2.Portal, evt *event.Event, message string) { - if portal == nil || portal.Bridge == nil || evt == nil { - return - } status := bridgev2.MessageStatus{ Status: event.MessageStatusPending, Message: message, IsCertain: true, } - portal.Bridge.Matrix.SendMessageStatus(ctx, &status, bridgev2.StatusEventInfoFromEvent(evt)) + bridgeadapter.SendMatrixMessageStatus(ctx, portal, evt, status) } func (oc *AIClient) sendSuccessStatus(ctx context.Context, portal *bridgev2.Portal, evt *event.Event) { - if portal == nil || portal.Bridge == nil || evt == nil { - return - } status := bridgev2.MessageStatus{ Status: event.MessageStatusSuccess, IsCertain: true, } - portal.Bridge.Matrix.SendMessageStatus(ctx, &status, bridgev2.StatusEventInfoFromEvent(evt)) + bridgeadapter.SendMatrixMessageStatus(ctx, portal, evt, status) } const autoGreetingDelay = 5 * time.Second diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index 16554033..1876a367 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -313,7 +313,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri } logCtx.Debug().Int("prompt_messages", len(promptContext.Messages)).Msg("Built prompt for inbound message") userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(eventID))), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), @@ -722,7 +722,7 @@ func (oc *AIClient) handleMediaMessage( return nil, messageSendStatusError(err, "Couldn't prepare the message. Try again.", "") } userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(eventID))), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), @@ -858,7 +858,7 @@ func (oc *AIClient) handleMediaMessage( } userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(eventID))), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), @@ -1018,7 +1018,7 @@ func (oc *AIClient) handleTextFileMessage( } userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", string(eventID))), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), @@ -1220,8 +1220,8 @@ func (oc *AIClient) handleToolsCommand( return } - parts := strings.SplitN(arg, " ", 2) - action := strings.ToLower(parts[0]) + action, _, _ := strings.Cut(arg, " ") + action = strings.ToLower(action) switch action { case "list": diff --git a/pkg/connector/internal_dispatch.go b/pkg/connector/internal_dispatch.go index 12a2429d..73763247 100644 --- a/pkg/connector/internal_dispatch.go +++ b/pkg/connector/internal_dispatch.go @@ -3,16 +3,14 @@ package connector import ( "context" "errors" - "fmt" "strings" "time" - "github.com/google/uuid" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" airuntime "github.com/beeper/ai-bridge/pkg/runtime" ) @@ -53,7 +51,7 @@ func (oc *AIClient) dispatchInternalMessage( if src := strings.TrimSpace(source); src != "" { prefix = src } - eventID := id.EventID(fmt.Sprintf("$%s-%s", prefix, uuid.NewString())) + eventID := bridgeadapter.NewEventID(prefix) inboundCtx := oc.resolvePromptInboundContext(ctx, portal, trimmed, eventID) promptCtx := withInboundContext(ctx, inboundCtx) @@ -63,7 +61,7 @@ func (oc *AIClient) dispatchInternalMessage( } userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", eventID)), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), diff --git a/pkg/connector/messages.go b/pkg/connector/messages.go index c71ca0f6..e45bd2bc 100644 --- a/pkg/connector/messages.go +++ b/pkg/connector/messages.go @@ -1,6 +1,7 @@ package connector import ( + "slices" "strings" "github.com/openai/openai-go/v3" @@ -131,27 +132,6 @@ func (m *UnifiedMessage) Text() string { return strings.Join(texts, "\n") } -// HasImages returns true if the message contains image content. -func (m *UnifiedMessage) HasImages() bool { - for _, part := range m.Content { - if part.Type == ContentTypeImage { - return true - } - } - return false -} - -// HasMultimodalContent returns true if the message contains any non-text content. -func (m *UnifiedMessage) HasMultimodalContent() bool { - for _, part := range m.Content { - switch part.Type { - case ContentTypeImage, ContentTypePDF, ContentTypeAudio, ContentTypeVideo: - return true - } - } - return false -} - // Text returns the text content of a canonical prompt message. func (m PromptMessage) Text() string { var texts []string @@ -171,7 +151,7 @@ func (m PromptMessage) Text() string { func ToPromptContext(systemPrompt string, tools []ToolDefinition, messages []UnifiedMessage) PromptContext { ctx := PromptContext{ SystemPrompt: strings.TrimSpace(systemPrompt), - Tools: append([]ToolDefinition(nil), tools...), + Tools: slices.Clone(tools), } systemParts := make([]string, 0, len(messages)) @@ -455,10 +435,11 @@ func joinChatText[T any](parts []T, extract func(T) string) string { func inferPromptMimeTypeFromDataURL(value string) string { value = strings.TrimSpace(value) - if !strings.HasPrefix(value, "data:") { + rest, ok := strings.CutPrefix(value, "data:") + if !ok { return "" } - value = strings.TrimPrefix(value, "data:") + value = rest idx := strings.Index(value, ";") if idx <= 0 { return "" diff --git a/pkg/connector/response_retry.go b/pkg/connector/response_retry.go index be347445..0c26801d 100644 --- a/pkg/connector/response_retry.go +++ b/pkg/connector/response_retry.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "slices" "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" @@ -448,7 +449,7 @@ func (oc *AIClient) truncateOversizedToolResultsForOverflow( } } - out := append([]openai.ChatCompletionMessageParamUnion{}, prompt...) + out := slices.Clone(prompt) truncated := 0 for i, msg := range out { if msg.OfTool == nil { diff --git a/pkg/connector/room_runs.go b/pkg/connector/room_runs.go index fa2bef54..39ef261b 100644 --- a/pkg/connector/room_runs.go +++ b/pkg/connector/room_runs.go @@ -147,7 +147,7 @@ func (oc *AIClient) roomRunStatusEvents(roomID id.RoomID) []*event.Event { return nil } run.mu.Lock() - events := append([]*event.Event(nil), run.statusEvents...) + events := slices.Clone(run.statusEvents) run.mu.Unlock() return events } diff --git a/pkg/connector/streaming_function_calls.go b/pkg/connector/streaming_function_calls.go index cc1d01e8..81b1bee4 100644 --- a/pkg/connector/streaming_function_calls.go +++ b/pkg/connector/streaming_function_calls.go @@ -26,8 +26,7 @@ func (oc *AIClient) processToolMediaResult( logSuffix string, ) (string, ResultStatus) { // TTS audio (AUDIO: prefix) - if strings.HasPrefix(result, TTSResultPrefix) { - audioB64 := strings.TrimPrefix(result, TTSResultPrefix) + if audioB64, ok := strings.CutPrefix(result, TTSResultPrefix); ok { audioData, err := base64.StdEncoding.DecodeString(audioB64) if err != nil { log.Warn().Err(err).Msg("Failed to decode TTS audio" + logSuffix) @@ -51,8 +50,7 @@ func (oc *AIClient) processToolMediaResult( } // Multiple images (IMAGES: prefix) - if strings.HasPrefix(result, ImagesResultPrefix) { - payload := strings.TrimPrefix(result, ImagesResultPrefix) + if payload, ok := strings.CutPrefix(result, ImagesResultPrefix); ok { var images []string if err := json.Unmarshal([]byte(payload), &images); err != nil { log.Warn().Err(err).Msg("Failed to parse generated images payload" + logSuffix) @@ -85,8 +83,7 @@ func (oc *AIClient) processToolMediaResult( } // Single image (IMAGE: prefix) - if strings.HasPrefix(result, ImageResultPrefix) { - imageB64 := strings.TrimPrefix(result, ImageResultPrefix) + if imageB64, ok := strings.CutPrefix(result, ImageResultPrefix); ok { imageData, mimeType, err := decodeBase64Image(imageB64) if err != nil { log.Warn().Err(err).Msg("Failed to decode generated image" + logSuffix) @@ -229,7 +226,7 @@ func (oc *AIClient) handleFunctionCallArgumentsDone( result, err = oc.executeBuiltinTool(toolCtx, portal, toolName, argsJSON) if err != nil { log.Warn().Err(err).Str("tool", toolName).Msg("Tool execution failed" + logSuffix) - result = fmt.Sprintf("Error: %s", err.Error()) + result = fmt.Sprintf("Error: %s", err) resultStatus = ResultStatusError } } diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index 4d4d4eb2..b102cc69 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -12,11 +12,10 @@ import ( "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" "github.com/beeper/ai-bridge/pkg/agents" "github.com/beeper/ai-bridge/pkg/agents/tools" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func normalizeAgentID(value string) string { @@ -345,7 +344,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } } - eventID := id.EventID(fmt.Sprintf("$subagent-%s", uuid.NewString())) + eventID := bridgeadapter.NewEventID("subagent") promptMessages, err := oc.buildPrompt(ctx, childPortal, childMeta, task, eventID) if err != nil { return tools.JSONResult(map[string]any{ @@ -355,7 +354,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } userMessage := &database.Message{ - ID: networkid.MessageID(fmt.Sprintf("mx:%s", eventID)), + ID: bridgeadapter.MatrixMessageID(eventID), MXID: eventID, Room: childPortal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), diff --git a/pkg/connector/tools.go b/pkg/connector/tools.go index fc0faa0d..7b8b2c49 100644 --- a/pkg/connector/tools.go +++ b/pkg/connector/tools.go @@ -71,29 +71,24 @@ func GetBridgeToolContext(ctx context.Context) *BridgeToolContext { return contextValue[*BridgeToolContext](ctx, bridgeToolContextKey{}) } -var ( - builtinToolsOnce sync.Once - builtinToolsCached []ToolDefinition - builtinToolsByNameMap map[string]*ToolDefinition -) - -func initBuiltinTools() { - builtinToolsCached = buildBuiltinToolDefinitions() - builtinToolsByNameMap = make(map[string]*ToolDefinition, len(builtinToolsCached)) - for i := range builtinToolsCached { - builtinToolsByNameMap[builtinToolsCached[i].Name] = &builtinToolsCached[i] +var builtinToolsInit = sync.OnceValues(func() ([]ToolDefinition, map[string]*ToolDefinition) { + tools := buildBuiltinToolDefinitions() + byName := make(map[string]*ToolDefinition, len(tools)) + for i := range tools { + byName[tools[i].Name] = &tools[i] } -} + return tools, byName +}) -// The result is computed once and cached for the process lifetime. +// BuiltinTools returns all builtin tool definitions (computed once and cached). func BuiltinTools() []ToolDefinition { - builtinToolsOnce.Do(initBuiltinTools) - return builtinToolsCached + tools, _ := builtinToolsInit() + return tools } func GetBuiltinTool(name string) *ToolDefinition { - builtinToolsOnce.Do(initBuiltinTools) - return builtinToolsByNameMap[name] + _, byName := builtinToolsInit() + return byName[name] } const ToolNameMessage = toolspec.MessageName diff --git a/pkg/integrations/memory/flush_tool_loop.go b/pkg/integrations/memory/flush_tool_loop.go index 0b34c8dd..101aa8f0 100644 --- a/pkg/integrations/memory/flush_tool_loop.go +++ b/pkg/integrations/memory/flush_tool_loop.go @@ -3,6 +3,7 @@ package memory import ( "context" "errors" + "slices" "strings" "time" @@ -54,7 +55,7 @@ func RunFlushToolLoop( flushCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - chat := append([]openai.ChatCompletionMessageParamUnion{}, messages...) + chat := slices.Clone(messages) for range maxTurns { assistant, calls, done, err := deps.NextTurn(flushCtx, model, chat) if err != nil { diff --git a/pkg/integrations/memory/overflow_exec.go b/pkg/integrations/memory/overflow_exec.go index d20b3059..73dd45ed 100644 --- a/pkg/integrations/memory/overflow_exec.go +++ b/pkg/integrations/memory/overflow_exec.go @@ -3,6 +3,7 @@ package memory import ( "context" "fmt" + "slices" "strings" "github.com/openai/openai-go/v3" @@ -128,7 +129,7 @@ func buildFlushPrompt(base []openai.ChatCompletionMessageParamUnion, settings *F if settings == nil { return nil } - trimmed := append([]openai.ChatCompletionMessageParamUnion{}, base...) + trimmed := slices.Clone(base) if strings.TrimSpace(settings.SystemPrompt) != "" { insertAt := 0 for insertAt < len(trimmed) && trimmed[insertAt].OfSystem != nil { diff --git a/pkg/integrations/memory/prompt_exec.go b/pkg/integrations/memory/prompt_exec.go index 791a8a1f..f1dbe8a8 100644 --- a/pkg/integrations/memory/prompt_exec.go +++ b/pkg/integrations/memory/prompt_exec.go @@ -2,6 +2,7 @@ package memory import ( "context" + "slices" "strings" "github.com/openai/openai-go/v3" @@ -54,7 +55,7 @@ func AugmentPrompt( return prompt } contextText := strings.Join(sections, "\n\n") - out := append([]openai.ChatCompletionMessageParamUnion{}, prompt...) + out := slices.Clone(prompt) out = append(out, openai.SystemMessage(contextText)) return out } diff --git a/pkg/runtime/compaction_overflow.go b/pkg/runtime/compaction_overflow.go index a0f5cfae..a7e2908e 100644 --- a/pkg/runtime/compaction_overflow.go +++ b/pkg/runtime/compaction_overflow.go @@ -2,6 +2,7 @@ package runtime import ( "fmt" + "slices" "strings" "github.com/openai/openai-go/v3" @@ -134,7 +135,7 @@ func pruneHistoryForContextSharePrompt( } preambleEnd := preambleEndIndex(prompt) - kept := append([]openai.ChatCompletionMessageParamUnion{}, prompt[preambleEnd:]...) + kept := slices.Clone(prompt[preambleEnd:]) droppedCount := 0 droppedTokens := 0 for len(kept) > 0 && estimatePromptTokensForCompaction(kept) > budgetTokens { @@ -152,7 +153,7 @@ func pruneHistoryForContextSharePrompt( kept = repairOrphanToolResults(rest) } - finalPrompt := append([]openai.ChatCompletionMessageParamUnion{}, prompt[:preambleEnd]...) + finalPrompt := slices.Clone(prompt[:preambleEnd]) finalPrompt = append(finalPrompt, kept...) return historySharePruneResult{ Prompt: finalPrompt, @@ -166,7 +167,7 @@ func pruneHistoryForContextSharePrompt( // CompactPromptOnOverflow applies deterministic compaction + smart truncation for overflow retries. func CompactPromptOnOverflow(input OverflowCompactionInput) OverflowCompactionResult { - workingPrompt := append([]openai.ChatCompletionMessageParamUnion{}, input.Prompt...) + workingPrompt := slices.Clone(input.Prompt) if len(workingPrompt) <= 2 { _, totalChars := PromptTextPayloads(workingPrompt) decision := CompactionDecision{ diff --git a/pkg/shared/streamtransport/converted_edit.go b/pkg/shared/streamtransport/converted_edit.go new file mode 100644 index 00000000..de73704c --- /dev/null +++ b/pkg/shared/streamtransport/converted_edit.go @@ -0,0 +1,21 @@ +package streamtransport + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" +) + +type RenderedMarkdownContent struct { + Body string + Format event.Format + FormattedBody string +} + +func BuildRenderedConvertedEdit(rendered RenderedMarkdownContent, topLevelExtra map[string]any) *bridgev2.ConvertedEdit { + return BuildConvertedEdit(&event.MessageEventContent{ + MsgType: event.MsgText, + Body: rendered.Body, + Format: rendered.Format, + FormattedBody: rendered.FormattedBody, + }, topLevelExtra) +} diff --git a/pkg/shared/streamtransport/session.go b/pkg/shared/streamtransport/session.go index 8401e67d..c26e648a 100644 --- a/pkg/shared/streamtransport/session.go +++ b/pkg/shared/streamtransport/session.go @@ -275,7 +275,7 @@ func (s *StreamSession) sendEphemeralWithRetry(ephemeralSender bridgev2.Ephemera } return false } - for i := 0; i < nonFallbackRetryCount; i++ { + for range nonFallbackRetryCount { if s.IsClosed() { return false } diff --git a/pkg/shared/stringutil/dedupe.go b/pkg/shared/stringutil/dedupe.go index b8035d0b..60c37067 100644 --- a/pkg/shared/stringutil/dedupe.go +++ b/pkg/shared/stringutil/dedupe.go @@ -1,6 +1,10 @@ package stringutil -import "strings" +import ( + "strings" + + "go.mau.fi/util/exslices" +) // DedupeStrings returns a deduplicated copy of values, preserving order. // Empty strings and strings that are empty after trimming are skipped. @@ -8,18 +12,11 @@ func DedupeStrings(values []string) []string { if len(values) == 0 { return nil } - seen := make(map[string]struct{}, len(values)) - out := make([]string, 0, len(values)) + var trimmed []string for _, raw := range values { - v := strings.TrimSpace(raw) - if v == "" { - continue - } - if _, ok := seen[v]; ok { - continue + if v := strings.TrimSpace(raw); v != "" { + trimmed = append(trimmed, v) } - seen[v] = struct{}{} - out = append(out, v) } - return out + return exslices.DeduplicateUnsorted(trimmed) } diff --git a/pkg/textfs/apply_patch.go b/pkg/textfs/apply_patch.go index 4576cff9..dec4a42d 100644 --- a/pkg/textfs/apply_patch.go +++ b/pkg/textfs/apply_patch.go @@ -250,8 +250,7 @@ func parseOneHunk(lines []string, lineNumber int) (applyPatchHunk, int, error) { return nil, 0, fmt.Errorf("invalid patch hunk at line %d: empty hunk", lineNumber) } firstLine := strings.TrimSpace(lines[0]) - if strings.HasPrefix(firstLine, addFileMarker) { - targetPath := strings.TrimPrefix(firstLine, addFileMarker) + if targetPath, ok := strings.CutPrefix(firstLine, addFileMarker); ok { contents := "" consumed := 1 for _, addLine := range lines[1:] { @@ -264,19 +263,17 @@ func parseOneHunk(lines []string, lineNumber int) (applyPatchHunk, int, error) { } return addFileHunk{path: targetPath, contents: contents}, consumed, nil } - if strings.HasPrefix(firstLine, deleteFileMarker) { - targetPath := strings.TrimPrefix(firstLine, deleteFileMarker) + if targetPath, ok := strings.CutPrefix(firstLine, deleteFileMarker); ok { return deleteFileHunk{path: targetPath}, 1, nil } - if strings.HasPrefix(firstLine, updateFileMarker) { - targetPath := strings.TrimPrefix(firstLine, updateFileMarker) + if targetPath, ok := strings.CutPrefix(firstLine, updateFileMarker); ok { remaining := lines[1:] consumed := 1 movePath := "" if len(remaining) > 0 { candidate := strings.TrimSpace(remaining[0]) - if strings.HasPrefix(candidate, moveToMarker) { - movePath = strings.TrimPrefix(candidate, moveToMarker) + if mp, ok := strings.CutPrefix(candidate, moveToMarker); ok { + movePath = mp remaining = remaining[1:] consumed++ } diff --git a/pkg/textfs/truncate.go b/pkg/textfs/truncate.go index 76ecd64c..08f0424d 100644 --- a/pkg/textfs/truncate.go +++ b/pkg/textfs/truncate.go @@ -75,7 +75,7 @@ func TruncateHead(content string, maxLines, maxBytes int) Truncation { outputLines := make([]string, 0, maxLines) outputBytes := 0 truncatedBy := "lines" - for i := 0; i < len(lines) && i < maxLines; i++ { + for i := range min(len(lines), maxLines) { line := lines[i] lineBytes := len(line) if i > 0 { From 6b0516da01c7841b37428a8de54f1fa13ed36c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:11:24 +0100 Subject: [PATCH 10/23] Use toolspec helpers for tool schemas Refactor builtin tool definitions to use shared toolspec helpers and reduce schema duplication. Added EmptyObjectSchema to toolspec and introduced unavailableBuiltinToolSpec + newUnavailableBuiltinTool to centralize creation of unavailable builtin tools. Replaced inline JSON schemas with toolspec.ObjectSchema, StringProperty, NumberProperty, BooleanProperty and EmptyObjectSchema across agents_list, apply_patch, boss, and textfs; updated imports and replaced direct execUnavailable usages with the new constructor where appropriate. --- pkg/agents/tools/agents_list.go | 11 +- pkg/agents/tools/apply_patch.go | 23 ++-- pkg/agents/tools/boss.go | 190 +++++++++----------------------- pkg/agents/tools/textfs.go | 64 ++++++----- pkg/shared/toolspec/toolspec.go | 4 + 5 files changed, 105 insertions(+), 187 deletions(-) diff --git a/pkg/agents/tools/agents_list.go b/pkg/agents/tools/agents_list.go index 4d15512c..f582cf15 100644 --- a/pkg/agents/tools/agents_list.go +++ b/pkg/agents/tools/agents_list.go @@ -1,6 +1,10 @@ package tools -import "github.com/modelcontextprotocol/go-sdk/mcp" +import ( + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/beeper/ai-bridge/pkg/shared/toolspec" +) // AgentsListTool lists agent ids allowed for sessions_spawn. var AgentsListTool = &Tool{ @@ -8,10 +12,7 @@ var AgentsListTool = &Tool{ Name: "agents_list", Description: "List agent ids you can target with sessions_spawn (based on allowlists).", Annotations: &mcp.ToolAnnotations{Title: "Agents List"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{}, - }, + InputSchema: toolspec.EmptyObjectSchema(), }, Type: ToolTypeBuiltin, Group: GroupSessions, diff --git a/pkg/agents/tools/apply_patch.go b/pkg/agents/tools/apply_patch.go index 797a8d34..2e81caea 100644 --- a/pkg/agents/tools/apply_patch.go +++ b/pkg/agents/tools/apply_patch.go @@ -1,19 +1,10 @@ package tools -import ( - "github.com/modelcontextprotocol/go-sdk/mcp" +import "github.com/beeper/ai-bridge/pkg/shared/toolspec" - "github.com/beeper/ai-bridge/pkg/shared/toolspec" -) - -var ApplyPatchTool = &Tool{ - Tool: mcp.Tool{ - Name: toolspec.ApplyPatchName, - Description: toolspec.ApplyPatchDescription, - Annotations: &mcp.ToolAnnotations{Title: "apply_patch"}, - InputSchema: toolspec.ApplyPatchSchema(), - }, - Type: ToolTypeBuiltin, - Group: GroupFS, - Execute: execUnavailable(toolspec.ApplyPatchName), -} +var ApplyPatchTool = newUnavailableBuiltinTool(unavailableBuiltinToolSpec{ + name: toolspec.ApplyPatchName, + description: toolspec.ApplyPatchDescription, + title: "apply_patch", + inputSchema: toolspec.ApplyPatchSchema(), +}) diff --git a/pkg/agents/tools/boss.go b/pkg/agents/tools/boss.go index e63c1655..09bfa620 100644 --- a/pkg/agents/tools/boss.go +++ b/pkg/agents/tools/boss.go @@ -135,18 +135,14 @@ var CreateAgentTool = &Tool{ Name: "create_agent", Description: "Create a new AI agent with custom configuration", Annotations: &mcp.ToolAnnotations{Title: "Create Agent"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "name": map[string]any{"type": "string", "description": "Display name for the agent"}, - "description": map[string]any{"type": "string", "description": "Brief description of what the agent does"}, - "model": map[string]any{"type": "string", "description": "Model ID to use (e.g., 'anthropic/claude-sonnet-4.5'). Leave empty for default."}, - "system_prompt": map[string]any{"type": "string", "description": "Custom system prompt for the agent"}, - "tools": toolPolicySchema(), - "subagents": subagentSchema(), - }, - "required": []string{"name"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "name": toolspec.StringProperty("Display name for the agent"), + "description": toolspec.StringProperty("Brief description of what the agent does"), + "model": toolspec.StringProperty("Model ID to use (e.g., 'anthropic/claude-sonnet-4.5'). Leave empty for default."), + "system_prompt": toolspec.StringProperty("Custom system prompt for the agent"), + "tools": toolPolicySchema(), + "subagents": subagentSchema(), + }, "name"), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -173,19 +169,15 @@ var EditAgentTool = &Tool{ Name: "edit_agent", Description: "Modify an existing custom agent's configuration", Annotations: &mcp.ToolAnnotations{Title: "Edit Agent"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "agent_id": map[string]any{"type": "string", "description": "ID of the agent to edit"}, - "name": map[string]any{"type": "string", "description": "New display name"}, - "description": map[string]any{"type": "string", "description": "New description"}, - "model": map[string]any{"type": "string", "description": "New model ID"}, - "system_prompt": map[string]any{"type": "string", "description": "New system prompt"}, - "tools": toolPolicySchema(), - "subagents": subagentSchema(), - }, - "required": []string{"agent_id"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "agent_id": toolspec.StringProperty("ID of the agent to edit"), + "name": toolspec.StringProperty("New display name"), + "description": toolspec.StringProperty("New description"), + "model": toolspec.StringProperty("New model ID"), + "system_prompt": toolspec.StringProperty("New system prompt"), + "tools": toolPolicySchema(), + "subagents": subagentSchema(), + }, "agent_id"), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -197,16 +189,9 @@ var DeleteAgentTool = &Tool{ Name: "delete_agent", Description: "Delete a custom agent (preset agents cannot be deleted)", Annotations: &mcp.ToolAnnotations{Title: "Delete Agent"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "agent_id": map[string]any{ - "type": "string", - "description": "ID of the agent to delete", - }, - }, - "required": []string{"agent_id"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "agent_id": toolspec.StringProperty("ID of the agent to delete"), + }, "agent_id"), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -218,10 +203,7 @@ var ListAgentsTool = &Tool{ Name: "list_agents", Description: "List all available agents (both preset and custom)", Annotations: &mcp.ToolAnnotations{Title: "List Agents"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{}, - }, + InputSchema: toolspec.EmptyObjectSchema(), }, Type: ToolTypeBuiltin, Group: GroupBuilder, @@ -233,10 +215,7 @@ var ListModelsTool = &Tool{ Name: "list_models", Description: "List all available AI models", Annotations: &mcp.ToolAnnotations{Title: "List Models"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{}, - }, + InputSchema: toolspec.EmptyObjectSchema(), }, Type: ToolTypeBuiltin, Group: GroupSessions, @@ -263,28 +242,12 @@ var ModifyRoomTool = &Tool{ Name: "modify_room", Description: "Modify an existing room's configuration", Annotations: &mcp.ToolAnnotations{Title: "Modify Room"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "room_id": map[string]any{ - "type": "string", - "description": "ID of the room to modify", - }, - "name": map[string]any{ - "type": "string", - "description": "New display name for the room", - }, - "agent_id": map[string]any{ - "type": "string", - "description": "New agent ID to assign to this room", - }, - "system_prompt": map[string]any{ - "type": "string", - "description": "New system prompt override for this room", - }, - }, - "required": []string{"room_id"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "room_id": toolspec.StringProperty("ID of the room to modify"), + "name": toolspec.StringProperty("New display name for the room"), + "agent_id": toolspec.StringProperty("New agent ID to assign to this room"), + "system_prompt": toolspec.StringProperty("New system prompt override for this room"), + }, "room_id"), }, Type: ToolTypeBuiltin, Group: GroupSessions, @@ -296,27 +259,15 @@ var SessionsListTool = &Tool{ Name: "sessions_list", Description: "List sessions with optional filters and last messages.", Annotations: &mcp.ToolAnnotations{Title: "List Sessions"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "kinds": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - }, - "limit": map[string]any{ - "type": "number", - "description": "Maximum number of sessions to return (default: 50)", - }, - "activeMinutes": map[string]any{ - "type": "number", - "description": "Only include sessions active within this many minutes", - }, - "messageLimit": map[string]any{ - "type": "number", - "description": "Include the last N messages for each session", - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "kinds": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, }, - }, + "limit": toolspec.NumberProperty("Maximum number of sessions to return (default: 50)"), + "activeMinutes": toolspec.NumberProperty("Only include sessions active within this many minutes"), + "messageLimit": toolspec.NumberProperty("Include the last N messages for each session"), + }), }, Type: ToolTypeBuiltin, Group: GroupSessions, @@ -328,24 +279,11 @@ var SessionsHistoryTool = &Tool{ Name: "sessions_history", Description: "Fetch message history for a session. Use the sessionKey from sessions_list.", Annotations: &mcp.ToolAnnotations{Title: "Session History"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "sessionKey": map[string]any{ - "type": "string", - "description": "Session identifier from sessions_list (preferred canonical target)", - }, - "limit": map[string]any{ - "type": "number", - "description": "Maximum number of messages to return (default: 200)", - }, - "includeTools": map[string]any{ - "type": "boolean", - "description": "Whether to include tool calls in the returned history", - }, - }, - "required": []string{"sessionKey"}, - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "sessionKey": toolspec.StringProperty("Session identifier from sessions_list (preferred canonical target)"), + "limit": toolspec.NumberProperty("Maximum number of messages to return (default: 200)"), + "includeTools": toolspec.BooleanProperty("Whether to include tool calls in the returned history"), + }, "sessionKey"), }, Type: ToolTypeBuiltin, Group: GroupSessions, @@ -375,41 +313,19 @@ var SessionsSpawnTool = &Tool{ Name: "sessions_spawn", Description: "Spawn a background sub-agent run in an isolated session and announce the result back to the requester chat.", Annotations: &mcp.ToolAnnotations{Title: "Spawn Session"}, - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "task": map[string]any{ - "type": "string", - "description": "Task description for the sub-agent.", - }, - "label": map[string]any{ - "type": "string", - "description": "Optional label for the sub-agent run.", - }, - "agentId": map[string]any{ - "type": "string", - "description": "Agent ID override for the sub-agent run.", - }, - "model": map[string]any{ - "type": "string", - "description": "Optional model override (provider/model).", - }, - "thinking": map[string]any{ - "type": "string", - "description": "Optional thinking level override.", - }, - "runTimeoutSeconds": map[string]any{ - "type": "number", - "description": "Optional run timeout in seconds.", - }, - "cleanup": map[string]any{ - "type": "string", - "enum": []string{"delete", "keep"}, - "description": "Cleanup policy for the spawned session.", - }, + InputSchema: toolspec.ObjectSchema(map[string]any{ + "task": toolspec.StringProperty("Task description for the sub-agent."), + "label": toolspec.StringProperty("Optional label for the sub-agent run."), + "agentId": toolspec.StringProperty("Agent ID override for the sub-agent run."), + "model": toolspec.StringProperty("Optional model override (provider/model)."), + "thinking": toolspec.StringProperty("Optional thinking level override."), + "runTimeoutSeconds": toolspec.NumberProperty("Optional run timeout in seconds."), + "cleanup": map[string]any{ + "type": "string", + "enum": []string{"delete", "keep"}, + "description": "Cleanup policy for the spawned session.", }, - "required": []string{"task"}, - }, + }, "task"), }, Type: ToolTypeBuiltin, Group: GroupSessions, diff --git a/pkg/agents/tools/textfs.go b/pkg/agents/tools/textfs.go index a280aa26..c2e44114 100644 --- a/pkg/agents/tools/textfs.go +++ b/pkg/agents/tools/textfs.go @@ -14,38 +14,44 @@ func execUnavailable(name string) func(ctx context.Context, input map[string]any } } -var ( - ReadTool = &Tool{ - Tool: mcp.Tool{ - Name: toolspec.ReadName, - Description: toolspec.ReadDescription, - Annotations: &mcp.ToolAnnotations{Title: "Read"}, - InputSchema: toolspec.ReadSchema(), - }, - Type: ToolTypeBuiltin, - Group: GroupFS, - Execute: execUnavailable(toolspec.ReadName), - } - WriteTool = &Tool{ - Tool: mcp.Tool{ - Name: toolspec.WriteName, - Description: toolspec.WriteDescription, - Annotations: &mcp.ToolAnnotations{Title: "Write"}, - InputSchema: toolspec.WriteSchema(), - }, - Type: ToolTypeBuiltin, - Group: GroupFS, - Execute: execUnavailable(toolspec.WriteName), - } - EditTool = &Tool{ +type unavailableBuiltinToolSpec struct { + name string + description string + title string + inputSchema map[string]any +} + +func newUnavailableBuiltinTool(spec unavailableBuiltinToolSpec) *Tool { + return &Tool{ Tool: mcp.Tool{ - Name: toolspec.EditName, - Description: toolspec.EditDescription, - Annotations: &mcp.ToolAnnotations{Title: "Edit"}, - InputSchema: toolspec.EditSchema(), + Name: spec.name, + Description: spec.description, + Annotations: &mcp.ToolAnnotations{Title: spec.title}, + InputSchema: spec.inputSchema, }, Type: ToolTypeBuiltin, Group: GroupFS, - Execute: execUnavailable(toolspec.EditName), + Execute: execUnavailable(spec.name), } +} + +var ( + ReadTool = newUnavailableBuiltinTool(unavailableBuiltinToolSpec{ + name: toolspec.ReadName, + description: toolspec.ReadDescription, + title: "Read", + inputSchema: toolspec.ReadSchema(), + }) + WriteTool = newUnavailableBuiltinTool(unavailableBuiltinToolSpec{ + name: toolspec.WriteName, + description: toolspec.WriteDescription, + title: "Write", + inputSchema: toolspec.WriteSchema(), + }) + EditTool = newUnavailableBuiltinTool(unavailableBuiltinToolSpec{ + name: toolspec.EditName, + description: toolspec.EditDescription, + title: "Edit", + inputSchema: toolspec.EditSchema(), + }) ) diff --git a/pkg/shared/toolspec/toolspec.go b/pkg/shared/toolspec/toolspec.go index 361765cc..65f917a2 100644 --- a/pkg/shared/toolspec/toolspec.go +++ b/pkg/shared/toolspec/toolspec.go @@ -563,6 +563,10 @@ func ObjectSchema(properties map[string]any, required ...string) map[string]any return schema } +func EmptyObjectSchema() map[string]any { + return ObjectSchema(map[string]any{}) +} + func StringProperty(description string) map[string]any { return map[string]any{ "type": "string", From faa919f4752e1f1260b2dcb9c427313fcf5cbed7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:13:39 +0100 Subject: [PATCH 11/23] Use strings.CutPrefix/CutSuffix for trimming Replace various patterns of strings.HasPrefix/HasSuffix + TrimPrefix/TrimSuffix with strings.CutPrefix and strings.CutSuffix for clearer, slightly more efficient trimming and to avoid repeated prefix computations. Changes touch multiple files: - bridges/codex/login.go: use CutPrefix when resolving ~ path to home directory. - bridges/opencode/opencode/events.go: use CutPrefix when extracting SSE "data:" payload. - bridges/opencode/opencodebridge/opencode_media.go: use CutPrefix for file:// handling and URL unescaping. - pkg/agents/heartbeat.go: use CutPrefix when stripping tokens at edges. - pkg/connector/desktop_api_sessions.go: use CutPrefix when parsing desktop session key prefixes. - pkg/connector/token_resolver.go: use CutSuffix when removing proxy service suffixes. - pkg/shared/media/data_uri.go: use CutPrefix to validate and strip the "data:" scheme. - pkg/textfs/apply_patch.go: use CutPrefix when parsing change context markers. No intended behavioral changes; primarily refactors to improve readability and reduce string allocations. --- bridges/codex/login.go | 4 ++-- bridges/opencode/opencode/events.go | 4 ++-- .../opencodebridge/opencode_canonical_stream.go | 3 +-- .../opencode/opencodebridge/opencode_media.go | 16 +++++++--------- pkg/agents/heartbeat.go | 4 ++-- pkg/connector/desktop_api_sessions.go | 8 ++++---- pkg/connector/token_resolver.go | 4 ++-- pkg/shared/media/data_uri.go | 4 ++-- pkg/textfs/apply_patch.go | 4 ++-- 9 files changed, 24 insertions(+), 27 deletions(-) diff --git a/bridges/codex/login.go b/bridges/codex/login.go index 067a55df..204a10f6 100644 --- a/bridges/codex/login.go +++ b/bridges/codex/login.go @@ -733,9 +733,9 @@ func (cl *CodexLogin) resolveCodexHomeBaseDir() string { base = filepath.Join(os.TempDir(), "ai-bridge-codex") } } - if strings.HasPrefix(base, "~"+string(os.PathSeparator)) { + if rest, ok := strings.CutPrefix(base, "~"+string(os.PathSeparator)); ok { if home, err := os.UserHomeDir(); err == nil && strings.TrimSpace(home) != "" { - base = filepath.Join(home, strings.TrimPrefix(base, "~"+string(os.PathSeparator))) + base = filepath.Join(home, rest) } } abs, err := filepath.Abs(base) diff --git a/bridges/opencode/opencode/events.go b/bridges/opencode/opencode/events.go index 5be2bd42..22b18b95 100644 --- a/bridges/opencode/opencode/events.go +++ b/bridges/opencode/opencode/events.go @@ -71,8 +71,8 @@ func (c *Client) StreamEvents(ctx context.Context) (<-chan Event, <-chan error) flush() continue } - if strings.HasPrefix(line, "data:") { - dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + if d, ok := strings.CutPrefix(line, "data:"); ok { + dataLines = append(dataLines, strings.TrimSpace(d)) } } if err := scanner.Err(); err != nil && ctx.Err() == nil { diff --git a/bridges/opencode/opencodebridge/opencode_canonical_stream.go b/bridges/opencode/opencodebridge/opencode_canonical_stream.go index ea123daa..2224e533 100644 --- a/bridges/opencode/opencodebridge/opencode_canonical_stream.go +++ b/bridges/opencode/opencodebridge/opencode_canonical_stream.go @@ -73,8 +73,7 @@ func (m *OpenCodeManager) syncAssistantTextPart(ctx context.Context, inst *openC inst.appendPartTextContent(part.SessionID, part.ID, kind, text) delivered = text } - } else if text != "" && strings.HasPrefix(text, delivered) && len(text) > len(delivered) { - missing := text[len(delivered):] + } else if missing, ok := strings.CutPrefix(text, delivered); ok && missing != "" { m.bridge.emitOpenCodeStreamEvent(ctx, portal, turnID, agentID, map[string]any{ "type": kind + "-delta", "id": partID, diff --git a/bridges/opencode/opencodebridge/opencode_media.go b/bridges/opencode/opencodebridge/opencode_media.go index 67169a70..525df9c6 100644 --- a/bridges/opencode/opencodebridge/opencode_media.go +++ b/bridges/opencode/opencodebridge/opencode_media.go @@ -96,8 +96,8 @@ func downloadOpenCodeFile(ctx context.Context, fileURL, fallbackMime string, max if strings.HasPrefix(fileURL, "file://") || strings.HasPrefix(fileURL, "/") { pathValue := fileURL - if strings.HasPrefix(pathValue, "file://") { - pathValue = strings.TrimPrefix(pathValue, "file://") + if p, ok := strings.CutPrefix(pathValue, "file://"); ok { + pathValue = p if unescaped, err := url.PathUnescape(pathValue); err == nil { pathValue = unescaped } @@ -158,15 +158,14 @@ func downloadOpenCodeFile(ctx context.Context, fileURL, fallbackMime string, max } func decodeOpenCodeDataURL(raw string) ([]byte, string, error) { - if !strings.HasPrefix(raw, "data:") { + rest, ok := strings.CutPrefix(raw, "data:") + if !ok { return nil, "", errors.New("not a data URL") } - comma := strings.IndexByte(raw, ',') - if comma < 0 { + meta, payload, ok := strings.Cut(rest, ",") + if !ok { return nil, "", errors.New("invalid data URL") } - meta := raw[len("data:"):comma] - payload := raw[comma+1:] isBase64 := strings.Contains(meta, ";base64") mimeType := "" if meta != "" { @@ -187,8 +186,7 @@ func decodeOpenCodeDataURL(raw string) ([]byte, string, error) { } func filenameFromOpenCodeURL(raw string) string { - if strings.HasPrefix(raw, "file://") { - pathValue := strings.TrimPrefix(raw, "file://") + if pathValue, ok := strings.CutPrefix(raw, "file://"); ok { if unescaped, err := url.PathUnescape(pathValue); err == nil { pathValue = unescaped } diff --git a/pkg/agents/heartbeat.go b/pkg/agents/heartbeat.go index b499d1bc..d66bbb34 100644 --- a/pkg/agents/heartbeat.go +++ b/pkg/agents/heartbeat.go @@ -71,8 +71,8 @@ func stripTokenAtEdges(raw string, token string) (string, bool) { for changed { changed = false next := strings.TrimSpace(text) - if strings.HasPrefix(next, token) { - after := strings.TrimLeft(next[len(token):], " \t\r\n") + if after, ok := strings.CutPrefix(next, token); ok { + after = strings.TrimLeft(after, " \t\r\n") text = after didStrip = true changed = true diff --git a/pkg/connector/desktop_api_sessions.go b/pkg/connector/desktop_api_sessions.go index d2621e98..eb4915e7 100644 --- a/pkg/connector/desktop_api_sessions.go +++ b/pkg/connector/desktop_api_sessions.go @@ -136,10 +136,10 @@ func parseDesktopSessionKey(sessionKey string) (string, string, bool) { return "", "", false } var raw string - if strings.HasPrefix(trimmed, desktopSessionKeyPrefix) { - raw = strings.TrimPrefix(trimmed, desktopSessionKeyPrefix) - } else if strings.HasPrefix(trimmed, desktopSessionKeyAliasPrefix) { - raw = strings.TrimPrefix(trimmed, desktopSessionKeyAliasPrefix) + if r, ok := strings.CutPrefix(trimmed, desktopSessionKeyPrefix); ok { + raw = r + } else if r, ok := strings.CutPrefix(trimmed, desktopSessionKeyAliasPrefix); ok { + raw = r } else { return "", "", false } diff --git a/pkg/connector/token_resolver.go b/pkg/connector/token_resolver.go index ac844328..6a671fce 100644 --- a/pkg/connector/token_resolver.go +++ b/pkg/connector/token_resolver.go @@ -92,8 +92,8 @@ func stripProxyServiceSuffix(path string) string { for { changed := false for _, suffix := range []string{"/openrouter/v1", "/openai/v1", "/gemini/v1beta", "/exa"} { - if strings.HasSuffix(trimmed, suffix) { - trimmed = strings.TrimRight(strings.TrimSuffix(trimmed, suffix), "/") + if rest, ok := strings.CutSuffix(trimmed, suffix); ok { + trimmed = strings.TrimRight(rest, "/") changed = true break } diff --git a/pkg/shared/media/data_uri.go b/pkg/shared/media/data_uri.go index b66d646b..0308a061 100644 --- a/pkg/shared/media/data_uri.go +++ b/pkg/shared/media/data_uri.go @@ -11,11 +11,11 @@ import ( // ParseDataURI parses a base64 data URI and returns raw base64 data and mime type. func ParseDataURI(dataURI string) (string, string, error) { // Format: data:[][;base64], - if !strings.HasPrefix(dataURI, "data:") { + rest, ok := strings.CutPrefix(dataURI, "data:") + if !ok { return "", "", errors.New("not a data URI") } - rest := dataURI[5:] metadata, data, ok := strings.Cut(rest, ",") if !ok { return "", "", errors.New("invalid data URI: no comma separator") diff --git a/pkg/textfs/apply_patch.go b/pkg/textfs/apply_patch.go index dec4a42d..359d916f 100644 --- a/pkg/textfs/apply_patch.go +++ b/pkg/textfs/apply_patch.go @@ -312,8 +312,8 @@ func parseUpdateFileChunk(lines []string, lineNumber int, allowMissingContext bo chunk := updateFileChunk{} if lines[0] == emptyChangeContextMarker { startIndex = 1 - } else if strings.HasPrefix(lines[0], changeContextMarker) { - chunk.changeContext = strings.TrimPrefix(lines[0], changeContextMarker) + } else if ctx, ok := strings.CutPrefix(lines[0], changeContextMarker); ok { + chunk.changeContext = ctx chunk.hasContext = true startIndex = 1 } else if !allowMissingContext { From 89719350348708f2b003c11fc258df282610b34e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:14:30 +0100 Subject: [PATCH 12/23] Update main.go --- cmd/bridgectl/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/bridgectl/main.go b/cmd/bridgectl/main.go index 26660ebf..e8d975ef 100644 --- a/cmd/bridgectl/main.go +++ b/cmd/bridgectl/main.go @@ -1218,12 +1218,12 @@ func requiredInstanceArg(args []string) (string, error) { } func expandPath(p string) (string, error) { - if strings.HasPrefix(p, "~/") { + if rest, ok := strings.CutPrefix(p, "~/"); ok { home, err := os.UserHomeDir() if err != nil { return "", err } - p = filepath.Join(home, p[2:]) + p = filepath.Join(home, rest) } return filepath.Abs(p) } From 9483ea2d3efab65798ac1998b37a62f0630128e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:32:18 +0100 Subject: [PATCH 13/23] Add BaseMessageMetadata and central SendViaPortal Introduce bridgeadapter.BaseMessageMetadata to consolidate common message metadata fields and CopyFrom logic. Update MessageMetadata structs across codex, opencode and connector packages to embed the new base type and simplify CopyFrom implementations. Add bridgeadapter.SendViaPortal helper to centralize QueueRemoteEvent send logic and refactor sendViaPortal implementations in codex, opencode and connector to use it. Clean up OpenCode bridge API (export/rename DisplayName, InstanceConfig, EnsureGhostDisplayName, CreateSession/UpdateSessionTitle) and remove duplicated helper functions and unused imports. --- bridges/codex/client.go | 45 +++++----- bridges/codex/metadata.go | 76 ++-------------- bridges/codex/portal_send.go | 39 ++------ .../opencodebridge/backfill_canonical.go | 50 ++++++----- bridges/opencode/opencodebridge/bridge.go | 28 ------ .../opencodebridge/message_metadata.go | 82 +++-------------- .../opencode/opencodebridge/opencode_ghost.go | 4 +- .../opencodebridge/opencode_helpers.go | 9 +- .../opencodebridge/opencode_manager.go | 14 +-- .../opencodebridge/opencode_portal.go | 4 +- bridges/opencode/portal_send.go | 32 ++----- pkg/bridgeadapter/helpers.go | 43 +++++++++ pkg/bridgeadapter/message_metadata.go | 72 +++++++++++++++ pkg/connector/chat.go | 3 +- pkg/connector/client.go | 3 +- pkg/connector/handlematrix.go | 9 +- pkg/connector/integration_host.go | 6 +- pkg/connector/metadata.go | 88 +++---------------- pkg/connector/portal_send.go | 34 ++----- 19 files changed, 247 insertions(+), 394 deletions(-) diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 02a68b5f..1a48dbc0 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -493,8 +493,7 @@ func (cc *CodexClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma SenderID: humanUserID(cc.UserLogin.ID), Timestamp: bridgeadapter.MatrixEventTimestamp(msg.Event), Metadata: &MessageMetadata{ - Role: "user", - Body: body, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: body}, }, } if msg.InputTransactionID != "" { @@ -1731,10 +1730,12 @@ func (cc *CodexClient) sendApprovalRequestFallbackEvent( Content: &event.MessageEventContent{MsgType: event.MsgNotice, Body: "Tool approval required"}, Extra: raw, DBMetadata: &MessageMetadata{ - Role: "assistant", + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: uiMessage, + }, ExcludeFromHistory: true, - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: uiMessage, }, }}, } @@ -1878,7 +1879,7 @@ func (cc *CodexClient) sendInitialStreamMessage(ctx context.Context, portal *bri Type: event.EventMessage, Content: &event.MessageEventContent{MsgType: event.MsgText, Body: content}, Extra: eventRaw, - DBMetadata: &MessageMetadata{Role: "assistant", TurnID: turnID}, + DBMetadata: &MessageMetadata{BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "assistant", TurnID: turnID}}, }}, } @@ -2060,25 +2061,27 @@ func (cc *CodexClient) saveAssistantMessage(ctx context.Context, portal *bridgev } fullMeta := &MessageMetadata{ - Role: "assistant", - Body: state.accumulated.String(), - FinishReason: finishReason, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + Body: state.accumulated.String(), + FinishReason: finishReason, + TurnID: state.turnID, + AgentID: state.agentID, + ToolCalls: state.toolCalls, + StartedAtMs: state.startedAtMs, + CompletedAtMs: state.completedAtMs, + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: cc.buildCanonicalUIMessage(state, model, finishReason), + GeneratedFiles: genFiles, + ThinkingContent: state.reasoning.String(), + PromptTokens: state.promptTokens, + CompletionTokens: state.completionTokens, + ReasoningTokens: state.reasoningTokens, + }, Model: model, - TurnID: state.turnID, - AgentID: state.agentID, - ToolCalls: state.toolCalls, - StartedAtMs: state.startedAtMs, FirstTokenAtMs: state.firstTokenAtMs, - CompletedAtMs: state.completedAtMs, HasToolCalls: len(state.toolCalls) > 0, - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: cc.buildCanonicalUIMessage(state, model, finishReason), - GeneratedFiles: genFiles, - ThinkingContent: state.reasoning.String(), ThinkingTokenCount: len(strings.Fields(state.reasoning.String())), - PromptTokens: state.promptTokens, - CompletionTokens: state.completionTokens, - ReasoningTokens: state.reasoningTokens, } // If the message was sent via sendViaPortal, the DB row already exists — update it. diff --git a/bridges/codex/metadata.go b/bridges/codex/metadata.go index bd30fa0d..a8f39643 100644 --- a/bridges/codex/metadata.go +++ b/bridges/codex/metadata.go @@ -32,28 +32,14 @@ type PortalMetadata struct { } type MessageMetadata struct { - Role string `json:"role,omitempty"` - Body string `json:"body,omitempty"` - ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` - CompletionID string `json:"completion_id,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - PromptTokens int64 `json:"prompt_tokens,omitempty"` - CompletionTokens int64 `json:"completion_tokens,omitempty"` - Model string `json:"model,omitempty"` - ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` - HasToolCalls bool `json:"has_tool_calls,omitempty"` - Transcript string `json:"transcript,omitempty"` - TurnID string `json:"turn_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - ToolCalls []ToolCallMetadata `json:"tool_calls,omitempty"` - CanonicalSchema string `json:"canonical_schema,omitempty"` - CanonicalUIMessage map[string]any `json:"canonical_ui_message,omitempty"` - StartedAtMs int64 `json:"started_at_ms,omitempty"` - FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` - CompletedAtMs int64 `json:"completed_at_ms,omitempty"` - ThinkingContent string `json:"thinking_content,omitempty"` - ThinkingTokenCount int `json:"thinking_token_count,omitempty"` - GeneratedFiles []GeneratedFileRef `json:"generated_files,omitempty"` + bridgeadapter.BaseMessageMetadata + ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` + CompletionID string `json:"completion_id,omitempty"` + Model string `json:"model,omitempty"` + HasToolCalls bool `json:"has_tool_calls,omitempty"` + Transcript string `json:"transcript,omitempty"` + FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` + ThinkingTokenCount int `json:"thinking_token_count,omitempty"` } type ToolCallMetadata = bridgeadapter.ToolCallMetadata @@ -71,72 +57,28 @@ func (mm *MessageMetadata) CopyFrom(other any) { if !ok || src == nil { return } - if src.Role != "" { - mm.Role = src.Role - } - if src.Body != "" { - mm.Body = src.Body - } + mm.CopyFromBase(&src.BaseMessageMetadata) if src.ExcludeFromHistory { mm.ExcludeFromHistory = true } if src.CompletionID != "" { mm.CompletionID = src.CompletionID } - if src.FinishReason != "" { - mm.FinishReason = src.FinishReason - } - if src.PromptTokens != 0 { - mm.PromptTokens = src.PromptTokens - } - if src.CompletionTokens != 0 { - mm.CompletionTokens = src.CompletionTokens - } if src.Model != "" { mm.Model = src.Model } - if src.ReasoningTokens != 0 { - mm.ReasoningTokens = src.ReasoningTokens - } if src.HasToolCalls { mm.HasToolCalls = true } if src.Transcript != "" { mm.Transcript = src.Transcript } - if src.TurnID != "" { - mm.TurnID = src.TurnID - } - if src.AgentID != "" { - mm.AgentID = src.AgentID - } - if len(src.ToolCalls) > 0 { - mm.ToolCalls = src.ToolCalls - } - if src.CanonicalSchema != "" { - mm.CanonicalSchema = src.CanonicalSchema - } - if len(src.CanonicalUIMessage) > 0 { - mm.CanonicalUIMessage = src.CanonicalUIMessage - } - if src.StartedAtMs != 0 { - mm.StartedAtMs = src.StartedAtMs - } if src.FirstTokenAtMs != 0 { mm.FirstTokenAtMs = src.FirstTokenAtMs } - if src.CompletedAtMs != 0 { - mm.CompletedAtMs = src.CompletedAtMs - } - if src.ThinkingContent != "" { - mm.ThinkingContent = src.ThinkingContent - } if src.ThinkingTokenCount != 0 { mm.ThinkingTokenCount = src.ThinkingTokenCount } - if len(src.GeneratedFiles) > 0 { - mm.GeneratedFiles = src.GeneratedFiles - } } func loginMetadata(login *bridgev2.UserLogin) *UserLoginMetadata { diff --git a/bridges/codex/portal_send.go b/bridges/codex/portal_send.go index 0381e916..f782caa3 100644 --- a/bridges/codex/portal_send.go +++ b/bridges/codex/portal_send.go @@ -3,7 +3,6 @@ package codex import ( "context" "fmt" - "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" @@ -13,41 +12,21 @@ import ( ) // sendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. -// Handles: intent resolution, ghost room join, send, DB persist. -// Returns the Matrix event ID and the network message ID used. -// If msgID is empty, a new one is generated. func (cc *CodexClient) sendViaPortal( - ctx context.Context, + _ context.Context, portal *bridgev2.Portal, converted *bridgev2.ConvertedMessage, msgID networkid.MessageID, ) (id.EventID, networkid.MessageID, error) { - if portal == nil || portal.MXID == "" { - return "", "", fmt.Errorf("invalid portal") - } - if cc == nil || cc.UserLogin == nil || cc.UserLogin.Bridge == nil { - return "", msgID, fmt.Errorf("bridge unavailable") - } - if msgID == "" { - msgID = bridgeadapter.NewMessageID("codex") - } - sender := cc.senderForPortal() - evt := &CodexRemoteMessage{ - Portal: portal.PortalKey, - ID: msgID, - Sender: sender, - Timestamp: time.Now(), + return bridgeadapter.SendViaPortal(bridgeadapter.SendViaPortalParams{ + Login: cc.UserLogin, + Portal: portal, + Sender: cc.senderForPortal(), + IDPrefix: "codex", LogKey: "codex_msg_id", - PreBuilt: converted, - } - result := cc.UserLogin.QueueRemoteEvent(evt) - if !result.Success { - if result.Error != nil { - return "", msgID, fmt.Errorf("send failed: %w", result.Error) - } - return "", msgID, fmt.Errorf("send failed") - } - return result.EventID, msgID, nil + MsgID: msgID, + Converted: converted, + }) } // getCodexIntentForPortal resolves the Matrix intent for the Codex ghost. diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index cf0766ad..3c181524 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -61,30 +61,32 @@ func buildCanonicalAssistantBackfill(msg opencode.MessageWithParts, agentID stri body: body, ui: uiMessage, meta: &MessageMetadata{ - Role: stringutil.FirstNonEmpty(strings.TrimSpace(msg.Info.Role), "assistant"), - Body: body, - SessionID: strings.TrimSpace(msg.Info.SessionID), - MessageID: strings.TrimSpace(msg.Info.ID), - ParentMessageID: strings.TrimSpace(msg.Info.ParentID), - Agent: strings.TrimSpace(msg.Info.Agent), - ModelID: strings.TrimSpace(msg.Info.ModelID), - ProviderID: strings.TrimSpace(msg.Info.ProviderID), - Mode: strings.TrimSpace(msg.Info.Mode), - FinishReason: stringutil.FirstNonEmpty(strings.TrimSpace(msg.Info.Finish), finishReason), - Cost: backfillCost(msg), - PromptTokens: backfillPromptTokens(msg), - CompletionTokens: backfillCompletionTokens(msg), - ReasoningTokens: backfillReasoningTokens(msg), - TotalTokens: backfillTotalTokens(msg), - TurnID: turnID, - AgentID: strings.TrimSpace(agentID), - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: uiMessage, - StartedAtMs: int64(msg.Info.Time.Created), - CompletedAtMs: int64(msg.Info.Time.Completed), - ThinkingContent: CanonicalReasoningText(uiMessage), - ToolCalls: CanonicalToolCalls(uiMessage), - GeneratedFiles: CanonicalGeneratedFiles(uiMessage), + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: stringutil.FirstNonEmpty(strings.TrimSpace(msg.Info.Role), "assistant"), + Body: body, + FinishReason: stringutil.FirstNonEmpty(strings.TrimSpace(msg.Info.Finish), finishReason), + PromptTokens: backfillPromptTokens(msg), + CompletionTokens: backfillCompletionTokens(msg), + ReasoningTokens: backfillReasoningTokens(msg), + TurnID: turnID, + AgentID: strings.TrimSpace(agentID), + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: uiMessage, + StartedAtMs: int64(msg.Info.Time.Created), + CompletedAtMs: int64(msg.Info.Time.Completed), + ThinkingContent: CanonicalReasoningText(uiMessage), + ToolCalls: CanonicalToolCalls(uiMessage), + GeneratedFiles: CanonicalGeneratedFiles(uiMessage), + }, + SessionID: strings.TrimSpace(msg.Info.SessionID), + MessageID: strings.TrimSpace(msg.Info.ID), + ParentMessageID: strings.TrimSpace(msg.Info.ParentID), + Agent: strings.TrimSpace(msg.Info.Agent), + ModelID: strings.TrimSpace(msg.Info.ModelID), + ProviderID: strings.TrimSpace(msg.Info.ProviderID), + Mode: strings.TrimSpace(msg.Info.Mode), + Cost: backfillCost(msg), + TotalTokens: backfillTotalTokens(msg), }, } } diff --git a/bridges/opencode/opencodebridge/bridge.go b/bridges/opencode/opencodebridge/bridge.go index 15e7b68e..af3e4334 100644 --- a/bridges/opencode/opencodebridge/bridge.go +++ b/bridges/opencode/opencodebridge/bridge.go @@ -77,34 +77,6 @@ func NewBridge(host Host) *Bridge { return bridge } -func (b *Bridge) DisplayName(instanceID string) string { - if b == nil { - return "" - } - return b.opencodeDisplayName(instanceID) -} - -func (b *Bridge) InstanceConfig(instanceID string) *OpenCodeInstance { - if b == nil { - return nil - } - return b.opencodeInstanceConfig(instanceID) -} - -func (b *Bridge) EnsureGhostDisplayName(ctx context.Context, instanceID string) { - if b == nil { - return - } - b.ensureOpenCodeGhostDisplayName(ctx, instanceID) -} - -func (b *Bridge) CreateSessionChat(ctx context.Context, instanceID, title string, pendingTitle bool) (*bridgev2.CreateChatResponse, error) { - if b == nil { - return nil, ErrUnavailable - } - return b.createOpenCodeSessionChat(ctx, instanceID, title, pendingTitle) -} - func (b *Bridge) AbortSession(ctx context.Context, instanceID, sessionID string) error { if b == nil || b.manager == nil { return ErrUnavailable diff --git a/bridges/opencode/opencodebridge/message_metadata.go b/bridges/opencode/opencodebridge/message_metadata.go index d5d02366..8cd7cc44 100644 --- a/bridges/opencode/opencodebridge/message_metadata.go +++ b/bridges/opencode/opencodebridge/message_metadata.go @@ -7,31 +7,17 @@ import ( ) type MessageMetadata struct { - Role string `json:"role,omitempty"` - Body string `json:"body,omitempty"` - SessionID string `json:"session_id,omitempty"` - MessageID string `json:"message_id,omitempty"` - ParentMessageID string `json:"parent_message_id,omitempty"` - Agent string `json:"agent,omitempty"` - ModelID string `json:"model_id,omitempty"` - ProviderID string `json:"provider_id,omitempty"` - Mode string `json:"mode,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - ErrorText string `json:"error_text,omitempty"` - Cost float64 `json:"cost,omitempty"` - PromptTokens int64 `json:"prompt_tokens,omitempty"` - CompletionTokens int64 `json:"completion_tokens,omitempty"` - ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` - TotalTokens int64 `json:"total_tokens,omitempty"` - TurnID string `json:"turn_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - CanonicalSchema string `json:"canonical_schema,omitempty"` - CanonicalUIMessage map[string]any `json:"canonical_ui_message,omitempty"` - StartedAtMs int64 `json:"started_at_ms,omitempty"` - CompletedAtMs int64 `json:"completed_at_ms,omitempty"` - ThinkingContent string `json:"thinking_content,omitempty"` - ToolCalls []bridgeadapter.ToolCallMetadata `json:"tool_calls,omitempty"` - GeneratedFiles []bridgeadapter.GeneratedFileRef `json:"generated_files,omitempty"` + bridgeadapter.BaseMessageMetadata + SessionID string `json:"session_id,omitempty"` + MessageID string `json:"message_id,omitempty"` + ParentMessageID string `json:"parent_message_id,omitempty"` + Agent string `json:"agent,omitempty"` + ModelID string `json:"model_id,omitempty"` + ProviderID string `json:"provider_id,omitempty"` + Mode string `json:"mode,omitempty"` + ErrorText string `json:"error_text,omitempty"` + Cost float64 `json:"cost,omitempty"` + TotalTokens int64 `json:"total_tokens,omitempty"` } type ToolCallMetadata = bridgeadapter.ToolCallMetadata @@ -45,12 +31,7 @@ func (mm *MessageMetadata) CopyFrom(other any) { if !ok || src == nil { return } - if src.Role != "" { - mm.Role = src.Role - } - if src.Body != "" { - mm.Body = src.Body - } + mm.CopyFromBase(&src.BaseMessageMetadata) if src.SessionID != "" { mm.SessionID = src.SessionID } @@ -72,52 +53,13 @@ func (mm *MessageMetadata) CopyFrom(other any) { if src.Mode != "" { mm.Mode = src.Mode } - if src.FinishReason != "" { - mm.FinishReason = src.FinishReason - } if src.ErrorText != "" { mm.ErrorText = src.ErrorText } if src.Cost != 0 { mm.Cost = src.Cost } - if src.PromptTokens != 0 { - mm.PromptTokens = src.PromptTokens - } - if src.CompletionTokens != 0 { - mm.CompletionTokens = src.CompletionTokens - } - if src.ReasoningTokens != 0 { - mm.ReasoningTokens = src.ReasoningTokens - } if src.TotalTokens != 0 { mm.TotalTokens = src.TotalTokens } - if src.TurnID != "" { - mm.TurnID = src.TurnID - } - if src.AgentID != "" { - mm.AgentID = src.AgentID - } - if src.CanonicalSchema != "" { - mm.CanonicalSchema = src.CanonicalSchema - } - if len(src.CanonicalUIMessage) > 0 { - mm.CanonicalUIMessage = src.CanonicalUIMessage - } - if src.StartedAtMs != 0 { - mm.StartedAtMs = src.StartedAtMs - } - if src.CompletedAtMs != 0 { - mm.CompletedAtMs = src.CompletedAtMs - } - if src.ThinkingContent != "" { - mm.ThinkingContent = src.ThinkingContent - } - if len(src.ToolCalls) > 0 { - mm.ToolCalls = src.ToolCalls - } - if len(src.GeneratedFiles) > 0 { - mm.GeneratedFiles = src.GeneratedFiles - } } diff --git a/bridges/opencode/opencodebridge/opencode_ghost.go b/bridges/opencode/opencodebridge/opencode_ghost.go index 19e2e2e8..a8fc66c3 100644 --- a/bridges/opencode/opencodebridge/opencode_ghost.go +++ b/bridges/opencode/opencodebridge/opencode_ghost.go @@ -7,7 +7,7 @@ import ( "maunium.net/go/mautrix/bridgev2" ) -func (b *Bridge) ensureOpenCodeGhostDisplayName(ctx context.Context, instanceID string) { +func (b *Bridge) EnsureGhostDisplayName(ctx context.Context, instanceID string) { if b == nil || b.host == nil { return } @@ -19,7 +19,7 @@ func (b *Bridge) ensureOpenCodeGhostDisplayName(ctx context.Context, instanceID if err != nil || ghost == nil { return } - displayName := b.opencodeDisplayName(instanceID) + displayName := b.DisplayName(instanceID) if ghost.Name == "" || !ghost.NameSet || ghost.Name != displayName { ghost.UpdateInfo(ctx, &bridgev2.UserInfo{ Name: ptr.Ptr(displayName), diff --git a/bridges/opencode/opencodebridge/opencode_helpers.go b/bridges/opencode/opencodebridge/opencode_helpers.go index 4cb755f2..6b02b3a7 100644 --- a/bridges/opencode/opencodebridge/opencode_helpers.go +++ b/bridges/opencode/opencodebridge/opencode_helpers.go @@ -5,7 +5,7 @@ import ( "strings" ) -func (b *Bridge) opencodeInstanceConfig(instanceID string) *OpenCodeInstance { +func (b *Bridge) InstanceConfig(instanceID string) *OpenCodeInstance { if b == nil || b.host == nil { return nil } @@ -16,8 +16,11 @@ func (b *Bridge) opencodeInstanceConfig(instanceID string) *OpenCodeInstance { return meta[instanceID] } -func (b *Bridge) opencodeDisplayName(instanceID string) string { - cfg := b.opencodeInstanceConfig(instanceID) +func (b *Bridge) DisplayName(instanceID string) string { + if b == nil { + return "" + } + cfg := b.InstanceConfig(instanceID) return opencodeLabelFromURL(cfg) } diff --git a/bridges/opencode/opencodebridge/opencode_manager.go b/bridges/opencode/opencodebridge/opencode_manager.go index adef84df..0c69e725 100644 --- a/bridges/opencode/opencodebridge/opencode_manager.go +++ b/bridges/opencode/opencodebridge/opencode_manager.go @@ -169,7 +169,7 @@ func (m *OpenCodeManager) Connect(ctx context.Context, baseURL, password, userna m.mu.Unlock() m.persistInstance(ctx, inst) - m.bridge.ensureOpenCodeGhostDisplayName(ctx, instanceID) + m.bridge.EnsureGhostDisplayName(ctx, instanceID) count, syncErr := m.syncSessions(ctx, inst, sessions) m.startEventLoop(inst) @@ -355,14 +355,6 @@ func (m *OpenCodeManager) AbortSession(ctx context.Context, instanceID, sessionI return nil } -func (m *OpenCodeManager) CreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { - return m.runCreateSession(ctx, instanceID, title, directory) -} - -func (m *OpenCodeManager) UpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { - return m.runUpdateSessionTitle(ctx, instanceID, sessionID, title) -} - func (m *OpenCodeManager) runSessionMutation( ctx context.Context, instanceID string, @@ -383,13 +375,13 @@ func (m *OpenCodeManager) runSessionMutation( return session, nil } -func (m *OpenCodeManager) runCreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { +func (m *OpenCodeManager) CreateSession(ctx context.Context, instanceID, title, directory string) (*opencode.Session, error) { return m.runSessionMutation(ctx, instanceID, "create session", func(inst *openCodeInstance) (*opencode.Session, error) { return inst.client.CreateSession(ctx, title, directory) }) } -func (m *OpenCodeManager) runUpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { +func (m *OpenCodeManager) UpdateSessionTitle(ctx context.Context, instanceID, sessionID, title string) (*opencode.Session, error) { return m.runSessionMutation(ctx, instanceID, "update session title", func(inst *openCodeInstance) (*opencode.Session, error) { return inst.client.UpdateSessionTitle(ctx, sessionID, title) }) diff --git a/bridges/opencode/opencodebridge/opencode_portal.go b/bridges/opencode/opencodebridge/opencode_portal.go index 92937162..4c2b0ae5 100644 --- a/bridges/opencode/opencodebridge/opencode_portal.go +++ b/bridges/opencode/opencodebridge/opencode_portal.go @@ -137,14 +137,14 @@ func (b *Bridge) composeOpenCodeChatInfo(title, instanceID string) *bridgev2.Cha HumanUserID: b.host.HumanUserID(login.ID), LoginID: login.ID, BotUserID: OpenCodeUserID(instanceID), - BotDisplayName: b.opencodeDisplayName(instanceID), + BotDisplayName: b.DisplayName(instanceID), CanBackfill: true, CapabilitiesEvent: b.host.RoomCapabilitiesEventType(), SettingsEvent: b.host.RoomSettingsEventType(), }) } -func (b *Bridge) createOpenCodeSessionChat(ctx context.Context, instanceID, title string, pendingTitle bool) (*bridgev2.CreateChatResponse, error) { +func (b *Bridge) CreateSessionChat(ctx context.Context, instanceID, title string, pendingTitle bool) (*bridgev2.CreateChatResponse, error) { if b == nil || b.host == nil { return nil, errors.New("login unavailable") } diff --git a/bridges/opencode/portal_send.go b/bridges/opencode/portal_send.go index e4a26b01..b2d05e9a 100644 --- a/bridges/opencode/portal_send.go +++ b/bridges/opencode/portal_send.go @@ -2,8 +2,6 @@ package opencode import ( "context" - "fmt" - "time" "maunium.net/go/mautrix/bridgev2" @@ -12,32 +10,20 @@ import ( // sendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. func (oc *OpenCodeClient) sendViaPortal( - ctx context.Context, + _ context.Context, portal *bridgev2.Portal, instanceID string, converted *bridgev2.ConvertedMessage, ) error { - if portal == nil || portal.MXID == "" { - return fmt.Errorf("invalid portal") - } - sender := oc.SenderForOpenCode(instanceID, false) - msgID := bridgeadapter.NewMessageID("opencode") - evt := &OpenCodeRemoteMessage{ - Portal: portal.PortalKey, - ID: msgID, - Sender: sender, - Timestamp: time.Now(), + _, _, err := bridgeadapter.SendViaPortal(bridgeadapter.SendViaPortalParams{ + Login: oc.UserLogin, + Portal: portal, + Sender: oc.SenderForOpenCode(instanceID, false), + IDPrefix: "opencode", LogKey: "opencode_msg_id", - PreBuilt: converted, - } - result := oc.UserLogin.QueueRemoteEvent(evt) - if !result.Success { - if result.Error != nil { - return fmt.Errorf("send failed: %w", result.Error) - } - return fmt.Errorf("send failed") - } - return nil + Converted: converted, + }) + return err } // sendSystemNoticeViaPortal is a convenience wrapper for sending MsgNotice via the pipeline. diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 747fda54..bdfdff76 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -1,6 +1,7 @@ package bridgeadapter import ( + "fmt" "time" "go.mau.fi/util/ptr" @@ -8,6 +9,7 @@ import ( "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" "github.com/beeper/ai-bridge/pkg/shared/streamtransport" ) @@ -137,6 +139,47 @@ func BuildDMChatInfo(p DMChatInfoParams) *bridgev2.ChatInfo { } } +// SendViaPortalParams holds the parameters for SendViaPortal. +type SendViaPortalParams struct { + Login *bridgev2.UserLogin + Portal *bridgev2.Portal + Sender bridgev2.EventSender + IDPrefix string // e.g. "ai", "codex", "opencode" + LogKey string // zerolog field name, e.g. "ai_msg_id" + MsgID networkid.MessageID + Converted *bridgev2.ConvertedMessage +} + +// SendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. +// If MsgID is empty, a new one is generated using IDPrefix. +func SendViaPortal(p SendViaPortalParams) (id.EventID, networkid.MessageID, error) { + if p.Portal == nil || p.Portal.MXID == "" { + return "", "", fmt.Errorf("invalid portal") + } + if p.Login == nil { + return "", p.MsgID, fmt.Errorf("bridge unavailable") + } + if p.MsgID == "" { + p.MsgID = NewMessageID(p.IDPrefix) + } + evt := &RemoteMessage{ + Portal: p.Portal.PortalKey, + ID: p.MsgID, + Sender: p.Sender, + Timestamp: time.Now(), + LogKey: p.LogKey, + PreBuilt: p.Converted, + } + result := p.Login.QueueRemoteEvent(evt) + if !result.Success { + if result.Error != nil { + return "", p.MsgID, fmt.Errorf("send failed: %w", result.Error) + } + return "", p.MsgID, fmt.Errorf("send failed") + } + return result.EventID, p.MsgID, nil +} + func BuildChatInfoWithFallback(metaTitle, portalName, fallbackTitle, portalTopic string) *bridgev2.ChatInfo { title := metaTitle if title == "" { diff --git a/pkg/bridgeadapter/message_metadata.go b/pkg/bridgeadapter/message_metadata.go index f3e0415a..914c6ac6 100644 --- a/pkg/bridgeadapter/message_metadata.go +++ b/pkg/bridgeadapter/message_metadata.go @@ -1,5 +1,77 @@ package bridgeadapter +// BaseMessageMetadata contains fields common to all bridge MessageMetadata structs. +// Embed this in each bridge's MessageMetadata to share CopyFrom logic. +type BaseMessageMetadata struct { + Role string `json:"role,omitempty"` + Body string `json:"body,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + PromptTokens int64 `json:"prompt_tokens,omitempty"` + CompletionTokens int64 `json:"completion_tokens,omitempty"` + ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` + TurnID string `json:"turn_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + CanonicalSchema string `json:"canonical_schema,omitempty"` + CanonicalUIMessage map[string]any `json:"canonical_ui_message,omitempty"` + StartedAtMs int64 `json:"started_at_ms,omitempty"` + CompletedAtMs int64 `json:"completed_at_ms,omitempty"` + ThinkingContent string `json:"thinking_content,omitempty"` + ToolCalls []ToolCallMetadata `json:"tool_calls,omitempty"` + GeneratedFiles []GeneratedFileRef `json:"generated_files,omitempty"` +} + +// CopyFromBase copies non-zero common fields from src into the receiver. +func (b *BaseMessageMetadata) CopyFromBase(src *BaseMessageMetadata) { + if src == nil { + return + } + if src.Role != "" { + b.Role = src.Role + } + if src.Body != "" { + b.Body = src.Body + } + if src.FinishReason != "" { + b.FinishReason = src.FinishReason + } + if src.PromptTokens != 0 { + b.PromptTokens = src.PromptTokens + } + if src.CompletionTokens != 0 { + b.CompletionTokens = src.CompletionTokens + } + if src.ReasoningTokens != 0 { + b.ReasoningTokens = src.ReasoningTokens + } + if src.TurnID != "" { + b.TurnID = src.TurnID + } + if src.AgentID != "" { + b.AgentID = src.AgentID + } + if src.CanonicalSchema != "" { + b.CanonicalSchema = src.CanonicalSchema + } + if len(src.CanonicalUIMessage) > 0 { + b.CanonicalUIMessage = src.CanonicalUIMessage + } + if src.StartedAtMs != 0 { + b.StartedAtMs = src.StartedAtMs + } + if src.CompletedAtMs != 0 { + b.CompletedAtMs = src.CompletedAtMs + } + if src.ThinkingContent != "" { + b.ThinkingContent = src.ThinkingContent + } + if len(src.ToolCalls) > 0 { + b.ToolCalls = src.ToolCalls + } + if len(src.GeneratedFiles) > 0 { + b.GeneratedFiles = src.GeneratedFiles + } +} + // ToolCallMetadata tracks a tool call within a message. // Both bridges and the connector share this type for JSON-serialized database storage. type ToolCallMetadata struct { diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 03e80aa9..1a04d232 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -1120,8 +1120,7 @@ func (oc *AIClient) copyMessagesToChat( Content: srcMeta.Body, Timestamp: srcMsg.Timestamp, Metadata: &MessageMetadata{ - Role: srcMeta.Role, - Body: srcMeta.Body, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: srcMeta.Role, Body: srcMeta.Body}, }, } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index b7acb3c7..df2ad2ee 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -2575,8 +2575,7 @@ func (oc *AIClient) handleDebouncedMessages(entries []DebounceEntry) { Room: last.Portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: combinedBody, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: combinedBody}, }, Timestamp: time.Now(), } diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index 1876a367..c3756353 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -318,8 +318,7 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: body, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: body}, }, Timestamp: bridgeadapter.MatrixEventTimestamp(msg.Event), } @@ -846,8 +845,10 @@ func (oc *AIClient) handleMediaMessage( } userMeta := &MessageMetadata{ - Role: "user", - Body: oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, buildMediaMetadataBody(caption, config.bodySuffix, understanding), senderName, roomName, isGroup), + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "user", + Body: oc.buildMatrixInboundBody(ctx, portal, meta, msg.Event, buildMediaMetadataBody(caption, config.bodySuffix, understanding), senderName, roomName, isGroup), + }, MediaURL: string(mediaURL), MimeType: mimeType, } diff --git a/pkg/connector/integration_host.go b/pkg/connector/integration_host.go index b8323b86..e5f09040 100644 --- a/pkg/connector/integration_host.go +++ b/pkg/connector/integration_host.go @@ -451,7 +451,7 @@ func (h *runtimeIntegrationHost) ResolveAgentID(raw string, fallbackDefault stri return agents.DefaultAgentID } normalized := normalizeAgentID(raw) - if normalized == "" || !h.agentExists(normalized) { + if normalized == "" || !h.AgentExists(normalized) { if fallbackDefault != "" { return normalizeAgentID(fallbackDefault) } @@ -465,10 +465,6 @@ func (h *runtimeIntegrationHost) NormalizeAgentID(raw string) string { } func (h *runtimeIntegrationHost) AgentExists(normalizedID string) bool { - return h.agentExists(normalizedID) -} - -func (h *runtimeIntegrationHost) agentExists(normalizedID string) bool { if h == nil || h.client == nil || h.client.connector == nil { return false } diff --git a/pkg/connector/metadata.go b/pkg/connector/metadata.go index ed8fc2bc..2f1e0f33 100644 --- a/pkg/connector/metadata.go +++ b/pkg/connector/metadata.go @@ -270,48 +270,23 @@ func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { // MessageMetadata keeps a tiny summary of each exchange so we can rebuild // prompts using database history. type MessageMetadata struct { - Role string `json:"role,omitempty"` - Body string `json:"body,omitempty"` + bridgeadapter.BaseMessageMetadata + CompletionID string `json:"completion_id,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - PromptTokens int64 `json:"prompt_tokens,omitempty"` - CompletionTokens int64 `json:"completion_tokens,omitempty"` Model string `json:"model,omitempty"` - ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` HasToolCalls bool `json:"has_tool_calls,omitempty"` Transcript string `json:"transcript,omitempty"` + FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` + ThinkingTokenCount int `json:"thinking_token_count,omitempty"` + ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` // Media understanding (OpenClaw-style) MediaUnderstanding []MediaUnderstandingOutput `json:"media_understanding,omitempty"` MediaUnderstandingDecisions []MediaUnderstandingDecision `json:"media_understanding_decisions,omitempty"` - // Turn tracking for the new schema - TurnID string `json:"turn_id,omitempty"` // Unique identifier for this assistant turn - AgentID string `json:"agent_id,omitempty"` // Which agent generated this (for multi-agent rooms) - - // Tool call tracking - ToolCalls []ToolCallMetadata `json:"tool_calls,omitempty"` // List of tool calls in this turn - - // Canonical internal schema payload (AI SDK compatible). - CanonicalSchema string `json:"canonical_schema,omitempty"` // e.g. ai-sdk-ui-message-v1 - CanonicalUIMessage map[string]any `json:"canonical_ui_message,omitempty"` // AI SDK UIMessage-compatible payload - - // Timing information - StartedAtMs int64 `json:"started_at_ms,omitempty"` // Unix ms when generation started - FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` // Unix ms of first token - CompletedAtMs int64 `json:"completed_at_ms,omitempty"` // Unix ms when completed - - // Thinking/reasoning content (embedded, not separate) - ThinkingContent string `json:"thinking_content,omitempty"` // Full thinking text - ThinkingTokenCount int `json:"thinking_token_count,omitempty"` // Number of thinking tokens - - // History exclusion - ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` // Exclude from LLM context (e.g., welcome messages) - // Multimodal history: media attached to this message for re-injection into prompts. - MediaURL string `json:"media_url,omitempty"` // mxc:// URL for user-sent media (image, PDF, audio, video) - MimeType string `json:"mime_type,omitempty"` // MIME type of user-sent media - GeneratedFiles []GeneratedFileRef `json:"generated_files,omitempty"` // Files generated by the assistant in this turn + MediaURL string `json:"media_url,omitempty"` // mxc:// URL for user-sent media + MimeType string `json:"mime_type,omitempty"` // MIME type of user-sent media } type GeneratedFileRef = bridgeadapter.GeneratedFileRef @@ -329,65 +304,28 @@ func (mm *MessageMetadata) CopyFrom(other any) { if !ok || src == nil { return } - if src.Role != "" { - mm.Role = src.Role - } - if src.Body != "" { - mm.Body = src.Body - } + mm.CopyFromBase(&src.BaseMessageMetadata) if src.CompletionID != "" { mm.CompletionID = src.CompletionID } - if src.FinishReason != "" { - mm.FinishReason = src.FinishReason - } - if src.PromptTokens != 0 { - mm.PromptTokens = src.PromptTokens - } - if src.CompletionTokens != 0 { - mm.CompletionTokens = src.CompletionTokens - } if src.Model != "" { mm.Model = src.Model } - if src.ReasoningTokens != 0 { - mm.ReasoningTokens = src.ReasoningTokens - } if src.HasToolCalls { mm.HasToolCalls = true } - - // Copy new fields - if src.TurnID != "" { - mm.TurnID = src.TurnID - } - if src.AgentID != "" { - mm.AgentID = src.AgentID - } - if len(src.ToolCalls) > 0 { - mm.ToolCalls = src.ToolCalls - } - if src.CanonicalSchema != "" { - mm.CanonicalSchema = src.CanonicalSchema - } - if len(src.CanonicalUIMessage) > 0 { - mm.CanonicalUIMessage = src.CanonicalUIMessage - } - if src.StartedAtMs != 0 { - mm.StartedAtMs = src.StartedAtMs + if src.Transcript != "" { + mm.Transcript = src.Transcript } if src.FirstTokenAtMs != 0 { mm.FirstTokenAtMs = src.FirstTokenAtMs } - if src.CompletedAtMs != 0 { - mm.CompletedAtMs = src.CompletedAtMs - } - if src.ThinkingContent != "" { - mm.ThinkingContent = src.ThinkingContent - } if src.ThinkingTokenCount != 0 { mm.ThinkingTokenCount = src.ThinkingTokenCount } + if src.ExcludeFromHistory { + mm.ExcludeFromHistory = true + } } var _ database.MetaMerger = (*MessageMetadata)(nil) diff --git a/pkg/connector/portal_send.go b/pkg/connector/portal_send.go index f10de140..db6bd263 100644 --- a/pkg/connector/portal_send.go +++ b/pkg/connector/portal_send.go @@ -30,39 +30,23 @@ func ensureConvertedMessageParts(converted *bridgev2.ConvertedMessage) { converted.Parts = parts } -// Handles: intent resolution, ghost room join, send, DB persist via QueueRemoteEvent. -// Returns the Matrix event ID and the network message ID used. -// If msgID is empty, a new one is generated. +// sendViaPortal sends a pre-built message through bridgev2's QueueRemoteEvent pipeline. func (oc *AIClient) sendViaPortal( ctx context.Context, portal *bridgev2.Portal, converted *bridgev2.ConvertedMessage, msgID networkid.MessageID, ) (id.EventID, networkid.MessageID, error) { - if portal == nil || portal.MXID == "" { - return "", "", fmt.Errorf("invalid portal") - } - if msgID == "" { - msgID = bridgeadapter.NewMessageID("ai") - } ensureConvertedMessageParts(converted) - sender := oc.senderForPortal(ctx, portal) - evt := &bridgeadapter.RemoteMessage{ - Portal: portal.PortalKey, - ID: msgID, - Sender: sender, - Timestamp: time.Now(), + return bridgeadapter.SendViaPortal(bridgeadapter.SendViaPortalParams{ + Login: oc.UserLogin, + Portal: portal, + Sender: oc.senderForPortal(ctx, portal), + IDPrefix: "ai", LogKey: "ai_msg_id", - PreBuilt: converted, - } - result := oc.UserLogin.QueueRemoteEvent(evt) - if !result.Success { - if result.Error != nil { - return "", msgID, fmt.Errorf("send failed: %w", result.Error) - } - return "", msgID, fmt.Errorf("send failed") - } - return result.EventID, msgID, nil + MsgID: msgID, + Converted: converted, + }) } // The targetMsgID is the network message ID of the message to edit. From 6b67d5a1f458d855abee54a778258f3950ce0292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:34:39 +0100 Subject: [PATCH 14/23] Use BaseMessageMetadata in MessageMetadata Refactor MessageMetadata usage to embed bridgeadapter.BaseMessageMetadata across the codebase. Updates constructors and event metadata to move common fields (Role, Body, TurnID, AgentID, token counts, canonical UI/schema, thinking content, tool calls, timestamps, etc.) into BaseMessageMetadata and adjust remaining top-level fields (SessionID, MessageID, Model, etc.) where appropriate. Files updated: bridges/opencode/{host.go,opencodebridge/backfill_canonical.go,stream_canonical.go} and pkg/connector/{handlematrix.go,internal_dispatch.go,response_finalization.go,streaming_persistence.go,subagent_spawn.go,toast.go}. Also adds necessary bridgeadapter imports and minor field reorderings to match the new structure. This centralizes base metadata handling for clearer structure and reuse. --- bridges/opencode/host.go | 12 +++-- .../opencodebridge/backfill_canonical.go | 1 + bridges/opencode/stream_canonical.go | 53 ++++++++++--------- pkg/connector/handlematrix.go | 6 +-- pkg/connector/internal_dispatch.go | 5 +- pkg/connector/response_finalization.go | 2 +- pkg/connector/streaming_persistence.go | 32 +++++------ pkg/connector/subagent_spawn.go | 3 +- pkg/connector/toast.go | 8 +-- 9 files changed, 64 insertions(+), 58 deletions(-) diff --git a/bridges/opencode/host.go b/bridges/opencode/host.go index 63b813af..adbd30c6 100644 --- a/bridges/opencode/host.go +++ b/bridges/opencode/host.go @@ -143,11 +143,13 @@ func (oc *OpenCodeClient) EmitOpenCodeStreamEvent(ctx context.Context, portal *b Content: &event.MessageEventContent{MsgType: event.MsgText, Body: "..."}, Extra: extra, DBMetadata: &MessageMetadata{ - Role: "assistant", - TurnID: turnID, - AgentID: strings.TrimSpace(agentID), - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: uiMessage, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + TurnID: turnID, + AgentID: strings.TrimSpace(agentID), + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: uiMessage, + }, }, }}, } diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index 3c181524..9b2eefe3 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -6,6 +6,7 @@ import ( "maunium.net/go/mautrix/event" "github.com/beeper/ai-bridge/bridges/opencode/opencode" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/streamui" "github.com/beeper/ai-bridge/pkg/shared/stringutil" diff --git a/bridges/opencode/stream_canonical.go b/bridges/opencode/stream_canonical.go index 598cead0..b5ff296c 100644 --- a/bridges/opencode/stream_canonical.go +++ b/bridges/opencode/stream_canonical.go @@ -11,6 +11,7 @@ import ( "maunium.net/go/mautrix/format" "github.com/beeper/ai-bridge/bridges/opencode/opencodebridge" + "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/connector/msgconv" "github.com/beeper/ai-bridge/pkg/matrixevents" "github.com/beeper/ai-bridge/pkg/shared/maputil" @@ -117,31 +118,33 @@ func (oc *OpenCodeClient) buildStreamDBMetadata(state *openCodeStreamState) *Mes uiMessage := oc.currentCanonicalUIMessage(state) thinking := opencodebridge.CanonicalReasoningText(uiMessage) return &MessageMetadata{ - Role: stringutil.FirstNonEmpty(state.role, "assistant"), - Body: stringutil.FirstNonEmpty(state.visible.String(), state.accumulated.String()), - SessionID: state.sessionID, - MessageID: state.messageID, - ParentMessageID: state.parentMessageID, - Agent: state.agent, - ModelID: state.modelID, - ProviderID: state.providerID, - Mode: state.mode, - FinishReason: state.finishReason, - ErrorText: state.errorText, - Cost: state.cost, - PromptTokens: state.promptTokens, - CompletionTokens: state.completionTokens, - ReasoningTokens: state.reasoningTokens, - TotalTokens: state.totalTokens, - TurnID: state.turnID, - AgentID: state.agentID, - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: uiMessage, - StartedAtMs: state.startedAtMs, - CompletedAtMs: state.completedAtMs, - ThinkingContent: thinking, - ToolCalls: opencodebridge.CanonicalToolCalls(uiMessage), - GeneratedFiles: opencodebridge.CanonicalGeneratedFiles(uiMessage), + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: stringutil.FirstNonEmpty(state.role, "assistant"), + Body: stringutil.FirstNonEmpty(state.visible.String(), state.accumulated.String()), + FinishReason: state.finishReason, + PromptTokens: state.promptTokens, + CompletionTokens: state.completionTokens, + ReasoningTokens: state.reasoningTokens, + TurnID: state.turnID, + AgentID: state.agentID, + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: uiMessage, + StartedAtMs: state.startedAtMs, + CompletedAtMs: state.completedAtMs, + ThinkingContent: thinking, + ToolCalls: opencodebridge.CanonicalToolCalls(uiMessage), + GeneratedFiles: opencodebridge.CanonicalGeneratedFiles(uiMessage), + }, + SessionID: state.sessionID, + MessageID: state.messageID, + ParentMessageID: state.parentMessageID, + Agent: state.agent, + ModelID: state.modelID, + ProviderID: state.providerID, + Mode: state.mode, + ErrorText: state.errorText, + Cost: state.cost, + TotalTokens: state.totalTokens, } } diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index c3756353..f6ade451 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -726,8 +726,7 @@ func (oc *AIClient) handleMediaMessage( Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: body, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: body}, }, Timestamp: bridgeadapter.MatrixEventTimestamp(msg.Event), } @@ -1024,8 +1023,7 @@ func (oc *AIClient) handleTextFileMessage( Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: combined, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: combined}, }, Timestamp: bridgeadapter.MatrixEventTimestamp(msg.Event), } diff --git a/pkg/connector/internal_dispatch.go b/pkg/connector/internal_dispatch.go index 73763247..0f322064 100644 --- a/pkg/connector/internal_dispatch.go +++ b/pkg/connector/internal_dispatch.go @@ -66,9 +66,8 @@ func (oc *AIClient) dispatchInternalMessage( Room: portal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: trimmed, - ExcludeFromHistory: excludeFromHistory, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: trimmed}, + ExcludeFromHistory: excludeFromHistory, }, Timestamp: time.Now(), } diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index a4e0906d..04f05b6c 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -144,7 +144,7 @@ func (oc *AIClient) sendInitialStreamMessage(ctx context.Context, portal *bridge Type: event.EventMessage, Content: &event.MessageEventContent{MsgType: event.MsgText, Body: content}, Extra: eventRaw, - DBMetadata: &MessageMetadata{Role: "assistant", TurnID: turnID}, + DBMetadata: &MessageMetadata{BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "assistant", TurnID: turnID}}, }}, } diff --git a/pkg/connector/streaming_persistence.go b/pkg/connector/streaming_persistence.go index d520cdcc..63dd6f20 100644 --- a/pkg/connector/streaming_persistence.go +++ b/pkg/connector/streaming_persistence.go @@ -36,26 +36,28 @@ func (oc *AIClient) saveAssistantMessage( } fullMeta := &MessageMetadata{ - Role: "assistant", - Body: state.accumulated.String(), + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + Body: state.accumulated.String(), + FinishReason: state.finishReason, + TurnID: state.turnID, + AgentID: state.agentID, + ToolCalls: state.toolCalls, + StartedAtMs: state.startedAtMs, + CompletedAtMs: state.completedAtMs, + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: oc.buildCanonicalUIMessage(state, meta), + GeneratedFiles: genFiles, + ThinkingContent: state.reasoning.String(), + PromptTokens: state.promptTokens, + CompletionTokens: state.completionTokens, + ReasoningTokens: state.reasoningTokens, + }, CompletionID: state.responseID, - FinishReason: state.finishReason, Model: modelID, - TurnID: state.turnID, - AgentID: state.agentID, - ToolCalls: state.toolCalls, - StartedAtMs: state.startedAtMs, FirstTokenAtMs: state.firstTokenAtMs, - CompletedAtMs: state.completedAtMs, HasToolCalls: len(state.toolCalls) > 0, - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: oc.buildCanonicalUIMessage(state, meta), - GeneratedFiles: genFiles, - ThinkingContent: state.reasoning.String(), ThinkingTokenCount: thinkingTokenCount(modelID, state.reasoning.String()), - PromptTokens: state.promptTokens, - CompletionTokens: state.completionTokens, - ReasoningTokens: state.reasoningTokens, } // If the message was sent via sendViaPortal, the DB row already exists — update it. diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index b102cc69..0825d34c 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -359,8 +359,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P Room: childPortal.PortalKey, SenderID: humanUserID(oc.UserLogin.ID), Metadata: &MessageMetadata{ - Role: "user", - Body: task, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: "user", Body: task}, }, Timestamp: time.Now(), } diff --git a/pkg/connector/toast.go b/pkg/connector/toast.go index 830fe009..c2f8ebe7 100644 --- a/pkg/connector/toast.go +++ b/pkg/connector/toast.go @@ -173,10 +173,12 @@ func buildApprovalSnapshotPart(body string, uiMessage map[string]any, toastText Content: &event.MessageEventContent{MsgType: event.MsgNotice, Body: body}, Extra: raw, DBMetadata: &MessageMetadata{ - Role: "assistant", + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: uiMessage, + }, ExcludeFromHistory: true, - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: uiMessage, }, } } From 34de9880443111ff879160b10638dda790da2d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:37:17 +0100 Subject: [PATCH 15/23] Unify prompt text appending Replace separate appendSystemPromptText/appendDeveloperPromptText helpers with a single appendPromptText(dst *string, text string) helper. Update call sites in pkg/connector/messages.go and pkg/connector/client.go to pass the target prompt field by pointer (e.g. &ctx.SystemPrompt). Behavior preserved: trims input, ignores empty text, and joins existing/prompts with a double newline. This reduces duplication and simplifies prompt handling. --- bridges/codex/client.go | 55 ++++---------------- pkg/bridgeadapter/helpers.go | 69 ++++++++++++++++++++++++++ pkg/connector/client.go | 4 +- pkg/connector/messages.go | 24 +++------ pkg/connector/streaming_persistence.go | 69 ++++---------------------- 5 files changed, 96 insertions(+), 125 deletions(-) diff --git a/bridges/codex/client.go b/bridges/codex/client.go index 1a48dbc0..cc5614fd 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -2084,52 +2084,15 @@ func (cc *CodexClient) saveAssistantMessage(ctx context.Context, portal *bridgev ThinkingTokenCount: len(strings.Fields(state.reasoning.String())), } - // If the message was sent via sendViaPortal, the DB row already exists — update it. - if state.networkMessageID != "" { - receiver := portal.Receiver - if receiver == "" && cc.UserLogin != nil { - receiver = cc.UserLogin.ID - } - var existing *database.Message - var err error - if receiver != "" { - existing, err = cc.UserLogin.Bridge.DB.Message.GetPartByID(ctx, receiver, state.networkMessageID, networkid.PartID("0")) - } - if existing == nil && state.initialEventID != "" { - existing, err = cc.UserLogin.Bridge.DB.Message.GetPartByMXID(ctx, state.initialEventID) - } - if err == nil && existing != nil { - existing.Metadata = fullMeta - if err := cc.UserLogin.Bridge.DB.Message.Update(ctx, existing); err != nil { - log.Warn().Err(err).Str("msg_id", string(existing.ID)).Msg("Failed to update assistant message metadata") - } else { - log.Debug().Str("msg_id", string(existing.ID)).Msg("Updated assistant message metadata") - } - return - } - log.Warn(). - Err(err). - Stringer("mxid", state.initialEventID). - Str("msg_id", string(state.networkMessageID)). - Msg("Could not find existing DB row for update, falling back to insert") - } - if state.initialEventID == "" { - return - } - - assistantMsg := &database.Message{ - ID: bridgeadapter.MatrixMessageID(state.initialEventID), - Room: portal.PortalKey, - SenderID: codexGhostID, - MXID: state.initialEventID, - Timestamp: time.Now(), - Metadata: fullMeta, - } - if err := cc.UserLogin.Bridge.DB.Message.Insert(ctx, assistantMsg); err != nil { - log.Warn().Err(err).Msg("Failed to save assistant message") - } else { - log.Debug().Str("msg_id", string(assistantMsg.ID)).Msg("Saved assistant message to database") - } + bridgeadapter.UpsertAssistantMessage(ctx, bridgeadapter.UpsertAssistantMessageParams{ + Login: cc.UserLogin, + Portal: portal, + SenderID: codexGhostID, + NetworkMessageID: state.networkMessageID, + InitialEventID: state.initialEventID, + Metadata: fullMeta, + Logger: log, + }) } // --- Approvals --- diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index bdfdff76..9a3cfea8 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -1,9 +1,11 @@ package bridgeadapter import ( + "context" "fmt" "time" + "github.com/rs/zerolog" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" @@ -194,3 +196,70 @@ func BuildChatInfoWithFallback(metaTitle, portalName, fallbackTitle, portalTopic Topic: ptr.NonZero(portalTopic), } } + +// UpsertAssistantMessageParams holds parameters for UpsertAssistantMessage. +type UpsertAssistantMessageParams struct { + Login *bridgev2.UserLogin + Portal *bridgev2.Portal + SenderID networkid.UserID + NetworkMessageID networkid.MessageID + InitialEventID id.EventID + Metadata any // must satisfy database.MetaMerger + Logger zerolog.Logger +} + +// UpsertAssistantMessage updates an existing message's metadata or inserts a new one. +// If NetworkMessageID is set, tries to find and update the existing row first. +// Falls back to inserting a new row keyed by InitialEventID. +func UpsertAssistantMessage(ctx context.Context, p UpsertAssistantMessageParams) { + if p.Login == nil || p.Portal == nil { + return + } + db := p.Login.Bridge.DB.Message + + if p.NetworkMessageID != "" { + receiver := p.Portal.Receiver + if receiver == "" { + receiver = p.Login.ID + } + var existing *database.Message + var err error + if receiver != "" { + existing, err = db.GetPartByID(ctx, receiver, p.NetworkMessageID, networkid.PartID("0")) + } + if existing == nil && p.InitialEventID != "" { + existing, err = db.GetPartByMXID(ctx, p.InitialEventID) + } + if err == nil && existing != nil { + existing.Metadata = p.Metadata + if err := db.Update(ctx, existing); err != nil { + p.Logger.Warn().Err(err).Str("msg_id", string(existing.ID)).Msg("Failed to update assistant message metadata") + } else { + p.Logger.Debug().Str("msg_id", string(existing.ID)).Msg("Updated assistant message metadata") + } + return + } + p.Logger.Warn(). + Err(err). + Stringer("mxid", p.InitialEventID). + Str("msg_id", string(p.NetworkMessageID)). + Msg("Could not find existing DB row for update, falling back to insert") + } + + if p.InitialEventID == "" { + return + } + assistantMsg := &database.Message{ + ID: MatrixMessageID(p.InitialEventID), + Room: p.Portal.PortalKey, + SenderID: p.SenderID, + MXID: p.InitialEventID, + Timestamp: time.Now(), + Metadata: p.Metadata, + } + if err := db.Insert(ctx, assistantMsg); err != nil { + p.Logger.Warn().Err(err).Msg("Failed to insert assistant message to database") + } else { + p.Logger.Debug().Str("msg_id", string(assistantMsg.ID)).Msg("Inserted assistant message to database") + } +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index df2ad2ee..950780a9 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -2031,7 +2031,7 @@ func (oc *AIClient) buildContextWithLinkContext( isSimple := isSimpleMode(meta) if !isSimple { - appendSystemPromptText(&promptContext, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) + appendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) } finalMessage := strings.TrimSpace(latest) @@ -2167,7 +2167,7 @@ func (oc *AIClient) buildContextWithMedia( isSimple := isSimpleMode(meta) inboundCtx := oc.resolvePromptInboundContext(ctx, portal, caption, eventID) if !isSimple { - appendSystemPromptText(&promptContext, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) + appendPromptText(&promptContext.SystemPrompt, airuntime.BuildInboundMetaSystemPrompt(inboundCtx)) } captionWithID := strings.TrimSpace(caption) diff --git a/pkg/connector/messages.go b/pkg/connector/messages.go index e45bd2bc..706f0c3c 100644 --- a/pkg/connector/messages.go +++ b/pkg/connector/messages.go @@ -268,9 +268,9 @@ func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatComplet } switch { case msg.OfSystem != nil: - appendSystemPromptText(ctx, extractChatSystemText(msg.OfSystem.Content)) + appendPromptText(&ctx.SystemPrompt, extractChatSystemText(msg.OfSystem.Content)) case msg.OfDeveloper != nil: - appendDeveloperPromptText(ctx, extractChatDeveloperText(msg.OfDeveloper.Content)) + appendPromptText(&ctx.DeveloperPrompt, extractChatDeveloperText(msg.OfDeveloper.Content)) case msg.OfUser != nil: ctx.Messages = append(ctx.Messages, promptMessageFromChatUser(msg.OfUser)) case msg.OfAssistant != nil: @@ -280,28 +280,16 @@ func appendChatMessageToPromptContext(ctx *PromptContext, msg openai.ChatComplet } } -func appendSystemPromptText(ctx *PromptContext, text string) { +func appendPromptText(dst *string, text string) { text = strings.TrimSpace(text) if text == "" { return } - if ctx.SystemPrompt == "" { - ctx.SystemPrompt = text + if *dst == "" { + *dst = text return } - ctx.SystemPrompt = strings.TrimSpace(ctx.SystemPrompt + "\n\n" + text) -} - -func appendDeveloperPromptText(ctx *PromptContext, text string) { - text = strings.TrimSpace(text) - if text == "" { - return - } - if ctx.DeveloperPrompt == "" { - ctx.DeveloperPrompt = text - return - } - ctx.DeveloperPrompt = strings.TrimSpace(ctx.DeveloperPrompt + "\n\n" + text) + *dst = strings.TrimSpace(*dst + "\n\n" + text) } func promptMessageFromChatUser(msg *openai.ChatCompletionUserMessageParam) PromptMessage { diff --git a/pkg/connector/streaming_persistence.go b/pkg/connector/streaming_persistence.go index 63dd6f20..6d385d22 100644 --- a/pkg/connector/streaming_persistence.go +++ b/pkg/connector/streaming_persistence.go @@ -60,38 +60,15 @@ func (oc *AIClient) saveAssistantMessage( ThinkingTokenCount: thinkingTokenCount(modelID, state.reasoning.String()), } - // If the message was sent via sendViaPortal, the DB row already exists — update it. - if state.networkMessageID != "" { - receiver := portal.Receiver - if receiver == "" && oc.UserLogin != nil { - receiver = oc.UserLogin.ID - } - var existing *database.Message - var err error - if receiver != "" { - existing, err = oc.UserLogin.Bridge.DB.Message.GetPartByID(ctx, receiver, state.networkMessageID, networkid.PartID("0")) - } - if existing == nil && state.initialEventID != "" { - existing, err = oc.UserLogin.Bridge.DB.Message.GetPartByMXID(ctx, state.initialEventID) - } - if err == nil && existing != nil { - existing.Metadata = fullMeta - if err := oc.UserLogin.Bridge.DB.Message.Update(ctx, existing); err != nil { - log.Warn().Err(err).Str("msg_id", string(existing.ID)).Msg("Failed to update assistant message metadata") - } else { - log.Debug().Str("msg_id", string(existing.ID)).Msg("Updated assistant message metadata") - } - } else { - log.Warn(). - Err(err). - Stringer("mxid", state.initialEventID). - Str("msg_id", string(state.networkMessageID)). - Msg("Could not find existing DB row for update, falling back to insert") - oc.insertAssistantMessage(ctx, log, portal, state, modelID, fullMeta) - } - } else { - oc.insertAssistantMessage(ctx, log, portal, state, modelID, fullMeta) - } + bridgeadapter.UpsertAssistantMessage(ctx, bridgeadapter.UpsertAssistantMessageParams{ + Login: oc.UserLogin, + Portal: portal, + SenderID: modelUserID(modelID), + NetworkMessageID: state.networkMessageID, + InitialEventID: state.initialEventID, + Metadata: fullMeta, + Logger: log, + }) usageMetaUpdated := false if meta != nil && (state.promptTokens > 0 || state.completionTokens > 0) { @@ -110,33 +87,7 @@ func (oc *AIClient) saveAssistantMessage( oc.notifySessionMutation(ctx, portal, meta, false) } -// insertAssistantMessage is the fallback path for saving assistant messages when no -// pre-existing DB row was created by sendViaPortal. -func (oc *AIClient) insertAssistantMessage( - ctx context.Context, - log zerolog.Logger, - portal *bridgev2.Portal, - state *streamingState, - modelID string, - meta *MessageMetadata, -) { - if state == nil || state.initialEventID == "" { - return - } - assistantMsg := &database.Message{ - ID: bridgeadapter.MatrixMessageID(state.initialEventID), - Room: portal.PortalKey, - SenderID: modelUserID(modelID), - MXID: state.initialEventID, - Timestamp: time.Now(), - Metadata: meta, - } - if err := oc.UserLogin.Bridge.DB.Message.Insert(ctx, assistantMsg); err != nil { - log.Warn().Err(err).Msg("Failed to insert assistant message to database") - } else { - log.Debug().Str("msg_id", string(assistantMsg.ID)).Msg("Inserted assistant message to database") - } -} + func thinkingTokenCount(model string, content string) int { content = strings.TrimSpace(content) From bcad817cd20594a6d968cd4b5f1cce3e913675f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:37:51 +0100 Subject: [PATCH 16/23] Dereference logger and remove unused imports Pass a zerolog.Logger value to the struct by dereferencing log (Logger: *log) and remove unused imports (database, networkid) from streaming_persistence.go. These changes fix type/import issues and clean up compile-time warnings. --- bridges/codex/client.go | 2 +- pkg/connector/streaming_persistence.go | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/bridges/codex/client.go b/bridges/codex/client.go index cc5614fd..e8c836cc 100644 --- a/bridges/codex/client.go +++ b/bridges/codex/client.go @@ -2091,7 +2091,7 @@ func (cc *CodexClient) saveAssistantMessage(ctx context.Context, portal *bridgev NetworkMessageID: state.networkMessageID, InitialEventID: state.initialEventID, Metadata: fullMeta, - Logger: log, + Logger: *log, }) } diff --git a/pkg/connector/streaming_persistence.go b/pkg/connector/streaming_persistence.go index 6d385d22..31d36248 100644 --- a/pkg/connector/streaming_persistence.go +++ b/pkg/connector/streaming_persistence.go @@ -7,8 +7,6 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) From 69e60f1f70fbd8ca16638126afc676c462001c20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:40:22 +0100 Subject: [PATCH 17/23] Embed BaseMessageMetadata and guard nil bridge Add a nil check for p.Login.Bridge in SendViaPortal to avoid bridge-unavailable panics. Update tests and usages to instantiate MessageMetadata via bridgeadapter.BaseMessageMetadata (add import) so common fields are embedded consistently. Also make minor struct field formatting adjustments and remove stray blank lines in streaming_persistence.go. --- pkg/bridgeadapter/helpers.go | 2 +- pkg/bridgeadapter/message_metadata.go | 28 +++++------ pkg/connector/canonical_history_test.go | 48 +++++++++++-------- pkg/connector/metadata.go | 14 +++--- .../session_transcript_openclaw_test.go | 48 ++++++++++--------- pkg/connector/streaming_persistence.go | 2 - 6 files changed, 75 insertions(+), 67 deletions(-) diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 9a3cfea8..19a611d1 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -158,7 +158,7 @@ func SendViaPortal(p SendViaPortalParams) (id.EventID, networkid.MessageID, erro if p.Portal == nil || p.Portal.MXID == "" { return "", "", fmt.Errorf("invalid portal") } - if p.Login == nil { + if p.Login == nil || p.Login.Bridge == nil { return "", p.MsgID, fmt.Errorf("bridge unavailable") } if p.MsgID == "" { diff --git a/pkg/bridgeadapter/message_metadata.go b/pkg/bridgeadapter/message_metadata.go index 914c6ac6..da715e25 100644 --- a/pkg/bridgeadapter/message_metadata.go +++ b/pkg/bridgeadapter/message_metadata.go @@ -3,21 +3,21 @@ package bridgeadapter // BaseMessageMetadata contains fields common to all bridge MessageMetadata structs. // Embed this in each bridge's MessageMetadata to share CopyFrom logic. type BaseMessageMetadata struct { - Role string `json:"role,omitempty"` - Body string `json:"body,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - PromptTokens int64 `json:"prompt_tokens,omitempty"` - CompletionTokens int64 `json:"completion_tokens,omitempty"` - ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` - TurnID string `json:"turn_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - CanonicalSchema string `json:"canonical_schema,omitempty"` + Role string `json:"role,omitempty"` + Body string `json:"body,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + PromptTokens int64 `json:"prompt_tokens,omitempty"` + CompletionTokens int64 `json:"completion_tokens,omitempty"` + ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` + TurnID string `json:"turn_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + CanonicalSchema string `json:"canonical_schema,omitempty"` CanonicalUIMessage map[string]any `json:"canonical_ui_message,omitempty"` - StartedAtMs int64 `json:"started_at_ms,omitempty"` - CompletedAtMs int64 `json:"completed_at_ms,omitempty"` - ThinkingContent string `json:"thinking_content,omitempty"` - ToolCalls []ToolCallMetadata `json:"tool_calls,omitempty"` - GeneratedFiles []GeneratedFileRef `json:"generated_files,omitempty"` + StartedAtMs int64 `json:"started_at_ms,omitempty"` + CompletedAtMs int64 `json:"completed_at_ms,omitempty"` + ThinkingContent string `json:"thinking_content,omitempty"` + ToolCalls []ToolCallMetadata `json:"tool_calls,omitempty"` + GeneratedFiles []GeneratedFileRef `json:"generated_files,omitempty"` } // CopyFromBase copies non-zero common fields from src into the receiver. diff --git a/pkg/connector/canonical_history_test.go b/pkg/connector/canonical_history_test.go index 224f6d86..1a70cf67 100644 --- a/pkg/connector/canonical_history_test.go +++ b/pkg/connector/canonical_history_test.go @@ -3,19 +3,23 @@ package connector import ( "context" "testing" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func TestHistoryMessageBundle_LegacyAssistantFallback(t *testing.T) { oc := &AIClient{} bundle := oc.historyMessageBundle(context.Background(), &MessageMetadata{ - Role: "assistant", - Body: "done", - ToolCalls: []ToolCallMetadata{{ - CallID: "call_1", - ToolName: "Read", - Input: map[string]any{"path": "README.md"}, - Output: map[string]any{"result": "ok"}, - }}, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + Body: "done", + ToolCalls: []ToolCallMetadata{{ + CallID: "call_1", + ToolName: "Read", + Input: map[string]any{"path": "README.md"}, + Output: map[string]any{"result": "ok"}, + }}, + }, }, false) if len(bundle) != 2 { @@ -35,19 +39,21 @@ func TestHistoryMessageBundle_LegacyAssistantFallback(t *testing.T) { func TestHistoryMessageBundle_CanonicalPartsSupportsMapSlices(t *testing.T) { oc := &AIClient{} bundle := oc.historyMessageBundle(context.Background(), &MessageMetadata{ - Role: "assistant", - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: map[string]any{ - "role": "assistant", - "parts": []map[string]any{ - {"type": "text", "text": "hello"}, - { - "type": "dynamic-tool", - "toolCallId": "call_1", - "toolName": "Read", - "input": map[string]any{"path": "README.md"}, - "state": "output-available", - "output": map[string]any{"result": "ok"}, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: map[string]any{ + "role": "assistant", + "parts": []map[string]any{ + {"type": "text", "text": "hello"}, + { + "type": "dynamic-tool", + "toolCallId": "call_1", + "toolName": "Read", + "input": map[string]any{"path": "README.md"}, + "state": "output-available", + "output": map[string]any{"result": "ok"}, + }, }, }, }, diff --git a/pkg/connector/metadata.go b/pkg/connector/metadata.go index 2f1e0f33..5aee49c5 100644 --- a/pkg/connector/metadata.go +++ b/pkg/connector/metadata.go @@ -272,13 +272,13 @@ func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { type MessageMetadata struct { bridgeadapter.BaseMessageMetadata - CompletionID string `json:"completion_id,omitempty"` - Model string `json:"model,omitempty"` - HasToolCalls bool `json:"has_tool_calls,omitempty"` - Transcript string `json:"transcript,omitempty"` - FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` - ThinkingTokenCount int `json:"thinking_token_count,omitempty"` - ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` + CompletionID string `json:"completion_id,omitempty"` + Model string `json:"model,omitempty"` + HasToolCalls bool `json:"has_tool_calls,omitempty"` + Transcript string `json:"transcript,omitempty"` + FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` + ThinkingTokenCount int `json:"thinking_token_count,omitempty"` + ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` // Media understanding (OpenClaw-style) MediaUnderstanding []MediaUnderstandingOutput `json:"media_understanding,omitempty"` diff --git a/pkg/connector/session_transcript_openclaw_test.go b/pkg/connector/session_transcript_openclaw_test.go index 86601019..2ea6d400 100644 --- a/pkg/connector/session_transcript_openclaw_test.go +++ b/pkg/connector/session_transcript_openclaw_test.go @@ -7,6 +7,8 @@ import ( "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/id" + + "github.com/beeper/ai-bridge/pkg/bridgeadapter" ) func TestStripOpenClawToolResults(t *testing.T) { @@ -101,30 +103,32 @@ func TestBuildOpenClawSessionMessagesFromCanonical(t *testing.T) { MXID: id.EventID("$assistant1"), Timestamp: time.UnixMilli(1730000000000), Metadata: &MessageMetadata{ - Role: "assistant", - CanonicalSchema: "ai-sdk-ui-message-v1", - CanonicalUIMessage: map[string]any{ - "parts": []any{ - map[string]any{"type": "text", "text": "hello"}, - map[string]any{ - "type": "dynamic-tool", - "toolCallId": "call_1", - "toolName": "web_search", - "input": map[string]any{"q": "matrix"}, - "state": "output-available", - "output": map[string]any{"result": "ok"}, + BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{ + Role: "assistant", + CanonicalSchema: "ai-sdk-ui-message-v1", + CanonicalUIMessage: map[string]any{ + "parts": []any{ + map[string]any{"type": "text", "text": "hello"}, + map[string]any{ + "type": "dynamic-tool", + "toolCallId": "call_1", + "toolName": "web_search", + "input": map[string]any{"q": "matrix"}, + "state": "output-available", + "output": map[string]any{"result": "ok"}, + }, }, }, - }, - ToolCalls: []ToolCallMetadata{ - { - CallID: "call_1", - ToolName: "web_search", - Input: map[string]any{"q": "matrix"}, - Output: map[string]any{"result": "ok"}, - ResultStatus: "success", - CallEventID: "$toolcall1", - ResultEventID: "$toolresult1", + ToolCalls: []ToolCallMetadata{ + { + CallID: "call_1", + ToolName: "web_search", + Input: map[string]any{"q": "matrix"}, + Output: map[string]any{"result": "ok"}, + ResultStatus: "success", + CallEventID: "$toolcall1", + ResultEventID: "$toolresult1", + }, }, }, }, diff --git a/pkg/connector/streaming_persistence.go b/pkg/connector/streaming_persistence.go index 31d36248..c9d8a61d 100644 --- a/pkg/connector/streaming_persistence.go +++ b/pkg/connector/streaming_persistence.go @@ -85,8 +85,6 @@ func (oc *AIClient) saveAssistantMessage( oc.notifySessionMutation(ctx, portal, meta, false) } - - func thinkingTokenCount(model string, content string) int { content = strings.TrimSpace(content) if content == "" { From 7acbcf6bba3430ce96c06cb36137ac830aaffafe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 19:47:34 +0100 Subject: [PATCH 18/23] Fix various bugs and improve robustness Multiple fixes across the codebase to improve correctness and nil-safety: adjust cancel/connected ordering in OpenCode timer handling; update ApplyPatch tool title; ensure mustJSON returns a proper JSON-escaped error string; bound NextUserLoginID loop and add a safe fallback; initialize timestamps on RemoteMessage/RemoteEdit when zero and fix in-place assignment of PreBuilt.ModifiedParts; guard addSummary nil call in error logging; validate empty description in media understanding fallback; restrict heartbeat session portal matches by agent metadata; only apply default heartbeat visibility when present; rewrite scoped row deletion to collect ids, check rows.Err, and perform deletes safely; include accumulated and reasoning parts in streaming UI messages; panic on unsupported subagent config types and add missing import; and only include mediaType/title in citation parts when non-empty. These changes prevent panics, ensure correct data, and improve logging and UI behavior. --- .../opencodebridge/opencode_instance_state.go | 2 +- pkg/agents/tools/apply_patch.go | 2 +- pkg/agents/tools/results.go | 3 ++- pkg/bridgeadapter/identifier_helpers.go | 5 ++++- pkg/bridgeadapter/remote_events.go | 10 +++++----- pkg/connector/error_logging.go | 4 +++- pkg/connector/handlematrix.go | 3 +++ pkg/connector/heartbeat_execute.go | 4 +++- pkg/connector/heartbeat_visibility.go | 4 +++- pkg/connector/scheduler_db.go | 15 +++++++++++---- pkg/connector/streaming_ui_helpers.go | 8 ++++++++ pkg/connector/subagent_conversion.go | 3 ++- pkg/shared/citations/citations.go | 12 ++++++++---- 13 files changed, 54 insertions(+), 21 deletions(-) diff --git a/bridges/opencode/opencodebridge/opencode_instance_state.go b/bridges/opencode/opencodebridge/opencode_instance_state.go index 52b59caf..4e14a375 100644 --- a/bridges/opencode/opencodebridge/opencode_instance_state.go +++ b/bridges/opencode/opencodebridge/opencode_instance_state.go @@ -80,8 +80,8 @@ func (inst *openCodeInstance) cancelAndStopTimer() { inst.cancel() } inst.cancel = nil - inst.connected = false inst.disconnectMu.Lock() + inst.connected = false if inst.disconnectTimer != nil { inst.disconnectTimer.Stop() inst.disconnectTimer = nil diff --git a/pkg/agents/tools/apply_patch.go b/pkg/agents/tools/apply_patch.go index 2e81caea..c03c631a 100644 --- a/pkg/agents/tools/apply_patch.go +++ b/pkg/agents/tools/apply_patch.go @@ -5,6 +5,6 @@ import "github.com/beeper/ai-bridge/pkg/shared/toolspec" var ApplyPatchTool = newUnavailableBuiltinTool(unavailableBuiltinToolSpec{ name: toolspec.ApplyPatchName, description: toolspec.ApplyPatchDescription, - title: "apply_patch", + title: "Apply Patch", inputSchema: toolspec.ApplyPatchSchema(), }) diff --git a/pkg/agents/tools/results.go b/pkg/agents/tools/results.go index 5430aaed..6cf4f460 100644 --- a/pkg/agents/tools/results.go +++ b/pkg/agents/tools/results.go @@ -32,7 +32,8 @@ func ErrorResult(toolName, message string) *Result { func mustJSON(v any) string { data, err := json.Marshal(v) if err != nil { - return fmt.Sprintf(`{"error":"failed to marshal: %s"}`, err) + errMsg, _ := json.Marshal(fmt.Sprintf("failed to marshal: %s", err)) + return fmt.Sprintf(`{"error":%s}`, errMsg) } return string(data) } diff --git a/pkg/bridgeadapter/identifier_helpers.go b/pkg/bridgeadapter/identifier_helpers.go index 7b142808..6091d84a 100644 --- a/pkg/bridgeadapter/identifier_helpers.go +++ b/pkg/bridgeadapter/identifier_helpers.go @@ -42,12 +42,15 @@ func NextUserLoginID(user *bridgev2.User, prefix string) networkid.UserLoginID { } used[string(existing.ID)] = struct{}{} } - for ordinal := 1; ; ordinal++ { + for ordinal := 1; ordinal <= len(used)+1; ordinal++ { loginID := MakeUserLoginID(prefix, user.MXID, ordinal) if _, ok := used[string(loginID)]; !ok { return loginID } } + // Should be unreachable: there are at most len(used) occupied ordinals, + // so ordinal len(used)+1 must be free. Fall back to a safe default. + return MakeUserLoginID(prefix, user.MXID, len(used)+1) } func SingleLoginFlow(enabled bool, flow bridgev2.LoginFlow) []bridgev2.LoginFlow { diff --git a/pkg/bridgeadapter/remote_events.go b/pkg/bridgeadapter/remote_events.go index e1ff1c28..b8539ba7 100644 --- a/pkg/bridgeadapter/remote_events.go +++ b/pkg/bridgeadapter/remote_events.go @@ -59,7 +59,7 @@ func (m *RemoteMessage) GetID() networkid.MessageID { func (m *RemoteMessage) GetTimestamp() time.Time { if m.Timestamp.IsZero() { - return time.Now() + m.Timestamp = time.Now() } return m.Timestamp } @@ -116,7 +116,7 @@ func (e *RemoteEdit) GetTargetMessage() networkid.MessageID { func (e *RemoteEdit) GetTimestamp() time.Time { if e.Timestamp.IsZero() { - return time.Now() + e.Timestamp = time.Now() } return e.Timestamp } @@ -127,9 +127,9 @@ func (e *RemoteEdit) GetStreamOrder() int64 { func (e *RemoteEdit) ConvertEdit(_ context.Context, _ *bridgev2.Portal, _ bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { if e.PreBuilt != nil && len(existing) > 0 { - for i, part := range e.PreBuilt.ModifiedParts { - if part.Part == nil && i < len(existing) { - part.Part = existing[i] + for i := range e.PreBuilt.ModifiedParts { + if e.PreBuilt.ModifiedParts[i].Part == nil && i < len(existing) { + e.PreBuilt.ModifiedParts[i].Part = existing[i] } } } diff --git a/pkg/connector/error_logging.go b/pkg/connector/error_logging.go index ee0e6e8d..dd1191c6 100644 --- a/pkg/connector/error_logging.go +++ b/pkg/connector/error_logging.go @@ -32,7 +32,9 @@ func logProviderFailure( ) { event := log.Error().Err(err).Str("stage", stage) addRequestSummary(event, meta, messages) - addSummary(event) + if addSummary != nil { + addSummary(event) + } addOpenAIErrorFields(event, err) event.Msg(msg) } diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index f6ade451..39b13999 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -922,6 +922,9 @@ func (oc *AIClient) dispatchMediaUnderstandingFallback( oc.loggerForContext(ctx).Warn().Err(err).Msg(failureLog) return nil, messageSendStatusError(err, userError, "") } + if description == "" { + return nil, messageSendStatusError(errors.New(emptyResult), userError, "") + } combined := buildMessage(caption, hasUserCaption, description) if combined == "" { diff --git a/pkg/connector/heartbeat_execute.go b/pkg/connector/heartbeat_execute.go index 14866838..8f65cb11 100644 --- a/pkg/connector/heartbeat_execute.go +++ b/pkg/connector/heartbeat_execute.go @@ -329,7 +329,9 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } if strings.HasPrefix(session, "!") { if portal := oc.portalByRoomID(context.Background(), id.RoomID(session)); portal != nil { - return portal, portal.MXID.String(), nil + if meta := portalMeta(portal); meta == nil || normalizeAgentID(meta.AgentID) == normalizeAgentID(agentID) { + return portal, portal.MXID.String(), nil + } } } hbSession := oc.resolveHeartbeatSession(agentID, heartbeat) diff --git a/pkg/connector/heartbeat_visibility.go b/pkg/connector/heartbeat_visibility.go index 9e5274b7..6c1e4186 100644 --- a/pkg/connector/heartbeat_visibility.go +++ b/pkg/connector/heartbeat_visibility.go @@ -29,7 +29,9 @@ func resolveHeartbeatVisibility(cfg *Config, channel string) ResolvedHeartbeatVi UseIndicator: defaultHeartbeatVisibility.UseIndicator, } - applyHeartbeatVisibility(&result, defaults.Heartbeat) + if defaults != nil { + applyHeartbeatVisibility(&result, defaults.Heartbeat) + } if perChannel != nil { applyHeartbeatVisibility(&result, perChannel.Heartbeat) } diff --git a/pkg/connector/scheduler_db.go b/pkg/connector/scheduler_db.go index 97a18405..22dee47d 100644 --- a/pkg/connector/scheduler_db.go +++ b/pkg/connector/scheduler_db.go @@ -513,15 +513,22 @@ func deleteMissingScopedRows(ctx context.Context, scope *schedulerDBScope, keep if err != nil { return err } - defer rows.Close() + var toDelete []string for rows.Next() { var idValue string if err := rows.Scan(&idValue); err != nil { + rows.Close() return err } - if _, ok := keep[strings.TrimSpace(idValue)]; ok { - continue + if _, ok := keep[strings.TrimSpace(idValue)]; !ok { + toDelete = append(toDelete, idValue) } + } + rows.Close() + if err := rows.Err(); err != nil { + return err + } + for _, idValue := range toDelete { if _, err := scope.db.Exec(ctx, fmt.Sprintf( `DELETE FROM %s WHERE bridge_id=$1 AND login_id=$2 AND %s=$3`, entityTable, idColumn, @@ -535,5 +542,5 @@ func deleteMissingScopedRows(ctx context.Context, scope *schedulerDBScope, keep return err } } - return rows.Err() + return nil } diff --git a/pkg/connector/streaming_ui_helpers.go b/pkg/connector/streaming_ui_helpers.go index 4d4b67c7..3f1d1b56 100644 --- a/pkg/connector/streaming_ui_helpers.go +++ b/pkg/connector/streaming_ui_helpers.go @@ -42,9 +42,17 @@ func (oc *AIClient) buildStreamUIMessage(state *streamingState, meta *PortalMeta uiMessage["metadata"] = msgconv.MergeUIMessageMetadata(metadata, oc.buildUIMessageMetadata(state, meta, true)) return msgconv.AppendUIMessageArtifacts(uiMessage, sourceParts, fileParts) } + var parts []map[string]any + if text := state.accumulated.String(); text != "" { + parts = append(parts, map[string]any{"type": "text", "text": text}) + } + if reasoning := state.reasoning.String(); reasoning != "" { + parts = append(parts, map[string]any{"type": "reasoning", "reasoning": reasoning}) + } return msgconv.BuildUIMessage(msgconv.UIMessageParams{ TurnID: state.turnID, Role: "assistant", + Parts: parts, Metadata: oc.buildUIMessageMetadata(state, meta, true), SourceURLs: sourceParts, FileParts: fileParts, diff --git a/pkg/connector/subagent_conversion.go b/pkg/connector/subagent_conversion.go index 035e8803..a06baa36 100644 --- a/pkg/connector/subagent_conversion.go +++ b/pkg/connector/subagent_conversion.go @@ -1,6 +1,7 @@ package connector import ( + "fmt" "slices" "github.com/beeper/ai-bridge/pkg/agents" @@ -48,6 +49,6 @@ func convertSubagentConfig[T subagentConfigLike, R any](cfg T, build func(string } return build(typed.Model, typed.Thinking, allowAgents) default: - return nil + panic(fmt.Sprintf("unsupported subagent config type: %T", cfg)) } } diff --git a/pkg/shared/citations/citations.go b/pkg/shared/citations/citations.go index fe775af5..43bf9f4f 100644 --- a/pkg/shared/citations/citations.go +++ b/pkg/shared/citations/citations.go @@ -179,10 +179,14 @@ func AppendSourceDocumentPart(parts *[]map[string]any, seen map[string]struct{}, } seen[seenKey] = struct{}{} part := map[string]any{ - "type": "source-document", - "sourceId": fmt.Sprintf("source-%d", len(*parts)+1), - "mediaType": doc.MediaType, - "title": doc.Title, + "type": "source-document", + "sourceId": fmt.Sprintf("source-%d", len(*parts)+1), + } + if mediaType := strings.TrimSpace(doc.MediaType); mediaType != "" { + part["mediaType"] = mediaType + } + if title := strings.TrimSpace(doc.Title); title != "" { + part["title"] = title } if filename := strings.TrimSpace(doc.Filename); filename != "" { part["filename"] = filename From 6baaf0f56770dc26dbfa8a3af5d57e31632838b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 21:20:01 +0100 Subject: [PATCH 19/23] Fix token counting, metadata copy, and misc bugs Multiple fixes across bridges and connector logic: - bridges/opencode: backfillTotalTokens now includes Cache.Read/Write when present to correctly account for cached tokens. - bridges/opencode: EnsureGhostDisplayName only updates when necessary and ensures the ghost user is marked as a bot. - pkg/bridgeadapter/helpers: UpsertAssistantMessage separates DB errors by lookup method (by ID vs MXID) and logs both errors; avoids treating a nil existing row as an error when updating metadata. - pkg/bridgeadapter/message_metadata: CopyFromBase now deep-copies maps and slices (CanonicalUIMessage, ToolCalls, GeneratedFiles) to avoid sharing underlying references. - pkg/connector/chat: ensureExistingChatPortalReady returns immediately on portal creation error and returns nil on success. - pkg/connector/response_finalization: sendFinalAssistantTurnContent chooses an edit target from network message or initial event, and skips/logs when no target exists instead of queuing a nil edit. - pkg/connector/subagent_spawn: simplify subagent config resolution by adding subagentModel and subagentThinking helpers and using firstNonEmptyTrimmed, removing the previous generic reflection-like resolver. These changes address correctness (token accounting, bot flag), safety (avoid shared mutable state), clearer error handling/logging, and more robust edit and subagent resolution logic. --- .../opencodebridge/backfill_canonical.go | 6 ++- .../opencode/opencodebridge/opencode_ghost.go | 3 +- pkg/bridgeadapter/helpers.go | 11 ++--- pkg/bridgeadapter/message_metadata.go | 11 +++-- pkg/connector/chat.go | 3 +- pkg/connector/response_finalization.go | 24 +++++++---- pkg/connector/subagent_spawn.go | 42 ++++++++----------- 7 files changed, 57 insertions(+), 43 deletions(-) diff --git a/bridges/opencode/opencodebridge/backfill_canonical.go b/bridges/opencode/opencodebridge/backfill_canonical.go index 9b2eefe3..c7d621a0 100644 --- a/bridges/opencode/opencodebridge/backfill_canonical.go +++ b/bridges/opencode/opencodebridge/backfill_canonical.go @@ -288,7 +288,11 @@ func backfillTokenValue(msg opencode.MessageWithParts, pick func(opencode.TokenU } func backfillTotalTokens(msg opencode.MessageWithParts) int64 { - return backfillPromptTokens(msg) + backfillCompletionTokens(msg) + backfillReasoningTokens(msg) + total := backfillPromptTokens(msg) + backfillCompletionTokens(msg) + backfillReasoningTokens(msg) + if msg.Info.Tokens != nil && msg.Info.Tokens.Cache != nil { + total += int64(msg.Info.Tokens.Cache.Read + msg.Info.Tokens.Cache.Write) + } + return total } func buildCanonicalBackfillPart(snapshot canonicalBackfillSnapshot) *event.MessageEventContent { diff --git a/bridges/opencode/opencodebridge/opencode_ghost.go b/bridges/opencode/opencodebridge/opencode_ghost.go index a8fc66c3..1752c4da 100644 --- a/bridges/opencode/opencodebridge/opencode_ghost.go +++ b/bridges/opencode/opencodebridge/opencode_ghost.go @@ -20,7 +20,8 @@ func (b *Bridge) EnsureGhostDisplayName(ctx context.Context, instanceID string) return } displayName := b.DisplayName(instanceID) - if ghost.Name == "" || !ghost.NameSet || ghost.Name != displayName { + needsUpdate := ghost.Name == "" || !ghost.NameSet || ghost.Name != displayName || !ghost.IsBot + if needsUpdate { ghost.UpdateInfo(ctx, &bridgev2.UserInfo{ Name: ptr.Ptr(displayName), IsBot: ptr.Ptr(true), diff --git a/pkg/bridgeadapter/helpers.go b/pkg/bridgeadapter/helpers.go index 19a611d1..bf9b1506 100644 --- a/pkg/bridgeadapter/helpers.go +++ b/pkg/bridgeadapter/helpers.go @@ -223,14 +223,14 @@ func UpsertAssistantMessage(ctx context.Context, p UpsertAssistantMessageParams) receiver = p.Login.ID } var existing *database.Message - var err error + var errByID, errByMXID error if receiver != "" { - existing, err = db.GetPartByID(ctx, receiver, p.NetworkMessageID, networkid.PartID("0")) + existing, errByID = db.GetPartByID(ctx, receiver, p.NetworkMessageID, networkid.PartID("0")) } if existing == nil && p.InitialEventID != "" { - existing, err = db.GetPartByMXID(ctx, p.InitialEventID) + existing, errByMXID = db.GetPartByMXID(ctx, p.InitialEventID) } - if err == nil && existing != nil { + if existing != nil { existing.Metadata = p.Metadata if err := db.Update(ctx, existing); err != nil { p.Logger.Warn().Err(err).Str("msg_id", string(existing.ID)).Msg("Failed to update assistant message metadata") @@ -240,7 +240,8 @@ func UpsertAssistantMessage(ctx context.Context, p UpsertAssistantMessageParams) return } p.Logger.Warn(). - Err(err). + AnErr("err_by_id", errByID). + AnErr("err_by_mxid", errByMXID). Stringer("mxid", p.InitialEventID). Str("msg_id", string(p.NetworkMessageID)). Msg("Could not find existing DB row for update, falling back to insert") diff --git a/pkg/bridgeadapter/message_metadata.go b/pkg/bridgeadapter/message_metadata.go index da715e25..71bf0269 100644 --- a/pkg/bridgeadapter/message_metadata.go +++ b/pkg/bridgeadapter/message_metadata.go @@ -53,7 +53,10 @@ func (b *BaseMessageMetadata) CopyFromBase(src *BaseMessageMetadata) { b.CanonicalSchema = src.CanonicalSchema } if len(src.CanonicalUIMessage) > 0 { - b.CanonicalUIMessage = src.CanonicalUIMessage + b.CanonicalUIMessage = make(map[string]any, len(src.CanonicalUIMessage)) + for k, v := range src.CanonicalUIMessage { + b.CanonicalUIMessage[k] = v + } } if src.StartedAtMs != 0 { b.StartedAtMs = src.StartedAtMs @@ -65,10 +68,12 @@ func (b *BaseMessageMetadata) CopyFromBase(src *BaseMessageMetadata) { b.ThinkingContent = src.ThinkingContent } if len(src.ToolCalls) > 0 { - b.ToolCalls = src.ToolCalls + b.ToolCalls = make([]ToolCallMetadata, len(src.ToolCalls)) + copy(b.ToolCalls, src.ToolCalls) } if len(src.GeneratedFiles) > 0 { - b.GeneratedFiles = src.GeneratedFiles + b.GeneratedFiles = make([]GeneratedFileRef, len(src.GeneratedFiles)) + copy(b.GeneratedFiles, src.GeneratedFiles) } } diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 1a04d232..36481c0a 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -1983,9 +1983,10 @@ func (oc *AIClient) ensureExistingChatPortalReady(ctx context.Context, loginMeta err := portal.CreateMatrixRoom(ctx, oc.UserLogin, info) if err != nil { oc.loggerForContext(ctx).Err(err).Msg(errMsg) + return err } oc.sendWelcomeMessage(ctx, portal) - return err + return nil } func (oc *AIClient) listAllChatPortals(ctx context.Context) ([]*bridgev2.Portal, error) { diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index 04f05b6c..4e784033 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -667,13 +667,23 @@ func (oc *AIClient) sendFinalAssistantTurnContent(ctx context.Context, portal *b TopLevelExtra: topLevelExtra, }}, } - oc.UserLogin.QueueRemoteEvent(&bridgeadapter.RemoteEdit{ - Portal: portal.PortalKey, - Sender: sender, - TargetMessage: state.networkMessageID, - LogKey: "ai_edit_target", - PreBuilt: editContent, - }) + editTarget := state.networkMessageID + if editTarget == "" { + editTarget = bridgeadapter.MatrixMessageID(state.initialEventID) + } + if editTarget == "" { + oc.loggerForContext(ctx).Warn(). + Str("turn_id", state.turnID). + Msg("Skipping final assistant edit: no network or initial event target") + } else { + oc.UserLogin.QueueRemoteEvent(&bridgeadapter.RemoteEdit{ + Portal: portal.PortalKey, + Sender: sender, + TargetMessage: editTarget, + LogKey: "ai_edit_target", + PreBuilt: editContent, + }) + } oc.recordAgentActivity(ctx, portal, meta) oc.loggerForContext(ctx).Debug(). Str("initial_event_id", state.initialEventID.String()). diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index 0825d34c..f37f6962 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -53,32 +53,24 @@ func (oc *AIClient) resolveSubagentAllowlist(ctx context.Context, requesterAgent return allowAny, allowSet } -func resolveSubagentConfigValue(override string, agent *agents.AgentDefinition, defaults *agents.SubagentConfig, field string) string { - return firstNonEmptyTrimmed(override, subagentStringValue(agent, field), subagentStringValue(defaults, field)) +func subagentModel(agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { + if agent != nil && agent.Subagents != nil && agent.Subagents.Model != "" { + return agent.Subagents.Model + } + if defaults != nil && defaults.Model != "" { + return defaults.Model + } + return "" } -func subagentStringValue(source any, field string) string { - switch cfg := source.(type) { - case *agents.AgentDefinition: - if cfg == nil { - return "" - } - return subagentStringValue(cfg.Subagents, field) - case *agents.SubagentConfig: - if cfg == nil { - return "" - } - switch field { - case "model": - return cfg.Model - case "thinking": - return cfg.Thinking - default: - return "" - } - default: - return "" +func subagentThinking(agent *agents.AgentDefinition, defaults *agents.SubagentConfig) string { + if agent != nil && agent.Subagents != nil && agent.Subagents.Thinking != "" { + return agent.Subagents.Thinking } + if defaults != nil && defaults.Thinking != "" { + return defaults.Thinking + } + return "" } func firstNonEmptyTrimmed(values ...string) string { @@ -255,7 +247,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P if oc.connector != nil && oc.connector.Config.Agents != nil && oc.connector.Config.Agents.Defaults != nil { defaultSubagents = oc.connector.Config.Agents.Defaults.Subagents } - thinkingCandidate := resolveSubagentConfigValue(thinkingOverride, targetAgent, defaultSubagents, "thinking") + thinkingCandidate := firstNonEmptyTrimmed(thinkingOverride, subagentThinking(targetAgent, defaultSubagents)) thinkingLevel, ok := normalizeThinkingLevel(thinkingCandidate) if !ok { return tools.JSONResult(map[string]any{ @@ -265,7 +257,7 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P } reasoningEffort := mapThinkingToReasoningEffort(thinkingLevel) - modelCandidate := resolveSubagentConfigValue(modelOverride, targetAgent, defaultSubagents, "model") + modelCandidate := firstNonEmptyTrimmed(modelOverride, subagentModel(targetAgent, defaultSubagents)) resolvedModel := "" modelWarning := "" From 79aa2fc98efb67a2a1882b354dd47fe28edf4ef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 21:48:45 +0100 Subject: [PATCH 20/23] Simplify AI room config and provisioning --- pkg/connector/agent_activity.go | 2 +- pkg/connector/agent_display.go | 3 - pkg/connector/agentstore.go | 33 +- pkg/connector/chat.go | 320 +------- pkg/connector/chat_fork_test.go | 21 +- pkg/connector/client.go | 211 ++--- pkg/connector/client_capabilities_test.go | 4 +- pkg/connector/command_registry.go | 17 + pkg/connector/connector.go | 196 +---- pkg/connector/defaults_alignment_test.go | 67 +- pkg/connector/error_logging.go | 4 +- pkg/connector/events.go | 81 +- pkg/connector/handleai.go | 32 +- pkg/connector/heartbeat_delivery.go | 2 +- pkg/connector/heartbeat_execute.go | 4 +- pkg/connector/history_limit_test.go | 6 +- pkg/connector/identifiers.go | 74 +- pkg/connector/inbound_directive_apply.go | 28 +- pkg/connector/inbound_prompt_runtime_test.go | 2 +- pkg/connector/integration_host.go | 12 - pkg/connector/integrations.go | 3 - pkg/connector/metadata.go | 188 +++-- pkg/connector/model_fallback.go | 16 +- pkg/connector/prompt_params.go | 3 - pkg/connector/provisioning.go | 720 ++++++++++++++++-- pkg/connector/provisioning_test.go | 126 +++ pkg/connector/response_finalization_test.go | 2 +- .../room_settings_event_content_test.go | 34 - pkg/connector/scheduler_cron.go | 12 +- pkg/connector/scheduler_rooms.go | 10 +- pkg/connector/session_greeting_test.go | 2 +- pkg/connector/sessions_tools.go | 8 +- pkg/connector/sessions_visibility_test.go | 1 - pkg/connector/simple_mode_prompt.go | 13 +- pkg/connector/simple_mode_prompt_test.go | 18 +- pkg/connector/status_text.go | 42 +- pkg/connector/streaming_chat_completions.go | 2 +- pkg/connector/streaming_continuation.go | 2 +- pkg/connector/streaming_init_test.go | 8 +- pkg/connector/streaming_params.go | 2 +- pkg/connector/streaming_tool_selection.go | 2 +- .../streaming_tool_selection_test.go | 6 +- pkg/connector/system_prompts.go | 14 +- pkg/connector/target_test_helpers_test.go | 21 + pkg/connector/tool_descriptions.go | 2 +- pkg/connector/tool_policy.go | 8 +- pkg/connector/tools.go | 89 +-- pkg/connector/tools_message_actions.go | 6 +- pkg/connector/trace.go | 26 +- 49 files changed, 1289 insertions(+), 1216 deletions(-) create mode 100644 pkg/connector/provisioning_test.go delete mode 100644 pkg/connector/room_settings_event_content_test.go create mode 100644 pkg/connector/target_test_helpers_test.go diff --git a/pkg/connector/agent_activity.go b/pkg/connector/agent_activity.go index 058408ec..a98543cd 100644 --- a/pkg/connector/agent_activity.go +++ b/pkg/connector/agent_activity.go @@ -74,7 +74,7 @@ func (oc *AIClient) lastActivePortal(agentID string) *bridgev2.Portal { portal := oc.portalByRoomID(context.Background(), id.RoomID(room)) // Guard against stale mappings when a room's agent assignment changes. if portal != nil { - if meta := portalMeta(portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { + if meta := portalMeta(portal); meta != nil && normalizeAgentID(resolveAgentID(meta)) != normalizeAgentID(agentID) { return nil } } diff --git a/pkg/connector/agent_display.go b/pkg/connector/agent_display.go index f5bb981a..8bcf3ed9 100644 --- a/pkg/connector/agent_display.go +++ b/pkg/connector/agent_display.go @@ -53,9 +53,6 @@ func (oc *AIClient) agentDefaultModel(agent *agents.AgentDefinition) string { if agent == nil { return oc.effectiveModel(nil) } - if override := oc.agentModelOverride(agent.ID); override != "" { - return ResolveAlias(override) - } if agent.Model.Primary != "" { return ResolveAlias(agent.Model.Primary) } diff --git a/pkg/connector/agentstore.go b/pkg/connector/agentstore.go index 20fd8611..3a96d667 100644 --- a/pkg/connector/agentstore.go +++ b/pkg/connector/agentstore.go @@ -522,13 +522,12 @@ func (b *BossStoreAdapter) CreateRoom(ctx context.Context, room tools.RoomData) return "", fmt.Errorf("failed to get created portal: %w", err) } - // Apply custom name and system prompt if provided + // Apply custom room name if provided. pm := portalMeta(portal) originalName := portal.Name originalNameSet := portal.NameSet originalTitle := pm.Title originalTitleGenerated := pm.TitleGenerated - originalSystemPrompt := pm.SystemPrompt if room.Name != "" { pm.Title = room.Name @@ -538,12 +537,6 @@ func (b *BossStoreAdapter) CreateRoom(ctx context.Context, room tools.RoomData) resp.PortalInfo.Name = &room.Name } } - if room.SystemPrompt != "" { - pm.SystemPrompt = room.SystemPrompt - // Note: portal.Topic is NOT set to SystemPrompt - they are separate concepts - // Topic is for display only, SystemPrompt is for LLM context - } - // Create the Matrix room if err := portal.CreateMatrixRoom(ctx, b.store.client.UserLogin, resp.PortalInfo); err != nil { cleanupPortal(ctx, b.store.client, portal, "failed to create Matrix room") @@ -562,13 +555,6 @@ func (b *BossStoreAdapter) CreateRoom(ctx context.Context, room tools.RoomData) pm.TitleGenerated = originalTitleGenerated } } - if room.SystemPrompt != "" { - if err := b.store.client.setRoomSystemPromptNoSave(ctx, portal, room.SystemPrompt); err != nil { - b.store.client.log.Warn().Err(err).Msg("Failed to set room system prompt") - pm.SystemPrompt = originalSystemPrompt - } - } - if err := portal.Save(ctx); err != nil { return "", fmt.Errorf("failed to save room overrides: %w", err) } @@ -597,29 +583,18 @@ func (b *BossStoreAdapter) ModifyRoom(ctx context.Context, roomID string, update if err != nil { return fmt.Errorf("agent '%s' not found: %w", updates.AgentID, err) } - pm.AgentID = agent.ID - pm.Model = "" - modelID := b.store.client.effectiveModel(pm) - pm.Capabilities = getModelCapabilities(modelID, b.store.client.findModelInfo(modelID)) portal.OtherUserID = agentUserID(agent.ID) + pm.ResolvedTarget = resolveTargetFromGhostID(portal.OtherUserID) + modelID := b.store.client.effectiveModel(pm) agentName := b.store.client.resolveAgentDisplayName(ctx, agent) b.store.client.ensureAgentGhostDisplayName(ctx, agent.ID, modelID, agentName) } - if updates.SystemPrompt != "" { - pm.SystemPrompt = updates.SystemPrompt - // Note: portal.Topic is NOT set to SystemPrompt - they are separate concepts - } if updates.Name != "" && portal.MXID != "" { if err := b.store.client.setRoomName(ctx, portal, updates.Name); err != nil { b.store.client.log.Warn().Err(err).Msg("Failed to set Matrix room name") } } - if updates.SystemPrompt != "" && portal.MXID != "" { - if err := b.store.client.setRoomSystemPrompt(ctx, portal, updates.SystemPrompt); err != nil { - b.store.client.log.Warn().Err(err).Msg("Failed to set room system prompt") - } - } return portal.Save(ctx) } @@ -645,7 +620,7 @@ func (b *BossStoreAdapter) ListRooms(ctx context.Context) ([]tools.RoomData, err rooms = append(rooms, tools.RoomData{ ID: roomID, Name: name, - AgentID: pm.AgentID, + AgentID: resolveAgentID(pm), }) } diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 36481c0a..4a331588 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -37,17 +37,11 @@ const defaultSimpleModeSystemPrompt = "You are a helpful assistant." var ErrDMGhostImmutable = errors.New("can't change the counterpart ghost in a DM") func hasAssignedAgent(meta *PortalMetadata) bool { - if meta == nil { - return false - } - return meta.AgentID != "" + return resolveAgentID(meta) != "" } func hasBossAgent(meta *PortalMetadata) bool { - if meta == nil { - return false - } - return agents.IsBossAgent(meta.AgentID) + return agents.IsBossAgent(resolveAgentID(meta)) } func dmModelSwitchGuidance(targetModel string) string { @@ -593,21 +587,15 @@ func (oc *AIClient) createAgentChatWithModel(ctx context.Context, agent *agents. // Set agent-specific metadata pm := portalMeta(portal) - pm.AgentID = agent.ID - if agent.SystemPrompt != "" { - pm.SystemPrompt = agent.SystemPrompt - } - if agent.ReasoningEffort != "" { - pm.ReasoningEffort = agent.ReasoningEffort - } - if !applyModelOverride { - pm.Model = "" - } agentGhostID := agentUserID(agent.ID) // Update the OtherUserID to be the agent ghost portal.OtherUserID = agentGhostID + pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID) + if applyModelOverride { + pm.RuntimeModelOverride = ResolveAlias(modelID) + } agentAvatar := strings.TrimSpace(agent.AvatarURL) if agentAvatar == "" { agentAvatar = strings.TrimSpace(agents.DefaultAgentAvatarMXC) @@ -649,14 +637,6 @@ func (oc *AIClient) createNewChat(ctx context.Context, modelID string) (*bridgev } // Keep simple mode chats non-agentic by default. - meta := portalMeta(portal) - if meta != nil && !meta.IsSimpleMode { - meta.IsSimpleMode = true - if err := portal.Save(ctx); err != nil { - return nil, fmt.Errorf("failed to save portal simple mode: %w", err) - } - } - // Rooms created via provisioning (ResolveIdentifier/CreateDM) won't go through our explicit // post-CreateMatrixRoom call sites. Schedule the welcome notice for when the Matrix room exists. oc.scheduleWelcomeMessage(ctx, portal.PortalKey) @@ -696,20 +676,15 @@ func cloneForkPortalMetadata(src *PortalMetadata, slug, title string) *PortalMet if src == nil { return nil } - return &PortalMetadata{ - Model: src.Model, - Slug: slug, - Title: title, - SystemPrompt: src.SystemPrompt, - Temperature: src.Temperature, - MaxContextMessages: src.MaxContextMessages, - MaxCompletionTokens: src.MaxCompletionTokens, - ReasoningEffort: src.ReasoningEffort, - Capabilities: src.Capabilities, - AgentID: src.AgentID, - AgentPrompt: src.AgentPrompt, - IsSimpleMode: src.IsSimpleMode, + clone := &PortalMetadata{ + Slug: slug, + Title: title, + } + if src.ResolvedTarget != nil { + target := *src.ResolvedTarget + clone.ResolvedTarget = &target } + return clone } // initPortalForChat handles common portal initialization logic. @@ -745,14 +720,10 @@ func (oc *AIClient) initPortalForChat(ctx context.Context, opts PortalInitOpts) var pmeta *PortalMetadata if opts.CopyFrom != nil { pmeta = cloneForkPortalMetadata(opts.CopyFrom, slug, title) - modelID = opts.CopyFrom.Model } else { pmeta = &PortalMetadata{ - Model: modelID, - Slug: slug, - Title: title, - SystemPrompt: opts.SystemPrompt, - Capabilities: getModelCapabilities(modelID, oc.findModelInfo(modelID)), + Slug: slug, + Title: title, } } portal.Metadata = pmeta @@ -921,8 +892,7 @@ func (oc *AIClient) handleNewChat( oc.sendSystemNotice(runCtx, portal, err.Error()) return } - modelOverride := meta != nil && meta.Model != "" - oc.createAndOpenAgentChat(runCtx, portal, agent, modelID, modelOverride) + oc.createAndOpenAgentChat(runCtx, portal, agent, modelID, false) return } @@ -1151,14 +1121,6 @@ func (oc *AIClient) createNewSimpleChat(ctx context.Context, modelID string) (*b } // Simple mode rooms are non-agentic. This disables directive processing. - meta := portalMeta(portal) - if meta != nil && !meta.IsSimpleMode { - meta.IsSimpleMode = true - if err := portal.Save(ctx); err != nil { - return nil, nil, err - } - } - return portal, chatInfo, nil } @@ -1208,13 +1170,11 @@ func (oc *AIClient) composeChatInfo(title, modelID string) *bridgev2.ChatInfo { title = modelName } chatInfo := bridgeadapter.BuildDMChatInfo(bridgeadapter.DMChatInfoParams{ - Title: title, - HumanUserID: humanUserID(oc.UserLogin.ID), - LoginID: oc.UserLogin.ID, - BotUserID: modelUserID(modelID), - BotDisplayName: modelName, - CapabilitiesEvent: RoomCapabilitiesEventType, - SettingsEvent: RoomSettingsEventType, + Title: title, + HumanUserID: humanUserID(oc.UserLogin.ID), + LoginID: oc.UserLogin.ID, + BotUserID: modelUserID(modelID), + BotDisplayName: modelName, }) // Override bot member with model-specific UserInfo and extra fields. chatInfo.Members.MemberMap[modelUserID(modelID)] = modelJoinMember(oc.UserLogin.ID, modelID, modelName, modelInfo) @@ -1272,78 +1232,6 @@ func (oc *AIClient) applyAgentChatInfo(chatInfo *bridgev2.ChatInfo, agentID, age chatInfo.Members = members } -// updatePortalConfig applies room settings to portal metadata with optimistic updates. -// If persistence fails, metadata is rolled back to the previous values. -func (oc *AIClient) updatePortalConfig(ctx context.Context, portal *bridgev2.Portal, config *RoomSettingsEventContent) error { - meta := portalMeta(portal) - before := clonePortalMetadata(meta) - - // Track old model for membership change - oldModel := meta.Model - - if config.Model != "" { - if err := oc.validateDMModelSwitch(portal, meta, config.Model); err != nil { - return dmModelSwitchBlockedError(config.Model) - } - } - - // Update only non-empty/non-zero values - if config.Model != "" { - meta.Model = config.Model - // Update capabilities when model changes - meta.Capabilities = getModelCapabilities(config.Model, oc.findModelInfo(config.Model)) - } - if config.SystemPrompt != "" { - meta.SystemPrompt = config.SystemPrompt - } - if config.Temperature != nil { - meta.Temperature = *config.Temperature - } - if config.MaxContextMessages > 0 { - meta.MaxContextMessages = config.MaxContextMessages - } - if config.MaxCompletionTokens > 0 { - meta.MaxCompletionTokens = config.MaxCompletionTokens - } - if config.ReasoningEffort != "" { - meta.ReasoningEffort = config.ReasoningEffort - } - if config.AgentID != "" { - meta.AgentID = config.AgentID - } - - meta.LastRoomStateSync = time.Now().Unix() - - // Persist changes - if err := portal.Save(ctx); err != nil { - if before != nil { - *meta = *before - } - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to save portal after config update") - return err - } - - // Re-broadcast room state to confirm changes to all clients - if err := oc.BroadcastRoomState(ctx, portal); err != nil { - if before != nil { - *meta = *before - if saveErr := portal.Save(ctx); saveErr != nil { - oc.loggerForContext(ctx).Warn().Err(saveErr).Msg("Failed to save rollback portal metadata after state broadcast failure") - } - } - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to re-broadcast room state after config update") - return err - } - - // Handle model switch - generate membership events if model changed. - // This is done after persistence succeeds so optimistic updates can roll back safely. - if config.Model != "" && oldModel != "" && config.Model != oldModel { - oc.handleModelSwitch(ctx, portal, oldModel, config.Model) - } - - return nil -} - // handleModelSwitch generates membership change events when switching models // This creates leave/join events to show the model transition in the room timeline // For agent rooms, it updates the agent ghost metadata. @@ -1562,170 +1450,8 @@ func (oc *AIClient) ensureSingleAIGhost(ctx context.Context, portal *bridgev2.Po // BroadcastRoomState sends current room capabilities and settings to Matrix room state func (oc *AIClient) BroadcastRoomState(ctx context.Context, portal *bridgev2.Portal) error { - if err := oc.broadcastCapabilities(ctx, portal); err != nil { - return err - } - if err := oc.broadcastSettings(ctx, portal); err != nil { - return err - } - // Broadcast command descriptions so clients can discover slash commands. - oc.BroadcastCommandDescriptions(ctx, portal) - return nil -} - -// buildEffectiveSettings builds the effective settings with source explanations -func (oc *AIClient) buildEffectiveSettings(meta *PortalMetadata) *EffectiveSettings { - loginMeta := loginMetadata(oc.UserLogin) - - return &EffectiveSettings{ - Model: oc.getModelWithSource(meta, loginMeta), - SystemPrompt: oc.getPromptWithSource(meta, loginMeta), - Temperature: oc.getTempWithSource(meta, loginMeta), - ReasoningEffort: oc.getReasoningWithSource(meta, loginMeta), - } -} - -func (oc *AIClient) getModelWithSource(meta *PortalMetadata, loginMeta *UserLoginMetadata) SettingExplanation { - if meta != nil && meta.Model != "" { - return SettingExplanation{Value: meta.Model, Source: SourceRoomOverride} - } - if loginMeta.Defaults != nil && loginMeta.Defaults.Model != "" { - return SettingExplanation{Value: loginMeta.Defaults.Model, Source: SourceUserDefault} - } - return SettingExplanation{Value: oc.defaultModelForProvider(), Source: SourceProviderConfig} -} - -func (oc *AIClient) getPromptWithSource(meta *PortalMetadata, loginMeta *UserLoginMetadata) SettingExplanation { - if meta != nil && meta.SystemPrompt != "" { - return SettingExplanation{Value: meta.SystemPrompt, Source: SourceRoomOverride} - } - if loginMeta.Defaults != nil && loginMeta.Defaults.SystemPrompt != "" { - return SettingExplanation{Value: loginMeta.Defaults.SystemPrompt, Source: SourceUserDefault} - } - if oc.connector.Config.DefaultSystemPrompt != "" { - return SettingExplanation{Value: oc.connector.Config.DefaultSystemPrompt, Source: SourceProviderConfig} - } - return SettingExplanation{Value: "", Source: SourceGlobalDefault} -} - -func (oc *AIClient) getTempWithSource(meta *PortalMetadata, loginMeta *UserLoginMetadata) SettingExplanation { - if meta != nil && meta.Temperature > 0 { - return SettingExplanation{Value: meta.Temperature, Source: SourceRoomOverride} - } - if loginMeta.Defaults != nil && loginMeta.Defaults.Temperature != nil { - return SettingExplanation{Value: *loginMeta.Defaults.Temperature, Source: SourceUserDefault} - } - return SettingExplanation{Value: nil, Source: SourceGlobalDefault, Reason: "provider/model default (unset)"} -} - -func (oc *AIClient) getReasoningWithSource(meta *PortalMetadata, loginMeta *UserLoginMetadata) SettingExplanation { - // Check model support first - if meta != nil && !meta.Capabilities.SupportsReasoning { - return SettingExplanation{Value: nil, Source: SourceModelLimit, Reason: "Model does not support reasoning"} - } - if meta != nil && meta.ReasoningEffort != "" { - return SettingExplanation{Value: meta.ReasoningEffort, Source: SourceRoomOverride} - } - if loginMeta.Defaults != nil && loginMeta.Defaults.ReasoningEffort != "" { - return SettingExplanation{Value: loginMeta.Defaults.ReasoningEffort, Source: SourceUserDefault} - } - if meta != nil && meta.Capabilities.SupportsReasoning { - return SettingExplanation{Value: defaultReasoningEffort, Source: SourceGlobalDefault} - } - return SettingExplanation{Value: "", Source: SourceGlobalDefault} -} - -// broadcastCapabilities sends bridge-controlled capabilities to Matrix room state -// This event is protected by power levels (100) so only the bridge bot can modify -func (oc *AIClient) broadcastCapabilities(ctx context.Context, portal *bridgev2.Portal) error { - if portal.MXID == "" { - return errors.New("portal has no Matrix room ID") - } - - meta := portalMeta(portal) - loginMeta := loginMetadata(oc.UserLogin) - - // Refresh stored model capabilities (room capabilities may add image-understanding union separately) - modelCaps := oc.getModelCapabilitiesForMeta(meta) - if meta.Capabilities != modelCaps { - meta.Capabilities = modelCaps - if err := portal.Save(ctx); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to save portal after capability refresh") - } - } - - roomCaps := oc.getRoomCapabilities(ctx, meta) - - // Build reasoning effort options if model supports reasoning - var reasoningEfforts []ReasoningEffortOption - if roomCaps.SupportsReasoning { - reasoningEfforts = []ReasoningEffortOption{ - {Value: "low", Label: "Low"}, - {Value: "medium", Label: "Medium"}, - {Value: "high", Label: "High"}, - } - } - - content := &RoomCapabilitiesEventContent{ - Capabilities: &roomCaps, - AvailableTools: oc.buildAvailableTools(meta), - ReasoningEffortOptions: reasoningEfforts, - Provider: loginMeta.Provider, - EffectiveSettings: oc.buildEffectiveSettings(meta), - } - - bot := oc.UserLogin.Bridge.Bot - _, err := bot.SendState(ctx, portal.MXID, RoomCapabilitiesEventType, "", &event.Content{ - Parsed: content, - }, time.Time{}) - - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to broadcast room capabilities") - return err - } - - // Also update standard room features for clients portal.UpdateCapabilities(ctx, oc.UserLogin, true) - - oc.loggerForContext(ctx).Debug().Str("model", meta.Model).Msg("Broadcasted room capabilities") - return nil -} - -// broadcastSettings sends user-editable settings to Matrix room state -// This event uses normal power levels (0) so users can modify -func (oc *AIClient) broadcastSettings(ctx context.Context, portal *bridgev2.Portal) error { - if portal.MXID == "" { - return errors.New("portal has no Matrix room ID") - } - - meta := portalMeta(portal) - - content := &RoomSettingsEventContent{ - Model: meta.Model, - SystemPrompt: meta.SystemPrompt, - Temperature: &meta.Temperature, - MaxContextMessages: meta.MaxContextMessages, - MaxCompletionTokens: meta.MaxCompletionTokens, - ReasoningEffort: meta.ReasoningEffort, - AgentID: meta.AgentID, - } - - bot := oc.UserLogin.Bridge.Bot - _, err := bot.SendState(ctx, portal.MXID, RoomSettingsEventType, "", &event.Content{ - Parsed: content, - }, time.Time{}) - - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to broadcast room settings") - return err - } - - meta.LastRoomStateSync = time.Now().Unix() - if err := portal.Save(ctx); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to save portal after state broadcast") - } - - oc.loggerForContext(ctx).Debug().Str("model", meta.Model).Msg("Broadcasted room settings") + oc.BroadcastCommandDescriptions(ctx, portal) return nil } diff --git a/pkg/connector/chat_fork_test.go b/pkg/connector/chat_fork_test.go index 8320f7fb..012b336c 100644 --- a/pkg/connector/chat_fork_test.go +++ b/pkg/connector/chat_fork_test.go @@ -4,19 +4,12 @@ import "testing" func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) { src := &PortalMetadata{ - Model: "openai/gpt-5", - SystemPrompt: "You are helpful.", - Temperature: 0.3, - MaxContextMessages: 42, - MaxCompletionTokens: 2048, - ReasoningEffort: "medium", - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, + GroupActivation: "always", // Legacy field is not copied in fork metadata. + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-5"), + ModelID: "openai/gpt-5", }, - AgentID: "beeper", - AgentPrompt: "agent prompt", - IsSimpleMode: true, - GroupActivation: "always", // Not copied in fork metadata. } got := cloneForkPortalMetadata(src, "chat-99", "Forked Chat") @@ -29,8 +22,8 @@ func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) { if got.Title != "Forked Chat" { t.Fatalf("expected title Forked Chat, got %q", got.Title) } - if !got.IsSimpleMode { - t.Fatalf("expected IsSimpleMode=true on forked metadata") + if !isSimpleMode(got) { + t.Fatalf("expected forked metadata to keep resolved simple-mode target") } if got.GroupActivation != "" { t.Fatalf("expected GroupActivation to remain unset in fork metadata copy, got %q", got.GroupActivation) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 950780a9..97bb6786 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -1089,7 +1089,7 @@ func (oc *AIClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*br store := NewAgentStoreAdapter(oc) agent, err := store.GetAgentByID(ctx, agentID) displayName := "Unknown Agent" - modelID := oc.agentModelOverride(agentID) + modelID := "" if err == nil && agent != nil { displayName = oc.resolveAgentDisplayName(ctx, agent) if displayName == "" { @@ -1245,61 +1245,29 @@ func (oc *AIClient) supportsMessageActionsFeature(meta *PortalMetadata) bool { } // effectiveModel returns the full prefixed model ID (e.g., "openai/gpt-5.2") -// Priority: Room → Agent → User → Provider → Global -// Exception: Boss agent rooms always use the Boss agent's model (no overrides) +// based only on the resolved room target. func (oc *AIClient) effectiveModel(meta *PortalMetadata) string { - // Check if an agent is assigned - if meta != nil { - agentID := resolveAgentID(meta) - if agentID != "" { - // Load the agent to get its model + if meta != nil && strings.TrimSpace(meta.RuntimeModelOverride) != "" { + return ResolveAlias(meta.RuntimeModelOverride) + } + if meta != nil && meta.ResolvedTarget != nil { + switch meta.ResolvedTarget.Kind { + case ResolvedTargetModel: + return ResolveAlias(meta.ResolvedTarget.ModelID) + case ResolvedTargetAgent: store := NewAgentStoreAdapter(oc) - agent, err := store.GetAgentByID(context.Background(), agentID) - if err == nil && agent != nil { - // Boss agent rooms always use the Boss model - no overrides allowed - if agents.IsBossAgent(agentID) && agent.Model.Primary != "" { - return ResolveAlias(agent.Model.Primary) - } - // For other agents, room override takes priority, then agent model - if meta.Model != "" { - return ResolveAlias(meta.Model) - } - if override := oc.agentModelOverride(agentID); override != "" { - return ResolveAlias(override) - } - if agent.Model.Primary != "" { - return ResolveAlias(agent.Model.Primary) - } + agent, err := store.GetAgentByID(context.Background(), meta.ResolvedTarget.AgentID) + if err == nil && agent != nil && agent.Model.Primary != "" { + return ResolveAlias(agent.Model.Primary) } + return "" + default: + return "" } } - - // Room-level model override (for rooms without an agent) - if meta != nil && meta.Model != "" { - return ResolveAlias(meta.Model) - } - - // User-level default - loginMeta := loginMetadata(oc.UserLogin) - if loginMeta.Defaults != nil && loginMeta.Defaults.Model != "" { - return ResolveAlias(loginMeta.Defaults.Model) - } - - // Provider default from config return oc.defaultModelForProvider() } -func (oc *AIClient) agentModelOverride(agentID string) string { - if agentID == "" || oc.UserLogin == nil { - return "" - } - loginMeta := loginMetadata(oc.UserLogin) - if loginMeta == nil || loginMeta.AgentModelOverrides == nil { - return "" - } - return strings.TrimSpace(loginMeta.AgentModelOverrides[agentID]) -} - // effectiveModelForAPI returns the actual model name to send to the API // For OpenRouter/Beeper, returns the full model ID (e.g., "openai/gpt-5.2") // For direct providers, strips the prefix (e.g., "openai/gpt-5.2" → "gpt-5.2") @@ -1331,7 +1299,13 @@ func (oc *AIClient) modelIDForAPI(modelID string) string { // defaultModelForProvider returns the configured default model for this login's provider func (oc *AIClient) defaultModelForProvider() string { + if oc == nil || oc.connector == nil || oc.UserLogin == nil { + return DefaultModelOpenRouter + } loginMeta := loginMetadata(oc.UserLogin) + if loginMeta == nil { + return DefaultModelOpenRouter + } providers := oc.connector.Config.Providers switch loginMeta.Provider { @@ -1355,29 +1329,50 @@ func (oc *AIClient) defaultModelForProvider() string { } } -// effectivePrompt returns the system prompt to use -// Priority: Room ? User ? Bridge Config +// effectivePrompt returns the base system prompt to use for non-agent rooms. func (oc *AIClient) effectivePrompt(meta *PortalMetadata) string { - // Room-level override takes priority - var base string - if meta != nil && meta.SystemPrompt != "" { - base = meta.SystemPrompt - } else { - loginMeta := loginMetadata(oc.UserLogin) - if loginMeta.Defaults != nil && loginMeta.Defaults.SystemPrompt != "" { - base = loginMeta.Defaults.SystemPrompt - } else { - base = oc.connector.Config.DefaultSystemPrompt - } - } - gravatarContext := oc.gravatarContext() - if gravatarContext == "" { + base := oc.connector.Config.DefaultSystemPrompt + supplement := oc.profilePromptSupplement() + if supplement == "" { return base } if strings.TrimSpace(base) == "" { - return gravatarContext + return supplement + } + return fmt.Sprintf("%s\n\n%s", base, supplement) +} + +func (oc *AIClient) profilePromptSupplement() string { + if oc == nil || oc.UserLogin == nil { + return strings.TrimSpace(oc.gravatarContext()) + } + loginMeta := loginMetadata(oc.UserLogin) + if loginMeta == nil { + return strings.TrimSpace(oc.gravatarContext()) + } + + var lines []string + if profile := loginMeta.Profile; profile != nil { + if v := strings.TrimSpace(profile.Name); v != "" { + lines = append(lines, "Name: "+v) + } + if v := strings.TrimSpace(profile.Occupation); v != "" { + lines = append(lines, "Occupation: "+v) + } + if v := strings.TrimSpace(profile.AboutUser); v != "" { + lines = append(lines, "About the user: "+v) + } + if v := strings.TrimSpace(profile.CustomInstructions); v != "" { + lines = append(lines, "Custom instructions: "+v) + } + } + if gravatar := strings.TrimSpace(oc.gravatarContext()); gravatar != "" { + lines = append(lines, gravatar) } - return fmt.Sprintf("%s\n\n%s", base, gravatarContext) + if len(lines) == 0 { + return "" + } + return "User profile:\n- " + strings.Join(lines, "\n- ") } // getLinkPreviewConfig returns the link preview configuration, with defaults filled in. @@ -1444,9 +1439,6 @@ func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.P if strings.TrimSpace(agent.SystemPrompt) != "" { extraParts = append(extraParts, strings.TrimSpace(agent.SystemPrompt)) } - if meta != nil && strings.TrimSpace(meta.SystemPrompt) != "" { - extraParts = append(extraParts, strings.TrimSpace(meta.SystemPrompt)) - } extraSystemPrompt := strings.Join(extraParts, "\n\n") // Build params for prompt generation (OpenClaw template) @@ -1464,7 +1456,7 @@ func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.P } } } - params.UserIdentitySupplement = oc.gravatarContext() + params.UserIdentitySupplement = oc.profilePromptSupplement() params.ContextFiles = oc.buildBootstrapContextFiles(ctx, agentID, meta) if meta != nil && strings.TrimSpace(meta.SubagentParentRoomID) != "" { params.PromptMode = agents.PromptModeMinimal @@ -1487,21 +1479,23 @@ func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.P params.ToolSummaries = toolSummaries } - // Build capabilities list from metadata + modelCaps := oc.getModelCapabilitiesForMeta(meta) + + // Build capabilities list from model resolution var caps []string - if meta.Capabilities.SupportsVision { + if modelCaps.SupportsVision { caps = append(caps, "vision") } - if meta.Capabilities.SupportsToolCalling { + if modelCaps.SupportsToolCalling { caps = append(caps, "tools") } - if meta.Capabilities.SupportsReasoning { + if modelCaps.SupportsReasoning { caps = append(caps, "reasoning") } - if meta.Capabilities.SupportsAudio { + if modelCaps.SupportsAudio { caps = append(caps, "audio") } - if meta.Capabilities.SupportsVideo { + if modelCaps.SupportsVideo { caps = append(caps, "video") } @@ -1528,7 +1522,7 @@ func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.P } // Reasoning hints and level - params.ReasoningTagHint = meta.Capabilities.SupportsReasoning && meta.EmitThinking + params.ReasoningTagHint = false params.ReasoningLevel = resolvePromptReasoningLevel(meta) // Default thinking level (OpenClaw-style): low for reasoning-capable models, otherwise off. @@ -1537,31 +1531,13 @@ func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.P return agents.BuildSystemPrompt(params) } -// effectiveTemperature returns the temperature to use. -// Priority: Room → User → Default (unset / provider default). func (oc *AIClient) effectiveTemperature(meta *PortalMetadata) float64 { - if meta != nil && meta.Temperature > 0 { - return meta.Temperature - } - var loginMeta *UserLoginMetadata - if oc != nil && oc.UserLogin != nil { - loginMeta = loginMetadata(oc.UserLogin) - } - if loginMeta != nil && loginMeta.Defaults != nil && loginMeta.Defaults.Temperature != nil { - return *loginMeta.Defaults.Temperature - } return defaultTemperature } // defaultThinkLevel resolves the default think level in an OpenClaw-compatible way: // low for reasoning-capable models, off otherwise. func (oc *AIClient) defaultThinkLevel(meta *PortalMetadata) string { - if meta != nil { - level := strings.ToLower(strings.TrimSpace(meta.ThinkingLevel)) - if level != "" { - return level - } - } switch effort := strings.ToLower(strings.TrimSpace(oc.effectiveReasoningEffort(meta))); effort { case "off", "none": return "off" @@ -1571,39 +1547,25 @@ func (oc *AIClient) defaultThinkLevel(meta *PortalMetadata) string { } return effort } - if meta != nil && meta.Capabilities.SupportsReasoning { + if caps := oc.getModelCapabilitiesForMeta(meta); caps.SupportsReasoning { return "low" } + if modelID := strings.TrimSpace(oc.effectiveModel(meta)); modelID != "" { + if info := oc.findModelInfo(modelID); info != nil && info.SupportsReasoning { + return "low" + } + } return "off" } -// effectiveReasoningEffort returns the reasoning effort to use -// Priority: Room ? User ? "" (none) func (oc *AIClient) effectiveReasoningEffort(meta *PortalMetadata) string { - if meta != nil && !meta.Capabilities.SupportsReasoning { + if !oc.getModelCapabilitiesForMeta(meta).SupportsReasoning { return "" } - if meta != nil && meta.ReasoningEffort != "" { - return meta.ReasoningEffort - } - var loginMeta *UserLoginMetadata - if oc != nil && oc.UserLogin != nil { - loginMeta = loginMetadata(oc.UserLogin) - } - if loginMeta != nil && loginMeta.Defaults != nil && loginMeta.Defaults.ReasoningEffort != "" { - return loginMeta.Defaults.ReasoningEffort - } - if meta != nil && meta.Capabilities.SupportsReasoning { - return defaultReasoningEffort - } - return "" + return defaultReasoningEffort } func (oc *AIClient) historyLimit(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int { - if meta != nil && meta.MaxContextMessages > 0 { - return meta.MaxContextMessages - } - isGroup := portal != nil && oc.isGroupChat(ctx, portal) if oc != nil && oc.connector != nil && oc.connector.Config.Messages != nil { if isGroup { @@ -1624,18 +1586,11 @@ func (oc *AIClient) historyLimit(ctx context.Context, portal *bridgev2.Portal, m func (oc *AIClient) effectiveMaxTokens(meta *PortalMetadata) int { var maxTokens int - // 1. Per-room override (highest priority) - if meta != nil && meta.MaxCompletionTokens > 0 { - maxTokens = meta.MaxCompletionTokens + modelID := oc.effectiveModel(meta) + if info := oc.findModelInfo(modelID); info != nil && info.MaxOutputTokens > 0 { + maxTokens = info.MaxOutputTokens } else { - // 2. Model catalog MaxOutputTokens - modelID := oc.effectiveModel(meta) - if info := oc.findModelInfo(modelID); info != nil && info.MaxOutputTokens > 0 { - maxTokens = info.MaxOutputTokens - } else { - // 3. Hardcoded fallback - maxTokens = defaultMaxTokens - } + maxTokens = defaultMaxTokens } // Cap at context window to prevent impossible requests. // When max output tokens >= context window (common for thinking/reasoning diff --git a/pkg/connector/client_capabilities_test.go b/pkg/connector/client_capabilities_test.go index 9f769197..39074f1d 100644 --- a/pkg/connector/client_capabilities_test.go +++ b/pkg/connector/client_capabilities_test.go @@ -15,8 +15,8 @@ func TestGetCapabilities_SimpleModeDisablesReplyEditReaction(t *testing.T) { oc := &AIClient{connector: &OpenAIConnector{}} portal := &bridgev2.Portal{ Portal: &database.Portal{ + OtherUserID: modelUserID("openai/gpt-5"), Metadata: &PortalMetadata{ - IsSimpleMode: true, Capabilities: ModelCapabilities{SupportsToolCalling: true}, }, }, @@ -46,6 +46,7 @@ func TestGetCapabilities_NonSimpleEnablesReplyEditReaction(t *testing.T) { oc := &AIClient{connector: &OpenAIConnector{}} portal := &bridgev2.Portal{ Portal: &database.Portal{ + OtherUserID: agentUserID("beeper"), Metadata: &PortalMetadata{ Capabilities: ModelCapabilities{SupportsToolCalling: true}, }, @@ -68,6 +69,7 @@ func TestGetCapabilities_MessageToolDisabledDisablesReplyEditReaction(t *testing oc := &AIClient{connector: &OpenAIConnector{}} portal := &bridgev2.Portal{ Portal: &database.Portal{ + OtherUserID: agentUserID("beeper"), Metadata: &PortalMetadata{ Capabilities: ModelCapabilities{SupportsToolCalling: true}, DisabledTools: []string{ diff --git a/pkg/connector/command_registry.go b/pkg/connector/command_registry.go index 049dc1c6..55e3432f 100644 --- a/pkg/connector/command_registry.go +++ b/pkg/connector/command_registry.go @@ -20,6 +20,17 @@ import ( var aiCommandRegistry = commandregistry.NewRegistry() var moduleCommandRegisterMu sync.Mutex var moduleCommandsRegistered = map[string]struct{}{} +var allowedUserCommandNames = map[string]struct{}{ + "new": {}, + "reset": {}, + "status": {}, + "stop": {}, +} + +func isUserFacingCommand(name string) bool { + _, ok := allowedUserCommandNames[strings.TrimSpace(strings.ToLower(name))] + return ok +} func registerAICommand(def commandregistry.Definition) *commands.FullHandler { return aiCommandRegistry.Register(def) @@ -101,6 +112,9 @@ func registerCommandsWithOwnerGuard(proc *commands.Processor, cfg *Config, log * if handler == nil || handler.Func == nil { continue } + if !isUserFacingCommand(handler.Name) { + continue + } original := handler.Func handler.Func = func(ce *commands.Event) { senderID := "" @@ -151,6 +165,9 @@ func (oc *AIClient) BroadcastCommandDescriptions(ctx context.Context, portal *br if handler == nil || handler.Name == "" { continue } + if !isUserFacingCommand(handler.Name) { + continue + } stateKey := handler.Name content := buildCommandDescriptionContent(handler) _, err := bot.SendState(ctx, portal.MXID, event.StateMSC4391BotCommand, stateKey, &event.Content{ diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 5b3c0f46..b6fd2646 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -2,13 +2,11 @@ package connector import ( "context" - "encoding/json" "fmt" "strings" "sync" "time" - "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" @@ -145,7 +143,7 @@ func (oc *OpenAIConnector) SetLocalAIBridgeLogin(userMXID id.UserID, accessToken } } -// registerCustomEventHandlers registers handlers for custom Matrix state events +// registerCustomEventHandlers registers connector-owned event handlers. func (oc *OpenAIConnector) registerCustomEventHandlers() { // Type assert the Matrix connector to get the concrete type with EventProcessor matrixConnector, ok := oc.br.Matrix.(*matrix.Connector) @@ -154,200 +152,10 @@ func (oc *OpenAIConnector) registerCustomEventHandlers() { return } - // Register handler for direct room settings state events - matrixConnector.EventProcessor.On(RoomSettingsEventType, oc.handleRoomSettingsEvent) - - // Register handler for BeeperSendState wrapper events (desktop E2EE state updates) - matrixConnector.EventProcessor.On(event.BeeperSendState, oc.handleBeeperSendStateEvent) - // Register handler for internal scheduler delayed ticks. matrixConnector.EventProcessor.On(ScheduleTickEventType, oc.handleScheduleTickEvent) - oc.br.Log.Info(). - Str("beeper_send_state_type", event.BeeperSendState.Type). - Str("beeper_send_state_class", event.BeeperSendState.Class.Name()). - Msg("Registered room settings event handlers (direct and BeeperSendState)") -} - -// handleRoomSettingsEvent processes Matrix room settings state events from users -func (oc *OpenAIConnector) handleRoomSettingsEvent(ctx context.Context, evt *event.Event) { - log := oc.br.Log.With(). - Str("component", "room_settings_handler"). - Str("room_id", evt.RoomID.String()). - Str("sender", evt.Sender.String()). - Logger() - - // Parse event content - var content RoomSettingsEventContent - if err := json.Unmarshal(evt.Content.VeryRaw, &content); err != nil { - log.Warn().Err(err).Msg("Failed to parse room settings event content") - return - } - - oc.processRoomSettingsContent(ctx, evt, &content, log) -} - -// processRoomSettingsContent handles the common logic for updating portal settings -// Called by both handleRoomSettingsEvent and handleBeeperSendStateEvent -func (oc *OpenAIConnector) processRoomSettingsContent( - ctx context.Context, - evt *event.Event, - content *RoomSettingsEventContent, - log zerolog.Logger, -) { - if evt == nil { - return - } - roomID := evt.RoomID - sender := evt.Sender - // Look up portal by Matrix room ID - portal, err := oc.br.GetPortalByMXID(ctx, roomID) - if err != nil { - log.Err(err).Msg("Failed to get portal for room settings event") - return - } - if portal == nil { - log.Debug().Msg("No portal found for room, ignoring settings event") - return - } - - // Get the user who sent the event and their login - user, err := oc.br.GetUserByMXID(ctx, sender) - if err != nil || user == nil { - log.Warn().Err(err).Msg("Failed to get user for room settings event") - return - } - - // Use getLoginForPortal to find the correct login based on portal's receiver - // This ensures we use the right provider when user has multiple accounts - login := oc.getLoginForPortal(ctx, user, portal) - if login == nil { - log.Warn().Msg("User has no active login, cannot process settings") - return - } - - client, ok := login.Client.(*AIClient) - if !ok || client == nil { - log.Warn().Msg("Invalid client type for user login") - return - } - - // Validate model if specified - if content.Model != "" { - resolved, valid, err := client.resolveModelID(ctx, content.Model) - if err != nil { - log.Warn().Err(err).Str("model", content.Model).Msg("Failed to validate model") - } else if !valid { - log.Warn().Str("model", content.Model).Msg("Invalid model specified, ignoring") - client.sendSystemNotice(ctx, portal, fmt.Sprintf("That model isn't available: %s. Settings weren't applied.", content.Model)) - return - } - content.Model = resolved - } - - // Update portal metadata with optimistic update + rollback behavior. - if err := client.updatePortalConfig(ctx, portal, content); err != nil { - sendStateEventFailureStatus(ctx, portal, evt, err) - log.Warn().Err(err).Msg("Failed to apply room settings state event") - return - } - - sendStateEventSuccessStatus(ctx, portal, evt) - - // Send confirmation notice - var changes []string - if content.Model != "" { - changes = append(changes, fmt.Sprintf("model=%s", content.Model)) - } - if content.Temperature != nil { - changes = append(changes, fmt.Sprintf("temperature=%.2f", *content.Temperature)) - } - if content.MaxContextMessages > 0 { - changes = append(changes, fmt.Sprintf("context=%d messages", content.MaxContextMessages)) - } - if content.MaxCompletionTokens > 0 { - changes = append(changes, fmt.Sprintf("max_tokens=%d", content.MaxCompletionTokens)) - } - if content.SystemPrompt != "" { - changes = append(changes, "system_prompt updated") - } - if content.ReasoningEffort != "" { - changes = append(changes, fmt.Sprintf("reasoning_effort=%s", content.ReasoningEffort)) - } - if len(changes) > 0 { - client.sendSystemNotice(ctx, portal, fmt.Sprintf("Configuration updated: %s", strings.Join(changes, ", "))) - } - - logEvent := log.Info().Str("model", content.Model) - if content.Temperature != nil { - logEvent = logEvent.Float64("temperature", *content.Temperature) - } - logEvent.Msg("Updated room settings from state event") -} - -// handleBeeperSendStateEvent processes com.beeper.send_state wrapper events -// This is used by the desktop client to send state events in encrypted rooms -func (oc *OpenAIConnector) handleBeeperSendStateEvent(ctx context.Context, evt *event.Event) { - log := oc.br.Log.With(). - Str("component", "beeper_send_state_handler"). - Str("room_id", evt.RoomID.String()). - Str("sender", evt.Sender.String()). - Str("event_type", evt.Type.Type). - Str("event_class", evt.Type.Class.Name()). - Logger() - - log.Info().RawJSON("raw_content", evt.Content.VeryRaw).Msg("Received BeeperSendState event") - - // Parse the wrapper content - var wrapperContent event.BeeperSendStateEventContent - if err := json.Unmarshal(evt.Content.VeryRaw, &wrapperContent); err != nil { - log.Debug().Err(err).Msg("Failed to parse BeeperSendState content") - return - } - - // Only process AI room settings events - if wrapperContent.Type != RoomSettingsEventType.Type { - return - } - - log.Debug(). - Str("inner_type", wrapperContent.Type). - Str("state_key", wrapperContent.StateKey). - Msg("Processing BeeperSendState wrapper for AI room settings") - - // Parse the inner room settings content - var content RoomSettingsEventContent - if err := json.Unmarshal(wrapperContent.Content.VeryRaw, &content); err != nil { - log.Warn().Err(err).Msg("Failed to parse inner room settings content") - return - } - - // Reuse existing handler logic with the parsed content - oc.processRoomSettingsContent(ctx, evt, &content, log) -} - -func sendStateEventFailureStatus(ctx context.Context, portal *bridgev2.Portal, evt *event.Event, err error) { - if portal == nil || portal.Bridge == nil || evt == nil || err == nil { - return - } - msgStatus := bridgev2.WrapErrorInStatus(err). - WithStatus(event.MessageStatusRetriable). - WithErrorReason(event.MessageStatusGenericError). - WithMessage("Failed to apply room settings. Your change was rolled back."). - WithIsCertain(true). - WithSendNotice(false) - portal.Bridge.Matrix.SendMessageStatus(ctx, &msgStatus, bridgev2.StatusEventInfoFromEvent(evt)) -} - -func sendStateEventSuccessStatus(ctx context.Context, portal *bridgev2.Portal, evt *event.Event) { - if portal == nil || portal.Bridge == nil || evt == nil { - return - } - msgStatus := bridgev2.MessageStatus{ - Status: event.MessageStatusSuccess, - IsCertain: true, - } - portal.Bridge.Matrix.SendMessageStatus(ctx, &msgStatus, bridgev2.StatusEventInfoFromEvent(evt)) + oc.br.Log.Info().Msg("Registered connector event handlers") } func (oc *OpenAIConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities { diff --git a/pkg/connector/defaults_alignment_test.go b/pkg/connector/defaults_alignment_test.go index 86f492aa..1657c8c7 100644 --- a/pkg/connector/defaults_alignment_test.go +++ b/pkg/connector/defaults_alignment_test.go @@ -1,6 +1,11 @@ package connector -import "testing" +import ( + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" +) func TestEffectiveTemperatureDefaultUnset(t *testing.T) { client := &AIClient{} @@ -10,11 +15,22 @@ func TestEffectiveTemperatureDefaultUnset(t *testing.T) { } func TestDefaultThinkLevelModelAware(t *testing.T) { - client := &AIClient{} + client := &AIClient{ + connector: &OpenAIConnector{}, + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + Provider: ProviderOpenRouter, + ModelCache: &ModelCache{Models: []ModelInfo{ + {ID: "openai/o4-mini", SupportsReasoning: true}, + {ID: "openai/gpt-4o-mini", SupportsReasoning: false}, + }}, + }}}, + } reasoningMeta := &PortalMetadata{ - Capabilities: ModelCapabilities{ - SupportsReasoning: true, + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/o4-mini"), + ModelID: "openai/o4-mini", }, } if got := client.defaultThinkLevel(reasoningMeta); got != "low" { @@ -22,8 +38,10 @@ func TestDefaultThinkLevelModelAware(t *testing.T) { } nonReasoningMeta := &PortalMetadata{ - Capabilities: ModelCapabilities{ - SupportsReasoning: false, + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-4o-mini"), + ModelID: "openai/gpt-4o-mini", }, } if got := client.defaultThinkLevel(nonReasoningMeta); got != "off" { @@ -31,30 +49,27 @@ func TestDefaultThinkLevelModelAware(t *testing.T) { } } -func TestDefaultThinkLevelHonorsExplicitThinkingLevel(t *testing.T) { - client := &AIClient{} - meta := &PortalMetadata{ - ThinkingLevel: "high", - Capabilities: ModelCapabilities{ - SupportsReasoning: true, - }, +func TestDefaultThinkLevelIgnoresLegacyThinkingOverrides(t *testing.T) { + client := &AIClient{ + connector: &OpenAIConnector{}, + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + Provider: ProviderOpenRouter, + ModelCache: &ModelCache{Models: []ModelInfo{ + {ID: "openai/o4-mini", SupportsReasoning: true}, + }}, + }}}, } - - if got := client.defaultThinkLevel(meta); got != "high" { - t.Fatalf("expected explicit thinking level to win, got %q", got) - } -} - -func TestDefaultThinkLevelUsesReasoningEffortFallback(t *testing.T) { - client := &AIClient{} meta := &PortalMetadata{ - Capabilities: ModelCapabilities{ - SupportsReasoning: true, - }, + ThinkingLevel: "high", ReasoningEffort: "medium", + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/o4-mini"), + ModelID: "openai/o4-mini", + }, } - if got := client.defaultThinkLevel(meta); got != "medium" { - t.Fatalf("expected medium from reasoning effort, got %q", got) + if got := client.defaultThinkLevel(meta); got != "low" { + t.Fatalf("expected ghost/model-derived low think level, got %q", got) } } diff --git a/pkg/connector/error_logging.go b/pkg/connector/error_logging.go index dd1191c6..6b0f9646 100644 --- a/pkg/connector/error_logging.go +++ b/pkg/connector/error_logging.go @@ -46,9 +46,7 @@ func addRequestSummary(event *zerolog.Event, meta *PortalMetadata, messages []op event.Int("message_count", len(messages)) event.Bool("has_audio", hasAudioContent(messages)) event.Bool("has_multimodal", hasMultimodalContent(messages)) - if meta != nil { - event.Bool("tool_calling", meta.Capabilities.SupportsToolCalling) - } + _ = meta } func addResponsesParamsSummary(event *zerolog.Event, params responses.ResponseNewParams) { diff --git a/pkg/connector/events.go b/pkg/connector/events.go index 8e63c6f7..a1ed4bf0 100644 --- a/pkg/connector/events.go +++ b/pkg/connector/events.go @@ -13,9 +13,6 @@ import ( // init registers custom AI event types with mautrix's TypeMap // so the state store can properly parse them during sync func init() { - event.TypeMap[RoomCapabilitiesEventType] = reflect.TypeOf(RoomCapabilitiesEventContent{}) - event.TypeMap[RoomSettingsEventType] = reflect.TypeOf(RoomSettingsEventContent{}) - event.TypeMap[ModelCapabilitiesEventType] = reflect.TypeOf(ModelCapabilitiesEventContent{}) event.TypeMap[AgentsEventType] = reflect.TypeOf(AgentsEventContent{}) } @@ -25,17 +22,6 @@ var StreamEventMessageType = matrixevents.StreamEventMessageType // CompactionStatusEventType notifies clients about context compaction var CompactionStatusEventType = matrixevents.CompactionStatusEventType -// RoomCapabilitiesEventType is the Matrix state event type for bridge-controlled capabilities -// Protected by power levels (100) so only the bridge bot can modify -var RoomCapabilitiesEventType = matrixevents.RoomCapabilitiesEventType - -// RoomSettingsEventType is the Matrix state event type for user-editable settings -// Normal power level (0) so users can modify -var RoomSettingsEventType = matrixevents.RoomSettingsEventType - -// ModelCapabilitiesEventType is the Matrix state event type for broadcasting available models -var ModelCapabilitiesEventType = matrixevents.ModelCapabilitiesEventType - // AgentsEventType configures active agents in a room var AgentsEventType = matrixevents.AgentsEventType @@ -69,82 +55,29 @@ const ( ToolTypeMCP = matrixevents.ToolTypeMCP ) -// ReasoningEffortOption represents an available reasoning effort level -type ReasoningEffortOption struct { - Value string `json:"value"` // minimal, low, medium, high, xhigh - Label string `json:"label"` // Display name -} - -// SettingSource indicates where a setting value came from +// SettingSource indicates where a setting or availability decision came from. type SettingSource string const ( SourceAgentPolicy SettingSource = "agent_policy" - SourceRoomOverride SettingSource = "room_override" - SourceUserDefault SettingSource = "user_default" SourceProviderConfig SettingSource = "provider_config" SourceGlobalDefault SettingSource = "global_default" SourceModelLimit SettingSource = "model_limitation" SourceProviderLimit SettingSource = "provider_limitation" ) -// SettingExplanation describes why a setting has its current value -type SettingExplanation struct { - Value any `json:"value"` - Source SettingSource `json:"source"` - Reason string `json:"reason,omitempty"` // Only when limited/unavailable -} - -// EffectiveSettings shows current values with source explanations -type EffectiveSettings struct { - Model SettingExplanation `json:"model"` - SystemPrompt SettingExplanation `json:"system_prompt"` - Temperature SettingExplanation `json:"temperature"` - ReasoningEffort SettingExplanation `json:"reasoning_effort"` -} - -// RoomCapabilitiesEventContent represents bridge-controlled room capabilities -// This is protected by power levels (100) so only the bridge bot can modify -type RoomCapabilitiesEventContent struct { - Capabilities *ModelCapabilities `json:"capabilities,omitempty"` - AvailableTools []ToolInfo `json:"available_tools,omitempty"` - ReasoningEffortOptions []ReasoningEffortOption `json:"reasoning_effort_options,omitempty"` - Provider string `json:"provider,omitempty"` - EffectiveSettings *EffectiveSettings `json:"effective_settings,omitempty"` -} - -// RoomSettingsEventContent represents user-editable room settings -// This uses normal power levels (0) so users can modify -type RoomSettingsEventContent struct { - Model string `json:"model,omitempty"` - SystemPrompt string `json:"system_prompt,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxContextMessages int `json:"max_context_messages,omitempty"` - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` - AgentID string `json:"agent_id,omitempty"` - EmitThinking *bool `json:"emit_thinking,omitempty"` - EmitToolArgs *bool `json:"emit_tool_args,omitempty"` -} - -// ToolInfo describes a tool and its status for room state broadcasting +// ToolInfo describes a tool and its status for internal UI/config rendering. type ToolInfo struct { Name string `json:"name"` - DisplayName string `json:"display_name"` // Human-readable name for UI - Type string `json:"type"` // "builtin", "provider", "plugin", "mcp" + DisplayName string `json:"display_name"` + Type string `json:"type"` Description string `json:"description,omitempty"` Enabled bool `json:"enabled"` - Available bool `json:"available"` // Based on model capabilities and provider - Source SettingSource `json:"source,omitempty"` // Where enabled state came from - Reason string `json:"reason,omitempty"` // Only when limited/unavailable -} - -// ModelCapabilitiesEventContent represents available models and their capabilities -type ModelCapabilitiesEventContent struct { - AvailableModels []ModelInfo `json:"available_models"` + Available bool `json:"available"` + Source SettingSource `json:"source,omitempty"` + Reason string `json:"reason,omitempty"` } -// Tool constants for model capabilities const ( ToolWebSearch = "web_search" ToolFunctionCalling = "function_calling" diff --git a/pkg/connector/handleai.go b/pkg/connector/handleai.go index 04c6f8d0..e4c2ab3c 100644 --- a/pkg/connector/handleai.go +++ b/pkg/connector/handleai.go @@ -219,7 +219,7 @@ func isInternalControlRoom(meta *PortalMetadata) bool { if meta == nil { return false } - return meta.IsBuilderRoom || isModuleInternalRoom(meta) + return isModuleInternalRoom(meta) } func autoGreetingBlockReason(meta *PortalMetadata) string { @@ -369,19 +369,16 @@ func (oc *AIClient) sendWelcomeMessage(ctx context.Context, portal *bridgev2.Por // Still send the welcome notice and schedule greeting; duplicates are preferable to missing UX. } - if meta.AgentID == "" { - displayName := modelContactName(meta.Model, oc.findModelInfo(meta.Model)) + if resolveAgentID(meta) == "" { + modelID := oc.effectiveModel(meta) + displayName := modelContactName(modelID, oc.findModelInfo(modelID)) oc.sendSystemNotice(bgCtx, portal, fmt.Sprintf("You are chatting with %s. AI can make mistakes.", displayName)) } else { oc.sendSystemNotice(bgCtx, portal, "AI can make mistakes.") } - // Ensure initial room state exists for clients (model/settings/capabilities). - // Only broadcast once on first-room initialization. - if meta.LastRoomStateSync == 0 { - if err := oc.BroadcastRoomState(bgCtx, portal); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to broadcast initial room state") - } + if err := oc.BroadcastRoomState(bgCtx, portal); err != nil { + oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to broadcast room state") } oc.scheduleAutoGreeting(bgCtx, portal) @@ -649,18 +646,9 @@ func (oc *AIClient) setRoomSystemPromptNoSave(ctx context.Context, portal *bridg } func (oc *AIClient) setRoomSystemPromptInternal(ctx context.Context, portal *bridgev2.Portal, prompt string, save bool) error { - if portal.MXID == "" { - return errors.New("portal has no Matrix room ID") - } - - meta := portalMeta(portal) - meta.SystemPrompt = prompt - - if save { - if err := portal.Save(ctx); err != nil { - return fmt.Errorf("failed to save portal: %w", err) - } - oc.loggerForContext(ctx).Debug().Str("prompt_len", fmt.Sprintf("%d", len(prompt))).Msg("Set room system prompt") - } + _ = ctx + _ = portal + _ = prompt + _ = save return nil } diff --git a/pkg/connector/heartbeat_delivery.go b/pkg/connector/heartbeat_delivery.go index 6f177c06..7850970b 100644 --- a/pkg/connector/heartbeat_delivery.go +++ b/pkg/connector/heartbeat_delivery.go @@ -44,7 +44,7 @@ func (oc *AIClient) resolveHeartbeatDeliveryTarget(agentID string, heartbeat *He if target.Portal != nil && target.RoomID != "" { // Stale agent routing guard: skip if portal is now assigned to a // different agent (matches resolveHeartbeatSessionPortal behavior). - if meta := portalMeta(target.Portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { + if meta := portalMeta(target.Portal); meta != nil && normalizeAgentID(resolveAgentID(meta)) != normalizeAgentID(agentID) { // Fall through to lastActivePortal / defaultChatPortal. } else { return target diff --git a/pkg/connector/heartbeat_execute.go b/pkg/connector/heartbeat_execute.go index 8f65cb11..f2f76dcf 100644 --- a/pkg/connector/heartbeat_execute.go +++ b/pkg/connector/heartbeat_execute.go @@ -329,7 +329,7 @@ func (oc *AIClient) resolveHeartbeatSessionPortal(agentID string, heartbeat *Hea } if strings.HasPrefix(session, "!") { if portal := oc.portalByRoomID(context.Background(), id.RoomID(session)); portal != nil { - if meta := portalMeta(portal); meta == nil || normalizeAgentID(meta.AgentID) == normalizeAgentID(agentID) { + if meta := portalMeta(portal); meta == nil || normalizeAgentID(resolveAgentID(meta)) == normalizeAgentID(agentID) { return portal, portal.MXID.String(), nil } } @@ -360,7 +360,7 @@ func (oc *AIClient) heartbeatSessionPortalCandidate(agentID string, session hear if portal == nil { return nil } - if meta := portalMeta(portal); meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) { + if meta := portalMeta(portal); meta != nil && normalizeAgentID(resolveAgentID(meta)) != normalizeAgentID(agentID) { return nil } return portal diff --git a/pkg/connector/history_limit_test.go b/pkg/connector/history_limit_test.go index ec589443..8ae8f695 100644 --- a/pkg/connector/history_limit_test.go +++ b/pkg/connector/history_limit_test.go @@ -8,14 +8,14 @@ import ( "maunium.net/go/mautrix/bridgev2/database" ) -func TestHistoryLimitMetaOverrideWins(t *testing.T) { +func TestHistoryLimitIgnoresLegacyMetaOverride(t *testing.T) { client := &AIClient{} portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test", RoomType: database.RoomTypeGroupDM}} meta := &PortalMetadata{MaxContextMessages: 7} limit := client.historyLimit(context.Background(), portal, meta) - if limit != 7 { - t.Fatalf("expected 7, got %d", limit) + if limit != defaultGroupContextMessages { + t.Fatalf("expected group default %d, got %d", defaultGroupContextMessages, limit) } } diff --git a/pkg/connector/identifiers.go b/pkg/connector/identifiers.go index 45c45783..57ce006b 100644 --- a/pkg/connector/identifiers.go +++ b/pkg/connector/identifiers.go @@ -27,10 +27,32 @@ func nthLoginID(providerSlug string, mxid id.UserID, ordinal int) networkid.User return networkid.UserLoginID(fmt.Sprintf("%s:%d", base, ordinal)) } +func nextLoginID(user *bridgev2.User, providerSlug string, mxid id.UserID) networkid.UserLoginID { + used := map[string]struct{}{} + if user != nil { + for _, existing := range user.GetUserLogins() { + if existing == nil { + continue + } + used[string(existing.ID)] = struct{}{} + } + } + for ordinal := 1; ; ordinal++ { + loginID := nthLoginID(providerSlug, mxid, ordinal) + if _, ok := used[string(loginID)]; !ok { + return loginID + } + } +} + func providerLoginID(provider string, mxid id.UserID, ordinal int) networkid.UserLoginID { return nthLoginID(providerSlug(provider), mxid, ordinal) } +func nextProviderLoginID(user *bridgev2.User, provider string, mxid id.UserID) networkid.UserLoginID { + return nextLoginID(user, providerSlug(provider), mxid) +} + func managedBeeperLoginID(mxid id.UserID) networkid.UserLoginID { return baseLoginID("beeper", mxid) } @@ -105,15 +127,61 @@ func humanUserID(loginID networkid.UserLoginID) networkid.UserID { return bridgeadapter.HumanUserID("openai-user", loginID) } +const ( + ResolvedTargetUnknown = "" + ResolvedTargetModel = "model" + ResolvedTargetAgent = "agent" +) + +type ResolvedTarget struct { + Kind string + GhostID networkid.UserID + ModelID string + AgentID string +} + +func resolveTargetFromGhostID(ghostID networkid.UserID) *ResolvedTarget { + if ghostID == "" { + return nil + } + if modelID := strings.TrimSpace(parseModelFromGhostID(string(ghostID))); modelID != "" { + return &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: ghostID, + ModelID: modelID, + } + } + if agentID, ok := parseAgentFromGhostID(string(ghostID)); ok && strings.TrimSpace(agentID) != "" { + return &ResolvedTarget{ + Kind: ResolvedTargetAgent, + GhostID: ghostID, + AgentID: strings.TrimSpace(agentID), + } + } + return nil +} + +func resolvedAgentIDForGhost(ghostID networkid.UserID) string { + target := resolveTargetFromGhostID(ghostID) + if target == nil { + return "" + } + return target.AgentID +} + func portalMeta(portal *bridgev2.Portal) *PortalMetadata { - return bridgeadapter.EnsurePortalMetadata[PortalMetadata](portal) + meta := bridgeadapter.EnsurePortalMetadata[PortalMetadata](portal) + if meta != nil { + meta.ResolvedTarget = resolveTargetFromGhostID(portal.OtherUserID) + } + return meta } func resolveAgentID(meta *PortalMetadata) string { - if meta == nil { + if meta == nil || meta.ResolvedTarget == nil { return "" } - return meta.AgentID + return meta.ResolvedTarget.AgentID } func messageMeta(msg *database.Message) *MessageMetadata { diff --git a/pkg/connector/inbound_directive_apply.go b/pkg/connector/inbound_directive_apply.go index ecdf14df..ab5de54b 100644 --- a/pkg/connector/inbound_directive_apply.go +++ b/pkg/connector/inbound_directive_apply.go @@ -3,33 +3,13 @@ package connector import "fmt" func applyThinkingLevel(meta *PortalMetadata, level string) { - if meta == nil { - return - } - meta.ThinkingLevel = level - meta.EmitThinking = level != "off" - if level == "minimal" { - meta.ReasoningEffort = "low" - } else if level == "low" || level == "medium" || level == "high" || level == "xhigh" { - meta.ReasoningEffort = level - } + _ = meta + _ = level } func applyReasoningLevel(meta *PortalMetadata, level string) { - if meta == nil { - return - } - if level == "off" { - meta.EmitThinking = false - meta.ReasoningEffort = "" - return - } - if level == "on" { - meta.EmitThinking = true - return - } - meta.EmitThinking = true - meta.ReasoningEffort = level + _ = meta + _ = level } func formatThinkingAck(level string) string { diff --git a/pkg/connector/inbound_prompt_runtime_test.go b/pkg/connector/inbound_prompt_runtime_test.go index 29ce5c50..d39b97cb 100644 --- a/pkg/connector/inbound_prompt_runtime_test.go +++ b/pkg/connector/inbound_prompt_runtime_test.go @@ -70,7 +70,7 @@ func TestBuildPromptWithLinkContext_SimpleModeSkipsInboundRuntimeMetadata(t *tes }, }, } - meta := &PortalMetadata{IsSimpleMode: true} + meta := simpleModeTestMeta("openai/gpt-5") ctx := withInboundContext(context.Background(), airuntime.InboundContext{ Provider: "matrix", Surface: "beeper-matrix", diff --git a/pkg/connector/integration_host.go b/pkg/connector/integration_host.go index e5f09040..a41da4f5 100644 --- a/pkg/connector/integration_host.go +++ b/pkg/connector/integration_host.go @@ -299,18 +299,6 @@ func (h *runtimeIntegrationHost) SetMetaField(meta any, key string, value any) { return } switch key { - case "AgentID": - if v, ok := value.(string); ok { - m.AgentID = v - } - case "Model": - if v, ok := value.(string); ok { - m.Model = strings.TrimSpace(v) - } - case "ReasoningEffort": - if v, ok := value.(string); ok { - m.ReasoningEffort = strings.TrimSpace(v) - } case "DisabledTools": if v, ok := value.([]string); ok { m.DisabledTools = v diff --git a/pkg/connector/integrations.go b/pkg/connector/integrations.go index 20657b63..8b47cff2 100644 --- a/pkg/connector/integrations.go +++ b/pkg/connector/integrations.go @@ -664,9 +664,6 @@ func integrationSessionKind(currentRoomID string, portalRoomID string, meta *Por if strings.TrimSpace(meta.SubagentParentRoomID) != "" { return "other" } - if meta.IsBuilderRoom { - return "other" - } } return "group" } diff --git a/pkg/connector/metadata.go b/pkg/connector/metadata.go index 5aee49c5..9ffa0bdb 100644 --- a/pkg/connector/metadata.go +++ b/pkg/connector/metadata.go @@ -46,7 +46,15 @@ type FileAnnotation struct { CreatedAt int64 `json:"created_at"` // Unix timestamp when cached } -// UserDefaults stores user-level default settings for new chats +type UserProfile struct { + Name string `json:"name,omitempty"` + Occupation string `json:"occupation,omitempty"` + AboutUser string `json:"about_user,omitempty"` + CustomInstructions string `json:"custom_instructions,omitempty"` +} + +// Legacy-only storage type kept so old JSON can still unmarshal during the hard cut. +// New code must not read this. type UserDefaults struct { Model string `json:"model,omitempty"` SystemPrompt string `json:"system_prompt,omitempty"` @@ -110,7 +118,6 @@ type BuiltinAlwaysAllowRule struct { // UserLoginMetadata is stored on each login row to keep per-user settings. type UserLoginMetadata struct { - Persona string `json:"persona,omitempty"` Provider string `json:"provider,omitempty"` // Selected provider (beeper, openai, openrouter) APIKey string `json:"api_key,omitempty"` BaseURL string `json:"base_url,omitempty"` // Per-user API endpoint @@ -121,13 +128,14 @@ type UserLoginMetadata struct { ChatsSynced bool `json:"chats_synced,omitempty"` // True after initial bootstrap completed successfully Gravatar *GravatarState `json:"gravatar,omitempty"` Timezone string `json:"timezone,omitempty"` - ResponsePrefix string `json:"response_prefix,omitempty"` + Profile *UserProfile `json:"profile,omitempty"` + ResponsePrefix string `json:"response_prefix,omitempty"` // Legacy-only. New code must not read this. // FileAnnotationCache stores parsed PDF content from OpenRouter's file-parser plugin // Key is the file hash (SHA256), pruned after 7 days FileAnnotationCache map[string]FileAnnotation `json:"file_annotation_cache,omitempty"` - // User-level defaults for new chats (set via provisioning API) + // Legacy-only. New code must not read this. Defaults *UserDefaults `json:"defaults,omitempty"` // Optional per-login tokens for external services @@ -136,13 +144,10 @@ type UserLoginMetadata struct { // Tool approval rules (e.g. "always allow" decisions for MCP approvals or dangerous builtin tools). ToolApprovals *ToolApprovalsConfig `json:"tool_approvals,omitempty"` - // AgentModelOverrides stores per-agent model overrides (agent ID -> model ID). - AgentModelOverrides map[string]string `json:"agent_model_overrides,omitempty"` - - // Agent Builder room for managing agents - BuilderRoomID networkid.PortalID `json:"builder_room_id,omitempty"` // Custom agents store (source of truth for user-created agents). CustomAgents map[string]*AgentDefinitionContent `json:"custom_agents,omitempty"` + // Legacy-only. New code must not read this. + BuilderRoomID networkid.PortalID `json:"builder_room_id,omitempty"` // Last active room per agent (used for heartbeat delivery). LastActiveRoomByAgent map[string]string `json:"last_active_room_by_agent,omitempty"` // Heartbeat dedupe state per agent. @@ -176,51 +181,54 @@ type GravatarState struct { // PortalMetadata stores per-room tuning knobs for the assistant. type PortalMetadata struct { - Model string `json:"model,omitempty"` // Set from room state - SystemPrompt string `json:"system_prompt,omitempty"` // Set from room state - ResponsePrefix string `json:"response_prefix,omitempty"` // Per-room response prefix override - Temperature float64 `json:"temperature,omitempty"` // Set from room state - MaxContextMessages int `json:"max_context_messages,omitempty"` // Set from room state - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // Set from room state - ReasoningEffort string `json:"reasoning_effort,omitempty"` // none, low, medium, high, xhigh - Slug string `json:"slug,omitempty"` - Title string `json:"title,omitempty"` - TitleGenerated bool `json:"title_generated,omitempty"` // True if title was auto-generated - WelcomeSent bool `json:"welcome_sent,omitempty"` - AutoGreetingSent bool `json:"auto_greeting_sent,omitempty"` - Capabilities ModelCapabilities `json:"capabilities,omitempty"` - LastRoomStateSync int64 `json:"last_room_state_sync,omitempty"` // Track when we've synced room state - PDFConfig *PDFConfig `json:"pdf_config,omitempty"` // Per-room PDF processing configuration - - EmitThinking bool `json:"emit_thinking,omitempty"` - EmitToolArgs bool `json:"emit_tool_args,omitempty"` - ThinkingLevel string `json:"thinking_level,omitempty"` // off|minimal|low|medium|high|xhigh - VerboseLevel string `json:"verbose_level,omitempty"` // off|on|full - ElevatedLevel string `json:"elevated_level,omitempty"` // off|on|ask|full - GroupActivation string `json:"group_activation,omitempty"` // mention|always - GroupActivationNeedsIntro bool `json:"group_activation_needs_intro,omitempty"` - GroupIntroSent bool `json:"group_intro_sent,omitempty"` - SendPolicy string `json:"send_policy,omitempty"` // allow|deny - SessionResetAt int64 `json:"session_reset_at,omitempty"` - AbortedLastRun bool `json:"aborted_last_run,omitempty"` - CompactionCount int `json:"compaction_count,omitempty"` - SessionBootstrappedAt int64 `json:"session_bootstrapped_at,omitempty"` - SessionBootstrapByAgent map[string]int64 `json:"session_bootstrap_by_agent,omitempty"` - - // Agent-related metadata - AgentID string `json:"agent_id,omitempty"` // Which agent is the ghost for this room - AgentPrompt string `json:"agent_prompt,omitempty"` // Cached prompt for the assigned agent - IsBuilderRoom bool `json:"is_builder_room,omitempty"` // True if this is the Manage AI Chats room (protected from overrides) - IsSimpleMode bool `json:"is_simple_mode,omitempty"` // True if this is a simple mode room (no directive processing) + // Legacy-only selector/tuning fields kept to allow old JSON to unmarshal during + // the hard cut. New code must not read these. + Model string `json:"model,omitempty"` + SystemPrompt string `json:"system_prompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxContextMessages int `json:"max_context_messages,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Capabilities ModelCapabilities `json:"capabilities,omitempty"` + PDFConfig *PDFConfig `json:"pdf_config,omitempty"` + EmitThinking bool `json:"emit_thinking,omitempty"` + EmitToolArgs bool `json:"emit_tool_args,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` + VerboseLevel string `json:"verbose_level,omitempty"` + ElevatedLevel string `json:"elevated_level,omitempty"` + GroupActivation string `json:"group_activation,omitempty"` + GroupActivationNeedsIntro bool `json:"group_activation_needs_intro,omitempty"` + GroupIntroSent bool `json:"group_intro_sent,omitempty"` + SendPolicy string `json:"send_policy,omitempty"` + AgentID string `json:"agent_id,omitempty"` + AgentPrompt string `json:"agent_prompt,omitempty"` + IsBuilderRoom bool `json:"is_builder_room,omitempty"` + IsSimpleMode bool `json:"is_simple_mode,omitempty"` + AckReactionEmoji string `json:"ack_reaction_emoji,omitempty"` + AckReactionRemoveAfter bool `json:"ack_reaction_remove_after,omitempty"` + + Slug string `json:"slug,omitempty"` + Title string `json:"title,omitempty"` + TitleGenerated bool `json:"title_generated,omitempty"` // True if title was auto-generated + WelcomeSent bool `json:"welcome_sent,omitempty"` + AutoGreetingSent bool `json:"auto_greeting_sent,omitempty"` + + SessionResetAt int64 `json:"session_reset_at,omitempty"` + AbortedLastRun bool `json:"aborted_last_run,omitempty"` + CompactionCount int `json:"compaction_count,omitempty"` + SessionBootstrappedAt int64 `json:"session_bootstrapped_at,omitempty"` + SessionBootstrapByAgent map[string]int64 `json:"session_bootstrap_by_agent,omitempty"` + ModuleMeta map[string]any `json:"module_meta,omitempty"` // Generic per-module metadata (e.g., cron room markers, memory flush state) SubagentParentRoomID string `json:"subagent_parent_room_id,omitempty"` // Parent room ID for subagent sessions - // Ack reaction config - similar to OpenClaw's ack reactions - AckReactionEmoji string `json:"ack_reaction_emoji,omitempty"` // Emoji to react with when message received (e.g., "👀", "🤔"). Empty = disabled. - AckReactionRemoveAfter bool `json:"ack_reaction_remove_after,omitempty"` // Remove the ack reaction after replying - // Runtime-only overrides (not persisted) - DisabledTools []string `json:"-"` + DisabledTools []string `json:"-"` + ResolvedTarget *ResolvedTarget `json:"-"` + RuntimeModelOverride string `json:"-"` + + // Legacy-only. New code must not read this. + ResponsePrefix string `json:"response_prefix,omitempty"` // Debounce configuration (0 = use default, -1 = disabled) DebounceMs int `json:"debounce_ms,omitempty"` @@ -231,10 +239,8 @@ type PortalMetadata struct { } -// isSimpleMode reports whether the portal is in simple mode -// (no directive processing, minimal agent chrome). func isSimpleMode(meta *PortalMetadata) bool { - return meta != nil && meta.IsSimpleMode + return meta != nil && meta.ResolvedTarget != nil && meta.ResolvedTarget.Kind == ResolvedTargetModel } func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { @@ -256,6 +262,7 @@ func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { if len(src.DisabledTools) > 0 { clone.DisabledTools = slices.Clone(src.DisabledTools) } + clone.ResolvedTarget = src.ResolvedTarget if src.ModuleMeta != nil { clone.ModuleMeta = make(map[string]any, len(src.ModuleMeta)) @@ -263,6 +270,10 @@ func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { clone.ModuleMeta[k] = jsonutil.DeepCloneAny(v) } } + if src.ResolvedTarget != nil { + target := *src.ResolvedTarget + clone.ResolvedTarget = &target + } return &clone } @@ -272,20 +283,26 @@ func clonePortalMetadata(src *PortalMetadata) *PortalMetadata { type MessageMetadata struct { bridgeadapter.BaseMessageMetadata - CompletionID string `json:"completion_id,omitempty"` - Model string `json:"model,omitempty"` - HasToolCalls bool `json:"has_tool_calls,omitempty"` - Transcript string `json:"transcript,omitempty"` - FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` - ThinkingTokenCount int `json:"thinking_token_count,omitempty"` - ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` + CompletionID string `json:"completion_id,omitempty"` + Model string `json:"model,omitempty"` + HasToolCalls bool `json:"has_tool_calls,omitempty"` + Transcript string `json:"transcript,omitempty"` // Media understanding (OpenClaw-style) MediaUnderstanding []MediaUnderstandingOutput `json:"media_understanding,omitempty"` MediaUnderstandingDecisions []MediaUnderstandingDecision `json:"media_understanding_decisions,omitempty"` + // Timing information + FirstTokenAtMs int64 `json:"first_token_at_ms,omitempty"` // Unix ms of first token + + // Thinking/reasoning content (embedded, not separate) + ThinkingTokenCount int `json:"thinking_token_count,omitempty"` // Number of thinking tokens + + // History exclusion + ExcludeFromHistory bool `json:"exclude_from_history,omitempty"` // Exclude from LLM context (e.g., welcome messages) + // Multimodal history: media attached to this message for re-injection into prompts. - MediaURL string `json:"media_url,omitempty"` // mxc:// URL for user-sent media + MediaURL string `json:"media_url,omitempty"` // mxc:// URL for user-sent media (image, PDF, audio, video) MimeType string `json:"mime_type,omitempty"` // MIME type of user-sent media } @@ -304,28 +321,65 @@ func (mm *MessageMetadata) CopyFrom(other any) { if !ok || src == nil { return } - mm.CopyFromBase(&src.BaseMessageMetadata) + if src.Role != "" { + mm.Role = src.Role + } + if src.Body != "" { + mm.Body = src.Body + } if src.CompletionID != "" { mm.CompletionID = src.CompletionID } + if src.FinishReason != "" { + mm.FinishReason = src.FinishReason + } + if src.PromptTokens != 0 { + mm.PromptTokens = src.PromptTokens + } + if src.CompletionTokens != 0 { + mm.CompletionTokens = src.CompletionTokens + } if src.Model != "" { mm.Model = src.Model } + if src.ReasoningTokens != 0 { + mm.ReasoningTokens = src.ReasoningTokens + } if src.HasToolCalls { mm.HasToolCalls = true } - if src.Transcript != "" { - mm.Transcript = src.Transcript + + // Copy new fields + if src.TurnID != "" { + mm.TurnID = src.TurnID + } + if src.AgentID != "" { + mm.AgentID = src.AgentID + } + if len(src.ToolCalls) > 0 { + mm.ToolCalls = src.ToolCalls + } + if src.CanonicalSchema != "" { + mm.CanonicalSchema = src.CanonicalSchema + } + if len(src.CanonicalUIMessage) > 0 { + mm.CanonicalUIMessage = src.CanonicalUIMessage + } + if src.StartedAtMs != 0 { + mm.StartedAtMs = src.StartedAtMs } if src.FirstTokenAtMs != 0 { mm.FirstTokenAtMs = src.FirstTokenAtMs } + if src.CompletedAtMs != 0 { + mm.CompletedAtMs = src.CompletedAtMs + } + if src.ThinkingContent != "" { + mm.ThinkingContent = src.ThinkingContent + } if src.ThinkingTokenCount != 0 { mm.ThinkingTokenCount = src.ThinkingTokenCount } - if src.ExcludeFromHistory { - mm.ExcludeFromHistory = true - } } var _ database.MetaMerger = (*MessageMetadata)(nil) diff --git a/pkg/connector/model_fallback.go b/pkg/connector/model_fallback.go index be93215c..47aa9483 100644 --- a/pkg/connector/model_fallback.go +++ b/pkg/connector/model_fallback.go @@ -28,18 +28,9 @@ func (e *NonFallbackError) Unwrap() error { } // modelFallbackChain returns the model chain to try in order. -// Room-level overrides take priority and disable fallbacks. +// Agent-defined fallbacks are used for agent rooms; model rooms only use their selected model. func (oc *AIClient) modelFallbackChain(ctx context.Context, meta *PortalMetadata) []string { - // Explicit room-level model overrides should not fall back. - if meta != nil && strings.TrimSpace(meta.Model) != "" { - return dedupeModels([]string{ResolveAlias(meta.Model)}) - } - - agentID := "" - if meta != nil { - agentID = meta.AgentID - } - + agentID := resolveAgentID(meta) if agentID != "" { store := NewAgentStoreAdapter(oc) agent, err := store.GetAgentByID(ctx, agentID) @@ -72,8 +63,7 @@ func (oc *AIClient) overrideModel(meta *PortalMetadata, modelID string) *PortalM return nil } metaCopy := *meta - metaCopy.Model = modelID - metaCopy.Capabilities = getModelCapabilities(modelID, oc.findModelInfo(modelID)) + metaCopy.RuntimeModelOverride = ResolveAlias(modelID) return &metaCopy } diff --git a/pkg/connector/prompt_params.go b/pkg/connector/prompt_params.go index 32152b3f..fb4e9472 100644 --- a/pkg/connector/prompt_params.go +++ b/pkg/connector/prompt_params.go @@ -5,8 +5,5 @@ func resolvePromptWorkspaceDir() string { } func resolvePromptReasoningLevel(meta *PortalMetadata) string { - if meta != nil && meta.EmitThinking { - return "on" - } return "" } diff --git a/pkg/connector/provisioning.go b/pkg/connector/provisioning.go index 25f78a7e..31f0ad32 100644 --- a/pkg/connector/provisioning.go +++ b/pkg/connector/provisioning.go @@ -1,23 +1,34 @@ package connector import ( + "context" "encoding/json" + "errors" + "fmt" + "io" "net/http" + "slices" + "strings" + "time" + "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/exhttp" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" + + "github.com/beeper/ai-bridge/pkg/agents" + "github.com/beeper/ai-bridge/pkg/agents/toolpolicy" ) -// ProvisioningAPI handles the provisioning endpoints for user defaults +// ProvisioningAPI handles login-scoped profile, agent, and MCP configuration. type ProvisioningAPI struct { log zerolog.Logger connector *OpenAIConnector prov bridgev2.IProvisioningAPI } -// initProvisioning sets up the provisioning API endpoints +// initProvisioning sets up the provisioning API endpoints. func (oc *OpenAIConnector) initProvisioning() { c, ok := oc.br.Matrix.(bridgev2.MatrixConnectorWithProvisioning) if !ok { @@ -36,13 +47,24 @@ func (oc *OpenAIConnector) initProvisioning() { } r.HandleFunc("GET /v1/models", api.handleListModels) - r.HandleFunc("GET /v1/defaults", api.handleGetDefaults) - r.HandleFunc("PUT /v1/defaults", api.handleSetDefaults) + r.HandleFunc("GET /v1/profile", api.handleGetProfile) + r.HandleFunc("PUT /v1/profile", api.handlePutProfile) + r.HandleFunc("GET /v1/agents", api.handleListAgents) + r.HandleFunc("POST /v1/agents", api.handleCreateAgent) + r.HandleFunc("GET /v1/agents/{agent_id}", api.handleGetAgent) + r.HandleFunc("PUT /v1/agents/{agent_id}", api.handleUpdateAgent) + r.HandleFunc("DELETE /v1/agents/{agent_id}", api.handleDeleteAgent) + r.HandleFunc("GET /v1/mcp/servers", api.handleListMCPServers) + r.HandleFunc("POST /v1/mcp/servers", api.handleCreateMCPServer) + r.HandleFunc("PUT /v1/mcp/servers/{name}", api.handleUpdateMCPServer) + r.HandleFunc("DELETE /v1/mcp/servers/{name}", api.handleDeleteMCPServer) + r.HandleFunc("POST /v1/mcp/servers/{name}/connect", api.handleConnectMCPServer) + r.HandleFunc("POST /v1/mcp/servers/{name}/disconnect", api.handleDisconnectMCPServer) - oc.br.Log.Info().Msg("Registered provisioning API endpoints for user defaults") + oc.br.Log.Info().Msg("Registered provisioning API endpoints for AI profile, agents, and MCP") } -// getLogin gets the user login from the request +// getLogin gets the preferred user login from the request. func (api *ProvisioningAPI) getLogin(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { user := api.prov.GetUser(r) login := api.connector.getPreferredUserLogin(r.Context(), user) @@ -53,13 +75,25 @@ func (api *ProvisioningAPI) getLogin(w http.ResponseWriter, r *http.Request) *br return login } -// handleListModels handles GET /v1/models -func (api *ProvisioningAPI) handleListModels(w http.ResponseWriter, r *http.Request) { +func (api *ProvisioningAPI) getClient(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, *AIClient) { login := api.getLogin(w, r) if login == nil { + return nil, nil + } + client, ok := login.Client.(*AIClient) + if !ok || client == nil { + mautrix.MUnknown.WithMessage("Invalid AI client for login.").Write(w) + return nil, nil + } + return login, client +} + +// handleListModels handles GET /v1/models. +func (api *ProvisioningAPI) handleListModels(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { return } - client := login.Client.(*AIClient) models, err := client.listAvailableModels(r.Context(), false) if err != nil { mautrix.MUnknown.WithMessage("Couldn't list models: %v.", err).Write(w) @@ -68,91 +102,655 @@ func (api *ProvisioningAPI) handleListModels(w http.ResponseWriter, r *http.Requ exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{"models": models}) } -// handleGetDefaults handles GET /v1/defaults -func (api *ProvisioningAPI) handleGetDefaults(w http.ResponseWriter, r *http.Request) { +type profilePayload struct { + Name *string `json:"name,omitempty"` + Occupation *string `json:"occupation,omitempty"` + AboutUser *string `json:"about_user,omitempty"` + CustomInstructions *string `json:"custom_instructions,omitempty"` + Timezone *string `json:"timezone,omitempty"` +} + +type profileResponse struct { + Name string `json:"name,omitempty"` + Occupation string `json:"occupation,omitempty"` + AboutUser string `json:"about_user,omitempty"` + CustomInstructions string `json:"custom_instructions,omitempty"` + Timezone string `json:"timezone,omitempty"` +} + +func profileResponseFromMeta(meta *UserLoginMetadata) profileResponse { + var resp profileResponse + if meta == nil { + return resp + } + if meta.Profile != nil { + resp.Name = meta.Profile.Name + resp.Occupation = meta.Profile.Occupation + resp.AboutUser = meta.Profile.AboutUser + resp.CustomInstructions = meta.Profile.CustomInstructions + } + resp.Timezone = meta.Timezone + return resp +} + +func applyProfilePayload(meta *UserLoginMetadata, payload profilePayload) error { + if meta == nil { + return errors.New("missing metadata") + } + if payload.Name != nil || payload.Occupation != nil || payload.AboutUser != nil || payload.CustomInstructions != nil { + if meta.Profile == nil { + meta.Profile = &UserProfile{} + } + if payload.Name != nil { + meta.Profile.Name = strings.TrimSpace(*payload.Name) + } + if payload.Occupation != nil { + meta.Profile.Occupation = strings.TrimSpace(*payload.Occupation) + } + if payload.AboutUser != nil { + meta.Profile.AboutUser = strings.TrimSpace(*payload.AboutUser) + } + if payload.CustomInstructions != nil { + meta.Profile.CustomInstructions = strings.TrimSpace(*payload.CustomInstructions) + } + if meta.Profile.Name == "" && meta.Profile.Occupation == "" && meta.Profile.AboutUser == "" && meta.Profile.CustomInstructions == "" { + meta.Profile = nil + } + } + if payload.Timezone != nil { + tz := strings.TrimSpace(*payload.Timezone) + if tz != "" { + if _, err := time.LoadLocation(tz); err != nil { + return fmt.Errorf("invalid timezone: %w", err) + } + } + meta.Timezone = tz + } + return nil +} + +// handleGetProfile handles GET /v1/profile. +func (api *ProvisioningAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) { login := api.getLogin(w, r) if login == nil { return } + exhttp.WriteJSONResponse(w, http.StatusOK, profileResponseFromMeta(loginMetadata(login))) +} + +// handlePutProfile handles PUT /v1/profile. +func (api *ProvisioningAPI) handlePutProfile(w http.ResponseWriter, r *http.Request) { + login := api.getLogin(w, r) + if login == nil { + return + } + var req profilePayload + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) + return + } meta := loginMetadata(login) - resp := map[string]any{} - if meta.Defaults != nil { - if meta.Defaults.Model != "" { - resp["model"] = meta.Defaults.Model + if err := applyProfilePayload(meta, req); err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } + if err := login.Save(r.Context()); err != nil { + mautrix.MUnknown.WithMessage("Couldn't save changes: %v.", err).Write(w) + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, profileResponseFromMeta(meta)) +} + +type agentUpsertRequest struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + Model string `json:"model,omitempty"` + ModelFallback []string `json:"model_fallback,omitempty"` + SystemPrompt string `json:"system_prompt,omitempty"` + PromptMode string `json:"prompt_mode,omitempty"` + Tools *toolpolicy.ToolPolicyConfig `json:"tools,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + IdentityName string `json:"identity_name,omitempty"` + IdentityPersona string `json:"identity_persona,omitempty"` + HeartbeatPrompt string `json:"heartbeat_prompt,omitempty"` + MemorySearch any `json:"memory_search,omitempty"` +} + +func writeAgentError(w http.ResponseWriter, err error) { + switch { + case errors.Is(err, agents.ErrAgentNotFound): + mautrix.MNotFound.WithMessage("Agent not found.").Write(w) + case errors.Is(err, agents.ErrAgentIsPreset): + mautrix.MForbidden.WithMessage("Preset agents can't be modified.").Write(w) + case errors.Is(err, agents.ErrMissingAgentID), errors.Is(err, agents.ErrMissingAgentName): + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + default: + mautrix.MUnknown.WithMessage("Couldn't process agent: %v.", err).Write(w) + } +} + +func normalizeAgentUpsertRequest(req agentUpsertRequest, pathID string) (*agents.AgentDefinition, error) { + agentID := strings.TrimSpace(pathID) + if agentID == "" { + agentID = strings.TrimSpace(req.ID) + } + if agentID == "" { + agentID = uuid.NewString() + } + content := &AgentDefinitionContent{ + ID: agentID, + Name: strings.TrimSpace(req.Name), + Description: strings.TrimSpace(req.Description), + AvatarURL: strings.TrimSpace(req.AvatarURL), + Model: strings.TrimSpace(req.Model), + ModelFallback: normalizeStringList(req.ModelFallback), + SystemPrompt: strings.TrimSpace(req.SystemPrompt), + PromptMode: strings.TrimSpace(req.PromptMode), + Temperature: req.Temperature, + ReasoningEffort: strings.TrimSpace(req.ReasoningEffort), + IdentityName: strings.TrimSpace(req.IdentityName), + IdentityPersona: strings.TrimSpace(req.IdentityPersona), + HeartbeatPrompt: strings.TrimSpace(req.HeartbeatPrompt), + MemorySearch: req.MemorySearch, + } + content.Tools = req.Tools + return FromAgentDefinitionContent(content), nil +} + +func normalizeStringList(input []string) []string { + if len(input) == 0 { + return nil + } + out := make([]string, 0, len(input)) + for _, item := range input { + item = strings.TrimSpace(item) + if item == "" { + continue } - if meta.Defaults.SystemPrompt != "" { - resp["system_prompt"] = meta.Defaults.SystemPrompt + out = append(out, item) + } + if len(out) == 0 { + return nil + } + return out +} + +func validateAgentModels(ctx context.Context, client *AIClient, agent *agents.AgentDefinition) error { + if agent == nil || client == nil { + return nil + } + models := []string{} + if strings.TrimSpace(agent.Model.Primary) != "" { + models = append(models, strings.TrimSpace(agent.Model.Primary)) + } + models = append(models, normalizeStringList(agent.Model.Fallbacks)...) + for _, model := range models { + resolved, valid, err := client.resolveModelID(ctx, model) + if err != nil { + return err } - if meta.Defaults.Temperature != nil { - resp["temperature"] = meta.Defaults.Temperature + if !valid || resolved == "" { + return fmt.Errorf("invalid model: %s", model) } - if meta.Defaults.ReasoningEffort != "" { - resp["reasoning_effort"] = meta.Defaults.ReasoningEffort + if model == agent.Model.Primary { + agent.Model.Primary = resolved + continue } } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + if len(agent.Model.Fallbacks) > 0 { + resolvedFallbacks := make([]string, 0, len(agent.Model.Fallbacks)) + for _, fallback := range normalizeStringList(agent.Model.Fallbacks) { + resolved, valid, err := client.resolveModelID(ctx, fallback) + if err != nil { + return err + } + if !valid || resolved == "" { + return fmt.Errorf("invalid model: %s", fallback) + } + resolvedFallbacks = append(resolvedFallbacks, resolved) + } + agent.Model.Fallbacks = resolvedFallbacks + } + return nil } -// ReqSetDefaults is the request body for PUT /v1/defaults -type ReqSetDefaults struct { - Model *string `json:"model,omitempty"` - SystemPrompt *string `json:"system_prompt,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - ReasoningEffort *string `json:"reasoning_effort,omitempty"` +func agentResponse(agent *agents.AgentDefinition) *AgentDefinitionContent { + if agent == nil { + return nil + } + return ToAgentDefinitionContent(agent) } -// handleSetDefaults handles PUT /v1/defaults -func (api *ProvisioningAPI) handleSetDefaults(w http.ResponseWriter, r *http.Request) { - login := api.getLogin(w, r) - if login == nil { +func listAgentsForResponse(ctx context.Context, store *AgentStoreAdapter) ([]*AgentDefinitionContent, error) { + loaded, err := store.LoadAgents(ctx) + if err != nil { + return nil, err + } + ids := make([]string, 0, len(loaded)) + for id := range loaded { + ids = append(ids, id) + } + slices.Sort(ids) + out := make([]*AgentDefinitionContent, 0, len(ids)) + for _, id := range ids { + if agent := loaded[id]; agent != nil { + out = append(out, agentResponse(agent)) + } + } + return out, nil +} + +func (api *ProvisioningAPI) handleListAgents(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + items, err := listAgentsForResponse(r.Context(), NewAgentStoreAdapter(client)) + if err != nil { + mautrix.MUnknown.WithMessage("Couldn't list agents: %v.", err).Write(w) return } - var req ReqSetDefaults + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{"agents": items}) +} + +func (api *ProvisioningAPI) handleGetAgent(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + agentID := strings.TrimSpace(r.PathValue("agent_id")) + agent, err := NewAgentStoreAdapter(client).GetAgentByID(r.Context(), agentID) + if err != nil { + writeAgentError(w, err) + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, agentResponse(agent)) +} + +func (api *ProvisioningAPI) handleCreateAgent(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + var req agentUpsertRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) return } + agent, err := normalizeAgentUpsertRequest(req, "") + if err != nil { + mautrix.MBadJSON.WithMessage("Invalid agent payload: %v.", err).Write(w) + return + } + if err = validateAgentModels(r.Context(), client, agent); err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } + store := NewAgentStoreAdapter(client) + if existing, err := store.GetAgentByID(r.Context(), agent.ID); err == nil && existing != nil { + mautrix.MInvalidParam.WithMessage("Agent %s already exists.", agent.ID).Write(w) + return + } + if err = store.SaveAgent(r.Context(), agent); err != nil { + writeAgentError(w, err) + return + } + exhttp.WriteJSONResponse(w, http.StatusCreated, agentResponse(agent)) +} +func (api *ProvisioningAPI) handleUpdateAgent(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + var req agentUpsertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) + return + } + agentID := strings.TrimSpace(r.PathValue("agent_id")) + agent, err := normalizeAgentUpsertRequest(req, agentID) + if err != nil { + mautrix.MBadJSON.WithMessage("Invalid agent payload: %v.", err).Write(w) + return + } + if err = validateAgentModels(r.Context(), client, agent); err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } + store := NewAgentStoreAdapter(client) + existing, err := store.GetAgentByID(r.Context(), agentID) + if err != nil { + writeAgentError(w, err) + return + } + if existing != nil && existing.IsPreset { + writeAgentError(w, agents.ErrAgentIsPreset) + return + } + if err = store.SaveAgent(r.Context(), agent); err != nil { + writeAgentError(w, err) + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, agentResponse(agent)) +} + +func (api *ProvisioningAPI) handleDeleteAgent(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + agentID := strings.TrimSpace(r.PathValue("agent_id")) + if err := NewAgentStoreAdapter(client).DeleteAgent(r.Context(), agentID); err != nil { + writeAgentError(w, err) + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{"deleted": true}) +} + +type mcpServerUpsertRequest struct { + Name string `json:"name,omitempty"` + Transport string `json:"transport,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + AuthType string `json:"auth_type,omitempty"` + Token string `json:"token,omitempty"` + AuthURL string `json:"auth_url,omitempty"` + Kind string `json:"kind,omitempty"` +} + +type mcpConnectRequest struct { + Token string `json:"token,omitempty"` +} + +type mcpServerResponse struct { + Name string `json:"name"` + Source string `json:"source,omitempty"` + Transport string `json:"transport,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + AuthType string `json:"auth_type,omitempty"` + TokenSet bool `json:"token_set,omitempty"` + AuthURL string `json:"auth_url,omitempty"` + Connected bool `json:"connected,omitempty"` + Kind string `json:"kind,omitempty"` +} + +func mcpServerResponseFromNamed(server namedMCPServer) mcpServerResponse { + cfg := normalizeMCPServerConfig(server.Config) + return mcpServerResponse{ + Name: server.Name, + Source: server.Source, + Transport: cfg.Transport, + Endpoint: cfg.Endpoint, + Command: cfg.Command, + Args: slices.Clone(cfg.Args), + AuthType: cfg.AuthType, + TokenSet: cfg.Token != "" || cfg.AuthType == "none", + AuthURL: cfg.AuthURL, + Connected: cfg.Connected, + Kind: cfg.Kind, + } +} + +func normalizeMCPRequest(req mcpServerUpsertRequest, pathName string) (string, MCPServerConfig, error) { + name := "" + if strings.TrimSpace(pathName) != "" { + name = normalizeMCPServerName(pathName) + } + if name == "" { + name = normalizeMCPServerName(req.Name) + } + if name == "" { + return "", MCPServerConfig{}, errors.New("server name is required") + } + cfg := normalizeMCPServerConfig(MCPServerConfig{ + Transport: strings.TrimSpace(req.Transport), + Endpoint: strings.TrimSpace(req.Endpoint), + Command: strings.TrimSpace(req.Command), + Args: normalizeStringList(req.Args), + AuthType: strings.TrimSpace(req.AuthType), + Token: strings.TrimSpace(req.Token), + AuthURL: strings.TrimSpace(req.AuthURL), + Kind: strings.TrimSpace(req.Kind), + Connected: false, + }) + if !mcpServerHasTarget(cfg) { + return "", MCPServerConfig{}, errors.New("mcp server target is required") + } + return name, cfg, nil +} + +func validateMCPConfig(client *AIClient, cfg MCPServerConfig) error { + if mcpServerUsesStdio(cfg) && !client.isMCPStdioEnabled() { + return errors.New("stdio MCP servers are disabled") + } + if cfg.Transport == mcpTransportStreamableHTTP && !isLikelyHTTPURL(cfg.Endpoint) { + return errors.New("invalid MCP endpoint") + } + return nil +} + +func resolveNamedMCPServer(client *AIClient, name string) (namedMCPServer, error) { + target, _, err := resolveMCPServerArg(client, []string{name}) + return target, err +} + +func ensureLoginMCPServer(meta *UserLoginMetadata) { + if meta.ServiceTokens == nil { + meta.ServiceTokens = &ServiceTokens{} + } + if meta.ServiceTokens.MCPServers == nil { + meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + } +} + +func (api *ProvisioningAPI) handleListMCPServers(w http.ResponseWriter, r *http.Request) { + _, client := api.getClient(w, r) + if client == nil { + return + } + servers := client.configuredMCPServers() + items := make([]mcpServerResponse, 0, len(servers)) + for _, server := range servers { + items = append(items, mcpServerResponseFromNamed(server)) + } + slices.SortFunc(items, func(a, b mcpServerResponse) int { return strings.Compare(a.Name, b.Name) }) + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{"servers": items}) +} + +func (api *ProvisioningAPI) handleCreateMCPServer(w http.ResponseWriter, r *http.Request) { + login, client := api.getClient(w, r) + if client == nil { + return + } + var req mcpServerUpsertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) + return + } + name, cfg, err := normalizeMCPRequest(req, "") + if err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } + if err = validateMCPConfig(client, cfg); err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } meta := loginMetadata(login) - if meta.Defaults == nil { - meta.Defaults = &UserDefaults{} + ensureLoginMCPServer(meta) + if _, exists := meta.ServiceTokens.MCPServers[name]; exists { + mautrix.MInvalidParam.WithMessage("MCP server %s already exists.", name).Write(w) + return } + setLoginMCPServer(meta, name, cfg) + if err = login.Save(r.Context()); err != nil { + mautrix.MUnknown.WithMessage("Couldn't save MCP server: %v.", err).Write(w) + return + } + client.invalidateMCPToolCache() + exhttp.WriteJSONResponse(w, http.StatusCreated, mcpServerResponseFromNamed(namedMCPServer{Name: name, Config: cfg, Source: "login"})) +} - // Validate and apply model - if req.Model != nil { - client := login.Client.(*AIClient) - if valid, _ := client.validateModel(r.Context(), *req.Model); !valid { - mautrix.MInvalidParam.WithMessage("Invalid model: %s.", *req.Model).Write(w) - return - } - meta.Defaults.Model = *req.Model +func (api *ProvisioningAPI) handleUpdateMCPServer(w http.ResponseWriter, r *http.Request) { + login, client := api.getClient(w, r) + if client == nil { + return + } + var req mcpServerUpsertRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) + return + } + name := strings.TrimSpace(r.PathValue("name")) + _, err := resolveNamedMCPServer(client, name) + if err != nil && err.Error() != "not found" { + mautrix.MInvalidParam.WithMessage("Couldn't resolve MCP server %s.", name).Write(w) + return + } + resolvedName, cfg, err := normalizeMCPRequest(req, name) + if err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return } + if err = validateMCPConfig(client, cfg); err != nil { + mautrix.MInvalidParam.WithMessage("%v.", err).Write(w) + return + } + meta := loginMetadata(login) + setLoginMCPServer(meta, resolvedName, cfg) + if err = login.Save(r.Context()); err != nil { + mautrix.MUnknown.WithMessage("Couldn't save MCP server: %v.", err).Write(w) + return + } + client.invalidateMCPToolCache() + exhttp.WriteJSONResponse(w, http.StatusOK, mcpServerResponseFromNamed(namedMCPServer{Name: resolvedName, Config: cfg, Source: "login"})) +} - // Apply other settings - if req.SystemPrompt != nil { - meta.Defaults.SystemPrompt = *req.SystemPrompt +func (api *ProvisioningAPI) handleDeleteMCPServer(w http.ResponseWriter, r *http.Request) { + login, client := api.getClient(w, r) + if client == nil { + return } - if req.Temperature != nil { - if *req.Temperature < 0 || *req.Temperature > 2 { - mautrix.MInvalidParam.WithMessage("Temperature must be between 0 and 2.").Write(w) - return + name := strings.TrimSpace(r.PathValue("name")) + target, err := resolveNamedMCPServer(client, name) + if err != nil { + mautrix.MNotFound.WithMessage("MCP server not found.").Write(w) + return + } + loginServers := client.loginMCPServers() + if _, ok := loginServers[target.Name]; !ok { + mautrix.MForbidden.WithMessage("Config-managed MCP servers can't be deleted here.").Write(w) + return + } + meta := loginMetadata(login) + clearLoginMCPServer(meta, target.Name) + if err = login.Save(r.Context()); err != nil { + mautrix.MUnknown.WithMessage("Couldn't remove MCP server: %v.", err).Write(w) + return + } + client.invalidateMCPToolCache() + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{"deleted": true}) +} + +func connectMCPServer(ctx context.Context, client *AIClient, login *bridgev2.UserLogin, name string, tokenOverride string) (namedMCPServer, int, error) { + target, err := resolveNamedMCPServer(client, name) + if err != nil { + return namedMCPServer{}, 0, err + } + cfg := normalizeMCPServerConfig(target.Config) + if tokenOverride != "" && !mcpServerUsesStdio(cfg) { + cfg.Token = strings.TrimSpace(tokenOverride) + if cfg.Token != "" && cfg.AuthType == "none" { + cfg.AuthType = "bearer" + } + } + if !mcpServerHasTarget(cfg) { + return namedMCPServer{}, 0, errors.New("mcp server target is required") + } + if mcpServerNeedsToken(cfg) && cfg.Token == "" { + cfg.Connected = false + setLoginMCPServer(loginMetadata(login), target.Name, cfg) + if err = login.Save(ctx); err != nil { + return namedMCPServer{}, 0, err + } + client.invalidateMCPToolCache() + return namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}, 0, errors.New("mcp server token is required") + } + cfg.Connected = true + count, connectErr := client.verifyMCPServerConnection(ctx, namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}) + if connectErr != nil { + cfg.Connected = false + setLoginMCPServer(loginMetadata(login), target.Name, cfg) + if err = login.Save(ctx); err != nil { + return namedMCPServer{}, 0, err } - meta.Defaults.Temperature = req.Temperature + client.invalidateMCPToolCache() + return namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}, 0, connectErr + } + setLoginMCPServer(loginMetadata(login), target.Name, cfg) + if err = login.Save(ctx); err != nil { + return namedMCPServer{}, 0, err + } + client.invalidateMCPToolCache() + return namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}, count, nil +} + +func (api *ProvisioningAPI) handleConnectMCPServer(w http.ResponseWriter, r *http.Request) { + login, client := api.getClient(w, r) + if client == nil { + return } - if req.ReasoningEffort != nil { - switch *req.ReasoningEffort { - case "", "none", "low", "medium", "high", "xhigh": - meta.Defaults.ReasoningEffort = *req.ReasoningEffort - default: - mautrix.MInvalidParam.WithMessage("reasoning_effort must be one of: none, low, medium, high, xhigh.").Write(w) + var req mcpConnectRequest + if r.Body != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil && !errors.Is(err, io.EOF) { + mautrix.MBadJSON.WithMessage("Invalid JSON: %v.", err).Write(w) return } } - if err := login.Save(r.Context()); err != nil { - mautrix.MUnknown.WithMessage("Couldn't save changes: %v.", err).Write(w) + server, count, err := connectMCPServer(r.Context(), client, login, strings.TrimSpace(r.PathValue("name")), strings.TrimSpace(req.Token)) + if err != nil { + code := http.StatusBadRequest + if mcpCallLikelyAuthError(err) { + code = http.StatusUnauthorized + } else if strings.Contains(err.Error(), "not found") { + code = http.StatusNotFound + } + exhttp.WriteJSONResponse(w, code, map[string]any{ + "error": err.Error(), + "server": mcpServerResponseFromNamed(server), + }) return } + exhttp.WriteJSONResponse(w, http.StatusOK, map[string]any{ + "server": mcpServerResponseFromNamed(server), + "tool_count": count, + }) +} - // Return updated defaults - api.handleGetDefaults(w, r) +func (api *ProvisioningAPI) handleDisconnectMCPServer(w http.ResponseWriter, r *http.Request) { + login, client := api.getClient(w, r) + if client == nil { + return + } + target, err := resolveNamedMCPServer(client, strings.TrimSpace(r.PathValue("name"))) + if err != nil { + mautrix.MNotFound.WithMessage("MCP server not found.").Write(w) + return + } + cfg := normalizeMCPServerConfig(target.Config) + cfg.Connected = false + setLoginMCPServer(loginMetadata(login), target.Name, cfg) + if err = login.Save(r.Context()); err != nil { + mautrix.MUnknown.WithMessage("Couldn't disconnect MCP server: %v.", err).Write(w) + return + } + client.invalidateMCPToolCache() + exhttp.WriteJSONResponse(w, http.StatusOK, mcpServerResponseFromNamed(namedMCPServer{Name: target.Name, Config: cfg, Source: "login"})) } diff --git a/pkg/connector/provisioning_test.go b/pkg/connector/provisioning_test.go new file mode 100644 index 00000000..8ee4075d --- /dev/null +++ b/pkg/connector/provisioning_test.go @@ -0,0 +1,126 @@ +package connector + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/agents/toolpolicy" +) + +func strPtr(v string) *string { + return &v +} + +func TestApplyProfilePayloadSetsAndClearsFields(t *testing.T) { + meta := &UserLoginMetadata{} + err := applyProfilePayload(meta, profilePayload{ + Name: strPtr(" Batuhan "), + Occupation: strPtr(" Product engineer "), + AboutUser: strPtr(" Works on AI tooling "), + CustomInstructions: strPtr(" Be direct "), + Timezone: strPtr("Europe/Amsterdam"), + }) + if err != nil { + t.Fatalf("applyProfilePayload returned error: %v", err) + } + if meta.Profile == nil { + t.Fatalf("expected profile to be initialized") + } + if meta.Profile.Name != "Batuhan" || meta.Profile.Occupation != "Product engineer" || meta.Profile.AboutUser != "Works on AI tooling" || meta.Profile.CustomInstructions != "Be direct" { + t.Fatalf("unexpected profile contents: %+v", meta.Profile) + } + if meta.Timezone != "Europe/Amsterdam" { + t.Fatalf("expected timezone to be stored, got %q", meta.Timezone) + } + + err = applyProfilePayload(meta, profilePayload{ + Name: strPtr(""), + Occupation: strPtr(""), + AboutUser: strPtr(""), + CustomInstructions: strPtr(""), + Timezone: strPtr(""), + }) + if err != nil { + t.Fatalf("applyProfilePayload clear returned error: %v", err) + } + if meta.Profile != nil { + t.Fatalf("expected empty profile to be cleared, got %+v", meta.Profile) + } + if meta.Timezone != "" { + t.Fatalf("expected timezone to be cleared, got %q", meta.Timezone) + } +} + +func TestApplyProfilePayloadRejectsInvalidTimezone(t *testing.T) { + meta := &UserLoginMetadata{} + err := applyProfilePayload(meta, profilePayload{Timezone: strPtr("Mars/Olympus")}) + if err == nil { + t.Fatal("expected invalid timezone error") + } +} + +func TestNormalizeAgentUpsertRequestCreatesDefinition(t *testing.T) { + agent, err := normalizeAgentUpsertRequest(agentUpsertRequest{ + Name: "Helper", + Description: "Useful", + Model: "openai/gpt-5.2", + ModelFallback: []string{" anthropic/claude-sonnet-4.6 ", ""}, + SystemPrompt: "Be useful", + PromptMode: "append", + Tools: &toolpolicy.ToolPolicyConfig{Allow: []string{"web_search"}}, + IdentityName: "Beep", + IdentityPersona: "Helpful assistant", + }, "") + if err != nil { + t.Fatalf("normalizeAgentUpsertRequest returned error: %v", err) + } + if agent == nil { + t.Fatal("expected agent definition") + } + if agent.ID == "" { + t.Fatal("expected generated agent id") + } + if agent.Name != "Helper" { + t.Fatalf("expected name Helper, got %q", agent.Name) + } + if agent.Model.Primary != "openai/gpt-5.2" { + t.Fatalf("expected primary model to be preserved, got %q", agent.Model.Primary) + } + if len(agent.Model.Fallbacks) != 1 || agent.Model.Fallbacks[0] != "anthropic/claude-sonnet-4.6" { + t.Fatalf("unexpected fallback models: %#v", agent.Model.Fallbacks) + } + if agent.Tools == nil || len(agent.Tools.Allow) != 1 || agent.Tools.Allow[0] != "web_search" { + t.Fatalf("expected tools policy to be preserved, got %#v", agent.Tools) + } +} + +func TestNormalizeMCPRequestValidatesAndNormalizes(t *testing.T) { + name, cfg, err := normalizeMCPRequest(mcpServerUpsertRequest{ + Name: " Search ", + Transport: "streamable_http", + Endpoint: "https://example.com/mcp", + AuthType: "bearer", + Token: "secret", + }, "") + if err != nil { + t.Fatalf("normalizeMCPRequest returned error: %v", err) + } + if name != "search" { + t.Fatalf("expected normalized name 'search', got %q", name) + } + if cfg.Transport != mcpTransportStreamableHTTP { + t.Fatalf("expected transport %q, got %q", mcpTransportStreamableHTTP, cfg.Transport) + } + if cfg.Endpoint != "https://example.com/mcp" { + t.Fatalf("expected endpoint to be preserved, got %q", cfg.Endpoint) + } + if cfg.Token != "secret" { + t.Fatalf("expected token to be preserved, got %q", cfg.Token) + } +} + +func TestNormalizeMCPRequestRejectsMissingTarget(t *testing.T) { + _, _, err := normalizeMCPRequest(mcpServerUpsertRequest{Name: "search"}, "") + if err == nil { + t.Fatal("expected missing target error") + } +} diff --git a/pkg/connector/response_finalization_test.go b/pkg/connector/response_finalization_test.go index ed0dfb74..0dcc444f 100644 --- a/pkg/connector/response_finalization_test.go +++ b/pkg/connector/response_finalization_test.go @@ -33,7 +33,7 @@ func TestBuildFinalEditUIMessage_IncludesSourceAndFileParts(t *testing.T) { streamui.ApplyChunk(&state.ui, map[string]any{"type": "text-delta", "id": "text-1", "delta": "hello"}) streamui.ApplyChunk(&state.ui, map[string]any{"type": "text-end", "id": "text-1"}) - ui := oc.buildFinalEditUIMessage(state, &PortalMetadata{Model: "gpt-4o"}, nil) + ui := oc.buildFinalEditUIMessage(state, simpleModeTestMeta("openai/gpt-4o"), nil) if ui == nil { t.Fatalf("expected final edit UI message") } diff --git a/pkg/connector/room_settings_event_content_test.go b/pkg/connector/room_settings_event_content_test.go deleted file mode 100644 index 160e6ffd..00000000 --- a/pkg/connector/room_settings_event_content_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package connector - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestRoomSettingsEventContentUnmarshalAgentID(t *testing.T) { - var content RoomSettingsEventContent - if err := json.Unmarshal([]byte(`{"agent_id":"beeper"}`), &content); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if content.AgentID != "beeper" { - t.Fatalf("expected agent_id to populate AgentID, got %q", content.AgentID) - } -} - -func TestRoomSettingsEventContentMarshalUsesCanonicalAgentID(t *testing.T) { - raw, err := json.Marshal(RoomSettingsEventContent{ - Model: "openai/gpt-5", - AgentID: "beeper", - }) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - encoded := string(raw) - if !strings.Contains(encoded, `"agent_id":"beeper"`) { - t.Fatalf("expected canonical agent_id field, got %s", encoded) - } - if strings.Contains(encoded, "default_agent_id") { - t.Fatalf("did not expect legacy default_agent_id field in %s", encoded) - } -} diff --git a/pkg/connector/scheduler_cron.go b/pkg/connector/scheduler_cron.go index 63d70b17..b5f26878 100644 --- a/pkg/connector/scheduler_cron.go +++ b/pkg/connector/scheduler_cron.go @@ -326,12 +326,12 @@ func (s *schedulerRuntime) executeCronJob(ctx context.Context, record *scheduled if meta == nil { meta = &PortalMetadata{} } - meta.AgentID = normalizedCronAgentID(&record.Job.AgentID) - if model := strings.TrimSpace(record.Job.Payload.Model); model != "" { - meta.Model = model + if portal.OtherUserID == "" { + portal.OtherUserID = agentUserID(normalizedCronAgentID(&record.Job.AgentID)) } - if thinking := strings.TrimSpace(record.Job.Payload.Thinking); thinking != "" { - meta.ReasoningEffort = thinking + meta.ResolvedTarget = resolveTargetFromGhostID(portal.OtherUserID) + if model := strings.TrimSpace(record.Job.Payload.Model); model != "" { + meta.RuntimeModelOverride = ResolveAlias(model) } if record.Job.Delivery != nil && record.Job.Delivery.Mode == integrationcron.DeliveryAnnounce { meta.DisabledTools = appendMissingDisabledTool(meta.DisabledTools, "message") @@ -391,7 +391,7 @@ func (s *schedulerRuntime) resolveCronDeliveryTarget(agentID string, delivery *i return true } meta := portalMeta(portal) - return meta != nil && normalizeAgentID(meta.AgentID) != normalizeAgentID(agentID) + return meta != nil && normalizeAgentID(resolveAgentID(meta)) != normalizeAgentID(agentID) }, LastActiveRoomID: func(agentID string) string { if portal := s.client.lastActivePortal(agentID); portal != nil && portal.MXID != "" { diff --git a/pkg/connector/scheduler_rooms.go b/pkg/connector/scheduler_rooms.go index fb9daa28..ad810d92 100644 --- a/pkg/connector/scheduler_rooms.go +++ b/pkg/connector/scheduler_rooms.go @@ -14,7 +14,6 @@ func (s *schedulerRuntime) ensureCronRoomLocked(ctx context.Context, record *sch } portalID := fmt.Sprintf("cron:%s:%s", normalizeAgentID(record.Job.AgentID), strings.TrimSpace(record.Job.ID)) portal, err := s.getOrCreateScheduledPortal(ctx, portalID, fmt.Sprintf("Cron: %s", strings.TrimSpace(record.Job.Name)), func(meta *PortalMetadata) { - meta.AgentID = normalizeAgentID(record.Job.AgentID) if meta.ModuleMeta == nil { meta.ModuleMeta = make(map[string]any) } @@ -29,6 +28,10 @@ func (s *schedulerRuntime) ensureCronRoomLocked(ctx context.Context, record *sch if err != nil { return err } + portal.OtherUserID = agentUserID(normalizeAgentID(record.Job.AgentID)) + if err := portal.Save(ctx); err != nil { + return err + } record.RoomID = portal.MXID.String() return nil } @@ -39,7 +42,6 @@ func (s *schedulerRuntime) ensureHeartbeatRoomLocked(ctx context.Context, state } portalID := fmt.Sprintf("heartbeat:%s", normalizeAgentID(state.AgentID)) portal, err := s.getOrCreateScheduledPortal(ctx, portalID, fmt.Sprintf("Heartbeat: %s", state.AgentID), func(meta *PortalMetadata) { - meta.AgentID = normalizeAgentID(state.AgentID) if meta.ModuleMeta == nil { meta.ModuleMeta = make(map[string]any) } @@ -54,6 +56,10 @@ func (s *schedulerRuntime) ensureHeartbeatRoomLocked(ctx context.Context, state if err != nil { return err } + portal.OtherUserID = agentUserID(normalizeAgentID(state.AgentID)) + if err := portal.Save(ctx); err != nil { + return err + } state.RoomID = portal.MXID.String() return nil } diff --git a/pkg/connector/session_greeting_test.go b/pkg/connector/session_greeting_test.go index 1a629d1e..032e615f 100644 --- a/pkg/connector/session_greeting_test.go +++ b/pkg/connector/session_greeting_test.go @@ -10,7 +10,7 @@ import ( func TestMaybePrependSessionGreeting(t *testing.T) { ctx := context.Background() - meta := &PortalMetadata{AgentID: "beeper"} + meta := agentModeTestMeta("beeper") prompt := []openai.ChatCompletionMessageParamUnion{} out := maybePrependSessionGreeting(ctx, nil, meta, prompt, zerolog.Nop()) diff --git a/pkg/connector/sessions_tools.go b/pkg/connector/sessions_tools.go index 05c59370..1f70890b 100644 --- a/pkg/connector/sessions_tools.go +++ b/pkg/connector/sessions_tools.go @@ -30,7 +30,7 @@ func shouldExcludeModelVisiblePortal(meta *PortalMetadata) bool { if meta == nil { return false } - if isModuleInternalRoom(meta) || meta.IsBuilderRoom { + if isModuleInternalRoom(meta) { return true } return strings.TrimSpace(meta.SubagentParentRoomID) != "" @@ -125,11 +125,7 @@ func (oc *AIClient) executeSessionsList(ctx context.Context, portal *bridgev2.Po entry["updatedAt"] = updatedAt } if meta != nil { - model := meta.Model - if strings.TrimSpace(model) == "" { - model = oc.effectiveModel(meta) - } - if model != "" { + if model := oc.effectiveModel(meta); model != "" { entry["model"] = model } } diff --git a/pkg/connector/sessions_visibility_test.go b/pkg/connector/sessions_visibility_test.go index 16da5d17..2e7e1934 100644 --- a/pkg/connector/sessions_visibility_test.go +++ b/pkg/connector/sessions_visibility_test.go @@ -12,7 +12,6 @@ func TestShouldExcludeModelVisiblePortal(t *testing.T) { meta PortalMetadata }{ {name: "cron", meta: PortalMetadata{ModuleMeta: map[string]any{"cron": map[string]any{"is_internal_room": true}}}}, - {name: "builder", meta: PortalMetadata{IsBuilderRoom: true}}, {name: "subagent", meta: PortalMetadata{SubagentParentRoomID: "!parent:example.com"}}, } for _, tc := range cases { diff --git a/pkg/connector/simple_mode_prompt.go b/pkg/connector/simple_mode_prompt.go index e16a64d9..9de24340 100644 --- a/pkg/connector/simple_mode_prompt.go +++ b/pkg/connector/simple_mode_prompt.go @@ -16,19 +16,14 @@ import ( // Simple mode uses a single system prompt with only the current time appended. func (oc *AIClient) buildSimpleModeSystemPrompt(meta *PortalMetadata) string { base := defaultSimpleModeSystemPrompt - if meta != nil { - if v := strings.TrimSpace(meta.SystemPrompt); v != "" { - base = v - } - } - timezone, _ := oc.resolveUserTimezone() now := formatCurrentTimeForPrompt(timezone) - lines := []string{ - strings.TrimSpace(base), - "Current time: " + now, + lines := []string{strings.TrimSpace(base)} + if supplement := strings.TrimSpace(oc.profilePromptSupplement()); supplement != "" { + lines = append(lines, supplement) } + lines = append(lines, "Current time: "+now) return strings.TrimSpace(strings.Join(lines, "\n")) } diff --git a/pkg/connector/simple_mode_prompt_test.go b/pkg/connector/simple_mode_prompt_test.go index dd29db0c..03b55c1d 100644 --- a/pkg/connector/simple_mode_prompt_test.go +++ b/pkg/connector/simple_mode_prompt_test.go @@ -19,8 +19,11 @@ func TestSimpleModePrompt_HasSingleSystemPromptWithTimeAndWebSearch(t *testing.T } meta := &PortalMetadata{ - IsSimpleMode: true, - // No SystemPrompt override: should use defaultSimpleModeSystemPrompt. + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-5.2"), + ModelID: "openai/gpt-5.2", + }, } out, err := client.buildPromptWithLinkContext(context.Background(), nil, meta, "hello", nil, "") @@ -69,9 +72,10 @@ func TestSimpleModePrompt_NoWebSearchHintEvenWhenConfigured(t *testing.T) { } meta := &PortalMetadata{ - IsSimpleMode: true, - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-5.2"), + ModelID: "openai/gpt-5.2", }, } @@ -112,7 +116,7 @@ func TestSimpleModePrompt_LatestUserMessageUnchanged_NoLinkContext_NoMessageID(t }, } - meta := &PortalMetadata{IsSimpleMode: true} + meta := &PortalMetadata{ResolvedTarget: &ResolvedTarget{Kind: ResolvedTargetModel, GhostID: modelUserID("openai/gpt-5.2"), ModelID: "openai/gpt-5.2"}} latest := "check this: https://example.com" out, err := client.buildPromptWithLinkContext(context.Background(), nil, meta, latest, nil, "$evt") @@ -139,7 +143,7 @@ func TestSimpleModePrompt_LatestUserMessageUnchanged_NoLinkContext_NoMessageID(t func TestBuildMatrixInboundBody_SimpleModeBypassesEnvelopeAndSenderMeta(t *testing.T) { client := &AIClient{} - meta := &PortalMetadata{IsSimpleMode: true} + meta := &PortalMetadata{ResolvedTarget: &ResolvedTarget{Kind: ResolvedTargetModel, GhostID: modelUserID("openai/gpt-5.2"), ModelID: "openai/gpt-5.2"}} got := client.buildMatrixInboundBody(context.Background(), nil, meta, nil, " hi ", "Alice", "Room", true) if got != "hi" { diff --git a/pkg/connector/status_text.go b/pkg/connector/status_text.go index d2a5e09f..09153dac 100644 --- a/pkg/connector/status_text.go +++ b/pkg/connector/status_text.go @@ -81,40 +81,14 @@ func (oc *AIClient) buildStatusText( sb.WriteString(fmt.Sprintf("Group activation: %s\n", activation)) } - thinking := oc.defaultThinkLevel(meta) - reasoning := strings.TrimSpace(meta.ReasoningEffort) - if reasoning == "" { - if meta.EmitThinking { - reasoning = "on" - } else { - reasoning = "off" - } - } - verbose := strings.TrimSpace(meta.VerboseLevel) - if verbose == "" { - verbose = "off" - } - elevated := strings.TrimSpace(meta.ElevatedLevel) - if elevated == "" { - elevated = "off" - } - sendPolicy := normalizeSendPolicyMode(meta.SendPolicy) - if sendPolicy == "" { - sendPolicy = "allow" - } - sendLabel := "on" - if sendPolicy == "deny" { - sendLabel = "off" - } - responseMode := string(oc.getAgentResponseMode(meta)) + caps := oc.getRoomCapabilities(ctx, meta) sb.WriteString(fmt.Sprintf( - "Options: think=%s reasoning=%s verbose=%s elevated=%s send=%s response=%s\n", - thinking, - reasoning, - verbose, - elevated, - sendLabel, - responseMode, + "Features: tools=%t vision=%t audio=%t video=%t pdf=%t\n", + caps.SupportsToolCalling, + caps.SupportsVision, + caps.SupportsAudio, + caps.SupportsVideo, + caps.SupportsPDF, )) queueDepth := 0 @@ -391,7 +365,7 @@ func (oc *AIClient) buildToolsStatusText(meta *PortalMetadata) string { sb.WriteString(fmt.Sprintf(" [%s] %s: %s%s\n", status, tool.Name, desc, reason)) } - if meta != nil && !meta.Capabilities.SupportsToolCalling { + if meta != nil && !oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling { sb.WriteString(fmt.Sprintf("\nNote: Current model (%s) may not support tool calling.\n", oc.effectiveModel(meta))) } diff --git a/pkg/connector/streaming_chat_completions.go b/pkg/connector/streaming_chat_completions.go index e2b1f534..752d3fd1 100644 --- a/pkg/connector/streaming_chat_completions.go +++ b/pkg/connector/streaming_chat_completions.go @@ -72,7 +72,7 @@ func (oc *AIClient) streamChatCompletions( if len(enabledTools) > 0 { params.Tools = append(params.Tools, ToOpenAIChatTools(enabledTools, &oc.log)...) } - if meta.Capabilities.SupportsToolCalling && chatHasAgent { + if oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling && chatHasAgent { if !oc.isBuilderRoom(portal) { var enabledSessions []*tools.Tool for _, tool := range tools.SessionTools() { diff --git a/pkg/connector/streaming_continuation.go b/pkg/connector/streaming_continuation.go index e81eb0cc..d43813cf 100644 --- a/pkg/connector/streaming_continuation.go +++ b/pkg/connector/streaming_continuation.go @@ -93,7 +93,7 @@ func (oc *AIClient) buildContinuationParams( } // Add session tools for non-boss agent rooms (needed for multi-turn tool use) - if meta.Capabilities.SupportsToolCalling && agentID != "" && !(hasBossAgent(meta) || agents.IsBossAgent(agentID)) { + if oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling && agentID != "" && !(hasBossAgent(meta) || agents.IsBossAgent(agentID)) { var enabledSessions []*tools.Tool for _, tool := range tools.SessionTools() { if oc.isToolEnabled(meta, tool.Name) { diff --git a/pkg/connector/streaming_init_test.go b/pkg/connector/streaming_init_test.go index 4d6a72cd..7e4cd76e 100644 --- a/pkg/connector/streaming_init_test.go +++ b/pkg/connector/streaming_init_test.go @@ -13,8 +13,12 @@ import ( func TestPrepareStreamingRun_SimpleModeClearsReplyTarget(t *testing.T) { oc := &AIClient{} meta := &PortalMetadata{ - IsSimpleMode: true, - SendPolicy: "deny", + SendPolicy: "deny", + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-5.2"), + ModelID: "openai/gpt-5.2", + }, } evt := &event.Event{ ID: id.EventID("$evt"), diff --git a/pkg/connector/streaming_params.go b/pkg/connector/streaming_params.go index dac01b14..b6f94424 100644 --- a/pkg/connector/streaming_params.go +++ b/pkg/connector/streaming_params.go @@ -59,7 +59,7 @@ func (oc *AIClient) buildResponsesAPIParams(ctx context.Context, portal *bridgev log.Debug().Int("count", len(enabledTools)).Msg("Added builtin function tools") } - if meta.Capabilities.SupportsToolCalling && hasAgent { + if oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling && hasAgent { // Add session tools for non-boss rooms if !hasBossAgent(meta) && !oc.isBuilderRoom(portal) { var enabledSessions []*tools.Tool diff --git a/pkg/connector/streaming_tool_selection.go b/pkg/connector/streaming_tool_selection.go index 87924315..080e1a2a 100644 --- a/pkg/connector/streaming_tool_selection.go +++ b/pkg/connector/streaming_tool_selection.go @@ -5,7 +5,7 @@ import "context" // selectedBuiltinToolsForTurn returns builtin tools exposed to the model for a turn. // Simple mode stays minimal: it only exposes web_search when tool-calling is supported. func (oc *AIClient) selectedBuiltinToolsForTurn(ctx context.Context, meta *PortalMetadata) []ToolDefinition { - if meta == nil || !meta.Capabilities.SupportsToolCalling { + if meta == nil || !oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling { return nil } diff --git a/pkg/connector/streaming_tool_selection_test.go b/pkg/connector/streaming_tool_selection_test.go index f66e85cb..91a27e41 100644 --- a/pkg/connector/streaming_tool_selection_test.go +++ b/pkg/connector/streaming_tool_selection_test.go @@ -19,7 +19,11 @@ func TestSelectedBuiltinToolsForTurn_SimpleModeEnablesOnlyWebSearch(t *testing.T } meta := &PortalMetadata{ - IsSimpleMode: true, + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID("openai/gpt-5.2"), + ModelID: "openai/gpt-5.2", + }, Capabilities: ModelCapabilities{ SupportsToolCalling: true, }, diff --git a/pkg/connector/system_prompts.go b/pkg/connector/system_prompts.go index 89f705f0..df71e331 100644 --- a/pkg/connector/system_prompts.go +++ b/pkg/connector/system_prompts.go @@ -34,18 +34,8 @@ func buildGroupIntro(roomName string, activation string) string { } func buildVerboseSystemHint(meta *PortalMetadata) string { - if meta == nil { - return "" - } - level := strings.ToLower(strings.TrimSpace(meta.VerboseLevel)) - switch level { - case "on": - return "Verbosity: on. Provide a bit more detail and context when helpful, but stay focused." - case "full": - return "Verbosity: full. Be thorough and detailed. Explain assumptions and reasoning clearly, without unnecessary fluff." - default: - return "" - } + _ = meta + return "" } func buildSessionIdentityHint(portal *bridgev2.Portal, meta *PortalMetadata) string { diff --git a/pkg/connector/target_test_helpers_test.go b/pkg/connector/target_test_helpers_test.go new file mode 100644 index 00000000..05c5361b --- /dev/null +++ b/pkg/connector/target_test_helpers_test.go @@ -0,0 +1,21 @@ +package connector + +func simpleModeTestMeta(modelID string) *PortalMetadata { + return &PortalMetadata{ + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetModel, + GhostID: modelUserID(modelID), + ModelID: modelID, + }, + } +} + +func agentModeTestMeta(agentID string) *PortalMetadata { + return &PortalMetadata{ + ResolvedTarget: &ResolvedTarget{ + Kind: ResolvedTargetAgent, + GhostID: agentUserID(agentID), + AgentID: agentID, + }, + } +} diff --git a/pkg/connector/tool_descriptions.go b/pkg/connector/tool_descriptions.go index db0a1caf..c41c5e6e 100644 --- a/pkg/connector/tool_descriptions.go +++ b/pkg/connector/tool_descriptions.go @@ -11,7 +11,7 @@ func (oc *AIClient) toolDescriptionForPortal(meta *PortalMetadata, toolName stri name := strings.TrimSpace(toolName) switch name { case toolspec.ImageName: - if meta != nil && meta.Capabilities.SupportsVision { + if meta != nil && oc.getModelCapabilitiesForMeta(meta).SupportsVision { return toolspec.ImageDescriptionVisionHint } case toolspec.WebSearchName: diff --git a/pkg/connector/tool_policy.go b/pkg/connector/tool_policy.go index 5bc0dc99..142ab4aa 100644 --- a/pkg/connector/tool_policy.go +++ b/pkg/connector/tool_policy.go @@ -45,12 +45,12 @@ func (oc *AIClient) isToolAvailable(meta *PortalMetadata, toolName string) (bool return available, source, reason } - if !meta.Capabilities.SupportsToolCalling { + if !oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling { return false, SourceModelLimit, "Model does not support tools" } - if agenttools.IsBossTool(toolName) && !(meta.IsBuilderRoom || hasBossAgent(meta)) { - return false, SourceGlobalDefault, "Builder room only" + if agenttools.IsBossTool(toolName) && !hasBossAgent(meta) { + return false, SourceGlobalDefault, "Boss agent only" } // Tool runtime prerequisites (API keys, services, etc.). These are intentionally @@ -190,7 +190,7 @@ func (oc *AIClient) toolNamesForPortal(meta *PortalMetadata) []string { for _, tool := range agenttools.SessionTools() { nameSet[tool.Name] = struct{}{} } - if meta != nil && (meta.IsBuilderRoom || hasBossAgent(meta)) { + if meta != nil && hasBossAgent(meta) { for _, tool := range agenttools.BossTools() { nameSet[tool.Name] = struct{}{} } diff --git a/pkg/connector/tools.go b/pkg/connector/tools.go index 7b8b2c49..a73c91df 100644 --- a/pkg/connector/tools.go +++ b/pkg/connector/tools.go @@ -1542,10 +1542,7 @@ func executeSessionStatus(ctx context.Context, args map[string]any) (string, err dayOfWeek := now.Weekday().String() // Get model info - model := meta.Model - if model == "" { - model = btc.Client.effectiveModel(meta) - } + model := btc.Client.effectiveModel(meta) // Parse provider from model string (format: "provider/model" or just "model") provider := "unknown" @@ -1555,15 +1552,9 @@ func executeSessionStatus(ctx context.Context, args map[string]any) (string, err modelName = parsedModel } - // Get context/token info from metadata - maxContext := meta.MaxContextMessages - if maxContext == 0 { - maxContext = 12 // default - } - maxTokens := meta.MaxCompletionTokens - if maxTokens == 0 { - maxTokens = 512 // default - } + // Get context/token info from the effective runtime only. + maxContext := btc.Client.getModelContextWindow(meta) + maxTokens := btc.Client.effectiveMaxTokens(meta) // Build session info sessionID := string(btc.Portal.PortalKey.ID) @@ -1575,74 +1566,10 @@ func executeSessionStatus(ctx context.Context, args map[string]any) (string, err title = "Untitled" } - // Handle model change if requested (OpenClaw-style "model" alias supported) - var modelChanged string - newModel := "" - if raw, ok := args["set_model"].(string); ok && strings.TrimSpace(raw) != "" { - newModel = strings.TrimSpace(raw) - } else if raw, ok := args["model"].(string); ok && strings.TrimSpace(raw) != "" { - newModel = strings.TrimSpace(raw) - } - - if newModel != "" { - if strings.EqualFold(newModel, "default") || strings.EqualFold(newModel, "reset") { - metaCopy := *meta - metaCopy.Model = "" - effective := btc.Client.effectiveModel(&metaCopy) - if err := btc.Client.validateDMModelSwitch(btc.Portal, meta, effective); err != nil { - return "", dmModelSwitchBlockedError(effective) - } - - // Clear override and recompute capabilities from effective model - meta.Model = "" - effective = btc.Client.effectiveModel(meta) - meta.Capabilities = getModelCapabilities(effective, btc.Client.findModelInfo(effective)) - if err := btc.Portal.Save(ctx); err != nil { - return "", fmt.Errorf("couldn't save model reset: %w", err) - } - btc.Portal.UpdateBridgeInfo(ctx) - btc.Client.ensureGhostDisplayName(ctx, effective) - modelChanged = fmt.Sprintf("\n\nModel reset to %s.", effective) - model = effective - if parsedProvider, parsedModel := splitModelProvider(effective); parsedProvider != "" && parsedModel != "" { - provider = parsedProvider - modelName = parsedModel - } else { - modelName = effective - } - } else { - resolvedModel, valid, err := btc.Client.resolveModelID(ctx, newModel) - if err != nil || !valid || resolvedModel == "" { - return "", fmt.Errorf("invalid model: %s", newModel) - } - if err := btc.Client.validateDMModelSwitch(btc.Portal, meta, resolvedModel); err != nil { - return "", dmModelSwitchBlockedError(resolvedModel) - } - - // Update the model in metadata - meta.Model = resolvedModel - meta.Capabilities = getModelCapabilities(resolvedModel, btc.Client.findModelInfo(resolvedModel)) - // Save portal metadata - if err := btc.Portal.Save(ctx); err != nil { - return "", fmt.Errorf("couldn't save model change: %w", err) - } - btc.Portal.UpdateBridgeInfo(ctx) - btc.Client.ensureGhostDisplayName(ctx, resolvedModel) - modelChanged = fmt.Sprintf("\n\nModel set to %s.", resolvedModel) - model = resolvedModel - if parsedProvider, parsedModel := splitModelProvider(resolvedModel); parsedProvider != "" && parsedModel != "" { - provider = parsedProvider - modelName = parsedModel - } else { - modelName = resolvedModel - } - } - } - // Get agent info if available agentInfo := "" - if meta.AgentID != "" { - agentInfo = fmt.Sprintf("\nAgent: %s", meta.AgentID) + if agentID := resolveAgentID(meta); agentID != "" { + agentInfo = fmt.Sprintf("\nAgent: %s", agentID) } // Build status card similar to OpenClaw @@ -1656,7 +1583,7 @@ Provider: %s Max Context: %d messages Max Tokens: %d -Session: %s + Session: %s Chat: %s%s%s`, timeStr, timezone, now.Format("MST"), dayOfWeek, @@ -1667,7 +1594,7 @@ Chat: %s%s%s`, sessionID, title, agentInfo, - modelChanged, + "", ) return status, nil diff --git a/pkg/connector/tools_message_actions.go b/pkg/connector/tools_message_actions.go index bc5a7357..8b6a56cc 100644 --- a/pkg/connector/tools_message_actions.go +++ b/pkg/connector/tools_message_actions.go @@ -151,11 +151,7 @@ func executeMessageMemberInfo(ctx context.Context, args map[string]any, btc *Bri } else { store := NewAgentStoreAdapter(btc.Client) if agent, err := store.GetAgentByID(ctx, agentID); err == nil && agent != nil && agent.Model.Primary != "" { - if override := btc.Client.agentModelOverride(agentID); override != "" { - modelID = ResolveAlias(override) - } else { - modelID = ResolveAlias(agent.Model.Primary) - } + modelID = ResolveAlias(agent.Model.Primary) } } } diff --git a/pkg/connector/trace.go b/pkg/connector/trace.go index 51c1e702..ec95b91f 100644 --- a/pkg/connector/trace.go +++ b/pkg/connector/trace.go @@ -1,30 +1,16 @@ package connector -import ( - "strings" - - "github.com/beeper/ai-bridge/pkg/shared/stringutil" -) - func traceLevel(meta *PortalMetadata) string { - if meta == nil { - return "off" - } - if level, ok := stringutil.NormalizeEnum(meta.VerboseLevel, verboseLevelAliases); ok { - return level - } - level := strings.ToLower(strings.TrimSpace(meta.VerboseLevel)) - if level == "" { - return "off" - } - return level + _ = meta + return "off" } func traceEnabled(meta *PortalMetadata) bool { - level := traceLevel(meta) - return level == "on" || level == "full" + _ = meta + return false } func traceFull(meta *PortalMetadata) bool { - return traceLevel(meta) == "full" + _ = meta + return false } From 5e7539f2761dccbbb298ed751d2b9ca7344781ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 22:03:38 +0100 Subject: [PATCH 21/23] Remove legacy AI room controls --- pkg/connector/builder.go | 152 --- pkg/connector/chat.go | 9 +- pkg/connector/client.go | 8 + pkg/connector/commands.go | 1323 +------------------ pkg/connector/commands_manage.go | 61 - pkg/connector/commands_mcp.go | 472 ------- pkg/connector/commands_parity.go | 358 ----- pkg/connector/commands_simple.go | 86 -- pkg/connector/desktop_api_helpers.go | 36 + pkg/connector/group_activation.go | 20 +- pkg/connector/handleai.go | 6 - pkg/connector/heartbeat_execute.go | 6 - pkg/connector/integration_host.go | 2 +- pkg/connector/mcp_helpers.go | 181 +++ pkg/connector/metadata.go | 1 + pkg/connector/reasoning_fallback.go | 2 +- pkg/connector/remote_message.go | 4 +- pkg/connector/response_prefix.go | 125 +- pkg/connector/streaming_chat_completions.go | 4 +- pkg/connector/streaming_params.go | 6 +- pkg/connector/streaming_state.go | 3 - pkg/connector/subagent_spawn.go | 9 +- pkg/connector/system_prompts.go | 12 +- pkg/connector/tool_execution.go | 2 +- pkg/connector/typing_queue.go | 3 - 25 files changed, 260 insertions(+), 2631 deletions(-) delete mode 100644 pkg/connector/builder.go delete mode 100644 pkg/connector/commands_manage.go delete mode 100644 pkg/connector/commands_mcp.go delete mode 100644 pkg/connector/commands_simple.go create mode 100644 pkg/connector/desktop_api_helpers.go create mode 100644 pkg/connector/mcp_helpers.go diff --git a/pkg/connector/builder.go b/pkg/connector/builder.go deleted file mode 100644 index 3131403f..00000000 --- a/pkg/connector/builder.go +++ /dev/null @@ -1,152 +0,0 @@ -package connector - -import ( - "context" - "fmt" - - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - - "github.com/beeper/ai-bridge/pkg/agents" -) - -// Builder room constants -const ( - BuilderRoomSlug = "builder" - BuilderRoomName = "Manage AI Chats" -) - -// ensureBuilderRoom creates or retrieves the "Manage AI Chats" room. -// This special room is where users interact with the Boss agent to manage their agents and rooms. -func (oc *AIClient) ensureBuilderRoom(ctx context.Context) error { - meta := loginMetadata(oc.UserLogin) - - // Check if we already have a Builder room - if meta.BuilderRoomID != "" { - // Verify it still exists - portal, err := oc.UserLogin.Bridge.GetPortalByKey(ctx, networkid.PortalKey{ - ID: meta.BuilderRoomID, - Receiver: oc.UserLogin.ID, - }) - if err == nil && portal != nil && portal.MXID != "" { - oc.loggerForContext(ctx).Debug().Str("room_id", string(meta.BuilderRoomID)).Msg("Manage AI Chats room already exists") - return nil - } - // Room doesn't exist anymore, clear the reference - meta.BuilderRoomID = "" - } - - oc.loggerForContext(ctx).Info().Msg("Creating Manage AI Chats room") - - // Create the Builder room with Boss agent as the ghost - portal, chatInfo, err := oc.createBuilderRoom(ctx) - if err != nil { - return fmt.Errorf("failed to create builder room: %w", err) - } - - // Create Matrix room - if err := portal.CreateMatrixRoom(ctx, oc.UserLogin, chatInfo); err != nil { - cleanupPortal(ctx, oc, portal, "failed to create builder Matrix room") - return fmt.Errorf("failed to create matrix room for builder: %w", err) - } - - // Send welcome message (excluded from LLM history) - oc.sendWelcomeMessage(ctx, portal) - - // Store the Builder room ID - meta.BuilderRoomID = portal.PortalKey.ID - if err := oc.UserLogin.Save(ctx); err != nil { - meta.BuilderRoomID = "" - cleanupPortal(ctx, oc, portal, "failed to save BuilderRoomID") - return fmt.Errorf("failed to save BuilderRoomID: %w", err) - } - - oc.loggerForContext(ctx).Info(). - Str("portal_id", string(portal.PortalKey.ID)). - Str("mxid", string(portal.MXID)). - Msg("Manage AI Chats room created") - - return nil -} - -// createBuilderRoom creates the "Manage AI Chats" room portal and chat info. -func (oc *AIClient) createBuilderRoom(ctx context.Context) (*bridgev2.Portal, *bridgev2.ChatInfo, error) { - bossAgent := agents.GetBossAgent() - - // Use a standard chat initialization with the management room title - opts := PortalInitOpts{ - Title: BuilderRoomName, - } - - portal, chatInfo, err := oc.initPortalForChat(ctx, opts) - if err != nil { - return nil, nil, err - } - - // Set up the portal metadata for the Boss agent - pm := portalMeta(portal) - pm.Slug = BuilderRoomSlug // Override slug to "builder" - pm.AgentID = bossAgent.ID - pm.SystemPrompt = agents.BossSystemPrompt - pm.Model = bossAgent.Model.Primary // Explicit model - always use Boss agent's model - pm.IsBuilderRoom = true // Mark as protected from overrides - - // Use agent ghost for the Boss agent - modelID := pm.Model - if modelID == "" { - modelID = oc.effectiveModel(nil) - } - bossGhostID := agentUserID(bossAgent.ID) - bossDisplayName := oc.resolveAgentDisplayName(ctx, bossAgent) - portal.OtherUserID = bossGhostID - - if chatInfo != nil && chatInfo.Members != nil { - members := chatInfo.Members - if members.MemberMap == nil { - members.MemberMap = make(bridgev2.ChatMemberMap) - } - members.OtherUserID = bossGhostID - humanID := humanUserID(oc.UserLogin.ID) - humanMember := members.MemberMap[humanID] - humanMember.EventSender = bridgev2.EventSender{ - IsFromMe: true, - SenderLogin: oc.UserLogin.ID, - } - bossMember := members.MemberMap[bossGhostID] - bossMember.EventSender = bridgev2.EventSender{ - Sender: bossGhostID, - SenderLogin: oc.UserLogin.ID, - } - bossMember.UserInfo = &bridgev2.UserInfo{ - Name: ptr.Ptr(bossDisplayName), - IsBot: ptr.Ptr(true), - Identifiers: agentContactIdentifiers(bossAgent.ID, modelID, oc.findModelInfo(modelID)), - } - bossMember.MemberEventExtra = map[string]any{ - "displayname": bossDisplayName, - "com.beeper.ai.model_id": modelID, - "com.beeper.ai.agent": bossAgent.ID, - } - members.MemberMap = bridgev2.ChatMemberMap{ - humanID: humanMember, - bossGhostID: bossMember, - } - chatInfo.Members = members - } - - // Re-save portal with updated metadata - if err := portal.Save(ctx); err != nil { - return nil, nil, fmt.Errorf("failed to save portal with agent config: %w", err) - } - oc.ensureAgentGhostDisplayName(ctx, bossAgent.ID, modelID, bossDisplayName) - - return portal, chatInfo, nil -} - -// isBuilderRoom checks if a portal is the Builder room. -func (oc *AIClient) isBuilderRoom(portal *bridgev2.Portal) bool { - meta := loginMetadata(oc.UserLogin) - return meta.BuilderRoomID != "" && portal.PortalKey.ID == meta.BuilderRoomID -} diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 4a331588..12f9a121 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -1009,13 +1009,13 @@ func (oc *AIClient) createForkedChat( return nil, nil, err } - agentID := sourceMeta.AgentID + agentID := resolveAgentID(sourceMeta) if agentID != "" { pm := portalMeta(portal) - pm.AgentID = agentID modelID := oc.effectiveModel(pm) portal.OtherUserID = agentUserID(agentID) + pm.ResolvedTarget = resolveTargetFromGhostID(portal.OtherUserID) agentName := agentID agentAvatar := "" @@ -1660,14 +1660,11 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { // Set agent-specific metadata pm := portalMeta(portal) - pm.AgentID = beeperAgent.ID - if beeperAgent.SystemPrompt != "" { - pm.SystemPrompt = beeperAgent.SystemPrompt - } // Update the OtherUserID to be the agent ghost agentGhostID := agentUserID(beeperAgent.ID) portal.OtherUserID = agentGhostID + pm.ResolvedTarget = resolveTargetFromGhostID(agentGhostID) if err := portal.Save(ctx); err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to save portal with agent config") diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 97bb6786..9ea7eeb4 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -1562,6 +1562,14 @@ func (oc *AIClient) effectiveReasoningEffort(meta *PortalMetadata) string { if !oc.getModelCapabilitiesForMeta(meta).SupportsReasoning { return "" } + if meta != nil { + switch effort := strings.ToLower(strings.TrimSpace(meta.RuntimeReasoning)); effort { + case "low", "medium", "high": + return effort + case "off", "none": + return "" + } + } return defaultReasoningEffort } diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index db44f646..ffdc2a6b 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -3,32 +3,20 @@ package connector import ( "context" "errors" - "fmt" - "strconv" - "strings" - "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/networkid" - "github.com/beeper/ai-bridge/pkg/agents" - "github.com/beeper/ai-bridge/pkg/agents/toolpolicy" "github.com/beeper/ai-bridge/pkg/connector/commandregistry" ) -// HelpSectionAI is the help section for AI-related commands +// HelpSectionAI is the help section for AI-related commands. var HelpSectionAI = commands.HelpSection{ Name: "AI Chat", Order: 30, } -var reservedAgentIDs = map[string]struct{}{ - "none": {}, - "clear": {}, - "boss": {}, -} - func resolveLoginForCommand( ctx context.Context, portal *bridgev2.Portal, @@ -95,1319 +83,20 @@ func isValidAgentID(agentID string) bool { return true } -func splitQuotedArgs(input string) ([]string, error) { - var args []string - var current strings.Builder - var quote rune - escaped := false - - flush := func() { - if current.Len() > 0 { - args = append(args, current.String()) - current.Reset() - } - } - - for _, r := range input { - if escaped { - current.WriteRune(r) - escaped = false - continue - } - - if r == '\\' && quote != '\'' { - escaped = true - continue - } - - if quote != 0 { - if r == quote { - quote = 0 - continue - } - current.WriteRune(r) - continue - } - - switch r { - case '\'', '"': - quote = r - case ' ', '\t', '\n', '\r': - flush() - default: - current.WriteRune(r) - } - } - - if quote != 0 { - return nil, errors.New("unterminated quote") - } - if escaped { - current.WriteRune('\\') - } - flush() - return args, nil -} - -// CommandModel handles the !ai model command -var _ = registerAICommand(commandregistry.Definition{ - Name: "model", - Description: "Get or set the AI model for this chat", - Args: "[_model name_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnModel, -}) - -func fnModel(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - ce.Reply("Current model: %s", client.effectiveModel(meta)) - return - } - - if rejectBossOverrides(ce, meta, "Can't change the model in a room managed by the Boss agent.") { - return - } - - modelID := strings.TrimSpace(ce.Args[0]) - resolvedModel, valid, err := client.resolveModelID(ce.Ctx, modelID) - if err != nil || !valid || resolvedModel == "" { - ce.Reply("That model isn't available: %s", modelID) - return - } - - agentID := resolveAgentID(meta) - if agentID != "" { - ce.Reply("Can't set the room model while an agent is assigned. Edit the agent instead.") - return - } - - if err := client.validateDMModelSwitch(ce.Portal, meta, resolvedModel); err != nil { - ce.Reply("%s", dmModelSwitchGuidance(resolvedModel)) - return - } - - meta.Model = resolvedModel - meta.Capabilities = getModelCapabilities(resolvedModel, client.findModelInfo(resolvedModel)) - client.savePortalQuiet(ce.Ctx, ce.Portal, "model change") - client.ensureGhostDisplayName(ce.Ctx, resolvedModel) - ce.Reply("Model set to %s.", resolvedModel) -} - -// CommandTemp handles the !ai temp command -var _ = registerAICommand(commandregistry.Definition{ - Name: "temp", - Description: "Get or set the temperature (0-2)", - Args: "[_value_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTemp, -}) - -func fnTemp(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - if temp := client.effectiveTemperature(meta); temp > 0 { - ce.Reply("Current temperature: %.2f", temp) - } else { - ce.Reply("Current temperature: provider default (unset)") - } - return - } - - if rejectBossOverrides(ce, meta, "Can't change the temperature in a room managed by the Boss agent.") { - return - } - - var temp float64 - if _, err := fmt.Sscanf(ce.Args[0], "%f", &temp); err != nil || temp < 0 || temp > 2 { - ce.Reply("Invalid temperature. Must be between 0 and 2.") - return - } - - meta.Temperature = temp - client.savePortalQuiet(ce.Ctx, ce.Portal, "temperature change") - if temp > 0 { - ce.Reply("Temperature set to %.2f.", temp) - } else { - ce.Reply("Temperature reset to provider default (unset).") - } -} - -// CommandSystemPrompt handles the !ai system-prompt command -var _ = registerAICommand(commandregistry.Definition{ - Name: "system-prompt", - Description: "Get or set the system prompt (shows full constructed prompt)", - Args: "[_text_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnSystemPrompt, -}) - -func fnSystemPrompt(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - // Show full constructed prompt (agent + room levels merged) - fullPrompt := client.effectiveAgentPrompt(ce.Ctx, ce.Portal, meta) - if fullPrompt == "" { - fullPrompt = client.effectivePrompt(meta) - } - if fullPrompt == "" { - fullPrompt = "(none)" - } - // Truncate for display - totalLen := len(fullPrompt) - if totalLen > 500 { - fullPrompt = fullPrompt[:500] + "...\n\n(truncated, full prompt is " + strconv.Itoa(totalLen) + " chars)" - } - ce.Reply("Current system prompt:\n%s", fullPrompt) - return - } - - if rejectBossOverrides(ce, meta, "Can't change the system prompt in a room managed by the Boss agent.") { - return - } - - meta.SystemPrompt = ce.RawArgs - client.savePortalQuiet(ce.Ctx, ce.Portal, "system prompt change") - ce.Reply("System prompt updated.") -} - -// CommandContext handles the !ai context command -var _ = registerAICommand(commandregistry.Definition{ - Name: "context", - Description: "Get or set context message limit (1-100)", - Args: "[_count_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnContext, -}) - -func fnContext(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - ce.Reply("%s", client.buildContextStatus(ce.Ctx, ce.Portal, meta)) - return - } - - var limit int - if _, err := fmt.Sscanf(ce.Args[0], "%d", &limit); err != nil || limit < 1 || limit > 100 { - ce.Reply("Invalid context limit. Must be between 1 and 100.") - return - } - - meta.MaxContextMessages = limit - client.savePortalQuiet(ce.Ctx, ce.Portal, "context change") - ce.Reply("Context limit set to %d messages.", limit) -} - -// CommandTokens handles the !ai tokens command -var _ = registerAICommand(commandregistry.Definition{ - Name: "tokens", - Description: "Get or set max completion tokens (1-16384)", - Args: "[_count_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTokens, -}) - -func fnTokens(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - ce.Reply("Current max tokens: %d", client.effectiveMaxTokens(meta)) - return - } - - var tokens int - if _, err := fmt.Sscanf(ce.Args[0], "%d", &tokens); err != nil || tokens < 1 || tokens > 16384 { - ce.Reply("Invalid max tokens. Must be between 1 and 16384.") - return - } - - meta.MaxCompletionTokens = tokens - client.savePortalQuiet(ce.Ctx, ce.Portal, "tokens change") - ce.Reply("Max tokens set to %d.", tokens) -} - -// CommandConfig handles the !ai config command var _ = registerAICommand(commandregistry.Definition{ - Name: "config", - Description: "Show current chat configuration", + Name: "new", + Description: "Create a new chat of the same type (agent or model)", + Args: "[agent ]", Section: HelpSectionAI, RequiresPortal: true, RequiresLogin: true, - Handler: fnConfig, -}) - -// CommandDesktopAPI handles the !ai desktop-api command -var _ = registerAICommand(commandregistry.Definition{ - Name: "desktop-api", - Description: "Manage Beeper Desktop API instances", - Args: " [args]", - Section: HelpSectionAI, - RequiresPortal: false, - RequiresLogin: true, - Handler: fnDesktopAPI, -}) - -const desktopAPIManageUsage = "`!ai desktop-api list` | `!ai desktop-api add [baseURL]` | `!ai desktop-api add [baseURL]` | `!ai desktop-api remove [name]`." - -var _ = registerAICommand(commandregistry.Definition{ - Name: "commands", - Description: "Show AI command groups and recommended command forms", - Section: HelpSectionAI, - RequiresPortal: false, - RequiresLogin: true, - Handler: fnCommands, + Handler: fnNew, }) -func fnCommands(ce *commands.Event) { - ce.Reply( - "AI command groups (preferred forms):\n\n" + - "Core chat:\n" + - "- `!ai status`\n" + - "- `!ai config`\n" + - "- `!ai model [model]`\n" + - "- `!ai temp [0-2]`\n" + - "- `!ai system-prompt [text]`\n" + - "- `!ai context [1-100]`\n" + - "- `!ai tokens [1-16384]`\n" + - "- `!ai tools [on|off] [tool]`\n" + - "- `!ai typing [never|instant|thinking|message|off|reset|interval ]`\n" + - "- `!ai debounce [ms|off|default]`\n\n" + - "Controls:\n" + - "- `!ai think off|minimal|low|medium|high|xhigh`\n" + - "- `!ai verbose on|off|full`\n" + - "- `!ai reasoning off|on|low|medium|high|xhigh`\n" + - "- `!ai elevated off|on|ask|full`\n" + - "- `!ai activation mention|always` (group chats)\n" + - "- `!ai send on|off|inherit`\n" + - "- `!ai queue status|reset| [debounce:] [cap:] [drop:]`\n\n" + - "Session actions:\n" + - "- `!ai approve [reason]`\n" + - "- `!ai new` — New chat of the same type\n" + - "- `!ai reset` — Reset this session/thread\n" + - "- `!ai stop` — Abort the current run\n" + - "- `!ai fork`\n" + - "- `!ai regenerate`\n" + - "- `!ai title [text]`\n" + - "- `!ai timezone [IANA_TZ]`\n\n" + - "Simple Mode:\n" + - "- `!ai simple new [model]` — Create a new AI chat\n" + - "- `!ai simple list` — List available models\n\n" + - "Agents:\n" + - "- `!ai agent [id|none]`\n" + - "- `!ai agents`\n" + - "- `!ai create-agent [model] [system prompt...]`\n" + - "- `!ai delete-agent `\n" + - "- `!ai manage`\n\n" + - "Integrations:\n" + - "- MCP: `!ai mcp ...`\n" + - "- Desktop API: `!ai desktop-api ...`\n" + - "- Gravatar: `!ai gravatar ...`\n\n" + - "Use `!help` for the full command list from the command processor.", - ) -} - -func fnConfig(ce *commands.Event) { +func fnNew(ce *commands.Event) { client, meta, ok := requireClientMeta(ce) if !ok { return } - - roomCaps := client.getRoomCapabilities(ce.Ctx, meta) - tempLabel := "provider default" - if temp := client.effectiveTemperature(meta); temp > 0 { - tempLabel = fmt.Sprintf("%.2f", temp) - } - config := fmt.Sprintf( - "Current configuration:\n• Model: %s\n• Temperature: %s\n• Context: %d messages\n• Max tokens: %d\n• Vision: %v", - client.effectiveModel(meta), tempLabel, client.historyLimit(ce.Ctx, ce.Portal, meta), - client.effectiveMaxTokens(meta), roomCaps.SupportsVision) - ce.Reply(config) -} - -func fnSetDesktopAPIToken(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - token := strings.TrimSpace(ce.Args[0]) - baseURL := "" - if len(ce.Args) > 1 { - baseURL = strings.TrimSpace(strings.Join(ce.Args[1:], " ")) - } - if token == "" { - ce.Reply("Usage: `!ai desktop-api add [baseURL]`.") - return - } - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} - } - meta.ServiceTokens.DesktopAPI = token - if meta.ServiceTokens.DesktopAPIInstances == nil { - meta.ServiceTokens.DesktopAPIInstances = map[string]DesktopAPIInstance{} - } - defaultConfig := meta.ServiceTokens.DesktopAPIInstances[desktopDefaultInstance] - defaultConfig.Token = token - if baseURL != "" { - defaultConfig.BaseURL = baseURL - } - meta.ServiceTokens.DesktopAPIInstances[desktopDefaultInstance] = defaultConfig - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save the Desktop API token: %s", err) - return - } - if baseURL != "" { - ce.Reply("Desktop API token saved (base URL %s)", baseURL) - return - } - ce.Reply("Desktop API token saved") -} - -func fnAddDesktopAPIInstance(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - if len(ce.Args) < 2 { - ce.Reply("Usage: `!ai desktop-api add [baseURL]`.") - return - } - name := normalizeDesktopInstanceName(ce.Args[0]) - if name == "" { - ce.Reply("Instance name is required") - return - } - token := strings.TrimSpace(ce.Args[1]) - if token == "" { - ce.Reply("Token is required") - return - } - baseURL := "" - if len(ce.Args) > 2 { - baseURL = strings.TrimSpace(strings.Join(ce.Args[2:], " ")) - } - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} - } - if meta.ServiceTokens.DesktopAPIInstances == nil { - meta.ServiceTokens.DesktopAPIInstances = map[string]DesktopAPIInstance{} - } - config := meta.ServiceTokens.DesktopAPIInstances[name] - config.Token = token - if baseURL != "" { - config.BaseURL = baseURL - } - meta.ServiceTokens.DesktopAPIInstances[name] = config - if name == desktopDefaultInstance { - meta.ServiceTokens.DesktopAPI = token - } - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save the Desktop API instance: %s", err) - return - } - if baseURL != "" { - ce.Reply("Desktop API instance '%s' saved (base URL %s)", name, baseURL) - return - } - ce.Reply("Desktop API instance '%s' saved", name) -} - -func fnRemoveDesktopAPIInstance(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - name := "" - if len(ce.Args) == 0 { - if meta.ServiceTokens == nil || len(meta.ServiceTokens.DesktopAPIInstances) == 0 { - ce.Reply("Desktop API instances: none configured") - return - } - if len(meta.ServiceTokens.DesktopAPIInstances) > 1 { - ce.Reply("Multiple Desktop API instances configured. Provide a name. Use `!ai desktop-api list`.") - return - } - for instanceName := range meta.ServiceTokens.DesktopAPIInstances { - name = instanceName - break - } - } else { - name = normalizeDesktopInstanceName(strings.Join(ce.Args, " ")) - if name == "" { - ce.Reply("Instance name is required") - return - } - } - if meta.ServiceTokens == nil || meta.ServiceTokens.DesktopAPIInstances == nil { - ce.Reply("Desktop API instance '%s' not found", name) - return - } - if _, ok := meta.ServiceTokens.DesktopAPIInstances[name]; !ok { - ce.Reply("Desktop API instance '%s' not found", name) - return - } - delete(meta.ServiceTokens.DesktopAPIInstances, name) - if name == desktopDefaultInstance { - meta.ServiceTokens.DesktopAPI = "" - } - if len(meta.ServiceTokens.DesktopAPIInstances) == 0 { - meta.ServiceTokens.DesktopAPIInstances = nil - } - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't remove the Desktop API instance: %s", err) - return - } - ce.Reply("Desktop API instance '%s' removed", name) -} - -func fnListDesktopAPIInstances(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - instances := client.desktopAPIInstances() - if len(instances) == 0 { - ce.Reply("Desktop API instances: none configured") - return - } - lines := make([]string, 0, len(instances)) - for _, name := range client.desktopAPIInstanceNames() { - config := instances[name] - status := "set" - if strings.TrimSpace(config.Token) == "" { - status = "missing token" - } - if strings.TrimSpace(config.BaseURL) != "" { - lines = append(lines, fmt.Sprintf("- %s: %s (base URL %s)", name, status, strings.TrimSpace(config.BaseURL))) - } else { - lines = append(lines, fmt.Sprintf("- %s: %s", name, status)) - } - } - ce.Reply("Desktop API instances:\n%s", strings.Join(lines, "\n")) -} - -func fnDesktopAPI(ce *commands.Event) { - if len(ce.Args) == 0 { - ce.Reply("Usage: %s", desktopAPIManageUsage) - return - } - - sub := strings.ToLower(strings.TrimSpace(ce.Args[0])) - switch sub { - case "list": - ce.Args = ce.Args[1:] - fnListDesktopAPIInstances(ce) - return - case "add": - parsedName, parsedToken, parsedBaseURL, parsedErr := parseDesktopAPIAddArgs(ce.Args[1:]) - if parsedErr != nil { - ce.Reply("Usage: %s", desktopAPIManageUsage) - return - } - if parsedName == "" || parsedName == desktopDefaultInstance { - nextArgs := []string{parsedToken} - if parsedBaseURL != "" { - nextArgs = append(nextArgs, parsedBaseURL) - } - ce.Args = nextArgs - fnSetDesktopAPIToken(ce) - return - } - nextArgs := []string{parsedName, parsedToken} - if parsedBaseURL != "" { - nextArgs = append(nextArgs, parsedBaseURL) - } - ce.Args = nextArgs - fnAddDesktopAPIInstance(ce) - return - case "remove": - ce.Args = ce.Args[1:] - fnRemoveDesktopAPIInstance(ce) - return - default: - ce.Reply("Usage: %s", desktopAPIManageUsage) - } -} - -func parseDesktopAPIAddArgs(args []string) (name, token, baseURL string, err error) { - if len(args) == 0 { - return "", "", "", errors.New("missing args") - } - - trimmed := make([]string, 0, len(args)) - for _, raw := range args { - part := strings.TrimSpace(raw) - if part != "" { - trimmed = append(trimmed, part) - } - } - if len(trimmed) == 0 { - return "", "", "", errors.New("missing args") - } - - if len(trimmed) == 1 { - return "", trimmed[0], "", nil - } - - if len(trimmed) == 2 { - if isLikelyHTTPURL(trimmed[1]) { - return "", trimmed[0], trimmed[1], nil - } - return normalizeDesktopInstanceName(trimmed[0]), trimmed[1], "", nil - } - - return normalizeDesktopInstanceName(trimmed[0]), trimmed[1], strings.TrimSpace(strings.Join(trimmed[2:], " ")), nil -} - -func isLikelyHTTPURL(raw string) bool { - value := strings.TrimSpace(strings.ToLower(raw)) - return strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") -} - -// CommandDebounce handles the !ai debounce command -var _ = registerAICommand(commandregistry.Definition{ - Name: "debounce", - Description: "Get or set message debounce delay (ms), 'off' to disable, 'default' to reset", - Args: "[_delay_|off|default]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnDebounce, -}) - -func fnDebounce(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - // Show current setting - switch { - case meta.DebounceMs < 0: - ce.Reply("Message debouncing is **disabled** for this room") - case meta.DebounceMs == 0: - ce.Reply("Message debounce: **%d ms** (default)", DefaultDebounceMs) - default: - ce.Reply("Message debounce: **%d ms**", meta.DebounceMs) - } - return - } - - arg := strings.ToLower(ce.Args[0]) - switch arg { - case "off", "disable", "disabled": - meta.DebounceMs = -1 - client.savePortalQuiet(ce.Ctx, ce.Portal, "debounce disabled") - ce.Reply("Message debouncing disabled for this room") - case "default", "reset": - meta.DebounceMs = 0 - client.savePortalQuiet(ce.Ctx, ce.Portal, "debounce reset") - ce.Reply("Message debounce reset to default (%d ms)", DefaultDebounceMs) - default: - // Parse as integer - delay, err := strconv.Atoi(arg) - if err != nil || delay < 0 || delay > 10000 { - ce.Reply("Invalid debounce delay. Use a number 0-10000 (ms), 'off', or 'default'.") - return - } - meta.DebounceMs = delay - client.savePortalQuiet(ce.Ctx, ce.Portal, "debounce change") - if delay == 0 { - ce.Reply("Message debounce reset to default (%d ms)", DefaultDebounceMs) - } else { - ce.Reply("Message debounce set to %d ms.", delay) - } - } -} - -// CommandTyping handles the !ai typing command -var _ = registerAICommand(commandregistry.Definition{ - Name: "typing", - Description: "Get or set typing indicator behavior for this chat", - Args: "[never|instant|thinking|message|off|reset|interval ]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTyping, -}) - -func fnTyping(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - isGroup := client.isGroupChat(ce.Ctx, ce.Portal) - if len(ce.Args) == 0 { - mode := client.resolveTypingMode(meta, &TypingContext{IsGroup: isGroup, WasMentioned: !isGroup}, false) - interval := client.resolveTypingInterval(meta) - response := fmt.Sprintf("Typing: mode=%s interval=%s", mode, formatTypingInterval(interval)) - if meta.TypingMode != "" || meta.TypingIntervalSeconds != nil { - overrideMode := "default" - if meta.TypingMode != "" { - overrideMode = meta.TypingMode - } - overrideInterval := "default" - if meta.TypingIntervalSeconds != nil { - overrideInterval = fmt.Sprintf("%ds", *meta.TypingIntervalSeconds) - } - response = fmt.Sprintf("%s (session override: mode=%s interval=%s)", response, overrideMode, overrideInterval) - } - ce.Reply(response) - return - } - - token := strings.ToLower(strings.TrimSpace(ce.Args[0])) - switch token { - case "reset", "default": - meta.TypingMode = "" - meta.TypingIntervalSeconds = nil - client.savePortalQuiet(ce.Ctx, ce.Portal, "typing reset") - ce.Reply("Typing settings reset to defaults.") - return - case "off": - meta.TypingMode = string(TypingModeNever) - client.savePortalQuiet(ce.Ctx, ce.Portal, "typing mode") - ce.Reply("Typing disabled for this session.") - return - case "interval": - if len(ce.Args) < 2 { - ce.Reply("Usage: `!ai typing interval `") - return - } - seconds, err := parsePositiveInt(ce.Args[1]) - if err != nil || seconds <= 0 { - ce.Reply("Interval must be a positive integer (seconds).") - return - } - meta.TypingIntervalSeconds = &seconds - client.savePortalQuiet(ce.Ctx, ce.Portal, "typing interval") - ce.Reply("Typing interval set to %ds.", seconds) - return - default: - if mode, ok := normalizeTypingMode(token); ok { - meta.TypingMode = string(mode) - client.savePortalQuiet(ce.Ctx, ce.Portal, "typing mode") - ce.Reply("Typing mode set to %s.", mode) - return - } - } - - ce.Reply("Usage: `!ai typing ` | `!ai typing interval ` | `!ai typing off` | `!ai typing reset`") -} - -// CommandTools handles the !ai tools command -var _ = registerAICommand(commandregistry.Definition{ - Name: "tools", - Description: "Enable/disable tools", - Args: "[on|off] [_tool_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTools, -}) - -func fnTools(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - // Run async to avoid blocking - go client.handleToolsCommand(ce.Ctx, ce.Portal, meta, ce.RawArgs) -} - -// CommandNew handles the !ai new command -var _ = registerAICommand(commandregistry.Definition{ - Name: "new", - Description: "Create a new chat of the same type (agent or model)", - Args: "[agent ]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnNew, -}) - -func fnNew(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - // Run async go client.handleNewChat(ce.Ctx, nil, ce.Portal, meta, ce.Args) } - -// CommandFork handles the !ai fork command -var _ = registerAICommand(commandregistry.Definition{ - Name: "fork", - Description: "Fork conversation to a new chat", - Args: "[_event_id_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnFork, -}) - -func fnFork(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - var arg string - if len(ce.Args) > 0 { - arg = ce.Args[0] - } - - // Run async - go client.handleFork(ce.Ctx, nil, ce.Portal, meta, arg) -} - -// CommandRegenerate handles the !ai regenerate command -var _ = registerAICommand(commandregistry.Definition{ - Name: "regenerate", - Description: "Regenerate the last AI response", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnRegenerate, -}) - -func fnRegenerate(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - // Run async - go client.handleRegenerate(ce.Ctx, nil, ce.Portal, meta) -} - -// CommandTitle handles the !ai title command -var _ = registerAICommand(commandregistry.Definition{ - Name: "title", - Description: "Regenerate the chat room title", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTitle, -}) - -func fnTitle(ce *commands.Event) { - client, _, ok := requireClientMeta(ce) - if !ok { - return - } - - // Run async - go client.handleRegenerateTitle(ce.Ctx, ce.Portal) -} - -// CommandModels handles the !ai models command -var _ = registerAICommand(commandregistry.Definition{ - Name: "models", - Description: "List all available models", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnModels, -}) - -func fnModels(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - // Get portal meta if available (for showing current model) - meta := getPortalMeta(ce) - - models, err := client.listAvailableModels(ce.Ctx, false) - if err != nil { - ce.Reply("Couldn't load models. Try again.") - return - } - - var sb strings.Builder - sb.WriteString("Available models:\n\n") - for _, m := range models { - var caps []string - if m.SupportsVision { - caps = append(caps, "Vision") - } - if m.SupportsReasoning { - caps = append(caps, "Reasoning") - } - if m.SupportsWebSearch { - caps = append(caps, "Web Search") - } - if m.SupportsImageGen { - caps = append(caps, "Image Gen") - } - if m.SupportsToolCalling { - caps = append(caps, "Tools") - } - sb.WriteString(fmt.Sprintf("• **%s** (`%s`)\n", m.Name, m.ID)) - if m.Description != "" { - sb.WriteString(fmt.Sprintf(" %s\n", m.Description)) - } - if len(caps) > 0 { - sb.WriteString(fmt.Sprintf(" %s\n", strings.Join(caps, " · "))) - } - sb.WriteString("\n") - } - - currentModel := client.effectiveModel(meta) - sb.WriteString(fmt.Sprintf("Current: **%s**\nUse `!ai model ` to switch models", currentModel)) - ce.Reply(sb.String()) -} - -// CommandTimezone handles the !ai timezone command -var _ = registerAICommand(commandregistry.Definition{ - Name: "timezone", - Description: "Get or set your timezone for all chats (IANA name)", - Args: "[_timezone_|reset]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnTimezone, -}) - -func fnTimezone(ce *commands.Event) { - client, _, ok := requireClientMeta(ce) - if !ok { - return - } - - loginMeta := loginMetadata(client.UserLogin) - if loginMeta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - if len(ce.Args) == 0 { - tz := strings.TrimSpace(loginMeta.Timezone) - if tz == "" { - ce.Reply("No timezone set. Use `!ai timezone ` (example: `America/Los_Angeles`).") - return - } - ce.Reply("Timezone: %s", tz) - return - } - - arg := strings.TrimSpace(ce.Args[0]) - switch strings.ToLower(arg) { - case "reset", "default", "clear": - loginMeta.Timezone = "" - if err := client.UserLogin.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't clear the timezone: %s", err.Error()) - return - } - ce.Reply("Timezone cleared. Falling back to UTC unless TZ is set.") - return - default: - tz, _, err := normalizeTimezone(arg) - if err != nil { - ce.Reply("Invalid timezone. Use an IANA name like `America/Los_Angeles` or `Europe/London`.") - return - } - loginMeta.Timezone = tz - if err := client.UserLogin.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save the timezone: %s", err.Error()) - return - } - ce.Reply("Timezone set to %s.", tz) - } -} - -// CommandGravatar handles the !ai gravatar command -var _ = registerAICommand(commandregistry.Definition{ - Name: "gravatar", - Description: "Fetch or set the Gravatar profile for this login", - Args: "[fetch|set] [email]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnGravatar, -}) - -func fnGravatar(ce *commands.Event) { - client, _, ok := requireClientMeta(ce) - if !ok { - return - } - - if len(ce.Args) == 0 { - loginMeta := loginMetadata(client.UserLogin) - if loginMeta == nil || loginMeta.Gravatar == nil || loginMeta.Gravatar.Primary == nil { - ce.Reply("No Gravatar profile set. Use `!ai gravatar set `.") - return - } - ce.Reply(formatGravatarMarkdown(loginMeta.Gravatar.Primary, "primary")) - return - } - - action := strings.ToLower(strings.TrimSpace(ce.Args[0])) - switch action { - case "fetch": - email := "" - if len(ce.Args) > 1 { - email = ce.Args[1] - } - if strings.TrimSpace(email) == "" { - loginMeta := loginMetadata(client.UserLogin) - if loginMeta != nil && loginMeta.Gravatar != nil && loginMeta.Gravatar.Primary != nil { - email = loginMeta.Gravatar.Primary.Email - } - } - if strings.TrimSpace(email) == "" { - ce.Reply("Email is required. Usage: `!ai gravatar fetch `.") - return - } - profile, err := fetchGravatarProfile(ce.Ctx, email) - if err != nil { - ce.Reply("Couldn't fetch the Gravatar profile: %s", err.Error()) - return - } - ce.Reply(formatGravatarMarkdown(profile, "fetched")) - return - case "set": - if len(ce.Args) < 2 || strings.TrimSpace(ce.Args[1]) == "" { - ce.Reply("Email is required. Usage: `!ai gravatar set `.") - return - } - profile, err := fetchGravatarProfile(ce.Ctx, ce.Args[1]) - if err != nil { - ce.Reply("Couldn't fetch the Gravatar profile: %s", err.Error()) - return - } - state := ensureGravatarState(loginMetadata(client.UserLogin)) - state.Primary = profile - if err := client.UserLogin.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save the Gravatar profile: %s", err.Error()) - return - } - ce.Reply(formatGravatarMarkdown(profile, "primary set")) - return - default: - ce.Reply("Usage: `!ai gravatar fetch ` or `!ai gravatar set `.") - } -} - -// CommandAgent handles the !ai agent command -var _ = registerAICommand(commandregistry.Definition{ - Name: "agent", - Description: "Get or set the agent for this chat", - Args: "[_agent id_]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnAgent, -}) - -func fnAgent(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - - store := NewAgentStoreAdapter(client) - - if len(ce.Args) == 0 { - // Show current agent - agentID := resolveAgentID(meta) - if agentID == "" { - ce.Reply("No agent configured. Using default model: %s", client.effectiveModel(meta)) - return - } - agent, err := store.GetAgentByID(ce.Ctx, agentID) - if err != nil { - ce.Reply("Current agent ID: %s (not found)", agentID) - return - } - displayName := client.resolveAgentDisplayName(ce.Ctx, agent) - if displayName == "" { - displayName = agent.ID - } - ce.Reply("Current agent: **%s** (`%s`)\n%s", displayName, agent.ID, agent.Description) - return - } - - if rejectBossOverrides(ce, meta, "Can't change the agent in a room managed by the Boss agent.") { - return - } - - // Set agent - agentID := ce.Args[0] - - // Special case: "none" clears the agent - if agentID == "none" || agentID == "clear" { - meta.AgentID = "" - meta.AgentPrompt = "" - modelID := client.effectiveModel(meta) - ce.Portal.OtherUserID = modelUserID(modelID) - client.savePortalQuiet(ce.Ctx, ce.Portal, "agent cleared") - _ = client.BroadcastRoomState(ce.Ctx, ce.Portal) - ce.Reply("Agent cleared. Using default model.") - return - } - - agent, err := store.GetAgentByID(ce.Ctx, agentID) - if err != nil { - ce.Reply("Agent not found: %s", agentID) - return - } - - meta.AgentID = agent.ID - meta.AgentPrompt = agent.SystemPrompt - meta.Model = "" - modelID := client.effectiveModel(meta) - meta.Capabilities = getModelCapabilities(modelID, client.findModelInfo(modelID)) - ce.Portal.OtherUserID = agentUserID(agent.ID) - client.savePortalQuiet(ce.Ctx, ce.Portal, "agent change") - agentName := client.resolveAgentDisplayName(ce.Ctx, agent) - client.ensureAgentGhostDisplayName(ce.Ctx, agent.ID, modelID, agentName) - _ = client.BroadcastRoomState(ce.Ctx, ce.Portal) - displayName := agentName - if displayName == "" { - displayName = agent.ID - } - ce.Reply("Agent set to **%s** (`%s`)", displayName, agent.ID) -} - -// CommandAgents handles the !ai agents command -var _ = registerAICommand(commandregistry.Definition{ - Name: "agents", - Description: "List available agents", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnAgents, -}) - -func fnAgents(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - store := NewAgentStoreAdapter(client) - agentsMap, err := store.LoadAgents(ce.Ctx) - if err != nil { - ce.Reply("Couldn't load agents: %v", err) - return - } - - var sb strings.Builder - sb.WriteString("## Available Agents\n\n") - - // Group by preset vs custom - var presets, custom []string - for id, agent := range agentsMap { - agentName := client.resolveAgentDisplayName(ce.Ctx, agent) - line := fmt.Sprintf("• **%s** (`%s`)", agentName, id) - if agent.Description != "" { - line += fmt.Sprintf(" - %s", agent.Description) - } - if agent.IsPreset { - presets = append(presets, line) - } else { - custom = append(custom, line) - } - } - - if len(presets) > 0 { - sb.WriteString("**Presets:**\n") - for _, line := range presets { - sb.WriteString(line + "\n") - } - sb.WriteString("\n") - } - - if len(custom) > 0 { - sb.WriteString("**Custom:**\n") - for _, line := range custom { - sb.WriteString(line + "\n") - } - sb.WriteString("\n") - } - - sb.WriteString("Use `!ai agent ` to switch agents") - ce.Reply(sb.String()) -} - -// CommandCreateAgent handles the !ai create-agent command -var _ = registerAICommand(commandregistry.Definition{ - Name: "create-agent", - Description: "Create a new custom agent", - Args: " [model] [system prompt...]", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnCreateAgent, -}) - -func fnCreateAgent(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - args := ce.Args - if raw := strings.TrimSpace(ce.RawArgs); raw != "" { - if parsed, err := splitQuotedArgs(raw); err == nil && len(parsed) > 0 { - args = parsed - } - } - - if len(args) < 2 { - ce.Reply("Usage: !ai create-agent [model] [system prompt...]\nExample: !ai create-agent my-helper \"My Helper\" gpt-4o You are a helpful assistant.") - return - } - - agentID := args[0] - agentName := args[1] - - if _, reserved := reservedAgentIDs[agentID]; reserved { - ce.Reply("Agent ID '%s' is reserved. Choose a different ID.", agentID) - return - } - if !isValidAgentID(agentID) { - ce.Reply("Invalid agent ID '%s'. Use only lowercase letters, numbers, and hyphens.", agentID) - return - } - - // Parse optional model and system prompt - var model, systemPrompt string - if len(args) > 2 { - model = args[2] - } - if len(args) > 3 { - systemPrompt = strings.Join(args[3:], " ") - } - - store := NewAgentStoreAdapter(client) - - // Check if agent already exists - if _, err := store.GetAgentByID(ce.Ctx, agentID); err == nil { - ce.Reply("Agent with ID '%s' already exists", agentID) - return - } - - // Create new agent - newAgent := &agents.AgentDefinition{ - ID: agentID, - Name: agentName, - SystemPrompt: systemPrompt, - Tools: &toolpolicy.ToolPolicyConfig{Profile: toolpolicy.ProfileFull}, - IsPreset: false, - CreatedAt: time.Now().Unix(), - UpdatedAt: time.Now().Unix(), - } - if model != "" { - newAgent.Model = agents.ModelConfig{Primary: model} - } - - if err := store.SaveAgent(ce.Ctx, newAgent); err != nil { - ce.Reply("Couldn't create the agent: %v", err) - return - } - - ce.Reply("Created agent: **%s** (`%s`)\nUse `!ai agent %s` to use it", agentName, agentID, agentID) -} - -// CommandDeleteAgent handles the !ai delete-agent command -var _ = registerAICommand(commandregistry.Definition{ - Name: "delete-agent", - Description: "Delete a custom agent", - Args: "", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnDeleteAgent, -}) - -func fnDeleteAgent(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - if len(ce.Args) < 1 { - ce.Reply("Usage: !ai delete-agent ") - return - } - - agentID := ce.Args[0] - store := NewAgentStoreAdapter(client) - - // Check if it's a preset - if agents.IsPreset(agentID) || agents.IsBossAgent(agentID) { - ce.Reply("Can't delete a preset agent: %s", agentID) - return - } - - if err := store.DeleteAgent(ce.Ctx, agentID); err != nil { - ce.Reply("Couldn't delete the agent: %v", err) - return - } - - ce.Reply("Deleted agent: %s", agentID) -} diff --git a/pkg/connector/commands_manage.go b/pkg/connector/commands_manage.go deleted file mode 100644 index c5350a8a..00000000 --- a/pkg/connector/commands_manage.go +++ /dev/null @@ -1,61 +0,0 @@ -package connector - -import ( - "maunium.net/go/mautrix/bridgev2/commands" - "maunium.net/go/mautrix/bridgev2/networkid" - - "github.com/beeper/ai-bridge/pkg/connector/commandregistry" -) - -// CommandManage handles the !ai manage command. -// This creates or opens the Builder room for advanced users to manage custom agents. -var _ = registerAICommand(commandregistry.Definition{ - Name: "manage", - Description: "Open the agent management room (for creating custom agents)", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnManage, -}) - -func fnManage(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - meta := loginMetadata(client.UserLogin) - - // Check if Builder room already exists - if meta.BuilderRoomID != "" { - portalKey := networkid.PortalKey{ - ID: meta.BuilderRoomID, - Receiver: client.UserLogin.ID, - } - portal, err := client.UserLogin.Bridge.GetPortalByKey(ce.Ctx, portalKey) - if err == nil && portal != nil && portal.MXID != "" { - ce.Reply("Agent management room: %s", portal.MXID) - return - } - // Room doesn't exist anymore, will create new one - } - - // Create Builder room on-demand - if err := client.ensureBuilderRoom(ce.Ctx); err != nil { - ce.Reply("Couldn't create the management room: %v", err) - return - } - - // Get the newly created room - meta = loginMetadata(client.UserLogin) - portalKey := networkid.PortalKey{ - ID: meta.BuilderRoomID, - Receiver: client.UserLogin.ID, - } - portal, err := client.UserLogin.Bridge.GetPortalByKey(ce.Ctx, portalKey) - if err != nil || portal == nil || portal.MXID == "" { - ce.Reply("Management room created, but the link isn't available.") - return - } - - ce.Reply("Created agent management room: %s\n\nIn this room you can:\n- Create custom agents\n- Manage existing agents\n- Configure advanced settings", portal.MXID) -} diff --git a/pkg/connector/commands_mcp.go b/pkg/connector/commands_mcp.go deleted file mode 100644 index 7fbd57c2..00000000 --- a/pkg/connector/commands_mcp.go +++ /dev/null @@ -1,472 +0,0 @@ -package connector - -import ( - "context" - "errors" - "fmt" - "strings" - "time" - - "maunium.net/go/mautrix/bridgev2/commands" - - "github.com/beeper/ai-bridge/pkg/connector/commandregistry" -) - -func mcpAddUsage(allowStdio bool) string { - if allowStdio { - return "`!ai mcp add [token] [authType] [authURL]` | `!ai mcp add streamable_http [token] [authType] [authURL]` | `!ai mcp add stdio [args...]`" - } - return "`!ai mcp add [token] [authType] [authURL]` | `!ai mcp add streamable_http [token] [authType] [authURL]`" -} - -func mcpManageUsage(allowStdio bool) string { - return fmt.Sprintf("`!ai mcp list` | %s | `!ai mcp connect [name] [token]` | `!ai mcp disconnect [name]` | `!ai mcp remove [name]`.", mcpAddUsage(allowStdio)) -} - -// CommandMCP handles the !ai mcp command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "mcp", - Description: "Manage MCP servers for this login", - Args: " [args]", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnMCPCommand, -}) - -func fnMCPCommand(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - allowStdio := client.isMCPStdioEnabled() - - if len(ce.Args) == 0 { - ce.Reply("Usage: %s", mcpManageUsage(allowStdio)) - return - } - - sub := strings.ToLower(strings.TrimSpace(ce.Args[0])) - switch sub { - case "list": - fnMCPList(ce, client) - return - case "add": - fnMCPAdd(ce, client) - return - case "connect": - fnMCPConnect(ce, client) - return - case "disconnect": - fnMCPDisconnect(ce, client) - return - case "remove": - fnMCPRemove(ce, client) - return - default: - ce.Reply("Usage: %s", mcpManageUsage(allowStdio)) - } -} - -func fnMCPList(ce *commands.Event, client *AIClient) { - servers := client.configuredMCPServers() - if len(servers) == 0 { - ce.Reply("No MCP servers are set up yet. Run `!ai mcp add` to add one.") - return - } - - toolCounts := map[string]int{} - ctx, cancel := context.WithTimeout(ce.Ctx, 3*time.Second) - defer cancel() - defs, err := client.mcpToolDefinitions(ctx) - if err == nil { - for _, def := range defs { - name := client.cachedMCPServerForTool(def.Name) - if name == "" { - continue - } - toolCounts[name]++ - } - } - - lines := make([]string, 0, len(servers)) - for _, server := range servers { - cfg := normalizeMCPServerConfig(server.Config) - status := "disconnected" - if cfg.Connected { - status = "connected" - } - auth := cfg.AuthType - if auth == "" { - auth = "none" - } - token := "missing" - if cfg.Token != "" || cfg.AuthType == "none" { - token = "set" - } - line := fmt.Sprintf("- %s: %s (transport=%s, target=%s, auth=%s, token=%s)", server.Name, status, cfg.Transport, mcpServerTargetLabel(cfg), auth, token) - if count, ok := toolCounts[server.Name]; ok { - line = fmt.Sprintf("%s, tools=%d", line, count) - } - if server.Source == "config" { - line += " [from config]" - } - lines = append(lines, line) - } - ce.Reply("MCP servers:\n%s", strings.Join(lines, "\n")) -} - -func parseMCPHTTPAuthArgs(rest []string) (token, authType, authURL string) { - authType = "bearer" - if len(rest) > 0 { - token = strings.TrimSpace(rest[0]) - } - if len(rest) > 1 { - authType = strings.TrimSpace(rest[1]) - } - if len(rest) > 2 { - authURL = strings.TrimSpace(strings.Join(rest[2:], " ")) - } - return token, authType, authURL -} - -func parseMCPAddArgs(args []string, allowStdio bool) (name string, cfg MCPServerConfig, err error) { - trimmed := make([]string, 0, len(args)) - for _, raw := range args { - part := strings.TrimSpace(raw) - if part != "" { - trimmed = append(trimmed, part) - } - } - if len(trimmed) == 0 { - return "", MCPServerConfig{}, errors.New("missing args") - } - - if len(trimmed) < 2 { - return "", MCPServerConfig{}, errors.New("missing target") - } - name = normalizeMCPServerName(trimmed[0]) - targetIndex := 1 - - rawTransportOrTarget := strings.TrimSpace(trimmed[targetIndex]) - normalizedTransport := normalizeMCPServerTransport(rawTransportOrTarget) - if normalizedTransport == mcpTransportStdio { - if !allowStdio { - return "", MCPServerConfig{}, errors.New("stdio disabled") - } - if len(trimmed) <= targetIndex+1 { - return "", MCPServerConfig{}, errors.New("missing command") - } - cfg = normalizeMCPServerConfig(MCPServerConfig{ - Transport: mcpTransportStdio, - Command: strings.TrimSpace(trimmed[targetIndex+1]), - Args: trimmed[targetIndex+2:], - AuthType: "none", - Connected: false, - Kind: mcpServerKindGeneric, - }) - if cfg.Command == "" { - return "", MCPServerConfig{}, errors.New("missing command") - } - return name, cfg, nil - } - - endpoint := rawTransportOrTarget - rest := trimmed[targetIndex+1:] - if normalizedTransport == mcpTransportStreamableHTTP { - if len(trimmed) <= targetIndex+1 { - return "", MCPServerConfig{}, errors.New("missing endpoint") - } - endpoint = strings.TrimSpace(trimmed[targetIndex+1]) - rest = trimmed[targetIndex+2:] - } - if !isLikelyHTTPURL(endpoint) { - return "", MCPServerConfig{}, errors.New("invalid endpoint") - } - token, authType, authURL := parseMCPHTTPAuthArgs(rest) - cfg = normalizeMCPServerConfig(MCPServerConfig{ - Transport: mcpTransportStreamableHTTP, - Endpoint: endpoint, - Token: token, - AuthType: authType, - AuthURL: authURL, - Connected: false, - Kind: mcpServerKindGeneric, - }) - return name, cfg, nil -} - -func fnMCPAdd(ce *commands.Event, client *AIClient) { - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - allowStdio := client.isMCPStdioEnabled() - name, cfg, err := parseMCPAddArgs(ce.Args[1:], allowStdio) - if err != nil { - if err.Error() == "stdio disabled" { - ce.Reply("Stdio MCP servers are disabled by the bridge configuration.") - return - } - ce.Reply("Usage: %s", mcpAddUsage(allowStdio)) - return - } - - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} - } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} - } - meta.ServiceTokens.MCPServers[name] = cfg - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save the MCP server: %s", err) - return - } - client.invalidateMCPToolCache() - - ce.Reply("Saved MCP server '%s' (%s). Connect with `!ai mcp connect %s`.", name, mcpServerTargetLabel(cfg), name) -} - -func resolveMCPServerArg(client *AIClient, args []string) (namedMCPServer, string, error) { - servers := client.configuredMCPServers() - if len(servers) == 0 { - return namedMCPServer{}, "", errors.New("none configured") - } - - if len(args) == 0 { - if len(servers) == 1 { - return servers[0], "", nil - } - return namedMCPServer{}, "", errors.New("ambiguous") - } - - candidate := strings.TrimSpace(args[0]) - for _, server := range servers { - if server.Name == normalizeMCPServerName(candidate) { - token := "" - if len(args) > 1 { - token = strings.TrimSpace(strings.Join(args[1:], " ")) - } - return server, token, nil - } - } - return namedMCPServer{}, "", errors.New("not found") -} - -func sendMCPAuthURLNotice(client *AIClient, ce *commands.Event, server namedMCPServer) { - if strings.TrimSpace(server.Config.AuthURL) == "" { - return - } - message := fmt.Sprintf("Sign in to MCP server '%s': %s", server.Name, server.Config.AuthURL) - if ce != nil && ce.Portal != nil { - client.sendSystemNotice(ce.Ctx, ce.Portal, message) - return - } - if ce != nil { - ce.Reply(message) - } -} - -func (oc *AIClient) verifyMCPServerConnection(ctx context.Context, server namedMCPServer) (int, error) { - if ctx == nil { - ctx = context.Background() - } - callCtx := ctx - var cancel context.CancelFunc - if _, hasDeadline := callCtx.Deadline(); !hasDeadline { - timeout := oc.mcpRequestTimeout() - if timeout > 10*time.Second { - timeout = 10 * time.Second - } - callCtx, cancel = context.WithTimeout(ctx, timeout) - } - if cancel != nil { - defer cancel() - } - defs, err := oc.fetchMCPToolsForServer(callCtx, server) - if err != nil { - return 0, err - } - return len(defs), nil -} - -func setLoginMCPServer(meta *UserLoginMetadata, name string, cfg MCPServerConfig) { - if meta.ServiceTokens == nil { - meta.ServiceTokens = &ServiceTokens{} - } - if meta.ServiceTokens.MCPServers == nil { - meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} - } - meta.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) -} - -func clearLoginMCPServer(meta *UserLoginMetadata, name string) { - if meta == nil || meta.ServiceTokens == nil || meta.ServiceTokens.MCPServers == nil { - return - } - delete(meta.ServiceTokens.MCPServers, name) - if len(meta.ServiceTokens.MCPServers) == 0 { - meta.ServiceTokens.MCPServers = nil - } - if serviceTokensEmpty(meta.ServiceTokens) { - meta.ServiceTokens = nil - } -} - -func fnMCPConnect(ce *commands.Event, client *AIClient) { - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - target, tokenOverride, err := resolveMCPServerArg(client, ce.Args[1:]) - if err != nil { - switch err.Error() { - case "none configured": - ce.Reply("No MCP servers are set up yet. Run `!ai mcp add` first.") - case "ambiguous": - ce.Reply("Multiple MCP servers are set up. Include a server name, or run `!ai mcp list`.") - default: - ce.Reply("Couldn't find that MCP server. Run `!ai mcp list`.") - } - return - } - - cfg := normalizeMCPServerConfig(target.Config) - if tokenOverride != "" && !mcpServerUsesStdio(cfg) { - cfg.Token = strings.TrimSpace(tokenOverride) - if cfg.Token != "" && cfg.AuthType == "none" { - cfg.AuthType = "bearer" - } - } - if !mcpServerHasTarget(cfg) { - ce.Reply("MCP server '%s' isn't configured with a target.", target.Name) - return - } - if mcpServerNeedsToken(cfg) && cfg.Token == "" { - cfg.Connected = false - setLoginMCPServer(meta, target.Name, cfg) - if saveErr := login.Save(ce.Ctx); saveErr != nil { - ce.Reply("Couldn't update MCP server '%s': %s", target.Name, saveErr) - return - } - client.invalidateMCPToolCache() - sendMCPAuthURLNotice(client, ce, namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}) - ce.Reply("MCP server '%s' needs a token. Add one: `!ai mcp connect %s `.", target.Name, target.Name) - return - } - - cfg.Connected = true - count, connectErr := client.verifyMCPServerConnection(ce.Ctx, namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}) - if connectErr != nil { - cfg.Connected = false - setLoginMCPServer(meta, target.Name, cfg) - if saveErr := login.Save(ce.Ctx); saveErr != nil { - ce.Reply("Couldn't save MCP server '%s': %s", target.Name, saveErr) - return - } - client.invalidateMCPToolCache() - if mcpCallLikelyAuthError(connectErr) { - sendMCPAuthURLNotice(client, ce, namedMCPServer{Name: target.Name, Config: cfg, Source: "login"}) - } - ce.Reply("Couldn't connect to MCP server '%s': %v", target.Name, connectErr) - return - } - - setLoginMCPServer(meta, target.Name, cfg) - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't save MCP server '%s': %s", target.Name, err) - return - } - client.invalidateMCPToolCache() - ce.Reply("Connected to MCP server '%s' (%d tools found).", target.Name, count) -} - -func fnMCPDisconnect(ce *commands.Event, client *AIClient) { - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - target, _, err := resolveMCPServerArg(client, ce.Args[1:]) - if err != nil { - switch err.Error() { - case "none configured": - ce.Reply("No MCP servers are set up yet.") - case "ambiguous": - ce.Reply("Multiple MCP servers are set up. Include a server name, or run `!ai mcp list`.") - default: - ce.Reply("Couldn't find that MCP server. Run `!ai mcp list`.") - } - return - } - - cfg := normalizeMCPServerConfig(target.Config) - cfg.Connected = false - setLoginMCPServer(meta, target.Name, cfg) - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't disconnect MCP server '%s': %s", target.Name, err) - return - } - client.invalidateMCPToolCache() - ce.Reply("Disconnected from MCP server '%s'.", target.Name) -} - -func fnMCPRemove(ce *commands.Event, client *AIClient) { - login := client.UserLogin - if login == nil { - ce.Reply("You're not signed in. Sign in and try again.") - return - } - meta := loginMetadata(login) - if meta == nil { - ce.Reply("Couldn't load your settings. Try again.") - return - } - - target, _, err := resolveMCPServerArg(client, ce.Args[1:]) - if err != nil { - switch err.Error() { - case "none configured": - ce.Reply("No MCP servers are set up yet.") - case "ambiguous": - ce.Reply("Multiple MCP servers are set up. Include a server name, or run `!ai mcp list`.") - default: - ce.Reply("Couldn't find that MCP server. Run `!ai mcp list`.") - } - return - } - - loginServers := client.loginMCPServers() - if _, ok := loginServers[target.Name]; !ok { - ce.Reply("MCP server '%s' is managed by the bridge configuration and can't be removed here. To override it for this login, run `!ai mcp disconnect %s`.", target.Name, target.Name) - return - } - - clearLoginMCPServer(meta, target.Name) - if err := login.Save(ce.Ctx); err != nil { - ce.Reply("Couldn't remove MCP server '%s': %s", target.Name, err) - return - } - client.invalidateMCPToolCache() - ce.Reply("Removed MCP server '%s'.", target.Name) -} diff --git a/pkg/connector/commands_parity.go b/pkg/connector/commands_parity.go index d8bfee85..bd41e568 100644 --- a/pkg/connector/commands_parity.go +++ b/pkg/connector/commands_parity.go @@ -1,19 +1,14 @@ package connector import ( - "encoding/json" - "fmt" - "strings" "time" "maunium.net/go/mautrix/bridgev2/commands" "github.com/beeper/ai-bridge/pkg/connector/commandregistry" airuntime "github.com/beeper/ai-bridge/pkg/runtime" - "github.com/beeper/ai-bridge/pkg/shared/stringutil" ) -// CommandStatus handles the !ai status command. var _ = registerAICommand(commandregistry.Definition{ Name: "status", Description: "Show current session status", @@ -23,16 +18,6 @@ var _ = registerAICommand(commandregistry.Definition{ Handler: fnStatus, }) -// CommandLastHeartbeat handles the !ai last-heartbeat command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "last-heartbeat", - Description: "Show the last heartbeat event for this login", - Section: HelpSectionAI, - RequiresPortal: false, - RequiresLogin: true, - Handler: fnLastHeartbeat, -}) - func fnStatus(ce *commands.Event) { client, meta, ok := requireClientMeta(ce) if !ok { @@ -43,32 +28,6 @@ func fnStatus(ce *commands.Event) { ce.Reply("%s", client.buildStatusText(ce.Ctx, ce.Portal, meta, isGroup, queueSettings)) } -func fnLastHeartbeat(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - evt := getLastHeartbeatEventForLogin(client.UserLogin) - if evt == nil { - ce.Reply("No heartbeat yet.") - return - } - pretty, err := json.MarshalIndent(evt, "", " ") - if err != nil { - ce.Reply("Failed to serialize last heartbeat: %s", err.Error()) - return - } - // Keep replies bounded; fall back to compact JSON if needed. - if len(pretty) > 8000 { - compact, err2 := json.Marshal(evt) - if err2 == nil { - pretty = compact - } - } - ce.Reply("```json\n%s\n```", string(pretty)) -} - -// CommandReset handles the !ai reset command. var _ = registerAICommand(commandregistry.Definition{ Name: "reset", Description: "Start a new session/thread in this room", @@ -85,8 +44,6 @@ func fnReset(ce *commands.Event) { } meta.SessionResetAt = time.Now().UnixMilli() - meta.GroupIntroSent = false - meta.GroupActivationNeedsIntro = true client.savePortalQuiet(ce.Ctx, ce.Portal, "session reset") client.clearPendingQueue(ce.Portal.MXID) client.cancelRoomRun(ce.Portal.MXID) @@ -94,7 +51,6 @@ func fnReset(ce *commands.Event) { ce.Reply("%s", formatSystemAck("Session reset.")) } -// CommandStop handles the !ai stop command. var _ = registerAICommand(commandregistry.Definition{ Name: "stop", Description: "Abort the current run and clear the pending queue", @@ -112,317 +68,3 @@ func fnStop(ce *commands.Event) { stopped := client.abortRoom(ce.Ctx, ce.Portal, meta) ce.Reply("%s", formatAbortNotice(stopped)) } - -// CommandQueue handles the !ai queue command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "queue", - Description: "Inspect or configure the message queue", - Args: "[status|reset|] [debounce:] [cap:] [drop:]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnQueue, -}) - -func fnQueue(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - portal := ce.Portal - - queueSettings, _, storeRef, sessionKey := client.resolveQueueSettingsForPortal(ce.Ctx, portal, meta, "", airuntime.QueueInlineOptions{}) - - if len(ce.Args) == 0 || strings.EqualFold(strings.TrimSpace(ce.Args[0]), "status") { - ce.Reply("%s", buildQueueStatusLine(queueSettings)) - return - } - - if strings.EqualFold(strings.TrimSpace(ce.Args[0]), "reset") { - if sessionKey != "" { - client.updateSessionEntry(ce.Ctx, storeRef, sessionKey, func(entry sessionEntry) sessionEntry { - entry.QueueMode = "" - entry.QueueDebounceMs = nil - entry.QueueCap = nil - entry.QueueDrop = "" - entry.UpdatedAt = time.Now().UnixMilli() - return entry - }) - } - client.clearPendingQueue(portal.MXID) - queueSettings, _, _, _ = client.resolveQueueSettingsForPortal(ce.Ctx, portal, meta, "", airuntime.QueueInlineOptions{}) - ce.Reply("%s", buildQueueStatusLine(queueSettings)) - return - } - - raw := strings.TrimSpace(strings.Join(ce.Args, " ")) - _, directive := parseQueueDirectiveArgs(raw) - if directive.HasDebounce && directive.DebounceMs == nil { - ce.Reply("Invalid debounce \"%s\". Use ms/s/m (e.g. debounce:1500ms, debounce:2s).", directive.RawDebounce) - return - } - if directive.HasCap && directive.Cap == nil { - ce.Reply("Invalid cap \"%s\". Use a positive integer (e.g. cap:10).", directive.RawCap) - return - } - if directive.HasDrop && directive.DropPolicy == nil { - ce.Reply("Invalid drop policy \"%s\". Use drop:old, drop:new, or drop:summarize.", directive.RawDrop) - return - } - if directive.QueueMode == "" && !directive.HasOptions { - ce.Reply("Usage: `!ai queue [status|reset|] [debounce:] [cap:] [drop:]`") - return - } - - if sessionKey != "" { - client.updateSessionEntry(ce.Ctx, storeRef, sessionKey, func(entry sessionEntry) sessionEntry { - if directive.QueueMode != "" { - entry.QueueMode = string(directive.QueueMode) - } - if directive.DebounceMs != nil { - entry.QueueDebounceMs = directive.DebounceMs - } - if directive.Cap != nil { - entry.QueueCap = directive.Cap - } - if directive.DropPolicy != nil { - entry.QueueDrop = string(*directive.DropPolicy) - } - entry.UpdatedAt = time.Now().UnixMilli() - return entry - }) - } - - queueSettings, _, _, _ = client.resolveQueueSettingsForPortal(ce.Ctx, portal, meta, "", airuntime.QueueInlineOptions{}) - ce.Reply("%s", buildQueueStatusLine(queueSettings)) -} - -// CommandThink handles the !ai think command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "think", - Description: "Get or set thinking level (off|minimal|low|medium|high|xhigh)", - Args: "[level]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnThink, -}) - -func fnThink(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - if len(ce.Args) == 0 { - ce.Reply("Thinking: %s", client.defaultThinkLevel(meta)) - return - } - level, ok := stringutil.NormalizeEnum(ce.Args[0], thinkLevelAliases) - if !ok { - ce.Reply("Usage: `!ai think off|minimal|low|medium|high|xhigh`") - return - } - applyThinkingLevel(meta, level) - client.savePortalQuiet(ce.Ctx, ce.Portal, "think change") - ce.Reply("%s", formatThinkingAck(level)) -} - -// CommandVerbose handles the !ai verbose command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "verbose", - Description: "Get or set verbosity (off|on|full)", - Args: "[level]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnVerbose, -}) - -func fnVerbose(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - if len(ce.Args) == 0 { - current := meta.VerboseLevel - if current == "" { - current = "off" - } - ce.Reply("Verbosity: %s", current) - return - } - level, ok := stringutil.NormalizeEnum(ce.Args[0], verboseLevelAliases) - if !ok { - ce.Reply("Usage: `!ai verbose on|off|full`") - return - } - meta.VerboseLevel = level - client.savePortalQuiet(ce.Ctx, ce.Portal, "verbose change") - ce.Reply("%s", formatVerboseAck(level)) -} - -// CommandReasoning handles the !ai reasoning command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "reasoning", - Description: "Get or set reasoning visibility/effort (off|on|low|medium|high|xhigh)", - Args: "[level]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnReasoning, -}) - -func fnReasoning(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - if len(ce.Args) == 0 { - current := strings.TrimSpace(meta.ReasoningEffort) - if current == "" { - if meta.EmitThinking { - current = "on" - } else { - current = "off" - } - } - ce.Reply("Reasoning: %s", current) - return - } - level, ok := stringutil.NormalizeEnum(ce.Args[0], reasoningLevelAliases) - if !ok { - ce.Reply("Usage: `!ai reasoning off|on|low|medium|high|xhigh`") - return - } - applyReasoningLevel(meta, level) - client.savePortalQuiet(ce.Ctx, ce.Portal, "reasoning change") - ce.Reply("%s", formatReasoningAck(level)) -} - -// CommandElevated handles the !ai elevated command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "elevated", - Description: "Get or set elevated access (off|on|ask|full)", - Args: "[level]", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnElevated, -}) - -func fnElevated(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - if len(ce.Args) == 0 { - current := meta.ElevatedLevel - if current == "" { - current = "off" - } - ce.Reply("Elevated access: %s", current) - return - } - level, ok := stringutil.NormalizeElevatedLevel(ce.Args[0]) - if !ok { - ce.Reply("Usage: `!ai elevated off|on|ask|full`") - return - } - meta.ElevatedLevel = level - client.savePortalQuiet(ce.Ctx, ce.Portal, "elevated change") - ce.Reply("%s", formatElevatedAck(level)) -} - -// CommandActivation handles the !ai activation command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "activation", - Description: "Set group activation policy (mention|always)", - Args: "", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnActivation, -}) - -func fnActivation(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - isGroup := client.isGroupChat(ce.Ctx, ce.Portal) - if !isGroup { - ce.Reply("%s", formatSystemAck("Group activation only applies to group chats.")) - return - } - if len(ce.Args) == 0 { - ce.Reply("%s", formatSystemAck("Usage: `!ai activation mention|always`")) - return - } - level, ok := stringutil.NormalizeEnum(ce.Args[0], groupActivationAliases) - if !ok { - ce.Reply("%s", formatSystemAck("Usage: `!ai activation mention|always`")) - return - } - meta.GroupActivation = level - meta.GroupActivationNeedsIntro = true - meta.GroupIntroSent = false - client.savePortalQuiet(ce.Ctx, ce.Portal, "activation change") - ce.Reply("%s", formatSystemAck(fmt.Sprintf("Group activation set to %s.", level))) -} - -// CommandSend handles the !ai send command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "send", - Description: "Allow/deny sending messages (on|off|inherit)", - Args: "", - Section: HelpSectionAI, - RequiresPortal: true, - RequiresLogin: true, - Handler: fnSend, -}) - -func fnSend(ce *commands.Event) { - client, meta, ok := requireClientMeta(ce) - if !ok { - return - } - if len(ce.Args) == 0 { - ce.Reply("%s", formatSystemAck("Usage: `!ai send on|off|inherit`")) - return - } - mode, ok := stringutil.NormalizeEnum(ce.Args[0], sendPolicyAliases) - if !ok { - ce.Reply("%s", formatSystemAck("Usage: `!ai send on|off|inherit`")) - return - } - if mode == "inherit" { - meta.SendPolicy = "" - } else { - meta.SendPolicy = mode - } - client.savePortalQuiet(ce.Ctx, ce.Portal, "send policy change") - label := mode - if mode == "allow" { - label = "on" - } else if mode == "deny" { - label = "off" - } - ce.Reply("%s", formatSystemAck(fmt.Sprintf("Send policy set to %s.", label))) -} - -// CommandWhoami handles the !ai whoami command. -var _ = registerAICommand(commandregistry.Definition{ - Name: "whoami", - Description: "Show your Matrix user ID", - Section: HelpSectionAI, - RequiresPortal: false, - RequiresLogin: false, - Handler: fnWhoami, -}) - -func fnWhoami(ce *commands.Event) { - if ce == nil || ce.User == nil { - return - } - ce.Reply("You are %s.", ce.User.MXID.String()) -} diff --git a/pkg/connector/commands_simple.go b/pkg/connector/commands_simple.go deleted file mode 100644 index cf7377bb..00000000 --- a/pkg/connector/commands_simple.go +++ /dev/null @@ -1,86 +0,0 @@ -package connector - -import ( - "fmt" - "strings" - - "maunium.net/go/mautrix/bridgev2/commands" - - "github.com/beeper/ai-bridge/pkg/connector/commandregistry" -) - -// CommandSimple handles the !ai simple command with sub-commands. -var _ = registerAICommand(commandregistry.Definition{ - Name: "simple", - Description: "Manage AI chat rooms (new, list)", - Args: "", - Section: HelpSectionAI, - RequiresLogin: true, - Handler: fnSimple, -}) - -func fnSimple(ce *commands.Event) { - client, ok := requireClient(ce) - if !ok { - return - } - - subCmd := "" - if len(ce.Args) > 0 { - subCmd = strings.ToLower(ce.Args[0]) - } - - switch subCmd { - case "new": - var modelID string - if len(ce.Args) > 1 { - resolved, valid, err := client.resolveModelID(ce.Ctx, ce.Args[1]) - if err != nil || !valid || resolved == "" { - ce.Reply("That model isn't available: %s", ce.Args[1]) - return - } - modelID = resolved - } else { - modelID = client.effectiveModel(nil) - } - go client.createAndOpenSimpleChat(ce.Ctx, ce.Portal, modelID) - ce.Reply("Creating AI chat with %s...", modelID) - - case "list": - models, err := client.listAvailableModels(ce.Ctx, false) - if err != nil { - ce.Reply("Couldn't load models.") - return - } - var sb strings.Builder - sb.WriteString("Available models:\n\n") - for _, m := range models { - var caps []string - if m.SupportsVision { - caps = append(caps, "Vision") - } - if m.SupportsReasoning { - caps = append(caps, "Reasoning") - } - if m.SupportsWebSearch { - caps = append(caps, "Web Search") - } - if m.SupportsImageGen { - caps = append(caps, "Image Gen") - } - if m.SupportsToolCalling { - caps = append(caps, "Tools") - } - sb.WriteString(fmt.Sprintf("• **%s** (`%s`)\n", m.Name, m.ID)) - if len(caps) > 0 { - sb.WriteString(fmt.Sprintf(" %s\n", strings.Join(caps, " · "))) - } - sb.WriteString("\n") - } - sb.WriteString("Use `!ai simple new [model]` to create a chat") - ce.Reply(sb.String()) - - default: - ce.Reply("Usage:\n• `!ai simple new [model]` — Create a new AI chat\n• `!ai simple list` — List available models") - } -} diff --git a/pkg/connector/desktop_api_helpers.go b/pkg/connector/desktop_api_helpers.go new file mode 100644 index 00000000..a3cb81ac --- /dev/null +++ b/pkg/connector/desktop_api_helpers.go @@ -0,0 +1,36 @@ +package connector + +import ( + "errors" + "strings" +) + +func parseDesktopAPIAddArgs(args []string) (name, token, baseURL string, err error) { + if len(args) == 0 { + return "", "", "", errors.New("missing args") + } + + trimmed := make([]string, 0, len(args)) + for _, raw := range args { + part := strings.TrimSpace(raw) + if part != "" { + trimmed = append(trimmed, part) + } + } + if len(trimmed) == 0 { + return "", "", "", errors.New("missing args") + } + + if len(trimmed) == 1 { + return "", trimmed[0], "", nil + } + + if len(trimmed) == 2 { + if isLikelyHTTPURL(trimmed[1]) { + return "", trimmed[0], trimmed[1], nil + } + return normalizeDesktopInstanceName(trimmed[0]), trimmed[1], "", nil + } + + return normalizeDesktopInstanceName(trimmed[0]), trimmed[1], strings.TrimSpace(strings.Join(trimmed[2:], " ")), nil +} diff --git a/pkg/connector/group_activation.go b/pkg/connector/group_activation.go index 822b3c60..ac6827a6 100644 --- a/pkg/connector/group_activation.go +++ b/pkg/connector/group_activation.go @@ -1,17 +1,9 @@ package connector -import ( - "strings" - - "github.com/beeper/ai-bridge/pkg/shared/stringutil" -) +import "github.com/beeper/ai-bridge/pkg/shared/stringutil" func (oc *AIClient) resolveGroupActivation(meta *PortalMetadata) string { - if meta != nil { - if normalized, ok := stringutil.NormalizeEnum(meta.GroupActivation, groupActivationAliases); ok { - return normalized - } - } + _ = meta if oc != nil && oc.connector != nil && oc.connector.Config.Messages != nil && oc.connector.Config.Messages.GroupChat != nil { if normalized, ok := stringutil.NormalizeEnum(oc.connector.Config.Messages.GroupChat.Activation, groupActivationAliases); ok { return normalized @@ -21,12 +13,6 @@ func (oc *AIClient) resolveGroupActivation(meta *PortalMetadata) string { } func normalizeSendPolicyMode(raw string) string { - value := strings.ToLower(strings.TrimSpace(raw)) - if value == "deny" || value == "off" { - return "deny" - } - if value == "allow" || value == "on" { - return "allow" - } + _ = raw return "" } diff --git a/pkg/connector/handleai.go b/pkg/connector/handleai.go index e4c2ab3c..c5c9dbdf 100644 --- a/pkg/connector/handleai.go +++ b/pkg/connector/handleai.go @@ -223,15 +223,9 @@ func isInternalControlRoom(meta *PortalMetadata) bool { } func autoGreetingBlockReason(meta *PortalMetadata) string { - sendPolicy := "" - if meta != nil { - sendPolicy = meta.SendPolicy - } switch { case isInternalControlRoom(meta): return "internal-control-room" - case normalizeSendPolicyMode(sendPolicy) == "deny": - return "send-policy-deny" case resolveAgentID(meta) == "": return "no-agent" } diff --git a/pkg/connector/heartbeat_execute.go b/pkg/connector/heartbeat_execute.go index f2f76dcf..36661817 100644 --- a/pkg/connector/heartbeat_execute.go +++ b/pkg/connector/heartbeat_execute.go @@ -148,12 +148,6 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, if promptMeta == nil { promptMeta = &PortalMetadata{} } - promptMeta.AgentID = agentID - if heartbeat != nil && heartbeat.Model != nil { - if model := strings.TrimSpace(*heartbeat.Model); model != "" { - promptMeta.Model = model - } - } responsePrefix := resolveResponsePrefixForHeartbeat(oc, cfg, agentID, promptMeta) hbCfg := &HeartbeatRunConfig{ Reason: reason, diff --git a/pkg/connector/integration_host.go b/pkg/connector/integration_host.go index a41da4f5..aadaa9f1 100644 --- a/pkg/connector/integration_host.go +++ b/pkg/connector/integration_host.go @@ -280,7 +280,7 @@ func (h *runtimeIntegrationHost) IsInternalRoom(meta any) bool { if m == nil { return false } - return m.IsBuilderRoom || isModuleInternalRoom(m) + return isModuleInternalRoom(m) } func (h *runtimeIntegrationHost) PortalMeta(portal any) any { diff --git a/pkg/connector/mcp_helpers.go b/pkg/connector/mcp_helpers.go new file mode 100644 index 00000000..e1160f64 --- /dev/null +++ b/pkg/connector/mcp_helpers.go @@ -0,0 +1,181 @@ +package connector + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" +) + +func mcpAddUsage(allowStdio bool) string { + if allowStdio { + return "`!ai mcp add [token] [authType] [authURL]` | `!ai mcp add streamable_http [token] [authType] [authURL]` | `!ai mcp add stdio [args...]`" + } + return "`!ai mcp add [token] [authType] [authURL]` | `!ai mcp add streamable_http [token] [authType] [authURL]`" +} + +func mcpManageUsage(allowStdio bool) string { + return fmt.Sprintf("`!ai mcp list` | %s | `!ai mcp connect [name] [token]` | `!ai mcp disconnect [name]` | `!ai mcp remove [name]`.", mcpAddUsage(allowStdio)) +} + +func isLikelyHTTPURL(raw string) bool { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || parsed == nil { + return false + } + return parsed.Scheme == "http" || parsed.Scheme == "https" +} + +func parseMCPHTTPAuthArgs(rest []string) (token, authType, authURL string) { + authType = "bearer" + if len(rest) > 0 { + token = strings.TrimSpace(rest[0]) + } + if len(rest) > 1 { + authType = strings.TrimSpace(rest[1]) + } + if len(rest) > 2 { + authURL = strings.TrimSpace(strings.Join(rest[2:], " ")) + } + return token, authType, authURL +} + +func parseMCPAddArgs(args []string, allowStdio bool) (name string, cfg MCPServerConfig, err error) { + trimmed := make([]string, 0, len(args)) + for _, raw := range args { + part := strings.TrimSpace(raw) + if part != "" { + trimmed = append(trimmed, part) + } + } + if len(trimmed) == 0 { + return "", MCPServerConfig{}, errors.New("missing args") + } + + if len(trimmed) < 2 { + return "", MCPServerConfig{}, errors.New("missing target") + } + name = normalizeMCPServerName(trimmed[0]) + targetIndex := 1 + + rawTransportOrTarget := strings.TrimSpace(trimmed[targetIndex]) + normalizedTransport := normalizeMCPServerTransport(rawTransportOrTarget) + if normalizedTransport == mcpTransportStdio { + if !allowStdio { + return "", MCPServerConfig{}, errors.New("stdio disabled") + } + if len(trimmed) <= targetIndex+1 { + return "", MCPServerConfig{}, errors.New("missing command") + } + cfg = normalizeMCPServerConfig(MCPServerConfig{ + Transport: mcpTransportStdio, + Command: strings.TrimSpace(trimmed[targetIndex+1]), + Args: trimmed[targetIndex+2:], + AuthType: "none", + Connected: false, + Kind: mcpServerKindGeneric, + }) + if cfg.Command == "" { + return "", MCPServerConfig{}, errors.New("missing command") + } + return name, cfg, nil + } + + endpoint := rawTransportOrTarget + rest := trimmed[targetIndex+1:] + if normalizedTransport == mcpTransportStreamableHTTP { + if len(trimmed) <= targetIndex+1 { + return "", MCPServerConfig{}, errors.New("missing endpoint") + } + endpoint = strings.TrimSpace(trimmed[targetIndex+1]) + rest = trimmed[targetIndex+2:] + } + if !isLikelyHTTPURL(endpoint) { + return "", MCPServerConfig{}, errors.New("invalid endpoint") + } + token, authType, authURL := parseMCPHTTPAuthArgs(rest) + cfg = normalizeMCPServerConfig(MCPServerConfig{ + Transport: mcpTransportStreamableHTTP, + Endpoint: endpoint, + Token: token, + AuthType: authType, + AuthURL: authURL, + Connected: false, + Kind: mcpServerKindGeneric, + }) + return name, cfg, nil +} + +func resolveMCPServerArg(client *AIClient, args []string) (namedMCPServer, string, error) { + servers := client.configuredMCPServers() + if len(servers) == 0 { + return namedMCPServer{}, "", errors.New("none configured") + } + + if len(args) == 0 { + if len(servers) == 1 { + return servers[0], "", nil + } + return namedMCPServer{}, "", errors.New("ambiguous") + } + + candidate := strings.TrimSpace(args[0]) + for _, server := range servers { + if server.Name == normalizeMCPServerName(candidate) { + token := "" + if len(args) > 1 { + token = strings.TrimSpace(strings.Join(args[1:], " ")) + } + return server, token, nil + } + } + return namedMCPServer{}, "", errors.New("not found") +} + +func (oc *AIClient) verifyMCPServerConnection(ctx context.Context, server namedMCPServer) (int, error) { + if ctx == nil { + ctx = context.Background() + } + callCtx := ctx + var cancel context.CancelFunc + if _, hasDeadline := callCtx.Deadline(); !hasDeadline { + timeout := oc.mcpRequestTimeout() + if timeout > 10*time.Second { + timeout = 10 * time.Second + } + callCtx, cancel = context.WithTimeout(ctx, timeout) + } + if cancel != nil { + defer cancel() + } + defs, err := oc.fetchMCPToolsForServer(callCtx, server) + if err != nil { + return 0, err + } + return len(defs), nil +} + +func setLoginMCPServer(meta *UserLoginMetadata, name string, cfg MCPServerConfig) { + if meta.ServiceTokens == nil { + meta.ServiceTokens = &ServiceTokens{} + } + if meta.ServiceTokens.MCPServers == nil { + meta.ServiceTokens.MCPServers = map[string]MCPServerConfig{} + } + meta.ServiceTokens.MCPServers[name] = normalizeMCPServerConfig(cfg) +} + +func clearLoginMCPServer(meta *UserLoginMetadata, name string) { + if meta == nil || meta.ServiceTokens == nil || meta.ServiceTokens.MCPServers == nil { + return + } + delete(meta.ServiceTokens.MCPServers, name) + if len(meta.ServiceTokens.MCPServers) == 0 { + meta.ServiceTokens.MCPServers = nil + } + if serviceTokensEmpty(meta.ServiceTokens) { + meta.ServiceTokens = nil + } +} diff --git a/pkg/connector/metadata.go b/pkg/connector/metadata.go index 9ffa0bdb..9d907dfe 100644 --- a/pkg/connector/metadata.go +++ b/pkg/connector/metadata.go @@ -226,6 +226,7 @@ type PortalMetadata struct { DisabledTools []string `json:"-"` ResolvedTarget *ResolvedTarget `json:"-"` RuntimeModelOverride string `json:"-"` + RuntimeReasoning string `json:"-"` // Legacy-only. New code must not read this. ResponsePrefix string `json:"response_prefix,omitempty"` diff --git a/pkg/connector/reasoning_fallback.go b/pkg/connector/reasoning_fallback.go index fd3f63c2..7f58a294 100644 --- a/pkg/connector/reasoning_fallback.go +++ b/pkg/connector/reasoning_fallback.go @@ -34,7 +34,7 @@ func (oc *AIClient) responseWithRetryAndReasoningFallback( if meta != nil && currentLevel != originalLevel { // Clone meta and override reasoning effort metaCopy := *meta - metaCopy.ReasoningEffort = currentLevel + metaCopy.RuntimeReasoning = currentLevel effectiveMeta = &metaCopy oc.loggerForContext(ctx).Info(). Str("original_level", originalLevel). diff --git a/pkg/connector/remote_message.go b/pkg/connector/remote_message.go index 8fb74ddf..7892473c 100644 --- a/pkg/connector/remote_message.go +++ b/pkg/connector/remote_message.go @@ -80,12 +80,10 @@ func (m *OpenAIRemoteMessage) ConvertMessage(ctx context.Context, portal *bridge m.Metadata.Body = m.Content } - // Get model from metadata or portal fallback + // Prefer the message metadata model when present. model := "" if m.Metadata != nil && m.Metadata.Model != "" { model = m.Metadata.Model - } else if portalMeta, ok := portal.Metadata.(*PortalMetadata); ok && portalMeta.Model != "" { - model = portalMeta.Model } var thinkingContent string diff --git a/pkg/connector/response_prefix.go b/pkg/connector/response_prefix.go index 880d8f6b..0793bd1e 100644 --- a/pkg/connector/response_prefix.go +++ b/pkg/connector/response_prefix.go @@ -1,123 +1,16 @@ package connector -import ( - "context" - "strings" - - "github.com/beeper/ai-bridge/pkg/agents" -) - -func resolveChannelResponsePrefix(cfg *Config) string { - if cfg == nil || cfg.Channels == nil { - return "" - } - if cfg.Channels.Matrix != nil { - if trimmed := strings.TrimSpace(cfg.Channels.Matrix.ResponsePrefix); trimmed != "" { - return trimmed - } - } - if cfg.Channels.Defaults != nil { - if trimmed := strings.TrimSpace(cfg.Channels.Defaults.ResponsePrefix); trimmed != "" { - return trimmed - } - } - return "" -} - -func resolveResponsePrefixRaw(oc *AIClient, cfg *Config, meta *PortalMetadata) string { - if meta != nil { - if trimmed := strings.TrimSpace(meta.ResponsePrefix); trimmed != "" { - return trimmed - } - } - if oc != nil && oc.UserLogin != nil { - if login := loginMetadata(oc.UserLogin); login != nil { - if trimmed := strings.TrimSpace(login.ResponsePrefix); trimmed != "" { - return trimmed - } - } - } - if channelPrefix := resolveChannelResponsePrefix(cfg); channelPrefix != "" { - return channelPrefix - } - if cfg == nil || cfg.Messages == nil { - return "" - } - return strings.TrimSpace(cfg.Messages.ResponsePrefix) -} - -func resolveIdentityNameForPrefix(oc *AIClient, agentID string) string { - if oc == nil { - return "" - } - resolved := strings.TrimSpace(agentID) - if resolved == "" { - resolved = agents.DefaultAgentID - } - store := NewAgentStoreAdapter(oc) - if agent, err := store.GetAgentByID(context.Background(), resolved); err == nil && agent != nil { - if agent.Identity != nil && strings.TrimSpace(agent.Identity.Name) != "" { - return strings.TrimSpace(agent.Identity.Name) - } - } - return oc.resolveAgentIdentityName(context.Background(), resolved) -} - -func buildResponsePrefixContext(oc *AIClient, agentID string, meta *PortalMetadata) ResponsePrefixContext { - ctx := ResponsePrefixContext{ - IdentityName: resolveIdentityNameForPrefix(oc, agentID), - } - if oc == nil { - return ctx - } - modelFull := oc.effectiveModel(meta) - if modelFull != "" { - ctx.ModelFull = modelFull - ctx.Model = extractShortModelName(modelFull) - ctx.Provider, _ = splitModelProvider(modelFull) - } - if ctx.Provider == "" { - if login := loginMetadata(oc.UserLogin); login != nil { - ctx.Provider = strings.TrimSpace(login.Provider) - } - } - think := strings.TrimSpace(oc.effectiveReasoningEffort(meta)) - if think == "" { - think = "off" - } - ctx.ThinkingLevel = think - return ctx -} - func resolveResponsePrefixForHeartbeat(oc *AIClient, cfg *Config, agentID string, meta *PortalMetadata) string { - raw := resolveResponsePrefixRaw(oc, cfg, meta) - if raw == "" { - return "" - } - if strings.EqualFold(raw, "auto") { - name := resolveIdentityNameForPrefix(oc, agentID) - if name == "" { - return "" - } - return "[" + name + "]" - } - ctx := buildResponsePrefixContext(oc, agentID, meta) - return resolveResponsePrefixTemplate(raw, ctx) + _ = oc + _ = cfg + _ = agentID + _ = meta + return "" } func resolveResponsePrefixForReply(oc *AIClient, cfg *Config, meta *PortalMetadata) string { - raw := resolveResponsePrefixRaw(oc, cfg, meta) - if raw == "" { - return "" - } - agentID := resolveAgentID(meta) - if strings.EqualFold(raw, "auto") { - name := resolveIdentityNameForPrefix(oc, agentID) - if name == "" { - return "" - } - return "[" + name + "]" - } - ctx := buildResponsePrefixContext(oc, agentID, meta) - return resolveResponsePrefixTemplate(raw, ctx) + _ = oc + _ = cfg + _ = meta + return "" } diff --git a/pkg/connector/streaming_chat_completions.go b/pkg/connector/streaming_chat_completions.go index 752d3fd1..5619a4ab 100644 --- a/pkg/connector/streaming_chat_completions.go +++ b/pkg/connector/streaming_chat_completions.go @@ -73,7 +73,7 @@ func (oc *AIClient) streamChatCompletions( params.Tools = append(params.Tools, ToOpenAIChatTools(enabledTools, &oc.log)...) } if oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling && chatHasAgent { - if !oc.isBuilderRoom(portal) { + if !hasBossAgent(meta) { var enabledSessions []*tools.Tool for _, tool := range tools.SessionTools() { if oc.isToolEnabled(meta, tool.Name) { @@ -84,7 +84,7 @@ func (oc *AIClient) streamChatCompletions( params.Tools = append(params.Tools, bossToolsToChatTools(enabledSessions, &oc.log)...) } } - if hasBossAgent(meta) || oc.isBuilderRoom(portal) { + if hasBossAgent(meta) { var enabledBoss []*tools.Tool for _, tool := range tools.BossTools() { if oc.isToolEnabled(meta, tool.Name) { diff --git a/pkg/connector/streaming_params.go b/pkg/connector/streaming_params.go index b6f94424..65457eda 100644 --- a/pkg/connector/streaming_params.go +++ b/pkg/connector/streaming_params.go @@ -60,8 +60,8 @@ func (oc *AIClient) buildResponsesAPIParams(ctx context.Context, portal *bridgev } if oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling && hasAgent { - // Add session tools for non-boss rooms - if !hasBossAgent(meta) && !oc.isBuilderRoom(portal) { + // Add session tools for non-boss agent rooms. + if !hasBossAgent(meta) { var enabledSessions []*tools.Tool for _, tool := range tools.SessionTools() { if oc.isToolEnabled(meta, tool.Name) { @@ -76,7 +76,7 @@ func (oc *AIClient) buildResponsesAPIParams(ctx context.Context, portal *bridgev } // Add boss tools if this is a Boss room - if hasBossAgent(meta) || oc.isBuilderRoom(portal) { + if hasBossAgent(meta) { var enabledBoss []*tools.Tool for _, tool := range tools.BossTools() { if oc.isToolEnabled(meta, tool.Name) { diff --git a/pkg/connector/streaming_state.go b/pkg/connector/streaming_state.go index 6d9ab2fa..2a9eff95 100644 --- a/pkg/connector/streaming_state.go +++ b/pkg/connector/streaming_state.go @@ -111,9 +111,6 @@ func newStreamingState(ctx context.Context, meta *PortalMetadata, sourceEventID ui: ui, pendingMcpApprovalsSeen: make(map[string]bool), } - if meta != nil && normalizeSendPolicyMode(meta.SendPolicy) == "deny" { - state.suppressSend = true - } if hb := heartbeatRunFromContext(ctx); hb != nil { state.heartbeat = hb.Config state.heartbeatResultCh = hb.ResultCh diff --git a/pkg/connector/subagent_spawn.go b/pkg/connector/subagent_spawn.go index f37f6962..c80f9881 100644 --- a/pkg/connector/subagent_spawn.go +++ b/pkg/connector/subagent_spawn.go @@ -299,15 +299,8 @@ func (oc *AIClient) executeSessionsSpawn(ctx context.Context, portal *bridgev2.P childMeta := portalMeta(childPortal) childMeta.SubagentParentRoomID = portal.MXID.String() - childMeta.SystemPrompt = agents.BuildSubagentSystemPrompt(agents.SubagentPromptParams{ - RequesterSessionKey: portal.MXID.String(), - RequesterChannel: "matrix", - ChildSessionKey: childPortal.MXID.String(), - Label: label, - Task: task, - }) if reasoningEffort != "" { - childMeta.ReasoningEffort = reasoningEffort + childMeta.RuntimeReasoning = reasoningEffort } roomName := resolveSubagentRoomName(label, task) diff --git a/pkg/connector/system_prompts.go b/pkg/connector/system_prompts.go index df71e331..dc3f7683 100644 --- a/pkg/connector/system_prompts.go +++ b/pkg/connector/system_prompts.go @@ -79,15 +79,9 @@ func (oc *AIClient) buildAdditionalSystemPromptsCore( if meta != nil && portal != nil && oc.isGroupChat(ctx, portal) { activation := oc.resolveGroupActivation(meta) - shouldIntro := !meta.GroupIntroSent || meta.GroupActivationNeedsIntro - if shouldIntro { - intro := buildGroupIntro(oc.matrixRoomDisplayName(ctx, portal), activation) - if strings.TrimSpace(intro) != "" { - out = append(out, openai.SystemMessage(intro)) - } - meta.GroupIntroSent = true - meta.GroupActivationNeedsIntro = false - oc.savePortalQuiet(ctx, portal, "group intro") + intro := buildGroupIntro(oc.matrixRoomDisplayName(ctx, portal), activation) + if strings.TrimSpace(intro) != "" { + out = append(out, openai.SystemMessage(intro)) } } diff --git a/pkg/connector/tool_execution.go b/pkg/connector/tool_execution.go index 4879512b..0de42e15 100644 --- a/pkg/connector/tool_execution.go +++ b/pkg/connector/tool_execution.go @@ -128,7 +128,7 @@ func (oc *AIClient) executeBuiltinToolDirect(ctx context.Context, portal *bridge return oc.executeMCPTool(ctx, toolName, args) } // Check if this is a Boss room or a session tool - use boss tool executor - if (meta != nil && hasBossAgent(meta)) || oc.isBuilderRoom(portal) || tools.IsSessionTool(toolName) || tools.IsBossTool(toolName) { + if (meta != nil && hasBossAgent(meta)) || tools.IsSessionTool(toolName) || tools.IsBossTool(toolName) { if result := oc.executeBossTool(ctx, portal, toolName, args); result != nil { return result.Content, result.Error } diff --git a/pkg/connector/typing_queue.go b/pkg/connector/typing_queue.go index fc824f30..cf588a75 100644 --- a/pkg/connector/typing_queue.go +++ b/pkg/connector/typing_queue.go @@ -11,9 +11,6 @@ func (oc *AIClient) startQueueTyping(ctx context.Context, portal *bridgev2.Porta if oc == nil || portal == nil || portal.MXID == "" { return } - if meta != nil && normalizeSendPolicyMode(meta.SendPolicy) == "deny" { - return - } if typingCtx == nil { typingCtx = &TypingContext{IsGroup: oc.isGroupChat(ctx, portal)} } From ad49f94af9e33bc45608fba8f60eaf54e3adb47f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 22:15:48 +0100 Subject: [PATCH 22/23] Remove remaining legacy room metadata --- README.md | 4 +- docs/matrix-ai-matrix-spec-v1.md | 56 +--------------- docs/msc/com.beeper.mscXXXX-commands.md | 39 ++--------- pkg/connector/chat.go | 33 ++++----- pkg/connector/chat_fork_test.go | 4 -- pkg/connector/client.go | 7 +- pkg/connector/client_capabilities_test.go | 10 +-- pkg/connector/defaults_alignment_test.go | 25 ------- pkg/connector/group_activation.go | 5 -- pkg/connector/handleai.go | 17 ----- pkg/connector/heartbeat_context.go | 1 - pkg/connector/heartbeat_execute.go | 2 - pkg/connector/history_limit_test.go | 11 --- pkg/connector/inbound_directive_apply.go | 53 --------------- pkg/connector/integrations_config.go | 14 ++-- .../integrations_example-config.yaml | 8 +-- pkg/connector/metadata.go | 49 ++------------ pkg/connector/response_finalization.go | 24 +------ pkg/connector/response_prefix.go | 16 ----- pkg/connector/response_prefix_template.go | 67 ------------------- pkg/connector/streaming_finish_reason_test.go | 2 +- pkg/connector/streaming_init_test.go | 5 +- pkg/connector/streaming_params.go | 2 +- .../streaming_tool_selection_test.go | 17 +---- pkg/connector/system_prompts_test.go | 2 +- .../tool_availability_configured_test.go | 22 ++++-- pkg/connector/tool_policy_apply_patch_test.go | 28 +++----- pkg/connector/tools_unique_test.go | 2 +- 28 files changed, 76 insertions(+), 449 deletions(-) delete mode 100644 pkg/connector/inbound_directive_apply.go delete mode 100644 pkg/connector/response_prefix.go delete mode 100644 pkg/connector/response_prefix_template.go diff --git a/README.md b/README.md index 5008a18e..9a5588e6 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Batteries included - one click setup (for [Beeper Plus](https://www.beeper.com/p Coming soon to Beeper Desktop as an experiment. Join the [Developer Community](beeper://connect) on [Matrix](https://matrix.to/#/#beeper-developers:beeper.com?via=beeper.com) for early access. -Connect all your chats with one click and manage your inbox with agents. Supports image generation, reminders, web search, and memory. Create basic AI Chats to talk to models with no tools and customizable system prompt. +Connect all your chats with one click and manage your inbox with agents. Supports image generation, reminders, web search, and memory. Create direct model chats for simple conversations or agent chats for richer workflows. Made by humans using agentic coding. @@ -20,7 +20,7 @@ Experimental Matrix ↔ AI bridge for Beeper, built on top of [mautrix/bridgev2] - Per-model chats (each model shows up as its own contact) - Streaming responses - Multimodal input (images, PDFs, audio, video) when supported by the model -- Per-room settings (model, temperature, system prompt, context limits, tools) +- Ghost-based chat targeting for models and agents - Login flows for Beeper, Magic Proxy, or custom (BYOK) - OpenClaw-style memory search (stored in the bridge DB) diff --git a/docs/matrix-ai-matrix-spec-v1.md b/docs/matrix-ai-matrix-spec-v1.md index a554ac25..5a6d4c73 100644 --- a/docs/matrix-ai-matrix-spec-v1.md +++ b/docs/matrix-ai-matrix-spec-v1.md @@ -31,7 +31,7 @@ This document specifies a Matrix transport profile for real-time AI: - primary: ephemeral events (`com.beeper.ai.stream_event` with AI SDK `UIMessageChunk`) - fallback: debounced `m.replace` timeline edits when ephemeral delivery is unavailable - `com.beeper.ai.*` timeline projection events (tool call/result, compaction status, etc). -- `com.beeper.ai.*` state events (room settings/capabilities). +- standard Matrix room features for capability advertising. - Tool approvals (MCP approvals + selected builtin tools). - Auxiliary `com.beeper.ai*` keys used for routing/metadata. @@ -82,9 +82,6 @@ Authoritative identifiers are defined in `pkg/matrixevents/matrixevents.go`. | `m.room.message` | message | timeline | Canonical assistant message carrier (`com.beeper.ai`) | [Canonical](#canonical) | | `com.beeper.ai.stream_event` | ephemeral | ephemeral | Streaming `UIMessageChunk` deltas | [Streaming](#streaming) | | `com.beeper.ai.compaction_status` | message | timeline | Context compaction lifecycle/status | [Projections](#projection-compaction) | -| `com.beeper.ai.room_capabilities` | state | state | Producer-controlled capabilities and effective settings | [State](#state-room-capabilities) | -| `com.beeper.ai.room_settings` | state | state | User-editable room settings | [State](#state-room-settings) | -| `com.beeper.ai.model_capabilities` | state | state | Per-model capabilities (e.g. supported features) | — | | `com.beeper.ai.agents` | state | state | Agent definitions for the room | — | ### Content Keys (Inside Standard Events) @@ -288,56 +285,7 @@ Example: ## State Events -State events broadcast room configuration and capabilities. - - -### `com.beeper.ai.room_capabilities` -Producer-controlled capabilities and effective settings. - -Fields (see `RoomCapabilitiesEventContent` in `pkg/connector/events.go`): -- `capabilities?: ModelCapabilities` -- `available_tools?: ToolInfo[]` -- `reasoning_effort_options?: { value: string, label: string }[]` -- `provider?: string` -- `effective_settings?: object` - -Example: -```json -{ - "capabilities": { - "supports_reasoning": true, - "supports_tool_calling": true - }, - "available_tools": [ - {"name": "web_search", "display_name": "Web Search", "type": "provider", "enabled": true, "available": true} - ], - "provider": "beeper" -} -``` - - -### `com.beeper.ai.room_settings` -User-editable room settings. - -Fields (see `RoomSettingsEventContent` in `pkg/connector/events.go`): -- `model?: string` -- `system_prompt?: string` -- `temperature?: number` -- `max_context_messages?: number` -- `max_completion_tokens?: number` -- `reasoning_effort?: string` -- `agent_id?: string` -- `emit_thinking?: boolean` -- `emit_tool_args?: boolean` - -Example: -```json -{ - "model": "openai/gpt-5", - "temperature": 0.7, - "agent_id": "boss" -} -``` +This bridge no longer uses custom room state for editable AI configuration. Room target selection is determined by ghost identity and membership, while room-level capability advertising uses standard Matrix room features. ## Tool Approvals diff --git a/docs/msc/com.beeper.mscXXXX-commands.md b/docs/msc/com.beeper.mscXXXX-commands.md index 26dce85c..7b3a0a3a 100644 --- a/docs/msc/com.beeper.mscXXXX-commands.md +++ b/docs/msc/com.beeper.mscXXXX-commands.md @@ -8,7 +8,7 @@ This is a profile document, not a new MSC. It specifies which commands ai-bridge ## Motivation -Text-based bot commands (`!ai model gpt-4o`, `!ai reset`) have several problems: +Text-based bot commands (`!ai status`, `!ai reset`) have several problems: - **Undiscoverable:** Users must read documentation or type `!ai help` to learn available commands. There is no in-client autocomplete or parameter hinting. - **Fragile parsing:** Free-text command parsing leads to ambiguous inputs and poor error messages. Typed parameters eliminate this class of bugs. @@ -36,22 +36,6 @@ The bot MUST broadcast one state event per command when it joins a room. The `st ``` ```json -{ - "type": "org.matrix.msc4391.command_description", - "state_key": "model", - "content": { - "description": "Get or set the AI model", - "arguments": { - "model_id": { - "description": "Model identifier (e.g. gpt-4o, claude-sonnet)", - "required": false, - "type": "string" - } - } - } -} -``` - ### Structured Invocation When a client sends a command, it MUST include the `org.matrix.msc4391.command` field in the message content: @@ -61,12 +45,10 @@ When a client sends a command, it MUST include the `org.matrix.msc4391.command` "type": "m.room.message", "content": { "msgtype": "m.text", - "body": "!ai model gpt-4o", + "body": "!ai status", "org.matrix.msc4391.command": { - "command": "model", - "arguments": { - "model_id": "gpt-4o" - } + "command": "status", + "arguments": {} } } } @@ -80,19 +62,10 @@ Commands broadcast by ai-bridge: | Command | Description | Arguments | |---------|-------------|-----------| +| `new` | Create a new chat of the same type | `agent?: string` | | `status` | Show current session status | — | -| `model` | Get or set the AI model | `model_id?: string` | | `reset` | Start a new session/thread | — | | `stop` | Abort current run and clear queue | — | -| `think` | Get or set thinking level | `level?: off\|minimal\|low\|medium\|high\|xhigh` | -| `verbose` | Get or set verbosity | `level?: off\|on\|full` | -| `reasoning` | Get or set reasoning visibility | `level?: off\|on\|low\|medium\|high\|xhigh` | -| `elevated` | Get or set elevated access | `level?: off\|on\|ask\|full` | -| `activation` | Set group activation policy | `policy: mention\|always` | -| `send` | Allow/deny sending messages | `mode: on\|off\|inherit` | -| `queue` | Inspect or configure message queue | `action?: status\|reset\|` | -| `whoami` | Show your Matrix user ID | — | -| `last-heartbeat` | Show last heartbeat event | — | Dynamic commands from integrations and modules are also broadcast as state events. @@ -104,7 +77,7 @@ When both are present, the structured `org.matrix.msc4391.command` field takes p ## Security Considerations -- **Command authorization:** The bot SHOULD check room power levels before executing commands that modify room or session state. Commands like `reset`, `model`, and `elevated` affect all users in the room. +- **Command authorization:** The bot SHOULD check room power levels before executing commands that modify room or session state. - **Argument validation:** The bot MUST validate structured arguments against the published schema before execution. Malformed arguments MUST be rejected with an error message. ## Unstable Prefix diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 12f9a121..6415565d 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -577,9 +577,8 @@ func (oc *AIClient) createAgentChatWithModel(ctx context.Context, agent *agents. agentName := oc.resolveAgentDisplayName(ctx, agent) portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ - ModelID: modelID, - Title: fmt.Sprintf("Chat with %s", agentName), - SystemPrompt: agent.SystemPrompt, + ModelID: modelID, + Title: fmt.Sprintf("Chat with %s", agentName), }) if err != nil { return nil, err @@ -629,8 +628,7 @@ func (oc *AIClient) createAgentChatWithModel(ctx context.Context, agent *agents. // createNewChat creates a new portal for a specific model func (oc *AIClient) createNewChat(ctx context.Context, modelID string) (*bridgev2.CreateChatResponse, error) { portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ - ModelID: modelID, - SystemPrompt: defaultSimpleModeSystemPrompt, + ModelID: modelID, }) if err != nil { return nil, err @@ -665,11 +663,10 @@ func (oc *AIClient) allocateNextChatIndex(ctx context.Context) (int, error) { // PortalInitOpts contains options for initializing a chat portal type PortalInitOpts struct { - ModelID string - Title string - SystemPrompt string - CopyFrom *PortalMetadata // For forked chats - copies config from source - PortalKey *networkid.PortalKey + ModelID string + Title string + CopyFrom *PortalMetadata // For forked chats - copies config from source + PortalKey *networkid.PortalKey } func cloneForkPortalMetadata(src *PortalMetadata, slug, title string) *PortalMetadata { @@ -737,8 +734,6 @@ func (oc *AIClient) initPortalForChat(ctx context.Context, opts PortalInitOpts) portal.AvatarID = networkid.AvatarID(defaultAvatar) portal.AvatarMXC = id.ContentURIString(defaultAvatar) } - // Note: portal.Topic is NOT set to SystemPrompt - they are separate concepts - if err := portal.Save(ctx); err != nil { return nil, nil, fmt.Errorf("failed to save portal: %w", err) } @@ -876,7 +871,7 @@ func (oc *AIClient) handleNewChat( // No args: create new room of same type if meta == nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't read current room settings.") + oc.sendSystemNotice(runCtx, portal, "Couldn't resolve the current chat target.") return } agentID := resolveAgentID(meta) @@ -1113,8 +1108,7 @@ func (oc *AIClient) copyMessagesToChat( // createNewSimpleChat creates a new simple mode chat portal with the specified model. func (oc *AIClient) createNewSimpleChat(ctx context.Context, modelID string) (*bridgev2.Portal, *bridgev2.ChatInfo, error) { portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ - ModelID: modelID, - SystemPrompt: defaultSimpleModeSystemPrompt, + ModelID: modelID, }) if err != nil { return nil, nil, err @@ -1448,7 +1442,7 @@ func (oc *AIClient) ensureSingleAIGhost(ctx context.Context, portal *bridgev2.Po return nil } -// BroadcastRoomState sends current room capabilities and settings to Matrix room state +// BroadcastRoomState refreshes standard Matrix room capabilities and command descriptions. func (oc *AIClient) BroadcastRoomState(ctx context.Context, portal *bridgev2.Portal) error { portal.UpdateCapabilities(ctx, oc.UserLogin, true) oc.BroadcastCommandDescriptions(ctx, portal) @@ -1627,10 +1621,9 @@ func (oc *AIClient) ensureDefaultChat(ctx context.Context) error { } portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ - ModelID: modelID, - Title: "New AI Chat", - SystemPrompt: beeperAgent.SystemPrompt, - PortalKey: &defaultPortalKey, + ModelID: modelID, + Title: "New AI Chat", + PortalKey: &defaultPortalKey, }) if err != nil { existingPortal, existingErr := oc.UserLogin.Bridge.GetExistingPortalByKey(ctx, defaultPortalKey) diff --git a/pkg/connector/chat_fork_test.go b/pkg/connector/chat_fork_test.go index 012b336c..cf4e3c0e 100644 --- a/pkg/connector/chat_fork_test.go +++ b/pkg/connector/chat_fork_test.go @@ -4,7 +4,6 @@ import "testing" func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) { src := &PortalMetadata{ - GroupActivation: "always", // Legacy field is not copied in fork metadata. ResolvedTarget: &ResolvedTarget{ Kind: ResolvedTargetModel, GhostID: modelUserID("openai/gpt-5"), @@ -25,7 +24,4 @@ func TestCloneForkPortalMetadata_PreservesSimpleMode(t *testing.T) { if !isSimpleMode(got) { t.Fatalf("expected forked metadata to keep resolved simple-mode target") } - if got.GroupActivation != "" { - t.Fatalf("expected GroupActivation to remain unset in fork metadata copy, got %q", got.GroupActivation) - } } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 9ea7eeb4..50b059a2 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -1148,8 +1148,7 @@ func updateGhostLastSync(_ context.Context, ghost *bridgev2.Ghost) bool { func (oc *AIClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal) *event.RoomFeatures { meta := portalMeta(portal) - // Always recompute effective room capabilities to ensure they're up-to-date - // (includes image-understanding union for agent rooms) + // Always recompute effective room capabilities from the resolved room target. modelCaps := oc.getRoomCapabilities(ctx, meta) allowTextFiles := oc.canUseMediaUnderstanding(meta) supportsPDF := modelCaps.SupportsPDF || oc.isOpenRouterProvider() @@ -1411,9 +1410,7 @@ func getLinkPreviewConfig(connectorConfig *Config) LinkPreviewConfig { return config } -// effectiveAgentPrompt returns the system prompt for the agent assigned to the room. -// This uses BuildSystemPrompt to generate a full prompt with room context when an agent is configured. -// Returns empty string if no agent is configured. +// effectiveAgentPrompt returns the resolved agent prompt for the current room target. func (oc *AIClient) effectiveAgentPrompt(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) string { if meta == nil { return "" diff --git a/pkg/connector/client_capabilities_test.go b/pkg/connector/client_capabilities_test.go index 39074f1d..db40ff0e 100644 --- a/pkg/connector/client_capabilities_test.go +++ b/pkg/connector/client_capabilities_test.go @@ -16,9 +16,7 @@ func TestGetCapabilities_SimpleModeDisablesReplyEditReaction(t *testing.T) { portal := &bridgev2.Portal{ Portal: &database.Portal{ OtherUserID: modelUserID("openai/gpt-5"), - Metadata: &PortalMetadata{ - Capabilities: ModelCapabilities{SupportsToolCalling: true}, - }, + Metadata: simpleModeTestMeta("openai/gpt-5"), }, } @@ -47,9 +45,7 @@ func TestGetCapabilities_NonSimpleEnablesReplyEditReaction(t *testing.T) { portal := &bridgev2.Portal{ Portal: &database.Portal{ OtherUserID: agentUserID("beeper"), - Metadata: &PortalMetadata{ - Capabilities: ModelCapabilities{SupportsToolCalling: true}, - }, + Metadata: agentModeTestMeta("beeper"), }, } @@ -71,7 +67,7 @@ func TestGetCapabilities_MessageToolDisabledDisablesReplyEditReaction(t *testing Portal: &database.Portal{ OtherUserID: agentUserID("beeper"), Metadata: &PortalMetadata{ - Capabilities: ModelCapabilities{SupportsToolCalling: true}, + ResolvedTarget: agentModeTestMeta("beeper").ResolvedTarget, DisabledTools: []string{ ToolNameMessage, }, diff --git a/pkg/connector/defaults_alignment_test.go b/pkg/connector/defaults_alignment_test.go index 1657c8c7..1f1afc3f 100644 --- a/pkg/connector/defaults_alignment_test.go +++ b/pkg/connector/defaults_alignment_test.go @@ -48,28 +48,3 @@ func TestDefaultThinkLevelModelAware(t *testing.T) { t.Fatalf("expected off for non-reasoning models, got %q", got) } } - -func TestDefaultThinkLevelIgnoresLegacyThinkingOverrides(t *testing.T) { - client := &AIClient{ - connector: &OpenAIConnector{}, - UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ - Provider: ProviderOpenRouter, - ModelCache: &ModelCache{Models: []ModelInfo{ - {ID: "openai/o4-mini", SupportsReasoning: true}, - }}, - }}}, - } - meta := &PortalMetadata{ - ThinkingLevel: "high", - ReasoningEffort: "medium", - ResolvedTarget: &ResolvedTarget{ - Kind: ResolvedTargetModel, - GhostID: modelUserID("openai/o4-mini"), - ModelID: "openai/o4-mini", - }, - } - - if got := client.defaultThinkLevel(meta); got != "low" { - t.Fatalf("expected ghost/model-derived low think level, got %q", got) - } -} diff --git a/pkg/connector/group_activation.go b/pkg/connector/group_activation.go index ac6827a6..07a5db1b 100644 --- a/pkg/connector/group_activation.go +++ b/pkg/connector/group_activation.go @@ -11,8 +11,3 @@ func (oc *AIClient) resolveGroupActivation(meta *PortalMetadata) string { } return "mention" } - -func normalizeSendPolicyMode(raw string) string { - _ = raw - return "" -} diff --git a/pkg/connector/handleai.go b/pkg/connector/handleai.go index c5c9dbdf..ce9da56a 100644 --- a/pkg/connector/handleai.go +++ b/pkg/connector/handleai.go @@ -629,20 +629,3 @@ func (oc *AIClient) getModelContextWindow(meta *PortalMetadata) int { // Default for unknown models return 128000 } - -// This is separate from room topic (which is display-only). -func (oc *AIClient) setRoomSystemPrompt(ctx context.Context, portal *bridgev2.Portal, prompt string) error { - return oc.setRoomSystemPromptInternal(ctx, portal, prompt, true) -} - -func (oc *AIClient) setRoomSystemPromptNoSave(ctx context.Context, portal *bridgev2.Portal, prompt string) error { - return oc.setRoomSystemPromptInternal(ctx, portal, prompt, false) -} - -func (oc *AIClient) setRoomSystemPromptInternal(ctx context.Context, portal *bridgev2.Portal, prompt string, save bool) error { - _ = ctx - _ = portal - _ = prompt - _ = save - return nil -} diff --git a/pkg/connector/heartbeat_context.go b/pkg/connector/heartbeat_context.go index 8241b522..c4f49d6f 100644 --- a/pkg/connector/heartbeat_context.go +++ b/pkg/connector/heartbeat_context.go @@ -14,7 +14,6 @@ type HeartbeatRunConfig struct { UseIndicator bool IncludeReasoning bool ExecEvent bool - ResponsePrefix string SessionKey string StoreAgentID string PrevUpdatedAt int64 diff --git a/pkg/connector/heartbeat_execute.go b/pkg/connector/heartbeat_execute.go index 36661817..00c52ade 100644 --- a/pkg/connector/heartbeat_execute.go +++ b/pkg/connector/heartbeat_execute.go @@ -148,7 +148,6 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, if promptMeta == nil { promptMeta = &PortalMetadata{} } - responsePrefix := resolveResponsePrefixForHeartbeat(oc, cfg, agentID, promptMeta) hbCfg := &HeartbeatRunConfig{ Reason: reason, AckMaxChars: resolveHeartbeatAckMaxChars(cfg, heartbeat), @@ -157,7 +156,6 @@ func (oc *AIClient) runHeartbeatOnce(agentID string, heartbeat *HeartbeatConfig, UseIndicator: visibility.UseIndicator, IncludeReasoning: heartbeat != nil && heartbeat.IncludeReasoning != nil && *heartbeat.IncludeReasoning, ExecEvent: hasExecCompletion, - ResponsePrefix: responsePrefix, SessionKey: storeKey, StoreAgentID: sessionResolution.StoreRef.AgentID, PrevUpdatedAt: prevUpdatedAt, diff --git a/pkg/connector/history_limit_test.go b/pkg/connector/history_limit_test.go index 8ae8f695..4bae5b36 100644 --- a/pkg/connector/history_limit_test.go +++ b/pkg/connector/history_limit_test.go @@ -8,17 +8,6 @@ import ( "maunium.net/go/mautrix/bridgev2/database" ) -func TestHistoryLimitIgnoresLegacyMetaOverride(t *testing.T) { - client := &AIClient{} - portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test", RoomType: database.RoomTypeGroupDM}} - meta := &PortalMetadata{MaxContextMessages: 7} - - limit := client.historyLimit(context.Background(), portal, meta) - if limit != defaultGroupContextMessages { - t.Fatalf("expected group default %d, got %d", defaultGroupContextMessages, limit) - } -} - func TestHistoryLimitDefaultsByRoomType(t *testing.T) { client := &AIClient{} diff --git a/pkg/connector/inbound_directive_apply.go b/pkg/connector/inbound_directive_apply.go deleted file mode 100644 index ab5de54b..00000000 --- a/pkg/connector/inbound_directive_apply.go +++ /dev/null @@ -1,53 +0,0 @@ -package connector - -import "fmt" - -func applyThinkingLevel(meta *PortalMetadata, level string) { - _ = meta - _ = level -} - -func applyReasoningLevel(meta *PortalMetadata, level string) { - _ = meta - _ = level -} - -func formatThinkingAck(level string) string { - if level == "off" { - return "Thinking disabled." - } - return fmt.Sprintf("Thinking level set to %s.", level) -} - -func formatVerboseAck(level string) string { - switch level { - case "off": - return formatSystemAck("Verbose logging disabled.") - case "full": - return formatSystemAck("Verbose logging set to full.") - default: - return formatSystemAck("Verbose logging enabled.") - } -} - -func formatReasoningAck(level string) string { - switch level { - case "off": - return formatSystemAck("Reasoning visibility disabled.") - case "stream": - return formatSystemAck("Reasoning stream enabled (Telegram only).") - default: - return formatSystemAck("Reasoning visibility enabled.") - } -} - -func formatElevatedAck(level string) string { - switch level { - case "off": - return formatSystemAck("Elevated mode disabled.") - case "full": - return formatSystemAck("Elevated mode set to full (auto-approve).") - default: - return formatSystemAck("Elevated mode set to ask (approvals may still apply).") - } -} diff --git a/pkg/connector/integrations_config.go b/pkg/connector/integrations_config.go index 95a00a30..3efb6ef3 100644 --- a/pkg/connector/integrations_config.go +++ b/pkg/connector/integrations_config.go @@ -161,15 +161,13 @@ type ChannelsConfig struct { } type ChannelDefaultsConfig struct { - Heartbeat *ChannelHeartbeatVisibilityConfig `yaml:"heartbeat"` - ResponsePrefix string `yaml:"responsePrefix"` + Heartbeat *ChannelHeartbeatVisibilityConfig `yaml:"heartbeat"` } type ChannelConfig struct { - Heartbeat *ChannelHeartbeatVisibilityConfig `yaml:"heartbeat"` - ResponsePrefix string `yaml:"responsePrefix"` - ReplyToMode string `yaml:"replyToMode"` // off|first|all (Matrix) - ThreadReplies string `yaml:"threadReplies"` // off|inbound|always (Matrix) + Heartbeat *ChannelHeartbeatVisibilityConfig `yaml:"heartbeat"` + ReplyToMode string `yaml:"replyToMode"` // off|first|all (Matrix) + ThreadReplies string `yaml:"threadReplies"` // off|inbound|always (Matrix) } type ChannelHeartbeatVisibilityConfig struct { @@ -180,7 +178,6 @@ type ChannelHeartbeatVisibilityConfig struct { // MessagesConfig defines message rendering settings. type MessagesConfig struct { - ResponsePrefix string `yaml:"responsePrefix"` AckReaction string `yaml:"ackReaction"` AckReactionScope string `yaml:"ackReactionScope"` // group-mentions|group-all|direct|all|off|none RemoveAckAfter bool `yaml:"removeAckAfterReply"` @@ -544,7 +541,6 @@ func upgradeConfig(helper configupgrade.Helper) { helper.Copy(configupgrade.Bool, "cron", "enabled") // Messages configuration - helper.Copy(configupgrade.Str, "messages", "responsePrefix") helper.Copy(configupgrade.List, "commands", "ownerAllowFrom") helper.Copy(configupgrade.Str, "messages", "queue", "mode") helper.Copy(configupgrade.Map, "messages", "queue", "byChannel") @@ -575,11 +571,9 @@ func upgradeConfig(helper configupgrade.Helper) { helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "showOk") helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "showAlerts") helper.Copy(configupgrade.Bool, "channels", "defaults", "heartbeat", "useIndicator") - helper.Copy(configupgrade.Str, "channels", "defaults", "responsePrefix") helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "showOk") helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "showAlerts") helper.Copy(configupgrade.Bool, "channels", "matrix", "heartbeat", "useIndicator") - helper.Copy(configupgrade.Str, "channels", "matrix", "responsePrefix") helper.Copy(configupgrade.Str, "channels", "matrix", "replyToMode") helper.Copy(configupgrade.Str, "channels", "matrix", "threadReplies") diff --git a/pkg/connector/integrations_example-config.yaml b/pkg/connector/integrations_example-config.yaml index 6a28c7f1..3ad59010 100644 --- a/pkg/connector/integrations_example-config.yaml +++ b/pkg/connector/integrations_example-config.yaml @@ -54,10 +54,8 @@ model_cache_duration: 6h # Optional message rendering settings. messages: - # Prefix applied to outbound replies (and heartbeat ok acks). - responsePrefix: "" # History defaults for prompt construction. - # Set 0 to disable, or override per-room with !ai context. + # Set 0 to disable. directChat: historyLimit: 20 groupChat: @@ -94,8 +92,8 @@ tool_approvals: # Optional per-channel overrides. channels: matrix: - # Response prefix override for Matrix rooms. - responsePrefix: "" + # Matrix reply/thread behavior. + replyToMode: "first" # Session configuration. session: diff --git a/pkg/connector/metadata.go b/pkg/connector/metadata.go index 9d907dfe..7280347f 100644 --- a/pkg/connector/metadata.go +++ b/pkg/connector/metadata.go @@ -7,7 +7,6 @@ import ( "go.mau.fi/util/jsontime" "go.mau.fi/util/random" "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" "github.com/beeper/ai-bridge/pkg/bridgeadapter" "github.com/beeper/ai-bridge/pkg/shared/jsonutil" @@ -53,15 +52,6 @@ type UserProfile struct { CustomInstructions string `json:"custom_instructions,omitempty"` } -// Legacy-only storage type kept so old JSON can still unmarshal during the hard cut. -// New code must not read this. -type UserDefaults struct { - Model string `json:"model,omitempty"` - SystemPrompt string `json:"system_prompt,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` -} - // ServiceTokens stores optional per-login credentials for external services. type ServiceTokens struct { OpenAI string `json:"openai,omitempty"` @@ -129,15 +119,11 @@ type UserLoginMetadata struct { Gravatar *GravatarState `json:"gravatar,omitempty"` Timezone string `json:"timezone,omitempty"` Profile *UserProfile `json:"profile,omitempty"` - ResponsePrefix string `json:"response_prefix,omitempty"` // Legacy-only. New code must not read this. // FileAnnotationCache stores parsed PDF content from OpenRouter's file-parser plugin // Key is the file hash (SHA256), pruned after 7 days FileAnnotationCache map[string]FileAnnotation `json:"file_annotation_cache,omitempty"` - // Legacy-only. New code must not read this. - Defaults *UserDefaults `json:"defaults,omitempty"` - // Optional per-login tokens for external services ServiceTokens *ServiceTokens `json:"service_tokens,omitempty"` @@ -146,8 +132,6 @@ type UserLoginMetadata struct { // Custom agents store (source of truth for user-created agents). CustomAgents map[string]*AgentDefinitionContent `json:"custom_agents,omitempty"` - // Legacy-only. New code must not read this. - BuilderRoomID networkid.PortalID `json:"builder_room_id,omitempty"` // Last active room per agent (used for heartbeat delivery). LastActiveRoomByAgent map[string]string `json:"last_active_room_by_agent,omitempty"` // Heartbeat dedupe state per agent. @@ -179,33 +163,11 @@ type GravatarState struct { Primary *GravatarProfile `json:"primary,omitempty"` } -// PortalMetadata stores per-room tuning knobs for the assistant. +// PortalMetadata stores non-derivable per-room runtime state. type PortalMetadata struct { - // Legacy-only selector/tuning fields kept to allow old JSON to unmarshal during - // the hard cut. New code must not read these. - Model string `json:"model,omitempty"` - SystemPrompt string `json:"system_prompt,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - MaxContextMessages int `json:"max_context_messages,omitempty"` - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` - Capabilities ModelCapabilities `json:"capabilities,omitempty"` - PDFConfig *PDFConfig `json:"pdf_config,omitempty"` - EmitThinking bool `json:"emit_thinking,omitempty"` - EmitToolArgs bool `json:"emit_tool_args,omitempty"` - ThinkingLevel string `json:"thinking_level,omitempty"` - VerboseLevel string `json:"verbose_level,omitempty"` - ElevatedLevel string `json:"elevated_level,omitempty"` - GroupActivation string `json:"group_activation,omitempty"` - GroupActivationNeedsIntro bool `json:"group_activation_needs_intro,omitempty"` - GroupIntroSent bool `json:"group_intro_sent,omitempty"` - SendPolicy string `json:"send_policy,omitempty"` - AgentID string `json:"agent_id,omitempty"` - AgentPrompt string `json:"agent_prompt,omitempty"` - IsBuilderRoom bool `json:"is_builder_room,omitempty"` - IsSimpleMode bool `json:"is_simple_mode,omitempty"` - AckReactionEmoji string `json:"ack_reaction_emoji,omitempty"` - AckReactionRemoveAfter bool `json:"ack_reaction_remove_after,omitempty"` + AckReactionEmoji string `json:"ack_reaction_emoji,omitempty"` + AckReactionRemoveAfter bool `json:"ack_reaction_remove_after,omitempty"` + PDFConfig *PDFConfig `json:"pdf_config,omitempty"` Slug string `json:"slug,omitempty"` Title string `json:"title,omitempty"` @@ -228,9 +190,6 @@ type PortalMetadata struct { RuntimeModelOverride string `json:"-"` RuntimeReasoning string `json:"-"` - // Legacy-only. New code must not read this. - ResponsePrefix string `json:"response_prefix,omitempty"` - // Debounce configuration (0 = use default, -1 = disabled) DebounceMs int `json:"debounce_ms,omitempty"` diff --git a/pkg/connector/response_finalization.go b/pkg/connector/response_finalization.go index 4e784033..94215526 100644 --- a/pkg/connector/response_finalization.go +++ b/pkg/connector/response_finalization.go @@ -221,12 +221,6 @@ func (oc *AIClient) sendFinalAssistantTurn(ctx context.Context, portal *bridgev2 cleanedContent := airuntime.SanitizeChatMessageForDisplay(directives.Text, false) finalReplyTarget := oc.resolveFinalReplyTarget(meta, state, &directives) - responsePrefix := resolveResponsePrefixForReply(oc, &oc.connector.Config, meta) - if responsePrefix != "" && strings.TrimSpace(cleanedContent) != "" { - if !strings.HasPrefix(cleanedContent, responsePrefix) { - cleanedContent = responsePrefix + " " + cleanedContent - } - } rendered := format.RenderMarkdown(cleanedContent, true, true) if finalReplyTarget.ReplyTo != "" { replyTo := finalReplyTarget.ReplyTo @@ -266,12 +260,6 @@ func (oc *AIClient) sendFinalHeartbeatTurn(ctx context.Context, portal *bridgev2 } shouldSkip = false } - responsePrefix := strings.TrimSpace(hb.ResponsePrefix) - if responsePrefix != "" && strings.TrimSpace(finalText) != "" && !shouldSkip { - if !strings.HasPrefix(finalText, responsePrefix) { - finalText = responsePrefix + " " + finalText - } - } cleaned := strings.TrimSpace(finalText) hasMedia := len(state.pendingImages) > 0 shouldSkipMain := shouldSkip && !hasMedia && !hb.ExecEvent @@ -304,11 +292,7 @@ func (oc *AIClient) sendFinalHeartbeatTurn(ctx context.Context, portal *bridgev2 oc.restoreHeartbeatUpdatedAt(storeRef, hb.SessionKey, hb.PrevUpdatedAt) silent := true if hb.ShowOk && deliverable { - heartbeatOk := agents.HeartbeatToken - if responsePrefix != "" { - heartbeatOk = responsePrefix + " " + agents.HeartbeatToken - } - oc.sendPlainAssistantMessage(ctx, portal, heartbeatOk) + oc.sendPlainAssistantMessage(ctx, portal, agents.HeartbeatToken) silent = false } oc.redactInitialStreamingMessage(ctx, portal, state) @@ -728,11 +712,9 @@ func generateOutboundLinkPreviews(ctx context.Context, text string, intent bridg return UploadPreviewImages(ctx, previewsWithImages, intent, portal.MXID) } -// getAgentResponseMode returns the response mode for the current agent. -// Defaults to ResponseModeNatural if not set. -// IsSimpleMode on the portal overrides all other settings (for simple mode rooms). +// getAgentResponseMode returns the response mode for the current room target. +// Defaults to ResponseModeNatural if no agent-specific mode is configured. func (oc *AIClient) getAgentResponseMode(meta *PortalMetadata) agents.ResponseMode { - // Simple mode flag takes priority (set by simple command) if isSimpleMode(meta) { return agents.ResponseModeSimple } diff --git a/pkg/connector/response_prefix.go b/pkg/connector/response_prefix.go deleted file mode 100644 index 0793bd1e..00000000 --- a/pkg/connector/response_prefix.go +++ /dev/null @@ -1,16 +0,0 @@ -package connector - -func resolveResponsePrefixForHeartbeat(oc *AIClient, cfg *Config, agentID string, meta *PortalMetadata) string { - _ = oc - _ = cfg - _ = agentID - _ = meta - return "" -} - -func resolveResponsePrefixForReply(oc *AIClient, cfg *Config, meta *PortalMetadata) string { - _ = oc - _ = cfg - _ = meta - return "" -} diff --git a/pkg/connector/response_prefix_template.go b/pkg/connector/response_prefix_template.go deleted file mode 100644 index cf819664..00000000 --- a/pkg/connector/response_prefix_template.go +++ /dev/null @@ -1,67 +0,0 @@ -package connector - -import ( - "regexp" - "strings" -) - -// ResponsePrefixContext mirrors OpenClaw's template context. -type ResponsePrefixContext struct { - Model string - ModelFull string - Provider string - ThinkingLevel string - IdentityName string -} - -var responsePrefixTemplatePattern = regexp.MustCompile(`\{([a-zA-Z][a-zA-Z0-9.]*)\}`) -var responsePrefixDateSuffix = regexp.MustCompile(`-\d{8}$`) - -func resolveResponsePrefixTemplate(template string, ctx ResponsePrefixContext) string { - if template == "" { - return "" - } - return responsePrefixTemplatePattern.ReplaceAllStringFunc(template, func(match string) string { - groups := responsePrefixTemplatePattern.FindStringSubmatch(match) - if len(groups) < 2 { - return match - } - varName := strings.ToLower(groups[1]) - switch varName { - case "model": - if ctx.Model != "" { - return ctx.Model - } - case "modelfull": - if ctx.ModelFull != "" { - return ctx.ModelFull - } - case "provider": - if ctx.Provider != "" { - return ctx.Provider - } - case "thinkinglevel", "think": - if ctx.ThinkingLevel != "" { - return ctx.ThinkingLevel - } - case "identity.name", "identityname": - if ctx.IdentityName != "" { - return ctx.IdentityName - } - } - return match - }) -} - -func extractShortModelName(fullModel string) string { - modelPart := strings.TrimSpace(fullModel) - if modelPart == "" { - return "" - } - if idx := strings.LastIndex(modelPart, "/"); idx >= 0 && idx+1 < len(modelPart) { - modelPart = modelPart[idx+1:] - } - modelPart = responsePrefixDateSuffix.ReplaceAllString(modelPart, "") - modelPart = strings.TrimSuffix(modelPart, "-latest") - return modelPart -} diff --git a/pkg/connector/streaming_finish_reason_test.go b/pkg/connector/streaming_finish_reason_test.go index 9472a997..956707fb 100644 --- a/pkg/connector/streaming_finish_reason_test.go +++ b/pkg/connector/streaming_finish_reason_test.go @@ -87,7 +87,7 @@ func TestBuildCanonicalUIMessage_IncludesSourceAndFileParts(t *testing.T) { }}, } - ui := oc.buildCanonicalUIMessage(state, &PortalMetadata{Model: "gpt-4o"}) + ui := oc.buildCanonicalUIMessage(state, simpleModeTestMeta("openai/gpt-4o")) if ui == nil { t.Fatalf("expected canonical message") } diff --git a/pkg/connector/streaming_init_test.go b/pkg/connector/streaming_init_test.go index 7e4cd76e..1f9a4c0d 100644 --- a/pkg/connector/streaming_init_test.go +++ b/pkg/connector/streaming_init_test.go @@ -13,7 +13,6 @@ import ( func TestPrepareStreamingRun_SimpleModeClearsReplyTarget(t *testing.T) { oc := &AIClient{} meta := &PortalMetadata{ - SendPolicy: "deny", ResolvedTarget: &ResolvedTarget{ Kind: ResolvedTargetModel, GhostID: modelUserID("openai/gpt-5.2"), @@ -54,9 +53,7 @@ func TestPrepareStreamingRun_SimpleModeClearsReplyTarget(t *testing.T) { func TestPrepareStreamingRun_NonSimpleKeepsReplyTarget(t *testing.T) { oc := &AIClient{} - meta := &PortalMetadata{ - SendPolicy: "deny", - } + meta := &PortalMetadata{} evt := &event.Event{ ID: id.EventID("$evt"), Sender: id.UserID("@alice:example.com"), diff --git a/pkg/connector/streaming_params.go b/pkg/connector/streaming_params.go index 65457eda..ccd4b3a8 100644 --- a/pkg/connector/streaming_params.go +++ b/pkg/connector/streaming_params.go @@ -35,7 +35,7 @@ func (oc *AIClient) buildResponsesAPIParams(ctx context.Context, portal *bridgev OfInputItemList: input, } - // Add reasoning effort if configured (uses inheritance: room → user → default) + // Add reasoning effort when the resolved target supports it. if reasoningEffort := oc.effectiveReasoningEffort(meta); reasoningEffort != "" { params.Reasoning = shared.ReasoningParam{ Effort: shared.ReasoningEffort(reasoningEffort), diff --git a/pkg/connector/streaming_tool_selection_test.go b/pkg/connector/streaming_tool_selection_test.go index 91a27e41..b6e330d5 100644 --- a/pkg/connector/streaming_tool_selection_test.go +++ b/pkg/connector/streaming_tool_selection_test.go @@ -18,16 +18,7 @@ func TestSelectedBuiltinToolsForTurn_SimpleModeEnablesOnlyWebSearch(t *testing.T }, } - meta := &PortalMetadata{ - ResolvedTarget: &ResolvedTarget{ - Kind: ResolvedTargetModel, - GhostID: modelUserID("openai/gpt-5.2"), - ModelID: "openai/gpt-5.2", - }, - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, - }, - } + meta := simpleModeTestMeta("openai/gpt-5.2") got := client.selectedBuiltinToolsForTurn(context.Background(), meta) if len(got) != 1 { @@ -51,11 +42,7 @@ func TestSelectedBuiltinToolsForTurn_NonAgentNonSimpleGetsNoTools(t *testing.T) }, } - meta := &PortalMetadata{ - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, - }, - } + meta := &PortalMetadata{} got := client.selectedBuiltinToolsForTurn(context.Background(), meta) if len(got) != 0 { diff --git a/pkg/connector/system_prompts_test.go b/pkg/connector/system_prompts_test.go index a583e210..05eda21d 100644 --- a/pkg/connector/system_prompts_test.go +++ b/pkg/connector/system_prompts_test.go @@ -15,7 +15,7 @@ func TestBuildSessionIdentityHint_IncludesRoomIDAndPortalID(t *testing.T) { portal.MXID = id.RoomID("!room:example.org") portal.PortalKey = networkid.PortalKey{ID: networkid.PortalID("portal-123")} - meta := &PortalMetadata{AgentID: "beeper"} + meta := agentModeTestMeta("beeper") got := buildSessionIdentityHint(portal, meta) if got == "" { t.Fatalf("expected non-empty hint") diff --git a/pkg/connector/tool_availability_configured_test.go b/pkg/connector/tool_availability_configured_test.go index 5098ed84..be454bbe 100644 --- a/pkg/connector/tool_availability_configured_test.go +++ b/pkg/connector/tool_availability_configured_test.go @@ -7,6 +7,8 @@ import ( "testing" "github.com/beeper/ai-bridge/pkg/shared/toolspec" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" ) func boolPtr(v bool) *bool { @@ -22,8 +24,11 @@ func TestToolAvailable_WebSearch_RequiresAnyProviderKey(t *testing.T) { }, }, }, + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + ModelCache: &ModelCache{Models: []ModelInfo{{ID: "openai/gpt-5.2", SupportsToolCalling: true}}}, + }}}, } - meta := &PortalMetadata{Capabilities: ModelCapabilities{SupportsToolCalling: true}} + meta := simpleModeTestMeta("openai/gpt-5.2") ok, source, reason := oc.isToolAvailable(meta, toolspec.WebSearchName) if ok { @@ -48,8 +53,11 @@ func TestToolAvailable_WebSearch_WithProviderKey(t *testing.T) { }, }, }, + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + ModelCache: &ModelCache{Models: []ModelInfo{{ID: "openai/gpt-5.2", SupportsToolCalling: true}}}, + }}}, } - meta := &PortalMetadata{Capabilities: ModelCapabilities{SupportsToolCalling: true}} + meta := simpleModeTestMeta("openai/gpt-5.2") ok, _, reason := oc.isToolAvailable(meta, toolspec.WebSearchName) if !ok { @@ -68,8 +76,11 @@ func TestToolAvailable_WebFetch_DirectDisabledAndNoExaKey(t *testing.T) { }, }, }, + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + ModelCache: &ModelCache{Models: []ModelInfo{{ID: "openai/gpt-5.2", SupportsToolCalling: true}}}, + }}}, } - meta := &PortalMetadata{Capabilities: ModelCapabilities{SupportsToolCalling: true}} + meta := simpleModeTestMeta("openai/gpt-5.2") ok, source, reason := oc.isToolAvailable(meta, toolspec.WebFetchName) if ok { @@ -84,8 +95,11 @@ func TestToolAvailable_TTS_PlatformBehavior(t *testing.T) { oc := &AIClient{ connector: &OpenAIConnector{Config: Config{}}, // provider/apiKey intentionally empty + UserLogin: &bridgev2.UserLogin{UserLogin: &database.UserLogin{Metadata: &UserLoginMetadata{ + ModelCache: &ModelCache{Models: []ModelInfo{{ID: "openai/gpt-5.2", SupportsToolCalling: true}}}, + }}}, } - meta := &PortalMetadata{Capabilities: ModelCapabilities{SupportsToolCalling: true}} + meta := simpleModeTestMeta("openai/gpt-5.2") ok, _, reason := oc.isToolAvailable(meta, toolspec.TTSName) if runtime.GOOS == "darwin" { diff --git a/pkg/connector/tool_policy_apply_patch_test.go b/pkg/connector/tool_policy_apply_patch_test.go index e81c951a..cd8e74ff 100644 --- a/pkg/connector/tool_policy_apply_patch_test.go +++ b/pkg/connector/tool_policy_apply_patch_test.go @@ -8,7 +8,12 @@ import ( ) func newTestAIClientWithConfig(cfg Config) *AIClient { - login := &database.UserLogin{Metadata: &UserLoginMetadata{Provider: ProviderOpenAI}} + login := &database.UserLogin{Metadata: &UserLoginMetadata{ + Provider: ProviderOpenAI, + ModelCache: &ModelCache{Models: []ModelInfo{ + {ID: "openai/gpt-5.2", SupportsToolCalling: true}, + }}, + }} userLogin := &bridgev2.UserLogin{UserLogin: login} return &AIClient{ UserLogin: userLogin, @@ -18,12 +23,7 @@ func newTestAIClientWithConfig(cfg Config) *AIClient { func TestApplyPatchAvailability_DisabledByDefault(t *testing.T) { oc := newTestAIClientWithConfig(Config{}) - meta := &PortalMetadata{ - Model: "openai/gpt-5.2", - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, - }, - } + meta := simpleModeTestMeta("openai/gpt-5.2") available, _, _ := oc.isToolAvailable(meta, ToolNameApplyPatch) if available { @@ -42,12 +42,7 @@ func TestApplyPatchAvailability_EnabledWithoutAllowlist(t *testing.T) { }, }, }) - meta := &PortalMetadata{ - Model: "openai/gpt-5.2", - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, - }, - } + meta := simpleModeTestMeta("openai/gpt-5.2") available, _, _ := oc.isToolAvailable(meta, ToolNameApplyPatch) if !available { @@ -67,12 +62,7 @@ func TestApplyPatchAvailability_AllowlistMismatch(t *testing.T) { }, }, }) - meta := &PortalMetadata{ - Model: "openai/gpt-5.2", - Capabilities: ModelCapabilities{ - SupportsToolCalling: true, - }, - } + meta := simpleModeTestMeta("openai/gpt-5.2") available, _, _ := oc.isToolAvailable(meta, ToolNameApplyPatch) if available { diff --git a/pkg/connector/tools_unique_test.go b/pkg/connector/tools_unique_test.go index 4368e2ac..004bb901 100644 --- a/pkg/connector/tools_unique_test.go +++ b/pkg/connector/tools_unique_test.go @@ -21,7 +21,7 @@ func TestToolNamesUnique(t *testing.T) { builtinSeen[tool.Name] = struct{}{} } - // Boss tools (combined with builtin in builder rooms) + // Boss tools (combined with builtin in boss-agent rooms) for _, tool := range agenttools.BossTools() { if tool.Name == "" { t.Fatalf("boss tool has empty name: %+v", tool) From 15a9a388dd7ccfd8df785bc66edcc93b67315363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 8 Mar 2026 22:38:36 +0100 Subject: [PATCH 23/23] Remove dead legacy connector code --- pkg/connector/chat.go | 445 ------------------ pkg/connector/command_aliases.go | 31 -- pkg/connector/commands_helpers.go | 19 - pkg/connector/connector.go | 28 -- pkg/connector/handlematrix.go | 177 ------- pkg/connector/identifiers.go | 30 -- pkg/connector/parse_utils.go | 18 - pkg/connector/queue_directive.go | 139 ------ pkg/connector/queue_notice.go | 23 - pkg/connector/status_text.go | 94 ---- .../tool_availability_configured_test.go | 3 +- pkg/connector/trace.go | 5 - 12 files changed, 2 insertions(+), 1010 deletions(-) delete mode 100644 pkg/connector/parse_utils.go delete mode 100644 pkg/connector/queue_directive.go delete mode 100644 pkg/connector/queue_notice.go diff --git a/pkg/connector/chat.go b/pkg/connector/chat.go index 6415565d..76f7bf91 100644 --- a/pkg/connector/chat.go +++ b/pkg/connector/chat.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/rs/zerolog" "go.mau.fi/util/ptr" "github.com/beeper/ai-bridge/pkg/agents" @@ -20,7 +19,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/simplevent" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -66,28 +64,6 @@ func modelRedirectTarget(requested, resolved string) networkid.UserID { // validateDMModelSwitch enforces the DM invariant that counterpart ghosts are immutable. // Agent rooms are exempt because the stable counterpart ghost is the agent ghost. -func (oc *AIClient) validateDMModelSwitch(portal *bridgev2.Portal, meta *PortalMetadata, targetModel string) error { - if oc == nil || portal == nil || meta == nil || strings.TrimSpace(targetModel) == "" { - return nil - } - if portal.RoomType != database.RoomTypeDM { - return nil - } - if resolveAgentID(meta) != "" { - return nil - } - currentModel := oc.effectiveModel(meta) - if currentModel == "" || currentModel == targetModel { - return nil - } - currentGhost := modelUserID(currentModel) - targetGhost := modelUserID(targetModel) - if currentGhost == targetGhost { - return nil - } - return fmt.Errorf("%w: %s -> %s", ErrDMGhostImmutable, currentModel, targetModel) -} - // buildAvailableTools returns a list of ToolInfo for all tools based on tool policy. func (oc *AIClient) buildAvailableTools(meta *PortalMetadata) []ToolInfo { names := oc.toolNamesForPortal(meta) @@ -743,90 +719,6 @@ func (oc *AIClient) initPortalForChat(ctx context.Context, opts PortalInitOpts) return portal, chatInfo, nil } -// handleFork creates a new chat and copies messages from the current conversation -func (oc *AIClient) handleFork( - ctx context.Context, - _ *event.Event, - portal *bridgev2.Portal, - meta *PortalMetadata, - arg string, -) { - runCtx := oc.backgroundContext(ctx) - - // 1. Retrieve all messages from current chat - messages, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(runCtx, portal.PortalKey, 10000) - if err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't load messages: "+err.Error()) - return - } - - if len(messages) == 0 { - oc.sendSystemNotice(runCtx, portal, "No messages to fork.") - return - } - - // 2. If event ID specified, filter messages up to that point - var messagesToCopy []*database.Message - if arg != "" { - // Validate Matrix event ID format - if !strings.HasPrefix(arg, "$") { - oc.sendSystemNotice(runCtx, portal, "Invalid event ID. Must start with '$'.") - return - } - - // Messages are newest-first, reverse iterate to find target - found := false - for i := len(messages) - 1; i >= 0; i-- { - msg := messages[i] - messagesToCopy = append(messagesToCopy, msg) - - // Check MXID field (Matrix event ID) - if msg.MXID != "" && string(msg.MXID) == arg { - found = true - break - } - // Check message ID format "mx:$eventid" - if strings.HasSuffix(string(msg.ID), arg) { - found = true - break - } - } - - if !found { - oc.sendSystemNotice(runCtx, portal, fmt.Sprintf("Couldn't find event: %s", arg)) - return - } - } else { - // Copy all messages (reverse to get chronological order) - for i := len(messages) - 1; i >= 0; i-- { - messagesToCopy = append(messagesToCopy, messages[i]) - } - } - - // 3. Create new chat with same configuration - newPortal, chatInfo, err := oc.createForkedChat(runCtx, portal, meta) - if err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't create the forked chat: "+err.Error()) - return - } - - // 4. Create Matrix room - if err := newPortal.CreateMatrixRoom(runCtx, oc.UserLogin, chatInfo); err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't create the room: "+err.Error()) - return - } - - // 5. Copy messages to new chat - copiedCount := oc.copyMessagesToChat(runCtx, newPortal, messagesToCopy) - - // 6. Send notice with link - roomLink := fmt.Sprintf("https://matrix.to/#/%s", newPortal.MXID) - oc.sendSystemNotice(runCtx, portal, fmt.Sprintf( - "Forked %d messages to new chat.\nOpen: %s", - copiedCount, roomLink, - )) -} - // handleNewChat creates a new chat using the current room's agent/model, // or an explicitly provided agent/model. func (oc *AIClient) handleNewChat( @@ -984,127 +876,6 @@ func (oc *AIClient) createAndOpenSimpleChat(ctx context.Context, portal *bridgev )) } -// createForkedChat creates a new portal inheriting config from source -func (oc *AIClient) createForkedChat( - ctx context.Context, - sourcePortal *bridgev2.Portal, - sourceMeta *PortalMetadata, -) (*bridgev2.Portal, *bridgev2.ChatInfo, error) { - sourceTitle := sourceMeta.Title - if sourceTitle == "" { - sourceTitle = sourcePortal.Name - } - title := fmt.Sprintf("%s (Fork)", sourceTitle) - - portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ - Title: title, - CopyFrom: sourceMeta, - }) - if err != nil { - return nil, nil, err - } - - agentID := resolveAgentID(sourceMeta) - if agentID != "" { - pm := portalMeta(portal) - - modelID := oc.effectiveModel(pm) - portal.OtherUserID = agentUserID(agentID) - pm.ResolvedTarget = resolveTargetFromGhostID(portal.OtherUserID) - - agentName := agentID - agentAvatar := "" - // Try preset first - guaranteed to work for built-in agents (like "beeper") - if preset := agents.GetPresetByID(agentID); preset != nil { - agentName = oc.resolveAgentDisplayName(ctx, preset) - agentAvatar = preset.AvatarURL - } else { - // Custom agent - need Matrix state lookup - store := NewAgentStoreAdapter(oc) - if agent, err := store.GetAgentByID(ctx, agentID); err == nil && agent != nil { - agentName = oc.resolveAgentDisplayName(ctx, agent) - agentAvatar = agent.AvatarURL - } - } - if strings.TrimSpace(agentAvatar) == "" { - agentAvatar = strings.TrimSpace(agents.DefaultAgentAvatarMXC) - } - if agentAvatar != "" { - portal.AvatarID = networkid.AvatarID(agentAvatar) - portal.AvatarMXC = id.ContentURIString(agentAvatar) - } - oc.applyAgentChatInfo(chatInfo, agentID, agentName, modelID) - oc.ensureAgentGhostDisplayName(ctx, agentID, modelID, agentName) - - if err := portal.Save(ctx); err != nil { - return nil, nil, err - } - } - - return portal, chatInfo, nil -} - -// copyMessagesToChat queues messages to be bridged to the new chat -// Returns the count of successfully queued messages -func (oc *AIClient) copyMessagesToChat( - ctx context.Context, - destPortal *bridgev2.Portal, - messages []*database.Message, -) int { - copiedCount := 0 - skippedCount := 0 - - for _, srcMsg := range messages { - srcMeta := messageMeta(srcMsg) - if srcMeta == nil || srcMeta.Body == "" { - skippedCount++ - continue - } - - // Determine sender - var sender bridgev2.EventSender - if srcMeta.Role == "user" { - sender = bridgev2.EventSender{ - Sender: humanUserID(oc.UserLogin.ID), - SenderLogin: oc.UserLogin.ID, - IsFromMe: true, - } - } else { - sender = bridgev2.EventSender{ - Sender: srcMsg.SenderID, - SenderLogin: oc.UserLogin.ID, - IsFromMe: false, - } - } - - // Create remote message for bridging - remoteMsg := &OpenAIRemoteMessage{ - PortalKey: destPortal.PortalKey, - ID: bridgeadapter.NewMessageID("fork"), - Sender: sender, - Content: srcMeta.Body, - Timestamp: srcMsg.Timestamp, - Metadata: &MessageMetadata{ - BaseMessageMetadata: bridgeadapter.BaseMessageMetadata{Role: srcMeta.Role, Body: srcMeta.Body}, - }, - } - - oc.UserLogin.QueueRemoteEvent(remoteMsg) - copiedCount++ - } - - // Log if partial copy occurred (some messages were skipped) - if skippedCount > 0 { - oc.loggerForContext(ctx).Warn(). - Int("copied", copiedCount). - Int("skipped", skippedCount). - Int("total", len(messages)). - Msg("Partial fork - some messages were skipped due to missing metadata") - } - - return copiedCount -} - // createNewSimpleChat creates a new simple mode chat portal with the specified model. func (oc *AIClient) createNewSimpleChat(ctx context.Context, modelID string) (*bridgev2.Portal, *bridgev2.ChatInfo, error) { portal, chatInfo, err := oc.initPortalForChat(ctx, PortalInitOpts{ @@ -1226,222 +997,6 @@ func (oc *AIClient) applyAgentChatInfo(chatInfo *bridgev2.ChatInfo, agentID, age chatInfo.Members = members } -// handleModelSwitch generates membership change events when switching models -// This creates leave/join events to show the model transition in the room timeline -// For agent rooms, it updates the agent ghost metadata. -func (oc *AIClient) handleModelSwitch(ctx context.Context, portal *bridgev2.Portal, oldModel, newModel string) { - if oldModel == newModel || oldModel == "" || newModel == "" { - return - } - - meta := portalMeta(portal) - agentID := resolveAgentID(meta) - - // Check if this is an agent room - update agent ghost metadata - if agentID != "" { - oc.handleAgentModelSwitch(ctx, portal, agentID, oldModel, newModel) - return - } - - // For non-agent rooms, use simple mode ghosts - oc.loggerForContext(ctx).Info(). - Str("old_model", oldModel). - Str("new_model", newModel). - Stringer("portal", portal.PortalKey). - Msg("Handling model switch") - - oldInfo := oc.findModelInfo(oldModel) - newInfo := oc.findModelInfo(newModel) - oldModelName := modelContactName(oldModel, oldInfo) - newModelName := modelContactName(newModel, newInfo) - - // Pre-update the new model ghost's profile before queueing the event - // This ensures the ghost has a display name set in its Matrix profile - newGhost, err := oc.UserLogin.Bridge.GetGhostByID(ctx, modelUserID(newModel)) - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Str("model", newModel).Msg("Failed to get ghost for model switch") - } else { - oc.ensureGhostDisplayNameWithGhost(ctx, newGhost, newModel, newInfo) - } - - // Create member changes: old model leaves, new model joins - // Use MemberEventExtra to set displayname directly in the membership event - // This works because MemberEventContent.Displayname has omitempty, so our Raw value is preserved - memberChanges := &bridgev2.ChatMemberList{ - MemberMap: bridgev2.ChatMemberMap{ - modelUserID(oldModel): { - EventSender: bridgev2.EventSender{ - Sender: modelUserID(oldModel), - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipLeave, - PrevMembership: event.MembershipJoin, - }, - modelUserID(newModel): modelJoinMember(oc.UserLogin.ID, newModel, newModelName, newInfo), - }, - } - - // Update portal's OtherUserID to new model - portal.OtherUserID = modelUserID(newModel) - - // Queue the ChatInfoChange event - evt := &simplevent.ChatInfoChange{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventChatInfoChange, - PortalKey: portal.PortalKey, - Timestamp: time.Now(), - LogContext: func(c zerolog.Context) zerolog.Context { - return c.Str("action", "model_switch"). - Str("old_model", oldModel). - Str("new_model", newModel) - }, - }, - ChatInfoChange: &bridgev2.ChatInfoChange{ - MemberChanges: memberChanges, - }, - } - - oc.UserLogin.QueueRemoteEvent(evt) - - // Send a notice about the model change from the bridge bot - notice := fmt.Sprintf("Switched from %s to %s", oldModelName, newModelName) - oc.sendSystemNotice(ctx, portal, notice) - - // Update bridge info and capabilities to resend room features state event with new capabilities - // This ensures the client knows what features the new model supports (vision, audio, etc.) - portal.UpdateBridgeInfo(ctx) - portal.UpdateCapabilities(ctx, oc.UserLogin, true) - - // Ensure only 1 AI ghost in room - if err := oc.ensureSingleAIGhost(ctx, portal); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure single AI ghost after model switch") - } -} - -// handleAgentModelSwitch handles model switching for agent rooms. -// Keeps a single agent ghost and updates member metadata. -func (oc *AIClient) handleAgentModelSwitch(ctx context.Context, portal *bridgev2.Portal, agentID, oldModel, newModel string) { - // Get the agent to determine display name - store := NewAgentStoreAdapter(oc) - agent, err := store.GetAgentByID(ctx, agentID) - if err != nil || agent == nil { - oc.loggerForContext(ctx).Warn().Err(err).Str("agent", agentID).Msg("Agent not found for model switch") - return - } - - oc.loggerForContext(ctx).Info(). - Str("agent", agentID). - Str("old_model", oldModel). - Str("new_model", newModel). - Stringer("portal", portal.PortalKey). - Msg("Handling agent model switch") - - ghostID := agentUserID(agentID) - agentName := oc.resolveAgentDisplayName(ctx, agent) - displayName := agentName - oldModelName := modelContactName(oldModel, oc.findModelInfo(oldModel)) - newModelName := modelContactName(newModel, oc.findModelInfo(newModel)) - oldGhostID := portal.OtherUserID - - // Update member metadata for the agent ghost - memberMap := bridgev2.ChatMemberMap{ - ghostID: { - EventSender: bridgev2.EventSender{ - Sender: ghostID, - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipJoin, - UserInfo: &bridgev2.UserInfo{ - Name: ptr.Ptr(displayName), - IsBot: ptr.Ptr(true), - Identifiers: agentContactIdentifiers(agentID, newModel, oc.findModelInfo(newModel)), - }, - MemberEventExtra: map[string]any{ - "displayname": displayName, - "com.beeper.ai.model_id": newModel, - "com.beeper.ai.agent": agentID, - }, - }, - } - if oldGhostID != "" && oldGhostID != ghostID { - memberMap[oldGhostID] = bridgev2.ChatMember{ - EventSender: bridgev2.EventSender{ - Sender: oldGhostID, - SenderLogin: oc.UserLogin.ID, - }, - Membership: event.MembershipLeave, - PrevMembership: event.MembershipJoin, - } - } - memberChanges := &bridgev2.ChatMemberList{MemberMap: memberMap} - - // Update portal's OtherUserID to agent ghost - portal.OtherUserID = ghostID - oc.ensureAgentGhostDisplayName(ctx, agentID, newModel, agentName) - - // Queue the ChatInfoChange event - evt := &simplevent.ChatInfoChange{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventChatInfoChange, - PortalKey: portal.PortalKey, - Timestamp: time.Now(), - LogContext: func(c zerolog.Context) zerolog.Context { - return c.Str("action", "agent_model_switch"). - Str("agent", agentID). - Str("old_model", oldModel). - Str("new_model", newModel) - }, - }, - ChatInfoChange: &bridgev2.ChatInfoChange{ - MemberChanges: memberChanges, - }, - } - - oc.UserLogin.QueueRemoteEvent(evt) - - // Send a notice about the model change - notice := fmt.Sprintf("Switched model from %s to %s", oldModelName, newModelName) - oc.sendSystemNotice(ctx, portal, notice) - - // Update bridge info and capabilities - portal.UpdateBridgeInfo(ctx) - portal.UpdateCapabilities(ctx, oc.UserLogin, true) - - // Ensure only 1 AI ghost in room - if err := oc.ensureSingleAIGhost(ctx, portal); err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("Failed to ensure single AI ghost after agent model switch") - } -} - -// ensureSingleAIGhost ensures only 1 model/agent ghost is in the room at a time. -// Updates portal.OtherUserID if it doesn't match the expected ghost. -func (oc *AIClient) ensureSingleAIGhost(ctx context.Context, portal *bridgev2.Portal) error { - meta := portalMeta(portal) - - // Determine which ghost SHOULD be in the room - var expectedGhostID networkid.UserID - agentID := resolveAgentID(meta) - - modelID := oc.effectiveModel(meta) - if agentID != "" { - expectedGhostID = agentUserID(agentID) - } else { - expectedGhostID = modelUserID(modelID) - } - - // Update portal.OtherUserID if mismatched - if portal.OtherUserID != expectedGhostID { - oc.loggerForContext(ctx).Debug(). - Str("old_ghost", string(portal.OtherUserID)). - Str("new_ghost", string(expectedGhostID)). - Stringer("portal", portal.PortalKey). - Msg("Updating portal OtherUserID to match expected ghost") - portal.OtherUserID = expectedGhostID - return portal.Save(ctx) - } - return nil -} - // BroadcastRoomState refreshes standard Matrix room capabilities and command descriptions. func (oc *AIClient) BroadcastRoomState(ctx context.Context, portal *bridgev2.Portal) error { portal.UpdateCapabilities(ctx, oc.UserLogin, true) diff --git a/pkg/connector/command_aliases.go b/pkg/connector/command_aliases.go index ecccdbe5..5b122c33 100644 --- a/pkg/connector/command_aliases.go +++ b/pkg/connector/command_aliases.go @@ -1,36 +1,5 @@ package connector -var thinkLevelAliases = map[string]string{ - "off": "off", - "on": "low", - "minimal": "minimal", - "low": "low", - "medium": "medium", - "high": "high", - "xhigh": "xhigh", -} - -var verboseLevelAliases = map[string]string{ - "off": "off", - "on": "on", - "full": "full", -} - -var reasoningLevelAliases = map[string]string{ - "off": "off", - "on": "on", - "low": "low", - "medium": "medium", - "high": "high", - "xhigh": "xhigh", -} - -var sendPolicyAliases = map[string]string{ - "on": "allow", - "off": "deny", - "inherit": "inherit", -} - var groupActivationAliases = map[string]string{ "mention": "mention", "always": "always", diff --git a/pkg/connector/commands_helpers.go b/pkg/connector/commands_helpers.go index 2be022e0..771f9a49 100644 --- a/pkg/connector/commands_helpers.go +++ b/pkg/connector/commands_helpers.go @@ -2,8 +2,6 @@ package connector import ( "maunium.net/go/mautrix/bridgev2/commands" - - "github.com/beeper/ai-bridge/pkg/agents" ) func requireClientMeta(ce *commands.Event) (*AIClient, *PortalMetadata, bool) { @@ -15,20 +13,3 @@ func requireClientMeta(ce *commands.Event) (*AIClient, *PortalMetadata, bool) { } return client, meta, true } - -func requireClient(ce *commands.Event) (*AIClient, bool) { - client := getAIClient(ce) - if client == nil { - ce.Reply("Couldn't load AI settings. Try again.") - return nil, false - } - return client, true -} - -func rejectBossOverrides(ce *commands.Event, meta *PortalMetadata, message string) bool { - if agents.IsBossAgent(resolveAgentID(meta)) { - ce.Reply(message) - return true - } - return false -} diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b6fd2646..7f5bd399 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -242,31 +242,3 @@ func (oc *OpenAIConnector) CreateLogin(ctx context.Context, user *bridgev2.User, } return &OpenAILogin{User: user, Connector: oc, FlowID: flowID}, nil } - -// getLoginForPortal finds the correct user login based on the portal's Receiver. -// This ensures we use the correct provider/API credentials when a user has multiple accounts. -func (oc *OpenAIConnector) getLoginForPortal(ctx context.Context, user *bridgev2.User, portal *bridgev2.Portal) *bridgev2.UserLogin { - if portal == nil { - return oc.getPreferredUserLogin(ctx, user) - } - - // The portal's Receiver field contains the UserLogin ID that owns this portal - receiverID := portal.Receiver - if receiverID == "" { - oc.br.Log.Warn().Stringer("portal", portal.PortalKey).Msg("Portal has no receiver, using default login") - return oc.getPreferredUserLogin(ctx, user) - } - - // Get the specific login that matches the portal's receiver - login, err := oc.br.GetExistingUserLoginByID(ctx, receiverID) - if err != nil || login == nil { - oc.br.Log.Warn(). - Err(err). - Stringer("portal", portal.PortalKey). - Str("receiver", string(receiverID)). - Msg("Failed to get login for portal receiver, using default login") - return oc.getPreferredUserLogin(ctx, user) - } - - return login -} diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index 39b13999..aa681fb6 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -1207,183 +1207,6 @@ func (oc *AIClient) removeAckReaction(ctx context.Context, portal *bridgev2.Port Msg("Queued ack reaction removal") } -// handleToolsCommand handles the !ai tools command for per-tool management -func (oc *AIClient) handleToolsCommand( - ctx context.Context, - portal *bridgev2.Portal, - meta *PortalMetadata, - arg string, -) { - runCtx := oc.backgroundContext(ctx) - - // No args - show status - if arg == "" { - oc.showToolsStatus(runCtx, portal, meta) - return - } - - action, _, _ := strings.Cut(arg, " ") - action = strings.ToLower(action) - - switch action { - case "list": - oc.showToolsStatus(runCtx, portal, meta) - case "on", "enable", "true", "1", "off", "disable", "false", "0": - oc.sendSystemNotice(runCtx, portal, "Per-tool toggles aren't supported anymore. Update tool policy in agent settings or the global tool_policy config.") - default: - oc.sendSystemNotice(runCtx, portal, "Usage:\n"+ - "• !ai tools - Show current tool status\n"+ - "• !ai tools list - List available tools\n"+ - "Tool toggles are managed by tool policy.") - } -} - -// showToolsStatus displays the current status of all tools -func (oc *AIClient) showToolsStatus(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) { - oc.sendSystemNotice(ctx, portal, oc.buildToolsStatusText(meta)) -} - -// handleRegenerate regenerates the last AI response -func (oc *AIClient) handleRegenerate( - ctx context.Context, - evt *event.Event, - portal *bridgev2.Portal, - meta *PortalMetadata, -) { - runCtx := oc.backgroundContext(ctx) - - // Get message history - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(runCtx, portal.PortalKey, 10) - if err != nil || len(history) == 0 { - oc.sendSystemNotice(runCtx, portal, "No messages to regenerate from.") - return - } - - // Find the last user message - var lastUserMessage *database.Message - for _, msg := range history { - msgMeta := messageMeta(msg) - if msgMeta != nil && msgMeta.Role == "user" { - lastUserMessage = msg - break - } - } - - if lastUserMessage == nil { - oc.sendSystemNotice(runCtx, portal, "No user message found to regenerate from.") - return - } - - userMeta := messageMeta(lastUserMessage) - if userMeta == nil || userMeta.Body == "" { - oc.sendSystemNotice(runCtx, portal, "Can't regenerate: message content isn't available.") - return - } - - oc.sendSystemNotice(runCtx, portal, "Regenerating response...") - - // Build prompt excluding the old assistant response - promptContext, err := oc.buildContextForRegenerate(runCtx, portal, meta, userMeta.Body, lastUserMessage.MXID) - if err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't regenerate: "+err.Error()) - return - } - - queueSettings, _, _, _ := oc.resolveQueueSettingsForPortal(runCtx, portal, meta, "", airuntime.QueueInlineOptions{}) - isGroup := oc.isGroupChat(runCtx, portal) - pending := pendingMessage{ - Event: evt, - Portal: portal, - Meta: meta, - Type: pendingTypeRegenerate, - MessageBody: userMeta.Body, - SourceEventID: lastUserMessage.MXID, - Typing: &TypingContext{ - IsGroup: isGroup, - WasMentioned: true, - }, - } - queueItem := pendingQueueItem{ - pending: pending, - messageID: string(evt.ID), - summaryLine: userMeta.Body, - enqueuedAt: time.Now().UnixMilli(), - } - oc.dispatchOrQueueWithStatus(runCtx, evt, portal, meta, queueItem, queueSettings, promptContext) -} - -// handleRegenerateTitle regenerates the current room title from recent messages. -func (oc *AIClient) handleRegenerateTitle( - ctx context.Context, - portal *bridgev2.Portal, -) { - runCtx := oc.backgroundContext(ctx) - - history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(runCtx, portal.PortalKey, 20) - if err != nil || len(history) == 0 { - oc.sendSystemNotice(runCtx, portal, "No messages to generate a title from.") - return - } - - var lastUserMessage *database.Message - var lastAssistantMessage *database.Message - for _, msg := range history { - msgMeta := messageMeta(msg) - if !shouldIncludeInHistory(msgMeta) { - continue - } - if lastAssistantMessage == nil && msgMeta.Role == "assistant" { - lastAssistantMessage = msg - } - if lastUserMessage == nil && msgMeta.Role == "user" { - lastUserMessage = msg - } - if lastUserMessage != nil && lastAssistantMessage != nil { - break - } - } - - if lastUserMessage == nil { - oc.sendSystemNotice(runCtx, portal, "No user message found to generate a title from.") - return - } - - userMeta := messageMeta(lastUserMessage) - if userMeta == nil || userMeta.Body == "" { - oc.sendSystemNotice(runCtx, portal, "Can't generate a title: message content isn't available.") - return - } - - assistantBody := "" - if lastAssistantMessage != nil { - assistantMeta := messageMeta(lastAssistantMessage) - if assistantMeta != nil { - assistantBody = assistantMeta.Body - } - } - - oc.sendSystemNotice(runCtx, portal, "Regenerating title...") - - title, err := oc.generateRoomTitle(runCtx, userMeta.Body, assistantBody) - if err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't generate a title: "+err.Error()) - return - } - - title = strings.TrimSpace(title) - if title == "" { - oc.sendSystemNotice(runCtx, portal, "Couldn't generate a title: empty response.") - return - } - - if err := oc.setRoomName(runCtx, portal, title); err != nil { - oc.sendSystemNotice(runCtx, portal, "Couldn't set the room title: "+err.Error()) - return - } - - oc.sendSystemNotice(runCtx, portal, fmt.Sprintf("Room title updated to: %s", title)) -} - // buildPromptForRegenerate builds a prompt for regeneration, excluding the last assistant message func (oc *AIClient) buildContextForRegenerate( ctx context.Context, diff --git a/pkg/connector/identifiers.go b/pkg/connector/identifiers.go index 57ce006b..3a605a18 100644 --- a/pkg/connector/identifiers.go +++ b/pkg/connector/identifiers.go @@ -27,32 +27,10 @@ func nthLoginID(providerSlug string, mxid id.UserID, ordinal int) networkid.User return networkid.UserLoginID(fmt.Sprintf("%s:%d", base, ordinal)) } -func nextLoginID(user *bridgev2.User, providerSlug string, mxid id.UserID) networkid.UserLoginID { - used := map[string]struct{}{} - if user != nil { - for _, existing := range user.GetUserLogins() { - if existing == nil { - continue - } - used[string(existing.ID)] = struct{}{} - } - } - for ordinal := 1; ; ordinal++ { - loginID := nthLoginID(providerSlug, mxid, ordinal) - if _, ok := used[string(loginID)]; !ok { - return loginID - } - } -} - func providerLoginID(provider string, mxid id.UserID, ordinal int) networkid.UserLoginID { return nthLoginID(providerSlug(provider), mxid, ordinal) } -func nextProviderLoginID(user *bridgev2.User, provider string, mxid id.UserID) networkid.UserLoginID { - return nextLoginID(user, providerSlug(provider), mxid) -} - func managedBeeperLoginID(mxid id.UserID) networkid.UserLoginID { return baseLoginID("beeper", mxid) } @@ -161,14 +139,6 @@ func resolveTargetFromGhostID(ghostID networkid.UserID) *ResolvedTarget { return nil } -func resolvedAgentIDForGhost(ghostID networkid.UserID) string { - target := resolveTargetFromGhostID(ghostID) - if target == nil { - return "" - } - return target.AgentID -} - func portalMeta(portal *bridgev2.Portal) *PortalMetadata { meta := bridgeadapter.EnsurePortalMetadata[PortalMetadata](portal) if meta != nil { diff --git a/pkg/connector/parse_utils.go b/pkg/connector/parse_utils.go deleted file mode 100644 index 501b644a..00000000 --- a/pkg/connector/parse_utils.go +++ /dev/null @@ -1,18 +0,0 @@ -package connector - -import ( - "errors" - "strconv" - "strings" -) - -func parsePositiveInt(raw string) (int, error) { - value, err := strconv.Atoi(strings.TrimSpace(raw)) - if err != nil { - return 0, err - } - if value <= 0 { - return 0, errors.New("value must be positive") - } - return value, nil -} diff --git a/pkg/connector/queue_directive.go b/pkg/connector/queue_directive.go deleted file mode 100644 index ca13fe95..00000000 --- a/pkg/connector/queue_directive.go +++ /dev/null @@ -1,139 +0,0 @@ -package connector - -import ( - "fmt" - "strings" - - airuntime "github.com/beeper/ai-bridge/pkg/runtime" -) - -type queueDirective struct { - QueueMode airuntime.QueueMode - QueueReset bool - RawMode string - DebounceMs *int - Cap *int - DropPolicy *airuntime.QueueDropPolicy - RawDebounce string - RawCap string - RawDrop string - HasOptions bool - HasDebounce bool - HasCap bool - HasDrop bool -} - -func parseQueueDebounce(raw string) *int { - if strings.TrimSpace(raw) == "" { - return nil - } - parsed, err := parseDurationMs(raw, "ms") - if err != nil { - return nil - } - value := int(parsed) - if value < 0 { - return nil - } - return &value -} - -func parseQueueCap(raw string) *int { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return nil - } - value := 0 - if _, err := fmt.Sscanf(trimmed, "%d", &value); err != nil { - return nil - } - if value < 1 { - return nil - } - return &value -} - -func parseQueueDirectiveArgs(raw string) (consumed int, result queueDirective) { - i := 0 - for i < len(raw) && raw[i] <= ' ' { - i++ - } - if i < len(raw) && raw[i] == ':' { - i++ - for i < len(raw) && raw[i] <= ' ' { - i++ - } - } - consumed = i - takeToken := func() string { - if i >= len(raw) { - return "" - } - start := i - for i < len(raw) && raw[i] > ' ' { - i++ - } - token := raw[start:i] - for i < len(raw) && raw[i] <= ' ' { - i++ - } - if token == "" { - return "" - } - consumed = i - return token - } - - for i < len(raw) { - token := takeToken() - if token == "" { - break - } - lowered := strings.ToLower(strings.TrimSpace(token)) - if lowered == "reset" { - result.QueueReset = true - break - } - if strings.HasPrefix(lowered, "debounce:") { - _, value, _ := strings.Cut(token, ":") - if value != "" { - result.RawDebounce = value - result.DebounceMs = parseQueueDebounce(value) - result.HasOptions = true - result.HasDebounce = true - } - continue - } - if strings.HasPrefix(lowered, "cap:") { - _, value, _ := strings.Cut(token, ":") - if value != "" { - result.RawCap = value - result.Cap = parseQueueCap(value) - result.HasOptions = true - result.HasCap = true - } - continue - } - if strings.HasPrefix(lowered, "drop:") { - _, value, _ := strings.Cut(token, ":") - if value != "" { - result.RawDrop = value - if policy, ok := airuntime.NormalizeQueueDropPolicy(value); ok { - result.DropPolicy = &policy - } - result.HasOptions = true - result.HasDrop = true - } - continue - } - if mode, ok := airuntime.NormalizeQueueMode(token); ok { - result.QueueMode = mode - result.RawMode = token - continue - } - break - } - return consumed, result -} - -// NOTE: Slash-style inline `/queue ...` directives are intentionally not supported. diff --git a/pkg/connector/queue_notice.go b/pkg/connector/queue_notice.go deleted file mode 100644 index 5238e99f..00000000 --- a/pkg/connector/queue_notice.go +++ /dev/null @@ -1,23 +0,0 @@ -package connector - -import ( - "fmt" - - airuntime "github.com/beeper/ai-bridge/pkg/runtime" -) - -const queueDirectiveOptionsHint = "modes steer, followup, collect, steer+backlog, interrupt; debounce:, cap:, drop:old|new|summarize" - -func buildQueueStatusLine(settings airuntime.QueueSettings) string { - debounceLabel := fmt.Sprintf("%dms", settings.DebounceMs) - capLabel := fmt.Sprintf("%d", settings.Cap) - dropLabel := string(settings.DropPolicy) - return fmt.Sprintf( - "Current queue settings: mode=%s, debounce=%s, cap=%s, drop=%s.\nOptions: %s.", - settings.Mode, - debounceLabel, - capLabel, - dropLabel, - queueDirectiveOptionsHint, - ) -} diff --git a/pkg/connector/status_text.go b/pkg/connector/status_text.go index 09153dac..6c24244f 100644 --- a/pkg/connector/status_text.go +++ b/pkg/connector/status_text.go @@ -1,14 +1,11 @@ package connector import ( - "cmp" "context" "fmt" - "slices" "strings" "time" - "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" airuntime "github.com/beeper/ai-bridge/pkg/runtime" @@ -182,64 +179,6 @@ func formatTypingInterval(interval time.Duration) string { return fmt.Sprintf("%ds", seconds) } -func (oc *AIClient) buildContextStatus(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) string { - if meta == nil || portal == nil { - return "Context unavailable" - } - var sb strings.Builder - sb.WriteString("Context\n") - modelID := oc.effectiveModel(meta) - provider := strings.TrimSpace(loginMetadata(oc.UserLogin).Provider) - if provider != "" { - sb.WriteString(fmt.Sprintf("Model: %s/%s\n", provider, modelID)) - } else { - sb.WriteString(fmt.Sprintf("Model: %s\n", modelID)) - } - - contextWindow := oc.getModelContextWindow(meta) - estimate := oc.estimatePromptTokens(ctx, portal, meta) - if estimate > 0 { - sb.WriteString(fmt.Sprintf( - "Prompt estimate: %s/%s (%s)\n", - formatCompactTokens(int64(estimate)), - formatCompactTokens(int64(contextWindow)), - formatPercent(estimate, contextWindow), - )) - } else { - sb.WriteString(fmt.Sprintf("Context window: %s tokens\n", formatCompactTokens(int64(contextWindow)))) - } - - systemPrompt := oc.effectivePrompt(meta) - if systemPrompt != "" { - sysTokens := 0 - if count, err := EstimateTokens([]openai.ChatCompletionMessageParamUnion{openai.SystemMessage(systemPrompt)}, modelID); err == nil { - sysTokens = count - } - sysLine := fmt.Sprintf("System prompt: %d chars", len(systemPrompt)) - if sysTokens > 0 { - sysLine = fmt.Sprintf("%s (%s tokens)", sysLine, formatCompactTokens(int64(sysTokens))) - } - sb.WriteString(sysLine + "\n") - } - - historyLimit := oc.historyLimit(ctx, portal, meta) - historyCount := 0 - if historyLimit > 0 { - if history, err := oc.UserLogin.Bridge.DB.Message.GetLastNInPortal(ctx, portal.PortalKey, historyLimit); err == nil { - historyCount = len(history) - } - } - sb.WriteString(fmt.Sprintf("History limit: %d messages\n", historyLimit)) - sb.WriteString(fmt.Sprintf("History loaded: %d messages\n", historyCount)) - - sb.WriteString(fmt.Sprintf("Compactions: %d\n", meta.CompactionCount)) - - if meta.SessionResetAt > 0 { - sb.WriteString(fmt.Sprintf("Session reset: %s\n", time.UnixMilli(meta.SessionResetAt).Format(time.RFC3339))) - } - return strings.TrimSpace(sb.String()) -} - type assistantUsageSnapshot struct { promptTokens int64 completionTokens int64 @@ -338,36 +277,3 @@ func formatAge(deltaMs int64) string { } return fmt.Sprintf("%dd ago", int(d.Hours()/24)) } - -func (oc *AIClient) buildToolsStatusText(meta *PortalMetadata) string { - var sb strings.Builder - sb.WriteString("Tool Status:\n\n") - - toolsList := oc.buildAvailableTools(meta) - slices.SortFunc(toolsList, func(a, b ToolInfo) int { - return cmp.Compare(a.Name, b.Name) - }) - - sb.WriteString("Tools:\n") - for _, tool := range toolsList { - status := "✗" - if tool.Enabled { - status = "✓" - } - desc := tool.Description - if desc == "" { - desc = tool.DisplayName - } - reason := "" - if !tool.Enabled && tool.Reason != "" { - reason = fmt.Sprintf(" (%s)", tool.Reason) - } - sb.WriteString(fmt.Sprintf(" [%s] %s: %s%s\n", status, tool.Name, desc, reason)) - } - - if meta != nil && !oc.getModelCapabilitiesForMeta(meta).SupportsToolCalling { - sb.WriteString(fmt.Sprintf("\nNote: Current model (%s) may not support tool calling.\n", oc.effectiveModel(meta))) - } - - return strings.TrimSpace(sb.String()) -} diff --git a/pkg/connector/tool_availability_configured_test.go b/pkg/connector/tool_availability_configured_test.go index be454bbe..92642edf 100644 --- a/pkg/connector/tool_availability_configured_test.go +++ b/pkg/connector/tool_availability_configured_test.go @@ -6,9 +6,10 @@ import ( "strings" "testing" - "github.com/beeper/ai-bridge/pkg/shared/toolspec" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" + + "github.com/beeper/ai-bridge/pkg/shared/toolspec" ) func boolPtr(v bool) *bool { diff --git a/pkg/connector/trace.go b/pkg/connector/trace.go index ec95b91f..825e139f 100644 --- a/pkg/connector/trace.go +++ b/pkg/connector/trace.go @@ -1,10 +1,5 @@ package connector -func traceLevel(meta *PortalMetadata) string { - _ = meta - return "off" -} - func traceEnabled(meta *PortalMetadata) bool { _ = meta return false