add KI
This commit is contained in:
1
drone_pilot/__init__.py
Normal file
1
drone_pilot/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Init for drone_pilot package
|
||||
28
drone_pilot/config.py
Normal file
28
drone_pilot/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# drone_pilot/config.py
|
||||
class Config:
|
||||
WIN_NAME = "Tello AI Pilot v2.0 (High Speed)"
|
||||
WIDTH, HEIGHT = 1024, 720
|
||||
|
||||
TARGET_ALTITUDE = 1.5
|
||||
TARGET_FACE_SIZE = 180
|
||||
TARGET_PERSON_SIZE = 400
|
||||
ALT_THRESHOLD = 0.12
|
||||
|
||||
YAW_GAIN = 0.12 # Reduced for smoother rotation
|
||||
FORWARD_GAIN = 1.5 # Kept high for fast pursuit
|
||||
ALT_GAIN = 40
|
||||
|
||||
DEPTH_THRESHOLD = 0.90
|
||||
OBSTACLE_TOF_CM = 70
|
||||
FACE_DEADZONE = 30 # Slightly larger deadzone for stability
|
||||
FACE_ROT_ONLY = 100
|
||||
PERSON_CONF_THRESHOLD = 0.5
|
||||
SMOOTHING_ALPHA = 0.35 # High directness, but slightly more damped than extreme
|
||||
|
||||
class Colors:
|
||||
GREEN = (0, 255, 0)
|
||||
RED = (0, 0, 255)
|
||||
BLUE = (255, 0, 0)
|
||||
WHITE = (255, 255, 255)
|
||||
BLACK = (0, 0, 0)
|
||||
HUD_BG = (10, 10, 10)
|
||||
126
drone_pilot/flight.py
Normal file
126
drone_pilot/flight.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict
|
||||
from .config import Config
|
||||
|
||||
class FlightController:
|
||||
def __init__(self):
|
||||
self.last_sent_rc = [0, 0, 0, 0]
|
||||
self.smooth_face = None
|
||||
self.search_start = time.time()
|
||||
self.status = "INITIALIZING"
|
||||
|
||||
# Memory for lost targets
|
||||
self.last_target_side = 0 # -1 for left, 1 for right
|
||||
self.lost_time = 0
|
||||
|
||||
def calculate(self,
|
||||
faces: List[Tuple],
|
||||
is_manual: bool,
|
||||
emergency_stop: bool,
|
||||
is_locked: bool,
|
||||
locked_person: Tuple,
|
||||
current_height: float,
|
||||
target_altitude: float,
|
||||
tof: int,
|
||||
zones: Dict[str, bool],
|
||||
zone_scores: Dict[str, float],
|
||||
manual_rc: Tuple[int, int, int, int]) -> Tuple[Tuple[int, int, int, int], str]:
|
||||
|
||||
lr, fb, ud, yv = 0, 0, 0, 0
|
||||
|
||||
# Face smoothing for UI/Visuals
|
||||
if len(faces) > 0:
|
||||
target = max(faces, key=lambda f: f[2] * f[3])
|
||||
if self.smooth_face is None: self.smooth_face = target
|
||||
else:
|
||||
self.smooth_face = tuple(int(self.smooth_face[i]*0.8 + target[i]*0.2) for i in range(4))
|
||||
else:
|
||||
self.smooth_face = None
|
||||
|
||||
if emergency_stop:
|
||||
self.status = "EMERGENCY STOP"
|
||||
return (0, 0, 0, 0), self.status
|
||||
|
||||
# Obstacle Avoidance (always active if flying)
|
||||
center_blocked = zones["CENTER"] or tof < Config.OBSTACLE_TOF_CM
|
||||
if center_blocked:
|
||||
self.status = "AVOIDING OBSTACLE"
|
||||
yv = 80 if zone_scores["LEFT"] < zone_scores["RIGHT"] else -80
|
||||
fb = -30
|
||||
return self._smooth(lr, fb, ud, yv)
|
||||
|
||||
if is_manual:
|
||||
self.status = "MANUAL CONTROL"
|
||||
lr, fb, m_ud, yv = manual_rc
|
||||
if abs(m_ud) > 0: ud = m_ud
|
||||
return self._smooth(lr, fb, ud, yv)
|
||||
|
||||
# AI LOGIC
|
||||
if is_locked:
|
||||
if locked_person is not None:
|
||||
# Target is visible -> Normal Pursuit
|
||||
self.search_start = time.time()
|
||||
self.lost_time = 0
|
||||
(x, y, w, h) = locked_person
|
||||
center_x = x + w // 2
|
||||
err_x = center_x - (Config.WIDTH // 2)
|
||||
|
||||
# Remember which side it was on
|
||||
self.last_target_side = 1 if err_x > 0 else -1
|
||||
|
||||
# Rotation (Yaw) - FULL SPEED CAPABLE
|
||||
if abs(err_x) > Config.FACE_DEADZONE:
|
||||
yv = int(np.clip(Config.YAW_GAIN * err_x, -100, 100))
|
||||
|
||||
# Forward/Backward pursuit - EXTREME SPEED
|
||||
alignment_factor = max(0.4, 1.0 - (abs(err_x) / Config.FACE_ROT_ONLY))
|
||||
target_fb = int(np.clip(Config.FORWARD_GAIN * (Config.TARGET_PERSON_SIZE - w), -90, 90))
|
||||
fb = int(target_fb * alignment_factor)
|
||||
|
||||
self.status = "PURSUIT: EXTREME"
|
||||
else:
|
||||
# Target is LOST -> Rapid Search logic
|
||||
if self.lost_time == 0: self.lost_time = time.time()
|
||||
elapsed = time.time() - self.lost_time
|
||||
|
||||
if elapsed < 10.0: # Search longer and faster
|
||||
yv = 80 * self.last_target_side
|
||||
self.status = f"LOST TARGET: RAPID SCAN {'RIGHT' if self.last_target_side > 0 else 'LEFT'}"
|
||||
else:
|
||||
self.status = "TARGET LOST: AGGRESSIVE PATROL"
|
||||
yv = 60
|
||||
|
||||
elif self.smooth_face is not None:
|
||||
# Face found but not locked
|
||||
(x, y, w, h) = self.smooth_face
|
||||
err_x = (x + w // 2) - (Config.WIDTH // 2)
|
||||
if abs(err_x) > Config.FACE_DEADZONE:
|
||||
yv = int(np.clip(Config.YAW_GAIN * err_x, -80, 80))
|
||||
self.status = "AWAITING LOCK"
|
||||
else:
|
||||
# Patrol mode - faster
|
||||
elapsed = (time.time() - self.search_start) % 6.0
|
||||
if elapsed < 2.0:
|
||||
self.status = "PATROL: DASH"
|
||||
fb = 40
|
||||
else:
|
||||
self.status = "PATROL: FAST SPIN"
|
||||
yv = 60
|
||||
|
||||
return self._smooth(lr, fb, ud, yv)
|
||||
|
||||
def _smooth(self, lr, fb, ud, yv):
|
||||
alpha = Config.SMOOTHING_ALPHA
|
||||
slr = int(self.last_sent_rc[0] * (1-alpha) + lr * alpha)
|
||||
sfb = int(self.last_sent_rc[1] * (1-alpha) + fb * alpha)
|
||||
sud = int(self.last_sent_rc[2] * (1-alpha) + ud * alpha)
|
||||
syv = int(self.last_sent_rc[3] * (1-alpha) + yv * alpha)
|
||||
|
||||
if abs(slr) < 2: slr = 0
|
||||
if abs(sfb) < 2: sfb = 0
|
||||
if abs(sud) < 2: sud = 0
|
||||
if abs(syv) < 2: syv = 0
|
||||
|
||||
self.last_sent_rc = [slr, sfb, sud, syv]
|
||||
return (slr, sfb, sud, syv), self.status
|
||||
302
drone_pilot/main.py
Normal file
302
drone_pilot/main.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from tello_sim_client import TelloSimClient
|
||||
from .config import Config, Colors
|
||||
from .vision import DroneVision
|
||||
from .flight import FlightController
|
||||
from .ui import HUD
|
||||
|
||||
class FaceTrackingApp:
|
||||
def __init__(self, use_real_tello: bool = False):
|
||||
print(f"\n[System] Initializing Tello AI Pilot (Real Tello: {use_real_tello})")
|
||||
|
||||
self.tello = TelloSimClient(use_real_tello=use_real_tello)
|
||||
self.tello.connect()
|
||||
self.tello.streamon()
|
||||
|
||||
self.vision = DroneVision()
|
||||
self.flight_controller = FlightController()
|
||||
|
||||
# State Management
|
||||
self.is_running = True
|
||||
self.is_manual = True
|
||||
self.is_locked = False
|
||||
self.is_taking_off = False
|
||||
self.is_flying = False
|
||||
self.takeoff_error = False
|
||||
self.locked_person = None
|
||||
self.locked_person_features = None
|
||||
self.lock_trigger = False
|
||||
self.emergency_stop = False
|
||||
|
||||
# Manual Override States
|
||||
self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0
|
||||
self.is_rotating = False
|
||||
|
||||
self.current_height = 0.0
|
||||
self.current_yaw = 0.0
|
||||
self.target_altitude = Config.TARGET_ALTITUDE
|
||||
self.last_rc_time = 0.0
|
||||
self._last_heartbeat = 0.0
|
||||
self._prev_rc = [0, 0, 0, 0]
|
||||
|
||||
# AI Threading
|
||||
self.ai_lock = threading.Lock()
|
||||
self.latest_frame = None
|
||||
self.ai_results = {
|
||||
"faces": [],
|
||||
"persons": [],
|
||||
"depth_map_vis": None,
|
||||
"zones": {"LEFT": False, "CENTER": False, "RIGHT": False},
|
||||
"zone_scores": {"LEFT": 0.0, "CENTER": 0.0, "RIGHT": 0.0},
|
||||
"reid_target_box": None
|
||||
}
|
||||
|
||||
cv2.namedWindow(Config.WIN_NAME, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(Config.WIN_NAME, Config.WIDTH, Config.HEIGHT)
|
||||
cv2.setMouseCallback(Config.WIN_NAME, self._on_mouse)
|
||||
|
||||
def _on_mouse(self, event, x, y, flags, param):
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
with self.ai_lock:
|
||||
faces = self.ai_results["faces"]
|
||||
persons = self.ai_results["persons"]
|
||||
frame = self.latest_frame.copy() if self.latest_frame is not None else None
|
||||
|
||||
clicked_target = None
|
||||
for pbox in persons:
|
||||
px, py, pw, ph = pbox
|
||||
if px < x < px+pw and py < y < py+ph:
|
||||
clicked_target = pbox; break
|
||||
|
||||
if not clicked_target:
|
||||
for fbox in faces:
|
||||
fx, fy, fw, fh = fbox
|
||||
if fx < x < fx+fw and fy < y < fy+fh:
|
||||
clicked_target = fbox; break
|
||||
|
||||
if clicked_target and frame is not None:
|
||||
feat = self.vision.extract_person_features(frame, clicked_target)
|
||||
if feat is not None:
|
||||
print(f"[System] Manual Lock-ON at ({x}, {y})")
|
||||
self.locked_person = clicked_target
|
||||
self.locked_person_features = feat
|
||||
self.is_locked = True
|
||||
self.lock_trigger = False
|
||||
self.is_manual = False
|
||||
else:
|
||||
print("[System] Manual Unlock")
|
||||
self.is_locked = False
|
||||
self.locked_person = None
|
||||
self.locked_person_features = None
|
||||
|
||||
def update_telemetry(self):
|
||||
try:
|
||||
self.current_yaw = float(self.tello.get_yaw())
|
||||
raw_h = float(self.tello.get_height())
|
||||
self.current_height = raw_h if abs(raw_h) < 10 else raw_h / 100.0
|
||||
except: pass
|
||||
|
||||
def ai_worker(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
frame_to_process = None
|
||||
with self.ai_lock:
|
||||
if self.latest_frame is not None:
|
||||
frame_to_process = self.latest_frame.copy()
|
||||
|
||||
if frame_to_process is not None:
|
||||
faces = self.vision.detect_faces(frame_to_process)
|
||||
persons = self.vision.detect_persons(frame_to_process)
|
||||
self.vision.estimate_depth_and_radar(frame_to_process)
|
||||
|
||||
reid_match = None
|
||||
if self.is_locked and self.locked_person_features is not None:
|
||||
max_sim = 0
|
||||
for pbox in persons:
|
||||
feat = self.vision.extract_person_features(frame_to_process, pbox)
|
||||
sim = self.vision.compare_features(self.locked_person_features, feat)
|
||||
if sim > max_sim:
|
||||
max_sim = sim
|
||||
reid_match = pbox
|
||||
|
||||
if max_sim > 0.75:
|
||||
new_feat = self.vision.extract_person_features(frame_to_process, reid_match)
|
||||
if new_feat is not None:
|
||||
self.locked_person_features = self.locked_person_features * 0.95 + new_feat * 0.05
|
||||
else:
|
||||
reid_match = None
|
||||
|
||||
with self.ai_lock:
|
||||
self.ai_results["faces"] = faces
|
||||
self.ai_results["persons"] = persons
|
||||
self.ai_results["depth_map_vis"] = self.vision.depth_map_vis
|
||||
self.ai_results["zones"] = self.vision.zones.copy()
|
||||
self.ai_results["zone_scores"] = self.vision.zone_scores.copy()
|
||||
self.ai_results["reid_target_box"] = reid_match
|
||||
|
||||
time.sleep(0.01)
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
|
||||
def run(self):
|
||||
ai_thread = threading.Thread(target=self.ai_worker, daemon=True)
|
||||
ai_thread.start()
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
fr = self.tello.get_frame_read()
|
||||
if fr is None or fr.frame is None:
|
||||
time.sleep(0.01); continue
|
||||
|
||||
frame = cv2.resize(fr.frame.copy(), (Config.WIDTH, Config.HEIGHT))
|
||||
with self.ai_lock:
|
||||
self.latest_frame = frame.copy()
|
||||
|
||||
self.update_telemetry()
|
||||
with self.ai_lock:
|
||||
faces = self.ai_results["faces"]
|
||||
persons = self.ai_results["persons"]
|
||||
depth_map_vis = self.ai_results["depth_map_vis"]
|
||||
zones = self.ai_results["zones"]
|
||||
zone_scores = self.ai_results["zone_scores"]
|
||||
reid_box = self.ai_results["reid_target_box"]
|
||||
|
||||
if self.is_locked and reid_box is not None:
|
||||
self.locked_person = reid_box
|
||||
|
||||
if self.lock_trigger and len(faces) > 0 and len(persons) > 0:
|
||||
(fx, fy, fw, fh) = max(faces, key=lambda f: f[2]*f[3])
|
||||
fcx, fcy = fx + fw//2, fy + fh//2
|
||||
for pbox in persons:
|
||||
px, py, pw, ph = pbox
|
||||
if px < fcx < px+pw and py < fcy < py+ph:
|
||||
feat = self.vision.extract_person_features(frame, pbox)
|
||||
if feat is not None:
|
||||
self.locked_person = pbox
|
||||
self.locked_person_features = feat
|
||||
self.is_locked = True
|
||||
self.lock_trigger = False
|
||||
break
|
||||
|
||||
# UI
|
||||
if self.is_locked and self.locked_person:
|
||||
(x,y,w,h) = self.locked_person
|
||||
cv2.rectangle(frame, (x,y), (x+w,y+h), Colors.BLUE, 3)
|
||||
cv2.putText(frame, "TARGET LOCKED", (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.BLUE, 2)
|
||||
elif self.flight_controller.smooth_face is not None:
|
||||
(x,y,w,h) = self.flight_controller.smooth_face
|
||||
cv2.rectangle(frame, (x,y), (x+w,y+h), Colors.GREEN, 2)
|
||||
cv2.putText(frame, "FACE FOUND", (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.GREEN, 2)
|
||||
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
self._handle_input(key)
|
||||
|
||||
if self.is_flying and not self.is_taking_off and not self.emergency_stop:
|
||||
try: tof = int(self.tello.get_distance_tof())
|
||||
except: tof = 1000
|
||||
|
||||
active_manual_rc = (self.m_lr, self.m_fb, self.m_ud, 60 if self.is_rotating else self.m_yv)
|
||||
|
||||
rc, status = self.flight_controller.calculate(
|
||||
faces=faces,
|
||||
is_manual=self.is_manual,
|
||||
emergency_stop=self.emergency_stop,
|
||||
is_locked=self.is_locked,
|
||||
locked_person=self.locked_person,
|
||||
current_height=self.current_height,
|
||||
target_altitude=0.0,
|
||||
tof=tof,
|
||||
zones=zones,
|
||||
zone_scores=zone_scores,
|
||||
manual_rc=active_manual_rc
|
||||
)
|
||||
|
||||
# Throttle and optimized sending
|
||||
now = time.time()
|
||||
if now - self.last_rc_time >= 0.1:
|
||||
changed = any(abs(rc[i] - self._prev_rc[i]) > 1 for i in range(4))
|
||||
heartbeat = now - self._last_heartbeat >= 0.5
|
||||
if changed or heartbeat:
|
||||
self.tello.send_rc_control(*rc)
|
||||
self._prev_rc = list(rc)
|
||||
self._last_heartbeat = now
|
||||
self.last_rc_time = now
|
||||
else:
|
||||
if self.is_taking_off: status = "TAKING OFF..."
|
||||
elif self.takeoff_error: status = "TAKEOFF ERROR (LOW BAT?)"
|
||||
elif self.emergency_stop: status = "STOPPED"
|
||||
else: status = "READY / ON GROUND"
|
||||
|
||||
try: bat = int(self.tello.get_battery())
|
||||
except: bat = 0
|
||||
|
||||
HUD.draw(frame, status, self.emergency_stop, self.is_locked, self.lock_trigger,
|
||||
self.current_height, 0.0, self.current_yaw, bat, depth_map_vis, zones)
|
||||
|
||||
cv2.imshow(Config.WIN_NAME, frame)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.is_running = False
|
||||
try: self.tello.land()
|
||||
except: pass
|
||||
self.tello.end()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def _handle_takeoff(self):
|
||||
if self.is_taking_off or self.is_flying: return
|
||||
|
||||
self.is_taking_off = True
|
||||
self.takeoff_error = False
|
||||
def _task():
|
||||
try:
|
||||
print("[Flight] Takeoff command sent to SDK...")
|
||||
self.tello.takeoff()
|
||||
time.sleep(3)
|
||||
self.is_flying = True
|
||||
print("[Flight] In Air. Logic active.")
|
||||
except Exception as e:
|
||||
print(f"[Flight Error] Takeoff failed: {e}")
|
||||
self.takeoff_error = True
|
||||
finally:
|
||||
self.is_taking_off = False
|
||||
|
||||
threading.Thread(target=_task, daemon=True).start()
|
||||
|
||||
def _handle_input(self, key: int):
|
||||
if key == 13: self.is_running = False
|
||||
elif key == 32: self.emergency_stop = not self.emergency_stop
|
||||
elif key == ord('m'): self.is_manual = not self.is_manual
|
||||
elif key == ord('k'):
|
||||
self.lock_trigger = not self.lock_trigger
|
||||
self.is_locked = False
|
||||
self.locked_person_features = None
|
||||
elif key == ord('t'):
|
||||
if not self.is_flying and not self.is_taking_off:
|
||||
self.emergency_stop = False
|
||||
self._handle_takeoff()
|
||||
elif key == ord('l'):
|
||||
self.is_flying = False
|
||||
self.is_taking_off = False
|
||||
try: self.tello.land()
|
||||
except: pass
|
||||
elif key == ord('1'): self.is_rotating = not self.is_rotating
|
||||
|
||||
# Reset manual speed
|
||||
self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0
|
||||
|
||||
if self.is_manual and not self.emergency_stop:
|
||||
s = 100 # Maximum manual speed
|
||||
if key == ord('w'): self.m_fb = s
|
||||
elif key == ord('s'): self.m_fb = -s
|
||||
elif key == ord('a'): self.m_lr = -s
|
||||
elif key == ord('d'): self.m_lr = s
|
||||
elif key == ord('r'): self.m_ud = s
|
||||
elif key == ord('f'): self.m_ud = -s
|
||||
elif key == ord('e'): self.m_yv = s
|
||||
elif key == ord('z'): self.m_yv = -s
|
||||
52
drone_pilot/ui.py
Normal file
52
drone_pilot/ui.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
from .config import Config, Colors
|
||||
|
||||
class HUD:
|
||||
@staticmethod
|
||||
def draw(frame: np.ndarray,
|
||||
status: str,
|
||||
emergency_stop: bool,
|
||||
is_locked: bool,
|
||||
lock_trigger: bool,
|
||||
current_height: float,
|
||||
target_altitude: float,
|
||||
current_yaw: float,
|
||||
bat: int,
|
||||
depth_map_vis: np.ndarray,
|
||||
zones: Dict[str, bool]):
|
||||
|
||||
cv2.rectangle(frame, (0, 0), (Config.WIDTH, 90), Colors.HUD_BG, -1)
|
||||
cv2.rectangle(frame, (0, Config.HEIGHT-90), (Config.WIDTH, Config.HEIGHT), Colors.HUD_BG, -1)
|
||||
|
||||
stat_color = Colors.RED if "AVOIDING" in status or emergency_stop else Colors.GREEN
|
||||
if is_locked: stat_color = Colors.BLUE
|
||||
|
||||
cv2.putText(frame, f"STATUS: {status}", (25, 40), cv2.FONT_HERSHEY_DUPLEX, 1.1, stat_color, 2)
|
||||
cv2.putText(frame, f"ALT: {current_height:.1f}m | BAT: {bat}% | YAW: {int(current_yaw)}o", (25, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.WHITE, 2)
|
||||
|
||||
if depth_map_vis is not None:
|
||||
mini = cv2.resize(depth_map_vis, (200, 150))
|
||||
frame[100:250, 25:225] = mini
|
||||
cv2.rectangle(frame, (25, 100), (225, 250), Colors.WHITE, 2)
|
||||
cv2.putText(frame, "AI DEPTH", (25, 275), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.WHITE, 2)
|
||||
|
||||
radar_x = Config.WIDTH - 220
|
||||
for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]):
|
||||
short_name = name[0]
|
||||
active = zones[name]
|
||||
col = Colors.RED if active else Colors.GREEN
|
||||
cv2.rectangle(frame, (radar_x + i*70, 100), (radar_x + (i+1)*60, 150), col, -1 if active else 2)
|
||||
cv2.putText(frame, short_name, (radar_x + 20 + i*70, 135), cv2.FONT_HERSHEY_DUPLEX, 0.8, Colors.WHITE, 2)
|
||||
|
||||
stop_col = Colors.RED if not emergency_stop else (0, 255, 255)
|
||||
cv2.rectangle(frame, (Config.WIDTH - 250, Config.HEIGHT - 80), (Config.WIDTH - 25, Config.HEIGHT - 10), stop_col, -1)
|
||||
cv2.putText(frame, "STOP", (Config.WIDTH - 200, Config.HEIGHT - 35), cv2.FONT_HERSHEY_DUPLEX, 0.8, Colors.WHITE, 2)
|
||||
|
||||
keys = ["T:Takeoff", "L:Land", "M:Manual", "K:LockOn", "R:Up", "F:Down", "1:Rotate", "Space:Stop", "ENTER:Quit"]
|
||||
for i, k in enumerate(keys):
|
||||
cv2.putText(frame, k, (25 + (i%5)*200, Config.HEIGHT - 55 + (i//5)*35), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.WHITE, 2)
|
||||
|
||||
if lock_trigger:
|
||||
cv2.putText(frame, "LOCK TRIGGER ACTIVE: FINDING FACE...", (25, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.BLUE, 2)
|
||||
218
drone_pilot/vision.py
Normal file
218
drone_pilot/vision.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Dict
|
||||
from .config import Config
|
||||
|
||||
class DroneVision:
|
||||
def __init__(self):
|
||||
# 1. Face Detection: YuNet (ONNX) + Haar Cascade Fallback
|
||||
self.face_detector = None
|
||||
try:
|
||||
self.face_detector = cv2.FaceDetectorYN.create(
|
||||
"models/face_detection_yunet.onnx",
|
||||
"",
|
||||
(320, 320),
|
||||
0.4,
|
||||
0.3,
|
||||
5000
|
||||
)
|
||||
print("[AI] YuNet Face Detector loaded.")
|
||||
except Exception as e:
|
||||
print(f"[AI Warning] YuNet failed to load: {e}")
|
||||
|
||||
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
||||
self.profile_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_profileface.xml')
|
||||
|
||||
# 2. ONNX Runtime Sessions
|
||||
try:
|
||||
providers = ['CPUExecutionProvider']
|
||||
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
|
||||
self.depth_session = ort.InferenceSession("models/midas_small.onnx", providers=providers)
|
||||
self.reid_session = ort.InferenceSession("models/reid_mobilenet.onnx", providers=providers)
|
||||
print(f"[AI] Depth & ReID (ONNX) initialized.")
|
||||
except Exception as e:
|
||||
print(f"[AI Error] ONNX initialization failed: {e}")
|
||||
self.depth_session = None
|
||||
self.reid_session = None
|
||||
|
||||
# 3. Person Detection (SSD Lite Torch)
|
||||
try:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
from torchvision.models.detection import ssdlite320_mobilenet_v3_large, SSDLite320_MobileNet_V3_Large_Weights
|
||||
weights = SSDLite320_MobileNet_V3_Large_Weights.DEFAULT
|
||||
self.person_model = ssdlite320_mobilenet_v3_large(weights=weights, box_score_thresh=Config.PERSON_CONF_THRESHOLD).to(self.device)
|
||||
self.person_model.eval()
|
||||
self.person_preprocess = weights.transforms()
|
||||
except Exception as e:
|
||||
print(f"[AI Error] Person detector failed: {e}")
|
||||
self.person_model = None
|
||||
|
||||
self.depth_map_vis = None
|
||||
self.zones = {"LEFT": False, "CENTER": False, "RIGHT": False}
|
||||
self.zone_scores = {"LEFT": 0.0, "CENTER": 0.0, "RIGHT": 0.0}
|
||||
|
||||
def _detect_haar(self, img_gray: np.ndarray) -> List[Tuple]:
|
||||
faces = list(self.face_cascade.detectMultiScale(img_gray, 1.1, 7, minSize=(30, 30)))
|
||||
profiles = self.profile_cascade.detectMultiScale(img_gray, 1.1, 8, minSize=(35, 35))
|
||||
for p in profiles:
|
||||
is_new = True
|
||||
px, py, pw, ph = p
|
||||
for (fx, fy, fw, fh) in faces:
|
||||
if abs(px - fx) < fw/2 and abs(py - fy) < fh/2:
|
||||
is_new = False; break
|
||||
if is_new: faces.append(tuple(p))
|
||||
return faces
|
||||
|
||||
def detect_faces(self, frame: np.ndarray) -> List[Tuple]:
|
||||
h, w = frame.shape[:2]
|
||||
faces = []
|
||||
|
||||
# YuNet Detection
|
||||
if self.face_detector is not None:
|
||||
try:
|
||||
self.face_detector.setInputSize((w, h))
|
||||
_, detections = self.face_detector.detect(frame)
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
faces.append(tuple(det[0:4].astype(int)))
|
||||
except: pass
|
||||
|
||||
# Haar Fallback
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
gray = cv2.GaussianBlur(gray, (3, 3), 0)
|
||||
haar_faces = self._detect_haar(gray)
|
||||
|
||||
for hf in haar_faces:
|
||||
is_new = True
|
||||
hx, hy, hw, hh = hf
|
||||
for (fx, fy, fw, fh) in faces:
|
||||
if abs(hx - fx) < fw/2 and abs(hy - fy) < fh/2:
|
||||
is_new = False; break
|
||||
if is_new: faces.append(hf)
|
||||
|
||||
# Center Zoom for extreme distance
|
||||
zh, zw = int(h * 0.4), int(w * 0.4)
|
||||
cy, cx = h // 2, w // 2
|
||||
y1, y2 = cy - zh // 2, cy + zh // 2
|
||||
x1, x2 = cx - zw // 2, cx + zw // 2
|
||||
center_crop = frame[y1:y2, x1:x2]
|
||||
center_upscaled = cv2.resize(center_crop, (zw * 2, zh * 2), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
# In Zoom mode, use YuNet if available, otherwise Haar
|
||||
if self.face_detector is not None:
|
||||
try:
|
||||
uh, uw = center_upscaled.shape[:2]
|
||||
self.face_detector.setInputSize((uw, uh))
|
||||
_, zoom_detections = self.face_detector.detect(center_upscaled)
|
||||
if zoom_detections is not None:
|
||||
for det in zoom_detections:
|
||||
zx, zy, zw_f, zh_f = det[0:4]
|
||||
rx, ry = int(zx / 2) + x1, int(zy / 2) + y1
|
||||
rw, rh = int(zw_f / 2), int(zh_f / 2)
|
||||
is_new = True
|
||||
for (fx, fy, fw, fh) in faces:
|
||||
if abs(rx - fx) < fw/2 and abs(ry - fy) < fh/2:
|
||||
is_new = False; break
|
||||
if is_new: faces.append((rx, ry, rw, rh))
|
||||
except: pass
|
||||
|
||||
# Always run Haar on zoom for robustness
|
||||
zoom_gray = cv2.cvtColor(center_upscaled, cv2.COLOR_BGR2GRAY)
|
||||
zoom_haar = self._detect_haar(zoom_gray)
|
||||
for (zx, zy, zw_f, zh_f) in zoom_haar:
|
||||
rx, ry = int(zx / 2) + x1, int(zy / 2) + y1
|
||||
rw, rh = int(zw_f / 2), int(zh_f / 2)
|
||||
is_new = True
|
||||
for (fx, fy, fw, fh) in faces:
|
||||
if abs(rx - fx) < fw/2 and abs(ry - fy) < fh/2:
|
||||
is_new = False; break
|
||||
if is_new: faces.append((rx, ry, rw, rh))
|
||||
|
||||
return faces
|
||||
|
||||
def detect_persons(self, frame: np.ndarray) -> List[Tuple]:
|
||||
if self.person_model is None: return []
|
||||
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img_rgb)
|
||||
input_tensor = self.person_preprocess(pil_img).to(self.device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
output = self.person_model(input_tensor)[0]
|
||||
|
||||
persons = []
|
||||
for i in range(len(output['labels'])):
|
||||
if output['labels'][i] == 1 and output['scores'][i] > Config.PERSON_CONF_THRESHOLD:
|
||||
box = output['boxes'][i].cpu().numpy().astype(int)
|
||||
persons.append((box[0], box[1], box[2]-box[0], box[3]-box[1]))
|
||||
return persons
|
||||
|
||||
def extract_person_features(self, frame: np.ndarray, box: Tuple) -> np.ndarray:
|
||||
if self.reid_session is None: return None
|
||||
try:
|
||||
x, y, w, h = box
|
||||
x1, y1 = max(0, x), max(0, y)
|
||||
x2, y2 = min(frame.shape[1], x+w), min(frame.shape[0], y+h)
|
||||
if x2 <= x1 or y2 <= y1: return None
|
||||
|
||||
crop = frame[y1:y2, x1:x2]
|
||||
img_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||
img_resized = cv2.resize(img_rgb, (224, 224))
|
||||
img_float = img_resized.astype(np.float32) / 255.0
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
img_norm = (img_float - mean) / std
|
||||
|
||||
img_input = np.transpose(img_norm, (2, 0, 1)).astype(np.float32)
|
||||
img_input = np.expand_dims(img_input, axis=0)
|
||||
|
||||
ort_inputs = {self.reid_session.get_inputs()[0].name: img_input}
|
||||
features = self.reid_session.run(None, ort_inputs)[0]
|
||||
return features.flatten()
|
||||
except: return None
|
||||
|
||||
def compare_features(self, feat1: np.ndarray, feat2: np.ndarray) -> float:
|
||||
if feat1 is None or feat2 is None: return 0.0
|
||||
try:
|
||||
norm1 = np.linalg.norm(feat1)
|
||||
norm2 = np.linalg.norm(feat2)
|
||||
if norm1 == 0 or norm2 == 0: return 0.0
|
||||
return np.dot(feat1, feat2) / (norm1 * norm2)
|
||||
except: return 0.0
|
||||
|
||||
def estimate_depth_and_radar(self, frame: np.ndarray):
|
||||
if self.depth_session is None: return
|
||||
try:
|
||||
h, w = frame.shape[:2]
|
||||
slice_y1, slice_y2 = int(h*0.25), int(h*0.75)
|
||||
analysis_area = frame[slice_y1:slice_y2, :]
|
||||
img_rgb = cv2.cvtColor(analysis_area, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img_resized = cv2.resize(img_rgb, (256, 256))
|
||||
img_float = img_resized.astype(np.float32) / 255.0
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
img_norm = (img_float - mean) / std
|
||||
img_input = np.transpose(img_norm, (2, 0, 1)).astype(np.float32)
|
||||
img_input = np.expand_dims(img_input, axis=0)
|
||||
|
||||
ort_inputs = {self.depth_session.get_inputs()[0].name: img_input}
|
||||
pred = self.depth_session.run(None, ort_inputs)[0][0]
|
||||
|
||||
pred_upscaled = cv2.resize(pred, (analysis_area.shape[1], analysis_area.shape[0]), interpolation=cv2.INTER_CUBIC)
|
||||
out_norm = cv2.normalize(pred_upscaled, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
self.depth_map_vis = cv2.applyColorMap(out_norm, cv2.COLORMAP_MAGMA)
|
||||
depth_map = pred_upscaled / (np.max(pred_upscaled) + 1e-5)
|
||||
|
||||
dh, dw = depth_map.shape
|
||||
zone_w = dw // 3
|
||||
for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]):
|
||||
zx1, zx2 = i * zone_w, (i + 1) * zone_w
|
||||
score = np.mean(depth_map[dh//4:3*dh//4, zx1:zx2])
|
||||
self.zone_scores[name] = score
|
||||
self.zones[name] = score > Config.DEPTH_THRESHOLD
|
||||
except: pass
|
||||
Reference in New Issue
Block a user