"""Template matching for UI element detection in game screenshots. Provides efficient template matching using OpenCV with support for multiple templates, confidence thresholds, and template management. """ from typing import List, Dict, Optional, Tuple, NamedTuple from pathlib import Path import logging from dataclasses import dataclass import cv2 import numpy as np logger = logging.getLogger(__name__) class TemplateMatch(NamedTuple): """Represents a template match with position and confidence.""" template_name: str confidence: float center: Tuple[int, int] # (x, y) center position bbox: Tuple[int, int, int, int] # (x, y, width, height) @dataclass class TemplateInfo: """Information about a loaded template.""" name: str image: np.ndarray width: int height: int path: Optional[str] = None class TemplateMatcher: """Core template matching functionality.""" def __init__(self, method: int = cv2.TM_CCOEFF_NORMED, threshold: float = 0.8): """Initialize template matcher. Args: method: OpenCV template matching method threshold: Minimum confidence threshold (0.0 to 1.0) """ self.method = method self.threshold = threshold def match_template(self, image: np.ndarray, template: np.ndarray, threshold: Optional[float] = None) -> List[TemplateMatch]: """Match single template in image. Args: image: Source image to search in template: Template image to find threshold: Confidence threshold override Returns: List of matches found """ if threshold is None: threshold = self.threshold # Convert to grayscale if needed if len(image.shape) == 3: image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: image_gray = image if len(template.shape) == 3: template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY) else: template_gray = template # Perform template matching result = cv2.matchTemplate(image_gray, template_gray, self.method) # Find matches above threshold locations = np.where(result >= threshold) matches = [] template_h, template_w = template_gray.shape for pt in zip(*locations[::-1]): # Switch x and y x, y = pt confidence = result[y, x] center = (x + template_w // 2, y + template_h // 2) bbox = (x, y, template_w, template_h) matches.append(TemplateMatch("", confidence, center, bbox)) # Remove overlapping matches (Non-Maximum Suppression) matches = self._apply_nms(matches, overlap_threshold=0.3) return matches def match_multiple_scales(self, image: np.ndarray, template: np.ndarray, scales: List[float] = None, threshold: Optional[float] = None) -> List[TemplateMatch]: """Match template at multiple scales. Args: image: Source image template: Template image scales: List of scale factors to try threshold: Confidence threshold Returns: List of matches at all scales """ if scales is None: scales = [0.8, 0.9, 1.0, 1.1, 1.2] all_matches = [] for scale in scales: # Scale template new_width = int(template.shape[1] * scale) new_height = int(template.shape[0] * scale) if new_width < 10 or new_height < 10: continue # Skip very small templates scaled_template = cv2.resize(template, (new_width, new_height)) # Find matches at this scale matches = self.match_template(image, scaled_template, threshold) all_matches.extend(matches) # Apply NMS across all scales all_matches = self._apply_nms(all_matches, overlap_threshold=0.5) return all_matches def _apply_nms(self, matches: List[TemplateMatch], overlap_threshold: float = 0.3) -> List[TemplateMatch]: """Apply Non-Maximum Suppression to remove overlapping matches. Args: matches: List of template matches overlap_threshold: Maximum allowed overlap ratio Returns: Filtered list of matches """ if not matches: return matches # Sort by confidence (highest first) matches = sorted(matches, key=lambda x: x.confidence, reverse=True) filtered_matches = [] for match in matches: # Check if this match overlaps significantly with any kept match is_duplicate = False for kept_match in filtered_matches: if self._calculate_overlap(match, kept_match) > overlap_threshold: is_duplicate = True break if not is_duplicate: filtered_matches.append(match) return filtered_matches def _calculate_overlap(self, match1: TemplateMatch, match2: TemplateMatch) -> float: """Calculate overlap ratio between two matches. Args: match1: First match match2: Second match Returns: Overlap ratio (0.0 to 1.0) """ x1, y1, w1, h1 = match1.bbox x2, y2, w2, h2 = match2.bbox # Calculate intersection left = max(x1, x2) right = min(x1 + w1, x2 + w2) top = max(y1, y2) bottom = min(y1 + h1, y2 + h2) if left >= right or top >= bottom: return 0.0 intersection = (right - left) * (bottom - top) area1 = w1 * h1 area2 = w2 * h2 union = area1 + area2 - intersection return intersection / union if union > 0 else 0.0 class TemplateManager: """Manages a collection of templates for game UI detection.""" def __init__(self, template_dir: Optional[Path] = None): """Initialize template manager. Args: template_dir: Directory containing template images """ self.template_dir = template_dir self.templates: Dict[str, TemplateInfo] = {} self.matcher = TemplateMatcher() if template_dir and template_dir.exists(): self.load_templates_from_directory(template_dir) def load_template(self, name: str, image_path: Path) -> bool: """Load single template from file. Args: name: Template identifier image_path: Path to template image Returns: True if loaded successfully """ try: image = cv2.imread(str(image_path)) if image is None: logger.error(f"Could not load template image: {image_path}") return False height, width = image.shape[:2] self.templates[name] = TemplateInfo( name=name, image=image, width=width, height=height, path=str(image_path) ) logger.info(f"Loaded template '{name}' ({width}x{height})") return True except Exception as e: logger.error(f"Failed to load template '{name}': {e}") return False def load_templates_from_directory(self, directory: Path) -> int: """Load all templates from directory. Args: directory: Directory containing template images Returns: Number of templates loaded """ loaded_count = 0 for image_path in directory.glob("*.png"): template_name = image_path.stem if self.load_template(template_name, image_path): loaded_count += 1 logger.info(f"Loaded {loaded_count} templates from {directory}") return loaded_count def find_template(self, image: np.ndarray, template_name: str, threshold: Optional[float] = None) -> List[TemplateMatch]: """Find specific template in image. Args: image: Source image template_name: Name of template to find threshold: Confidence threshold override Returns: List of matches found """ if template_name not in self.templates: logger.warning(f"Template '{template_name}' not found") return [] template_info = self.templates[template_name] matches = self.matcher.match_template(image, template_info.image, threshold) # Set template name in matches named_matches = [] for match in matches: named_match = TemplateMatch( template_name=template_name, confidence=match.confidence, center=match.center, bbox=match.bbox ) named_matches.append(named_match) return named_matches def find_any_template(self, image: np.ndarray, template_names: Optional[List[str]] = None, threshold: Optional[float] = None) -> List[TemplateMatch]: """Find any of the specified templates in image. Args: image: Source image template_names: List of template names to search for, or None for all threshold: Confidence threshold override Returns: List of all matches found """ if template_names is None: template_names = list(self.templates.keys()) all_matches = [] for template_name in template_names: matches = self.find_template(image, template_name, threshold) all_matches.extend(matches) # Sort by confidence all_matches.sort(key=lambda x: x.confidence, reverse=True) return all_matches def wait_for_template(self, capture_func, template_name: str, timeout: float = 10.0, check_interval: float = 0.5, threshold: Optional[float] = None) -> Optional[TemplateMatch]: """Wait for template to appear on screen. Args: capture_func: Function that returns screenshot template_name: Template to wait for timeout: Maximum wait time in seconds check_interval: Time between checks in seconds threshold: Confidence threshold override Returns: First match found, or None if timeout """ import time start_time = time.time() while time.time() - start_time < timeout: image = capture_func() matches = self.find_template(image, template_name, threshold) if matches: return matches[0] # Return best match time.sleep(check_interval) return None def get_template_info(self, template_name: str) -> Optional[TemplateInfo]: """Get information about loaded template. Args: template_name: Name of template Returns: TemplateInfo object or None if not found """ return self.templates.get(template_name) def list_templates(self) -> List[str]: """Get list of all loaded template names. Returns: List of template names """ return list(self.templates.keys()) def create_debug_image(self, image: np.ndarray, matches: List[TemplateMatch]) -> np.ndarray: """Create debug image showing template matches. Args: image: Original image matches: List of matches to highlight Returns: Debug image with matches drawn """ debug_img = image.copy() for match in matches: x, y, w, h = match.bbox # Draw bounding box cv2.rectangle(debug_img, (x, y), (x + w, y + h), (0, 255, 0), 2) # Draw center point center_x, center_y = match.center cv2.circle(debug_img, (center_x, center_y), 5, (255, 0, 0), -1) # Draw template name and confidence label = f"{match.template_name}: {match.confidence:.2f}" cv2.putText(debug_img, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) return debug_img