feat: Implement configuration management and DNS provider integration
- Added configuration management using Viper in internal/config/config.go - Implemented ClientConfig, ServerConfig, TLSConfig, HetznerConfig, UpstreamConfig, and main Config struct. - Created LoadConfig function to read and validate configuration files. - Developed Hetzner DNS provider in internal/provider/hetzner/hetzner.go with methods for updating DNS records. - Added comprehensive unit tests for configuration loading and Hetzner provider functionality. - Established HTTP server with metrics and update endpoint in internal/server/server.go. - Implemented request handling, authorization, and error management in the server. - Created integration tests for the Hetzner provider API interactions. - Removed legacy dynamic DNS integration tests in favor of the new API-based approach.
This commit is contained in:
62
internal/config/config.go
Normal file
62
internal/config/config.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
Exact []string `mapstructure:"exact"`
|
||||
Wildcard []string `mapstructure:"wildcard"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
BindAddress string `mapstructure:"bind_address"`
|
||||
TLS TLSConfig `mapstructure:"tls"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
CertFile string `mapstructure:"cert_file"`
|
||||
KeyFile string `mapstructure:"key_file"`
|
||||
}
|
||||
|
||||
type HetznerConfig struct {
|
||||
APIToken string `mapstructure:"api_token"`
|
||||
}
|
||||
|
||||
type UpstreamConfig struct {
|
||||
Provider string `mapstructure:"provider"`
|
||||
Hetzner HetznerConfig `mapstructure:"hetzner"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Upstream UpstreamConfig `mapstructure:"upstream"`
|
||||
Clients map[string]ClientConfig `mapstructure:"clients"`
|
||||
}
|
||||
|
||||
// LoadConfig reads the file at path (yaml, json, toml) into Config and validates it.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
v := viper.New()
|
||||
v.SetConfigFile(path)
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("reading config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
for name, client := range cfg.Clients {
|
||||
if len(client.Exact) == 0 && len(client.Wildcard) == 0 {
|
||||
return nil, fmt.Errorf("client %q must have at least one of exact or wildcard", name)
|
||||
}
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
64
internal/config/config_test.go
Normal file
64
internal/config/config_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/config"
|
||||
)
|
||||
|
||||
func TestLoadConfig_Success(t *testing.T) {
|
||||
content := `
|
||||
server:
|
||||
bind_address: ":9090"
|
||||
tls:
|
||||
enabled: true
|
||||
cert_file: "cert.pem"
|
||||
key_file: "key.pem"
|
||||
upstream:
|
||||
provider: hetzner
|
||||
hetzner:
|
||||
api_token: "token123"
|
||||
clients:
|
||||
clientA:
|
||||
secret: "sec"
|
||||
exact:
|
||||
- "foo.com"
|
||||
wildcard:
|
||||
- "bar.com"
|
||||
`
|
||||
tmp := filepath.Join(os.TempDir(), "config_test.yaml")
|
||||
if err := os.WriteFile(tmp, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
|
||||
cfg, err := config.LoadConfig(tmp)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if cfg.Server.BindAddress != ":9090" {
|
||||
t.Errorf("expected bind_address :9090, got %s", cfg.Server.BindAddress)
|
||||
}
|
||||
if !cfg.Server.TLS.Enabled {
|
||||
t.Error("expected TLS enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_Failure(t *testing.T) {
|
||||
content := `
|
||||
clients:
|
||||
clientB:
|
||||
secret: "sec"
|
||||
`
|
||||
tmp := filepath.Join(os.TempDir(), "config_fail.yaml")
|
||||
if err := os.WriteFile(tmp, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
|
||||
if _, err := config.LoadConfig(tmp); err == nil {
|
||||
t.Fatal("expected error for missing fields, got nil")
|
||||
}
|
||||
}
|
||||
132
internal/provider/hetzner/hetzner.go
Normal file
132
internal/provider/hetzner/hetzner.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package hetzner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
pvd "git.cloonar.com/cloonar/updns/internal/provider"
|
||||
)
|
||||
|
||||
const defaultAPIBase = "https://dns.hetzner.com"
|
||||
|
||||
type provider struct {
|
||||
token string
|
||||
client *http.Client
|
||||
apiBaseURL string
|
||||
}
|
||||
|
||||
type zone struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type zonesResponse struct {
|
||||
Zones []zone `json:"zones"`
|
||||
}
|
||||
|
||||
type record struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type recordsResponse struct {
|
||||
Records []record `json:"records"`
|
||||
}
|
||||
|
||||
// NewProvider creates a Hetzner DNS provider using the official API.
|
||||
func NewProvider(token string) pvd.Provider {
|
||||
return &provider{token: token, client: http.DefaultClient, apiBaseURL: defaultAPIBase}
|
||||
}
|
||||
|
||||
// NewProviderWithURL creates a Hetzner provider with a custom API base URL (for testing).
|
||||
func NewProviderWithURL(token, apiBaseURL string) pvd.Provider {
|
||||
return &provider{token: token, client: http.DefaultClient, apiBaseURL: apiBaseURL}
|
||||
}
|
||||
|
||||
// UpdateRecord updates the DNS record for the given domain to the provided IP.
|
||||
func (p *provider) UpdateRecord(ctx context.Context, domain, ip string) error {
|
||||
// Determine zone name (last two labels, or single label for two-part domains)
|
||||
parts := strings.Split(domain, ".")
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("invalid domain: %s", domain)
|
||||
}
|
||||
var zoneName string
|
||||
if len(parts) == 2 {
|
||||
zoneName = parts[1]
|
||||
} else {
|
||||
zoneName = strings.Join(parts[len(parts)-2:], ".")
|
||||
}
|
||||
// Fetch zone ID
|
||||
zonesURL := fmt.Sprintf("%s/zones?name=%s", p.apiBaseURL, zoneName)
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, zonesURL, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+p.token)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch zones: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("hetzner update failed with status: %s", resp.Status)
|
||||
}
|
||||
var zr zonesResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&zr); err != nil {
|
||||
return fmt.Errorf("parsing zones response: %w", err)
|
||||
}
|
||||
if len(zr.Zones) == 0 {
|
||||
return fmt.Errorf("zone %s not found", zoneName)
|
||||
}
|
||||
zoneID := zr.Zones[0].ID
|
||||
|
||||
// Fetch records in zone
|
||||
recsURL := fmt.Sprintf("%s/records?zone_id=%s", p.apiBaseURL, zoneID)
|
||||
req, _ = http.NewRequestWithContext(ctx, http.MethodGet, recsURL, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+p.token)
|
||||
resp, err = p.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch records: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("hetzner update failed with status: %s", resp.Status)
|
||||
}
|
||||
var rr recordsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rr); err != nil {
|
||||
return fmt.Errorf("parsing records response: %w", err)
|
||||
}
|
||||
var recID string
|
||||
for _, rec := range rr.Records {
|
||||
if rec.Name == domain {
|
||||
recID = rec.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
if recID == "" {
|
||||
return fmt.Errorf("record %s not found", domain)
|
||||
}
|
||||
|
||||
// Update record value
|
||||
updateURL := fmt.Sprintf("%s/records/%s", p.apiBaseURL, recID)
|
||||
body := map[string]string{"value": ip}
|
||||
buf := &bytes.Buffer{}
|
||||
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||
return fmt.Errorf("encode update body: %w", err)
|
||||
}
|
||||
req, _ = http.NewRequestWithContext(ctx, http.MethodPut, updateURL, buf)
|
||||
req.Header.Set("Authorization", "Bearer "+p.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err = p.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update record: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("update API status: %s", resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
91
internal/provider/hetzner/hetzner_api_test.go
Normal file
91
internal/provider/hetzner/hetzner_api_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package hetzner_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/provider/hetzner"
|
||||
)
|
||||
|
||||
func TestUpdateRecordFullAPILifecycle(t *testing.T) {
|
||||
domain := "test.example.com"
|
||||
ip := "1.2.3.4"
|
||||
zoneName := "example.com"
|
||||
zoneID := "zone-123"
|
||||
recID := "rec-456"
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/zones":
|
||||
// Query zones by name
|
||||
resp := map[string]interface{}{
|
||||
"zones": []map[string]string{{"id": zoneID, "name": zoneName}},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/records":
|
||||
// Query records by zone_id
|
||||
resp := map[string]interface{}{
|
||||
"records": []map[string]string{{"id": recID, "name": domain, "value": "0.0.0.0", "type": "A"}},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/records/"+recID:
|
||||
// Validate update payload
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var payload map[string]string
|
||||
json.Unmarshal(body, &payload)
|
||||
if payload["value"] != ip {
|
||||
t.Errorf("expected update value %s, got %s", ip, payload["value"])
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := hetzner.NewProviderWithURL("token", ts.URL)
|
||||
if err := provider.UpdateRecord(context.Background(), domain, ip); err != nil {
|
||||
t.Fatalf("full API lifecycle failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRecordZoneNotFound(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"zones": []map[string]string{}})
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := hetzner.NewProviderWithURL("token", ts.URL)
|
||||
err := provider.UpdateRecord(context.Background(), "nozone.example", "1.1.1.1")
|
||||
if err == nil || !strings.Contains(err.Error(), "zone example not found") {
|
||||
t.Fatalf("expected zone not found error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRecordRecordNotFound(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/zones":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"zones": []map[string]string{{"id": "z", "name": "example.com"}},
|
||||
})
|
||||
case "/records":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"records": []map[string]string{}})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := hetzner.NewProviderWithURL("token", ts.URL)
|
||||
err := provider.UpdateRecord(context.Background(), "missing.example.com", "1.1.1.1")
|
||||
if err == nil || !strings.Contains(err.Error(), "record missing.example.com not found") {
|
||||
t.Fatalf("expected record not found error, got %v", err)
|
||||
}
|
||||
}
|
||||
4
internal/provider/hetzner/hetzner_integration_test.go
Normal file
4
internal/provider/hetzner/hetzner_integration_test.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package hetzner_test
|
||||
|
||||
// Legacy dynamic DNS integration tests have been removed.
|
||||
// The Hetzner provider now uses the official DNS API; see hetzner_api_test.go for coverage.
|
||||
12
internal/provider/hetzner/hetzner_test.go
Normal file
12
internal/provider/hetzner/hetzner_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package hetzner_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
pvd "git.cloonar.com/cloonar/updns/internal/provider"
|
||||
"git.cloonar.com/cloonar/updns/internal/provider/hetzner"
|
||||
)
|
||||
|
||||
func TestNewProviderImplementsInterface(t *testing.T) {
|
||||
var _ pvd.Provider = hetzner.NewProvider("token")
|
||||
}
|
||||
8
internal/provider/provider.go
Normal file
8
internal/provider/provider.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package provider
|
||||
|
||||
import "context"
|
||||
|
||||
// Provider defines the interface for updating DNS records.
|
||||
type Provider interface {
|
||||
UpdateRecord(ctx context.Context, domain, ip string) error
|
||||
}
|
||||
22
internal/provider/provider_test.go
Normal file
22
internal/provider/provider_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package provider_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/provider"
|
||||
)
|
||||
|
||||
// mockProvider is a dummy implementation for testing the Provider interface.
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) UpdateRecord(ctx context.Context, domain, ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProviderInterfaceCompliance(t *testing.T) {
|
||||
var p provider.Provider = &mockProvider{}
|
||||
if err := p.UpdateRecord(context.Background(), "example.com", "1.2.3.4"); err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
151
internal/server/server.go
Normal file
151
internal/server/server.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/config"
|
||||
pvd "git.cloonar.com/cloonar/updns/internal/provider"
|
||||
"git.cloonar.com/cloonar/updns/internal/provider/hetzner"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
totalUpdates = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "updns_total_updates",
|
||||
Help: "Total number of update requests",
|
||||
})
|
||||
successUpdates = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "updns_success_updates",
|
||||
Help: "Number of successful updates",
|
||||
})
|
||||
failedUpdates = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "updns_failed_updates",
|
||||
Help: "Number of failed updates",
|
||||
})
|
||||
exactAuth = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "updns_exact_auth",
|
||||
Help: "Number of updates authorized via exact match",
|
||||
})
|
||||
wildcardAuth = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "updns_wildcard_auth",
|
||||
Help: "Number of updates authorized via wildcard",
|
||||
})
|
||||
)
|
||||
|
||||
type updateRequest struct {
|
||||
Key string `json:"key" binding:"required"`
|
||||
Secret string `json:"secret" binding:"required"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
// NewRouter constructs the HTTP handler with routes, middleware, logging and metrics.
|
||||
func NewRouter(cfg *config.Config) *gin.Engine {
|
||||
logger, _ := zap.NewProduction()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
logger.Info("request",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Int("status", c.Writer.Status()),
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
)
|
||||
})
|
||||
|
||||
r.GET("/metrics", gin.WrapH(promhttp.Handler()))
|
||||
|
||||
r.POST("/update", func(c *gin.Context) {
|
||||
totalUpdates.Inc()
|
||||
var req updateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
failedUpdates.Inc()
|
||||
logger.Error("invalid request", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
ip := req.IP
|
||||
if ip == "" {
|
||||
ip = c.ClientIP()
|
||||
}
|
||||
clientCfg, ok := cfg.Clients[req.Key]
|
||||
if !ok || req.Secret != clientCfg.Secret {
|
||||
failedUpdates.Inc()
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"status": "error", "message": "invalid key or secret"})
|
||||
return
|
||||
}
|
||||
authorized := false
|
||||
for _, h := range clientCfg.Exact {
|
||||
if req.Host == h {
|
||||
authorized = true
|
||||
exactAuth.Inc()
|
||||
break
|
||||
}
|
||||
}
|
||||
if !authorized {
|
||||
for _, base := range clientCfg.Wildcard {
|
||||
if req.Host == base || strings.HasSuffix(req.Host, "."+base) {
|
||||
authorized = true
|
||||
wildcardAuth.Inc()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !authorized {
|
||||
failedUpdates.Inc()
|
||||
c.JSON(http.StatusForbidden, gin.H{"status": "error", "message": "host not authorized"})
|
||||
return
|
||||
}
|
||||
prov, ok := selectProvider(cfg)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "provider not configured"})
|
||||
return
|
||||
}
|
||||
if err := prov.UpdateRecord(c.Request.Context(), req.Host, ip); err != nil {
|
||||
failedUpdates.Inc()
|
||||
logger.Error("update record failed", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "update failed"})
|
||||
return
|
||||
}
|
||||
successUpdates.Inc()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "message": "Record updated"})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// StartServer initializes Provider and starts the HTTP server.
|
||||
func StartServer(cfg *config.Config) error {
|
||||
prov, ok := selectProvider(cfg)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported provider: %s", cfg.Upstream.Provider)
|
||||
}
|
||||
// drop unused to avoid compile error
|
||||
_ = prov
|
||||
|
||||
router := NewRouter(cfg)
|
||||
if cfg.Server.TLS.Enabled {
|
||||
return router.RunTLS(cfg.Server.BindAddress, cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
|
||||
}
|
||||
return router.Run(cfg.Server.BindAddress)
|
||||
}
|
||||
|
||||
// selectProvider returns the configured Provider or false if unsupported.
|
||||
func selectProvider(cfg *config.Config) (pvd.Provider, bool) {
|
||||
switch cfg.Upstream.Provider {
|
||||
case "hetzner":
|
||||
return hetzner.NewProvider(cfg.Upstream.Hetzner.APIToken), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
105
internal/server/server_router_test.go
Normal file
105
internal/server/server_router_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/config"
|
||||
"git.cloonar.com/cloonar/updns/internal/server"
|
||||
)
|
||||
|
||||
func newTestConfig(provider string) *config.Config {
|
||||
return &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
BindAddress: ":0",
|
||||
TLS: config.TLSConfig{Enabled: false},
|
||||
},
|
||||
Upstream: config.UpstreamConfig{
|
||||
Provider: provider,
|
||||
Hetzner: config.HetznerConfig{APIToken: "token"},
|
||||
},
|
||||
Clients: map[string]config.ClientConfig{
|
||||
"client1": {
|
||||
Secret: "s3cr3t",
|
||||
Exact: []string{"a.example.com"},
|
||||
Wildcard: []string{"example.net"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsEndpoint(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("unknown"))
|
||||
req := httptest.NewRequest("GET", "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 OK, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInvalidJSON(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("unknown"))
|
||||
req := httptest.NewRequest("POST", "/update", bytes.NewBufferString("{invalid"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 BadRequest, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUnauthorizedKey(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("unknown"))
|
||||
body := map[string]string{"key": "bad", "secret": "x", "host": "a.example.com"}
|
||||
data, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 Unauthorized, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHostForbidden(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("unknown"))
|
||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "bad.example.com"}
|
||||
data, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 Forbidden, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateProviderNotConfigured(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("unknown"))
|
||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com"}
|
||||
data, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 InternalServerError, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSuccess(t *testing.T) {
|
||||
r := server.NewRouter(newTestConfig("hetzner"))
|
||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com", "ip": "1.2.3.4"}
|
||||
data, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 OK, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
31
internal/server/server_test.go
Normal file
31
internal/server/server_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.cloonar.com/cloonar/updns/internal/config"
|
||||
"git.cloonar.com/cloonar/updns/internal/server"
|
||||
)
|
||||
|
||||
func TestStartServerUnsupportedProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
BindAddress: "127.0.0.1:0",
|
||||
TLS: config.TLSConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
Upstream: config.UpstreamConfig{
|
||||
Provider: "unknown",
|
||||
},
|
||||
Clients: map[string]config.ClientConfig{},
|
||||
}
|
||||
err := server.StartServer(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported provider") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user