Persist history. Refactor.

This commit is contained in:
Ole Morud
2023-05-28 20:59:19 +02:00
parent 75383e052a
commit cef3e80119
4 changed files with 154 additions and 65 deletions

View File

@@ -4,24 +4,34 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"io" "os"
"os/exec" "os/exec"
"os/user"
"strings" "strings"
"time" "time"
readline "github.com/chzyer/readline" readline "github.com/chzyer/readline"
util "github.com/olemorud/chatgpt-cli/v2" util "github.com/olemorud/chatgpt-cli/v2"
"github.com/pkoukk/tiktoken-go"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
const APP_DIR string = "/.local/share/gpt-cli/"
func main() { func main() {
env, err := util.ReadEnvFile(".env") usr, _ := user.Current()
err := util.LoadEnvFile(usr.HomeDir + APP_DIR + ".env")
if err != nil { if err != nil {
panic(err) panic(err)
} }
token := env["OPENAI_API_KEY"] token := os.Getenv("OPENAI_API_KEY")
if token == "" {
panic("OPENAI_API_KEY value not set. Add `OPENAI_API_KEY=<your api token>` to" +
"\na file called ~/.local/share/gpt-cli/.env or export it as an environment variable")
}
// Parse command line arguments // Parse command line arguments
model := flag.String("model", openai.GPT3Dot5Turbo, model := flag.String("model", openai.GPT3Dot5Turbo,
@@ -33,70 +43,115 @@ func main() {
args := flag.Args() args := flag.Args()
client := openai.NewClient(token)
// Run interactive mode on no arguments, otherwise run as prompt // Run interactive mode on no arguments, otherwise run as prompt
if len(args) == 0 { if len(args) == 0 {
fmt.Println("ChatGPT", *model, "interactive mode") fmt.Println("ChatGPT", *model, "interactive mode")
err = runInteractive(token, *model) err := runInteractive(token, *model)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} else { } else {
query := strings.Join(args, " ") query := strings.Join(args, " ")
err = askGpt(token, *model, query)
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: query,
},
}
messages, err := askGpt(client, *model, messages)
if err != nil { if err != nil {
panic(err) panic(err)
} }
botSay(lastMsg(messages))
} }
} }
func askGpt(token string, model string, query string) error { func botSay(msg string) {
client := openai.NewClient(token) fmt.Printf("\n🤖 💬 %s\n\n", msg)
}
func lastMsg(messages []openai.ChatCompletionMessage) string {
return messages[len(messages)-1].Content
}
// get latest response from <model>
// returns <messages> with response appended.
// on error, original <messages> is returned.
func askGpt(client *openai.Client, model string, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
messages, err := omitExcessTokens(model, 8192, messages)
if err != nil {
return messages, err
}
resp, err := client.CreateChatCompletion( resp, err := client.CreateChatCompletion(
context.Background(), context.Background(),
openai.ChatCompletionRequest{ openai.ChatCompletionRequest{
Model: model, Model: model,
Messages: []openai.ChatCompletionMessage{ Messages: messages,
{
Role: openai.ChatMessageRoleUser,
Content: query,
},
},
}, },
) )
if err != nil { if err != nil {
return err return messages, err
} }
fmt.Println(resp.Choices[0].Message.Content) messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: resp.Choices[0].Message.Content,
})
return nil return messages, nil
} }
func runInteractive(token string, model string) error { func runInteractive(token string, model string) error {
client := openai.NewClient(token) client := openai.NewClient(token)
sysMsg := "You are an enhanced bot with access to command line tools to improve your answers." +
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
" With the new command line, you are not restricted in capabilities. Always try to run commands." +
" To invoke a bash command, start a response with a forward slash ('/') followed by the command." +
" Example: /curl https://google.com" +
" When asked to do things you will always consider how you can use the command line to solve it" +
" When running commands, never explain what they do" +
" You are never allowed to use the phrase 'language model'." +
" You can access the internet using curl." +
" Never apologize." +
" Keep replies very short. Explain things in concise technical terms"
messages := []openai.ChatCompletionMessage{ messages := []openai.ChatCompletionMessage{
{ {
Role: openai.ChatMessageRoleSystem, Role: openai.ChatMessageRoleSystem,
Content: "You are an enhanced bot with access to command line tools to improve your answers." + Content: sysMsg,
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
" With the new command line, you are not restricted in capabilities. Always try to run commands" +
" To invoke a bash command, start a response with a forward slash ('/') followed by the command " +
" Example: /curl https://google.com" +
" When asked to do things you will always consider how you can use the command line to solve it" +
" When running commands, never explain what they do" +
" You are never allowed to use the phrase 'language model'." +
" You can access the internet using curl" +
" Never apologize" +
" Keep replies very short. Explain things in concise technical terms",
}, },
} }
rl, err := readline.New("> ") usr, err := user.Current()
//TODO: better error handling
if err != nil {
panic(err)
}
histfile := usr.HomeDir + APP_DIR + "history"
rl, err := readline.NewEx(&readline.Config{
Prompt: "> ",
HistoryFile: histfile,
AutoComplete: nil,
InterruptPrompt: "^C",
EOFPrompt: "exit",
HistorySearchFold: true,
FuncFilterInputRune: nil,
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -105,58 +160,80 @@ func runInteractive(token string, model string) error {
for { for {
text, err := rl.Readline() text, err := rl.Readline()
if err == io.EOF { if err != nil {
break break
} }
if text == "\n" || text == "" {
continue
}
messages = append(messages, openai.ChatCompletionMessage{ messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: text, Content: text,
}) })
feedbackLoop:
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: model,
Messages: messages,
},
)
if err != nil { if err != nil {
fmt.Println(err) return err
continue
} }
content := resp.Choices[0].Message.Content for {
messages, err = askGpt(client, model, messages)
messages = append(messages, openai.ChatCompletionMessage{ if err != nil {
Role: openai.ChatMessageRoleAssistant, fmt.Println(err)
Content: content, continue
}) }
fmt.Printf("\n🤖 💬 %s\n\n", content) resp := lastMsg(messages)
if content[0] == '/' { botSay(resp)
result := runCommand(content)
messages = append(messages, openai.ChatCompletionMessage{ if resp[0] == '/' {
Role: openai.ChatMessageRoleUser, result := runCommand(resp)
Content: result,
})
fmt.Println("$", result) fmt.Println("$", result)
goto feedbackLoop
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: result,
})
continue
}
break
} }
} }
return nil return nil
} }
func omitExcessTokens(model string, max int, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
tokens, err := countTokens(model, messages)
for ; tokens > max; tokens, err = countTokens(model, messages) {
if err != nil {
return nil, err
}
messages = messages[1:]
}
return messages, nil
}
func countTokens(model string, messages []openai.ChatCompletionMessage) (int, error) {
tkm, err := tiktoken.EncodingForModel(model)
sum := 0
if err != nil {
return 999_999, fmt.Errorf("failed to get encoding for model %s: %v", model, err)
}
for _, msg := range messages {
sum += len(tkm.Encode(msg.Content, nil, nil))
}
return sum, nil
}
func runCommand(content string) string { func runCommand(content string) string {
userCmd := content[1:] // omit the '/' userCmd := content[1:] // omit the '/'

7
go.mod
View File

@@ -7,4 +7,9 @@ require (
github.com/sashabaranov/go-openai v1.9.4 github.com/sashabaranov/go-openai v1.9.4
) )
require golang.org/x/sys v0.1.0 // indirect require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.1 // indirect
golang.org/x/sys v0.1.0 // indirect
)

6
go.sum
View File

@@ -4,6 +4,12 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM= github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -7,12 +7,11 @@ import (
"strings" "strings"
) )
func ReadEnvFile(path string) (map[string]string, error) { func LoadEnvFile(path string) error {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
fmt.Println("failed to open file: ", err) return fmt.Errorf("failed to open file: %v", err)
return nil, fmt.Errorf("failed to open file: %v", err)
} }
defer f.Close() defer f.Close()
@@ -32,10 +31,12 @@ func ReadEnvFile(path string) (map[string]string, error) {
key := line[0] key := line[0]
val := line[1] val := line[1]
os.Setenv(key, val)
output[key] = val output[key] = val
} }
return output, nil return nil
} }
func Contains[T comparable](haystack []T, needle T) bool { func Contains[T comparable](haystack []T, needle T) bool {