diff --git a/.gitignore b/.gitignore index 75ec3f0..e689df1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,67 @@ -.vscode/* \ No newline at end of file +# Application created directories +output/ + +# Visual Studio Code +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launce.json +!.vscode/extensions.json +!.vscode/*.code-snippets +.history/ +*.vsix + +# GoLang +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +go.work + +# General +.DS_Store +.AppleDouble +.LSOverride +# Icon must end with two \r +Icon + + +# Thumbnails +._* +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db +# Dump file +*.stackdump +# Folder config file +[Dd]esktop.ini +# Recycle Bin used on file shares +$RECYCLE.BIN/ +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp +# Windows shortcuts +*.lnk diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..7a02377 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,58 @@ +linters: + disable-all: true + enable: + # default linters + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - unused + # project linters + - asasalint + - asciicheck + - bodyclose + - contextcheck + - dupl + - durationcheck + - errchkjson + - gocheckcompilerdirectives + - gocognit + - goconst + - gocritic + - godox + - goimports + - gosec + - grouper + - importas + - misspell + - musttag + - nestif + - nilerr + - nilnil + - prealloc + - reassign + - tagalign + - tenv + - unconvert + - unparam + - usestdlibvars + - wastedassign + - whitespace + fast: true +linter-settings: + tagalign: + order: + - json + - yaml + - yml + - toml + - mapstructure + - binding + - validate + - env + - default + - ignored + - required + - secret + - info diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..b32e3d1 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "golang.go" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7c70a28 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "go.useLanguageServer": true, + "go.vetOnSave": "package", + "go.lintOnSave": "package", + "go.formatTool": "goimports", + "go.lintTool": "golangci-lint", + "go.lintFlags": [ + "--fast" + ], + + "[go]": { + "editor.detectIndentation": false, + "editor.tabSize": 2, + "editor.insertSpaces": false, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + }, + + "cSpell.words": [] +} diff --git a/app b/app deleted file mode 100755 index 25318e1..0000000 Binary files a/app and /dev/null differ diff --git a/config/config.yaml b/assets/config/config.yaml similarity index 90% rename from config/config.yaml rename to assets/config/config.yaml index 2252883..7ec14a4 100644 --- a/config/config.yaml +++ b/assets/config/config.yaml @@ -39,25 +39,32 @@ allowList: - ^ip6-allnodes$ - ^ip6-allrouters$ - ^ip6-allhosts$ -- (^|\.)thepiratebay\.org$ -- (^|\.)sendgrid\.net$ -- (^|\.)googleadservices\.com$ -- (^|\.)doubleclick\.net$ -- (^|\.)sailthru\.com$ -- (^|\.)magiskmanager\.com$ - (^|\.)apiservices\.krxd\.net$ -- (^|\.)logfiles\.zoom\.us$ -- (^|\.)logfiles-va\.zoom\.us$ -- (^|\.)nest\.com$ +- (^|\.)app-measurement\.com$ +- (^|\.)assets\.adobedtm\.com$ +- (^|\.)brandify\.com$ - (^|\.)clients.\.google\.com$ -- (^|\.)login\.live\.com$ -- (^|\.)unagi\.amazon\.com$ -- (^|\.)unagi-na\.amazon\.com$ +- (^|\.)doubleclick\.net$ - (^|\.)duckduckgo\.com$ +- (^|\.)ghostery\.net$ +- (^|\.)googleadservices\.com$ +- (^|\.)kochava\.com$ +- (^|\.)logfiles-va\.zoom\.us$ +- (^|\.)logfiles\.zoom\.us$ +- (^|\.)login\.live\.com$ +- (^|\.)magiskmanager\.com$ - (^|\.)msn\.com$ +- (^|\.)nest\.com$ - (^|\.)nexusrules\.officeapps\.live\.com$ - (^|\.)playfabapi\.com$ +- (^|\.)sailthru\.com$ +- (^|\.)sendgrid\.net$ +- (^|\.)tealiumiq\.com$ +- (^|\.)thepiratebay\.org$ +- (^|\.)unagi-na\.amazon\.com$ +- (^|\.)unagi\.amazon\.com$ - (^|\.)vercel-dns\.com$ +- ^\w+-\d{4}\.\w+-msedge\.net$ denyList: - jindlecleanings.xyz -- "*.jindlecleanings.xyz" \ No newline at end of file +- "*.jindlecleanings.xyz" diff --git a/assets/embed.go b/assets/embed.go new file mode 100644 index 0000000..e060e48 --- /dev/null +++ b/assets/embed.go @@ -0,0 +1,9 @@ +package assets + +import _ "embed" + +//go:embed config/config.yaml +var Config []byte + +//go:embed templates/bind-record.named +var BindRecord []byte diff --git a/assets/templates/bind-record.named b/assets/templates/bind-record.named new file mode 100644 index 0000000..fbd1529 --- /dev/null +++ b/assets/templates/bind-record.named @@ -0,0 +1,23 @@ +{{- $domain := .Domain -}} +$TTL {{ or .TTL "1h" }} +@ IN SOA {{ $domain }}. {{ or .Email "domain-admin" }}. ( + {{ or .Serial "0000000000" }} ; Serial + {{ or .Refresh "1h" }} ; Refresh + {{ or .Retry "30m" }} ; Retry + {{ or .Expire "1w" }} ; Expire + {{ or .Minimum "1h" }} ; Minimum +) + +; +; Name Servers +; +{{- range .NameServers }} + IN NS {{ . }}. +{{- end }} + +; +; Addresses +; +{{- range .BlockedDomains }} +{{ . }} IN CNAME blocked.{{ $domain }}. +{{- end }} diff --git a/cmd/bind/build-bind.go b/cmd/bind/build-bind.go index 47d82e6..66e35d9 100644 --- a/cmd/bind/build-bind.go +++ b/cmd/bind/build-bind.go @@ -2,64 +2,34 @@ package main import ( "bytes" - "log" - "os" - "time" "text/template" + + "pihole-blocklist/bind/assets" + "pihole-blocklist/bind/internal/config" ) -func buildBindResponsePolicyFile() { - defer timeTrack(time.Now(), "buildBindResponsePolicyFile") - +func buildBindResponsePolicyFile() error { var ( output bytes.Buffer ) - outputTemplate := `{{- $domain := .Domain -}} -$TTL {{ or .TTL "1h" }} -@ IN SOA {{ $domain }}. {{ or .Email "domain-admin" }}. ( - {{ or .Serial "0000000000" }} ; Serial - {{ or .Refresh "1h" }} ; Refresh - {{ or .Retry "30m" }} ; Retry - {{ or .Expire "1w" }} ; Expire - {{ or .Minimum "1h" }} ; Minimum -) + outputTemplate := assets.BindRecord -; -; Name Servers -; -{{- range .NameServers }} - IN NS {{ . }}. -{{- end }} - -; -; Addresses -; -{{- range .BlockedDomains }} -{{ . }} IN CNAME blocked.{{ $domain }}. -{{- end }} -` - - t, err := template.New("response-policy-zone").Parse(outputTemplate) + t, err := template.New("response-policy-zone").Parse(string(outputTemplate)) if err != nil { - log.Fatalf("[FATAL] Unable to parse template (%s): %v\n", "response-policy-zone", err) + return err } - if err := t.Execute(&output, config.Config.ZoneConfig); err != nil { - log.Fatalf("[FATAL] Unable to generate template output: %v\n", err) + if err := t.Execute(&output, cfg.ConfigFile.ZoneConfig); err != nil { + return err } - fileWriter, err := os.Create(config.BindOutputFileName) + bytesWritten, err := config.WriteFile(cfg.BindOutputFileName, output.Bytes()) if err != nil { - log.Fatalf("[FATAL] Unable to open file (%s) for writing: %v", config.BindOutputFileName, err) - } - defer fileWriter.Close() - - bytesWritten, err := fileWriter.Write(output.Bytes()) - if err != nil { - log.Fatalf("[FATAL] Unable to write to file (%s): %v", config.BindOutputFileName, err) + return err } - log.Printf("[DEBUG] Wrote %d bytes to %s.\n", bytesWritten, config.BindOutputFileName) + cfg.Log.Debug("file created", "file", cfg.BindOutputFileName, "bytes", bytesWritten) + return nil } diff --git a/cmd/bind/cleanup.go b/cmd/bind/cleanup.go index 7b56e8f..d572be2 100644 --- a/cmd/bind/cleanup.go +++ b/cmd/bind/cleanup.go @@ -1,15 +1,11 @@ package main import ( - "log" "regexp" "sort" - "time" ) func cleanBadDomains(domains []string) []string { - defer timeTrack(time.Now(), "cleanBadDomains") - // remove duplicates total := len(domains) all := make(map[string]bool) @@ -21,7 +17,7 @@ func cleanBadDomains(domains []string) []string { } } domains = list - log.Printf("[INFO] Duplicate items removed: %d\n", total-len(domains)) + cfg.Log.Info("hosts removed from blocklist", "reason", "duplicate", "hosts", total-len(domains)) // remove hosts that are too long total = len(domains) @@ -33,18 +29,18 @@ func cleanBadDomains(domains []string) []string { list = append(list, blocklistItem) } domains = list - log.Printf("[INFO] Hosts with too many characters removed: %d\n", total-len(domains)) + cfg.Log.Info("hosts removed from blocklist", "reason", "too many characters", "hosts", total-len(domains)) // remove allow-listed matches total = len(domains) // filter out bad regex goodAllowedItemList := []string{} - for _, allowedItem := range config.Config.AllowLists { + for _, allowedItem := range cfg.ConfigFile.AllowLists { _, err := regexp.Compile(allowedItem) if err != nil { - log.Printf("[ERROR] Allow list item (%s) is not valid regex: %v\n", allowedItem, err) - break + cfg.Log.Error("unable to parse allow list item", "error", err, "regex", allowedItem) + continue } goodAllowedItemList = append(goodAllowedItemList, allowedItem) } @@ -54,7 +50,7 @@ func cleanBadDomains(domains []string) []string { addEntry := true for _, allowedItem := range goodAllowedItemList { if regexp.MustCompile(allowedItem).MatchString(v) { - log.Printf("[DEBUG] Removing allowed matching item: %s\n", v) + cfg.Log.Debug("hosts removed from blocklist", "reason", "allowed host", "match string", allowedItem, "host", v) addEntry = false } } @@ -63,9 +59,9 @@ func cleanBadDomains(domains []string) []string { } } domains = list - log.Printf("[INFO] Allowed hosts removed: %d\n", total-len(domains)) + cfg.Log.Info("hosts removed from blocklist", "hosts", total-len(domains)) - log.Printf("[INFO] Total domains in list at end: %d.\n", len(domains)) + cfg.Log.Info("total domains in list", "hosts", len(domains)) sort.Strings(domains) return domains } diff --git a/cmd/bind/config.go b/cmd/bind/config.go deleted file mode 100644 index a9f7935..0000000 --- a/cmd/bind/config.go +++ /dev/null @@ -1,103 +0,0 @@ -package main - -import ( - "os" - "time" - - "github.com/hashicorp/logutils" -) - -type configStructure struct { - // time configuration - TimeFormat string - TimeZone *time.Location - TimeZoneUTC *time.Location - - // logging - Log *logutils.LevelFilter - - // HTTP Client timeout configurations - HTTPClientRequestTimeout int - HTTPClientConnectTimeout int - HTTPClientTLSHandshakeTimeout int - HTTPClientIdleTimeout int - - // Output Filename - BindOutputFileName string - - // Config - ConfigFileLocation string - Config configFileStruct -} - -type configFileStruct struct { - ZoneConfig struct { - TTL string `yaml:"timeToLive"` - Domain string `yaml:"baseDomain"` - Email string `yaml:"emailAddress"` - Serial string `yaml:"zoneSerialNumber"` - Refresh string `yaml:"zoneRefresh"` - Retry string `yaml:"zoneRetry"` - Expire string `yaml:"zoneExpire"` - Minimum string `yaml:"zoneMinimum"` - NameServers []string `yaml:"nameServers"` - BlockedDomains []string `yaml:"blockedDomains"` - } `yaml:"zoneConfig"` - Sources struct { - HostFileURLs []string `yaml:"hostFileURLs"` - DomainListURLs []string `yaml:"domainListURLs"` - } `yaml:"sources"` - AllowLists []string `yaml:"allowList"` - DenyList []string `yaml:"denyList"` -} - -var config = configStructure{ - TimeFormat: "2006-01-02 15:04:05", - Log: &logutils.LevelFilter{ - Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "WARNING", "ERROR"}, - Writer: os.Stderr, - }, - - // Nice blocklist location: https://firebog.net/ - // Default Blocklist - Config: configFileStruct{ - Sources: struct { - HostFileURLs []string `yaml:"hostFileURLs"` - DomainListURLs []string `yaml:"domainListURLs"` - }{ - HostFileURLs: []string{ - //"https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts", - //"http://sysctl.org/cameleon/hosts", - //"https://raw.githubusercontent.com/DandelionSprout/adfilt/master/Alternate%20versions%20Anti-Malware%20List/AntiMalwareHosts.txt", - //"https://raw.githubusercontent.com/FadeMind/hosts.extras/master/add.Risk/hosts", - }, - DomainListURLs: []string{ - //"https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt", - //"https://s3.amazonaws.com/lists.disconnect.me/simple_malvertising.txt", - //"https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt", - //"https://v.firebog.net/hosts/Prigent-Crypto.txt", - //"https://phishing.army/download/phishing_army_blocklist_extended.txt", - //"https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-malware.txt", - //"https://raw.githubusercontent.com/Spam404/lists/master/main-blacklist.txt", - //"https://dbl.oisd.nl/", - //"https://osint.digitalside.it/Threat-Intel/lists/latestdomains.txt", - }, - }, - AllowLists: []string{ - // localhosts included in blocklists for some reason - `localhost`, - `localhost.localdomain`, - `local`, - `broadcasthost`, - `localhost`, - `ip6-localhost`, - `ip6-loopback`, - `localhost`, - `ip6-localnet`, - `ip6-mcastprefix`, - `ip6-allnodes`, - `ip6-allrouters`, - `ip6-allhosts`, - }, - }, -} diff --git a/cmd/bind/get-remote-data.go b/cmd/bind/get-remote-data.go index c76d9a0..997c53c 100644 --- a/cmd/bind/get-remote-data.go +++ b/cmd/bind/get-remote-data.go @@ -1,29 +1,27 @@ package main import ( - "log" - "pihole-blocklist/v2/internal/httpclient" "time" + + "pihole-blocklist/bind/internal/httpclient" ) func getListData() []string { - defer timeTrack(time.Now(), "getListData") - var badDomains []string listSimple := make(chan []string) listComplex := make(chan []string) - log.Printf("[INFO] Downloading blocklists\n") + cfg.Log.Info("downloading blocklists") // Get Simple Blocklists go func() { - data := getData(config.Config.Sources.DomainListURLs) + data := getData(cfg.ConfigFile.Sources.DomainListURLs) domains := parseSimple(data) listSimple <- domains }() // Get Host File Blocklists go func() { - data := getData(config.Config.Sources.HostFileURLs) + data := getData(cfg.ConfigFile.Sources.HostFileURLs) domains := parseComplex(data) listComplex <- domains }() @@ -38,9 +36,9 @@ func getListData() []string { select { case simple = <-listSimple: simpleFinished = true - log.Printf("[INFO] All simple lists have been retrieved.\n") + cfg.Log.Info("all simple lists downloaded") case complex = <-listComplex: - log.Printf("[INFO] All complex lists have been retrieved.\n") + cfg.Log.Info("all complex lists downloaded") complexFinished = true default: time.Sleep(time.Millisecond * 100) @@ -49,28 +47,26 @@ func getListData() []string { if simpleFinished && complexFinished { badDomains = append(badDomains, simple...) badDomains = append(badDomains, complex...) - log.Printf("[INFO] Number of domains detected: %d\n", len(badDomains)) + cfg.Log.Info("domains retrieved", "hosts", len(badDomains)) break } } // append deny list items to list of blocked domains - badDomains = append(badDomains, config.Config.DenyList...) + badDomains = append(badDomains, cfg.ConfigFile.DenyList...) return badDomains } func getData(urls []string) []byte { - defer timeTrack(time.Now(), "getData") - - var listData []byte + listData := make([]byte, 0, len(urls)+1) for _, u := range urls { - log.Printf("[TRACE] Downloading URL: %s\n", u) + cfg.Log.Debug("downloading", "url", u) c := httpclient.DefaultClient() data, err := c.Get(u) if err != nil { - log.Printf("[ERROR] Unable to get remote content from URL (%s): %v", u, err) + cfg.Log.Error("unable to get remote content", "error", err, "url", err) } listData = append(listData, data...) // add newline to the end of data, you know, for funzies diff --git a/cmd/bind/init.go b/cmd/bind/init.go deleted file mode 100644 index fb49997..0000000 --- a/cmd/bind/init.go +++ /dev/null @@ -1,212 +0,0 @@ -package main - -import ( - "flag" - "io/ioutil" - "log" - "os" - "strconv" - "strings" - "time" - - "github.com/hashicorp/logutils" - "gopkg.in/yaml.v3" -) - -// getEnvString returns string from environment variable -func getEnvString(env, def string) (val string) { //nolint:deadcode - defer timeTrack(time.Now(), "getEnvString") - - val = os.Getenv(env) - - if val == "" { - return def - } - - return -} - -// getEnvInt returns int from environment variable -func getEnvInt(env string, def int) (ret int) { - defer timeTrack(time.Now(), "getEnvInt") - - val := os.Getenv(env) - - if val == "" { - return def - } - - ret, err := strconv.Atoi(val) - if err != nil { - log.Fatalf("[ERROR] Environment variable is not numeric: %v\n", env) - } - - return -} - -func initialize() { - defer timeTrack(time.Now(), "initialize") - - config.TimeZone, _ = time.LoadLocation("America/Chicago") - config.TimeZoneUTC, _ = time.LoadLocation("UTC") - - // read command line options - var ( - logLevel int - ns1, ns2 string - ) - - // log configuration - flag.IntVar(&logLevel, - "log", - getEnvInt("LOG_LEVEL", 50), - "(LOG_LEVEL)\nlog level") - // http client configuration - flag.IntVar(&config.HTTPClientRequestTimeout, - "client-req-to", - getEnvInt("HTTP_CLIENT_REQUEST_TIMEOUT", 60), - "(HTTP_CLIENT_REQUEST_TIMEOUT)\ntime in seconds for the internal http client to complete a request") - flag.IntVar(&config.HTTPClientConnectTimeout, - "client-con-to", - getEnvInt("HTTP_CLIENT_CONNECT_TIMEOUT", 5), - "(HTTP_CLIENT_CONNECT_TIMEOUT)\ntime in seconds for the internal http client connection timeout") - flag.IntVar(&config.HTTPClientTLSHandshakeTimeout, - "client-tls-to", - getEnvInt("HTTP_CLIENT_TLS_TIMEOUT", 5), - "(HTTP_CLIENT_TLS_TIMEOUT)\ntime in seconds for the internal http client to complete a tls handshake") - flag.IntVar(&config.HTTPClientIdleTimeout, - "client-idle-to", - getEnvInt("HTTP_CLIENT_IDLE_TIMEOUT", 5), - "(HTTP_CLIENT_IDLE_TIMEOUT)\ntime in seconds that the internal http client will keep a connection open when idle") - // Bind Config - flag.StringVar(&config.Config.ZoneConfig.TTL, - "bind-ttl", - getEnvString("TTL", "1h"), - "(TTL)\nBind zone time to live") - flag.StringVar(&config.Config.ZoneConfig.Domain, - "bind-domain", - getEnvString("DOMAIN", "example.com"), - "(DOMAIN)\nBind zone base domain") - flag.StringVar(&config.Config.ZoneConfig.Email, - "bind-email", - getEnvString("EMAIL", "domain-admin@example.com"), - "(EMAIL)\nBind zone authority e-mail address") - flag.StringVar(&config.Config.ZoneConfig.Serial, - "bind-timestamp", - getEnvString("TIMESTAMP", time.Now().In(config.TimeZone).Format("0601021504")), - "(TIMESTAMP)\nBind zone serial number") - flag.StringVar(&config.Config.ZoneConfig.Refresh, - "bind-refresh", - getEnvString("REFRESH", "1h"), - "(REFRESH)\nBind zone refresh time") - flag.StringVar(&config.Config.ZoneConfig.Retry, - "bind-retry", - getEnvString("RETRY", "30m"), - "(RETRY)\nBind zone retry time") - flag.StringVar(&config.Config.ZoneConfig.Expire, - "bind-expire", - getEnvString("EXPIRE", "1w"), - "(EXPIRE)\nBind zone expire time") - flag.StringVar(&config.Config.ZoneConfig.Minimum, - "bind-minimum", - getEnvString("MINIMUM", "1h"), - "(MINIMUM)\nBind zone minimum time") - flag.StringVar(&ns1, - "bind-ns1", - getEnvString("NS1", ""), - "(NS1)\nBind zone primary name-server") - flag.StringVar(&ns2, - "bind-ns2", - getEnvString("NS2", ""), - "(NS2)\nBind zone secondary name-server") - // output file - flag.StringVar(&config.BindOutputFileName, - "output", - getEnvString("OUTPUT", "./response-policy.bind"), - "(FILENAME)\nWrite local file to filename") - flag.StringVar(&config.ConfigFileLocation, - "config-file", - getEnvString("CONFIG_FILE", ""), - "(CONFIG_FILE)\nRead configuration from file") - flag.Parse() - - // set logging level - switch { - case logLevel <= 20: - config.Log.SetMinLevel(logutils.LogLevel("ERROR")) - case logLevel > 20 && logLevel <= 40: - config.Log.SetMinLevel(logutils.LogLevel("WARNING")) - case logLevel > 40 && logLevel <= 60: - config.Log.SetMinLevel(logutils.LogLevel("INFO")) - case logLevel > 60 && logLevel <= 80: - config.Log.SetMinLevel(logutils.LogLevel("DEBUG")) - case logLevel > 80: - config.Log.SetMinLevel(logutils.LogLevel("TRACE")) - } - log.SetOutput(config.Log) - - // print current configuration - log.Printf("[DEBUG] configuration value set: LOG_LEVEL = %v\n", strconv.Itoa(logLevel)) - log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_REQUEST_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientRequestTimeout)) - log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_CONNECT_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientConnectTimeout)) - log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_TLS_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientTLSHandshakeTimeout)) - log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_IDLE_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientIdleTimeout)) - log.Printf("[DEBUG] configuration value set: TTL = %v\n", config.Config.ZoneConfig.TTL) - log.Printf("[DEBUG] configuration value set: DOMAIN = %v\n", config.Config.ZoneConfig.Domain) - log.Printf("[DEBUG] configuration value set: EMAIL = %v\n", config.Config.ZoneConfig.Email) - log.Printf("[DEBUG] configuration value set: TIMESTAMP = %v\n", config.Config.ZoneConfig.Serial) - log.Printf("[DEBUG] configuration value set: REFRESH = %v\n", config.Config.ZoneConfig.Refresh) - log.Printf("[DEBUG] configuration value set: RETRY = %v\n", config.Config.ZoneConfig.Retry) - log.Printf("[DEBUG] configuration value set: EXPIRE = %v\n", config.Config.ZoneConfig.Expire) - log.Printf("[DEBUG] configuration value set: MINIMUM = %v\n", config.Config.ZoneConfig.Minimum) - log.Printf("[DEBUG] configuration value set: NS1 = %v\n", ns1) - log.Printf("[DEBUG] configuration value set: NS2 = %v\n", ns2) - log.Printf("[DEBUG] configuration value set: CONFIG_FILE = %v\n", config.ConfigFileLocation) - - // read config file - var err error - if config.ConfigFileLocation != "" { - if config.Config, err = readConfigFile(config.ConfigFileLocation); err != nil { - log.Fatalf("[FATAL] Invalid config file: %v\n", err) - } - if config.Config.ZoneConfig.Serial == "" { - config.Config.ZoneConfig.Serial = time.Now().In(config.TimeZone).Format("0601021504") - } - } - - // set bind-config nameservers - if ns1 != "" { - config.Config.ZoneConfig.NameServers = append(config.Config.ZoneConfig.NameServers, ns1) - } - if ns2 != "" { - config.Config.ZoneConfig.NameServers = append(config.Config.ZoneConfig.NameServers, ns2) - } - - if len(config.Config.ZoneConfig.NameServers) == 0 { - log.Printf("[ERROR] A primary name-server must be identified.") - flag.PrintDefaults() - os.Exit(1) - } - - // bind does not use "@", so we convert it to a "." - config.Config.ZoneConfig.Email = strings.Replace(config.Config.ZoneConfig.Email, "@", ".", -1) - - log.Printf("[DEBUG] Initialization Complete\n") -} - -func readConfigFile(configFileLocation string) (configFileStruct, error) { - defer timeTrack(time.Now(), "readConfigFile") - - var output configFileStruct - - rd, err := ioutil.ReadFile(configFileLocation) - if err != nil { - return output, err - } - - if err := yaml.Unmarshal(rd, &output); err != nil { - return output, err - } - - return output, nil -} diff --git a/cmd/bind/main.go b/cmd/bind/main.go index 38760aa..0c5a5e8 100644 --- a/cmd/bind/main.go +++ b/cmd/bind/main.go @@ -1,13 +1,26 @@ package main +import ( + "pihole-blocklist/bind/internal/config" + "time" +) + +var cfg config.Config + func main() { - initialize() + cfg = config.Init() + + // Set the zone serial number + cfg.ConfigFile.ZoneConfig.Serial = time.Now().In(cfg.TZLocal).Format("0601021504") // get remote URL data badDomains := getListData() // clean-up - config.Config.ZoneConfig.BlockedDomains = cleanBadDomains(badDomains) + cfg.ConfigFile.ZoneConfig.BlockedDomains = cleanBadDomains(badDomains) - buildBindResponsePolicyFile() + // write file + if err := buildBindResponsePolicyFile(); err != nil { + cfg.Log.Error("unable to write file", "error", err, "path", cfg.BindOutputFileName) + } } diff --git a/cmd/bind/parsing-complex.go b/cmd/bind/parsing-complex.go index 513b218..d07f49a 100644 --- a/cmd/bind/parsing-complex.go +++ b/cmd/bind/parsing-complex.go @@ -3,17 +3,13 @@ package main import ( "bufio" "bytes" - "log" "regexp" "strings" - "time" "github.com/asaskevich/govalidator" ) func parseComplex(data []byte) []string { - defer timeTrack(time.Now(), "parseComplex") - var domains []string // convert data to reader for line-by-line reading @@ -42,7 +38,7 @@ func parseComplex(data []byte) []string { if govalidator.IsDNSName(lineItems[1]) { domains = append(domains, lineItems[1]) } else { - log.Printf("[TRACE] Domain is not valid: %s\n", lineItems[0]) + cfg.Log.Debug("host invalid", "host", lineItems[0]) } } } diff --git a/cmd/bind/parsing-simple.go b/cmd/bind/parsing-simple.go index b161112..c76c440 100644 --- a/cmd/bind/parsing-simple.go +++ b/cmd/bind/parsing-simple.go @@ -3,17 +3,13 @@ package main import ( "bufio" "bytes" - "log" "regexp" "strings" - "time" "github.com/asaskevich/govalidator" ) func parseSimple(data []byte) []string { - defer timeTrack(time.Now(), "parseSimple") - var domains []string // convert data to reader for line-by-line reading @@ -42,7 +38,7 @@ func parseSimple(data []byte) []string { if govalidator.IsDNSName(lineItems[0]) { domains = append(domains, lineItems[0]) } else { - log.Printf("[TRACE] Domain is not valid: %s\n", lineItems[0]) + cfg.Log.Debug("host invalid", "host", lineItems[0]) } } } diff --git a/cmd/bind/supporting-functions.go b/cmd/bind/supporting-functions.go deleted file mode 100644 index 98eb1f0..0000000 --- a/cmd/bind/supporting-functions.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -import ( - "log" - "time" -) - -func timeTrack(start time.Time, name string) { - elapsed := time.Since(start) - log.Printf("[DEBUG] Function %s took %s\n", name, elapsed) -} - -func removeStringFromSlice(s []string, i int) []string { - s[i] = s[len(s)-1] - return s[:len(s)-1] -} diff --git a/go.mod b/go.mod index 8c51e79..02dceed 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,8 @@ -module pihole-blocklist/v2 +module pihole-blocklist/bind go 1.18 require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 - github.com/hashicorp/logutils v1.0.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index d78e584..0d53b0c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= -github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/envconfig.go b/internal/config/envconfig.go new file mode 100644 index 0000000..73c1786 --- /dev/null +++ b/internal/config/envconfig.go @@ -0,0 +1,241 @@ +package config + +import ( + "flag" + "fmt" + "os" + "reflect" + "strconv" + "strings" +) + +type structInfo struct { + Name string + Alt string + Info string + Key string + Field reflect.Value + Tags reflect.StructTag + Type reflect.Type + DefaultValue interface{} + Secret interface{} +} + +func getEnv[t string | bool | int | int64 | float64](env string, def t) (t, error) { + val := os.Getenv(env) + if len(val) == 0 { + return def, nil + } + + output := *new(t) + switch (interface{})(def).(type) { + case string: + v, err := typeConversion("string", val) + if err != nil { + return (interface{})(false).(t), err + } + output = v.(t) + case bool: + v, err := typeConversion("bool", val) + if err != nil { + return (interface{})(false).(t), err + } + output = v.(t) + case int: + v, err := typeConversion("int", val) + if err != nil { + return (interface{})(int(0)).(t), err + } + output = (interface{})(int(v.(int64))).(t) + case int64: + v, err := typeConversion("int64", val) + if err != nil { + return (interface{})(int64(0)).(t), err + } + output = v.(t) + case float64: + v, err := typeConversion("float64", val) + if err != nil { + return (interface{})(float64(0)).(t), err + } + output = v.(t) + } + + return output, nil +} + +func getStructInfo(spec interface{}) ([]structInfo, error) { + s := reflect.ValueOf(spec) + + if s.Kind() != reflect.Pointer { + return []structInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a pointer to a struct.\n", s.Kind()) + } + + s = s.Elem() + if s.Kind() != reflect.Struct { + return []structInfo{}, fmt.Errorf("getStructInfo() was sent a %s instead of a struct.\n", s.Kind()) + } + typeOfSpec := s.Type() + + infos := make([]structInfo, 0, s.NumField()) + for i := 0; i < s.NumField(); i++ { + f := s.Field(i) + ftype := typeOfSpec.Field(i) + + ignored, _ := strconv.ParseBool(ftype.Tag.Get("ignored")) + if !f.CanSet() || ignored { + continue + } + + for f.Kind() == reflect.Pointer { + if f.IsNil() { + if f.Type().Elem().Kind() != reflect.Struct { + break + } + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + + secret, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("secret")) + if err != nil { + secret = false + } + + var desc string + if len(ftype.Tag.Get("info")) != 0 { + desc = fmt.Sprintf("(%s) %s", strings.ToUpper(ftype.Tag.Get("env")), ftype.Tag.Get("info")) + } else { + desc = fmt.Sprintf("(%s)", strings.ToUpper(ftype.Tag.Get("env"))) + } + + info := structInfo{ + Name: ftype.Name, + Alt: strings.ToUpper(ftype.Tag.Get("env")), + Info: desc, + Key: ftype.Name, + Field: f, + Tags: ftype.Tag, + Type: ftype.Type, + Secret: secret, + } + if info.Alt != "" { + info.Key = info.Alt + } + info.Key = strings.ToUpper(info.Key) + if ftype.Tag.Get("default") != "" { + v, err := typeConversion(ftype.Type.String(), ftype.Tag.Get("default")) + if err != nil { + return []structInfo{}, err + } + info.DefaultValue = v + } + infos = append(infos, info) + } + return infos, nil +} + +func typeConversion(t, v string) (interface{}, error) { + switch t { + case "string": //nolint:goconst + return v, nil + case "int": //nolint:goconst + return strconv.ParseInt(v, 10, 0) + case "int8": + return strconv.ParseInt(v, 10, 8) + case "int16": + return strconv.ParseInt(v, 10, 16) + case "int32": + return strconv.ParseInt(v, 10, 32) + case "int64": + return strconv.ParseInt(v, 10, 64) + case "uint": + return strconv.ParseUint(v, 10, 0) + case "uint16": + return strconv.ParseUint(v, 10, 16) + case "uint32": + return strconv.ParseUint(v, 10, 32) + case "uint64": + return strconv.ParseUint(v, 10, 64) + case "float32": + return strconv.ParseFloat(v, 32) + case "float64": + return strconv.ParseFloat(v, 64) + case "complex64": + return strconv.ParseComplex(v, 64) + case "complex128": + return strconv.ParseComplex(v, 128) + case "bool": //nolint:goconst + return strconv.ParseBool(v) + } + return nil, fmt.Errorf("Unable to identify type.") +} + +func (cfg *Config) parseFlags(cfgInfo []structInfo) error { //nolint:gocognit + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + var dv string + + if info.DefaultValue != nil { + dv = info.DefaultValue.(string) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.StringVar(p, info.Name, retVal, info.Info) + case "bool": + var dv bool + + if info.DefaultValue != nil { + dv = info.DefaultValue.(bool) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.BoolVar(p, info.Name, retVal, info.Info) + case "int": + var dv int + + if info.DefaultValue != nil { + dv = int(info.DefaultValue.(int64)) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.IntVar(p, info.Name, retVal, info.Info) + case "int64": + var dv int64 + + if info.DefaultValue != nil { + dv = info.DefaultValue.(int64) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int64) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.Int64Var(p, info.Name, retVal, info.Info) + case "float64": + var dv float64 + + if info.DefaultValue != nil { + dv = info.DefaultValue.(float64) + } + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*float64) + retVal, err := getEnv(info.Alt, dv) + if err != nil { + return err + } + flag.Float64Var(p, info.Name, retVal, info.Info) + } + } + flag.Parse() + return nil +} diff --git a/internal/config/file-operations.go b/internal/config/file-operations.go new file mode 100644 index 0000000..3b5a308 --- /dev/null +++ b/internal/config/file-operations.go @@ -0,0 +1,52 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +func FileExists(path string) bool { + if _, err := os.Stat(path); os.IsNotExist(err) { + return false + } + return true +} + +func ReadFile(path string) ([]byte, error) { + var output []byte + if !FileExists(path) { + return []byte{}, fmt.Errorf("Unable to read file, file does not exist: %s", path) + } + + output, err := os.ReadFile(path) + if err != nil { + return []byte{}, fmt.Errorf("Unable to read file, %v: %s", err, path) + } + return output, nil +} + +func WriteFile(path string, data []byte) (int, error) { + dir := filepath.Dir(path) + + if err := os.MkdirAll(dir, 0755); err != nil { + return 0, fmt.Errorf("Unable to create parent directory, %v: %s", err, dir) + } + + fh, err := os.Create(path) + if err != nil { + return 0, fmt.Errorf("Unable to open file for writing, %v: %s", err, path) + } + defer fh.Close() + + bs, err := fh.Write(data) + if err != nil { + return 0, fmt.Errorf("Unable to write file, %v: %s", err, path) + } + + if err := fh.Sync(); err != nil { + return 0, fmt.Errorf("Unable to sync file to disk, %v: %s", err, path) + } + + return bs, nil +} diff --git a/internal/config/initialize.go b/internal/config/initialize.go new file mode 100644 index 0000000..df4a783 --- /dev/null +++ b/internal/config/initialize.go @@ -0,0 +1,64 @@ +package config + +import ( + "fmt" + "os" + "pihole-blocklist/bind/assets" + "time" + + "gopkg.in/yaml.v3" +) + +func Init() Config { + cfg := New() + + cfgInfo, err := getStructInfo(&cfg) + if err != nil { + panic(fmt.Sprintf("Unable to initialize program: %v", err)) + } + + // get command line flags + if err := cfg.parseFlags(cfgInfo); err != nil { + panic(fmt.Sprintf("Unable to initialize program: %v", err)) + } + + // set logging Level + setLogLevel(&cfg) + + // set timezone & time format + cfg.TZUTC, _ = time.LoadLocation("UTC") + cfg.TZLocal, err = time.LoadLocation(cfg.TimeZoneLocal) + if err != nil { + cfg.Log.Error("Unable to parse timezone string", "error", err) + os.Exit(1) + } + + // check config file + if !FileExists(cfg.ConfigFileLocation) { + if _, err := WriteFile(cfg.ConfigFileLocation, assets.Config); err != nil { + cfg.Log.Error(err.Error()) + os.Exit(1) + } + cfg.Log.Error("Unable to locate configuration file, an example config file has been written", "path", cfg.ConfigFileLocation) + os.Exit(1) + } + + // read config + cfData, err := ReadFile(cfg.ConfigFileLocation) + if err != nil { + cfg.Log.Error("Unable to read config file", "error", err) + os.Exit(1) + } + + // unmarshal config file + if err := yaml.Unmarshal(cfData, &cfg.ConfigFile); err != nil { + cfg.Log.Error("Unable to read config file contents", "error", err) + os.Exit(1) + } + + // print running config + printRunningConfig(&cfg, cfgInfo) + + // return configuration + return cfg +} diff --git a/internal/config/struct-config.go b/internal/config/struct-config.go new file mode 100644 index 0000000..d5a358c --- /dev/null +++ b/internal/config/struct-config.go @@ -0,0 +1,108 @@ +package config + +import ( + "log/slog" + "os" + "reflect" + "strconv" + "time" +) + +type Config struct { + // time configuration + TimeFormat string `default:"2006-01-02 15:04:05" env:"time_format"` + TimeZoneLocal string `default:"America/Chicago" env:"time_zone"` + TZLocal *time.Location `ignored:"true"` + TZUTC *time.Location `ignored:"true"` + + // logging + LogLevel int `default:"50" env:"log_level"` + Log *slog.Logger `ignored:"true"` + SLogLevel *slog.LevelVar `ignored:"true"` + + // HTTP Client timeout configurations + HTTPClientRequestTimeout int `default:"60" env:"HTTP_CLIENT_REQUEST_TIMEOUT"` + HTTPClientConnectTimeout int `default:"5" env:"HTTP_CLIENT_CONNECT_TIMEOUT"` + HTTPClientTLSHandshakeTimeout int `default:"5" env:"HTTP_CLIENT_TLS_TIMEOUT"` + HTTPClientIdleTimeout int `default:"5" env:"HTTP_CLIENT_IDLE_TIMEOUT"` + + // Output Filename + BindOutputFileName string `default:"./response-policy.bind" env:"OUTPUT"` + + // Config + ConfigFileLocation string `default:"./config.yaml" env:"CONFIG_FILE"` + ConfigFile configFileStruct +} + +type configFileStruct struct { + ZoneConfig struct { + BlockedDomains []string `yaml:"blockedDomains"` + Domain string `yaml:"baseDomain"` + Email string `yaml:"emailAddress"` + Expire string `yaml:"zoneExpire"` + Minimum string `yaml:"zoneMinimum"` + NameServers []string `yaml:"nameServers"` + Refresh string `yaml:"zoneRefresh"` + Retry string `yaml:"zoneRetry"` + Serial string `yaml:"zoneSerialNumber"` + TTL string `yaml:"timeToLive"` + } `yaml:"zoneConfig"` + Sources struct { + DomainListURLs []string `yaml:"domainListURLs"` + HostFileURLs []string `yaml:"hostFileURLs"` + } `yaml:"sources"` + AllowLists []string `yaml:"allowList"` + DenyList []string `yaml:"denyList"` +} + +// New initializes the config variable for use with a prepared set of defaults. +func New() Config { + cfg := Config{ + SLogLevel: new(slog.LevelVar), + } + + cfg.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: cfg.SLogLevel, + })) + + return cfg +} + +func setLogLevel(cfg *Config) { + switch { + // error + case cfg.LogLevel <= 20: + cfg.SLogLevel.Set(slog.LevelError) + cfg.Log.Info("Log level updated", "level", slog.LevelError) + // warning + case cfg.LogLevel > 20 && cfg.LogLevel <= 40: + cfg.SLogLevel.Set(slog.LevelWarn) + cfg.Log.Info("Log level updated", "level", slog.LevelWarn) + // info + case cfg.LogLevel > 40 && cfg.LogLevel <= 60: + cfg.SLogLevel.Set(slog.LevelInfo) + cfg.Log.Info("Log level updated", "level", slog.LevelInfo) + // debug + case cfg.LogLevel > 60: + cfg.SLogLevel.Set(slog.LevelDebug) + cfg.Log.Info("Log level updated", "level", slog.LevelDebug) + } + // set default logger + slog.SetDefault(cfg.Log) +} + +func printRunningConfig(cfg *Config, cfgInfo []structInfo) { + for _, info := range cfgInfo { + switch info.Type.String() { + case "string": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*string) + cfg.Log.Debug("Running Configuration", info.Alt, *p) + case "bool": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*bool) + cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatBool(*p)) + case "int": + p := reflect.ValueOf(cfg).Elem().FieldByName(info.Name).Addr().Interface().(*int) + cfg.Log.Debug("Running Configuration", info.Alt, strconv.FormatInt(int64(*p), 10)) + } + } +}