feat: change max token handling
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"paraclub-ai-mailer/internal/fetcher"
|
||||
"paraclub-ai-mailer/internal/imap"
|
||||
"paraclub-ai-mailer/internal/logger"
|
||||
"paraclub-ai-mailer/internal/tokens"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -56,6 +57,12 @@ func main() {
|
||||
|
||||
fetcher := fetcher.New()
|
||||
aiProcessor := ai.New(cfg.AI)
|
||||
|
||||
tokenCounter, err := tokens.New()
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to initialize token counter")
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Debug("All components initialized successfully")
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
@@ -76,7 +83,7 @@ func main() {
|
||||
return
|
||||
case <-ticker.C:
|
||||
logger.Debug("Processing tick started")
|
||||
processEmails(imapClient, fetcher, aiProcessor, cfg)
|
||||
processEmails(imapClient, fetcher, aiProcessor, tokenCounter, cfg)
|
||||
logger.Debug("Processing tick completed")
|
||||
}
|
||||
}
|
||||
@@ -89,7 +96,7 @@ func main() {
|
||||
logger.Info("Application shutdown complete")
|
||||
}
|
||||
|
||||
func processEmails(imapClient *imap.IMAPClient, fetcher *fetcher.Fetcher, aiProcessor *ai.AI, cfg *config.Config) {
|
||||
func processEmails(imapClient *imap.IMAPClient, fetcher *fetcher.Fetcher, aiProcessor *ai.AI, tokenCounter *tokens.TokenCounter, cfg *config.Config) {
|
||||
logger.Debug("Starting email processing cycle")
|
||||
|
||||
// Fetch unprocessed emails
|
||||
@@ -118,44 +125,55 @@ func processEmails(imapClient *imap.IMAPClient, fetcher *fetcher.Fetcher, aiProc
|
||||
// Process each email
|
||||
var processedCount, errorCount, skippedCount int
|
||||
for _, email := range emails {
|
||||
emailBodySize := len(email.Body)
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"from": email.From,
|
||||
"messageId": email.ID,
|
||||
"bodySizeBytes": emailBodySize,
|
||||
}).Info("Processing email")
|
||||
|
||||
// Check email size limit
|
||||
if cfg.Processing.MaxEmailSizeBytes > 0 && emailBodySize > cfg.Processing.MaxEmailSizeBytes {
|
||||
// Extract clean email content (removes attachments, MIME boundaries, headers, converts HTML to text)
|
||||
cleanEmailContent := imap.ExtractMessageContent(email.Body)
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"cleanSize": len(cleanEmailContent),
|
||||
}).Debug("Extracted clean email content")
|
||||
|
||||
// Calculate token count for validation (use English as default to avoid API call)
|
||||
// Language detection will happen during actual GenerateReply
|
||||
systemPrompt := aiProcessor.BuildSystemPrompt("English")
|
||||
userPrompt := aiProcessor.BuildUserPrompt(contextContent, cleanEmailContent)
|
||||
estimatedTokens := tokenCounter.EstimateFullRequest(systemPrompt, userPrompt)
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"estimatedTokens": estimatedTokens,
|
||||
"maxTokens": cfg.Processing.MaxTokens,
|
||||
}).Debug("Calculated token estimate for email")
|
||||
|
||||
// Check token limit
|
||||
if cfg.Processing.MaxTokens > 0 && estimatedTokens > cfg.Processing.MaxTokens {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"from": email.From,
|
||||
"bodySizeBytes": emailBodySize,
|
||||
"maxSizeBytes": cfg.Processing.MaxEmailSizeBytes,
|
||||
}).Warn("Email body exceeds size limit, skipping")
|
||||
"estimatedTokens": estimatedTokens,
|
||||
"maxTokens": cfg.Processing.MaxTokens,
|
||||
}).Warn("Email exceeds token limit, marking as AI-processed but keeping in inbox")
|
||||
|
||||
skippedCount++
|
||||
|
||||
// Mark as AI-processed to prevent reprocessing
|
||||
// Mark as AI-processed to prevent reprocessing, but DON'T move the email
|
||||
if markErr := imapClient.MarkAsAIProcessed(email); markErr != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"error": markErr,
|
||||
}).Error("Failed to mark oversized email as AI-processed")
|
||||
} else {
|
||||
logger.WithField("subject", email.Subject).Info("Marked oversized email as AI-processed (email remains in inbox)")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract clean email content (remove MIME boundaries, headers, etc.)
|
||||
cleanEmailContent := imap.ExtractMessageContent(email.Body)
|
||||
cleanContentSize := len(cleanEmailContent)
|
||||
logger.WithFields(logrus.Fields{
|
||||
"subject": email.Subject,
|
||||
"rawSize": emailBodySize,
|
||||
"cleanSize": cleanContentSize,
|
||||
"sizeReduction": emailBodySize - cleanContentSize,
|
||||
}).Debug("Extracted clean email content")
|
||||
|
||||
// Generate AI response with clean content
|
||||
response, err := aiProcessor.GenerateReply(cleanEmailContent, contextContent)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,7 +24,10 @@ polling:
|
||||
interval: "5m" # Examples: "30s", "1m", "1h"
|
||||
|
||||
processing:
|
||||
max_email_size_bytes: 102400 # Maximum email body size in bytes (100KB), 0 = no limit
|
||||
max_tokens: 12000 # Maximum total tokens for API request (system + context + email)
|
||||
# Recommended: 8000-12000 for most models
|
||||
# Set to 0 for no limit
|
||||
# Note: Attachments are automatically stripped before counting
|
||||
skip_junk_emails: false # Skip emails marked as junk/spam (not yet implemented)
|
||||
|
||||
logging:
|
||||
|
||||
@@ -51,7 +51,7 @@ type LoggingConfig struct {
|
||||
}
|
||||
|
||||
type ProcessingConfig struct {
|
||||
MaxEmailSizeBytes int `yaml:"max_email_size_bytes"` // Maximum email body size in bytes (0 = no limit)
|
||||
MaxTokens int `yaml:"max_tokens"` // Maximum total tokens for API request (0 = no limit)
|
||||
SkipJunkEmails bool `yaml:"skip_junk_emails"` // Skip emails marked as junk/spam
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func Load(path string) (*Config, error) {
|
||||
"pollingInterval": config.Polling.Interval,
|
||||
"loggingLevel": config.Logging.Level,
|
||||
"loggingFilePath": config.Logging.FilePath,
|
||||
"maxEmailSizeBytes": config.Processing.MaxEmailSizeBytes,
|
||||
"maxTokens": config.Processing.MaxTokens,
|
||||
"skipJunkEmails": config.Processing.SkipJunkEmails,
|
||||
}).Debug("Configuration loaded successfully")
|
||||
|
||||
|
||||
3
go.mod
3
go.mod
@@ -10,7 +10,10 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
)
|
||||
|
||||
7
go.sum
7
go.sum
@@ -1,6 +1,8 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/emersion/go-imap v1.2.1 h1:+s9ZjMEjOB8NzZMVTM3cCenz2JrQIGGo5j1df19WjTA=
|
||||
github.com/emersion/go-imap v1.2.1/go.mod h1:Qlx1FSx2FTxjnjWpIlVNEuX+ylerZQNFE5NsmKFSejY=
|
||||
github.com/emersion/go-message v0.15.0/go.mod h1:wQUEfE+38+7EW8p8aZ96ptg6bAb1iwdgej19uXASlE4=
|
||||
@@ -8,6 +10,10 @@ github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTe
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 h1:oP4q0fw+fOSWn3DfFi4EXdT+B+gTtzx8GC9xsc26Znk=
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
|
||||
github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594/go.mod h1:aqO8z8wPrjkscevZJFVE1wXJrLpC5LtJG7fqLOsPb2U=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
@@ -15,6 +21,7 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -43,6 +43,31 @@ func New(cfg config.AIConfig) *AI {
|
||||
}
|
||||
}
|
||||
|
||||
// BuildSystemPrompt creates the system prompt for a given language
|
||||
// Exposed for token counting without making an API call
|
||||
func (a *AI) BuildSystemPrompt(lang string) string {
|
||||
return fmt.Sprintf(`You are a helpful assistant who responds to emails.
|
||||
Your primary goal is to answer the user's query (found in the 'Email Body') by primarily using the information available in the 'Additional Context' and your general knowledge.
|
||||
While the 'Email Body' provides the question, your answer should be synthesized from the context and your understanding, not by directly repeating or solely relying on unverified information from the 'Email Body' itself.
|
||||
|
||||
Instructions:
|
||||
- Language: Your response must be entirely in %s, regardless of the language used in the email content or context.
|
||||
- Format: CRITICAL: Your reply MUST be raw HTML. Use appropriate HTML tags for structure and styling. For example, wrap paragraphs in <p>...</p> tags and use <br> for line breaks if needed within a paragraph. Even a short sentence must be wrapped in HTML (e.g., <p>Yes.</p>).
|
||||
- Markdown: Do NOT wrap the HTML in markdown code blocks (e.g., %s).
|
||||
- Extraneous Text: Do not include a subject line. Do not include explanations, commentary, or any extra text that is not part of the direct answer.
|
||||
- Closing: Avoid generic closing statements like "If you have further questions...". Focus solely on answering the email.
|
||||
`, lang, "```html ... ```")
|
||||
}
|
||||
|
||||
// BuildUserPrompt creates the user message with context and email content
|
||||
func (a *AI) BuildUserPrompt(contextContent map[string]string, emailContent string) string {
|
||||
var context string
|
||||
for url, content := range contextContent {
|
||||
context += fmt.Sprintf("\nContext from %s:\n%s\n", url, content)
|
||||
}
|
||||
return fmt.Sprintf("### Additional Context:\n%s\n\n### Email Body:\n%s", context, emailContent)
|
||||
}
|
||||
|
||||
func (a *AI) detectLanguage(emailContent string) (string, error) {
|
||||
logger.WithField("emailContentLength", len(emailContent)).Debug("Starting language detection")
|
||||
|
||||
@@ -77,32 +102,11 @@ func (a *AI) GenerateReply(emailContent string, contextContent map[string]string
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prepare context from all URLs
|
||||
var context string
|
||||
for url, content := range contextContent {
|
||||
context += fmt.Sprintf("\nContext from %s:\n%s\n", url, content)
|
||||
logger.WithFields(logrus.Fields{
|
||||
"url": url,
|
||||
"contentLength": len(content),
|
||||
}).Debug("Added context from URL")
|
||||
}
|
||||
// Build prompts using exposed methods
|
||||
systemMsg := a.BuildSystemPrompt(lang)
|
||||
userMsg := a.BuildUserPrompt(contextContent, emailContent)
|
||||
|
||||
// Prepare the system message with language-specific instruction
|
||||
systemMsg := fmt.Sprintf(`You are a helpful assistant who responds to emails.
|
||||
Your primary goal is to answer the user's query (found in the 'Email Body') by primarily using the information available in the 'Additional Context' and your general knowledge.
|
||||
While the 'Email Body' provides the question, your answer should be synthesized from the context and your understanding, not by directly repeating or solely relying on unverified information from the 'Email Body' itself.
|
||||
|
||||
Instructions:
|
||||
- Language: Your response must be entirely in %s, regardless of the language used in the email content or context.
|
||||
- Format: CRITICAL: Your reply MUST be raw HTML. Use appropriate HTML tags for structure and styling. For example, wrap paragraphs in <p>...</p> tags and use <br> for line breaks if needed within a paragraph. Even a short sentence must be wrapped in HTML (e.g., <p>Yes.</p>).
|
||||
- Markdown: Do NOT wrap the HTML in markdown code blocks (e.g., %s).
|
||||
- Extraneous Text: Do not include a subject line. Do not include explanations, commentary, or any extra text that is not part of the direct answer.
|
||||
- Closing: Avoid generic closing statements like "If you have further questions...". Focus solely on answering the email.
|
||||
`, lang, "```html ... ```")
|
||||
logger.WithFields(logrus.Fields{
|
||||
"systemprompt": systemMsg,
|
||||
}).Debug("Generating system prompt")
|
||||
userMsg := fmt.Sprintf("### Additional Context:\n%s\n\n### Email Body:\n%s", context, emailContent)
|
||||
logger.Debug("Generated system and user prompts")
|
||||
|
||||
messages := []Message{
|
||||
{Role: "system", Content: systemMsg},
|
||||
|
||||
@@ -373,7 +373,8 @@ func handleMultipartMessage(reader io.Reader, boundary string) string {
|
||||
}
|
||||
|
||||
mReader := multipart.NewReader(reader, boundary)
|
||||
var textContent string
|
||||
var textPlainContent string
|
||||
var textHTMLContent string
|
||||
partIndex := 0
|
||||
|
||||
for {
|
||||
@@ -387,39 +388,75 @@ func handleMultipartMessage(reader io.Reader, boundary string) string {
|
||||
}
|
||||
|
||||
contentType := part.Header.Get("Content-Type")
|
||||
contentDisposition := part.Header.Get("Content-Disposition")
|
||||
contentTransferEncoding := strings.ToLower(part.Header.Get("Content-Transfer-Encoding"))
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"partIndex": partIndex,
|
||||
"partContentType": contentType,
|
||||
"partDisposition": contentDisposition,
|
||||
"partTransferEncoding": contentTransferEncoding,
|
||||
"partHeaders": part.Header,
|
||||
}).Debug("handleMultipartMessage: Processing part")
|
||||
|
||||
if strings.HasPrefix(contentType, "text/plain") {
|
||||
// Skip attachments
|
||||
if strings.HasPrefix(contentDisposition, "attachment") {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"partIndex": partIndex,
|
||||
"filename": part.FileName(),
|
||||
}).Debug("Skipping attachment part")
|
||||
partIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip non-text content types (images, videos, applications, etc.)
|
||||
if !strings.HasPrefix(contentType, "text/plain") &&
|
||||
!strings.HasPrefix(contentType, "text/html") {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"partIndex": partIndex,
|
||||
"contentType": contentType,
|
||||
}).Debug("Skipping non-text content type")
|
||||
partIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle quoted-printable encoding
|
||||
var partReader io.Reader = part
|
||||
if contentTransferEncoding == "quoted-printable" {
|
||||
partReader = quotedprintable.NewReader(part)
|
||||
}
|
||||
// Add handling for "base64" if needed in the future
|
||||
// else if contentTransferEncoding == "base64" {
|
||||
// partReader = base64.NewDecoder(base64.StdEncoding, part)
|
||||
// }
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := buf.ReadFrom(partReader); err != nil {
|
||||
logger.WithError(err).WithField("partIndex", partIndex).Debug("Failed to read from partReader in multipart")
|
||||
continue // Or handle error more robustly
|
||||
partIndex++
|
||||
continue
|
||||
}
|
||||
textContent = buf.String()
|
||||
// Assuming we only care about the first text/plain part found
|
||||
// If multiple text/plain parts could exist and need concatenation, this logic would need adjustment.
|
||||
break
|
||||
|
||||
// Store text/plain and text/html separately
|
||||
if strings.HasPrefix(contentType, "text/plain") {
|
||||
textPlainContent = buf.String()
|
||||
logger.WithField("textPlainLength", len(textPlainContent)).Debug("Found text/plain part")
|
||||
} else if strings.HasPrefix(contentType, "text/html") {
|
||||
textHTMLContent = buf.String()
|
||||
logger.WithField("textHTMLLength", len(textHTMLContent)).Debug("Found text/html part")
|
||||
}
|
||||
|
||||
partIndex++
|
||||
}
|
||||
|
||||
logger.WithField("textContentResult", textContent).Debug("handleMultipartMessage: Returning textContent")
|
||||
return textContent
|
||||
// Prefer text/plain over text/html
|
||||
if textPlainContent != "" {
|
||||
logger.Debug("handleMultipartMessage: Returning text/plain content")
|
||||
return textPlainContent
|
||||
}
|
||||
|
||||
if textHTMLContent != "" {
|
||||
logger.Debug("handleMultipartMessage: Converting text/html to plain text")
|
||||
return htmlToPlainText(textHTMLContent)
|
||||
}
|
||||
|
||||
logger.Debug("handleMultipartMessage: No text content found")
|
||||
return ""
|
||||
}
|
||||
|
||||
func handleSinglePartMessage(reader io.Reader) string {
|
||||
@@ -434,6 +471,48 @@ func handleSinglePartMessage(reader io.Reader) string {
|
||||
return content
|
||||
}
|
||||
|
||||
// htmlToPlainText converts HTML content to plain text by extracting text nodes
|
||||
func htmlToPlainText(htmlContent string) string {
|
||||
logger.WithField("htmlLength", len(htmlContent)).Debug("Converting HTML to plain text")
|
||||
|
||||
// Simple HTML tag stripping - removes all HTML tags and extracts text
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for _, char := range htmlContent {
|
||||
switch char {
|
||||
case '<':
|
||||
inTag = true
|
||||
case '>':
|
||||
inTag = false
|
||||
default:
|
||||
if !inTag {
|
||||
result.WriteRune(char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
plainText := result.String()
|
||||
|
||||
// Clean up excessive whitespace
|
||||
plainText = strings.ReplaceAll(plainText, "\r\n", "\n")
|
||||
plainText = strings.ReplaceAll(plainText, "\r", "\n")
|
||||
|
||||
// Replace multiple spaces with single space
|
||||
lines := strings.Split(plainText, "\n")
|
||||
var cleanLines []string
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" {
|
||||
cleanLines = append(cleanLines, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
result2 := strings.Join(cleanLines, "\n")
|
||||
logger.WithField("plainTextLength", len(result2)).Debug("HTML converted to plain text")
|
||||
return result2
|
||||
}
|
||||
|
||||
func cleanMessageContent(content string, performHeaderStripping bool) string {
|
||||
logger.WithField("inputContentLength", len(content)).Debug("cleanMessageContent: Starting")
|
||||
logger.WithField("performHeaderStripping", performHeaderStripping).Debug("cleanMessageContent: performHeaderStripping flag")
|
||||
|
||||
50
internal/tokens/tokens.go
Normal file
50
internal/tokens/tokens.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"paraclub-ai-mailer/internal/logger"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TokenCounter struct {
|
||||
encoding *tiktoken.Tiktoken
|
||||
}
|
||||
|
||||
// New creates a token counter with cl100k_base encoding (GPT-4/Claude compatible)
|
||||
func New() (*TokenCounter, error) {
|
||||
enc, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get tiktoken encoding: %w", err)
|
||||
}
|
||||
return &TokenCounter{encoding: enc}, nil
|
||||
}
|
||||
|
||||
// CountString counts tokens in a single string
|
||||
func (tc *TokenCounter) CountString(text string) int {
|
||||
tokens := tc.encoding.Encode(text, nil, nil)
|
||||
return len(tokens)
|
||||
}
|
||||
|
||||
// EstimateFullRequest estimates total tokens for the complete API request
|
||||
// Includes: system prompt + user prompt + message overhead
|
||||
func (tc *TokenCounter) EstimateFullRequest(systemPrompt, userPrompt string) int {
|
||||
systemTokens := tc.CountString(systemPrompt)
|
||||
userTokens := tc.CountString(userPrompt)
|
||||
|
||||
// Add overhead for message structure (~100 tokens for JSON formatting, role labels, etc.)
|
||||
overhead := 100
|
||||
|
||||
total := systemTokens + userTokens + overhead
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"systemTokens": systemTokens,
|
||||
"userTokens": userTokens,
|
||||
"overhead": overhead,
|
||||
"total": total,
|
||||
}).Debug("Token estimation breakdown")
|
||||
|
||||
return total
|
||||
}
|
||||
Reference in New Issue
Block a user