Initial project structure: reusable isometric bot engine with D2R implementation
This commit is contained in:
commit
e0282a7111
44 changed files with 3433 additions and 0 deletions
403
engine/screen/template.py
Normal file
403
engine/screen/template.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""Template matching for UI element detection in game screenshots.
|
||||
|
||||
Provides efficient template matching using OpenCV with support for
|
||||
multiple templates, confidence thresholds, and template management.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateMatch(NamedTuple):
|
||||
"""Represents a template match with position and confidence."""
|
||||
template_name: str
|
||||
confidence: float
|
||||
center: Tuple[int, int] # (x, y) center position
|
||||
bbox: Tuple[int, int, int, int] # (x, y, width, height)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
"""Information about a loaded template."""
|
||||
name: str
|
||||
image: np.ndarray
|
||||
width: int
|
||||
height: int
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Core template matching functionality."""
|
||||
|
||||
def __init__(self, method: int = cv2.TM_CCOEFF_NORMED,
|
||||
threshold: float = 0.8):
|
||||
"""Initialize template matcher.
|
||||
|
||||
Args:
|
||||
method: OpenCV template matching method
|
||||
threshold: Minimum confidence threshold (0.0 to 1.0)
|
||||
"""
|
||||
self.method = method
|
||||
self.threshold = threshold
|
||||
|
||||
def match_template(self, image: np.ndarray, template: np.ndarray,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Match single template in image.
|
||||
|
||||
Args:
|
||||
image: Source image to search in
|
||||
template: Template image to find
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.threshold
|
||||
|
||||
# Convert to grayscale if needed
|
||||
if len(image.shape) == 3:
|
||||
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
image_gray = image
|
||||
|
||||
if len(template.shape) == 3:
|
||||
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
template_gray = template
|
||||
|
||||
# Perform template matching
|
||||
result = cv2.matchTemplate(image_gray, template_gray, self.method)
|
||||
|
||||
# Find matches above threshold
|
||||
locations = np.where(result >= threshold)
|
||||
|
||||
matches = []
|
||||
template_h, template_w = template_gray.shape
|
||||
|
||||
for pt in zip(*locations[::-1]): # Switch x and y
|
||||
x, y = pt
|
||||
confidence = result[y, x]
|
||||
|
||||
center = (x + template_w // 2, y + template_h // 2)
|
||||
bbox = (x, y, template_w, template_h)
|
||||
|
||||
matches.append(TemplateMatch("", confidence, center, bbox))
|
||||
|
||||
# Remove overlapping matches (Non-Maximum Suppression)
|
||||
matches = self._apply_nms(matches, overlap_threshold=0.3)
|
||||
|
||||
return matches
|
||||
|
||||
def match_multiple_scales(self, image: np.ndarray, template: np.ndarray,
|
||||
scales: List[float] = None,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Match template at multiple scales.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template: Template image
|
||||
scales: List of scale factors to try
|
||||
threshold: Confidence threshold
|
||||
|
||||
Returns:
|
||||
List of matches at all scales
|
||||
"""
|
||||
if scales is None:
|
||||
scales = [0.8, 0.9, 1.0, 1.1, 1.2]
|
||||
|
||||
all_matches = []
|
||||
|
||||
for scale in scales:
|
||||
# Scale template
|
||||
new_width = int(template.shape[1] * scale)
|
||||
new_height = int(template.shape[0] * scale)
|
||||
|
||||
if new_width < 10 or new_height < 10:
|
||||
continue # Skip very small templates
|
||||
|
||||
scaled_template = cv2.resize(template, (new_width, new_height))
|
||||
|
||||
# Find matches at this scale
|
||||
matches = self.match_template(image, scaled_template, threshold)
|
||||
all_matches.extend(matches)
|
||||
|
||||
# Apply NMS across all scales
|
||||
all_matches = self._apply_nms(all_matches, overlap_threshold=0.5)
|
||||
|
||||
return all_matches
|
||||
|
||||
def _apply_nms(self, matches: List[TemplateMatch],
|
||||
overlap_threshold: float = 0.3) -> List[TemplateMatch]:
|
||||
"""Apply Non-Maximum Suppression to remove overlapping matches.
|
||||
|
||||
Args:
|
||||
matches: List of template matches
|
||||
overlap_threshold: Maximum allowed overlap ratio
|
||||
|
||||
Returns:
|
||||
Filtered list of matches
|
||||
"""
|
||||
if not matches:
|
||||
return matches
|
||||
|
||||
# Sort by confidence (highest first)
|
||||
matches = sorted(matches, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
filtered_matches = []
|
||||
|
||||
for match in matches:
|
||||
# Check if this match overlaps significantly with any kept match
|
||||
is_duplicate = False
|
||||
|
||||
for kept_match in filtered_matches:
|
||||
if self._calculate_overlap(match, kept_match) > overlap_threshold:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
filtered_matches.append(match)
|
||||
|
||||
return filtered_matches
|
||||
|
||||
def _calculate_overlap(self, match1: TemplateMatch, match2: TemplateMatch) -> float:
|
||||
"""Calculate overlap ratio between two matches.
|
||||
|
||||
Args:
|
||||
match1: First match
|
||||
match2: Second match
|
||||
|
||||
Returns:
|
||||
Overlap ratio (0.0 to 1.0)
|
||||
"""
|
||||
x1, y1, w1, h1 = match1.bbox
|
||||
x2, y2, w2, h2 = match2.bbox
|
||||
|
||||
# Calculate intersection
|
||||
left = max(x1, x2)
|
||||
right = min(x1 + w1, x2 + w2)
|
||||
top = max(y1, y2)
|
||||
bottom = min(y1 + h1, y2 + h2)
|
||||
|
||||
if left >= right or top >= bottom:
|
||||
return 0.0
|
||||
|
||||
intersection = (right - left) * (bottom - top)
|
||||
area1 = w1 * h1
|
||||
area2 = w2 * h2
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manages a collection of templates for game UI detection."""
|
||||
|
||||
def __init__(self, template_dir: Optional[Path] = None):
|
||||
"""Initialize template manager.
|
||||
|
||||
Args:
|
||||
template_dir: Directory containing template images
|
||||
"""
|
||||
self.template_dir = template_dir
|
||||
self.templates: Dict[str, TemplateInfo] = {}
|
||||
self.matcher = TemplateMatcher()
|
||||
|
||||
if template_dir and template_dir.exists():
|
||||
self.load_templates_from_directory(template_dir)
|
||||
|
||||
def load_template(self, name: str, image_path: Path) -> bool:
|
||||
"""Load single template from file.
|
||||
|
||||
Args:
|
||||
name: Template identifier
|
||||
image_path: Path to template image
|
||||
|
||||
Returns:
|
||||
True if loaded successfully
|
||||
"""
|
||||
try:
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
logger.error(f"Could not load template image: {image_path}")
|
||||
return False
|
||||
|
||||
height, width = image.shape[:2]
|
||||
|
||||
self.templates[name] = TemplateInfo(
|
||||
name=name,
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
path=str(image_path)
|
||||
)
|
||||
|
||||
logger.info(f"Loaded template '{name}' ({width}x{height})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load template '{name}': {e}")
|
||||
return False
|
||||
|
||||
def load_templates_from_directory(self, directory: Path) -> int:
|
||||
"""Load all templates from directory.
|
||||
|
||||
Args:
|
||||
directory: Directory containing template images
|
||||
|
||||
Returns:
|
||||
Number of templates loaded
|
||||
"""
|
||||
loaded_count = 0
|
||||
|
||||
for image_path in directory.glob("*.png"):
|
||||
template_name = image_path.stem
|
||||
if self.load_template(template_name, image_path):
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"Loaded {loaded_count} templates from {directory}")
|
||||
return loaded_count
|
||||
|
||||
def find_template(self, image: np.ndarray, template_name: str,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Find specific template in image.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template_name: Name of template to find
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
if template_name not in self.templates:
|
||||
logger.warning(f"Template '{template_name}' not found")
|
||||
return []
|
||||
|
||||
template_info = self.templates[template_name]
|
||||
matches = self.matcher.match_template(image, template_info.image, threshold)
|
||||
|
||||
# Set template name in matches
|
||||
named_matches = []
|
||||
for match in matches:
|
||||
named_match = TemplateMatch(
|
||||
template_name=template_name,
|
||||
confidence=match.confidence,
|
||||
center=match.center,
|
||||
bbox=match.bbox
|
||||
)
|
||||
named_matches.append(named_match)
|
||||
|
||||
return named_matches
|
||||
|
||||
def find_any_template(self, image: np.ndarray,
|
||||
template_names: Optional[List[str]] = None,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Find any of the specified templates in image.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template_names: List of template names to search for, or None for all
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of all matches found
|
||||
"""
|
||||
if template_names is None:
|
||||
template_names = list(self.templates.keys())
|
||||
|
||||
all_matches = []
|
||||
|
||||
for template_name in template_names:
|
||||
matches = self.find_template(image, template_name, threshold)
|
||||
all_matches.extend(matches)
|
||||
|
||||
# Sort by confidence
|
||||
all_matches.sort(key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
return all_matches
|
||||
|
||||
def wait_for_template(self, capture_func, template_name: str,
|
||||
timeout: float = 10.0, check_interval: float = 0.5,
|
||||
threshold: Optional[float] = None) -> Optional[TemplateMatch]:
|
||||
"""Wait for template to appear on screen.
|
||||
|
||||
Args:
|
||||
capture_func: Function that returns screenshot
|
||||
template_name: Template to wait for
|
||||
timeout: Maximum wait time in seconds
|
||||
check_interval: Time between checks in seconds
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
First match found, or None if timeout
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
image = capture_func()
|
||||
matches = self.find_template(image, template_name, threshold)
|
||||
|
||||
if matches:
|
||||
return matches[0] # Return best match
|
||||
|
||||
time.sleep(check_interval)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_info(self, template_name: str) -> Optional[TemplateInfo]:
|
||||
"""Get information about loaded template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template
|
||||
|
||||
Returns:
|
||||
TemplateInfo object or None if not found
|
||||
"""
|
||||
return self.templates.get(template_name)
|
||||
|
||||
def list_templates(self) -> List[str]:
|
||||
"""Get list of all loaded template names.
|
||||
|
||||
Returns:
|
||||
List of template names
|
||||
"""
|
||||
return list(self.templates.keys())
|
||||
|
||||
def create_debug_image(self, image: np.ndarray, matches: List[TemplateMatch]) -> np.ndarray:
|
||||
"""Create debug image showing template matches.
|
||||
|
||||
Args:
|
||||
image: Original image
|
||||
matches: List of matches to highlight
|
||||
|
||||
Returns:
|
||||
Debug image with matches drawn
|
||||
"""
|
||||
debug_img = image.copy()
|
||||
|
||||
for match in matches:
|
||||
x, y, w, h = match.bbox
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(debug_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||
|
||||
# Draw center point
|
||||
center_x, center_y = match.center
|
||||
cv2.circle(debug_img, (center_x, center_y), 5, (255, 0, 0), -1)
|
||||
|
||||
# Draw template name and confidence
|
||||
label = f"{match.template_name}: {match.confidence:.2f}"
|
||||
cv2.putText(debug_img, label, (x, y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
||||
|
||||
return debug_img
|
||||
Loading…
Add table
Add a link
Reference in a new issue