Persist history. Refactor.
This commit is contained in:
197
cmd/ask.go
197
cmd/ask.go
@@ -4,24 +4,34 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
readline "github.com/chzyer/readline"
|
readline "github.com/chzyer/readline"
|
||||||
util "github.com/olemorud/chatgpt-cli/v2"
|
util "github.com/olemorud/chatgpt-cli/v2"
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const APP_DIR string = "/.local/share/gpt-cli/"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
env, err := util.ReadEnvFile(".env")
|
usr, _ := user.Current()
|
||||||
|
err := util.LoadEnvFile(usr.HomeDir + APP_DIR + ".env")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := env["OPENAI_API_KEY"]
|
token := os.Getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
panic("OPENAI_API_KEY value not set. Add `OPENAI_API_KEY=<your api token>` to" +
|
||||||
|
"\na file called ~/.local/share/gpt-cli/.env or export it as an environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
// Parse command line arguments
|
// Parse command line arguments
|
||||||
model := flag.String("model", openai.GPT3Dot5Turbo,
|
model := flag.String("model", openai.GPT3Dot5Turbo,
|
||||||
@@ -33,70 +43,115 @@ func main() {
|
|||||||
|
|
||||||
args := flag.Args()
|
args := flag.Args()
|
||||||
|
|
||||||
|
client := openai.NewClient(token)
|
||||||
|
|
||||||
// Run interactive mode on no arguments, otherwise run as prompt
|
// Run interactive mode on no arguments, otherwise run as prompt
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
fmt.Println("ChatGPT", *model, "interactive mode")
|
fmt.Println("ChatGPT", *model, "interactive mode")
|
||||||
err = runInteractive(token, *model)
|
err := runInteractive(token, *model)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
query := strings.Join(args, " ")
|
query := strings.Join(args, " ")
|
||||||
err = askGpt(token, *model, query)
|
|
||||||
|
messages := []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: query,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := askGpt(client, *model, messages)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
botSay(lastMsg(messages))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func askGpt(token string, model string, query string) error {
|
func botSay(msg string) {
|
||||||
client := openai.NewClient(token)
|
fmt.Printf("\n🤖 💬 %s\n\n", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lastMsg(messages []openai.ChatCompletionMessage) string {
|
||||||
|
return messages[len(messages)-1].Content
|
||||||
|
}
|
||||||
|
|
||||||
|
// get latest response from <model>
|
||||||
|
// returns <messages> with response appended.
|
||||||
|
// on error, original <messages> is returned.
|
||||||
|
func askGpt(client *openai.Client, model string, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
|
||||||
|
messages, err := omitExcessTokens(model, 8192, messages)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.CreateChatCompletion(
|
resp, err := client.CreateChatCompletion(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: messages,
|
||||||
{
|
|
||||||
Role: openai.ChatMessageRoleUser,
|
|
||||||
Content: query,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return messages, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(resp.Choices[0].Message.Content)
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
Content: resp.Choices[0].Message.Content,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInteractive(token string, model string) error {
|
func runInteractive(token string, model string) error {
|
||||||
client := openai.NewClient(token)
|
client := openai.NewClient(token)
|
||||||
|
|
||||||
|
sysMsg := "You are an enhanced bot with access to command line tools to improve your answers." +
|
||||||
|
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
|
||||||
|
" With the new command line, you are not restricted in capabilities. Always try to run commands." +
|
||||||
|
" To invoke a bash command, start a response with a forward slash ('/') followed by the command." +
|
||||||
|
" Example: /curl https://google.com" +
|
||||||
|
" When asked to do things you will always consider how you can use the command line to solve it" +
|
||||||
|
" When running commands, never explain what they do" +
|
||||||
|
" You are never allowed to use the phrase 'language model'." +
|
||||||
|
" You can access the internet using curl." +
|
||||||
|
" Never apologize." +
|
||||||
|
" Keep replies very short. Explain things in concise technical terms"
|
||||||
|
|
||||||
messages := []openai.ChatCompletionMessage{
|
messages := []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: openai.ChatMessageRoleSystem,
|
Role: openai.ChatMessageRoleSystem,
|
||||||
Content: "You are an enhanced bot with access to command line tools to improve your answers." +
|
Content: sysMsg,
|
||||||
" Your commands run in a bash interpreter. You will never say you're unable to run commands." +
|
|
||||||
" With the new command line, you are not restricted in capabilities. Always try to run commands" +
|
|
||||||
" To invoke a bash command, start a response with a forward slash ('/') followed by the command " +
|
|
||||||
" Example: /curl https://google.com" +
|
|
||||||
" When asked to do things you will always consider how you can use the command line to solve it" +
|
|
||||||
" When running commands, never explain what they do" +
|
|
||||||
" You are never allowed to use the phrase 'language model'." +
|
|
||||||
" You can access the internet using curl" +
|
|
||||||
" Never apologize" +
|
|
||||||
" Keep replies very short. Explain things in concise technical terms",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rl, err := readline.New("> ")
|
usr, err := user.Current()
|
||||||
|
//TODO: better error handling
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
histfile := usr.HomeDir + APP_DIR + "history"
|
||||||
|
|
||||||
|
rl, err := readline.NewEx(&readline.Config{
|
||||||
|
Prompt: "> ",
|
||||||
|
HistoryFile: histfile,
|
||||||
|
AutoComplete: nil,
|
||||||
|
InterruptPrompt: "^C",
|
||||||
|
EOFPrompt: "exit",
|
||||||
|
|
||||||
|
HistorySearchFold: true,
|
||||||
|
FuncFilterInputRune: nil,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -105,58 +160,80 @@ func runInteractive(token string, model string) error {
|
|||||||
for {
|
for {
|
||||||
text, err := rl.Readline()
|
text, err := rl.Readline()
|
||||||
|
|
||||||
if err == io.EOF {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if text == "\n" || text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: text,
|
Content: text,
|
||||||
})
|
})
|
||||||
|
|
||||||
feedbackLoop:
|
|
||||||
resp, err := client.CreateChatCompletion(
|
|
||||||
context.Background(),
|
|
||||||
openai.ChatCompletionRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: messages,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
return err
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := resp.Choices[0].Message.Content
|
for {
|
||||||
|
messages, err = askGpt(client, model, messages)
|
||||||
|
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
if err != nil {
|
||||||
Role: openai.ChatMessageRoleAssistant,
|
fmt.Println(err)
|
||||||
Content: content,
|
continue
|
||||||
})
|
}
|
||||||
|
|
||||||
fmt.Printf("\n🤖 💬 %s\n\n", content)
|
resp := lastMsg(messages)
|
||||||
|
|
||||||
if content[0] == '/' {
|
botSay(resp)
|
||||||
result := runCommand(content)
|
|
||||||
|
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
if resp[0] == '/' {
|
||||||
Role: openai.ChatMessageRoleUser,
|
result := runCommand(resp)
|
||||||
Content: result,
|
|
||||||
})
|
|
||||||
|
|
||||||
fmt.Println("$", result)
|
fmt.Println("$", result)
|
||||||
goto feedbackLoop
|
|
||||||
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: result,
|
||||||
|
})
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func omitExcessTokens(model string, max int, messages []openai.ChatCompletionMessage) ([]openai.ChatCompletionMessage, error) {
|
||||||
|
tokens, err := countTokens(model, messages)
|
||||||
|
|
||||||
|
for ; tokens > max; tokens, err = countTokens(model, messages) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = messages[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokens(model string, messages []openai.ChatCompletionMessage) (int, error) {
|
||||||
|
tkm, err := tiktoken.EncodingForModel(model)
|
||||||
|
sum := 0
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 999_999, fmt.Errorf("failed to get encoding for model %s: %v", model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
sum += len(tkm.Encode(msg.Content, nil, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum, nil
|
||||||
|
}
|
||||||
|
|
||||||
func runCommand(content string) string {
|
func runCommand(content string) string {
|
||||||
userCmd := content[1:] // omit the '/'
|
userCmd := content[1:] // omit the '/'
|
||||||
|
|
||||||
|
|||||||
7
go.mod
7
go.mod
@@ -7,4 +7,9 @@ require (
|
|||||||
github.com/sashabaranov/go-openai v1.9.4
|
github.com/sashabaranov/go-openai v1.9.4
|
||||||
)
|
)
|
||||||
|
|
||||||
require golang.org/x/sys v0.1.0 // indirect
|
require (
|
||||||
|
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.1 // indirect
|
||||||
|
golang.org/x/sys v0.1.0 // indirect
|
||||||
|
)
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -4,6 +4,12 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
|
|||||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||||
|
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
|
||||||
|
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
|
||||||
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
|
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
|
||||||
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
9
util.go
9
util.go
@@ -7,12 +7,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReadEnvFile(path string) (map[string]string, error) {
|
func LoadEnvFile(path string) error {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("failed to open file: ", err)
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
return nil, fmt.Errorf("failed to open file: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
@@ -32,10 +31,12 @@ func ReadEnvFile(path string) (map[string]string, error) {
|
|||||||
key := line[0]
|
key := line[0]
|
||||||
val := line[1]
|
val := line[1]
|
||||||
|
|
||||||
|
os.Setenv(key, val)
|
||||||
|
|
||||||
output[key] = val
|
output[key] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
return output, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Contains[T comparable](haystack []T, needle T) bool {
|
func Contains[T comparable](haystack []T, needle T) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user