diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aec1491 --- /dev/null +++ b/.gitignore @@ -0,0 +1,67 @@ +# Application created directories +output/ + +# Visual Studio Code +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launce.json +!.vscode/extensions.json +!.vscode/*.code-snippets +.history/ +*.vsix + +# GoLang +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +go.work + +# General +.DS_Store +.AppleDouble +.LSOverride +# Icon must end with two \r +Icon + + +# Thumbnails +._* +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db +# Dump file +*.stackdump +# Folder config file +[Dd]esktop.ini +# Recycle Bin used on file shares +$RECYCLE.BIN/ +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp +# Windows shortcuts +*.lnk \ No newline at end of file diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..cb7d85d --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,58 @@ +linters: + disable-all: true + enable: + # default linters + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - unused + # project linters + - asasalint + - asciicheck + - bodyclose + - contextcheck + - dupl + - durationcheck + - errchkjson + - gocheckcompilerdirectives + - gocognit + - goconst + - gocritic + - godox + - goimports + - gosec + - grouper + - importas + - misspell + - musttag + - nestif + - nilerr + - nilnil + - prealloc + - reassign + - tagalign + - tenv + - unconvert + - unparam + - usestdlibvars + - wastedassign + - whitespace + fast: true +linter-settings: + tagalign: + order: + - json + - yaml + - yml + - toml + - mapstructure + - binding + - validate + - env + - default + - ignored + - required + - secret + - info \ No newline at end of file diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..b32e3d1 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "golang.go" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..abdad2f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,29 @@ +{ + "go.useLanguageServer": true, + "go.vetOnSave": "package", + "go.lintOnSave": "package", + "go.formatTool": "goimports", + "go.lintTool": "golangci-lint", + "go.lintFlags": [ + "--fast" + ], + + "[go]": { + "editor.detectIndentation": false, + "editor.tabSize": 2, + "editor.insertSpaces": false, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + }, + + "cSpell.words": [ + "ftype", + "nolint", + "goconst", + "TZUTC", + "webserver", + "gocognit" + ] +} \ No newline at end of file diff --git a/assets/embed.go b/assets/embed.go new file mode 100644 index 0000000..f267eb4 --- /dev/null +++ b/assets/embed.go @@ -0,0 +1,6 @@ +package assets + +import "embed" + +//go:embed html/* +var EmbedHTML embed.FS diff --git a/assets/html/css/style.css b/assets/html/css/style.css new file mode 100644 index 0000000..11fcded --- /dev/null +++ b/assets/html/css/style.css @@ -0,0 +1,41 @@ +body { + background-color: #ffffff; + color: #000000; + font-size:14pt; + line-height:1.5em; + font-family:"Myriad Pro", "Trebuchet MS", Helvetica, sans-serif; + width: 44em; + margin:4ex 0 12ex 5%; +} +.fire { + font-size: 40pt; + color: #ff0000; +} +.always { + font-size: 25pt; + color: #0000ff; +} +.safe { + font-size: 40pt; + color: #00af00; +} +P.little { + line-height:1em; + font-size: 35pt; + color: #f97b04; +} +small { + font-size: 10pt; +} +A:link { + color: #aa0000; +} +A:visited { + color: #606060; +} +A:active { + color: #ffffff; +} +img.c1 { + border:0;width:88px;height:31px +} diff --git a/assets/html/index.tplt b/assets/html/index.tplt new file mode 100644 index 0000000..acd1b5b --- /dev/null +++ b/assets/html/index.tplt @@ -0,0 +1,49 @@ + + + + + + + + + + + Is The Internet On Fire? + + + + + Is The Internet On Fire? + + + + [txtdig +short txt istheinternetonfire.app [json] + +
+ + {{- if gt (len .CVEs) 0 }} +

+ Yes!
+ It's always something.
+

+
What's Burning?
+ {{- range .CVEs }} + {{ .CveID | ToUpper }} - {{ .Product }} - {{ .ShortDescription }}
+ {{- end }} + {{ else }} +

