From 4819f9256919d5acbde54c784da6bb69f22d3e0e Mon Sep 17 00:00:00 2001 From: Dominik Polakovics Date: Fri, 25 Apr 2025 21:24:59 +0200 Subject: [PATCH] feat: add posibility to use token file in hetzner config --- example-config.yaml | 5 +- internal/config/config.go | 3 +- internal/provider/hetzner/hetzner.go | 147 +++++++++++------- internal/provider/hetzner/hetzner_api_test.go | 61 ++++++-- internal/provider/hetzner/hetzner_test.go | 8 +- internal/server/server.go | 48 ++++-- 6 files changed, 189 insertions(+), 83 deletions(-) diff --git a/example-config.yaml b/example-config.yaml index 19605fc..d98cbdd 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -7,7 +7,10 @@ server: upstream: provider: hetzner hetzner: - api_token: "YOUR_HETZNER_API_TOKEN" + # Provide the API token directly + # api_token: "YOUR_HETZNER_API_TOKEN" + # OR provide the path to a file containing the token + api_token_file: "/path/to/your/hetzner_token.txt" clients: client1: secret: "s3cr3t123" diff --git a/internal/config/config.go b/internal/config/config.go index 2c8013e..8035977 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,7 +24,8 @@ type TLSConfig struct { } type HetznerConfig struct { - APIToken string `mapstructure:"api_token"` + APIToken string `mapstructure:"api_token"` + APITokenFile string `mapstructure:"api_token_file"` } type UpstreamConfig struct { diff --git a/internal/provider/hetzner/hetzner.go b/internal/provider/hetzner/hetzner.go index 72476f5..c155c0c 100644 --- a/internal/provider/hetzner/hetzner.go +++ b/internal/provider/hetzner/hetzner.go @@ -6,8 +6,10 @@ import ( "encoding/json" "fmt" "net/http" + "os" "strings" + "git.cloonar.com/cloonar/updns/internal/config" pvd "git.cloonar.com/cloonar/updns/internal/provider" ) @@ -33,8 +35,8 @@ type record struct { Name string `json:"name"` Value string `json:"value"` Type string `json:"type"` - ZoneID string `json:"zone_id"` - TTL int `json:"ttl"` + ZoneID string `json:"zone_id"` + TTL int `json:"ttl"` } type recordsResponse struct { @@ -42,12 +44,47 @@ type recordsResponse struct { } // NewProvider creates a Hetzner DNS provider using the official API. -func NewProvider(token string) pvd.Provider { - return &provider{token: token, client: http.DefaultClient, apiBaseURL: defaultAPIBase} +func NewProvider(cfg config.HetznerConfig) (pvd.Provider, error) { + var token string + + hasToken := cfg.APIToken != "" + hasTokenFile := cfg.APITokenFile != "" + + if hasToken && hasTokenFile { + return nil, fmt.Errorf("hetzner config: provide api_token or api_token_file, not both") + } + if !hasToken && !hasTokenFile { + return nil, fmt.Errorf("hetzner config: api_token or api_token_file must be provided") + } + + if hasTokenFile { + tokenBytes, err := os.ReadFile(cfg.APITokenFile) + if err != nil { + return nil, fmt.Errorf("reading hetzner token file %q: %w", cfg.APITokenFile, err) + } + token = strings.TrimSpace(string(tokenBytes)) + } else { + token = cfg.APIToken + } + + if token == "" { + // This case might happen if the file exists but is empty + return nil, fmt.Errorf("hetzner api token is empty") + } + + return &provider{token: token, client: http.DefaultClient, apiBaseURL: defaultAPIBase}, nil } // NewProviderWithURL creates a Hetzner provider with a custom API base URL (for testing). +// Note: This testing helper still requires a direct token string. func NewProviderWithURL(token, apiBaseURL string) pvd.Provider { + // Basic validation for the test helper + if token == "" { + panic("NewProviderWithURL requires a non-empty token for testing") + } + if apiBaseURL == "" { + panic("NewProviderWithURL requires a non-empty apiBaseURL for testing") + } return &provider{token: token, client: http.DefaultClient, apiBaseURL: apiBaseURL} } @@ -64,7 +101,7 @@ func (p *provider) UpdateRecord(ctx context.Context, domain, ip string) error { } else { zoneName = strings.Join(parts[len(parts)-2:], ".") } - subdomain := strings.Join(parts[:len(parts)-2], ".") + subdomain := strings.Join(parts[:len(parts)-2], ".") // Fetch zone ID zonesURL := fmt.Sprintf("%s/zones?name=%s", p.apiBaseURL, zoneName) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, zonesURL, nil) @@ -111,57 +148,57 @@ func (p *provider) UpdateRecord(ctx context.Context, domain, ip string) error { } if recID == "" { // return fmt.Errorf("record %s not found", domain) - // Create new record - // Cut the last 2 parts of the domain name - createURL := fmt.Sprintf("%s/records", p.apiBaseURL) - body := record{ - Name: subdomain, - Type: "A", - Value: ip, - TTL: 60, - ZoneID: zoneID, - } - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(body); err != nil { - return fmt.Errorf("encode create body: %w", err) - } - req, _ = http.NewRequestWithContext(ctx, http.MethodPost, createURL, buf) - req.Header.Set("Auth-API-Token", p.token) - req.Header.Set("Content-Type", "application/json") - resp, err = p.client.Do(req) - if err != nil { - return fmt.Errorf("create record: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("create API status: %s", resp.Status) - } + // Create new record + // Cut the last 2 parts of the domain name + createURL := fmt.Sprintf("%s/records", p.apiBaseURL) + body := record{ + Name: subdomain, + Type: "A", + Value: ip, + TTL: 60, + ZoneID: zoneID, + } + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(body); err != nil { + return fmt.Errorf("encode create body: %w", err) + } + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, createURL, buf) + req.Header.Set("Auth-API-Token", p.token) + req.Header.Set("Content-Type", "application/json") + resp, err = p.client.Do(req) + if err != nil { + return fmt.Errorf("create record: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("create API status: %s", resp.Status) + } } else { - // Update record value - updateURL := fmt.Sprintf("%s/records/%s", p.apiBaseURL, recID) - body := record{ - Name: subdomain, - Type: "A", - Value: ip, - TTL: 60, - ZoneID: zoneID, - } - buf := &bytes.Buffer{} - if err := json.NewEncoder(buf).Encode(body); err != nil { - return fmt.Errorf("encode update body: %w", err) - } - req, _ = http.NewRequestWithContext(ctx, http.MethodPut, updateURL, buf) - req.Header.Set("Auth-API-Token", p.token) - req.Header.Set("Content-Type", "application/json") - resp, err = p.client.Do(req) - if err != nil { - return fmt.Errorf("update record: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("update API status: %s", resp.Status) - } - } + // Update record value + updateURL := fmt.Sprintf("%s/records/%s", p.apiBaseURL, recID) + body := record{ + Name: subdomain, + Type: "A", + Value: ip, + TTL: 60, + ZoneID: zoneID, + } + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(body); err != nil { + return fmt.Errorf("encode update body: %w", err) + } + req, _ = http.NewRequestWithContext(ctx, http.MethodPut, updateURL, buf) + req.Header.Set("Auth-API-Token", p.token) + req.Header.Set("Content-Type", "application/json") + resp, err = p.client.Do(req) + if err != nil { + return fmt.Errorf("update record: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("update API status: %s", resp.Status) + } + } return nil } diff --git a/internal/provider/hetzner/hetzner_api_test.go b/internal/provider/hetzner/hetzner_api_test.go index 29f9596..71e0c6d 100644 --- a/internal/provider/hetzner/hetzner_api_test.go +++ b/internal/provider/hetzner/hetzner_api_test.go @@ -68,24 +68,61 @@ func TestUpdateRecordZoneNotFound(t *testing.T) { } } -func TestUpdateRecordRecordNotFound(t *testing.T) { +func TestUpdateRecordRecordNotFoundCreates(t *testing.T) { + domain := "new.example.com" + ip := "1.1.1.1" + zoneName := "example.com" + zoneID := "zone-abc" + postCalled := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/zones": - json.NewEncoder(w).Encode(map[string]interface{}{ - "zones": []map[string]string{{"id": "z", "name": "example.com"}}, - }) - case "/records": - json.NewEncoder(w).Encode(map[string]interface{}{"records": []map[string]string{}}) + switch { + case r.Method == http.MethodGet && r.URL.Path == "/zones": + // Find zone + resp := map[string]interface{}{ + "zones": []map[string]string{{"id": zoneID, "name": zoneName}}, + } + json.NewEncoder(w).Encode(resp) + case r.Method == http.MethodGet && r.URL.Path == "/records": + // Find no records + resp := map[string]interface{}{ + "records": []map[string]string{}, + } + json.NewEncoder(w).Encode(resp) + case r.Method == http.MethodPost && r.URL.Path == "/records": + // Expect creation + postCalled = true + body, _ := io.ReadAll(r.Body) + var payload map[string]interface{} // Use interface{} for mixed types (TTL is int) + json.Unmarshal(body, &payload) + if payload["name"] != "new" { // Name should be the subdomain part + t.Errorf("expected create name 'new', got %s", payload["name"]) + } + if payload["value"] != ip { + t.Errorf("expected create value %s, got %s", ip, payload["value"]) + } + if payload["type"] != "A" { + t.Errorf("expected create type 'A', got %s", payload["type"]) + } + if payload["zone_id"] != zoneID { + t.Errorf("expected create zone_id %s, got %s", zoneID, payload["zone_id"]) + } + // Respond with created record details (optional, but good practice) + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{"record": payload}) default: - http.Error(w, "not found", http.StatusNotFound) + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "unexpected request", http.StatusInternalServerError) } })) defer ts.Close() provider := hetzner.NewProviderWithURL("token", ts.URL) - err := provider.UpdateRecord(context.Background(), "missing.example.com", "1.1.1.1") - if err == nil || !strings.Contains(err.Error(), "record missing.example.com not found") { - t.Fatalf("expected record not found error, got %v", err) + err := provider.UpdateRecord(context.Background(), domain, ip) + if err != nil { + t.Fatalf("expected successful creation, but got error: %v", err) + } + if !postCalled { + t.Fatalf("expected POST /records to be called for creation, but it wasn't") } } diff --git a/internal/provider/hetzner/hetzner_test.go b/internal/provider/hetzner/hetzner_test.go index 5968f20..767d0bf 100644 --- a/internal/provider/hetzner/hetzner_test.go +++ b/internal/provider/hetzner/hetzner_test.go @@ -3,10 +3,16 @@ package hetzner_test import ( "testing" + "git.cloonar.com/cloonar/updns/internal/config" pvd "git.cloonar.com/cloonar/updns/internal/provider" "git.cloonar.com/cloonar/updns/internal/provider/hetzner" ) func TestNewProviderImplementsInterface(t *testing.T) { - var _ pvd.Provider = hetzner.NewProvider("token") + cfg := config.HetznerConfig{APIToken: "test-token"} + prov, err := hetzner.NewProvider(cfg) + if err != nil { + t.Fatalf("NewProvider failed: %v", err) + } + var _ pvd.Provider = prov } diff --git a/internal/server/server.go b/internal/server/server.go index 915922b..32663ce 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -106,9 +106,10 @@ func NewRouter(cfg *config.Config) *gin.Engine { c.JSON(http.StatusForbidden, gin.H{"status": "error", "message": "host not authorized"}) return } - prov, ok := selectProvider(cfg) - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "provider not configured"}) + prov, err := selectProvider(cfg) + if err != nil { + logger.Error("provider selection failed", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "provider configuration error"}) return } if err := prov.UpdateRecord(c.Request.Context(), req.Host, ip); err != nil { @@ -126,12 +127,12 @@ func NewRouter(cfg *config.Config) *gin.Engine { // StartServer initializes Provider and starts the HTTP server. func StartServer(cfg *config.Config) error { - prov, ok := selectProvider(cfg) - if !ok { - return fmt.Errorf("unsupported provider: %s", cfg.Upstream.Provider) + // Provider selection happens within the request handler now to handle potential config errors per request + // We could pre-validate the provider config here, but deferring allows checking file existence/permissions closer to use. + // A simple check that the provider *name* is supported is still useful. + if _, supported := configToProviderName(cfg.Upstream.Provider); !supported { + return fmt.Errorf("unsupported provider name in config: %q", cfg.Upstream.Provider) } - // drop unused to avoid compile error - _ = prov router := NewRouter(cfg) if cfg.Server.TLS.Enabled { @@ -140,12 +141,33 @@ func StartServer(cfg *config.Config) error { return router.Run(cfg.Server.BindAddress) } -// selectProvider returns the configured Provider or false if unsupported. -func selectProvider(cfg *config.Config) (pvd.Provider, bool) { - switch cfg.Upstream.Provider { +// configToProviderName checks if a provider name from the config is known. +// This is a simple check before attempting full provider initialization. +func configToProviderName(providerName string) (string, bool) { + switch providerName { case "hetzner": - return hetzner.NewProvider(cfg.Upstream.Hetzner.APIToken), true + return "hetzner", true default: - return nil, false + return "", false + } +} + +// selectProvider returns the configured Provider or an error if initialization fails. +func selectProvider(cfg *config.Config) (pvd.Provider, error) { + providerName, supported := configToProviderName(cfg.Upstream.Provider) + if !supported { + return nil, fmt.Errorf("unsupported provider: %s", cfg.Upstream.Provider) + } + + switch providerName { + case "hetzner": + prov, err := hetzner.NewProvider(cfg.Upstream.Hetzner) + if err != nil { + return nil, fmt.Errorf("initializing hetzner provider: %w", err) + } + return prov, nil + default: + // This case should technically be unreachable due to the check above + return nil, fmt.Errorf("internal error: unsupported provider %s passed initial check", providerName) } }