Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions ilink/context_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package ilink

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
)

var contextStoreMu sync.Mutex

type contextTokenData struct {
Tokens map[string]string `json:"tokens"`
}

func contextTokenPath(botID string) (string, error) {
dir, err := AccountsDir()
if err != nil {
return "", err
}
return filepath.Join(dir, NormalizeAccountID(botID)+".contexts.json"), nil
}

// SaveContextToken stores the latest iLink context token for a user.
func SaveContextToken(botID, userID, token string) error {
if botID == "" || userID == "" || token == "" {
return nil
}

contextStoreMu.Lock()
defer contextStoreMu.Unlock()

path, err := contextTokenPath(botID)
if err != nil {
return err
}

data := contextTokenData{Tokens: map[string]string{}}
if raw, err := os.ReadFile(path); err == nil {
_ = json.Unmarshal(raw, &data)
}
if data.Tokens == nil {
data.Tokens = map[string]string{}
}
data.Tokens[userID] = token

if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("create context token dir: %w", err)
}

raw, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("marshal context tokens: %w", err)
}
if err := os.WriteFile(path, raw, 0o600); err != nil {
return fmt.Errorf("write context tokens: %w", err)
}
return nil
}

// LoadContextToken returns the latest cached iLink context token for a user.
func LoadContextToken(botID, userID string) (string, error) {
if botID == "" || userID == "" {
return "", nil
}

contextStoreMu.Lock()
defer contextStoreMu.Unlock()

path, err := contextTokenPath(botID)
if err != nil {
return "", err
}

raw, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", fmt.Errorf("read context tokens: %w", err)
}

var data contextTokenData
if err := json.Unmarshal(raw, &data); err != nil {
return "", fmt.Errorf("parse context tokens: %w", err)
}
return data.Tokens[userID], nil
}

// ClearContextTokens removes cached iLink context tokens for a bot account.
func ClearContextTokens(botID string) error {
if botID == "" {
return nil
}

contextStoreMu.Lock()
defer contextStoreMu.Unlock()

path, err := contextTokenPath(botID)
if err != nil {
return err
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove context tokens: %w", err)
}
return nil
}
56 changes: 56 additions & 0 deletions ilink/context_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ilink

import (
"os"
"path/filepath"
"testing"
)

func TestContextTokenStoreRoundTrip(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)

botID := "bot@example"
userID := "user@im.wechat"
token := "context-token"

if err := SaveContextToken(botID, userID, token); err != nil {
t.Fatalf("SaveContextToken() error = %v", err)
}

got, err := LoadContextToken(botID, userID)
if err != nil {
t.Fatalf("LoadContextToken() error = %v", err)
}
if got != token {
t.Fatalf("LoadContextToken() = %q, want %q", got, token)
}

path := filepath.Join(home, ".weclaw", "accounts", "bot-example.contexts.json")
if _, err := os.Stat(path); err != nil {
t.Fatalf("context token file was not written: %v", err)
}
}

func TestClearContextTokens(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)

botID := "bot@example"
userID := "user@im.wechat"

if err := SaveContextToken(botID, userID, "context-token"); err != nil {
t.Fatalf("SaveContextToken() error = %v", err)
}
if err := ClearContextTokens(botID); err != nil {
t.Fatalf("ClearContextTokens() error = %v", err)
}

got, err := LoadContextToken(botID, userID)
if err != nil {
t.Fatalf("LoadContextToken() error = %v", err)
}
if got != "" {
t.Fatalf("LoadContextToken() = %q, want empty", got)
}
}
3 changes: 3 additions & 0 deletions ilink/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ func (m *Monitor) Run(ctx context.Context) error {
log.Printf("[monitor] session expired, resetting sync buf")
m.getUpdatesBuf = ""
m.saveBuf()
if err := ClearContextTokens(m.client.BotID()); err != nil {
log.Printf("[monitor] failed to clear context tokens: %v", err)
}
} else {
// Sync buf already empty but still getting session expired:
// the bot token itself has expired. The user needs to re-login.
Expand Down
9 changes: 6 additions & 3 deletions messaging/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ type Handler struct {
customAliases map[string]string // custom alias -> agent name (from config)
factory AgentFactory
saveDefault SaveDefaultFunc
contextTokens sync.Map // map[userID]contextToken
saveDir string // directory to save images/files to
seenMsgs sync.Map // map[int64]time.Time — dedup by message_id
contextTokens sync.Map // map[userID]contextToken
saveDir string // directory to save images/files to
seenMsgs sync.Map // map[int64]time.Time — dedup by message_id
}

// NewHandler creates a new message handler.
Expand Down Expand Up @@ -299,6 +299,9 @@ func (h *Handler) HandleMessage(ctx context.Context, client *ilink.Client, msg i

// Store context token for this user
h.contextTokens.Store(msg.FromUserID, msg.ContextToken)
if err := ilink.SaveContextToken(client.BotID(), msg.FromUserID, msg.ContextToken); err != nil {
log.Printf("[handler] failed to save context token for %s: %v", msg.FromUserID, err)
}

// Generate a clientID for this reply (used to correlate typing → finish)
clientID := NewClientID()
Expand Down
7 changes: 7 additions & 0 deletions messaging/media.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ func sendMediaData(ctx context.Context, client *ilink.Client, toUserID, fileName
if fileName == "" {
fileName = "file"
}
if contextToken == "" {
token, err := ilink.LoadContextToken(client.BotID(), toUserID)
if err != nil {
return fmt.Errorf("load context token: %w", err)
}
contextToken = token
}

cdnMediaType, itemType := classifyMedia(contentType, source)

Expand Down
7 changes: 7 additions & 0 deletions messaging/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ func SendTextReply(ctx context.Context, client *ilink.Client, toUserID, text, co
if clientID == "" {
clientID = NewClientID()
}
if contextToken == "" {
token, err := ilink.LoadContextToken(client.BotID(), toUserID)
if err != nil {
return fmt.Errorf("load context token: %w", err)
}
contextToken = token
}

// Convert markdown to plain text for WeChat display
plainText := MarkdownToPlainText(text)
Expand Down