From ce4b4a11ffd086d245f8ad556d8643312ba0df68 Mon Sep 17 00:00:00 2001 From: nhyatt Date: Fri, 2 May 2025 20:57:03 -0500 Subject: [PATCH] some library updates and adds support for adblock lists --- .vscode/settings.json | 2 +- assets/config/config.yaml | 14 +- cmd/bind/build-bind.go | 10 +- cmd/bind/cleanup.go | 14 +- cmd/bind/get-remote-data.go | 33 ++- cmd/bind/main.go | 14 +- cmd/bind/parsing-adblock.go | 53 +++++ cmd/bind/parsing-complex.go | 4 +- cmd/bind/parsing-simple.go | 4 +- go.mod | 15 +- go.sum | 8 + .../{config => common}/file-operations.go | 2 +- internal/config/envconfig.go | 66 +++--- internal/config/envconfig_test.go | 191 ++++++++++++++++++ internal/config/initialize.go | 41 ++-- internal/config/struct-config.go | 89 ++++---- internal/config/struct-config_test.go | 40 ++++ internal/log/logging.go | 154 ++++++++++++++ internal/log/logging_test.go | 97 +++++++++ 19 files changed, 712 insertions(+), 139 deletions(-) create mode 100644 cmd/bind/parsing-adblock.go rename internal/{config => common}/file-operations.go (98%) create mode 100644 internal/config/envconfig_test.go create mode 100644 internal/config/struct-config_test.go create mode 100644 internal/log/logging.go create mode 100644 internal/log/logging_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 7c70a28..83ffa7b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ "editor.insertSpaces": false, "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" } }, diff --git a/assets/config/config.yaml b/assets/config/config.yaml index 0b01da7..d1679aa 100644 --- a/assets/config/config.yaml +++ b/assets/config/config.yaml @@ -57,7 +57,10 @@ sources: - "https://raw.githubusercontent.com/dibdot/DoH-IP-blocklists/master/doh-domains.txt" - "https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt" - "https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-malware.txt" - - "https://dbl.oisd.nl/" + - "https://big.oisd.nl/" + adBlockURLs: + - "https://raw.githubusercontent.com/hagezi/dns-blocklists/main/adblock/pro.txt" + - "https://raw.githubusercontent.com/hagezi/dns-blocklists/main/adblock/tif.txt" allowList: - ^localhost$ - ^localhost\.localdomain$ @@ -75,6 +78,7 @@ allowList: - (^|\.)assets\.adobedtm\.com$ - (^|\.)brandify\.com$ - (^|\.)clients.\.google\.com$ +- (^|\.)cpng\.lol$ - (^|\.)doubleclick\.net$ - (^|\.)duckduckgo\.com$ - (^|\.)ghostery\.net$ @@ -95,7 +99,11 @@ allowList: - (^|\.)unagi-na\.amazon\.com$ - (^|\.)unagi\.amazon\.com$ - (^|\.)vercel-dns\.com$ -- ^\w+-\d{4}\.\w+-msedge\.net$ +- (^|\.)launchdarkly\.com$ +- (^|\.)mimojp\.store$ +- ^\w+-\d{4}\.\w+-msedge\.net$ +- ^ctldl\.windowsupdate\.com$ +- ^settings-win\.data\.microsoft\.com$ denyList: - jindlecleanings.xyz -- "*.jindlecleanings.xyz" +- "*.jindlecleanings.xyz" \ No newline at end of file diff --git a/cmd/bind/build-bind.go b/cmd/bind/build-bind.go index 66e35d9..daa96b8 100644 --- a/cmd/bind/build-bind.go +++ b/cmd/bind/build-bind.go @@ -5,8 +5,10 @@ import ( "text/template" - "pihole-blocklist/bind/assets" - "pihole-blocklist/bind/internal/config" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/assets" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/common" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) func buildBindResponsePolicyFile() error { @@ -25,11 +27,11 @@ func buildBindResponsePolicyFile() error { return err } - bytesWritten, err := config.WriteFile(cfg.BindOutputFileName, output.Bytes()) + bytesWritten, err := common.WriteFile(cfg.BindOutputFileName, output.Bytes()) if err != nil { return err } - cfg.Log.Debug("file created", "file", cfg.BindOutputFileName, "bytes", bytesWritten) + log.Debug("file created", "file", cfg.BindOutputFileName, "bytes", bytesWritten) return nil } diff --git a/cmd/bind/cleanup.go b/cmd/bind/cleanup.go index d572be2..50240bf 100644 --- a/cmd/bind/cleanup.go +++ b/cmd/bind/cleanup.go @@ -3,6 +3,8 @@ package main import ( "regexp" "sort" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) func cleanBadDomains(domains []string) []string { @@ -17,7 +19,7 @@ func cleanBadDomains(domains []string) []string { } } domains = list - cfg.Log.Info("hosts removed from blocklist", "reason", "duplicate", "hosts", total-len(domains)) + log.Info("hosts removed from blocklist", "reason", "duplicate", "hosts", total-len(domains)) // remove hosts that are too long total = len(domains) @@ -29,7 +31,7 @@ func cleanBadDomains(domains []string) []string { list = append(list, blocklistItem) } domains = list - cfg.Log.Info("hosts removed from blocklist", "reason", "too many characters", "hosts", total-len(domains)) + log.Info("hosts removed from blocklist", "reason", "too many characters", "hosts", total-len(domains)) // remove allow-listed matches total = len(domains) @@ -39,7 +41,7 @@ func cleanBadDomains(domains []string) []string { for _, allowedItem := range cfg.ConfigFile.AllowLists { _, err := regexp.Compile(allowedItem) if err != nil { - cfg.Log.Error("unable to parse allow list item", "error", err, "regex", allowedItem) + log.Error("unable to parse allow list item", "error", err, "regex", allowedItem) continue } goodAllowedItemList = append(goodAllowedItemList, allowedItem) @@ -50,7 +52,7 @@ func cleanBadDomains(domains []string) []string { addEntry := true for _, allowedItem := range goodAllowedItemList { if regexp.MustCompile(allowedItem).MatchString(v) { - cfg.Log.Debug("hosts removed from blocklist", "reason", "allowed host", "match string", allowedItem, "host", v) + log.Debug("hosts removed from blocklist", "reason", "allowed host", "match string", allowedItem, "host", v) addEntry = false } } @@ -59,9 +61,9 @@ func cleanBadDomains(domains []string) []string { } } domains = list - cfg.Log.Info("hosts removed from blocklist", "hosts", total-len(domains)) + log.Info("hosts removed from blocklist", "hosts", total-len(domains)) - cfg.Log.Info("total domains in list", "hosts", len(domains)) + log.Info("total domains in list", "hosts", len(domains)) sort.Strings(domains) return domains } diff --git a/cmd/bind/get-remote-data.go b/cmd/bind/get-remote-data.go index 997c53c..e9334d7 100644 --- a/cmd/bind/get-remote-data.go +++ b/cmd/bind/get-remote-data.go @@ -3,15 +3,17 @@ package main import ( "time" - "pihole-blocklist/bind/internal/httpclient" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/httpclient" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) func getListData() []string { var badDomains []string listSimple := make(chan []string) listComplex := make(chan []string) + listAdBlock := make(chan []string) - cfg.Log.Info("downloading blocklists") + log.Info("downloading blocklists") // Get Simple Blocklists go func() { data := getData(cfg.ConfigFile.Sources.DomainListURLs) @@ -26,28 +28,39 @@ func getListData() []string { listComplex <- domains }() + // Get AdBlock Blocklists + go func() { + data := getData(cfg.ConfigFile.Sources.AdBlockURLs) + domains := parseAdBlock(data) + listAdBlock <- domains + }() + // Wait for all downloads to finish var ( - simple, complex []string - simpleFinished, complexFinished bool + simple, complex, adblock []string + simpleFinished, complexFinished, adBlockFinished bool ) for { select { case simple = <-listSimple: + log.Info("all simple lists downloaded") simpleFinished = true - cfg.Log.Info("all simple lists downloaded") case complex = <-listComplex: - cfg.Log.Info("all complex lists downloaded") + log.Info("all complex lists downloaded") complexFinished = true + case adblock = <-listAdBlock: + log.Info("all adblock lists downloaded") + adBlockFinished = true default: time.Sleep(time.Millisecond * 100) } - if simpleFinished && complexFinished { + if simpleFinished && complexFinished && adBlockFinished { badDomains = append(badDomains, simple...) badDomains = append(badDomains, complex...) - cfg.Log.Info("domains retrieved", "hosts", len(badDomains)) + badDomains = append(badDomains, adblock...) + log.Info("domains retrieved", "hosts", len(badDomains)) break } } @@ -62,11 +75,11 @@ func getData(urls []string) []byte { listData := make([]byte, 0, len(urls)+1) for _, u := range urls { - cfg.Log.Debug("downloading", "url", u) + log.Debug("downloading", "url", u) c := httpclient.DefaultClient() data, err := c.Get(u) if err != nil { - cfg.Log.Error("unable to get remote content", "error", err, "url", err) + log.Error("unable to get remote content", "error", err, "url", err) } listData = append(listData, data...) // add newline to the end of data, you know, for funzies diff --git a/cmd/bind/main.go b/cmd/bind/main.go index 6560f5e..afa7c06 100644 --- a/cmd/bind/main.go +++ b/cmd/bind/main.go @@ -1,8 +1,11 @@ package main import ( - "pihole-blocklist/bind/internal/config" + "os" "time" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/config" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) var cfg config.Config @@ -11,7 +14,12 @@ func main() { // Blocklist Collection // https://firebog.net/ - cfg = config.Init() + var err error + + if cfg, err = config.Init(); err != nil { + log.Error("error initializing program", "error", err) + os.Exit(1) + } // Set the zone serial number cfg.ConfigFile.ZoneConfig.Serial = time.Now().In(cfg.TZLocal).Format("0601021504") @@ -24,6 +32,6 @@ func main() { // write file if err := buildBindResponsePolicyFile(); err != nil { - cfg.Log.Error("unable to write file", "error", err, "path", cfg.BindOutputFileName) + log.Error("unable to write file", "error", err, "path", cfg.BindOutputFileName) } } diff --git a/cmd/bind/parsing-adblock.go b/cmd/bind/parsing-adblock.go new file mode 100644 index 0000000..950962a --- /dev/null +++ b/cmd/bind/parsing-adblock.go @@ -0,0 +1,53 @@ +package main + +import ( + "bufio" + "bytes" + "regexp" + + "github.com/asaskevich/govalidator" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" +) + +func parseAdBlock(data []byte) []string { + var domains []string + + // convert data to reader for line-by-line reading + r := bytes.NewReader(data) + + // process combined files line-by-line + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + + // skip lines with the AdBlock header + if regexp.MustCompile(`^\[Adblock Plus\]`).MatchString(line) { + continue + } + + // skip lines where the first non-whitespace character is '#' or '//' + if regexp.MustCompile(`^(\s+)?(#|\/\/|\!)`).MatchString(line) { + continue + } + + // skip lines with no characters and/or whitespace only + if regexp.MustCompile(`^(\s+)?$`).MatchString(line) { + continue + } + + // remove line header + d := regexp.MustCompile(`^\|\|`).ReplaceAllString(line, "") + + // remove line footer + d = regexp.MustCompile(`\^$`).ReplaceAllString(d, "") + + if govalidator.IsDNSName(d) { + domains = append(domains, d) + } else { + log.Debug("host invalid", "host", d) + } + } + + return domains +} diff --git a/cmd/bind/parsing-complex.go b/cmd/bind/parsing-complex.go index 6306edb..cb7f9a4 100644 --- a/cmd/bind/parsing-complex.go +++ b/cmd/bind/parsing-complex.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/asaskevich/govalidator" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) func parseComplex(data []byte) []string { @@ -38,7 +40,7 @@ func parseComplex(data []byte) []string { if govalidator.IsDNSName(lineItems[1]) { domains = append(domains, lineItems[1]) } else { - cfg.Log.Debug("host invalid", "host", lineItems[0]) + log.Debug("host invalid", "host", lineItems[0]) } } } diff --git a/cmd/bind/parsing-simple.go b/cmd/bind/parsing-simple.go index 71b775e..037bea9 100644 --- a/cmd/bind/parsing-simple.go +++ b/cmd/bind/parsing-simple.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/asaskevich/govalidator" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) func parseSimple(data []byte) []string { @@ -38,7 +40,7 @@ func parseSimple(data []byte) []string { if govalidator.IsDNSName(lineItems[0]) { domains = append(domains, lineItems[0]) } else { - cfg.Log.Debug("host invalid", "host", lineItems[0]) + log.Debug("host invalid", "host", lineItems[0]) } } } diff --git a/go.mod b/go.mod index 02dceed..447e694 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,17 @@ -module pihole-blocklist/bind +module gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator -go 1.18 +go 1.21.0 + +toolchain go1.24.2 require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 - gopkg.in/yaml.v3 v3.0.1 + github.com/goccy/go-yaml v1.17.1 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0d53b0c..469beab 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/file-operations.go b/internal/common/file-operations.go similarity index 98% rename from internal/config/file-operations.go rename to internal/common/file-operations.go index 3b5a308..43abdf0 100644 --- a/internal/config/file-operations.go +++ b/internal/common/file-operations.go @@ -1,4 +1,4 @@ -package config +package common import ( "fmt" diff --git a/internal/config/envconfig.go b/internal/config/envconfig.go index 73c1786..d5414e6 100644 --- a/internal/config/envconfig.go +++ b/internal/config/envconfig.go @@ -18,7 +18,7 @@ type structInfo struct { Tags reflect.StructTag Type reflect.Type DefaultValue interface{} - Secret interface{} + Secret bool } func getEnv[t string | bool | int | int64 | float64](env string, def t) (t, error) { @@ -77,6 +77,10 @@ func getStructInfo(spec interface{}) ([]structInfo, error) { } typeOfSpec := s.Type() + return parseStructInfo(s, typeOfSpec) +} + +func parseStructInfo(s reflect.Value, typeOfSpec reflect.Type) ([]structInfo, error) { infos := make([]structInfo, 0, s.NumField()) for i := 0; i < s.NumField(); i++ { f := s.Field(i) @@ -87,17 +91,7 @@ func getStructInfo(spec interface{}) ([]structInfo, error) { continue } - for f.Kind() == reflect.Pointer { - if f.IsNil() { - if f.Type().Elem().Kind() != reflect.Struct { - break - } - f.Set(reflect.New(f.Type().Elem())) - } - f = f.Elem() - } - - secret, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("secret")) + secret, err := strconv.ParseBool(ftype.Tag.Get("secret")) if err != nil { secret = false } @@ -110,31 +104,41 @@ func getStructInfo(spec interface{}) ([]structInfo, error) { } info := structInfo{ - Name: ftype.Name, - Alt: strings.ToUpper(ftype.Tag.Get("env")), - Info: desc, - Key: ftype.Name, - Field: f, - Tags: ftype.Tag, - Type: ftype.Type, - Secret: secret, - } - if info.Alt != "" { - info.Key = info.Alt - } - info.Key = strings.ToUpper(info.Key) - if ftype.Tag.Get("default") != "" { - v, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("default")) - if err != nil { - return []structInfo{}, err - } - info.DefaultValue = v + Alt: strings.ToUpper(ftype.Tag.Get("env")), + DefaultValue: getDefault(ftype), + Field: f, + Info: desc, + Key: getAlt(ftype), + Name: ftype.Name, + Secret: secret, + Tags: ftype.Tag, + Type: ftype.Type, } + infos = append(infos, info) } + return infos, nil } +func getAlt(ftype reflect.StructField) string { + if len(ftype.Tag.Get("env")) > 0 { + return strings.ToUpper(ftype.Tag.Get("env")) + } + return strings.ToUpper(ftype.Name) +} + +func getDefault(ftype reflect.StructField) interface{} { + if ftype.Tag.Get("default") != "" { + v, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("default")) + if err != nil { + return nil + } + return v + } + return nil +} + func typeConversion(t, v string) (interface{}, error) { switch t { case "string": //nolint:goconst diff --git a/internal/config/envconfig_test.go b/internal/config/envconfig_test.go new file mode 100644 index 0000000..917b65b --- /dev/null +++ b/internal/config/envconfig_test.go @@ -0,0 +1,191 @@ +package config + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mock_config struct { + NoTags string + Ignored string `ignored:"true"` + Info string `info:"This is an info string."` + Secret string `secret:"true"` + Env string `env:"test_env"` + Default_string string `default:"This is a default string."` + Default_bool bool `default:"true"` + Default_int int `default:"100"` + Default_int64 int64 `default:"100"` + Default_float64 float64 `default:"100.001"` +} + +func TestGetEnv(t *testing.T) { + var ( + expected_string string = "This is a default string." + expected_bool bool = true + expected_int int = 100 + expected_int64 int64 = 100 + expected_float64 float64 = 100.001 + expected_unset_default string = "This is a default value." + ) + + // string + t.Setenv("TEST_STRING", expected_string) + test_string, err := getEnv("TEST_STRING", "This is a default string.") + assert.NoError(t, err) + assert.Equal(t, expected_string, test_string) + + // bool + _, err = getEnv("TEST_STRING", expected_bool) + assert.Error(t, err) + t.Setenv("TEST_BOOL", strconv.FormatBool(expected_bool)) + test_bool, err := getEnv("TEST_BOOL", expected_bool) + assert.NoError(t, err) + assert.Equal(t, expected_bool, test_bool) + + // int + _, err = getEnv("TEST_STRING", expected_int) + assert.Error(t, err) + t.Setenv("TEST_INT", strconv.FormatInt(int64(expected_int), 10)) + test_int, err := getEnv("TEST_INT", expected_int) + assert.NoError(t, err) + assert.Equal(t, expected_int, test_int) + + // int64 + _, err = getEnv("TEST_STRING", expected_int64) + assert.Error(t, err) + t.Setenv("TEST_INT64", strconv.FormatInt(expected_int64, 10)) + test_int64, err := getEnv("TEST_INT", expected_int64) + assert.NoError(t, err) + assert.Equal(t, expected_int64, test_int64) + + // float64 + _, err = getEnv("TEST_STRING", expected_float64) + assert.Error(t, err) + t.Setenv("TEST_INT", strconv.FormatFloat(expected_float64, 'f', 3, 64)) + test_float64, err := getEnv("TEST_INT", expected_float64) + assert.NoError(t, err) + assert.Equal(t, expected_float64, test_float64) + + // unset or missing environment variable + test_unset, err := getEnv("TEST_DEFAULT", expected_unset_default) + assert.NoError(t, err) + assert.Equal(t, expected_unset_default, test_unset) +} + +func TestGetStructInfo(t *testing.T) { + test_config := mock_config{ + NoTags: "notags", + Ignored: "ignored", + Secret: "secret", + } + + cfgInfo, err := getStructInfo(&test_config) + assert.NoError(t, err) + + for _, v := range cfgInfo { + switch v.Name { + case "Info": + assert.Equal(t, "() This is an info string.", v.Info) + case "Secret": + assert.Equal(t, true, v.Secret) + case "Env": + assert.Equal(t, "TEST_ENV", v.Alt) + case "Default_value": + assert.Equal(t, "This is a default string.", v.DefaultValue) + } + } +} + +func TestTypeConversion(t *testing.T) { + var ( + expected_string string = "This is a default string." + expected_int int = 100 + expected_int8 int8 = 100 + expected_int16 int16 = 100 + expected_int32 int32 = 100 + expected_int64 int64 = 100 + expected_uint uint = 100 + expected_uint16 uint16 = 100 + expected_uint32 uint32 = 100 + expected_uint64 uint64 = 100 + expected_float32 float32 = 100.001 + expected_float64 float64 = 100.001 + expected_bool bool = true + ) + + // string + output_string, err := typeConversion("string", expected_string) + assert.NoError(t, err) + assert.Equal(t, expected_string, output_string) + + // int + output_int, err := typeConversion("int", strconv.FormatInt(int64(expected_int), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_int, int(output_int.(int64))) + + // int8 + output_int8, err := typeConversion("int8", strconv.FormatInt(int64(expected_int8), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_int8, int8(output_int8.(int64))) // nolint: gosec + + // int16 + output_int16, err := typeConversion("int16", strconv.FormatInt(int64(expected_int16), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_int16, int16(output_int16.(int64))) // nolint: gosec + + // int32 + output_int32, err := typeConversion("int32", strconv.FormatInt(int64(expected_int32), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_int32, int32(output_int32.(int64))) // nolint: gosec + + // int64 + output_int64, err := typeConversion("int64", strconv.FormatInt(expected_int64, 10)) + assert.NoError(t, err) + assert.Equal(t, expected_int64, output_int64) + + // uint + output_uint, err := typeConversion("uint", strconv.FormatInt(int64(expected_uint), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_uint, uint(output_uint.(uint64))) // nolint: gosec + + // uint16 + output_uint16, err := typeConversion("uint16", strconv.FormatInt(int64(expected_uint16), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_uint16, uint16(output_uint16.(uint64))) // nolint: gosec + + // uint32 + output_uint32, err := typeConversion("uint32", strconv.FormatInt(int64(expected_uint32), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_uint32, uint32(output_uint32.(uint64))) // nolint: gosec + + // uint64 + output_uint64, err := typeConversion("uint64", strconv.FormatInt(int64(expected_uint64), 10)) + assert.NoError(t, err) + assert.Equal(t, expected_uint64, output_uint64) + + // float32 + output_float32, err := typeConversion("float32", strconv.FormatFloat(float64(expected_float32), 'f', 3, 64)) + assert.NoError(t, err) + assert.Equal(t, expected_float32, float32(output_float32.(float64))) // nolint: gosec + + // float64 + output_float64, err := typeConversion("float64", strconv.FormatFloat(expected_float64, 'f', 3, 64)) + assert.NoError(t, err) + assert.Equal(t, expected_float64, output_float64) + + // bool + output_bool, err := typeConversion("bool", strconv.FormatBool(expected_bool)) + assert.NoError(t, err) + assert.Equal(t, expected_bool, output_bool) +} + +func TestParseFlags(t *testing.T) { + test_config := Config{} + + cfgInfo, err := getStructInfo(&test_config) + assert.NoError(t, err) + + assert.NoError(t, test_config.parseFlags(cfgInfo)) +} diff --git a/internal/config/initialize.go b/internal/config/initialize.go index df4a783..b32a845 100644 --- a/internal/config/initialize.go +++ b/internal/config/initialize.go @@ -2,63 +2,62 @@ package config import ( "fmt" - "os" - "pihole-blocklist/bind/assets" "time" - "gopkg.in/yaml.v3" + "github.com/goccy/go-yaml" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/assets" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/common" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) -func Init() Config { +func Init() (Config, error) { cfg := New() + // parse config structure cfgInfo, err := getStructInfo(&cfg) if err != nil { - panic(fmt.Sprintf("Unable to initialize program: %v", err)) + return Config{}, err } // get command line flags if err := cfg.parseFlags(cfgInfo); err != nil { - panic(fmt.Sprintf("Unable to initialize program: %v", err)) + return Config{}, err } // set logging Level - setLogLevel(&cfg) + log.Init("text") + log.SetNumericLevel(cfg.LogLevel) // set timezone & time format cfg.TZUTC, _ = time.LoadLocation("UTC") cfg.TZLocal, err = time.LoadLocation(cfg.TimeZoneLocal) if err != nil { - cfg.Log.Error("Unable to parse timezone string", "error", err) - os.Exit(1) + return Config{}, err } // check config file - if !FileExists(cfg.ConfigFileLocation) { - if _, err := WriteFile(cfg.ConfigFileLocation, assets.Config); err != nil { - cfg.Log.Error(err.Error()) - os.Exit(1) + if !common.FileExists(cfg.ConfigFileLocation) { + if _, err := common.WriteFile(cfg.ConfigFileLocation, assets.Config); err != nil { + return Config{}, err } - cfg.Log.Error("Unable to locate configuration file, an example config file has been written", "path", cfg.ConfigFileLocation) - os.Exit(1) + return Config{}, fmt.Errorf("Unable to locate configuration file, an example config file has been written to %s", cfg.ConfigFileLocation) } // read config - cfData, err := ReadFile(cfg.ConfigFileLocation) + cfData, err := common.ReadFile(cfg.ConfigFileLocation) if err != nil { - cfg.Log.Error("Unable to read config file", "error", err) - os.Exit(1) + return Config{}, err } // unmarshal config file if err := yaml.Unmarshal(cfData, &cfg.ConfigFile); err != nil { - cfg.Log.Error("Unable to read config file contents", "error", err) - os.Exit(1) + return Config{}, err } // print running config printRunningConfig(&cfg, cfgInfo) // return configuration - return cfg + return cfg, nil } diff --git a/internal/config/struct-config.go b/internal/config/struct-config.go index d5a358c..fc65f61 100644 --- a/internal/config/struct-config.go +++ b/internal/config/struct-config.go @@ -1,13 +1,19 @@ package config import ( - "log/slog" - "os" "reflect" "strconv" "time" + + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" ) +// Config uses struct tags to configure the application. +// (default) Default value to be used if unset or not defined. +// (ignored) Don't process the current tag. +// (info) String to be presented to the user on -help use. +// (secret) If set to true, hide the value from being output on start-up. +// (env) environment variable to be used if not set on command line. type Config struct { // time configuration TimeFormat string `default:"2006-01-02 15:04:05" env:"time_format"` @@ -16,21 +22,20 @@ type Config struct { TZUTC *time.Location `ignored:"true"` // logging - LogLevel int `default:"50" env:"log_level"` - Log *slog.Logger `ignored:"true"` - SLogLevel *slog.LevelVar `ignored:"true"` + LogLevel int `default:"50" env:"log_level"` - // HTTP Client timeout configurations - HTTPClientRequestTimeout int `default:"60" env:"HTTP_CLIENT_REQUEST_TIMEOUT"` - HTTPClientConnectTimeout int `default:"5" env:"HTTP_CLIENT_CONNECT_TIMEOUT"` - HTTPClientTLSHandshakeTimeout int `default:"5" env:"HTTP_CLIENT_TLS_TIMEOUT"` - HTTPClientIdleTimeout int `default:"5" env:"HTTP_CLIENT_IDLE_TIMEOUT"` + // webserver + WebServerPort int `default:"8080" env:"webserver_port"` + WebServerIP string `default:"0.0.0.0" env:"webserver_ip"` + WebServerReadTimeout int `default:"5" env:"webserver_read_timeout"` + WebServerWriteTimeout int `default:"1" env:"webserver_write_timeout"` + WebServerIdleTimeout int `default:"2" env:"webserver_idle_timeout"` // Output Filename - BindOutputFileName string `default:"./response-policy.bind" env:"OUTPUT"` + BindOutputFileName string `default:"./response-policy.bind" env:"output"` // Config - ConfigFileLocation string `default:"./config.yaml" env:"CONFIG_FILE"` + ConfigFileLocation string `default:"./config.yaml" env:"config_file"` ConfigFile configFileStruct } @@ -48,6 +53,7 @@ type configFileStruct struct { TTL string `yaml:"timeToLive"` } `yaml:"zoneConfig"` Sources struct { + AdBlockURLs []string `yaml:"adBlockURLs"` DomainListURLs []string `yaml:"domainListURLs"` HostFileURLs []string `yaml:"hostFileURLs"` } `yaml:"sources"` @@ -57,52 +63,27 @@ type configFileStruct struct { // New initializes the config variable for use with a prepared set of defaults. func New() Config { - cfg := Config{ - SLogLevel: new(slog.LevelVar), - } - - cfg.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: cfg.SLogLevel, - })) - - return cfg -} - -func setLogLevel(cfg *Config) { - switch { - // error - case cfg.LogLevel <= 20: - cfg.SLogLevel.Set(slog.LevelError) - cfg.Log.Info("Log level updated", "level", slog.LevelError) - // warning - case cfg.LogLevel > 20 && cfg.LogLevel <= 40: - cfg.SLogLevel.Set(slog.LevelWarn) - cfg.Log.Info("Log level updated", "level", slog.LevelWarn) - // info - case cfg.LogLevel > 40 && cfg.LogLevel <= 60: - cfg.SLogLevel.Set(slog.LevelInfo) - cfg.Log.Info("Log level updated", "level", slog.LevelInfo) - // debug - case cfg.LogLevel > 60: - cfg.SLogLevel.Set(slog.LevelDebug) - cfg.Log.Info("Log level updated", "level", slog.LevelDebug) - } - // set default logger - slog.SetDefault(cfg.Log) + return Config{} } func printRunningConfig(cfg *Config, cfgInfo []structInfo) { + var logRunningConfiguration string = "Running Configuration" + for _, info := range cfgInfo { - switch info.Type.String() { - case "string": - p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) - cfg.Log.Debug("Running Configuration", info.Alt, *p) - case "bool": - p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) - cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatBool(*p)) - case "int": - p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) - cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatInt(int64(*p), 10)) + if info.Secret { + log.Debug(logRunningConfiguration, info.Name, "REDACTED") + } else { + switch info.Type.String() { + case "string": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + log.Debug(logRunningConfiguration, info.Alt, *p) + case "bool": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + log.Debug(logRunningConfiguration, info.Alt, strconv.FormatBool(*p)) + case "int": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + log.Debug(logRunningConfiguration, info.Alt, strconv.FormatInt(int64(*p), 10)) + } } } } diff --git a/internal/config/struct-config_test.go b/internal/config/struct-config_test.go new file mode 100644 index 0000000..dd2e514 --- /dev/null +++ b/internal/config/struct-config_test.go @@ -0,0 +1,40 @@ +package config + +import ( + "bytes" + "log/slog" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "gitlab.smoothnet.org/nhyatt/bind-response-policy-zone-creator/internal/log" +) + +func slogToBuffer() (*bytes.Buffer, *slog.Logger) { + buf := new(bytes.Buffer) + return buf, slog.New( + slog.NewTextHandler( + buf, + &slog.HandlerOptions{ + Level: log.LevelTrace, + }, + ), + ) +} + +func TestPrintRunningConfig(t *testing.T) { + buf, l := slogToBuffer() + log.L.Log = l + + c := New() + cfgInfo, err := getStructInfo(&c) + assert.NoError(t, err) + printRunningConfig(&c, cfgInfo) + + assert.Contains(t, buf.String(), "Running Configuration") +} + +func TestNew(t *testing.T) { + c := New() + assert.Equal(t, "config.Config", reflect.TypeOf(c).String()) +} diff --git a/internal/log/logging.go b/internal/log/logging.go new file mode 100644 index 0000000..023092d --- /dev/null +++ b/internal/log/logging.go @@ -0,0 +1,154 @@ +package log + +import ( + "context" + "log/slog" + "os" +) + +const ( + LevelTrace = slog.Level(-8) + LevelFatal = slog.Level(12) +) + +type Log struct { + Ctx context.Context + Log *slog.Logger + SLogLevel slog.LevelVar +} + +var ( + // LevelNames set the names associated with custom logging levels. + LevelNames = map[slog.Leveler]string{ + LevelTrace: "TRACE", + LevelFatal: "FATAL", + } + // L is the global interface used for calling the logger subfunctions. + L = Log{} +) + +func Init(writer string) { + slogOptions := &slog.HandlerOptions{ + Level: &L.SLogLevel, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + a.Value = slog.StringValue(a.Value.Time().Format("2006-01-02T15:04:05.000-0700")) + a.Key = "ts" + } + if a.Key == slog.LevelKey { + level := a.Value.Any().(slog.Level) + levelLabel, exists := LevelNames[level] + if !exists { + levelLabel = level.String() + } + + a.Value = slog.StringValue(levelLabel) + } + + return a + }, + } + + // Initialize SLog and translate new logging levels + switch writer { + case "json": + L.Log = slog.New(slog.NewJSONHandler(os.Stdout, slogOptions)) + default: + L.Log = slog.New(slog.NewTextHandler(os.Stdout, slogOptions)) + } + // create context + L.Ctx = context.Background() +} + +// SetNumericLevel will set the log level based on a number from 1-100. +// The larger the number the more verbose the logs. +// +// 1-20 = Fatal, 21-40 = Error, 41-60 = Warn, 61-80 = Info, 81-99 = Debug, +// and 100 = Trace. +func SetNumericLevel(level int) { + var llu string = "Log Level Updated" + + switch { + // fatal + case level <= 20: + L.SLogLevel.Set(LevelFatal) + Info(llu, "level", LevelFatal) + // error + case level > 20 && level <= 40: + L.SLogLevel.Set(slog.LevelError) + Info(llu, "level", slog.LevelError) + // warning + case level > 40 && level <= 60: + L.SLogLevel.Set(slog.LevelWarn) + Info(llu, "level", slog.LevelWarn) + // info + case level > 60 && level <= 80: + L.SLogLevel.Set(slog.LevelInfo) + Info(llu, "level", slog.LevelInfo) + // debug + case level > 80 && level <= 99: + L.SLogLevel.Set(slog.LevelDebug) + Info(llu, "level", slog.LevelDebug) + // trace + case level > 99: + L.SLogLevel.Set(LevelTrace) + Info(llu, "level", LevelTrace) + } + + // set default logger + slog.SetDefault(L.Log) +} + +func Fatal(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + LevelFatal, + msg, + attrs..., + ) +} + +func Error(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + slog.LevelError, + msg, + attrs..., + ) +} + +func Warn(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + slog.LevelWarn, + msg, + attrs..., + ) +} + +func Info(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + slog.LevelInfo, + msg, + attrs..., + ) +} + +func Debug(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + slog.LevelDebug, + msg, + attrs..., + ) +} + +func Trace(msg string, attrs ...interface{}) { + L.Log.Log( + L.Ctx, + LevelTrace, + msg, + attrs..., + ) +} diff --git a/internal/log/logging_test.go b/internal/log/logging_test.go new file mode 100644 index 0000000..ca6c10f --- /dev/null +++ b/internal/log/logging_test.go @@ -0,0 +1,97 @@ +package log +import ( + "bytes" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" +) + +func slogToBuffer() (*bytes.Buffer, *slog.Logger) { + buf := new(bytes.Buffer) + return buf, slog.New( + slog.NewTextHandler( + buf, + &slog.HandlerOptions{ + Level: LevelTrace, + }, + ), + ) +} + +func TestSetLogLevel(t *testing.T) { + Init("text") + + for _, i := range []int{0, 21, 41, 61, 81, 101} { + SetNumericLevel(i) + + switch i { + case 0: + assert.Equal(t, LevelFatal, L.SLogLevel.Level()) + case 21: + assert.Equal(t, slog.LevelError, L.SLogLevel.Level()) + case 41: + assert.Equal(t, slog.LevelWarn, L.SLogLevel.Level()) + case 61: + assert.Equal(t, slog.LevelInfo, L.SLogLevel.Level()) + case 81: + assert.Equal(t, slog.LevelDebug, L.SLogLevel.Level()) + case 101: + assert.Equal(t, LevelTrace, L.SLogLevel.Level()) + } + } +} + +func TestFatal(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Fatal("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=ERROR+4") +} + +func TestError(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Error("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=ERROR") +} + +func TestWarn(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Warn("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=WARN") +} + +func TestInfo(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Info("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=INFO") +} + +func TestDebug(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Debug("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=DEBUG") +} + +func TestTrace(t *testing.T) { + buf, log := slogToBuffer() + L.Log = log + + Trace("TEST Message") + assert.Contains(t, buf.String(), "TEST Message") + assert.Contains(t, buf.String(), "level=DEBUG-4") +}