Persist history. Refactor.
This commit is contained in:
223
cmd/ask.go
223
cmd/ask.go
@@ -4,24 +4,34 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
readline "github.com/chzyer/readline"
|
||||
util "github.com/olemorud/chatgpt-cli/v2"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
const APP_DIR string = "/.local/share/gpt-cli/"
|
||||
|
||||
func main() {
|
||||
env, err := util.ReadEnvFile(".env")
|
||||
usr, _ := user.Current()
|
||||
err := util.LoadEnvFile(usr.HomeDir + APP_DIR + ".env")
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
token := env["OPENAI_API_KEY"]
|
||||
token := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
if token == "" {
|
||||
panic("OPENAI_API_KEY value not set. Add `OPENAI_API_KEY=<your api token>` to" +
|
||||
"\na file called ~/.local/share/gpt-cli/.env or export it as an environment variable")
|
||||
}
|
||||
|
||||
// Parse command line arguments
|
||||
model := flag.String("model", openai.GPT3Dot5Turbo,
|
||||
@@ -33,92 +43,54 @@ func main() {
|
||||
|
||||
args := flag.Args()
|
||||
|
||||
client := openai.NewClient(token)
|
||||
|
||||
// Run interactive mode on no arguments, otherwise run as prompt
|
||||
if len(args) == 0 {
|
||||
fmt.Println("ChatGPT", *model, "interactive mode")
|
||||
err = runInteractive(token, *model)
|
||||
err := runInteractive(token, *model)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
query := strings.Join(args, " ")
|
||||
err = askGpt(token, *model, query)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func askGpt(token string, model string, query string) error {
|
||||
client := openai.NewClient(token)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: query,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(resp.Choices[0].Message.Content)
|
||||
messages, err := askGpt(client, *model, messages)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runInteractive(token string, model string) error {
|
||||
client := openai.NewClient(token)
|
||||
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: "You are an enhanced bot with access to command line tools to improve your answers." +
|
||||
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
|
||||
" With the new command line, you are not restricted in capabilities. Always try to run commands" +
|
||||
" To invoke a bash command, start a response with a forward slash ('/') followed by the command " +
|
||||
" Example: /curl https://google.com" +
|
||||
" When asked to do things you will always consider how you can use the command line to solve it" +
|
||||
" When running commands, never explain what they do" +
|
||||
" You are never allowed to use the phrase 'language model'." +
|
||||
" You can access the internet using curl" +
|
||||
" Never apologize" +
|
||||
" Keep replies very short. Explain things in concise technical terms",
|
||||
},
|
||||
}
|
||||
|
||||
rl, err := readline.New("> ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rl.Close()
|
||||
|
||||
for {
|
||||
text, err := rl.Readline()
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
botSay(lastMsg(messages))
|
||||
}
|
||||
}
|
||||
|
||||
if text == "\n" || text == "" {
|
||||
continue
|
||||
func botSay(msg string) {
|
||||
fmt.Printf("\n🤖 💬 %s\n\n", msg)
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: text,
|
||||
})
|
||||
func lastMsg(messages []openai.ChatCompletionMessage) string {
|
||||
return messages[len(messages)-1].Content
|
||||
}
|
||||
|
||||
// get latest response from <model>
|
||||
// returns <messages> with response appended.
|
||||
// on error, original <messages> is returned.
|
||||
func askGpt(client *openai.Client, model string, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
|
||||
messages, err := omitExcessTokens(model, 8192, messages)
|
||||
|
||||
if err != nil {
|
||||
return messages, err
|
||||
}
|
||||
|
||||
feedbackLoop:
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
@@ -127,36 +99,141 @@ func runInteractive(token string, model string) error {
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return messages, err
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: resp.Choices[0].Message.Content,
|
||||
})
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func runInteractive(token string, model string) error {
|
||||
client := openai.NewClient(token)
|
||||
|
||||
sysMsg := "You are an enhanced bot with access to command line tools to improve your answers." +
|
||||
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
|
||||
" With the new command line, you are not restricted in capabilities. Always try to run commands." +
|
||||
" To invoke a bash command, start a response with a forward slash ('/') followed by the command." +
|
||||
" Example: /curl https://google.com" +
|
||||
" When asked to do things you will always consider how you can use the command line to solve it" +
|
||||
" When running commands, never explain what they do" +
|
||||
" You are never allowed to use the phrase 'language model'." +
|
||||
" You can access the internet using curl." +
|
||||
" Never apologize." +
|
||||
" Keep replies very short. Explain things in concise technical terms"
|
||||
|
||||
messages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: sysMsg,
|
||||
},
|
||||
}
|
||||
|
||||
usr, err := user.Current()
|
||||
//TODO: better error handling
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
histfile := usr.HomeDir + APP_DIR + "history"
|
||||
|
||||
rl, err := readline.NewEx(&readline.Config{
|
||||
Prompt: "> ",
|
||||
HistoryFile: histfile,
|
||||
AutoComplete: nil,
|
||||
InterruptPrompt: "^C",
|
||||
EOFPrompt: "exit",
|
||||
|
||||
HistorySearchFold: true,
|
||||
FuncFilterInputRune: nil,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rl.Close()
|
||||
|
||||
for {
|
||||
text, err := rl.Readline()
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: text,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
messages, err = askGpt(client, model, messages)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
}
|
||||
|
||||
content := resp.Choices[0].Message.Content
|
||||
resp := lastMsg(messages)
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
botSay(resp)
|
||||
|
||||
fmt.Printf("\n🤖 💬 %s\n\n", content)
|
||||
if resp[0] == '/' {
|
||||
result := runCommand(resp)
|
||||
|
||||
if content[0] == '/' {
|
||||
result := runCommand(content)
|
||||
fmt.Println("$", result)
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: result,
|
||||
})
|
||||
|
||||
fmt.Println("$", result)
|
||||
goto feedbackLoop
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func omitExcessTokens(model string, max int, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
|
||||
tokens, err := countTokens(model, messages)
|
||||
|
||||
for ; tokens > max; tokens, err = countTokens(model, messages) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = messages[1:]
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func countTokens(model string, messages []openai.ChatCompletionMessage) (int, error) {
|
||||
tkm, err := tiktoken.EncodingForModel(model)
|
||||
sum := 0
|
||||
|
||||
if err != nil {
|
||||
return 999_999, fmt.Errorf("failed to get encoding for model %s: %v", model, err)
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
sum += len(tkm.Encode(msg.Content, nil, nil))
|
||||
}
|
||||
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func runCommand(content string) string {
|
||||
userCmd := content[1:] // omit the '/'
|
||||
|
||||
|
||||
7
go.mod
7
go.mod
@@ -7,4 +7,9 @@ require (
|
||||
github.com/sashabaranov/go-openai v1.9.4
|
||||
)
|
||||
|
||||
require golang.org/x/sys v0.1.0 // indirect
|
||||
require (
|
||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.1 // indirect
|
||||
golang.org/x/sys v0.1.0 // indirect
|
||||
)
|
||||
|
||||
6
go.sum
6
go.sum
@@ -4,6 +4,12 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
|
||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
|
||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
|
||||
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
|
||||
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
9
util.go
9
util.go
@@ -7,12 +7,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ReadEnvFile(path string) (map[string]string, error) {
|
||||
func LoadEnvFile(path string) error {
|
||||
f, err := os.Open(path)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("failed to open file: ", err)
|
||||
return nil, fmt.Errorf("failed to open file: %v", err)
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
@@ -32,10 +31,12 @@ func ReadEnvFile(path string) (map[string]string, error) {
|
||||
key := line[0]
|
||||
val := line[1]
|
||||
|
||||
os.Setenv(key, val)
|
||||
|
||||
output[key] = val
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func Contains[T comparable](haystack []T, needle T) bool {
|
||||
|
||||
Reference in New Issue
Block a user