commit d3452021749df13b97bc2cb2155bce1690c2c323 Author: nhyatt Date: Wed Nov 16 17:50:30 2022 -0600 initial commit diff --git a/cmd/getvaultpw/main.go b/cmd/getvaultpw/main.go new file mode 100644 index 0000000..d1e42ad --- /dev/null +++ b/cmd/getvaultpw/main.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "log" + "time" + + "getvaultpw/internal/config" + "getvaultpw/internal/protectString" + "getvaultpw/internal/vault" +) + +func main() { + // initialize application configuration + cfg := config.Init() + + defer func(cfg *config.Config) { + if err := cfg.WriteConfig(cfg.ConfigFile); err != nil { + log.Fatalf("[WARNING] Unable to update configuration file: %v", err) + } + log.Println("[DEBUG] shutdown sequence complete") + }(cfg) + + // if we are passed a vault instance and a password we need to update the password for a environment + if len(cfg.VaultInstance) != 0 && len(cfg.VaultPass) != 0 { + cfg.ConfigFileData.VaultEnvironment[cfg.VaultInstance].EPass = protectString.Encrypt(cfg.VaultPass, cfg.ConfigFileData.Salt) + } + + // get the password for a secret + if len(cfg.SecretID) != 0 { + for k, v := range cfg.ConfigFileData.Credential { + if k == cfg.SecretID { + if (v.TimeStamp+(60*60*2)) < time.Now().UTC().Unix() || len(cfg.ConfigFileData.Credential[k].Cache) == 0 { + o, err := vault.GetCredential(cfg.ConfigFileData.VaultEnvironment[v.VaultEnv].Host, + cfg.ConfigFileData.VaultEnvironment[v.VaultEnv].User, + protectString.Decrypt(cfg.ConfigFileData.VaultEnvironment[v.VaultEnv].EPass, cfg.ConfigFileData.Salt), + v.ID, + v.Path) + + if err != nil { + log.Fatalf("[ERROR] %v", err) + } + cfg.ConfigFileData.Credential[k].Cache = protectString.Encrypt(o, cfg.ConfigFileData.Salt) + cfg.ConfigFileData.Credential[k].TimeStamp = time.Now().UTC().Unix() + fmt.Println(o) + } else { + o := protectString.Decrypt(cfg.ConfigFileData.Credential[k].Cache, cfg.ConfigFileData.Salt) + fmt.Println(o) + } + } else { + log.Fatalf("[ERROR] Unable to find secret (%s) in config file: %s", cfg.SecretID, cfg.ConfigFile) + } + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..040af0c --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module getvaultpw + +go 1.19 + +require ( + github.com/hashicorp/logutils v1.0.0 + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..1a67e6f --- /dev/null +++ b/go.sum @@ -0,0 +1,5 @@ +github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= +github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..29e82f6 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,109 @@ +package config + +import ( + "log" + "os" + "reflect" + "strconv" + "time" + + "github.com/hashicorp/logutils" +) + +type Config struct { + // time configuration + TimeFormat string `env:"TIME_FORMAT" default:"2006-01-02 15:04:05"` + TimeZoneLocal string `env:"TIME_ZONE" default:"America/Chicago"` + TZoneLocal *time.Location `ignored:"true"` + TZoneUTC *time.Location `ignored:"true"` + + // logging + LogLevel int `env:"LOG_LEVEL" default:"0"` + Log *logutils.LevelFilter `ignored:"true"` + + // configuration file + ConfigFile string `ignored:"true" env:"CONFIG_FILE" required:"false"` + ConfigFileExample string `ignored:"true"` + ConfigFileData ConfigFileStruct `ignored:"true"` + + // misc + VaultUser string `env:"VAULT_USER" required:"false" default:""` + VaultPass string `env:"VAULT_PASS" required:"false" default:""` + SecretID string `env:"VAULT_SECRET_ID" required:"false" default:""` + VaultInstance string `env:"VAULT_INSTANCE" required:"false" default:"dev"` +} + +// DefaultConfig initializes the config variable for use with a prepared set of defaults. +func DefaultConfig() *Config { + home, err := os.UserHomeDir() + if err != nil { + log.Fatalf("[FATAL] Unable to determine user home directory for config file: %v", err) + } + + return &Config{ + Log: &logutils.LevelFilter{ + Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "WARNING", "ERROR"}, + Writer: os.Stderr, + }, + ConfigFile: home + "/.getvaultpw.yml", + ConfigFileExample: `--- +salt: exampleSalt + +vaultEnvironment: + dev: + host: vault.dev.example.com + user: userName + test: + host: vault.test.example.com + user: userName + stage: + host: vault.stage.example.com + user: userName + prod: + host: vault.prod.example.com + user: userName + +credentials: + serviceAccountNumberOne: + vaultPath: /secrets/serviceAccountNumberOne/credentials + vaultID: password + vaultInstance: dev + serviceAccountNumberTwo: + vaultPath: /secrets/serviceAccountNumberTwo/credentials + vaultID: password +`, + } +} + +func (cfg *Config) setLogLevel() { + switch { + case cfg.LogLevel <= 20: + cfg.Log.SetMinLevel(logutils.LogLevel("ERROR")) + case cfg.LogLevel > 20 && cfg.LogLevel <= 40: + cfg.Log.SetMinLevel(logutils.LogLevel("WARNING")) + case cfg.LogLevel > 40 && cfg.LogLevel <= 60: + cfg.Log.SetMinLevel(logutils.LogLevel("INFO")) + case cfg.LogLevel > 60 && cfg.LogLevel <= 80: + cfg.Log.SetMinLevel(logutils.LogLevel("DEBUG")) + case cfg.LogLevel > 80: + cfg.Log.SetMinLevel(logutils.LogLevel("TRACE")) + } + log.SetOutput(cfg.Log) +} + +func (cfg *Config) printRunningConfig(cfgInfo []StructInfo) { + log.Printf("[DEBUG] Current Running Configuration Values:") + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + log.Printf("[DEBUG]\t%s\t\t= %s\n", info.Alt, *p) + case "bool": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + log.Printf("[DEBUG]\t%s\t\t= %s\n", info.Alt, strconv.FormatBool(*p)) + case "int": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + log.Printf("[DEBUG]\t%s\t\t= %s\n", info.Alt, strconv.FormatInt(int64(*p), 10)) + } + } +} diff --git a/internal/config/configFile.go b/internal/config/configFile.go new file mode 100644 index 0000000..d21ec57 --- /dev/null +++ b/internal/config/configFile.go @@ -0,0 +1,43 @@ +package config + +import ( + "log" + "os" + + "gopkg.in/yaml.v2" +) + +type ConfigFileStruct struct { + Salt string `yaml:"salt"` + VaultEnvironment map[string]*VaultEnvStruct `yaml:"vaultEnvironment"` + Credential map[string]*CredentialStruct `yaml:"credentials"` +} + +type VaultEnvStruct struct { + Host string `yaml:"host"` + User string `yaml:"user"` + EPass string `yaml:"encryptedPassword"` +} + +type CredentialStruct struct { + Path string `yaml:"vaultPath"` + ID string `yaml:"vaultID"` + TimeStamp int64 `yaml:"timestamp"` + Cache string `yaml:"cachedValue"` + VaultEnv string `yaml:"vaultInstance"` +} + +func (cfg *Config) WriteConfig(dest string) error { + file, err := os.OpenFile(dest, os.O_WRONLY, 0600) + if err != nil { + log.Fatalf("[ERROR] Unable to open the config file %s: %v", dest, err) + } + defer file.Close() + + data, _ := yaml.Marshal(cfg.ConfigFileData) + if _, err = file.Write(data); err != nil { + log.Fatalf("[ERROR] Unable to update config file %s: %v", cfg.ConfigFile, err) + } + + return nil +} diff --git a/internal/config/envconfig.go b/internal/config/envconfig.go new file mode 100644 index 0000000..4bf5bb4 --- /dev/null +++ b/internal/config/envconfig.go @@ -0,0 +1,111 @@ +package config + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type StructInfo struct { + Name string + Alt string + Key string + Field reflect.Value + Tags reflect.StructTag + Type reflect.Type + DefaultValue interface{} +} + +func getStructInfo(spec interface{}) ([]StructInfo, error) { + s := reflect.ValueOf(spec) + + if s.Kind() != reflect.Pointer { + return []StructInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a pointer to a struct.\n", s.Kind()) + } + + s = s.Elem() + if s.Kind() != reflect.Struct { + return []StructInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a struct.\n", s.Kind()) + } + typeOfSpec := s.Type() + + infos := make([]StructInfo, 0, s.NumField()) + for i := 0; i < s.NumField(); i++ { + f := s.Field(i) + ftype := typeOfSpec.Field(i) + + ignored, _ := strconv.ParseBool(ftype.Tag.Get("ignored")) + if !f.CanSet() || ignored { + continue + } + + for f.Kind() == reflect.Pointer { + if f.IsNil() { + if f.Type().Elem().Kind() != reflect.Struct { + break + } + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + + info := StructInfo{ + Name: ftype.Name, + Alt: strings.ToUpper(ftype.Tag.Get("env")), + Key: ftype.Name, + Field: f, + Tags: ftype.Tag, + Type: ftype.Type, + } + if info.Alt != "" { + info.Key = info.Alt + } + info.Key = strings.ToUpper(info.Key) + if ftype.Tag.Get("default") != "" { + v, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("default")) + if err != nil { + return []StructInfo{}, err + } + info.DefaultValue = v + } + infos = append(infos, info) + } + return infos, nil +} + +func typeConversion(t, v string) (interface{}, error) { + switch t { + case "string": + return v, nil + case "int": + return strconv.ParseInt(v, 10, 0) + case "int8": + return strconv.ParseInt(v, 10, 8) + case "int16": + return strconv.ParseInt(v, 10, 16) + case "int32": + return strconv.ParseInt(v, 10, 32) + case "int64": + return strconv.ParseInt(v, 10, 64) + case "uint": + return strconv.ParseUint(v, 10, 0) + case "uint16": + return strconv.ParseUint(v, 10, 16) + case "uint32": + return strconv.ParseUint(v, 10, 32) + case "uint64": + return strconv.ParseUint(v, 10, 64) + case "float32": + return strconv.ParseFloat(v, 32) + case "float64": + return strconv.ParseFloat(v, 64) + case "complex64": + return strconv.ParseComplex(v, 64) + case "complex128": + return strconv.ParseComplex(v, 128) + case "bool": + return strconv.ParseBool(v) + } + return nil, fmt.Errorf("Unable to identify type.") +} diff --git a/internal/config/initialize.go b/internal/config/initialize.go new file mode 100644 index 0000000..c72ddc2 --- /dev/null +++ b/internal/config/initialize.go @@ -0,0 +1,161 @@ +package config + +import ( + "flag" + "log" + "os" + "reflect" + "strconv" + "time" + + "crypto/rand" + "encoding/hex" + "io/ioutil" + + "gopkg.in/yaml.v2" +) + +// getEnvString returns string from environment variable +func getEnvString(env, def string) (val string) { //nolint:deadcode + val = os.Getenv(env) + + if val == "" { + return def + } + + return +} + +// getEnvInt returns int from environment variable +func getEnvInt(env string, def int) (ret int) { + val := os.Getenv(env) + + if val == "" { + return def + } + + ret, err := strconv.Atoi(val) + if err != nil { + log.Fatalf("[ERROR] Environment variable is not numeric: %v\n", env) + } + + return +} + +// getEnvBool returns boolean from environment variable +func getEnvBool(env string, def bool) bool { + var ( + err error + retVal bool + val = os.Getenv(env) + ) + + if len(val) == 0 { + return def + } else { + retVal, err = strconv.ParseBool(val) + if err != nil { + log.Fatalf("[ERROR] Environment variable is not boolean: %v\n", env) + } + } + + return retVal +} + +// Init initializes the application configuration by reading default values from the struct's tags +// and environment variables. Tags processed by this process are as follows: +// `ignored:"true" env:"ENVIRONMENT_VARIABLE" default:"default value"` +func Init() *Config { + var cryptovault string + cfg := DefaultConfig() + + cfgInfo, err := getStructInfo(cfg) + if err != nil { + log.Fatalf("[FATAL] %v", err) + } + + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + var dv string + + if info.DefaultValue != nil { + dv = info.DefaultValue.(string) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + flag.StringVar(p, info.Name, getEnvString(info.Name, dv), "("+info.Key+")") + case "bool": + var dv bool + + if info.DefaultValue != nil { + dv = info.DefaultValue.(bool) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + flag.BoolVar(p, info.Name, getEnvBool(info.Name, dv), "("+info.Key+")") + case "int": + var dv int + + if info.DefaultValue != nil { + dv = int(info.DefaultValue.(int64)) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + flag.IntVar(p, info.Name, getEnvInt(info.Name, dv), "("+info.Key+")") + } + } + + flag.StringVar(&cryptovault, "label", "", "Deprecated support feature. Please use -SecretID.\nCryptovault compatibility feature used to identify the ID of the secret") + flag.Parse() + + if len(cryptovault) != 0 { + cfg.SecretID = cryptovault + } + + // set logging level + cfg.setLogLevel() + + // check if configuration file is present, if not create the example config file + if _, err := os.Stat(cfg.ConfigFile); os.IsNotExist(err) { + log.Printf("[WARNING] No configuration file present. Creating example configuration file at %s", cfg.ConfigFile) + file, err := os.Create(cfg.ConfigFile) + if err != nil { + log.Fatalf("[ERROR] Unable to create the config file %s: %v", cfg.ConfigFile, err) + } + defer file.Close() + + file.Write([]byte(cfg.ConfigFileExample)) + } + + // readConfiguration File + fileData, err := ioutil.ReadFile(cfg.ConfigFile) + if err != nil { + log.Fatalf("[ERROR] Unable to read the config file %s: %v", cfg.ConfigFile, err) + } + err = yaml.Unmarshal(fileData, &cfg.ConfigFileData) + if err != nil { + log.Printf("[ERROR] Unable to process the config file %s: %v", cfg.ConfigFile, err) + } + + // make sure there is a secure salt + if cfg.ConfigFileData.Salt == "" || cfg.ConfigFileData.Salt == "exampleSalt" { + s := make([]byte, 32) + rand.Read(s) + cfg.ConfigFileData.Salt = hex.EncodeToString(s) + } + + // timezone & format configuration + cfg.TZoneUTC, _ = time.LoadLocation("UTC") + if err != nil { + log.Fatalf("[ERROR] Unable to parse timezone string. Please use one of the timezone database values listed here: %s", "https://en.wikipedia.org/wiki/List_of_tz_database_time_zones") + } + cfg.TZoneLocal, err = time.LoadLocation(cfg.TimeZoneLocal) + if err != nil { + log.Fatalf("[ERROR] Unable to parse timezone string. Please use one of the timezone database values listed here: %s", "https://en.wikipedia.org/wiki/List_of_tz_database_time_zones") + } + time.Now().Format(cfg.TimeFormat) + + // print running config + cfg.printRunningConfig(cfgInfo) + + log.Println("[INFO] initialization complete") + return cfg +} diff --git a/internal/httpclient/httpclient.go b/internal/httpclient/httpclient.go new file mode 100644 index 0000000..4585831 --- /dev/null +++ b/internal/httpclient/httpclient.go @@ -0,0 +1,137 @@ +package httpclient + +import ( + "bytes" + "errors" + "net" + "strconv" + "time" + + "io/ioutil" + "net/http" +) + +// HTTPClient is an interface for initializing the http client library. +type HTTPClient struct { + Client *http.Client + Data *bytes.Buffer + Headers map[string]string + + Username string + Password string +} + +// DefaultClient is a function for defining a basic HTTP client with standard timeouts. +func DefaultClient() *HTTPClient { + return &HTTPClient{ + Client: &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 5 * time.Second, + IdleConnTimeout: 300 * time.Second, + }, + }, + } +} + +// NewClient Create an HTTPClient with a user-provided net/http.Client +func NewClient(httpClient *http.Client) *HTTPClient { + return &HTTPClient{Client: httpClient} +} + +// SetBasicAuth is a chaining function to set the username and password for basic +// authentication +func (c *HTTPClient) SetBasicAuth(username, password string) *HTTPClient { + c.Username = username + c.Password = password + + return c +} + +// SetPostData is a chaining function to set POST/PUT/PATCH data +func (c *HTTPClient) SetPostData(data string) *HTTPClient { + c.Data = bytes.NewBufferString(data) + + return c +} + +// SetHeader is a chaining function to set arbitrary HTTP Headers +func (c *HTTPClient) SetHeader(label string, value string) *HTTPClient { + if c.Headers == nil { + c.Headers = map[string]string{} + } + + c.Headers[label] = value + + return c +} + +// Get calls the net.http GET operation +func (c *HTTPClient) Get(url string) ([]byte, error) { + return c.do(url, http.MethodGet) +} + +// Patch calls the net.http PATCH operation +func (c *HTTPClient) Patch(url string) ([]byte, error) { + return c.do(url, http.MethodPatch) +} + +// Post calls the net.http POST operation +func (c *HTTPClient) Post(url string) ([]byte, error) { + return c.do(url, http.MethodPost) +} + +// Put calls the net.http PUT operation +func (c *HTTPClient) Put(url string) ([]byte, error) { + return c.do(url, http.MethodPut) +} + +func (c *HTTPClient) do(url string, method string) ([]byte, error) { + var ( + req *http.Request + res *http.Response + output []byte + err error + ) + + // NewRequest knows that c.data is typed *bytes.Buffer and will SEGFAULT + // if c.data is nil. So we create a request using nil when c.data is nil + if c.Data != nil { + req, err = http.NewRequest(method, url, c.Data) + } else { + req, err = http.NewRequest(method, url, nil) + } + if err != nil { + return nil, err + } + + if (len(c.Username) > 0) && (len(c.Password) > 0) { + req.SetBasicAuth(c.Username, c.Password) + } + + if c.Headers != nil { + for label, value := range c.Headers { + req.Header.Set(label, value) + } + } + + if res, err = c.Client.Do(req); err != nil { + return nil, err + } + + defer res.Body.Close() + + if output, err = ioutil.ReadAll(res.Body); err != nil { + return nil, err + } + + // check status + if res.StatusCode < 200 || res.StatusCode >= 300 { + return nil, errors.New("non-successful status code received [" + strconv.Itoa(res.StatusCode) + "]") + } + + return output, nil +} diff --git a/internal/httpclient/httpclient_test.go b/internal/httpclient/httpclient_test.go new file mode 100644 index 0000000..f300f64 --- /dev/null +++ b/internal/httpclient/httpclient_test.go @@ -0,0 +1,281 @@ +package httpclient + +import ( + "fmt" + "testing" + + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" +) + +type Data struct { + Greeting string `json:"greeting"` + Headers map[string]string `json:"headers"` + Method string `json:"method"` + Username string `json:"username"` + Password string `json:"password"` + PostData string `json:"postdata"` +} + +var ( + greeting = "Hello world" + postData = "Test data" + authUser = "testuser" + authPass = "testpass" + headerLabel = "Test-Header" + headerValue = "Test-Value" +) + +func httpTestHandler(w http.ResponseWriter, r *http.Request) { + var ( + b []byte + user string + pass string + body []byte + ) + + data := Data{ + Greeting: greeting, + Headers: map[string]string{}, + Method: r.Method, + } + + user, pass, ok := r.BasicAuth() + if ok { + data.Username = user + data.Password = pass + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + fmt.Fprint(w, "ioutil.ReadAll failed") + } + data.PostData = string(body) + + for h := range r.Header { + data.Headers[h] = r.Header.Get(h) + } + + b, err = json.MarshalIndent(data, "", " ") + if err != nil { + fmt.Fprint(w, "Json marshal failed somehow") + } + fmt.Fprint(w, string(b)) +} + +func checkMethod(t *testing.T, data Data, method string) { + if data.Method != method { + t.Errorf("data.Method(%s) != method(%s)", data.Method, method) + } + t.Log("checkMethod() success") +} + +func checkGreeting(t *testing.T, data Data) { + if data.Greeting != greeting { + t.Errorf("data.Greeting(%s) != greeting(%s)", data.Greeting, greeting) + } + t.Log("checkGreeting() success") +} + +func checkBasicAuth(t *testing.T, data Data) { + if data.Username != authUser { + t.Errorf("data.Username(%s) != authUser(%s)", data.Username, authUser) + } + if data.Password != authPass { + t.Errorf("data.Password(%s) != authPass(%s)", data.Password, authPass) + } + t.Log("checkBasicAuth() success") +} + +func checkPostData(t *testing.T, data Data) { + if data.PostData != postData { + t.Errorf("data.PostData(%s) != postData(%s)", data.PostData, postData) + } + t.Log("checkPostData() success") +} + +func TestGet(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) +} + +func TestGetAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) + checkBasicAuth(t, data) +} + +func TestPut(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Put(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPut) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPutAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Put(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPut) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestPost(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Post(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPost) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPostAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Post(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPost) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestPatch(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Patch(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPatch) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPatchAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Patch(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPatch) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestSetHeader(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetHeader(headerLabel, headerValue).Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) + if data.Headers[headerLabel] != headerValue { + t.Errorf("SetHeader values not set in header: %+v", data.Headers) + } +} \ No newline at end of file diff --git a/internal/protectString/decrypt.go b/internal/protectString/decrypt.go new file mode 100644 index 0000000..6532d51 --- /dev/null +++ b/internal/protectString/decrypt.go @@ -0,0 +1,39 @@ +package protectString + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/hex" + "fmt" +) + +func Decrypt(encryptedString string, keyString string) (decryptedString string) { + key, _ := hex.DecodeString(keyString) + enc, _ := hex.DecodeString(encryptedString) + + //Create a new Cipher Block from the key + block, err := aes.NewCipher(key) + if err != nil { + panic(err.Error()) + } + + //Create a new GCM + aesGCM, err := cipher.NewGCM(block) + if err != nil { + panic(err.Error()) + } + + //Get the nonce size + nonceSize := aesGCM.NonceSize() + + //Extract the nonce from the encrypted data + nonce, ciphertext := enc[:nonceSize], enc[nonceSize:] + + //Decrypt the data + plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) + if err != nil { + panic(err.Error()) + } + + return fmt.Sprintf("%s", plaintext) +} diff --git a/internal/protectString/encrypt.go b/internal/protectString/encrypt.go new file mode 100644 index 0000000..063e1b1 --- /dev/null +++ b/internal/protectString/encrypt.go @@ -0,0 +1,37 @@ +package protectString + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "fmt" + "io" +) + +func Encrypt(stringToEncrypt string, keyString string) (encryptedString string) { + key, _ := hex.DecodeString(keyString) + plaintext := []byte(stringToEncrypt) + + //Create a new Cipher Block from the key + block, err := aes.NewCipher(key) + if err != nil { + panic(err.Error()) + } + + //Create a new GCM + aesGCM, err := cipher.NewGCM(block) + if err != nil { + panic(err.Error()) + } + + //Create a nonce. Nonce should be from GCM + nonce := make([]byte, aesGCM.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + panic(err.Error()) + } + + //Encrypt the data using aesGCM.Seal + ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil) + return fmt.Sprintf("%x", ciphertext) +} diff --git a/internal/vault/vault.go b/internal/vault/vault.go new file mode 100644 index 0000000..ed7fc97 --- /dev/null +++ b/internal/vault/vault.go @@ -0,0 +1,116 @@ +package vault + +import ( + "fmt" + "log" + "regexp" + "time" + + "encoding/json" + + "getvaultpw/internal/httpclient" +) + +type authRespStruct struct { + Auth struct { + Accessor string `json:"accessor"` + ClientToken string `json:"client_token"` + EntityID string `json:"entity_id"` + LeaseDuration int64 `json:"lease_duration"` + Metadata struct { + Username string `json:"username"` + } `json:"metadata"` + MfaRequirement interface{} `json:"mfa_requirement"` + NumUses int64 `json:"num_uses"` + Orphan bool `json:"orphan"` + Policies []string `json:"policies"` + Renewable bool `json:"renewable"` + TokenPolicies []string `json:"token_policies"` + TokenType string `json:"token_type"` + } `json:"auth"` + Data struct{} `json:"data"` + LeaseDuration int64 `json:"lease_duration"` + LeaseID string `json:"lease_id"` + Renewable bool `json:"renewable"` + RequestID string `json:"request_id"` + Warnings interface{} `json:"warnings"` + WrapInfo interface{} `json:"wrap_info"` +} + +type secretV2Struct struct { + RequestID string `json:"request_id"` + LeaseID string `json:"lease_id"` + Renewable bool `json:"renewable"` + LeaseDuration int `json:"lease_duration"` + Data struct { + Data map[string]string `json:"data"` + Metadata struct { + CreatedTime time.Time `json:"created_time"` + CustomMetadata interface{} `json:"custom_metadata"` + DeletionTime string `json:"deletion_time"` + Destroyed bool `json:"destroyed"` + Version int `json:"version"` + } `json:"metadata"` + } `json:"data"` + WrapInfo interface{} `json:"wrap_info"` + Warnings interface{} `json:"warnings"` + Auth interface{} `json:"auth"` +} + +func login(host, user, pass string) (string, error) { + c := httpclient.DefaultClient() + c.SetHeader("Accept", "application/json") + c.SetHeader("Content-Type", "application/json") + c.SetPostData(fmt.Sprintf("{ \"password\":\"%s\"}", pass)) + + log.Printf("[TRACE] LOGIN URL : %s", fmt.Sprintf("%s/v1/auth/ldap/login/%s", host, user)) + log.Printf("[TRACE] LOGIN USER : %s", user) + log.Printf("[TRACE] LOGIN PASS : %s", pass) + o, err := c.Post(fmt.Sprintf("%s/v1/auth/ldap/login/%s", host, user)) + if err != nil { + return "", err + } + + var output authRespStruct + if err := json.Unmarshal(o, &output); err != nil { + return "", err + } + return output.Auth.ClientToken, nil +} + +func GetCredential(host, user, pass, store, path string) (string, error) { + token, err := login(host, user, pass) + if err != nil { + return "", err + } + + c := httpclient.DefaultClient() + c.SetHeader("Accept", "application/json") + c.SetHeader("Content-Type", "application/json") + c.SetHeader("X-Vault-Token", token) + + log.Printf("[TRACE] SECRET URL : %s", fmt.Sprintf("%s/v1/%s/data/%s", host, store, path)) + log.Printf("[TRACE] SECRET TOKEN: %s", token) + o, err := c.Get(fmt.Sprintf("%s/v1/%s/data/%s", host, store, path)) + if err != nil { + return "", err + } + + var output secretV2Struct + if err := json.Unmarshal(o, &output); err != nil { + return "", err + } + + for k, v := range output.Data.Data { + r, err := regexp.Compile(`(p|P)(a|A)(s|S)(s|S)((w|W)(o|O)(r|R)(d|D))?`) + if err != nil { + return "", err + } + + if r.Match([]byte(k)) { + return v, nil + } + } + + return "", fmt.Errorf("no password credential located in secret store") +}