51 lines
1.3 KiB
Go
51 lines
1.3 KiB
Go
package tokens
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"paraclub-ai-mailer/internal/logger"
|
|
|
|
"github.com/pkoukk/tiktoken-go"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type TokenCounter struct {
|
|
encoding *tiktoken.Tiktoken
|
|
}
|
|
|
|
// New creates a token counter with cl100k_base encoding (GPT-4/Claude compatible)
|
|
func New() (*TokenCounter, error) {
|
|
enc, err := tiktoken.GetEncoding("cl100k_base")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get tiktoken encoding: %w", err)
|
|
}
|
|
return &TokenCounter{encoding: enc}, nil
|
|
}
|
|
|
|
// CountString counts tokens in a single string
|
|
func (tc *TokenCounter) CountString(text string) int {
|
|
tokens := tc.encoding.Encode(text, nil, nil)
|
|
return len(tokens)
|
|
}
|
|
|
|
// EstimateFullRequest estimates total tokens for the complete API request
|
|
// Includes: system prompt + user prompt + message overhead
|
|
func (tc *TokenCounter) EstimateFullRequest(systemPrompt, userPrompt string) int {
|
|
systemTokens := tc.CountString(systemPrompt)
|
|
userTokens := tc.CountString(userPrompt)
|
|
|
|
// Add overhead for message structure (~100 tokens for JSON formatting, role labels, etc.)
|
|
overhead := 100
|
|
|
|
total := systemTokens + userTokens + overhead
|
|
|
|
logger.WithFields(logrus.Fields{
|
|
"systemTokens": systemTokens,
|
|
"userTokens": userTokens,
|
|
"overhead": overhead,
|
|
"total": total,
|
|
}).Debug("Token estimation breakdown")
|
|
|
|
return total
|
|
}
|