403 lines
No EOL
13 KiB
Python
403 lines
No EOL
13 KiB
Python
"""Template matching for UI element detection in game screenshots.
|
|
|
|
Provides efficient template matching using OpenCV with support for
|
|
multiple templates, confidence thresholds, and template management.
|
|
"""
|
|
|
|
from typing import List, Dict, Optional, Tuple, NamedTuple
|
|
from pathlib import Path
|
|
import logging
|
|
from dataclasses import dataclass
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TemplateMatch(NamedTuple):
|
|
"""Represents a template match with position and confidence."""
|
|
template_name: str
|
|
confidence: float
|
|
center: Tuple[int, int] # (x, y) center position
|
|
bbox: Tuple[int, int, int, int] # (x, y, width, height)
|
|
|
|
|
|
@dataclass
|
|
class TemplateInfo:
|
|
"""Information about a loaded template."""
|
|
name: str
|
|
image: np.ndarray
|
|
width: int
|
|
height: int
|
|
path: Optional[str] = None
|
|
|
|
|
|
class TemplateMatcher:
|
|
"""Core template matching functionality."""
|
|
|
|
def __init__(self, method: int = cv2.TM_CCOEFF_NORMED,
|
|
threshold: float = 0.8):
|
|
"""Initialize template matcher.
|
|
|
|
Args:
|
|
method: OpenCV template matching method
|
|
threshold: Minimum confidence threshold (0.0 to 1.0)
|
|
"""
|
|
self.method = method
|
|
self.threshold = threshold
|
|
|
|
def match_template(self, image: np.ndarray, template: np.ndarray,
|
|
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
|
"""Match single template in image.
|
|
|
|
Args:
|
|
image: Source image to search in
|
|
template: Template image to find
|
|
threshold: Confidence threshold override
|
|
|
|
Returns:
|
|
List of matches found
|
|
"""
|
|
if threshold is None:
|
|
threshold = self.threshold
|
|
|
|
# Convert to grayscale if needed
|
|
if len(image.shape) == 3:
|
|
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
image_gray = image
|
|
|
|
if len(template.shape) == 3:
|
|
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
template_gray = template
|
|
|
|
# Perform template matching
|
|
result = cv2.matchTemplate(image_gray, template_gray, self.method)
|
|
|
|
# Find matches above threshold
|
|
locations = np.where(result >= threshold)
|
|
|
|
matches = []
|
|
template_h, template_w = template_gray.shape
|
|
|
|
for pt in zip(*locations[::-1]): # Switch x and y
|
|
x, y = pt
|
|
confidence = result[y, x]
|
|
|
|
center = (x + template_w // 2, y + template_h // 2)
|
|
bbox = (x, y, template_w, template_h)
|
|
|
|
matches.append(TemplateMatch("", confidence, center, bbox))
|
|
|
|
# Remove overlapping matches (Non-Maximum Suppression)
|
|
matches = self._apply_nms(matches, overlap_threshold=0.3)
|
|
|
|
return matches
|
|
|
|
def match_multiple_scales(self, image: np.ndarray, template: np.ndarray,
|
|
scales: List[float] = None,
|
|
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
|
"""Match template at multiple scales.
|
|
|
|
Args:
|
|
image: Source image
|
|
template: Template image
|
|
scales: List of scale factors to try
|
|
threshold: Confidence threshold
|
|
|
|
Returns:
|
|
List of matches at all scales
|
|
"""
|
|
if scales is None:
|
|
scales = [0.8, 0.9, 1.0, 1.1, 1.2]
|
|
|
|
all_matches = []
|
|
|
|
for scale in scales:
|
|
# Scale template
|
|
new_width = int(template.shape[1] * scale)
|
|
new_height = int(template.shape[0] * scale)
|
|
|
|
if new_width < 10 or new_height < 10:
|
|
continue # Skip very small templates
|
|
|
|
scaled_template = cv2.resize(template, (new_width, new_height))
|
|
|
|
# Find matches at this scale
|
|
matches = self.match_template(image, scaled_template, threshold)
|
|
all_matches.extend(matches)
|
|
|
|
# Apply NMS across all scales
|
|
all_matches = self._apply_nms(all_matches, overlap_threshold=0.5)
|
|
|
|
return all_matches
|
|
|
|
def _apply_nms(self, matches: List[TemplateMatch],
|
|
overlap_threshold: float = 0.3) -> List[TemplateMatch]:
|
|
"""Apply Non-Maximum Suppression to remove overlapping matches.
|
|
|
|
Args:
|
|
matches: List of template matches
|
|
overlap_threshold: Maximum allowed overlap ratio
|
|
|
|
Returns:
|
|
Filtered list of matches
|
|
"""
|
|
if not matches:
|
|
return matches
|
|
|
|
# Sort by confidence (highest first)
|
|
matches = sorted(matches, key=lambda x: x.confidence, reverse=True)
|
|
|
|
filtered_matches = []
|
|
|
|
for match in matches:
|
|
# Check if this match overlaps significantly with any kept match
|
|
is_duplicate = False
|
|
|
|
for kept_match in filtered_matches:
|
|
if self._calculate_overlap(match, kept_match) > overlap_threshold:
|
|
is_duplicate = True
|
|
break
|
|
|
|
if not is_duplicate:
|
|
filtered_matches.append(match)
|
|
|
|
return filtered_matches
|
|
|
|
def _calculate_overlap(self, match1: TemplateMatch, match2: TemplateMatch) -> float:
|
|
"""Calculate overlap ratio between two matches.
|
|
|
|
Args:
|
|
match1: First match
|
|
match2: Second match
|
|
|
|
Returns:
|
|
Overlap ratio (0.0 to 1.0)
|
|
"""
|
|
x1, y1, w1, h1 = match1.bbox
|
|
x2, y2, w2, h2 = match2.bbox
|
|
|
|
# Calculate intersection
|
|
left = max(x1, x2)
|
|
right = min(x1 + w1, x2 + w2)
|
|
top = max(y1, y2)
|
|
bottom = min(y1 + h1, y2 + h2)
|
|
|
|
if left >= right or top >= bottom:
|
|
return 0.0
|
|
|
|
intersection = (right - left) * (bottom - top)
|
|
area1 = w1 * h1
|
|
area2 = w2 * h2
|
|
union = area1 + area2 - intersection
|
|
|
|
return intersection / union if union > 0 else 0.0
|
|
|
|
|
|
class TemplateManager:
|
|
"""Manages a collection of templates for game UI detection."""
|
|
|
|
def __init__(self, template_dir: Optional[Path] = None):
|
|
"""Initialize template manager.
|
|
|
|
Args:
|
|
template_dir: Directory containing template images
|
|
"""
|
|
self.template_dir = template_dir
|
|
self.templates: Dict[str, TemplateInfo] = {}
|
|
self.matcher = TemplateMatcher()
|
|
|
|
if template_dir and template_dir.exists():
|
|
self.load_templates_from_directory(template_dir)
|
|
|
|
def load_template(self, name: str, image_path: Path) -> bool:
|
|
"""Load single template from file.
|
|
|
|
Args:
|
|
name: Template identifier
|
|
image_path: Path to template image
|
|
|
|
Returns:
|
|
True if loaded successfully
|
|
"""
|
|
try:
|
|
image = cv2.imread(str(image_path))
|
|
if image is None:
|
|
logger.error(f"Could not load template image: {image_path}")
|
|
return False
|
|
|
|
height, width = image.shape[:2]
|
|
|
|
self.templates[name] = TemplateInfo(
|
|
name=name,
|
|
image=image,
|
|
width=width,
|
|
height=height,
|
|
path=str(image_path)
|
|
)
|
|
|
|
logger.info(f"Loaded template '{name}' ({width}x{height})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load template '{name}': {e}")
|
|
return False
|
|
|
|
def load_templates_from_directory(self, directory: Path) -> int:
|
|
"""Load all templates from directory.
|
|
|
|
Args:
|
|
directory: Directory containing template images
|
|
|
|
Returns:
|
|
Number of templates loaded
|
|
"""
|
|
loaded_count = 0
|
|
|
|
for image_path in directory.glob("*.png"):
|
|
template_name = image_path.stem
|
|
if self.load_template(template_name, image_path):
|
|
loaded_count += 1
|
|
|
|
logger.info(f"Loaded {loaded_count} templates from {directory}")
|
|
return loaded_count
|
|
|
|
def find_template(self, image: np.ndarray, template_name: str,
|
|
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
|
"""Find specific template in image.
|
|
|
|
Args:
|
|
image: Source image
|
|
template_name: Name of template to find
|
|
threshold: Confidence threshold override
|
|
|
|
Returns:
|
|
List of matches found
|
|
"""
|
|
if template_name not in self.templates:
|
|
logger.warning(f"Template '{template_name}' not found")
|
|
return []
|
|
|
|
template_info = self.templates[template_name]
|
|
matches = self.matcher.match_template(image, template_info.image, threshold)
|
|
|
|
# Set template name in matches
|
|
named_matches = []
|
|
for match in matches:
|
|
named_match = TemplateMatch(
|
|
template_name=template_name,
|
|
confidence=match.confidence,
|
|
center=match.center,
|
|
bbox=match.bbox
|
|
)
|
|
named_matches.append(named_match)
|
|
|
|
return named_matches
|
|
|
|
def find_any_template(self, image: np.ndarray,
|
|
template_names: Optional[List[str]] = None,
|
|
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
|
"""Find any of the specified templates in image.
|
|
|
|
Args:
|
|
image: Source image
|
|
template_names: List of template names to search for, or None for all
|
|
threshold: Confidence threshold override
|
|
|
|
Returns:
|
|
List of all matches found
|
|
"""
|
|
if template_names is None:
|
|
template_names = list(self.templates.keys())
|
|
|
|
all_matches = []
|
|
|
|
for template_name in template_names:
|
|
matches = self.find_template(image, template_name, threshold)
|
|
all_matches.extend(matches)
|
|
|
|
# Sort by confidence
|
|
all_matches.sort(key=lambda x: x.confidence, reverse=True)
|
|
|
|
return all_matches
|
|
|
|
def wait_for_template(self, capture_func, template_name: str,
|
|
timeout: float = 10.0, check_interval: float = 0.5,
|
|
threshold: Optional[float] = None) -> Optional[TemplateMatch]:
|
|
"""Wait for template to appear on screen.
|
|
|
|
Args:
|
|
capture_func: Function that returns screenshot
|
|
template_name: Template to wait for
|
|
timeout: Maximum wait time in seconds
|
|
check_interval: Time between checks in seconds
|
|
threshold: Confidence threshold override
|
|
|
|
Returns:
|
|
First match found, or None if timeout
|
|
"""
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
image = capture_func()
|
|
matches = self.find_template(image, template_name, threshold)
|
|
|
|
if matches:
|
|
return matches[0] # Return best match
|
|
|
|
time.sleep(check_interval)
|
|
|
|
return None
|
|
|
|
def get_template_info(self, template_name: str) -> Optional[TemplateInfo]:
|
|
"""Get information about loaded template.
|
|
|
|
Args:
|
|
template_name: Name of template
|
|
|
|
Returns:
|
|
TemplateInfo object or None if not found
|
|
"""
|
|
return self.templates.get(template_name)
|
|
|
|
def list_templates(self) -> List[str]:
|
|
"""Get list of all loaded template names.
|
|
|
|
Returns:
|
|
List of template names
|
|
"""
|
|
return list(self.templates.keys())
|
|
|
|
def create_debug_image(self, image: np.ndarray, matches: List[TemplateMatch]) -> np.ndarray:
|
|
"""Create debug image showing template matches.
|
|
|
|
Args:
|
|
image: Original image
|
|
matches: List of matches to highlight
|
|
|
|
Returns:
|
|
Debug image with matches drawn
|
|
"""
|
|
debug_img = image.copy()
|
|
|
|
for match in matches:
|
|
x, y, w, h = match.bbox
|
|
|
|
# Draw bounding box
|
|
cv2.rectangle(debug_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
|
|
|
# Draw center point
|
|
center_x, center_y = match.center
|
|
cv2.circle(debug_img, (center_x, center_y), 5, (255, 0, 0), -1)
|
|
|
|
# Draw template name and confidence
|
|
label = f"{match.template_name}: {match.confidence:.2f}"
|
|
cv2.putText(debug_img, label, (x, y - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
|
|
|
return debug_img |