spf-to-ip4-list/spf-to-ip4-list.go

116 lines
2.8 KiB
Go

package main
import (
"fmt"
"net"
"os"
"regexp"
"strings"
)
func getSPFRecord(domain string) (string, error) {
txtRecords, err := net.LookupTXT(domain)
if err != nil {
return "", fmt.Errorf("error looking up TXT records for %s: %v", domain, err)
}
for _, record := range txtRecords {
if strings.HasPrefix(record, "v=spf1") {
return record, nil
}
}
return "", fmt.Errorf("no SPF record found for %s", domain)
}
func collectSpfRecordDetails(domain string) (map[string][]string, error) {
spfRecord, err := getSPFRecord(domain)
if err != nil {
return nil, err
}
result := make(map[string][]string)
for key, regexStr := range map[string]string{
"ipv4": `ip4:([\d\./]+)`,
"ipv6": `ip6:([a-f0-9\:\/]+)`,
"includes": `include\:([a-z\.\-\_0-9]+)`,
"redirects": `redirect\=([a-z\.\-\_0-9]+)`,
"a": `a\:([a-z\.\-\_0-9]+)`,
} {
re, err := regexp.Compile(regexStr)
if err != nil {
return nil, fmt.Errorf("error compiling regex for %s: %v", key, err)
}
matches := re.FindAllStringSubmatch(spfRecord, -1)
for _, match := range matches {
result[key] = append(result[key], match[1]) // Extract the captured group
}
}
return result, nil
}
func collectAllItemsRecursively(domain, itemType string) ([]string, error) {
items, err := collectSpfRecordDetails(domain)
if err != nil {
return nil, err
}
allItems := items[itemType]
for _, item := range items[itemType] {
subItems, err := collectAllItemsRecursively(item, itemType)
if err != nil {
return nil, err
}
allItems = append(allItems, subItems...)
}
return allItems, nil
}
func getAllSpfIPs(domain, ipVersion string) ([]string, error) {
redirects, err := collectAllItemsRecursively(domain, "redirects")
if err != nil {
return nil, err
}
includes, err := collectAllItemsRecursively(domain, "includes")
if err != nil {
return nil, err
}
items, err := collectSpfRecordDetails(domain)
if err != nil {
return nil, err
}
allIPs := items[ipVersion]
for _, item := range append(redirects, includes...) {
subIPs, err := getAllSpfIPs(item, ipVersion)
if err != nil {
return nil, err
}
allIPs = append(allIPs, subIPs...)
}
return allIPs, nil
}
func main() {
if len(os.Args) < 2 {
fmt.Println("Usage: spf_ips <domain>")
return
}
domain := os.Args[1]
ipv4s, err := getAllSpfIPs(domain, "ipv4")
if err != nil {
fmt.Println("Error:", err)
return
}
for _, ip := range ipv4s {
fmt.Println(ip)
}
}