+ Nope!
+

+ {{- end }} +
+ + Inspiration for this site was taken directly from istheinternetonfire.com by @jschauma. +
+ Updated by @nhyatt. +
+ Source located on GiTea +
+ + diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ea1b0dd --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module istheinternetonfire.app + +go 1.21.6 diff --git a/internal/cisa/cisa.go b/internal/cisa/cisa.go new file mode 100644 index 0000000..549ff52 --- /dev/null +++ b/internal/cisa/cisa.go @@ -0,0 +1,45 @@ +package cisa + +import ( + "encoding/json" + "net/http" + "sync" + "time" + + "istheinternetonfire.app/internal/config" + "istheinternetonfire.app/internal/httpclient" +) + +var ( + mu sync.Mutex + Cisa CisaJSON +) + +func Read() CisaJSON { + mu.Lock() + o := Cisa + mu.Unlock() + return o +} + +func Start() { + for { + c := httpclient.NewClient(http.DefaultClient) + d, err := c.Get(config.Cfg.RemoteURL) + if err != nil { + time.Sleep(time.Second * 120) + continue + } + + mu.Lock() + if err := json.Unmarshal(d, &Cisa); err != nil { + mu.Unlock() + time.Sleep(time.Second * 120) + continue + } + mu.Unlock() + + config.Cfg.Log.Info("obtained remote data") + time.Sleep(time.Second * time.Duration(config.Cfg.RefreshSeconds)) + } +} diff --git a/internal/cisa/struct-cisa.go b/internal/cisa/struct-cisa.go new file mode 100644 index 0000000..0bd2525 --- /dev/null +++ b/internal/cisa/struct-cisa.go @@ -0,0 +1,22 @@ +package cisa + +type CisaJSON struct { + CatalogVersion string `json:"catalogVersion"` + Count int `json:"count"` + DateReleased string `json:"dateReleased"` + Title string `json:"title"` + Vulnerabilities []VulStruct `json:"vulnerabilities"` +} + +type VulStruct struct { + CveID string `json:"cveID"` + DateAdded string `json:"dateAdded"` + DueDate string `json:"dueDate"` + KnownRansomwareCampaignUse string `json:"knownRansomwareCampaignUse"` + Notes string `json:"notes"` + Product string `json:"product"` + RequiredAction string `json:"requiredAction"` + ShortDescription string `json:"shortDescription"` + VendorProject string `json:"vendorProject"` + VulnerabilityName string `json:"vulnerabilityName"` +} \ No newline at end of file diff --git a/internal/config/envconfig.go b/internal/config/envconfig.go new file mode 100644 index 0000000..73c1786 --- /dev/null +++ b/internal/config/envconfig.go @@ -0,0 +1,241 @@ +package config + +import ( + "flag" + "fmt" + "os" + "reflect" + "strconv" + "strings" +) + +type structInfo struct { + Name string + Alt string + Info string + Key string + Field reflect.Value + Tags reflect.StructTag + Type reflect.Type + DefaultValue interface{} + Secret interface{} +} + +func getEnv[t string | bool | int | int64 | float64](env string, def t) (t, error) { + val := os.Getenv(env) + if len(val) == 0 { + return def, nil + } + + output := *new(t) + switch (interface{})(def).(type) { + case string: + v, err := typeConversion("string", val) + if err != nil { + return (interface{})(false).(t), err + } + output = v.(t) + case bool: + v, err := typeConversion("bool", val) + if err != nil { + return (interface{})(false).(t), err + } + output = v.(t) + case int: + v, err := typeConversion("int", val) + if err != nil { + return (interface{})(int(0)).(t), err + } + output = (interface{})(int(v.(int64))).(t) + case int64: + v, err := typeConversion("int64", val) + if err != nil { + return (interface{})(int64(0)).(t), err + } + output = v.(t) + case float64: + v, err := typeConversion("float64", val) + if err != nil { + return (interface{})(float64(0)).(t), err + } + output = v.(t) + } + + return output, nil +} + +func getStructInfo(spec interface{}) ([]structInfo, error) { + s := reflect.ValueOf(spec) + + if s.Kind() != reflect.Pointer { + return []structInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a pointer to a struct.\n", s.Kind()) + } + + s = s.Elem() + if s.Kind() != reflect.Struct { + return []structInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a struct.\n", s.Kind()) + } + typeOfSpec := s.Type() + + infos := make([]structInfo, 0, s.NumField()) + for i := 0; i < s.NumField(); i++ { + f := s.Field(i) + ftype := typeOfSpec.Field(i) + + ignored, _ := strconv.ParseBool(ftype.Tag.Get("ignored")) + if !f.CanSet() || ignored { + continue + } + + for f.Kind() == reflect.Pointer { + if f.IsNil() { + if f.Type().Elem().Kind() != reflect.Struct { + break + } + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + + secret, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("secret")) + if err != nil { + secret = false + } + + var desc string + if len(ftype.Tag.Get("info")) != 0 { + desc = fmt.Sprintf("(%s) %s", strings.ToUpper(ftype.Tag.Get("env")), ftype.Tag.Get("info")) + } else { + desc = fmt.Sprintf("(%s)", strings.ToUpper(ftype.Tag.Get("env"))) + } + + info := structInfo{ + Name: ftype.Name, + Alt: strings.ToUpper(ftype.Tag.Get("env")), + Info: desc, + Key: ftype.Name, + Field: f, + Tags: ftype.Tag, + Type: ftype.Type, + Secret: secret, + } + if info.Alt != "" { + info.Key = info.Alt + } + info.Key = strings.ToUpper(info.Key) + if ftype.Tag.Get("default") != "" { + v, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("default")) + if err != nil { + return []structInfo{}, err + } + info.DefaultValue = v + } + infos = append(infos, info) + } + return infos, nil +} + +func typeConversion(t, v string) (interface{}, error) { + switch t { + case "string": //nolint:goconst + return v, nil + case "int": //nolint:goconst + return strconv.ParseInt(v, 10, 0) + case "int8": + return strconv.ParseInt(v, 10, 8) + case "int16": + return strconv.ParseInt(v, 10, 16) + case "int32": + return strconv.ParseInt(v, 10, 32) + case "int64": + return strconv.ParseInt(v, 10, 64) + case "uint": + return strconv.ParseUint(v, 10, 0) + case "uint16": + return strconv.ParseUint(v, 10, 16) + case "uint32": + return strconv.ParseUint(v, 10, 32) + case "uint64": + return strconv.ParseUint(v, 10, 64) + case "float32": + return strconv.ParseFloat(v, 32) + case "float64": + return strconv.ParseFloat(v, 64) + case "complex64": + return strconv.ParseComplex(v, 64) + case "complex128": + return strconv.ParseComplex(v, 128) + case "bool": //nolint:goconst + return strconv.ParseBool(v) + } + return nil, fmt.Errorf("Unable to identify type.") +} + +func (cfg *Config) parseFlags(cfgInfo []structInfo) error { //nolint:gocognit + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + var dv string + + if info.DefaultValue != nil { + dv = info.DefaultValue.(string) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.StringVar(p, info.Name, retVal, info.Info) + case "bool": + var dv bool + + if info.DefaultValue != nil { + dv = info.DefaultValue.(bool) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.BoolVar(p, info.Name, retVal, info.Info) + case "int": + var dv int + + if info.DefaultValue != nil { + dv = int(info.DefaultValue.(int64)) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.IntVar(p, info.Name, retVal, info.Info) + case "int64": + var dv int64 + + if info.DefaultValue != nil { + dv = info.DefaultValue.(int64) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int64) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.Int64Var(p, info.Name, retVal, info.Info) + case "float64": + var dv float64 + + if info.DefaultValue != nil { + dv = info.DefaultValue.(float64) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*float64) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.Float64Var(p, info.Name, retVal, info.Info) + } + } + flag.Parse() + return nil +} diff --git a/internal/config/initialize.go b/internal/config/initialize.go new file mode 100644 index 0000000..c6506bf --- /dev/null +++ b/internal/config/initialize.go @@ -0,0 +1,37 @@ +package config + +import ( + "log" + "os" + "time" +) + +var Cfg Config + +func Init() { + Cfg = New() + + cfgInfo, err := getStructInfo(&Cfg) + if err != nil { + log.Fatalf("Unable to initialize program: %v", err) + } + + // get command line flags + if err := Cfg.parseFlags(cfgInfo); err != nil { + log.Fatalf("Unable to initialize program: %v", err) + } + + // set logging Level + setLogLevel(&Cfg) + + // set timezone & time format + Cfg.TZUTC, _ = time.LoadLocation("UTC") + Cfg.TZLocal, err = time.LoadLocation(Cfg.TimeZoneLocal) + if err != nil { + Cfg.Log.Error("Unable to parse timezone string", "error", err) + os.Exit(1) + } + + // print running config + printRunningConfig(&Cfg, cfgInfo) +} diff --git a/internal/config/struct-config.go b/internal/config/struct-config.go new file mode 100644 index 0000000..24ec803 --- /dev/null +++ b/internal/config/struct-config.go @@ -0,0 +1,85 @@ +package config + +import ( + "log/slog" + "os" + "reflect" + "strconv" + "time" +) + +type Config struct { + // time configuration + TimeFormat string `default:"2006-01-02 15:04:05" env:"time_format"` + TimeZoneLocal string `default:"America/Chicago" env:"time_zone"` + TZLocal *time.Location `ignored:"true"` + TZUTC *time.Location `ignored:"true"` + + // logging + LogLevel int `default:"50" env:"log_level"` + Log *slog.Logger `ignored:"true"` + SLogLevel *slog.LevelVar `ignored:"true"` + + // webserver + WebServerPort int `default:"8080" env:"webserver_port"` + WebServerIP string `default:"0.0.0.0" env:"webserver_ip"` + WebServerReadTimeout int `default:"5" env:"webserver_read_timeout"` + WebServerWriteTimeout int `default:"1" env:"webserver_write_timeout"` + WebServerIdleTimeout int `default:"2" env:"webserver_idle_timeout"` + + // cisa + RemoteURL string `default:"https://www.cisa.gov/sites/default/files/feeds/known_exploited_vulnerabilities.json?sort_by=field_date_added" env:"remote_url"` + RefreshSeconds int `default:"14400" env:"refresh_seconds"` +} + +// New initializes the config variable for use with a prepared set of defaults. +func New() Config { + cfg := Config{ + SLogLevel: new(slog.LevelVar), + } + + cfg.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: cfg.SLogLevel, + })) + + return cfg +} + +func setLogLevel(cfg *Config) { + switch { + // error + case cfg.LogLevel <= 20: + cfg.SLogLevel.Set(slog.LevelError) + cfg.Log.Info("Log level updated", "level", slog.LevelError) + // warning + case cfg.LogLevel > 20 && cfg.LogLevel <= 40: + cfg.SLogLevel.Set(slog.LevelWarn) + cfg.Log.Info("Log level updated", "level", slog.LevelWarn) + // info + case cfg.LogLevel > 40 && cfg.LogLevel <= 60: + cfg.SLogLevel.Set(slog.LevelInfo) + cfg.Log.Info("Log level updated", "level", slog.LevelInfo) + // debug + case cfg.LogLevel > 60: + cfg.SLogLevel.Set(slog.LevelDebug) + cfg.Log.Info("Log level updated", "level", slog.LevelDebug) + } + // set default logger + slog.SetDefault(cfg.Log) +} + +func printRunningConfig(cfg *Config, cfgInfo []structInfo) { + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + cfg.Log.Debug("Running Configuration", info.Alt, *p) + case "bool": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatBool(*p)) + case "int": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatInt(int64(*p), 10)) + } + } +} diff --git a/internal/httpclient/httpclient.go b/internal/httpclient/httpclient.go new file mode 100644 index 0000000..ec0cef8 --- /dev/null +++ b/internal/httpclient/httpclient.go @@ -0,0 +1,165 @@ +package httpclient + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "net/http" +) + +// HTTPClient is an interface for initializing the http client library. +type HTTPClient struct { + Client *http.Client + Data *bytes.Buffer + Headers map[string]string + + Username string + Password string +} + +// DefaultClient is a function for defining a basic HTTP client with standard timeouts. +func DefaultClient() *HTTPClient { + return &HTTPClient{ + Client: &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 5 * time.Second, + IdleConnTimeout: 300 * time.Second, + }, + }, + } +} + +// NewClient Create an HTTPClient with a user-provided net/http.Client +func NewClient(httpClient *http.Client) *HTTPClient { + return &HTTPClient{Client: httpClient} +} + +// SetBasicAuth is a chaining function to set the username and password for basic +// authentication +func (c *HTTPClient) SetBasicAuth(username, password string) *HTTPClient { + c.Username = username + c.Password = password + + return c +} + +// SetPostData is a chaining function to set POST/PUT/PATCH data +func (c *HTTPClient) SetPostData(data string) *HTTPClient { + c.Data = bytes.NewBufferString(data) + + return c +} + +// SetHeader is a chaining function to set arbitrary HTTP Headers +func (c *HTTPClient) SetHeader(label string, value string) *HTTPClient { + if c.Headers == nil { + c.Headers = map[string]string{} + } + + c.Headers[label] = value + + return c +} + +// Get calls the net.http GET operation +func (c *HTTPClient) Get(url string) ([]byte, error) { + return c.do(url, http.MethodGet) +} + +// Patch calls the net.http PATCH operation +func (c *HTTPClient) Patch(url string) ([]byte, error) { + return c.do(url, http.MethodPatch) +} + +// Post calls the net.http POST operation +func (c *HTTPClient) Post(url string) ([]byte, error) { + return c.do(url, http.MethodPost) +} + +// Put calls the net.http PUT operation +func (c *HTTPClient) Put(url string) ([]byte, error) { + return c.do(url, http.MethodPut) +} + +func (c *HTTPClient) do(url string, method string) ([]byte, error) { + var ( + req *http.Request + res *http.Response + output []byte + err error + ) + + // NewRequest knows that c.data is typed *bytes.Buffer and will SEGFAULT + // if c.data is nil. So we create a request using nil when c.data is nil + if c.Data != nil { + req, err = http.NewRequest(method, url, c.Data) + } else { + req, err = http.NewRequest(method, url, nil) + } + if err != nil { + return nil, err + } + + if (len(c.Username) > 0) && (len(c.Password) > 0) { + req.SetBasicAuth(c.Username, c.Password) + } + + if c.Headers != nil { + for label, value := range c.Headers { + req.Header.Set(label, value) + } + } + + if res, err = c.Client.Do(req); err != nil { + return nil, err + } + + defer res.Body.Close() + + if output, err = io.ReadAll(res.Body); err != nil { + return nil, err + } + + // check status + if res.StatusCode < 200 || res.StatusCode >= 300 { + return nil, errors.New("non-successful status code received [" + strconv.Itoa(res.StatusCode) + "]") + } + + // gzip encoding + if strings.EqualFold(res.Header.Get("Content-Encoding"), "gzip") || strings.EqualFold(res.Header.Get("Content-Encoding"), "x-gzip") { + compHandler, err := gzip.NewReader(bytes.NewReader(output)) + if err != nil { + return nil, fmt.Errorf("unable to uncompress response: %v", err) + } + output, err = io.ReadAll(compHandler) + if err != nil { + return nil, fmt.Errorf("unable to uncompress response: %v", err) + } + } + + // deflate encoding + if strings.EqualFold(res.Header.Get("Content-Encoding"), "deflate") { + compHandler, err := zlib.NewReader(bytes.NewReader(output)) + if err != nil { + return nil, fmt.Errorf("unable to uncompress response: %v", err) + } + output, err = io.ReadAll(compHandler) + if err != nil { + return nil, fmt.Errorf("unable to uncompress response: %v", err) + } + } + + return output, nil +} diff --git a/internal/httpclient/httpclient_test.go b/internal/httpclient/httpclient_test.go new file mode 100644 index 0000000..4ae88c3 --- /dev/null +++ b/internal/httpclient/httpclient_test.go @@ -0,0 +1,281 @@ +package httpclient + +import ( + "fmt" + "io" + "testing" + + "encoding/json" + "net/http" + "net/http/httptest" +) + +type Data struct { + Greeting string `json:"greeting"` + Headers map[string]string `json:"headers"` + Method string `json:"method"` + Username string `json:"username"` + Password string `json:"password"` + PostData string `json:"postdata"` +} + +var ( + greeting = "Hello world" + postData = "Test data" + authUser = "testuser" + authPass = "testpass" + headerLabel = "Test-Header" + headerValue = "Test-Value" +) + +func httpTestHandler(w http.ResponseWriter, r *http.Request) { + var ( + b []byte + user string + pass string + body []byte + ) + + data := Data{ + Greeting: greeting, + Headers: map[string]string{}, + Method: r.Method, + } + + user, pass, ok := r.BasicAuth() + if ok { + data.Username = user + data.Password = pass + } + + body, err := io.ReadAll(r.Body) + if err != nil { + fmt.Fprint(w, "io.ReadAll failed") + } + data.PostData = string(body) + + for h := range r.Header { + data.Headers[h] = r.Header.Get(h) + } + + b, err = json.MarshalIndent(data, "", " ") + if err != nil { + fmt.Fprint(w, "Json marshal failed somehow") + } + fmt.Fprint(w, string(b)) +} + +func checkMethod(t *testing.T, data Data, method string) { + if data.Method != method { + t.Errorf("data.Method(%s) != method(%s)", data.Method, method) + } + t.Log("checkMethod() success") +} + +func checkGreeting(t *testing.T, data Data) { + if data.Greeting != greeting { + t.Errorf("data.Greeting(%s) != greeting(%s)", data.Greeting, greeting) + } + t.Log("checkGreeting() success") +} + +func checkBasicAuth(t *testing.T, data Data) { + if data.Username != authUser { + t.Errorf("data.Username(%s) != authUser(%s)", data.Username, authUser) + } + if data.Password != authPass { + t.Errorf("data.Password(%s) != authPass(%s)", data.Password, authPass) + } + t.Log("checkBasicAuth() success") +} + +func checkPostData(t *testing.T, data Data) { + if data.PostData != postData { + t.Errorf("data.PostData(%s) != postData(%s)", data.PostData, postData) + } + t.Log("checkPostData() success") +} + +func TestGet(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) +} + +func TestGetAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) + checkBasicAuth(t, data) +} + +func TestPut(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Put(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPut) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPutAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Put(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPut) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestPost(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Post(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPost) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPostAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Post(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPost) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestPatch(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetPostData(postData).Patch(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPatch) + checkGreeting(t, data) + checkPostData(t, data) +} + +func TestPatchAuth(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetBasicAuth(authUser, authPass).SetPostData(postData).Patch(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodPatch) + checkGreeting(t, data) + checkBasicAuth(t, data) + checkPostData(t, data) +} + +func TestSetHeader(t *testing.T) { + var data Data + + ts := httptest.NewServer(http.HandlerFunc(httpTestHandler)) + defer ts.Close() + + output, err := DefaultClient().SetHeader(headerLabel, headerValue).Get(ts.URL) + if err != nil { + t.Error(err) + } + + if err = json.Unmarshal(output, &data); err != nil { + t.Error(err) + } + + checkMethod(t, data, http.MethodGet) + checkGreeting(t, data) + if data.Headers[headerLabel] != headerValue { + t.Errorf("SetHeader values not set in header: %+v", data.Headers) + } +} diff --git a/internal/webserver/httpServer.go b/internal/webserver/httpServer.go new file mode 100644 index 0000000..44ca712 --- /dev/null +++ b/internal/webserver/httpServer.go @@ -0,0 +1,118 @@ +package webserver + +import ( + "fmt" + "log" + "regexp" + "strconv" + "strings" + "time" + + "compress/gzip" + "net/http" + + "istheinternetonfire.app/assets" + "istheinternetonfire.app/internal/config" +) + +const ( + TYPE_APPLICATION_PEM string = "application/x-pem-file" + TYPE_APPLICATION_JSON string = "application/json" + TYPE_AUDIO_MPEG string = "audio/mpeg" + TYPE_FONT_WOFF string = "font/woff" + TYPE_FONT_WOFF2 string = "font/woff2" + TYPE_IMAGE_JPG string = "image/jpg" + TYPE_IMAGE_PNG string = "image/png" + TYPE_TEXT_CSS string = "text/css" + TYPE_TEXT_HTML string = "text/html" + TYPE_TEXT_JS string = "text/javascript" + TYPE_TEXT_PLAIN string = "text/plain" + TYPE_TEXT_RAW string = "text/raw" +) + +var validFiles map[string]string = map[string]string{ + "/robots.txt": TYPE_TEXT_PLAIN, + "/apple-touch-icon.png": TYPE_IMAGE_PNG, + "/favicon.ico": TYPE_IMAGE_PNG, + "/favicon-16x16.png": TYPE_IMAGE_PNG, + "/favicon-32x32.png": TYPE_IMAGE_PNG, + "/js/bootstrap.bundle.min.js": TYPE_TEXT_JS, + "/js/bootstrap.bundle.min.js.map": TYPE_APPLICATION_JSON, + "/js/jquery.min.js": TYPE_TEXT_JS, +} + +func isValidReq(file string) (string, error) { + for f, t := range validFiles { + if file == f { + return t, nil + } + } + + return "", fmt.Errorf("Invalid file requested: %s", file) +} + +func httpAccessLog(req *http.Request) { + config.Cfg.Log.Debug("http request", "method", req.Method, "remote-address", req.RemoteAddr, "request-uri", req.RequestURI) +} + +func crossSiteOrigin(w http.ResponseWriter) { + w.Header().Add("Access-Control-Allow-Origin", "*") + w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS") +} + +func Start() { + path := http.NewServeMux() + + connection := &http.Server{ + Addr: config.Cfg.WebServerIP + ":" + strconv.FormatInt(int64(config.Cfg.WebServerPort), 10), + Handler: path, + ReadTimeout: time.Duration(config.Cfg.WebServerReadTimeout) * time.Second, + WriteTimeout: time.Duration(config.Cfg.WebServerWriteTimeout) * time.Second, + IdleTimeout: time.Duration(config.Cfg.WebServerIdleTimeout) * time.Second, + } + + path.HandleFunc("/", webRoot) + + if err := connection.ListenAndServe(); err != nil { + config.Cfg.Log.Error("unable to start webserver", "error", err) + } +} + +func webRoot(w http.ResponseWriter, r *http.Request) { + httpAccessLog(r) + crossSiteOrigin(w) + + if strings.ToLower(r.Method) != "get" { + config.Cfg.Log.Debug("http invalid method", "url", r.URL.Path, "expected", "GET", "received", r.Method) + tmpltError(w, http.StatusBadRequest, "Invalid http method.") + return + } + + if r.URL.Path == "/" { + tmpltWebRoot(w, r) + } else { + cType, err := isValidReq(r.URL.Path) + if err != nil { + config.Cfg.Log.Debug("request not found", "url", r.URL.Path) + tmpltStatusNotFound(w, r.URL.Path) + return + } + + w.Header().Add("Content-Type", cType) + o, err := assets.EmbedHTML.ReadFile("html" + r.URL.Path) + if err != nil { + log.Printf("[ERROR] Unable to read local embedded file data: %v\n", err) + tmpltError(w, http.StatusInternalServerError, "Server unable to retrieve file data.") + return + } + + if regexp.MustCompile(`gzip`).Match([]byte(r.Header.Get("Accept-Encoding"))) { + w.Header().Add("Content-Encoding", "gzip") + gw := gzip.NewWriter(w) + defer gw.Close() + gw.Write(o) + } else { + w.Write(o) + } + } +} diff --git a/internal/webserver/httpTemplate.go b/internal/webserver/httpTemplate.go new file mode 100644 index 0000000..17bf5f6 --- /dev/null +++ b/internal/webserver/httpTemplate.go @@ -0,0 +1,102 @@ +package webserver + +import ( + "bytes" + "strings" + "time" + + "encoding/json" + "net/http" + "text/template" + + "istheinternetonfire.app/assets" + "istheinternetonfire.app/internal/cisa" + "istheinternetonfire.app/internal/config" +) + +type webErrStruct struct { + Error bool `json:"error" yaml:"error"` + ErrorMsg string `json:"error_message" yaml:"errorMessage"` +} + +func tmpltError(w http.ResponseWriter, serverStatus int, message string) { + var ( + output []byte + o = webErrStruct{ + Error: true, + ErrorMsg: message, + } + err error + ) + + w.Header().Add("Content-Type", "application/json") + output, err = json.MarshalIndent(o, "", " ") + if err != nil { + config.Cfg.Log.Warn("marshal error", "error", err) + w.WriteHeader(serverStatus) + w.Write(output) //nolint:errcheck + } +} + +func tmpltWebRoot(w http.ResponseWriter, r *http.Request) { + tmplt, err := template.New("index.tplt").Funcs(template.FuncMap{ + "ToUpper": strings.ToUpper, + }).ParseFS( + assets.EmbedHTML, + "html/index.tplt", + "html/css/style.css", + ) + if err != nil { + config.Cfg.Log.Debug("unable to parse html template", "error", err) + tmpltError(w, http.StatusInternalServerError, "Template Parse Error.") + return + } + + var ( + msgBuffer bytes.Buffer + cves []cisa.VulStruct + ) + + c := cisa.Read() + for _, i := range c.Vulnerabilities { + t, _ := time.Parse("2006-01-02", i.DateAdded) + if t.After(time.Now().Add(-time.Hour * 720)) { + cves = append(cves, i) + } + } + + if err := tmplt.Execute(&msgBuffer, struct { + CVEs []cisa.VulStruct + }{ + CVEs: cves[len(cves)-3:], + }); err != nil { + config.Cfg.Log.Debug("unable to execute html template", err) + tmpltError(w, http.StatusInternalServerError, "Template Parse Error.") + return + } + + w.Write(msgBuffer.Bytes()) +} + +func tmpltStatusNotFound(w http.ResponseWriter, path string) { + tmplt, err := template.ParseFS(assets.EmbedHTML, "html/file-not-found.tplt") + if err != nil { + config.Cfg.Log.Debug("unable to parse html template", err) + tmpltError(w, http.StatusInternalServerError, "Template Parse Error.") + return + } + + var msgBuffer bytes.Buffer + if err := tmplt.Execute(&msgBuffer, struct { + Title string + ErrorCode int + }{ + Title: path, + ErrorCode: http.StatusNotFound, + }); err != nil { + config.Cfg.Log.Debug("unable to execute html template", err) + tmpltError(w, http.StatusInternalServerError, "Template Parse Error.") + return + } + w.Write(msgBuffer.Bytes()) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..0174dcb --- /dev/null +++ b/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + + "istheinternetonfire.app/internal/cisa" + "istheinternetonfire.app/internal/config" + "istheinternetonfire.app/internal/webserver" +) + +func forever() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + sig := <-c + log.Printf("[WARNING] shutting down, detected signal: %s", sig) +} + +func main() { + // initialize all parameters + config.Init() + + // configure shutdown sequence + defer func() { + log.Printf("[TRACE] shutdown sequence complete") + }() + + // start webserver + go webserver.Start() + + // get remote data + go cisa.Start() + + forever() +}