Initial project structure: reusable isometric bot engine with D2R implementation
This commit is contained in:
commit
e0282a7111
44 changed files with 3433 additions and 0 deletions
23
engine/screen/__init__.py
Normal file
23
engine/screen/__init__.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Screen reading components for visual game state detection.
|
||||
|
||||
This module provides tools for capturing, analyzing, and extracting information
|
||||
from game screenshots without requiring memory access or game modification.
|
||||
|
||||
Components:
|
||||
- capture: Screenshot capture using various backends
|
||||
- ocr: Optical Character Recognition for text extraction
|
||||
- template: Template matching for UI element detection
|
||||
"""
|
||||
|
||||
from .capture import ScreenCapture, ScreenRegion
|
||||
from .ocr import OCREngine, TextDetector
|
||||
from .template import TemplateManager, TemplateMatcher
|
||||
|
||||
__all__ = [
|
||||
"ScreenCapture",
|
||||
"ScreenRegion",
|
||||
"OCREngine",
|
||||
"TextDetector",
|
||||
"TemplateManager",
|
||||
"TemplateMatcher",
|
||||
]
|
||||
220
engine/screen/capture.py
Normal file
220
engine/screen/capture.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Screen capture utilities for taking game screenshots.
|
||||
|
||||
Provides efficient screenshot capture using multiple backends (mss, PIL)
|
||||
with support for specific regions and window targeting.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageGrab
|
||||
import mss
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenRegion:
|
||||
"""Defines a rectangular region of the screen to capture."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@property
|
||||
def bounds(self) -> Tuple[int, int, int, int]:
|
||||
"""Return region as (left, top, right, bottom) tuple."""
|
||||
return (self.x, self.y, self.x + self.width, self.y + self.height)
|
||||
|
||||
@property
|
||||
def mss_bounds(self) -> Dict[str, int]:
|
||||
"""Return region in MSS format."""
|
||||
return {
|
||||
"top": self.y,
|
||||
"left": self.x,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
}
|
||||
|
||||
|
||||
class ScreenCapture:
|
||||
"""High-performance screen capture with multiple backends."""
|
||||
|
||||
def __init__(self, backend: str = "mss", monitor: int = 1):
|
||||
"""Initialize screen capture.
|
||||
|
||||
Args:
|
||||
backend: Capture backend ("mss" or "pil")
|
||||
monitor: Monitor number to capture from (1-indexed)
|
||||
"""
|
||||
self.backend = backend
|
||||
self.monitor = monitor
|
||||
self._mss_instance: Optional[mss.mss] = None
|
||||
self._monitor_info: Optional[Dict[str, int]] = None
|
||||
|
||||
if backend == "mss":
|
||||
self._initialize_mss()
|
||||
|
||||
def _initialize_mss(self) -> None:
|
||||
"""Initialize MSS backend."""
|
||||
try:
|
||||
self._mss_instance = mss.mss()
|
||||
monitors = self._mss_instance.monitors
|
||||
|
||||
if self.monitor >= len(monitors):
|
||||
logger.warning(f"Monitor {self.monitor} not found, using primary")
|
||||
self.monitor = 1
|
||||
|
||||
self._monitor_info = monitors[self.monitor]
|
||||
logger.info(f"Initialized MSS capture for monitor {self.monitor}: "
|
||||
f"{self._monitor_info['width']}x{self._monitor_info['height']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MSS: {e}")
|
||||
self.backend = "pil"
|
||||
|
||||
def capture_screen(self, region: Optional[ScreenRegion] = None) -> np.ndarray:
|
||||
"""Capture screenshot of screen or region.
|
||||
|
||||
Args:
|
||||
region: Specific region to capture, or None for full screen
|
||||
|
||||
Returns:
|
||||
Screenshot as numpy array in BGR format (for OpenCV compatibility)
|
||||
"""
|
||||
try:
|
||||
if self.backend == "mss":
|
||||
return self._capture_mss(region)
|
||||
else:
|
||||
return self._capture_pil(region)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Screen capture failed: {e}")
|
||||
# Fallback to empty image
|
||||
return np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
def _capture_mss(self, region: Optional[ScreenRegion]) -> np.ndarray:
|
||||
"""Capture using MSS backend."""
|
||||
if not self._mss_instance:
|
||||
raise RuntimeError("MSS not initialized")
|
||||
|
||||
if region:
|
||||
monitor = region.mss_bounds
|
||||
else:
|
||||
monitor = self._monitor_info or self._mss_instance.monitors[self.monitor]
|
||||
|
||||
# MSS returns BGRA format
|
||||
screenshot = self._mss_instance.grab(monitor)
|
||||
img_array = np.frombuffer(screenshot.rgb, dtype=np.uint8)
|
||||
img_array = img_array.reshape((screenshot.height, screenshot.width, 3))
|
||||
|
||||
# Convert RGB to BGR for OpenCV
|
||||
return cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
|
||||
def _capture_pil(self, region: Optional[ScreenRegion]) -> np.ndarray:
|
||||
"""Capture using PIL backend."""
|
||||
if region:
|
||||
bbox = region.bounds
|
||||
else:
|
||||
bbox = None
|
||||
|
||||
# PIL returns RGB format
|
||||
screenshot = ImageGrab.grab(bbox=bbox)
|
||||
img_array = np.array(screenshot)
|
||||
|
||||
# Convert RGB to BGR for OpenCV
|
||||
return cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
|
||||
def save_screenshot(self, filename: str, region: Optional[ScreenRegion] = None) -> bool:
|
||||
"""Save screenshot to file.
|
||||
|
||||
Args:
|
||||
filename: Output filename
|
||||
region: Region to capture, or None for full screen
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
img = self.capture_screen(region)
|
||||
return cv2.imwrite(filename, img)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save screenshot: {e}")
|
||||
return False
|
||||
|
||||
def get_screen_size(self) -> Tuple[int, int]:
|
||||
"""Get screen dimensions.
|
||||
|
||||
Returns:
|
||||
(width, height) tuple
|
||||
"""
|
||||
if self.backend == "mss" and self._monitor_info:
|
||||
return (self._monitor_info["width"], self._monitor_info["height"])
|
||||
else:
|
||||
# Use PIL as fallback
|
||||
screenshot = ImageGrab.grab()
|
||||
return screenshot.size
|
||||
|
||||
def find_window(self, window_title: str) -> Optional[ScreenRegion]:
|
||||
"""Find window by title and return its region.
|
||||
|
||||
Args:
|
||||
window_title: Partial or full window title to search for
|
||||
|
||||
Returns:
|
||||
ScreenRegion if window found, None otherwise
|
||||
|
||||
Note:
|
||||
This is a placeholder - actual implementation would use
|
||||
platform-specific window enumeration (e.g., Windows API, X11)
|
||||
"""
|
||||
# TODO: Implement window finding
|
||||
logger.warning("Window finding not implemented yet")
|
||||
return None
|
||||
|
||||
def benchmark_capture(self, iterations: int = 100) -> Dict[str, float]:
|
||||
"""Benchmark capture performance.
|
||||
|
||||
Args:
|
||||
iterations: Number of captures to perform
|
||||
|
||||
Returns:
|
||||
Performance statistics
|
||||
"""
|
||||
logger.info(f"Benchmarking {self.backend} backend ({iterations} iterations)")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(iterations):
|
||||
self.capture_screen()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
avg_time = total_time / iterations
|
||||
fps = iterations / total_time
|
||||
|
||||
stats = {
|
||||
"backend": self.backend,
|
||||
"iterations": iterations,
|
||||
"total_time": total_time,
|
||||
"avg_time_ms": avg_time * 1000,
|
||||
"fps": fps,
|
||||
}
|
||||
|
||||
logger.info(f"Benchmark results: {avg_time*1000:.2f}ms avg, {fps:.1f} FPS")
|
||||
return stats
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
if self._mss_instance:
|
||||
self._mss_instance.close()
|
||||
346
engine/screen/ocr.py
Normal file
346
engine/screen/ocr.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"""OCR (Optical Character Recognition) for extracting text from screenshots.
|
||||
|
||||
Provides text detection and extraction capabilities using pytesseract
|
||||
with preprocessing for better accuracy in game environments.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||
import logging
|
||||
import re
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextMatch(NamedTuple):
|
||||
"""Represents detected text with position and confidence."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: Tuple[int, int, int, int] # (x, y, width, height)
|
||||
|
||||
|
||||
class OCRConfig:
|
||||
"""Configuration for OCR processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Tesseract configuration
|
||||
self.tesseract_config = "--oem 3 --psm 6" # Default config
|
||||
self.language = "eng"
|
||||
self.min_confidence = 30.0
|
||||
|
||||
# Image preprocessing
|
||||
self.preprocess = True
|
||||
self.scale_factor = 2.0
|
||||
self.denoise = True
|
||||
self.contrast_enhance = True
|
||||
|
||||
# Text filtering
|
||||
self.min_text_length = 1
|
||||
self.filter_patterns = [
|
||||
r'^[a-zA-Z0-9\s\-_:.,/]+$', # Alphanumeric with common punctuation
|
||||
]
|
||||
|
||||
|
||||
class OCREngine:
|
||||
"""OCR engine for text extraction from game screenshots."""
|
||||
|
||||
def __init__(self, config: Optional[OCRConfig] = None):
|
||||
"""Initialize OCR engine.
|
||||
|
||||
Args:
|
||||
config: OCR configuration, or None for defaults
|
||||
"""
|
||||
self.config = config or OCRConfig()
|
||||
self._verify_tesseract()
|
||||
|
||||
def _verify_tesseract(self) -> None:
|
||||
"""Verify tesseract installation."""
|
||||
try:
|
||||
pytesseract.get_tesseract_version()
|
||||
logger.info("Tesseract initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Tesseract not found or not working: {e}")
|
||||
raise RuntimeError("Tesseract OCR is required but not available")
|
||||
|
||||
def extract_text(self, image: np.ndarray, region: Optional[Tuple[int, int, int, int]] = None) -> str:
|
||||
"""Extract all text from image.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array
|
||||
region: Optional (x, y, width, height) region to process
|
||||
|
||||
Returns:
|
||||
Extracted text as string
|
||||
"""
|
||||
processed_img = self._preprocess_image(image, region)
|
||||
|
||||
try:
|
||||
text = pytesseract.image_to_string(
|
||||
processed_img,
|
||||
lang=self.config.language,
|
||||
config=self.config.tesseract_config
|
||||
)
|
||||
|
||||
return self._clean_text(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
def find_text(self, image: np.ndarray, search_text: str,
|
||||
case_sensitive: bool = False) -> List[TextMatch]:
|
||||
"""Find specific text in image with positions.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array
|
||||
search_text: Text to search for
|
||||
case_sensitive: Whether search should be case sensitive
|
||||
|
||||
Returns:
|
||||
List of TextMatch objects for found text
|
||||
"""
|
||||
processed_img = self._preprocess_image(image)
|
||||
|
||||
try:
|
||||
# Get detailed OCR data
|
||||
data = pytesseract.image_to_data(
|
||||
processed_img,
|
||||
lang=self.config.language,
|
||||
config=self.config.tesseract_config,
|
||||
output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
matches = []
|
||||
search_lower = search_text.lower() if not case_sensitive else search_text
|
||||
|
||||
for i in range(len(data['text'])):
|
||||
text = data['text'][i].strip()
|
||||
confidence = float(data['conf'][i])
|
||||
|
||||
if confidence < self.config.min_confidence:
|
||||
continue
|
||||
|
||||
text_to_match = text.lower() if not case_sensitive else text
|
||||
|
||||
if search_lower in text_to_match:
|
||||
bbox = (
|
||||
data['left'][i],
|
||||
data['top'][i],
|
||||
data['width'][i],
|
||||
data['height'][i]
|
||||
)
|
||||
|
||||
matches.append(TextMatch(text, confidence, bbox))
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text search failed: {e}")
|
||||
return []
|
||||
|
||||
def get_text_regions(self, image: np.ndarray) -> List[TextMatch]:
|
||||
"""Get all text regions with positions and confidence.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array
|
||||
|
||||
Returns:
|
||||
List of TextMatch objects for all detected text
|
||||
"""
|
||||
processed_img = self._preprocess_image(image)
|
||||
|
||||
try:
|
||||
data = pytesseract.image_to_data(
|
||||
processed_img,
|
||||
lang=self.config.language,
|
||||
config=self.config.tesseract_config,
|
||||
output_type=pytesseract.Output.DICT
|
||||
)
|
||||
|
||||
text_regions = []
|
||||
|
||||
for i in range(len(data['text'])):
|
||||
text = data['text'][i].strip()
|
||||
confidence = float(data['conf'][i])
|
||||
|
||||
if (confidence < self.config.min_confidence or
|
||||
len(text) < self.config.min_text_length):
|
||||
continue
|
||||
|
||||
if not self._passes_text_filters(text):
|
||||
continue
|
||||
|
||||
bbox = (
|
||||
data['left'][i],
|
||||
data['top'][i],
|
||||
data['width'][i],
|
||||
data['height'][i]
|
||||
)
|
||||
|
||||
text_regions.append(TextMatch(text, confidence, bbox))
|
||||
|
||||
return text_regions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text region detection failed: {e}")
|
||||
return []
|
||||
|
||||
def _preprocess_image(self, image: np.ndarray,
|
||||
region: Optional[Tuple[int, int, int, int]] = None) -> Image.Image:
|
||||
"""Preprocess image for better OCR accuracy.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array
|
||||
region: Optional region to extract
|
||||
|
||||
Returns:
|
||||
Preprocessed PIL Image
|
||||
"""
|
||||
# Extract region if specified
|
||||
if region:
|
||||
x, y, w, h = region
|
||||
image = image[y:y+h, x:x+w]
|
||||
|
||||
if not self.config.preprocess:
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Scale up for better OCR
|
||||
if self.config.scale_factor > 1.0:
|
||||
height, width = gray.shape
|
||||
new_width = int(width * self.config.scale_factor)
|
||||
new_height = int(height * self.config.scale_factor)
|
||||
gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
# Denoise
|
||||
if self.config.denoise:
|
||||
gray = cv2.fastNlMeansDenoising(gray)
|
||||
|
||||
# Enhance contrast
|
||||
if self.config.contrast_enhance:
|
||||
# Use CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
||||
gray = clahe.apply(gray)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(gray)
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""Clean extracted text.
|
||||
|
||||
Args:
|
||||
text: Raw extracted text
|
||||
|
||||
Returns:
|
||||
Cleaned text
|
||||
"""
|
||||
# Remove extra whitespace
|
||||
text = re.sub(r'\s+', ' ', text.strip())
|
||||
|
||||
# Remove common OCR artifacts
|
||||
text = re.sub(r'[|¦]', 'I', text) # Vertical bars to I
|
||||
text = re.sub(r'[{}]', '', text) # Remove braces
|
||||
|
||||
return text
|
||||
|
||||
def _passes_text_filters(self, text: str) -> bool:
|
||||
"""Check if text passes configured filters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text passes filters
|
||||
"""
|
||||
if not self.config.filter_patterns:
|
||||
return True
|
||||
|
||||
for pattern in self.config.filter_patterns:
|
||||
if re.match(pattern, text):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TextDetector:
|
||||
"""High-level text detection interface."""
|
||||
|
||||
def __init__(self, ocr_config: Optional[OCRConfig] = None):
|
||||
"""Initialize text detector.
|
||||
|
||||
Args:
|
||||
ocr_config: OCR configuration
|
||||
"""
|
||||
self.ocr = OCREngine(ocr_config)
|
||||
self.text_cache: Dict[str, List[TextMatch]] = {}
|
||||
|
||||
def contains_text(self, image: np.ndarray, text: str,
|
||||
case_sensitive: bool = False) -> bool:
|
||||
"""Check if image contains specific text.
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
text: Text to search for
|
||||
case_sensitive: Case sensitive search
|
||||
|
||||
Returns:
|
||||
True if text found
|
||||
"""
|
||||
matches = self.ocr.find_text(image, text, case_sensitive)
|
||||
return len(matches) > 0
|
||||
|
||||
def wait_for_text(self, capture_func, text: str, timeout: float = 10.0,
|
||||
check_interval: float = 0.5) -> bool:
|
||||
"""Wait for specific text to appear on screen.
|
||||
|
||||
Args:
|
||||
capture_func: Function that returns screenshot
|
||||
text: Text to wait for
|
||||
timeout: Maximum wait time in seconds
|
||||
check_interval: Time between checks in seconds
|
||||
|
||||
Returns:
|
||||
True if text appeared, False if timeout
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
image = capture_func()
|
||||
if self.contains_text(image, text):
|
||||
return True
|
||||
|
||||
time.sleep(check_interval)
|
||||
|
||||
return False
|
||||
|
||||
def get_ui_text(self, image: np.ndarray) -> Dict[str, str]:
|
||||
"""Extract common UI text elements.
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
|
||||
Returns:
|
||||
Dictionary mapping UI elements to text
|
||||
"""
|
||||
# This is a placeholder for game-specific UI text extraction
|
||||
# In practice, this would define regions for health, mana, inventory, etc.
|
||||
text_regions = self.ocr.get_text_regions(image)
|
||||
|
||||
ui_text = {}
|
||||
for region in text_regions:
|
||||
# Categorize text based on position or pattern
|
||||
if "health" in region.text.lower():
|
||||
ui_text["health"] = region.text
|
||||
elif "mana" in region.text.lower():
|
||||
ui_text["mana"] = region.text
|
||||
# Add more UI element detection
|
||||
|
||||
return ui_text
|
||||
403
engine/screen/template.py
Normal file
403
engine/screen/template.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""Template matching for UI element detection in game screenshots.
|
||||
|
||||
Provides efficient template matching using OpenCV with support for
|
||||
multiple templates, confidence thresholds, and template management.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateMatch(NamedTuple):
|
||||
"""Represents a template match with position and confidence."""
|
||||
template_name: str
|
||||
confidence: float
|
||||
center: Tuple[int, int] # (x, y) center position
|
||||
bbox: Tuple[int, int, int, int] # (x, y, width, height)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
"""Information about a loaded template."""
|
||||
name: str
|
||||
image: np.ndarray
|
||||
width: int
|
||||
height: int
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Core template matching functionality."""
|
||||
|
||||
def __init__(self, method: int = cv2.TM_CCOEFF_NORMED,
|
||||
threshold: float = 0.8):
|
||||
"""Initialize template matcher.
|
||||
|
||||
Args:
|
||||
method: OpenCV template matching method
|
||||
threshold: Minimum confidence threshold (0.0 to 1.0)
|
||||
"""
|
||||
self.method = method
|
||||
self.threshold = threshold
|
||||
|
||||
def match_template(self, image: np.ndarray, template: np.ndarray,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Match single template in image.
|
||||
|
||||
Args:
|
||||
image: Source image to search in
|
||||
template: Template image to find
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.threshold
|
||||
|
||||
# Convert to grayscale if needed
|
||||
if len(image.shape) == 3:
|
||||
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
image_gray = image
|
||||
|
||||
if len(template.shape) == 3:
|
||||
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
template_gray = template
|
||||
|
||||
# Perform template matching
|
||||
result = cv2.matchTemplate(image_gray, template_gray, self.method)
|
||||
|
||||
# Find matches above threshold
|
||||
locations = np.where(result >= threshold)
|
||||
|
||||
matches = []
|
||||
template_h, template_w = template_gray.shape
|
||||
|
||||
for pt in zip(*locations[::-1]): # Switch x and y
|
||||
x, y = pt
|
||||
confidence = result[y, x]
|
||||
|
||||
center = (x + template_w // 2, y + template_h // 2)
|
||||
bbox = (x, y, template_w, template_h)
|
||||
|
||||
matches.append(TemplateMatch("", confidence, center, bbox))
|
||||
|
||||
# Remove overlapping matches (Non-Maximum Suppression)
|
||||
matches = self._apply_nms(matches, overlap_threshold=0.3)
|
||||
|
||||
return matches
|
||||
|
||||
def match_multiple_scales(self, image: np.ndarray, template: np.ndarray,
|
||||
scales: List[float] = None,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Match template at multiple scales.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template: Template image
|
||||
scales: List of scale factors to try
|
||||
threshold: Confidence threshold
|
||||
|
||||
Returns:
|
||||
List of matches at all scales
|
||||
"""
|
||||
if scales is None:
|
||||
scales = [0.8, 0.9, 1.0, 1.1, 1.2]
|
||||
|
||||
all_matches = []
|
||||
|
||||
for scale in scales:
|
||||
# Scale template
|
||||
new_width = int(template.shape[1] * scale)
|
||||
new_height = int(template.shape[0] * scale)
|
||||
|
||||
if new_width < 10 or new_height < 10:
|
||||
continue # Skip very small templates
|
||||
|
||||
scaled_template = cv2.resize(template, (new_width, new_height))
|
||||
|
||||
# Find matches at this scale
|
||||
matches = self.match_template(image, scaled_template, threshold)
|
||||
all_matches.extend(matches)
|
||||
|
||||
# Apply NMS across all scales
|
||||
all_matches = self._apply_nms(all_matches, overlap_threshold=0.5)
|
||||
|
||||
return all_matches
|
||||
|
||||
def _apply_nms(self, matches: List[TemplateMatch],
|
||||
overlap_threshold: float = 0.3) -> List[TemplateMatch]:
|
||||
"""Apply Non-Maximum Suppression to remove overlapping matches.
|
||||
|
||||
Args:
|
||||
matches: List of template matches
|
||||
overlap_threshold: Maximum allowed overlap ratio
|
||||
|
||||
Returns:
|
||||
Filtered list of matches
|
||||
"""
|
||||
if not matches:
|
||||
return matches
|
||||
|
||||
# Sort by confidence (highest first)
|
||||
matches = sorted(matches, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
filtered_matches = []
|
||||
|
||||
for match in matches:
|
||||
# Check if this match overlaps significantly with any kept match
|
||||
is_duplicate = False
|
||||
|
||||
for kept_match in filtered_matches:
|
||||
if self._calculate_overlap(match, kept_match) > overlap_threshold:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
filtered_matches.append(match)
|
||||
|
||||
return filtered_matches
|
||||
|
||||
def _calculate_overlap(self, match1: TemplateMatch, match2: TemplateMatch) -> float:
|
||||
"""Calculate overlap ratio between two matches.
|
||||
|
||||
Args:
|
||||
match1: First match
|
||||
match2: Second match
|
||||
|
||||
Returns:
|
||||
Overlap ratio (0.0 to 1.0)
|
||||
"""
|
||||
x1, y1, w1, h1 = match1.bbox
|
||||
x2, y2, w2, h2 = match2.bbox
|
||||
|
||||
# Calculate intersection
|
||||
left = max(x1, x2)
|
||||
right = min(x1 + w1, x2 + w2)
|
||||
top = max(y1, y2)
|
||||
bottom = min(y1 + h1, y2 + h2)
|
||||
|
||||
if left >= right or top >= bottom:
|
||||
return 0.0
|
||||
|
||||
intersection = (right - left) * (bottom - top)
|
||||
area1 = w1 * h1
|
||||
area2 = w2 * h2
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manages a collection of templates for game UI detection."""
|
||||
|
||||
def __init__(self, template_dir: Optional[Path] = None):
|
||||
"""Initialize template manager.
|
||||
|
||||
Args:
|
||||
template_dir: Directory containing template images
|
||||
"""
|
||||
self.template_dir = template_dir
|
||||
self.templates: Dict[str, TemplateInfo] = {}
|
||||
self.matcher = TemplateMatcher()
|
||||
|
||||
if template_dir and template_dir.exists():
|
||||
self.load_templates_from_directory(template_dir)
|
||||
|
||||
def load_template(self, name: str, image_path: Path) -> bool:
|
||||
"""Load single template from file.
|
||||
|
||||
Args:
|
||||
name: Template identifier
|
||||
image_path: Path to template image
|
||||
|
||||
Returns:
|
||||
True if loaded successfully
|
||||
"""
|
||||
try:
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
logger.error(f"Could not load template image: {image_path}")
|
||||
return False
|
||||
|
||||
height, width = image.shape[:2]
|
||||
|
||||
self.templates[name] = TemplateInfo(
|
||||
name=name,
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
path=str(image_path)
|
||||
)
|
||||
|
||||
logger.info(f"Loaded template '{name}' ({width}x{height})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load template '{name}': {e}")
|
||||
return False
|
||||
|
||||
def load_templates_from_directory(self, directory: Path) -> int:
|
||||
"""Load all templates from directory.
|
||||
|
||||
Args:
|
||||
directory: Directory containing template images
|
||||
|
||||
Returns:
|
||||
Number of templates loaded
|
||||
"""
|
||||
loaded_count = 0
|
||||
|
||||
for image_path in directory.glob("*.png"):
|
||||
template_name = image_path.stem
|
||||
if self.load_template(template_name, image_path):
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"Loaded {loaded_count} templates from {directory}")
|
||||
return loaded_count
|
||||
|
||||
def find_template(self, image: np.ndarray, template_name: str,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Find specific template in image.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template_name: Name of template to find
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
if template_name not in self.templates:
|
||||
logger.warning(f"Template '{template_name}' not found")
|
||||
return []
|
||||
|
||||
template_info = self.templates[template_name]
|
||||
matches = self.matcher.match_template(image, template_info.image, threshold)
|
||||
|
||||
# Set template name in matches
|
||||
named_matches = []
|
||||
for match in matches:
|
||||
named_match = TemplateMatch(
|
||||
template_name=template_name,
|
||||
confidence=match.confidence,
|
||||
center=match.center,
|
||||
bbox=match.bbox
|
||||
)
|
||||
named_matches.append(named_match)
|
||||
|
||||
return named_matches
|
||||
|
||||
def find_any_template(self, image: np.ndarray,
|
||||
template_names: Optional[List[str]] = None,
|
||||
threshold: Optional[float] = None) -> List[TemplateMatch]:
|
||||
"""Find any of the specified templates in image.
|
||||
|
||||
Args:
|
||||
image: Source image
|
||||
template_names: List of template names to search for, or None for all
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
List of all matches found
|
||||
"""
|
||||
if template_names is None:
|
||||
template_names = list(self.templates.keys())
|
||||
|
||||
all_matches = []
|
||||
|
||||
for template_name in template_names:
|
||||
matches = self.find_template(image, template_name, threshold)
|
||||
all_matches.extend(matches)
|
||||
|
||||
# Sort by confidence
|
||||
all_matches.sort(key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
return all_matches
|
||||
|
||||
def wait_for_template(self, capture_func, template_name: str,
|
||||
timeout: float = 10.0, check_interval: float = 0.5,
|
||||
threshold: Optional[float] = None) -> Optional[TemplateMatch]:
|
||||
"""Wait for template to appear on screen.
|
||||
|
||||
Args:
|
||||
capture_func: Function that returns screenshot
|
||||
template_name: Template to wait for
|
||||
timeout: Maximum wait time in seconds
|
||||
check_interval: Time between checks in seconds
|
||||
threshold: Confidence threshold override
|
||||
|
||||
Returns:
|
||||
First match found, or None if timeout
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
image = capture_func()
|
||||
matches = self.find_template(image, template_name, threshold)
|
||||
|
||||
if matches:
|
||||
return matches[0] # Return best match
|
||||
|
||||
time.sleep(check_interval)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_info(self, template_name: str) -> Optional[TemplateInfo]:
|
||||
"""Get information about loaded template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template
|
||||
|
||||
Returns:
|
||||
TemplateInfo object or None if not found
|
||||
"""
|
||||
return self.templates.get(template_name)
|
||||
|
||||
def list_templates(self) -> List[str]:
|
||||
"""Get list of all loaded template names.
|
||||
|
||||
Returns:
|
||||
List of template names
|
||||
"""
|
||||
return list(self.templates.keys())
|
||||
|
||||
def create_debug_image(self, image: np.ndarray, matches: List[TemplateMatch]) -> np.ndarray:
|
||||
"""Create debug image showing template matches.
|
||||
|
||||
Args:
|
||||
image: Original image
|
||||
matches: List of matches to highlight
|
||||
|
||||
Returns:
|
||||
Debug image with matches drawn
|
||||
"""
|
||||
debug_img = image.copy()
|
||||
|
||||
for match in matches:
|
||||
x, y, w, h = match.bbox
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(debug_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||
|
||||
# Draw center point
|
||||
center_x, center_y = match.center
|
||||
cv2.circle(debug_img, (center_x, center_y), 5, (255, 0, 0), -1)
|
||||
|
||||
# Draw template name and confidence
|
||||
label = f"{match.template_name}: {match.confidence:.2f}"
|
||||
cv2.putText(debug_img, label, (x, y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
||||
|
||||
return debug_img
|
||||
Loading…
Add table
Add a link
Reference in a new issue