package nsupdate import ( "bytes" "fmt" "net" "strconv" "strings" "os/exec" "github.com/miekg/dns" "istheinternetonfire.app/internal/config" ) func New(keyLabel, keyAlgorithm, keySecret, server string, port int, zone string) NsUpdateStruct { return NsUpdateStruct{ Key: KeyStruct{ Label: keyLabel, Algorithm: keyAlgorithm, Secret: keySecret, }, Server: server, Port: port, Zone: zone, } } var ( // escape characters specialChars = []string{ `"`, } ) func (c NsUpdateStruct) Update(record, recordType, value string) error { if strings.ToUpper(recordType) == "TXT" { // sanitize TXT value for _, v := range specialChars { value = strings.ReplaceAll(value, v, fmt.Sprintf("\\%s", v)) } // convert to rune a := []rune(value) // split into 255 character blocks value = "" for i, r := range a { value += string(r) if i > 0 && (i+1)%254 == 0 { value += string(`" "`) } } // convert new lines into safe string value = strings.ReplaceAll(value, "\n", "%n") } r, err := getRecord(c.Server, c.Port, recordType, record) if err != nil { config.Cfg.Log.Info("unable to get existing record", "record", record, "error", err) } if r != value { config.Cfg.Log.Debug("deleting record", "record", record) if err := c.Delete(record, recordType); err != nil { return err } config.Cfg.Log.Debug("creating record", "record", record) if err := c.Create(record, recordType, value); err != nil { return err } return nil } config.Cfg.Log.Debug("no update necessary") return nil } func (c NsUpdateStruct) Delete(record, recordType string) error { var ( stdout bytes.Buffer stderr bytes.Buffer ) command := fmt.Sprintf(`#/bin/env/sh read -r -d '' DDNS_KEY <<- EOF key "%s" { algorithm "%s"; secret "%s"; }; EOF read -r -d '' COMMAND <<- EOF server %s update delete %s. 30 IN %s send EOF nsupdate -v -k <(printf '%%s' "${DDNS_KEY}") <(printf '%%s\n\n' "${COMMAND}") `, c.Key.Label, c.Key.Algorithm, c.Key.Secret, c.Server, record, strings.ToUpper(recordType)) cmd := exec.Command("/usr/bin/sh", "-c", command) cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { config.Cfg.Log.Error("error deleting record", "error", err, "stderr", stderr.String(), "stdout", stdout.String()) return err } return nil } func (c NsUpdateStruct) Create(record, recordType, value string) error { var ( stdout bytes.Buffer stderr bytes.Buffer ) command := fmt.Sprintf(`#/bin/env/sh read -r -d '' DDNS_KEY <<- EOF key "%s" { algorithm "%s"; secret "%s"; }; EOF read -r -d '' COMMAND <<- EOF server %s update add %s. 30 IN %s "%s" send EOF nsupdate -v -k <(printf '%%s' "${DDNS_KEY}") <(printf '%%s\n\n' "${COMMAND}") `, c.Key.Label, c.Key.Algorithm, c.Key.Secret, c.Server, record, strings.ToUpper(recordType), value) cmd := exec.Command("/usr/bin/sh", "-c", command) cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { config.Cfg.Log.Error("error adding record", "error", err, "stderr", stderr.String(), "stdout", stdout.String()) return err } return nil } func getRecord(server string, port int, queryType, query string) (string, error) { config.Cfg.Log.Debug("looking up dns record", "query", query, "type", queryType) recordType, err := getRecordType(queryType) if err != nil { return "", err } c := new(dns.Client) m := new(dns.Msg) m.SetQuestion(fmt.Sprintf("%s.", query), recordType) m.RecursionDesired = true r, _, err := c.Exchange(m, net.JoinHostPort(server, strconv.Itoa(port))) if err != nil { return "", err } if r.Rcode != dns.RcodeSuccess { return "", err } for _, a := range r.Answer { return getRecordFromResult(a, queryType) } return "", fmt.Errorf("no result") } func getRecordFromResult(a dns.RR, queryType string) (string, error) { switch strings.ToLower(queryType) { case "a": if res, ok := a.(*dns.A); ok { return res.String(), nil } case "aaaa": if res, ok := a.(*dns.AAAA); ok { return res.String(), nil } case "afsdb": if res, ok := a.(*dns.AFSDB); ok { return res.String(), nil } case "apl": if res, ok := a.(*dns.APL); ok { return res.String(), nil } case "caa": if res, ok := a.(*dns.CAA); ok { return res.String(), nil } case "cdnskey": if res, ok := a.(*dns.CDNSKEY); ok { return res.String(), nil } case "cds": if res, ok := a.(*dns.CDS); ok { return res.String(), nil } case "cert": if res, ok := a.(*dns.CERT); ok { return res.String(), nil } case "cname": if res, ok := a.(*dns.CNAME); ok { return res.String(), nil } case "csync": if res, ok := a.(*dns.CSYNC); ok { return res.String(), nil } case "dhcid": if res, ok := a.(*dns.DHCID); ok { return res.String(), nil } case "dlv": if res, ok := a.(*dns.DLV); ok { return res.String(), nil } case "dname": if res, ok := a.(*dns.DNAME); ok { return res.String(), nil } case "dnskey": if res, ok := a.(*dns.DNSKEY); ok { return res.String(), nil } case "ds": if res, ok := a.(*dns.DS); ok { return res.String(), nil } case "eui48": if res, ok := a.(*dns.EUI48); ok { return res.String(), nil } case "eui64": if res, ok := a.(*dns.EUI64); ok { return res.String(), nil } case "hinfo": if res, ok := a.(*dns.HINFO); ok { return res.String(), nil } case "hip": if res, ok := a.(*dns.HIP); ok { return res.String(), nil } case "https": if res, ok := a.(*dns.HTTPS); ok { return res.String(), nil } case "ipseckey": if res, ok := a.(*dns.IPSECKEY); ok { return res.String(), nil } case "key": if res, ok := a.(*dns.KEY); ok { return res.String(), nil } case "kx": if res, ok := a.(*dns.KX); ok { return res.String(), nil } case "loc": if res, ok := a.(*dns.LOC); ok { return res.String(), nil } case "mx": if res, ok := a.(*dns.MX); ok { return res.String(), nil } case "naptr": if res, ok := a.(*dns.NAPTR); ok { return res.String(), nil } case "ns": if res, ok := a.(*dns.NS); ok { return res.String(), nil } case "nsec": if res, ok := a.(*dns.NSEC); ok { return res.String(), nil } case "nsec3": if res, ok := a.(*dns.NSEC3); ok { return res.String(), nil } case "nsec3param": if res, ok := a.(*dns.NSEC3PARAM); ok { return res.String(), nil } case "openpgpkey": if res, ok := a.(*dns.OPENPGPKEY); ok { return res.String(), nil } case "ptr": if res, ok := a.(*dns.PTR); ok { return res.String(), nil } case "rrsig": if res, ok := a.(*dns.RRSIG); ok { return res.String(), nil } case "rp": if res, ok := a.(*dns.RP); ok { return res.String(), nil } case "sig": if res, ok := a.(*dns.SIG); ok { return res.String(), nil } case "smimea": if res, ok := a.(*dns.SMIMEA); ok { return res.String(), nil } case "soa": if res, ok := a.(*dns.SOA); ok { return res.String(), nil } case "srv": if res, ok := a.(*dns.SRV); ok { return res.String(), nil } case "sshfp": if res, ok := a.(*dns.SSHFP); ok { return res.String(), nil } case "svcb": if res, ok := a.(*dns.SVCB); ok { return res.String(), nil } case "ta": if res, ok := a.(*dns.TA); ok { return res.String(), nil } case "tkey": if res, ok := a.(*dns.TKEY); ok { return res.String(), nil } case "tlsa": if res, ok := a.(*dns.TLSA); ok { return res.String(), nil } case "tsig": if res, ok := a.(*dns.TSIG); ok { return res.String(), nil } case "txt": if res, ok := a.(*dns.TXT); ok { return res.String(), nil } case "uri": if res, ok := a.(*dns.URI); ok { return res.String(), nil } case "zonemd": if res, ok := a.(*dns.ZONEMD); ok { return res.String(), nil } } return "", fmt.Errorf("invalid record type") } func getRecordType(queryType string) (uint16, error) { var recordType uint16 switch strings.ToLower(queryType) { case "a": recordType = dns.TypeA case "aaaa": recordType = dns.TypeAAAA case "afsdb": recordType = dns.TypeAFSDB case "apl": recordType = dns.TypeAPL case "caa": recordType = dns.TypeCAA case "cdnskey": recordType = dns.TypeCDNSKEY case "cds": recordType = dns.TypeCDS case "cert": recordType = dns.TypeCERT case "cname": recordType = dns.TypeCNAME case "csync": recordType = dns.TypeCSYNC case "dhcid": recordType = dns.TypeDHCID case "dlv": recordType = dns.TypeDLV case "dname": recordType = dns.TypeDNAME case "dnskey": recordType = dns.TypeDNSKEY case "ds": recordType = dns.TypeDS case "eui48": recordType = dns.TypeEUI48 case "eui64": recordType = dns.TypeEUI64 case "hinfo": recordType = dns.TypeHINFO case "hip": recordType = dns.TypeHIP case "https": recordType = dns.TypeHTTPS case "ipseckey": recordType = dns.TypeIPSECKEY case "key": recordType = dns.TypeKEY case "kx": recordType = dns.TypeKX case "loc": recordType = dns.TypeLOC case "mx": recordType = dns.TypeMX case "naptr": recordType = dns.TypeNAPTR case "ns": recordType = dns.TypeNS case "nsec": recordType = dns.TypeNSEC case "nsec3": recordType = dns.TypeNSEC3 case "nsec3param": recordType = dns.TypeNSEC3PARAM case "openpgpkey": recordType = dns.TypeOPENPGPKEY case "ptr": recordType = dns.TypePTR case "rrsig": recordType = dns.TypeRRSIG case "rp": recordType = dns.TypeRP case "sig": recordType = dns.TypeSIG case "smimea": recordType = dns.TypeSMIMEA case "soa": recordType = dns.TypeSOA case "srv": recordType = dns.TypeSRV case "sshfp": recordType = dns.TypeSSHFP case "svcb": recordType = dns.TypeSVCB case "ta": recordType = dns.TypeTA case "tkey": recordType = dns.TypeTKEY case "tlsa": recordType = dns.TypeTLSA case "tsig": recordType = dns.TypeTSIG case "txt": recordType = dns.TypeTXT case "uri": recordType = dns.TypeURI case "zonemd": recordType = dns.TypeZONEMD default: return uint16(0), fmt.Errorf("invalid record type") } return recordType, nil }