diff --git a/src/middleware/__tests__/auth.test.ts b/src/middleware/__tests__/auth.test.ts index 1d2e587..aa98403 100644 --- a/src/middleware/__tests__/auth.test.ts +++ b/src/middleware/__tests__/auth.test.ts @@ -1,29 +1,21 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' -import { authMiddleware } from '../auth.js' -// Mock the keys service vi.mock('../../services/keys.js', () => ({ isValidKey: vi.fn(), getKeyInfo: vi.fn(), })) -const { isValidKey, getKeyInfo } = await import('../../services/keys.js') -const mockIsValidKey = vi.mocked(isValidKey) -const mockGetKeyInfo = vi.mocked(getKeyInfo) +import { authMiddleware } from '../auth.js' +import { isValidKey, getKeyInfo } from '../../services/keys.js' -function mockReq(overrides: any = {}): any { - return { - headers: {}, - query: {}, - ...overrides, - } -} - -function mockRes(): any { - const res: any = {} - res.status = vi.fn().mockReturnValue(res) - res.json = vi.fn().mockReturnValue(res) - return res +function mockReqResNext(overrides: Partial<{ headers: any; query: any }> = {}) { + const req = { headers: {}, query: {}, ...overrides } as any + const res = { + status: vi.fn().mockReturnThis(), + json: vi.fn().mockReturnThis(), + } as any + const next = vi.fn() + return { req, res, next } } describe('authMiddleware', () => { @@ -31,88 +23,75 @@ describe('authMiddleware', () => { vi.clearAllMocks() }) - it('should return 401 when no API key is provided', async () => { - const req = mockReq() - const res = mockRes() - const next = vi.fn() - + it('returns 401 when no API key is provided', async () => { + const { req, res, next } = mockReqResNext() await authMiddleware(req, res, next) - expect(res.status).toHaveBeenCalledWith(401) expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ error: expect.stringContaining('Missing API key') })) expect(next).not.toHaveBeenCalled() }) - it('should extract key from Bearer authorization header', async () => { - mockIsValidKey.mockResolvedValueOnce(true) - mockGetKeyInfo.mockResolvedValueOnce({ key: 'snap_abc123', tier: 'pro', email: 'test@test.com', createdAt: '2024-01-01' }) - - const req = mockReq({ headers: { authorization: 'Bearer snap_abc123' } }) - const res = mockRes() - const next = vi.fn() - + it('extracts key from Authorization Bearer header', async () => { + vi.mocked(isValidKey).mockResolvedValue(true) + vi.mocked(getKeyInfo).mockResolvedValue({ key: 'test-key', tier: 'free', email: 'a@b.com', createdAt: '' }) + const { req, res, next } = mockReqResNext({ headers: { authorization: 'Bearer test-key' } }) await authMiddleware(req, res, next) - - expect(mockIsValidKey).toHaveBeenCalledWith('snap_abc123') + expect(isValidKey).toHaveBeenCalledWith('test-key') expect(next).toHaveBeenCalled() expect(req.apiKeyInfo).toBeDefined() }) - it('should extract key from X-API-Key header', async () => { - mockIsValidKey.mockResolvedValueOnce(true) - mockGetKeyInfo.mockResolvedValueOnce({ key: 'snap_xyz789', tier: 'starter', email: 'a@b.com', createdAt: '2024-01-01' }) - - const req = mockReq({ headers: { 'x-api-key': 'snap_xyz789' } }) - const res = mockRes() - const next = vi.fn() - + it('extracts key from X-API-Key header', async () => { + vi.mocked(isValidKey).mockResolvedValue(true) + vi.mocked(getKeyInfo).mockResolvedValue({ key: 'xkey', tier: 'pro', email: 'a@b.com', createdAt: '' }) + const { req, res, next } = mockReqResNext({ headers: { 'x-api-key': 'xkey' } }) await authMiddleware(req, res, next) - - expect(mockIsValidKey).toHaveBeenCalledWith('snap_xyz789') + expect(isValidKey).toHaveBeenCalledWith('xkey') expect(next).toHaveBeenCalled() }) - it('should extract key from query parameter', async () => { - mockIsValidKey.mockResolvedValueOnce(true) - mockGetKeyInfo.mockResolvedValueOnce({ key: 'snap_qp1', tier: 'business', email: 'c@d.com', createdAt: '2024-01-01' }) - - const req = mockReq({ query: { key: 'snap_qp1' } }) - const res = mockRes() - const next = vi.fn() - + it('extracts key from query parameter', async () => { + vi.mocked(isValidKey).mockResolvedValue(true) + vi.mocked(getKeyInfo).mockResolvedValue({ key: 'qkey', tier: 'starter', email: 'a@b.com', createdAt: '' }) + const { req, res, next } = mockReqResNext({ query: { key: 'qkey' } }) await authMiddleware(req, res, next) - - expect(mockIsValidKey).toHaveBeenCalledWith('snap_qp1') + expect(isValidKey).toHaveBeenCalledWith('qkey') expect(next).toHaveBeenCalled() }) - it('should return 403 for invalid API key', async () => { - mockIsValidKey.mockResolvedValueOnce(false) - - const req = mockReq({ headers: { 'x-api-key': 'invalid_key' } }) - const res = mockRes() - const next = vi.fn() - + it('prefers Bearer header over X-API-Key and query', async () => { + vi.mocked(isValidKey).mockResolvedValue(true) + vi.mocked(getKeyInfo).mockResolvedValue({ key: 'bearer-key', tier: 'free', email: 'a@b.com', createdAt: '' }) + const { req, res, next } = mockReqResNext({ + headers: { authorization: 'Bearer bearer-key', 'x-api-key': 'other' }, + query: { key: 'another' }, + }) await authMiddleware(req, res, next) + expect(isValidKey).toHaveBeenCalledWith('bearer-key') + }) + it('returns 403 for invalid API key', async () => { + vi.mocked(isValidKey).mockResolvedValue(false) + const { req, res, next } = mockReqResNext({ headers: { authorization: 'Bearer bad-key' } }) + await authMiddleware(req, res, next) expect(res.status).toHaveBeenCalledWith(403) - expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ error: expect.stringContaining('Invalid API key') })) + expect(res.json).toHaveBeenCalledWith({ error: 'Invalid API key' }) expect(next).not.toHaveBeenCalled() }) - it('should prefer Bearer header over X-API-Key and query', async () => { - mockIsValidKey.mockResolvedValueOnce(true) - mockGetKeyInfo.mockResolvedValueOnce({ key: 'snap_bearer', tier: 'pro', email: 'e@f.com', createdAt: '2024-01-01' }) - - const req = mockReq({ - headers: { authorization: 'Bearer snap_bearer', 'x-api-key': 'snap_xapi' }, - query: { key: 'snap_query' } - }) - const res = mockRes() - const next = vi.fn() - + it('attaches apiKeyInfo to request on success', async () => { + const info = { key: 'k', tier: 'business' as const, email: 'x@y.com', createdAt: '2025-01-01' } + vi.mocked(isValidKey).mockResolvedValue(true) + vi.mocked(getKeyInfo).mockResolvedValue(info) + const { req, res, next } = mockReqResNext({ headers: { authorization: 'Bearer k' } }) await authMiddleware(req, res, next) + expect(req.apiKeyInfo).toEqual(info) + }) - expect(mockIsValidKey).toHaveBeenCalledWith('snap_bearer') + it('returns 401 when Authorization header is not Bearer', async () => { + const { req, res, next } = mockReqResNext({ headers: { authorization: 'Basic abc123' } }) + await authMiddleware(req, res, next) + expect(res.status).toHaveBeenCalledWith(401) + expect(next).not.toHaveBeenCalled() }) }) diff --git a/src/middleware/__tests__/compression.test.ts b/src/middleware/__tests__/compression.test.ts new file mode 100644 index 0000000..dde95ca --- /dev/null +++ b/src/middleware/__tests__/compression.test.ts @@ -0,0 +1,65 @@ +import { describe, it, expect } from 'vitest' +import express from 'express' +import request from 'supertest' +import { compressionMiddleware } from '../compression.js' + +function createApp() { + const app = express() + app.use(compressionMiddleware) + + app.get('/text', (_req, res) => { + // Send enough data to exceed 1024 byte threshold + res.type('text/html').send('x'.repeat(2000)) + }) + + app.get('/small', (_req, res) => { + res.type('text/html').send('small') + }) + + app.get('/image', (_req, res) => { + res.type('image/png').send(Buffer.alloc(2000)) + }) + + app.get('/json', (_req, res) => { + res.type('application/json').send(JSON.stringify({ data: 'y'.repeat(2000) })) + }) + + return app +} + +describe('compressionMiddleware', () => { + it('compresses text responses above threshold', async () => { + const res = await request(createApp()) + .get('/text') + .set('Accept-Encoding', 'gzip') + expect(res.headers['content-encoding']).toBe('gzip') + }) + + it('does not compress responses below threshold', async () => { + const res = await request(createApp()) + .get('/small') + .set('Accept-Encoding', 'gzip') + expect(res.headers['content-encoding']).toBeUndefined() + }) + + it('does not compress image responses', async () => { + const res = await request(createApp()) + .get('/image') + .set('Accept-Encoding', 'gzip') + expect(res.headers['content-encoding']).toBeUndefined() + }) + + it('compresses JSON responses above threshold', async () => { + const res = await request(createApp()) + .get('/json') + .set('Accept-Encoding', 'gzip') + expect(res.headers['content-encoding']).toBe('gzip') + }) + + it('does not compress when client does not accept gzip', async () => { + const res = await request(createApp()) + .get('/text') + .set('Accept-Encoding', 'identity') + expect(res.headers['content-encoding']).toBeUndefined() + }) +}) diff --git a/src/middleware/__tests__/usage.test.ts b/src/middleware/__tests__/usage.test.ts new file mode 100644 index 0000000..1d902ee --- /dev/null +++ b/src/middleware/__tests__/usage.test.ts @@ -0,0 +1,127 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' + +vi.mock('../../services/db.js', () => ({ + queryWithRetry: vi.fn(), + connectWithRetry: vi.fn(), +})) + +vi.mock('../../services/keys.js', () => ({ + getTierLimit: vi.fn(), +})) + +vi.mock('../../services/logger.js', () => ({ + default: { info: vi.fn(), error: vi.fn(), warn: vi.fn() }, +})) + +// Must import after mocks +import { usageMiddleware, loadUsageData, getUsageForKey } from '../usage.js' +import { queryWithRetry } from '../../services/db.js' +import { getTierLimit } from '../../services/keys.js' + +function mockReqResNext(apiKeyInfo?: any) { + const req = { apiKeyInfo } as any + const res = { + status: vi.fn().mockReturnThis(), + json: vi.fn().mockReturnThis(), + setHeader: vi.fn(), + } as any + const next = vi.fn() + return { req, res, next } +} + +describe('usageMiddleware', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('calls next when no apiKeyInfo on request', () => { + const { req, res, next } = mockReqResNext() + usageMiddleware(req, res, next) + expect(next).toHaveBeenCalled() + expect(res.status).not.toHaveBeenCalled() + }) + + it('tracks usage and sets headers', () => { + vi.mocked(getTierLimit).mockReturnValue(100) + const { req, res, next } = mockReqResNext({ key: 'new-key', tier: 'free' }) + usageMiddleware(req, res, next) + expect(next).toHaveBeenCalled() + expect(res.setHeader).toHaveBeenCalledWith('X-Usage-Count', '1') + expect(res.setHeader).toHaveBeenCalledWith('X-Usage-Limit', '100') + }) + + it('increments count on repeated calls', () => { + vi.mocked(getTierLimit).mockReturnValue(100) + const { req: req1, res: res1, next: next1 } = mockReqResNext({ key: 'inc-key', tier: 'free' }) + usageMiddleware(req1, res1, next1) + + const { req: req2, res: res2, next: next2 } = mockReqResNext({ key: 'inc-key', tier: 'free' }) + usageMiddleware(req2, res2, next2) + expect(res2.setHeader).toHaveBeenCalledWith('X-Usage-Count', '2') + }) + + it('returns 429 when usage limit is reached', () => { + vi.mocked(getTierLimit).mockReturnValue(1) + // First call uses up the limit + const { req: req1, res: res1, next: next1 } = mockReqResNext({ key: 'limit-key', tier: 'free' }) + usageMiddleware(req1, res1, next1) + expect(next1).toHaveBeenCalled() + + // Second call should be rate limited + const { req: req2, res: res2, next: next2 } = mockReqResNext({ key: 'limit-key', tier: 'free' }) + usageMiddleware(req2, res2, next2) + expect(res2.status).toHaveBeenCalledWith(429) + expect(res2.json).toHaveBeenCalledWith(expect.objectContaining({ + error: expect.stringContaining('Monthly limit reached'), + usage: 1, + limit: 1, + })) + expect(next2).not.toHaveBeenCalled() + }) + + it('resets count for a new month', () => { + vi.mocked(getTierLimit).mockReturnValue(1) + // Use a unique key to avoid state from other tests + const { req: req1, res: res1, next: next1 } = mockReqResNext({ key: 'month-key', tier: 'free' }) + usageMiddleware(req1, res1, next1) + + // Simulate month change using fake timers + vi.useFakeTimers() + vi.setSystemTime(new Date('2099-02-15T12:00:00Z')) + + const { req: req2, res: res2, next: next2 } = mockReqResNext({ key: 'month-key', tier: 'free' }) + usageMiddleware(req2, res2, next2) + expect(next2).toHaveBeenCalled() + expect(res2.setHeader).toHaveBeenCalledWith('X-Usage-Count', '1') + + vi.useRealTimers() + }) +}) + +describe('loadUsageData', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('loads usage data from database', async () => { + vi.mocked(queryWithRetry).mockResolvedValue({ + rows: [{ key: 'db-key', count: 42, month_key: '2026-03' }], + } as any) + await loadUsageData() + const record = getUsageForKey('db-key') + expect(record).toEqual({ count: 42, monthKey: '2026-03' }) + }) + + it('handles database errors gracefully', async () => { + vi.mocked(queryWithRetry).mockRejectedValue(new Error('DB down')) + await loadUsageData() + // Should not throw, usage map should be empty + expect(getUsageForKey('any-key')).toBeUndefined() + }) +}) + +describe('getUsageForKey', () => { + it('returns undefined for unknown key', () => { + expect(getUsageForKey('nonexistent-key-xyz')).toBeUndefined() + }) +})