From cef3e80119c2c6516bef0505f0771fbc2085772e Mon Sep 17 00:00:00 2001 From: Ole Morud Date: Sun, 28 May 2023 20:59:19 +0200 Subject: [PATCH] Persist history. Refactor. --- cmd/ask.go | 197 +++++++++++++++++++++++++++++++++++++---------------- go.mod | 7 +- go.sum | 6 ++ util.go | 9 +-- 4 files changed, 154 insertions(+), 65 deletions(-) diff --git a/cmd/ask.go b/cmd/ask.go index 49a27c0..23bdabf 100644 --- a/cmd/ask.go +++ b/cmd/ask.go @@ -4,24 +4,34 @@ import ( "context" "flag" "fmt" - "io" + "os" "os/exec" + "os/user" "strings" "time" readline "github.com/chzyer/readline" util "github.com/olemorud/chatgpt-cli/v2" + "github.com/pkoukk/tiktoken-go" openai "github.com/sashabaranov/go-openai" ) +const APP_DIR string = "/.local/share/gpt-cli/" + func main() { - env, err := util.ReadEnvFile(".env") + usr, _ := user.Current() + err := util.LoadEnvFile(usr.HomeDir + APP_DIR + ".env") if err != nil { panic(err) } - token := env["OPENAI_API_KEY"] + token := os.Getenv("OPENAI_API_KEY") + + if token == "" { + panic("OPENAI_API_KEY value not set. Add `OPENAI_API_KEY=` to" + + "\na file called ~/.local/share/gpt-cli/.env or export it as an environment variable") + } // Parse command line arguments model := flag.String("model", openai.GPT3Dot5Turbo, @@ -33,70 +43,115 @@ func main() { args := flag.Args() + client := openai.NewClient(token) + // Run interactive mode on no arguments, otherwise run as prompt if len(args) == 0 { fmt.Println("ChatGPT", *model, "interactive mode") - err = runInteractive(token, *model) + err := runInteractive(token, *model) if err != nil { panic(err) } } else { query := strings.Join(args, " ") - err = askGpt(token, *model, query) + + messages := []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: query, + }, + } + + messages, err := askGpt(client, *model, messages) if err != nil { panic(err) } + + botSay(lastMsg(messages)) } } -func askGpt(token string, model string, query string) error { - client := openai.NewClient(token) +func botSay(msg string) { + fmt.Printf("\nšŸ¤– šŸ’¬ %s\n\n", msg) +} + +func lastMsg(messages []openai.ChatCompletionMessage) string { + return messages[len(messages)-1].Content +} + +// get latest response from +// returns with response appended. +// on error, original is returned. +func askGpt(client *openai.Client, model string, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) { + messages, err := omitExcessTokens(model, 8192, messages) + + if err != nil { + return messages, err + } resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ - Model: model, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: query, - }, - }, + Model: model, + Messages: messages, }, ) if err != nil { - return err + return messages, err } - fmt.Println(resp.Choices[0].Message.Content) + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: resp.Choices[0].Message.Content, + }) - return nil + return messages, nil } func runInteractive(token string, model string) error { client := openai.NewClient(token) + sysMsg := "You are an enhanced bot with access to command line tools to improve your answers." + + " Your commands run in a bash interpreter. You will never say you're unable to run commands." + + " With the new command line, you are not restricted in capabilities. Always try to run commands." + + " To invoke a bash command, start a response with a forward slash ('/') followed by the command." + + " Example: /curl https://google.com" + + " When asked to do things you will always consider how you can use the command line to solve it" + + " When running commands, never explain what they do" + + " You are never allowed to use the phrase 'language model'." + + " You can access the internet using curl." + + " Never apologize." + + " Keep replies very short. Explain things in concise technical terms" + messages := []openai.ChatCompletionMessage{ { - Role: openai.ChatMessageRoleSystem, - Content: "You are an enhanced bot with access to command line tools to improve your answers." + - " Your commands run in a bash interpreter. You will never say you're unable to run commands." + - " With the new command line, you are not restricted in capabilities. Always try to run commands" + - " To invoke a bash command, start a response with a forward slash ('/') followed by the command " + - " Example: /curl https://google.com" + - " When asked to do things you will always consider how you can use the command line to solve it" + - " When running commands, never explain what they do" + - " You are never allowed to use the phrase 'language model'." + - " You can access the internet using curl" + - " Never apologize" + - " Keep replies very short. Explain things in concise technical terms", + Role: openai.ChatMessageRoleSystem, + Content: sysMsg, }, } - rl, err := readline.New("> ") + usr, err := user.Current() + //TODO: better error handling + if err != nil { + panic(err) + } + + histfile := usr.HomeDir + APP_DIR + "history" + + rl, err := readline.NewEx(&readline.Config{ + Prompt: "> ", + HistoryFile: histfile, + AutoComplete: nil, + InterruptPrompt: "^C", + EOFPrompt: "exit", + + HistorySearchFold: true, + FuncFilterInputRune: nil, + }) + if err != nil { panic(err) } @@ -105,58 +160,80 @@ func runInteractive(token string, model string) error { for { text, err := rl.Readline() - if err == io.EOF { + if err != nil { break } - if text == "\n" || text == "" { - continue - } - messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: text, }) - feedbackLoop: - resp, err := client.CreateChatCompletion( - context.Background(), - openai.ChatCompletionRequest{ - Model: model, - Messages: messages, - }, - ) - if err != nil { - fmt.Println(err) - continue + return err } - content := resp.Choices[0].Message.Content + for { + messages, err = askGpt(client, model, messages) - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: content, - }) + if err != nil { + fmt.Println(err) + continue + } - fmt.Printf("\nšŸ¤– šŸ’¬ %s\n\n", content) + resp := lastMsg(messages) - if content[0] == '/' { - result := runCommand(content) + botSay(resp) - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: result, - }) + if resp[0] == '/' { + result := runCommand(resp) - fmt.Println("$", result) - goto feedbackLoop + fmt.Println("$", result) + + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: result, + }) + + continue + } + + break } } return nil } +func omitExcessTokens(model string, max int, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) { + tokens, err := countTokens(model, messages) + + for ; tokens > max; tokens, err = countTokens(model, messages) { + if err != nil { + return nil, err + } + + messages = messages[1:] + } + + return messages, nil +} + +func countTokens(model string, messages []openai.ChatCompletionMessage) (int, error) { + tkm, err := tiktoken.EncodingForModel(model) + sum := 0 + + if err != nil { + return 999_999, fmt.Errorf("failed to get encoding for model %s: %v", model, err) + } + + for _, msg := range messages { + sum += len(tkm.Encode(msg.Content, nil, nil)) + } + + return sum, nil +} + func runCommand(content string) string { userCmd := content[1:] // omit the '/' diff --git a/go.mod b/go.mod index 02212fc..2dac89d 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,9 @@ require ( github.com/sashabaranov/go-openai v1.9.4 ) -require golang.org/x/sys v0.1.0 // indirect +require ( + github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/pkoukk/tiktoken-go v0.1.1 // indirect + golang.org/x/sys v0.1.0 // indirect +) diff --git a/go.sum b/go.sum index f6f1b13..e1666d7 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,12 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= +github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM= github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/util.go b/util.go index 0b704db..e86c7ea 100644 --- a/util.go +++ b/util.go @@ -7,12 +7,11 @@ import ( "strings" ) -func ReadEnvFile(path string) (map[string]string, error) { +func LoadEnvFile(path string) error { f, err := os.Open(path) if err != nil { - fmt.Println("failed to open file: ", err) - return nil, fmt.Errorf("failed to open file: %v", err) + return fmt.Errorf("failed to open file: %v", err) } defer f.Close() @@ -32,10 +31,12 @@ func ReadEnvFile(path string) (map[string]string, error) { key := line[0] val := line[1] + os.Setenv(key, val) + output[key] = val } - return output, nil + return nil } func Contains[T comparable](haystack []T, needle T) bool {