diff --git a/internal/server/server.go b/internal/server/server.go index 32663ce..d2b156e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -47,9 +47,7 @@ type updateRequest struct { } // NewRouter constructs the HTTP handler with routes, middleware, logging and metrics. -func NewRouter(cfg *config.Config) *gin.Engine { - logger, _ := zap.NewProduction() - +func NewRouter(cfg *config.Config, logger *zap.Logger, prov pvd.Provider) *gin.Engine { r := gin.New() r.Use(gin.Recovery()) r.Use(func(c *gin.Context) { @@ -106,12 +104,7 @@ func NewRouter(cfg *config.Config) *gin.Engine { c.JSON(http.StatusForbidden, gin.H{"status": "error", "message": "host not authorized"}) return } - prov, err := selectProvider(cfg) - if err != nil { - logger.Error("provider selection failed", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "message": "provider configuration error"}) - return - } + // Provider is now initialized at startup and passed in if err := prov.UpdateRecord(c.Request.Context(), req.Host, ip); err != nil { failedUpdates.Inc() logger.Error("update record failed", zap.Error(err)) @@ -127,39 +120,29 @@ func NewRouter(cfg *config.Config) *gin.Engine { // StartServer initializes Provider and starts the HTTP server. func StartServer(cfg *config.Config) error { - // Provider selection happens within the request handler now to handle potential config errors per request - // We could pre-validate the provider config here, but deferring allows checking file existence/permissions closer to use. - // A simple check that the provider *name* is supported is still useful. - if _, supported := configToProviderName(cfg.Upstream.Provider); !supported { - return fmt.Errorf("unsupported provider name in config: %q", cfg.Upstream.Provider) - } + logger, _ := zap.NewProduction() // Initialize logger once - router := NewRouter(cfg) + // Initialize provider at startup + prov, err := selectProvider(cfg) + if err != nil { + logger.Error("failed to initialize provider", zap.Error(err)) + return fmt.Errorf("provider initialization failed: %w", err) + } + logger.Info("provider initialized successfully", zap.String("provider", cfg.Upstream.Provider)) + + router := NewRouter(cfg, logger, prov) // Pass logger and provider if cfg.Server.TLS.Enabled { + logger.Info("starting TLS server", zap.String("address", cfg.Server.BindAddress)) return router.RunTLS(cfg.Server.BindAddress, cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile) } + logger.Info("starting HTTP server", zap.String("address", cfg.Server.BindAddress)) return router.Run(cfg.Server.BindAddress) } -// configToProviderName checks if a provider name from the config is known. -// This is a simple check before attempting full provider initialization. -func configToProviderName(providerName string) (string, bool) { - switch providerName { - case "hetzner": - return "hetzner", true - default: - return "", false - } -} - // selectProvider returns the configured Provider or an error if initialization fails. func selectProvider(cfg *config.Config) (pvd.Provider, error) { - providerName, supported := configToProviderName(cfg.Upstream.Provider) - if !supported { - return nil, fmt.Errorf("unsupported provider: %s", cfg.Upstream.Provider) - } - - switch providerName { + // configToProviderName logic is effectively duplicated here, safe to remove the separate function if only used here. + switch cfg.Upstream.Provider { case "hetzner": prov, err := hetzner.NewProvider(cfg.Upstream.Hetzner) if err != nil { @@ -168,6 +151,6 @@ func selectProvider(cfg *config.Config) (pvd.Provider, error) { return prov, nil default: // This case should technically be unreachable due to the check above - return nil, fmt.Errorf("internal error: unsupported provider %s passed initial check", providerName) + return nil, fmt.Errorf("internal error: unsupported provider %s passed initial check", cfg.Upstream.Provider) } } diff --git a/internal/server/server_router_test.go b/internal/server/server_router_test.go index 17fcbe1..063ac37 100644 --- a/internal/server/server_router_test.go +++ b/internal/server/server_router_test.go @@ -2,15 +2,27 @@ package server_test import ( "bytes" + "context" // Added for mock provider "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" - "git.cloonar.com/cloonar/updns/internal/config" + "git.cloonar.com/cloonar/updns/internal/config" // Added for mock provider "git.cloonar.com/cloonar/updns/internal/server" + "go.uber.org/zap" // Added for logger ) +// mockProvider is a simple mock for testing router logic without real provider interaction. +type mockProvider struct { + updateErr error // Allows simulating update errors +} + +func (m *mockProvider) UpdateRecord(ctx context.Context, host, ip string) error { + return m.updateErr +} + func newTestConfig(provider string) *config.Config { return &config.Config{ Server: config.ServerConfig{ @@ -32,7 +44,9 @@ func newTestConfig(provider string) *config.Config { } func TestMetricsEndpoint(t *testing.T) { - r := server.NewRouter(newTestConfig("unknown")) + logger := zap.NewNop() + mockProv := &mockProvider{} + r := server.NewRouter(newTestConfig("unknown"), logger, mockProv) req := httptest.NewRequest("GET", "/metrics", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -42,7 +56,9 @@ func TestMetricsEndpoint(t *testing.T) { } func TestUpdateInvalidJSON(t *testing.T) { - r := server.NewRouter(newTestConfig("unknown")) + logger := zap.NewNop() + mockProv := &mockProvider{} + r := server.NewRouter(newTestConfig("unknown"), logger, mockProv) req := httptest.NewRequest("POST", "/update", bytes.NewBufferString("{invalid")) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -53,7 +69,9 @@ func TestUpdateInvalidJSON(t *testing.T) { } func TestUpdateUnauthorizedKey(t *testing.T) { - r := server.NewRouter(newTestConfig("unknown")) + logger := zap.NewNop() + mockProv := &mockProvider{} + r := server.NewRouter(newTestConfig("unknown"), logger, mockProv) body := map[string]string{"key": "bad", "secret": "x", "host": "a.example.com"} data, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data)) @@ -66,7 +84,9 @@ func TestUpdateUnauthorizedKey(t *testing.T) { } func TestUpdateHostForbidden(t *testing.T) { - r := server.NewRouter(newTestConfig("unknown")) + logger := zap.NewNop() + mockProv := &mockProvider{} + r := server.NewRouter(newTestConfig("unknown"), logger, mockProv) body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "bad.example.com"} data, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data)) @@ -78,8 +98,16 @@ func TestUpdateHostForbidden(t *testing.T) { } } -func TestUpdateProviderNotConfigured(t *testing.T) { - r := server.NewRouter(newTestConfig("unknown")) +// Note: This test case name is slightly misleading now. The router itself doesn't care +// about the provider name from config anymore, as the provider is passed in. +// The provider initialization logic is tested separately or in StartServer tests. +// We keep the test to ensure the update path works correctly when auth passes. +// If the *mock* provider returned an error, we'd get 500. +func TestUpdateProviderInteraction(t *testing.T) { + logger := zap.NewNop() + // Simulate a provider error + mockProv := &mockProvider{updateErr: errors.New("simulated provider update error")} + r := server.NewRouter(newTestConfig("hetzner"), logger, mockProv) // Config provider name doesn't matter here body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com"} data, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data)) @@ -92,7 +120,9 @@ func TestUpdateProviderNotConfigured(t *testing.T) { } func TestUpdateSuccess(t *testing.T) { - r := server.NewRouter(newTestConfig("hetzner")) + logger := zap.NewNop() + mockProv := &mockProvider{} // No error for success case + r := server.NewRouter(newTestConfig("hetzner"), logger, mockProv) body := map[string]string{"key": "client1", "secret": "s3cr3t", "host": "a.example.com", "ip": "1.2.3.4"} data, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/update", bytes.NewBuffer(data))