package main

import (
	"flag"
	"io/ioutil"
	"log"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/hashicorp/logutils"
	"gopkg.in/yaml.v3"
)

// getEnvString returns string from environment variable
func getEnvString(env, def string) (val string) { //nolint:deadcode
	defer timeTrack(time.Now(), "getEnvString")

	val = os.Getenv(env)

	if val == "" {
		return def
	}

	return
}

// getEnvInt returns int from environment variable
func getEnvInt(env string, def int) (ret int) {
	defer timeTrack(time.Now(), "getEnvInt")

	val := os.Getenv(env)

	if val == "" {
		return def
	}

	ret, err := strconv.Atoi(val)
	if err != nil {
		log.Fatalf("[ERROR] Environment variable is not numeric: %v\n", env)
	}

	return
}

func initialize() {
	defer timeTrack(time.Now(), "initialize")

	config.TimeZone, _ = time.LoadLocation("America/Chicago")
	config.TimeZoneUTC, _ = time.LoadLocation("UTC")

	// read command line options
	var (
		logLevel int
		ns1, ns2 string
	)

	// log configuration
	flag.IntVar(&logLevel,
		"log",
		getEnvInt("LOG_LEVEL", 50),
		"(LOG_LEVEL)\nlog level")
	// http client configuration
	flag.IntVar(&config.HTTPClientRequestTimeout,
		"client-req-to",
		getEnvInt("HTTP_CLIENT_REQUEST_TIMEOUT", 60),
		"(HTTP_CLIENT_REQUEST_TIMEOUT)\ntime in seconds for the internal http client to complete a request")
	flag.IntVar(&config.HTTPClientConnectTimeout,
		"client-con-to",
		getEnvInt("HTTP_CLIENT_CONNECT_TIMEOUT", 5),
		"(HTTP_CLIENT_CONNECT_TIMEOUT)\ntime in seconds for the internal http client connection timeout")
	flag.IntVar(&config.HTTPClientTLSHandshakeTimeout,
		"client-tls-to",
		getEnvInt("HTTP_CLIENT_TLS_TIMEOUT", 5),
		"(HTTP_CLIENT_TLS_TIMEOUT)\ntime in seconds for the internal http client to complete a tls handshake")
	flag.IntVar(&config.HTTPClientIdleTimeout,
		"client-idle-to",
		getEnvInt("HTTP_CLIENT_IDLE_TIMEOUT", 5),
		"(HTTP_CLIENT_IDLE_TIMEOUT)\ntime in seconds that the internal http client will keep a connection open when idle")
	// Bind Config
	flag.StringVar(&config.Config.ZoneConfig.TTL,
		"bind-ttl",
		getEnvString("TTL", "1h"),
		"(TTL)\nBind zone time to live")
	flag.StringVar(&config.Config.ZoneConfig.Domain,
		"bind-domain",
		getEnvString("DOMAIN", "example.com"),
		"(DOMAIN)\nBind zone base domain")
	flag.StringVar(&config.Config.ZoneConfig.Email,
		"bind-email",
		getEnvString("EMAIL", "domain-admin@example.com"),
		"(EMAIL)\nBind zone authority e-mail address")
	flag.StringVar(&config.Config.ZoneConfig.Serial,
		"bind-timestamp",
		getEnvString("TIMESTAMP", time.Now().In(config.TimeZone).Format("0601021504")),
		"(TIMESTAMP)\nBind zone serial number")
	flag.StringVar(&config.Config.ZoneConfig.Refresh,
		"bind-refresh",
		getEnvString("REFRESH", "1h"),
		"(REFRESH)\nBind zone refresh time")
	flag.StringVar(&config.Config.ZoneConfig.Retry,
		"bind-retry",
		getEnvString("RETRY", "30m"),
		"(RETRY)\nBind zone retry time")
	flag.StringVar(&config.Config.ZoneConfig.Expire,
		"bind-expire",
		getEnvString("EXPIRE", "1w"),
		"(EXPIRE)\nBind zone expire time")
	flag.StringVar(&config.Config.ZoneConfig.Minimum,
		"bind-minimum",
		getEnvString("MINIMUM", "1h"),
		"(MINIMUM)\nBind zone minimum time")
	flag.StringVar(&ns1,
		"bind-ns1",
		getEnvString("NS1", ""),
		"(NS1)\nBind zone primary name-server")
	flag.StringVar(&ns2,
		"bind-ns2",
		getEnvString("NS2", ""),
		"(NS2)\nBind zone secondary name-server")
	// output file
	flag.StringVar(&config.BindOutputFileName,
		"output",
		getEnvString("OUTPUT", "./response-policy.bind"),
		"(FILENAME)\nWrite local file to filename")
	flag.StringVar(&config.ConfigFileLocation,
		"config-file",
		getEnvString("CONFIG_FILE", ""),
		"(CONFIG_FILE)\nRead configuration from file")
	flag.Parse()

	// set logging level
	switch {
	case logLevel <= 20:
		config.Log.SetMinLevel(logutils.LogLevel("ERROR"))
	case logLevel > 20 && logLevel <= 40:
		config.Log.SetMinLevel(logutils.LogLevel("WARNING"))
	case logLevel > 40 && logLevel <= 60:
		config.Log.SetMinLevel(logutils.LogLevel("INFO"))
	case logLevel > 60 && logLevel <= 80:
		config.Log.SetMinLevel(logutils.LogLevel("DEBUG"))
	case logLevel > 80:
		config.Log.SetMinLevel(logutils.LogLevel("TRACE"))
	}
	log.SetOutput(config.Log)

	// print current configuration
	log.Printf("[DEBUG] configuration value set: LOG_LEVEL                   = %v\n", strconv.Itoa(logLevel))
	log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_REQUEST_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientRequestTimeout))
	log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_CONNECT_TIMEOUT = %v\n", strconv.Itoa(config.HTTPClientConnectTimeout))
	log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_TLS_TIMEOUT     = %v\n", strconv.Itoa(config.HTTPClientTLSHandshakeTimeout))
	log.Printf("[DEBUG] configuration value set: HTTP_CLIENT_IDLE_TIMEOUT    = %v\n", strconv.Itoa(config.HTTPClientIdleTimeout))
	log.Printf("[DEBUG] configuration value set: TTL                         = %v\n", config.Config.ZoneConfig.TTL)
	log.Printf("[DEBUG] configuration value set: DOMAIN                      = %v\n", config.Config.ZoneConfig.Domain)
	log.Printf("[DEBUG] configuration value set: EMAIL                       = %v\n", config.Config.ZoneConfig.Email)
	log.Printf("[DEBUG] configuration value set: TIMESTAMP                   = %v\n", config.Config.ZoneConfig.Serial)
	log.Printf("[DEBUG] configuration value set: REFRESH                     = %v\n", config.Config.ZoneConfig.Refresh)
	log.Printf("[DEBUG] configuration value set: RETRY                       = %v\n", config.Config.ZoneConfig.Retry)
	log.Printf("[DEBUG] configuration value set: EXPIRE                      = %v\n", config.Config.ZoneConfig.Expire)
	log.Printf("[DEBUG] configuration value set: MINIMUM                     = %v\n", config.Config.ZoneConfig.Minimum)
	log.Printf("[DEBUG] configuration value set: NS1                         = %v\n", ns1)
	log.Printf("[DEBUG] configuration value set: NS2                         = %v\n", ns2)
	log.Printf("[DEBUG] configuration value set: CONFIG_FILE                 = %v\n", config.ConfigFileLocation)

	// read config file
	var err error
	if config.ConfigFileLocation != "" {
		if config.Config, err = readConfigFile(config.ConfigFileLocation); err != nil {
			log.Fatalf("[FATAL] Invalid config file: %v\n", err)
		}
		if config.Config.ZoneConfig.Serial == "" {
			config.Config.ZoneConfig.Serial = time.Now().In(config.TimeZone).Format("0601021504")
		}
	}

	// set bind-config nameservers
	if ns1 != "" {
		config.Config.ZoneConfig.NameServers = append(config.Config.ZoneConfig.NameServers, ns1)
	}
	if ns2 != "" {
		config.Config.ZoneConfig.NameServers = append(config.Config.ZoneConfig.NameServers, ns2)
	}

	if len(config.Config.ZoneConfig.NameServers) == 0 {
		log.Printf("[ERROR] A primary name-server must be identified.")
		flag.PrintDefaults()
		os.Exit(1)
	}

	// bind does not use "@", so we convert it to a "."
	config.Config.ZoneConfig.Email = strings.Replace(config.Config.ZoneConfig.Email, "@", ".", -1)

	log.Printf("[DEBUG] Initialization Complete\n")
}

func readConfigFile(configFileLocation string) (configFileStruct, error) {
	defer timeTrack(time.Now(), "readConfigFile")

	var output configFileStruct

	rd, err := ioutil.ReadFile(configFileLocation)
	if err != nil {
		return output, err
	}

	if err := yaml.Unmarshal(rd, &output); err != nil {
		return output, err
	}

	return output, nil
}