140 lines
4.7 KiB
Python
140 lines
4.7 KiB
Python
"""Object and UI element detection using computer vision.
|
|
|
|
Provides high-level detection for game elements using template matching,
|
|
color filtering, and contour analysis.
|
|
"""
|
|
|
|
from typing import List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import logging
|
|
|
|
import numpy as np
|
|
import cv2
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Detection:
|
|
"""Represents a detected object/element on screen."""
|
|
|
|
x: int
|
|
y: int
|
|
width: int
|
|
height: int
|
|
confidence: float
|
|
label: str = ""
|
|
|
|
@property
|
|
def center(self) -> Tuple[int, int]:
|
|
return (self.x + self.width // 2, self.y + self.height // 2)
|
|
|
|
@property
|
|
def bounds(self) -> Tuple[int, int, int, int]:
|
|
return (self.x, self.y, self.x + self.width, self.y + self.height)
|
|
|
|
|
|
class ElementDetector:
|
|
"""Detects game UI elements and objects via computer vision."""
|
|
|
|
def __init__(self, confidence_threshold: float = 0.8):
|
|
self.confidence_threshold = confidence_threshold
|
|
self._templates: dict[str, np.ndarray] = {}
|
|
|
|
def load_template(self, name: str, image_path: str) -> None:
|
|
"""Load a template image for matching."""
|
|
template = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
|
if template is None:
|
|
raise FileNotFoundError(f"Template not found: {image_path}")
|
|
self._templates[name] = template
|
|
logger.debug(f"Loaded template '{name}': {template.shape}")
|
|
|
|
def find_template(
|
|
self, screen: np.ndarray, template_name: str,
|
|
method: int = cv2.TM_CCOEFF_NORMED,
|
|
) -> Optional[Detection]:
|
|
"""Find best match of a template in the screen image."""
|
|
if template_name not in self._templates:
|
|
logger.error(f"Unknown template: {template_name}")
|
|
return None
|
|
|
|
template = self._templates[template_name]
|
|
result = cv2.matchTemplate(screen, template, method)
|
|
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
|
|
|
if max_val >= self.confidence_threshold:
|
|
h, w = template.shape[:2]
|
|
return Detection(
|
|
x=max_loc[0], y=max_loc[1],
|
|
width=w, height=h,
|
|
confidence=max_val, label=template_name,
|
|
)
|
|
return None
|
|
|
|
def find_all_templates(
|
|
self, screen: np.ndarray, template_name: str,
|
|
method: int = cv2.TM_CCOEFF_NORMED,
|
|
) -> List[Detection]:
|
|
"""Find all matches of a template above confidence threshold."""
|
|
if template_name not in self._templates:
|
|
return []
|
|
|
|
template = self._templates[template_name]
|
|
h, w = template.shape[:2]
|
|
result = cv2.matchTemplate(screen, template, method)
|
|
|
|
locations = np.where(result >= self.confidence_threshold)
|
|
detections = []
|
|
|
|
for pt in zip(*locations[::-1]):
|
|
detections.append(Detection(
|
|
x=pt[0], y=pt[1], width=w, height=h,
|
|
confidence=result[pt[1], pt[0]], label=template_name,
|
|
))
|
|
|
|
# Non-maximum suppression (simple distance-based)
|
|
return self._nms(detections, distance_threshold=min(w, h) // 2)
|
|
|
|
def find_by_color(
|
|
self, screen: np.ndarray, lower_hsv: Tuple[int, int, int],
|
|
upper_hsv: Tuple[int, int, int], min_area: int = 100,
|
|
label: str = "",
|
|
) -> List[Detection]:
|
|
"""Find objects by HSV color range."""
|
|
hsv = cv2.cvtColor(screen, cv2.COLOR_BGR2HSV)
|
|
mask = cv2.inRange(hsv, np.array(lower_hsv), np.array(upper_hsv))
|
|
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
detections = []
|
|
for contour in contours:
|
|
area = cv2.contourArea(contour)
|
|
if area >= min_area:
|
|
x, y, w, h = cv2.boundingRect(contour)
|
|
detections.append(Detection(
|
|
x=x, y=y, width=w, height=h,
|
|
confidence=area / (w * h), label=label,
|
|
))
|
|
|
|
return detections
|
|
|
|
def _nms(self, detections: List[Detection], distance_threshold: int) -> List[Detection]:
|
|
"""Simple non-maximum suppression by distance."""
|
|
if not detections:
|
|
return []
|
|
|
|
detections.sort(key=lambda d: d.confidence, reverse=True)
|
|
kept = []
|
|
|
|
for det in detections:
|
|
too_close = False
|
|
for k in kept:
|
|
dx = abs(det.center[0] - k.center[0])
|
|
dy = abs(det.center[1] - k.center[1])
|
|
if dx < distance_threshold and dy < distance_threshold:
|
|
too_close = True
|
|
break
|
|
if not too_close:
|
|
kept.append(det)
|
|
|
|
return kept
|