命令行版的ChatGPT(修改版)

本帖最后由 CrLf 于 2023-7-29 00:17 编辑

命令行下调用OpenAI接口,从标准输入中读取用户输入并将其发送到GPT模型,再将响应写入标准输出。因原版默认是UTF8,所以我改成默认以GBK编码读取输入,并增加 --utf8 开关兼容utf8编码。

原版GitHub:https://github.com/pdfinn/sgpt

用法:
sgpt -k <API_KEY> -i <INSTRUCTION> [-t TEMPERATURE] [-m MODEL] [-s SEPARATOR] [-u] [-d]COPY
参数说明:
短参数 长参数 环境变量 描述 默认值
-k --api_key SGPT_API_KEY 配置OpenAI的API KEY
-i --instruction SGPT_INSTRUCTION 系统指令,用于补充一些背景信息或要求
-t --temperature SGPT_TEMPERATURE 温度值,范围是0~1,数值越高,给出的答案越有想象力但也更倾向于编造 0.5
-m --model SGPT_MODEL 所采用的模型 gpt-3.5-turbo
-s --separator SGPT_SEPARATOR 不同内容的分隔符 \n
-u --utf8 SGPT_UTF8 以UTF8编码解读输入内容(该参数由CrLf添加,使默认编码是GBK) false
-d --debug SGPT_DEBUG 启用调试模式,将输出很多调试信息 false

