fix: initialize dns provider just once at startup
This commit is contained in:
@@ -47,9 +47,7 @@ type updateRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter constructs the HTTP handler with routes, middleware, logging and metrics.
|
// NewRouter constructs the HTTP handler with routes, middleware, logging and metrics.
|
||||||
func NewRouter(cfg *config.Config) *gin.Engine {
|
func NewRouter(cfg *config.Config, logger *zap.Logger, prov pvd.Provider) *gin.Engine {
|
||||||
logger, _ := zap.NewProduction()
|
|
||||||
|
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(gin.Recovery())
|
r.Use(gin.Recovery())
|
||||||
r.Use(func(c *gin.Context) {
|
r.Use(func(c *gin.Context) {
|
||||||
@@ -106,12 +104,7 @@ func NewRouter(cfg *config.Config) *gin.Engine {
|
|||||||
c.JSON(http.StatusForbidden, gin.H{"status": "error", "message": "host not authorized"})
|
c.JSON(http.StatusForbidden, gin.H{"status": "error", "message": "host not authorized"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prov, err := selectProvider(cfg)
|
// Provider is now initialized at startup and passed in
|
||||||
if err != nil {
|
|
||||||
logger.Error("provider selection failed", zap.Error(err))
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "provider configuration error"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := prov.UpdateRecord(c.Request.Context(), req.Host, ip); err != nil {
|
if err := prov.UpdateRecord(c.Request.Context(), req.Host, ip); err != nil {
|
||||||
failedUpdates.Inc()
|
failedUpdates.Inc()
|
||||||
logger.Error("update record failed", zap.Error(err))
|
logger.Error("update record failed", zap.Error(err))
|
||||||
@@ -127,39 +120,29 @@ func NewRouter(cfg *config.Config) *gin.Engine {
|
|||||||
|
|
||||||
// StartServer initializes Provider and starts the HTTP server.
|
// StartServer initializes Provider and starts the HTTP server.
|
||||||
func StartServer(cfg *config.Config) error {
|
func StartServer(cfg *config.Config) error {
|
||||||
// Provider selection happens within the request handler now to handle potential config errors per request
|
logger, _ := zap.NewProduction() // Initialize logger once
|
||||||
// We could pre-validate the provider config here, but deferring allows checking file existence/permissions closer to use.
|
|
||||||
// A simple check that the provider *name* is supported is still useful.
|
|
||||||
if _, supported := configToProviderName(cfg.Upstream.Provider); !supported {
|
|
||||||
return fmt.Errorf("unsupported provider name in config: %q", cfg.Upstream.Provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
router := NewRouter(cfg)
|
// Initialize provider at startup
|
||||||
|
prov, err := selectProvider(cfg)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to initialize provider", zap.Error(err))
|
||||||
|
return fmt.Errorf("provider initialization failed: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("provider initialized successfully", zap.String("provider", cfg.Upstream.Provider))
|
||||||
|
|
||||||
|
router := NewRouter(cfg, logger, prov) // Pass logger and provider
|
||||||
if cfg.Server.TLS.Enabled {
|
if cfg.Server.TLS.Enabled {
|
||||||
|
logger.Info("starting TLS server", zap.String("address", cfg.Server.BindAddress))
|
||||||
return router.RunTLS(cfg.Server.BindAddress, cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
|
return router.RunTLS(cfg.Server.BindAddress, cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
|
||||||
}
|
}
|
||||||
|
logger.Info("starting HTTP server", zap.String("address", cfg.Server.BindAddress))
|
||||||
return router.Run(cfg.Server.BindAddress)
|
return router.Run(cfg.Server.BindAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// configToProviderName checks if a provider name from the config is known.
|
|
||||||
// This is a simple check before attempting full provider initialization.
|
|
||||||
func configToProviderName(providerName string) (string, bool) {
|
|
||||||
switch providerName {
|
|
||||||
case "hetzner":
|
|
||||||
return "hetzner", true
|
|
||||||
default:
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectProvider returns the configured Provider or an error if initialization fails.
|
// selectProvider returns the configured Provider or an error if initialization fails.
|
||||||
func selectProvider(cfg *config.Config) (pvd.Provider, error) {
|
func selectProvider(cfg *config.Config) (pvd.Provider, error) {
|
||||||
providerName, supported := configToProviderName(cfg.Upstream.Provider)
|
// configToProviderName logic is effectively duplicated here, safe to remove the separate function if only used here.
|
||||||
if !supported {
|
switch cfg.Upstream.Provider {
|
||||||
return nil, fmt.Errorf("unsupported provider: %s", cfg.Upstream.Provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch providerName {
|
|
||||||
case "hetzner":
|
case "hetzner":
|
||||||
prov, err := hetzner.NewProvider(cfg.Upstream.Hetzner)
|
prov, err := hetzner.NewProvider(cfg.Upstream.Hetzner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -168,6 +151,6 @@ func selectProvider(cfg *config.Config) (pvd.Provider, error) {
|
|||||||
return prov, nil
|
return prov, nil
|
||||||
default:
|
default:
|
||||||
// This case should technically be unreachable due to the check above
|
// This case should technically be unreachable due to the check above
|
||||||
return nil, fmt.Errorf("internal error: unsupported provider %s passed initial check", providerName)
|
return nil, fmt.Errorf("internal error: unsupported provider %s passed initial check", cfg.Upstream.Provider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,27 @@ package server_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context" // Added for mock provider
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.cloonar.com/cloonar/updns/internal/config"
|
"git.cloonar.com/cloonar/updns/internal/config" // Added for mock provider
|
||||||
"git.cloonar.com/cloonar/updns/internal/server"
|
"git.cloonar.com/cloonar/updns/internal/server"
|
||||||
|
"go.uber.org/zap" // Added for logger
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mockProvider is a simple mock for testing router logic without real provider interaction.
|
||||||
|
type mockProvider struct {
|
||||||
|
updateErr error // Allows simulating update errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) UpdateRecord(ctx context.Context, host, ip string) error {
|
||||||
|
return m.updateErr
|
||||||
|
}
|
||||||
|
|
||||||
func newTestConfig(provider string) *config.Config {
|
func newTestConfig(provider string) *config.Config {
|
||||||
return &config.Config{
|
return &config.Config{
|
||||||
Server: config.ServerConfig{
|
Server: config.ServerConfig{
|
||||||
@@ -32,7 +44,9 @@ func newTestConfig(provider string) *config.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMetricsEndpoint(t *testing.T) {
|
func TestMetricsEndpoint(t *testing.T) {
|
||||||
r := server.NewRouter(newTestConfig("unknown"))
|
logger := zap.NewNop()
|
||||||
|
mockProv := &mockProvider{}
|
||||||
|
r := server.NewRouter(newTestConfig("unknown"), logger, mockProv)
|
||||||
req := httptest.NewRequest("GET", "/metrics", nil)
|
req := httptest.NewRequest("GET", "/metrics", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
@@ -42,7 +56,9 @@ func TestMetricsEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateInvalidJSON(t *testing.T) {
|
func TestUpdateInvalidJSON(t *testing.T) {
|
||||||
r := server.NewRouter(newTestConfig("unknown"))
|
logger := zap.NewNop()
|
||||||
|
mockProv := &mockProvider{}
|
||||||
|
r := server.NewRouter(newTestConfig("unknown"), logger, mockProv)
|
||||||
req := httptest.NewRequest("POST", "/update", bytes.NewBufferString("{invalid"))
|
req := httptest.NewRequest("POST", "/update", bytes.NewBufferString("{invalid"))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -53,7 +69,9 @@ func TestUpdateInvalidJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateUnauthorizedKey(t *testing.T) {
|
func TestUpdateUnauthorizedKey(t *testing.T) {
|
||||||
r := server.NewRouter(newTestConfig("unknown"))
|
logger := zap.NewNop()
|
||||||
|
mockProv := &mockProvider{}
|
||||||
|
r := server.NewRouter(newTestConfig("unknown"), logger, mockProv)
|
||||||
body := map[string]string{"key": "bad", "secret": "x", "host": "a.example.com"}
|
body := map[string]string{"key": "bad", "secret": "x", "host": "a.example.com"}
|
||||||
data, _ := json.Marshal(body)
|
data, _ := json.Marshal(body)
|
||||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||||
@@ -66,7 +84,9 @@ func TestUpdateUnauthorizedKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateHostForbidden(t *testing.T) {
|
func TestUpdateHostForbidden(t *testing.T) {
|
||||||
r := server.NewRouter(newTestConfig("unknown"))
|
logger := zap.NewNop()
|
||||||
|
mockProv := &mockProvider{}
|
||||||
|
r := server.NewRouter(newTestConfig("unknown"), logger, mockProv)
|
||||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "bad.example.com"}
|
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "bad.example.com"}
|
||||||
data, _ := json.Marshal(body)
|
data, _ := json.Marshal(body)
|
||||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||||
@@ -78,8 +98,16 @@ func TestUpdateHostForbidden(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateProviderNotConfigured(t *testing.T) {
|
// Note: This test case name is slightly misleading now. The router itself doesn't care
|
||||||
r := server.NewRouter(newTestConfig("unknown"))
|
// about the provider name from config anymore, as the provider is passed in.
|
||||||
|
// The provider initialization logic is tested separately or in StartServer tests.
|
||||||
|
// We keep the test to ensure the update path works correctly when auth passes.
|
||||||
|
// If the *mock* provider returned an error, we'd get 500.
|
||||||
|
func TestUpdateProviderInteraction(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
// Simulate a provider error
|
||||||
|
mockProv := &mockProvider{updateErr: errors.New("simulated provider update error")}
|
||||||
|
r := server.NewRouter(newTestConfig("hetzner"), logger, mockProv) // Config provider name doesn't matter here
|
||||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com"}
|
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com"}
|
||||||
data, _ := json.Marshal(body)
|
data, _ := json.Marshal(body)
|
||||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||||
@@ -92,7 +120,9 @@ func TestUpdateProviderNotConfigured(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateSuccess(t *testing.T) {
|
func TestUpdateSuccess(t *testing.T) {
|
||||||
r := server.NewRouter(newTestConfig("hetzner"))
|
logger := zap.NewNop()
|
||||||
|
mockProv := &mockProvider{} // No error for success case
|
||||||
|
r := server.NewRouter(newTestConfig("hetzner"), logger, mockProv)
|
||||||
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com", "ip": "1.2.3.4"}
|
body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com", "ip": "1.2.3.4"}
|
||||||
data, _ := json.Marshal(body)
|
data, _ := json.Marshal(body)
|
||||||
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))
|
||||||
|
|||||||
Reference in New Issue
Block a user