iso-bot/engine/screen/template.py

403 lines
No EOL
13 KiB
Python

"""Template matching for UI element detection in game screenshots.
Provides efficient template matching using OpenCV with support for
multiple templates, confidence thresholds, and template management.
"""
from typing import List, Dict, Optional, Tuple, NamedTuple
from pathlib import Path
import logging
from dataclasses import dataclass
import cv2
import numpy as np
logger = logging.getLogger(__name__)
class TemplateMatch(NamedTuple):
"""Represents a template match with position and confidence."""
template_name: str
confidence: float
center: Tuple[int, int] # (x, y) center position
bbox: Tuple[int, int, int, int] # (x, y, width, height)
@dataclass
class TemplateInfo:
"""Information about a loaded template."""
name: str
image: np.ndarray
width: int
height: int
path: Optional[str] = None
class TemplateMatcher:
"""Core template matching functionality."""
def __init__(self, method: int = cv2.TM_CCOEFF_NORMED,
threshold: float = 0.8):
"""Initialize template matcher.
Args:
method: OpenCV template matching method
threshold: Minimum confidence threshold (0.0 to 1.0)
"""
self.method = method
self.threshold = threshold
def match_template(self, image: np.ndarray, template: np.ndarray,
threshold: Optional[float] = None) -> List[TemplateMatch]:
"""Match single template in image.
Args:
image: Source image to search in
template: Template image to find
threshold: Confidence threshold override
Returns:
List of matches found
"""
if threshold is None:
threshold = self.threshold
# Convert to grayscale if needed
if len(image.shape) == 3:
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
image_gray = image
if len(template.shape) == 3:
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
else:
template_gray = template
# Perform template matching
result = cv2.matchTemplate(image_gray, template_gray, self.method)
# Find matches above threshold
locations = np.where(result >= threshold)
matches = []
template_h, template_w = template_gray.shape
for pt in zip(*locations[::-1]): # Switch x and y
x, y = pt
confidence = result[y, x]
center = (x + template_w // 2, y + template_h // 2)
bbox = (x, y, template_w, template_h)
matches.append(TemplateMatch("", confidence, center, bbox))
# Remove overlapping matches (Non-Maximum Suppression)
matches = self._apply_nms(matches, overlap_threshold=0.3)
return matches
def match_multiple_scales(self, image: np.ndarray, template: np.ndarray,
scales: List[float] = None,
threshold: Optional[float] = None) -> List[TemplateMatch]:
"""Match template at multiple scales.
Args:
image: Source image
template: Template image
scales: List of scale factors to try
threshold: Confidence threshold
Returns:
List of matches at all scales
"""
if scales is None:
scales = [0.8, 0.9, 1.0, 1.1, 1.2]
all_matches = []
for scale in scales:
# Scale template
new_width = int(template.shape[1] * scale)
new_height = int(template.shape[0] * scale)
if new_width < 10 or new_height < 10:
continue # Skip very small templates
scaled_template = cv2.resize(template, (new_width, new_height))
# Find matches at this scale
matches = self.match_template(image, scaled_template, threshold)
all_matches.extend(matches)
# Apply NMS across all scales
all_matches = self._apply_nms(all_matches, overlap_threshold=0.5)
return all_matches
def _apply_nms(self, matches: List[TemplateMatch],
overlap_threshold: float = 0.3) -> List[TemplateMatch]:
"""Apply Non-Maximum Suppression to remove overlapping matches.
Args:
matches: List of template matches
overlap_threshold: Maximum allowed overlap ratio
Returns:
Filtered list of matches
"""
if not matches:
return matches
# Sort by confidence (highest first)
matches = sorted(matches, key=lambda x: x.confidence, reverse=True)
filtered_matches = []
for match in matches:
# Check if this match overlaps significantly with any kept match
is_duplicate = False
for kept_match in filtered_matches:
if self._calculate_overlap(match, kept_match) > overlap_threshold:
is_duplicate = True
break
if not is_duplicate:
filtered_matches.append(match)
return filtered_matches
def _calculate_overlap(self, match1: TemplateMatch, match2: TemplateMatch) -> float:
"""Calculate overlap ratio between two matches.
Args:
match1: First match
match2: Second match
Returns:
Overlap ratio (0.0 to 1.0)
"""
x1, y1, w1, h1 = match1.bbox
x2, y2, w2, h2 = match2.bbox
# Calculate intersection
left = max(x1, x2)
right = min(x1 + w1, x2 + w2)
top = max(y1, y2)
bottom = min(y1 + h1, y2 + h2)
if left >= right or top >= bottom:
return 0.0
intersection = (right - left) * (bottom - top)
area1 = w1 * h1
area2 = w2 * h2
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
class TemplateManager:
"""Manages a collection of templates for game UI detection."""
def __init__(self, template_dir: Optional[Path] = None):
"""Initialize template manager.
Args:
template_dir: Directory containing template images
"""
self.template_dir = template_dir
self.templates: Dict[str, TemplateInfo] = {}
self.matcher = TemplateMatcher()
if template_dir and template_dir.exists():
self.load_templates_from_directory(template_dir)
def load_template(self, name: str, image_path: Path) -> bool:
"""Load single template from file.
Args:
name: Template identifier
image_path: Path to template image
Returns:
True if loaded successfully
"""
try:
image = cv2.imread(str(image_path))
if image is None:
logger.error(f"Could not load template image: {image_path}")
return False
height, width = image.shape[:2]
self.templates[name] = TemplateInfo(
name=name,
image=image,
width=width,
height=height,
path=str(image_path)
)
logger.info(f"Loaded template '{name}' ({width}x{height})")
return True
except Exception as e:
logger.error(f"Failed to load template '{name}': {e}")
return False
def load_templates_from_directory(self, directory: Path) -> int:
"""Load all templates from directory.
Args:
directory: Directory containing template images
Returns:
Number of templates loaded
"""
loaded_count = 0
for image_path in directory.glob("*.png"):
template_name = image_path.stem
if self.load_template(template_name, image_path):
loaded_count += 1
logger.info(f"Loaded {loaded_count} templates from {directory}")
return loaded_count
def find_template(self, image: np.ndarray, template_name: str,
threshold: Optional[float] = None) -> List[TemplateMatch]:
"""Find specific template in image.
Args:
image: Source image
template_name: Name of template to find
threshold: Confidence threshold override
Returns:
List of matches found
"""
if template_name not in self.templates:
logger.warning(f"Template '{template_name}' not found")
return []
template_info = self.templates[template_name]
matches = self.matcher.match_template(image, template_info.image, threshold)
# Set template name in matches
named_matches = []
for match in matches:
named_match = TemplateMatch(
template_name=template_name,
confidence=match.confidence,
center=match.center,
bbox=match.bbox
)
named_matches.append(named_match)
return named_matches
def find_any_template(self, image: np.ndarray,
template_names: Optional[List[str]] = None,
threshold: Optional[float] = None) -> List[TemplateMatch]:
"""Find any of the specified templates in image.
Args:
image: Source image
template_names: List of template names to search for, or None for all
threshold: Confidence threshold override
Returns:
List of all matches found
"""
if template_names is None:
template_names = list(self.templates.keys())
all_matches = []
for template_name in template_names:
matches = self.find_template(image, template_name, threshold)
all_matches.extend(matches)
# Sort by confidence
all_matches.sort(key=lambda x: x.confidence, reverse=True)
return all_matches
def wait_for_template(self, capture_func, template_name: str,
timeout: float = 10.0, check_interval: float = 0.5,
threshold: Optional[float] = None) -> Optional[TemplateMatch]:
"""Wait for template to appear on screen.
Args:
capture_func: Function that returns screenshot
template_name: Template to wait for
timeout: Maximum wait time in seconds
check_interval: Time between checks in seconds
threshold: Confidence threshold override
Returns:
First match found, or None if timeout
"""
import time
start_time = time.time()
while time.time() - start_time < timeout:
image = capture_func()
matches = self.find_template(image, template_name, threshold)
if matches:
return matches[0] # Return best match
time.sleep(check_interval)
return None
def get_template_info(self, template_name: str) -> Optional[TemplateInfo]:
"""Get information about loaded template.
Args:
template_name: Name of template
Returns:
TemplateInfo object or None if not found
"""
return self.templates.get(template_name)
def list_templates(self) -> List[str]:
"""Get list of all loaded template names.
Returns:
List of template names
"""
return list(self.templates.keys())
def create_debug_image(self, image: np.ndarray, matches: List[TemplateMatch]) -> np.ndarray:
"""Create debug image showing template matches.
Args:
image: Original image
matches: List of matches to highlight
Returns:
Debug image with matches drawn
"""
debug_img = image.copy()
for match in matches:
x, y, w, h = match.bbox
# Draw bounding box
cv2.rectangle(debug_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
# Draw center point
center_x, center_y = match.center
cv2.circle(debug_img, (center_x, center_y), 5, (255, 0, 0), -1)
# Draw template name and confidence
label = f"{match.template_name}: {match.confidence:.2f}"
cv2.putText(debug_img, label, (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
return debug_img