CrLf修改后的源码:
  package main
   
  import (
  "bufio"
  "encoding/json"
  "fmt"
  "github.com/spf13/pflag"
  "github.com/spf13/viper"
  "io"
  "io/ioutil"
  "log"
  "net/http"
  "os"
  "strconv"
  "strings"
   
  // mod by CrLf 添加必要的模块
  "bytes"
  "golang.org/x/text/encoding/simplifiedchinese"
  "golang.org/x/text/transform"
   
  )
   
  // mod by CrLf 用于将UTF8转码为GBK
  // UTF-8 转 GBK
  func Utf8ToGbk(s []byte) ([]byte, error) {
  reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder())
  d, e := ioutil.ReadAll(reader)
  if e != nil {
  return nil, e
  }
  return d, nil
  }
   
  func GbkToUtf8(s []byte) ([]byte, error) {
      reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder())
      d, e := ioutil.ReadAll(reader)
      if e != nil {
          return nil, e
      }
      return d, nil
  }
   
   
  type OpenAIResponse struct {
  Choices []struct {
  Text    string `json:"text,omitempty"`
  Message struct {
  Role    string `json:"role,omitempty"`
  Content string `json:"content,omitempty"`
  } `json:"message,omitempty"`
  } `json:"choices"`
  }
   
  // mod by CrLf 声明utf8变量
  var utf8 *bool
  var debug *bool
   
  func init() {
  // mod by CrLf 去除重复的提醒
   
  // envUTF8 := os.Getenv("SGPT_UTF8")
  // envDebug := os.Getenv("SGPT_DEBUG")
  // utf8 = pflag.Bool("u", parseBoolWithDefault(envUTF8, false), "Enable UTF8 input")
  // debug = pflag.Bool("d", parseBoolWithDefault(envDebug, false), "Enable debug output")
  }
   
  func main() {
  // Default values
  defaultTemperature := 0.5
  defaultModel := "gpt-3.5-turbo"
   
  // Check environment variables
  envApiKey := os.Getenv("SGPT_API_KEY")
  envInstruction := os.Getenv("SGPT_INSTRUCTION")
  envTemperature, err := strconv.ParseFloat(os.Getenv("SGPT_TEMPERATURE"), 64)
  if err != nil {
  envTemperature = defaultTemperature
  }
  envModel := os.Getenv("SGPT_MODEL")
  envSeparator := os.Getenv("SGPT_SEPARATOR")
   
  // mod by CrLf 增加对环境变量 SGPT_UTF8 的支持
  envUTF8 := parseBoolWithDefault(os.Getenv("SGPT_UTF8"), false)
  envDebug := parseBoolWithDefault(os.Getenv("SGPT_DEBUG"), false)
   
  // Command line arguments
  apiKey := pflag.StringP("api_key", "k", envApiKey, "OpenAI API key")
  instruction := pflag.StringP("instruction", "i", envInstruction, "Instruction for the GPT model")
  temperature := pflag.Float64P("temperature", "t", envTemperature, "Temperature for the GPT model")
  model := pflag.StringP("model", "m", envModel, "GPT model to use")
  defaulSeparator := "\n"
  separator := pflag.StringP("separator", "s", envSeparator, "Separator character for input")
  if *separator == "" {
  *separator = defaulSeparator
  }
   
  // mod by CrLf 增加对参数 --utf8 或 -u 的支持
  utf8 = pflag.BoolP("utf8", "u", envUTF8, "Enable UTF8 input")
  debug = pflag.BoolP("debug", "d", envDebug, "Enable debug output")
  pflag.Parse()
   
  // Read the configuration file
  viper.SetConfigName("sgpt")
  viper.AddConfigPath(".")
  viper.AddConfigPath("$HOME/.sgpt")
  viper.SetConfigType("yaml")
   
  err = viper.ReadInConfig()
   
  // mod by CrLf 默认屏蔽无用警告,仅在debug模式下展示
  if _, ok := err.(viper.ConfigFileNotFoundError); ok {
  debugOutput(*debug, "Warning: Config file not found: %v", err)
  } else if err != nil {
  debugOutput(*debug, "Warning: Error reading config file: %v", err)
  }
   
  // Set default values and bind configuration values to flags
  viper.SetDefault("model", defaultModel)
  viper.SetDefault("temperature", defaultTemperature)
  viper.BindPFlag("api_key", pflag.Lookup("k"))
  viper.BindPFlag("instruction", pflag.Lookup("i"))
  viper.BindPFlag("model", pflag.Lookup("m"))
  viper.BindPFlag("temperature", pflag.Lookup("t"))
  viper.BindPFlag("separator", pflag.Lookup("s"))
  viper.BindPFlag("debug", pflag.Lookup("d"))
   
  // Use default values if neither flags nor environment variables are set
  if *model == "" {
  *model = defaultModel
  }
   
  if *apiKey == "" {
  log.Fatal("API key is required")
  }
   
   
  // Read input from stdin continuously
  // mod by CrLf 根据utf8开关的启禁用状态判断以utf8还是gbk读取stdin
  var reader io.RuneReader
  if *utf8 {
  reader = bufio.NewReader(os.Stdin)
  } else {
  byteInput, _ := io.ReadAll(os.Stdin)
  gbkBytes, _ := GbkToUtf8(byteInput)
  reader = bytes.NewReader(gbkBytes)
  }
   
  var inputBuffer strings.Builder
   
  for {
  inputChar, _, err := reader.ReadRune()
  if err == io.EOF {
  input := inputBuffer.String()
  if input != "" {
  response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
  if err != nil {
  log.Fatal(err)
  }
  fmt.Println(response)
  }
  break
  }
  if err != nil {
  log.Fatal(err)
  }
   
  if string(inputChar) == *separator {
  input := inputBuffer.String()
  inputBuffer.Reset()
   
  response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model)
  if err != nil {
  log.Fatal(err)
  }
   
  fmt.Println(response)
  } else {
  inputBuffer.WriteRune(inputChar)
  }
  }
  }
   
  func debugOutput(debug bool, format string, a ...interface{}) {
  if debug {
  log.Printf(format, a...)
  }
  }
   
  func parseFloatWithDefault(value string, defaultValue float64) float64 {
  if value == "" {
  return defaultValue
  }
  parsedValue, err := strconv.ParseFloat(value, 64)
  if err != nil {
  log.Printf("Warning: Failed to parse float value: %v", err)
  return defaultValue
  }
  return parsedValue
  }
   
  func parseBoolWithDefault(value string, defaultValue bool) bool {
  if value == "" {
  return defaultValue
  }
  parsedValue, err := strconv.ParseBool(value)
  if err != nil {
  log.Printf("Warning: Failed to parse bool value: %v", err)
  return defaultValue
  }
  return parsedValue
  }
   
  func callOpenAI(apiKey, instruction, input string, temperature float64, model string) (string, error) {
  var url string
  var jsonData []byte
  var err error
   
  switch model {
  case "gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo":
  url = "https://api.openai.com/v1/chat/completions"
   
  // Prepare JSON data for GPT-4 models
  messages := []map[string]string{
  {"role": "system", "content": instruction},
  {"role": "user", "content": input},
  }
   
  jsonData, err = json.Marshal(map[string]interface{}{
  "model":       model,
  "messages":    messages,
  "temperature": temperature,
  "max_tokens":  100,
  "stop":        []string{"\n"},
  })
   
  case "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001":
  url = "https://api.openai.com/v1/completions"
   
  // Prepare JSON data for GPT-3 models
  prompt := instruction + " " + input
  jsonData, err = json.Marshal(map[string]interface{}{
  "model":       model,
  "prompt":      prompt,
  "temperature": temperature,
  "max_tokens":  100,
  "stop":        []string{"\n"},
  })
   
  case "whisper-1":
  url = "https://api.openai.com/v1/audio/transcriptions"
  default:
  return "", fmt.Errorf("unsupported model: %s", model)
  }
   
  if err != nil {
  return "", err
  }
   
  data := strings.NewReader(string(jsonData))
   
  req, err := http.NewRequest("POST", url, data)
  if err != nil {
  return "", err
  }
   
  req.Header.Set("Content-Type", "application/json")
  req.Header.Set("Authorization", "Bearer "+apiKey)
   
  client := &http.Client{}
  resp, err := client.Do(req)
  if err != nil {
  return "", err
  }
  defer resp.Body.Close()
   
  body, err := ioutil.ReadAll(resp.Body)
  if err != nil {
  return "", err
  }
   
  debugOutput(*debug, "API response: %s\n", string(body))
   
  var openAIResponse OpenAIResponse
  err = json.Unmarshal(body, &openAIResponse)
  if err != nil {
  return "", err
  }
   
  if len(openAIResponse.Choices) == 0 {
  debugOutput(*debug, "API response: %s\n", string(body))
  debugOutput(*debug, "HTTP status code: %s\n", strconv.Itoa(resp.StatusCode))
  return "", fmt.Errorf("no choices returned from the API")
  }
   
  assistantMessage := ""
  for _, choice := range openAIResponse.Choices {
  if choice.Message.Role == "assistant" {
  assistantMessage = strings.TrimSpace(choice.Message.Content)
  break
  }
  if choice.Text != "" {
  assistantMessage = strings.TrimSpace(choice.Text)
  break
  }
  }
   
  if assistantMessage == "" {
  return "", fmt.Errorf("no assistant message found in the API response")
  }
   
  return assistantMessage, nil
  }COPY

编译后的下载地址:http://bcn.bathome.net/s/tool/index.html?key=sgpt

---------------------------------------------------------------------------------------

本帖最后由 CrLf 于 2023-7-29 00:13 编辑

举个例子:

  echo 柬埔寨在哪里|sgpt.exe --api_key "***这里是你的openai_api_key***" --instruction "请用中文回答:" --model "gpt-3.5-turbo"
  :: 回答为:柬埔寨位于东南亚,东临越南,南接泰国,西邻泰国和洞朗,北界老挝。COPY

如果要传入非GBK字符,请 chcp 65001 后使用 --utf8 开关

 

 

出处:http://bbs.bathome.net/thread-66919-1-2.html

posted on 2024-07-08 17:01  jack_Meng  阅读(25)  评论(0编辑  收藏  举报

导航