Persist history. Refactor.

This commit is contained in:
Ole Morud
2023-05-28 20:59:19 +02:00
parent 75383e052a
commit cef3e80119
4 changed files with 154 additions and 65 deletions

View File

@@ -4,24 +4,34 @@ import (
"context"
"flag"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strings"
"time"
readline "github.com/chzyer/readline"
util "github.com/olemorud/chatgpt-cli/v2"
"github.com/pkoukk/tiktoken-go"
openai "github.com/sashabaranov/go-openai"
)
const APP_DIR string = "/.local/share/gpt-cli/"
func main() {
env, err := util.ReadEnvFile(".env")
usr, _ := user.Current()
err := util.LoadEnvFile(usr.HomeDir + APP_DIR + ".env")
if err != nil {
panic(err)
}
token := env["OPENAI_API_KEY"]
token := os.Getenv("OPENAI_API_KEY")
if token == "" {
panic("OPENAI_API_KEY value not set. Add `OPENAI_API_KEY=<your api token>` to" +
"\na file called ~/.local/share/gpt-cli/.env or export it as an environment variable")
}
// Parse command line arguments
model := flag.String("model", openai.GPT3Dot5Turbo,
@@ -33,70 +43,115 @@ func main() {
args := flag.Args()
client := openai.NewClient(token)
// Run interactive mode on no arguments, otherwise run as prompt
if len(args) == 0 {
fmt.Println("ChatGPT", *model, "interactive mode")
err = runInteractive(token, *model)
err := runInteractive(token, *model)
if err != nil {
panic(err)
}
} else {
query := strings.Join(args, " ")
err = askGpt(token, *model, query)
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: query,
},
}
messages, err := askGpt(client, *model, messages)
if err != nil {
panic(err)
}
botSay(lastMsg(messages))
}
}
func askGpt(token string, model string, query string) error {
client := openai.NewClient(token)
func botSay(msg string) {
fmt.Printf("\n🤖 💬 %s\n\n", msg)
}
func lastMsg(messages []openai.ChatCompletionMessage) string {
return messages[len(messages)-1].Content
}
// get latest response from <model>
// returns <messages> with response appended.
// on error, original <messages> is returned.
func askGpt(client *openai.Client, model string, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
messages, err := omitExcessTokens(model, 8192, messages)
if err != nil {
return messages, err
}
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: query,
},
},
Model: model,
Messages: messages,
},
)
if err != nil {
return err
return messages, err
}
fmt.Println(resp.Choices[0].Message.Content)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: resp.Choices[0].Message.Content,
})
return nil
return messages, nil
}
func runInteractive(token string, model string) error {
client := openai.NewClient(token)
sysMsg := "You are an enhanced bot with access to command line tools to improve your answers." +
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
" With the new command line, you are not restricted in capabilities. Always try to run commands." +
" To invoke a bash command, start a response with a forward slash ('/') followed by the command." +
" Example: /curl https://google.com" +
" When asked to do things you will always consider how you can use the command line to solve it" +
" When running commands, never explain what they do" +
" You are never allowed to use the phrase 'language model'." +
" You can access the internet using curl." +
" Never apologize." +
" Keep replies very short. Explain things in concise technical terms"
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are an enhanced bot with access to command line tools to improve your answers." +
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
" With the new command line, you are not restricted in capabilities. Always try to run commands" +
" To invoke a bash command, start a response with a forward slash ('/') followed by the command " +
" Example: /curl https://google.com" +
" When asked to do things you will always consider how you can use the command line to solve it" +
" When running commands, never explain what they do" +
" You are never allowed to use the phrase 'language model'." +
" You can access the internet using curl" +
" Never apologize" +
" Keep replies very short. Explain things in concise technical terms",
Role: openai.ChatMessageRoleSystem,
Content: sysMsg,
},
}
rl, err := readline.New("> ")
usr, err := user.Current()
//TODO: better error handling
if err != nil {
panic(err)
}
histfile := usr.HomeDir + APP_DIR + "history"
rl, err := readline.NewEx(&readline.Config{
Prompt: "> ",
HistoryFile: histfile,
AutoComplete: nil,
InterruptPrompt: "^C",
EOFPrompt: "exit",
HistorySearchFold: true,
FuncFilterInputRune: nil,
})
if err != nil {
panic(err)
}
@@ -105,58 +160,80 @@ func runInteractive(token string, model string) error {
for {
text, err := rl.Readline()
if err == io.EOF {
if err != nil {
break
}
if text == "\n" || text == "" {
continue
}
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: text,
})
feedbackLoop:
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: model,
Messages: messages,
},
)
if err != nil {
fmt.Println(err)
continue
return err
}
content := resp.Choices[0].Message.Content
for {
messages, err = askGpt(client, model, messages)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: content,
})
if err != nil {
fmt.Println(err)
continue
}
fmt.Printf("\n🤖 💬 %s\n\n", content)
resp := lastMsg(messages)
if content[0] == '/' {
result := runCommand(content)
botSay(resp)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: result,
})
if resp[0] == '/' {
result := runCommand(resp)
fmt.Println("$", result)
goto feedbackLoop
fmt.Println("$", result)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: result,
})
continue
}
break
}
}
return nil
}
func omitExcessTokens(model string, max int, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
tokens, err := countTokens(model, messages)
for ; tokens > max; tokens, err = countTokens(model, messages) {
if err != nil {
return nil, err
}
messages = messages[1:]
}
return messages, nil
}
func countTokens(model string, messages []openai.ChatCompletionMessage) (int, error) {
tkm, err := tiktoken.EncodingForModel(model)
sum := 0
if err != nil {
return 999_999, fmt.Errorf("failed to get encoding for model %s: %v", model, err)
}
for _, msg := range messages {
sum += len(tkm.Encode(msg.Content, nil, nil))
}
return sum, nil
}
func runCommand(content string) string {
userCmd := content[1:] // omit the '/'

7
go.mod
View File

@@ -7,4 +7,9 @@ require (
github.com/sashabaranov/go-openai v1.9.4
)
require golang.org/x/sys v0.1.0 // indirect
require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.1 // indirect
golang.org/x/sys v0.1.0 // indirect
)

6
go.sum
View File

@@ -4,6 +4,12 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -7,12 +7,11 @@ import (
"strings"
)
func ReadEnvFile(path string) (map[string]string, error) {
func LoadEnvFile(path string) error {
f, err := os.Open(path)
if err != nil {
fmt.Println("failed to open file: ", err)
return nil, fmt.Errorf("failed to open file: %v", err)
return fmt.Errorf("failed to open file: %v", err)
}
defer f.Close()
@@ -32,10 +31,12 @@ func ReadEnvFile(path string) (map[string]string, error) {
key := line[0]
val := line[1]
os.Setenv(key, val)
output[key] = val
}
return output, nil
return nil
}
func Contains[T comparable](haystack []T, needle T) bool {