2024-01-14 18:40:47 -06:00

432 lines
9.0 KiB
Go

package nsupdate
import (
"bytes"
"fmt"
"net"
"strconv"
"strings"
"os/exec"
"github.com/miekg/dns"
"istheinternetonfire.app/internal/config"
)
func New() NsUpdateStruct {
return NsUpdateStruct{}
}
func (c NsUpdateStruct) Update(record, recordType, value string) error {
r, err := getRecord(c.Server, c.Port, recordType, record)
if err != nil {
config.Cfg.Log.Debug("creating record", "record", record)
return c.Create(record, recordType, value)
}
if r != value {
config.Cfg.Log.Debug("deleting record", "record", record)
if err := c.Delete(record, recordType); err != nil {
return err
}
config.Cfg.Log.Debug("creating record", "record", record)
if err := c.Create(record, recordType, value); err != nil {
return err
}
return nil
}
config.Cfg.Log.Debug("no update necessary")
return nil
}
func (c NsUpdateStruct) Delete(record, recordType string) error {
var (
stdout bytes.Buffer
stderr bytes.Buffer
)
command := fmt.Sprintf(`#/bin/env/sh
read -r -d '' DDNS_KEY <<- EOF
key "update" {
algorithm "%s";
secret "%s";
};
EOF
read -r -d '' COMMAND <<- EOF
server %s
update delete %s. 30 IN %s
send
EOF
nsupdate -v -k <(printf '%%s' "${DDNS_KEY}") <(printf '%%s\n\n' "${COMMAND}")
`, c.Key.Algorithm, c.Key.Secret, c.Server, record, strings.ToUpper(recordType))
cmd := exec.Command("/usr/bin/sh", "-c", command)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
config.Cfg.Log.Error("error deleting record", "error", err, "stderr", stderr.String(), "stdout", stdout.String())
return err
}
return nil
}
func (c NsUpdateStruct) Create(record, recordType, value string) error {
var (
stdout bytes.Buffer
stderr bytes.Buffer
)
command := fmt.Sprintf(`#/bin/env/sh
read -r -d '' DDNS_KEY <<- EOF
key "update" {
algorithm "%s";
secret "%s";
};
EOF
read -r -d '' COMMAND <<- EOF
server %s
update add %s. 30 IN %s %s
send
EOF
nsupdate -v -k <(printf '%%s' "${DDNS_KEY}") <(printf '%%s\n\n' "${COMMAND}")
`, c.Key.Algorithm, c.Key.Secret, c.Server, record, strings.ToUpper(recordType), value)
cmd := exec.Command("/usr/bin/sh", "-c", command)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
config.Cfg.Log.Error("error adding record", "error", err, "stderr", stderr.String(), "stdout", stdout.String())
return err
}
return nil
}
func getRecord(server string, port int, queryType, query string) (string, error) {
config.Cfg.Log.Debug("looking up dns record", "query", query, "type", queryType)
recordType, err := getRecordType(queryType)
if err != nil {
return "", err
}
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion(fmt.Sprintf("%s.", query), recordType)
m.RecursionDesired = true
r, _, err := c.Exchange(m, net.JoinHostPort(server, strconv.Itoa(port)))
if err != nil {
return "", err
}
if r.Rcode != dns.RcodeSuccess {
return "", err
}
for _, a := range r.Answer {
return getRecordFromResult(a, queryType)
}
return "", fmt.Errorf("no result")
}
func getRecordFromResult(a dns.RR, queryType string) (string, error) {
switch strings.ToLower(queryType) {
case "a":
if res, ok := a.(*dns.A); ok {
return res.String(), nil
}
case "aaaa":
if res, ok := a.(*dns.AAAA); ok {
return res.String(), nil
}
case "afsdb":
if res, ok := a.(*dns.AFSDB); ok {
return res.String(), nil
}
case "apl":
if res, ok := a.(*dns.APL); ok {
return res.String(), nil
}
case "caa":
if res, ok := a.(*dns.CAA); ok {
return res.String(), nil
}
case "cdnskey":
if res, ok := a.(*dns.CDNSKEY); ok {
return res.String(), nil
}
case "cds":
if res, ok := a.(*dns.CDS); ok {
return res.String(), nil
}
case "cert":
if res, ok := a.(*dns.CERT); ok {
return res.String(), nil
}
case "cname":
if res, ok := a.(*dns.CNAME); ok {
return res.String(), nil
}
case "csync":
if res, ok := a.(*dns.CSYNC); ok {
return res.String(), nil
}
case "dhcid":
if res, ok := a.(*dns.DHCID); ok {
return res.String(), nil
}
case "dlv":
if res, ok := a.(*dns.DLV); ok {
return res.String(), nil
}
case "dname":
if res, ok := a.(*dns.DNAME); ok {
return res.String(), nil
}
case "dnskey":
if res, ok := a.(*dns.DNSKEY); ok {
return res.String(), nil
}
case "ds":
if res, ok := a.(*dns.DS); ok {
return res.String(), nil
}
case "eui48":
if res, ok := a.(*dns.EUI48); ok {
return res.String(), nil
}
case "eui64":
if res, ok := a.(*dns.EUI64); ok {
return res.String(), nil
}
case "hinfo":
if res, ok := a.(*dns.HINFO); ok {
return res.String(), nil
}
case "hip":
if res, ok := a.(*dns.HIP); ok {
return res.String(), nil
}
case "https":
if res, ok := a.(*dns.HTTPS); ok {
return res.String(), nil
}
case "ipseckey":
if res, ok := a.(*dns.IPSECKEY); ok {
return res.String(), nil
}
case "key":
if res, ok := a.(*dns.KEY); ok {
return res.String(), nil
}
case "kx":
if res, ok := a.(*dns.KX); ok {
return res.String(), nil
}
case "loc":
if res, ok := a.(*dns.LOC); ok {
return res.String(), nil
}
case "mx":
if res, ok := a.(*dns.MX); ok {
return res.String(), nil
}
case "naptr":
if res, ok := a.(*dns.NAPTR); ok {
return res.String(), nil
}
case "ns":
if res, ok := a.(*dns.NS); ok {
return res.String(), nil
}
case "nsec":
if res, ok := a.(*dns.NSEC); ok {
return res.String(), nil
}
case "nsec3":
if res, ok := a.(*dns.NSEC3); ok {
return res.String(), nil
}
case "nsec3param":
if res, ok := a.(*dns.NSEC3PARAM); ok {
return res.String(), nil
}
case "openpgpkey":
if res, ok := a.(*dns.OPENPGPKEY); ok {
return res.String(), nil
}
case "ptr":
if res, ok := a.(*dns.PTR); ok {
return res.String(), nil
}
case "rrsig":
if res, ok := a.(*dns.RRSIG); ok {
return res.String(), nil
}
case "rp":
if res, ok := a.(*dns.RP); ok {
return res.String(), nil
}
case "sig":
if res, ok := a.(*dns.SIG); ok {
return res.String(), nil
}
case "smimea":
if res, ok := a.(*dns.SMIMEA); ok {
return res.String(), nil
}
case "soa":
if res, ok := a.(*dns.SOA); ok {
return res.String(), nil
}
case "srv":
if res, ok := a.(*dns.SRV); ok {
return res.String(), nil
}
case "sshfp":
if res, ok := a.(*dns.SSHFP); ok {
return res.String(), nil
}
case "svcb":
if res, ok := a.(*dns.SVCB); ok {
return res.String(), nil
}
case "ta":
if res, ok := a.(*dns.TA); ok {
return res.String(), nil
}
case "tkey":
if res, ok := a.(*dns.TKEY); ok {
return res.String(), nil
}
case "tlsa":
if res, ok := a.(*dns.TLSA); ok {
return res.String(), nil
}
case "tsig":
if res, ok := a.(*dns.TSIG); ok {
return res.String(), nil
}
case "txt":
if res, ok := a.(*dns.TXT); ok {
return res.String(), nil
}
case "uri":
if res, ok := a.(*dns.URI); ok {
return res.String(), nil
}
case "zonemd":
if res, ok := a.(*dns.ZONEMD); ok {
return res.String(), nil
}
}
return "", fmt.Errorf("invalid record type")
}
func getRecordType(queryType string) (uint16, error) {
var recordType uint16
switch strings.ToLower(queryType) {
case "a":
recordType = dns.TypeA
case "aaaa":
recordType = dns.TypeAAAA
case "afsdb":
recordType = dns.TypeAFSDB
case "apl":
recordType = dns.TypeAPL
case "caa":
recordType = dns.TypeCAA
case "cdnskey":
recordType = dns.TypeCDNSKEY
case "cds":
recordType = dns.TypeCDS
case "cert":
recordType = dns.TypeCERT
case "cname":
recordType = dns.TypeCNAME
case "csync":
recordType = dns.TypeCSYNC
case "dhcid":
recordType = dns.TypeDHCID
case "dlv":
recordType = dns.TypeDLV
case "dname":
recordType = dns.TypeDNAME
case "dnskey":
recordType = dns.TypeDNSKEY
case "ds":
recordType = dns.TypeDS
case "eui48":
recordType = dns.TypeEUI48
case "eui64":
recordType = dns.TypeEUI64
case "hinfo":
recordType = dns.TypeHINFO
case "hip":
recordType = dns.TypeHIP
case "https":
recordType = dns.TypeHTTPS
case "ipseckey":
recordType = dns.TypeIPSECKEY
case "key":
recordType = dns.TypeKEY
case "kx":
recordType = dns.TypeKX
case "loc":
recordType = dns.TypeLOC
case "mx":
recordType = dns.TypeMX
case "naptr":
recordType = dns.TypeNAPTR
case "ns":
recordType = dns.TypeNS
case "nsec":
recordType = dns.TypeNSEC
case "nsec3":
recordType = dns.TypeNSEC3
case "nsec3param":
recordType = dns.TypeNSEC3PARAM
case "openpgpkey":
recordType = dns.TypeOPENPGPKEY
case "ptr":
recordType = dns.TypePTR
case "rrsig":
recordType = dns.TypeRRSIG
case "rp":
recordType = dns.TypeRP
case "sig":
recordType = dns.TypeSIG
case "smimea":
recordType = dns.TypeSMIMEA
case "soa":
recordType = dns.TypeSOA
case "srv":
recordType = dns.TypeSRV
case "sshfp":
recordType = dns.TypeSSHFP
case "svcb":
recordType = dns.TypeSVCB
case "ta":
recordType = dns.TypeTA
case "tkey":
recordType = dns.TypeTKEY
case "tlsa":
recordType = dns.TypeTLSA
case "tsig":
recordType = dns.TypeTSIG
case "txt":
recordType = dns.TypeTXT
case "uri":
recordType = dns.TypeURI
case "zonemd":
recordType = dns.TypeZONEMD
default:
return uint16(0), fmt.Errorf("invalid record type")
}
return recordType, nil
}