diff --git a/drone_pilot/__init__.py b/drone_pilot/__init__.py new file mode 100644 index 0000000..939a53f --- /dev/null +++ b/drone_pilot/__init__.py @@ -0,0 +1 @@ +# Init for drone_pilot package \ No newline at end of file diff --git a/drone_pilot/config.py b/drone_pilot/config.py new file mode 100644 index 0000000..bc47fda --- /dev/null +++ b/drone_pilot/config.py @@ -0,0 +1,28 @@ +# drone_pilot/config.py +class Config: + WIN_NAME = "Tello AI Pilot v2.0 (High Speed)" + WIDTH, HEIGHT = 1024, 720 + + TARGET_ALTITUDE = 1.5 + TARGET_FACE_SIZE = 180 + TARGET_PERSON_SIZE = 400 + ALT_THRESHOLD = 0.12 + + YAW_GAIN = 0.12 # Reduced for smoother rotation + FORWARD_GAIN = 1.5 # Kept high for fast pursuit + ALT_GAIN = 40 + + DEPTH_THRESHOLD = 0.90 + OBSTACLE_TOF_CM = 70 + FACE_DEADZONE = 30 # Slightly larger deadzone for stability + FACE_ROT_ONLY = 100 + PERSON_CONF_THRESHOLD = 0.5 + SMOOTHING_ALPHA = 0.35 # High directness, but slightly more damped than extreme + +class Colors: + GREEN = (0, 255, 0) + RED = (0, 0, 255) + BLUE = (255, 0, 0) + WHITE = (255, 255, 255) + BLACK = (0, 0, 0) + HUD_BG = (10, 10, 10) diff --git a/drone_pilot/flight.py b/drone_pilot/flight.py new file mode 100644 index 0000000..13280d9 --- /dev/null +++ b/drone_pilot/flight.py @@ -0,0 +1,126 @@ +import time +import numpy as np +from typing import List, Tuple, Dict +from .config import Config + +class FlightController: + def __init__(self): + self.last_sent_rc = [0, 0, 0, 0] + self.smooth_face = None + self.search_start = time.time() + self.status = "INITIALIZING" + + # Memory for lost targets + self.last_target_side = 0 # -1 for left, 1 for right + self.lost_time = 0 + + def calculate(self, + faces: List[Tuple], + is_manual: bool, + emergency_stop: bool, + is_locked: bool, + locked_person: Tuple, + current_height: float, + target_altitude: float, + tof: int, + zones: Dict[str, bool], + zone_scores: Dict[str, float], + manual_rc: Tuple[int, int, int, int]) -> Tuple[Tuple[int, int, int, int], str]: + + lr, fb, ud, yv = 0, 0, 0, 0 + + # Face smoothing for UI/Visuals + if len(faces) > 0: + target = max(faces, key=lambda f: f[2] * f[3]) + if self.smooth_face is None: self.smooth_face = target + else: + self.smooth_face = tuple(int(self.smooth_face[i]*0.8 + target[i]*0.2) for i in range(4)) + else: + self.smooth_face = None + + if emergency_stop: + self.status = "EMERGENCY STOP" + return (0, 0, 0, 0), self.status + + # Obstacle Avoidance (always active if flying) + center_blocked = zones["CENTER"] or tof < Config.OBSTACLE_TOF_CM + if center_blocked: + self.status = "AVOIDING OBSTACLE" + yv = 80 if zone_scores["LEFT"] < zone_scores["RIGHT"] else -80 + fb = -30 + return self._smooth(lr, fb, ud, yv) + + if is_manual: + self.status = "MANUAL CONTROL" + lr, fb, m_ud, yv = manual_rc + if abs(m_ud) > 0: ud = m_ud + return self._smooth(lr, fb, ud, yv) + + # AI LOGIC + if is_locked: + if locked_person is not None: + # Target is visible -> Normal Pursuit + self.search_start = time.time() + self.lost_time = 0 + (x, y, w, h) = locked_person + center_x = x + w // 2 + err_x = center_x - (Config.WIDTH // 2) + + # Remember which side it was on + self.last_target_side = 1 if err_x > 0 else -1 + + # Rotation (Yaw) - FULL SPEED CAPABLE + if abs(err_x) > Config.FACE_DEADZONE: + yv = int(np.clip(Config.YAW_GAIN * err_x, -100, 100)) + + # Forward/Backward pursuit - EXTREME SPEED + alignment_factor = max(0.4, 1.0 - (abs(err_x) / Config.FACE_ROT_ONLY)) + target_fb = int(np.clip(Config.FORWARD_GAIN * (Config.TARGET_PERSON_SIZE - w), -90, 90)) + fb = int(target_fb * alignment_factor) + + self.status = "PURSUIT: EXTREME" + else: + # Target is LOST -> Rapid Search logic + if self.lost_time == 0: self.lost_time = time.time() + elapsed = time.time() - self.lost_time + + if elapsed < 10.0: # Search longer and faster + yv = 80 * self.last_target_side + self.status = f"LOST TARGET: RAPID SCAN {'RIGHT' if self.last_target_side > 0 else 'LEFT'}" + else: + self.status = "TARGET LOST: AGGRESSIVE PATROL" + yv = 60 + + elif self.smooth_face is not None: + # Face found but not locked + (x, y, w, h) = self.smooth_face + err_x = (x + w // 2) - (Config.WIDTH // 2) + if abs(err_x) > Config.FACE_DEADZONE: + yv = int(np.clip(Config.YAW_GAIN * err_x, -80, 80)) + self.status = "AWAITING LOCK" + else: + # Patrol mode - faster + elapsed = (time.time() - self.search_start) % 6.0 + if elapsed < 2.0: + self.status = "PATROL: DASH" + fb = 40 + else: + self.status = "PATROL: FAST SPIN" + yv = 60 + + return self._smooth(lr, fb, ud, yv) + + def _smooth(self, lr, fb, ud, yv): + alpha = Config.SMOOTHING_ALPHA + slr = int(self.last_sent_rc[0] * (1-alpha) + lr * alpha) + sfb = int(self.last_sent_rc[1] * (1-alpha) + fb * alpha) + sud = int(self.last_sent_rc[2] * (1-alpha) + ud * alpha) + syv = int(self.last_sent_rc[3] * (1-alpha) + yv * alpha) + + if abs(slr) < 2: slr = 0 + if abs(sfb) < 2: sfb = 0 + if abs(sud) < 2: sud = 0 + if abs(syv) < 2: syv = 0 + + self.last_sent_rc = [slr, sfb, sud, syv] + return (slr, sfb, sud, syv), self.status diff --git a/drone_pilot/main.py b/drone_pilot/main.py new file mode 100644 index 0000000..585abca --- /dev/null +++ b/drone_pilot/main.py @@ -0,0 +1,302 @@ +import cv2 +import numpy as np +import threading +import time +import traceback +from tello_sim_client import TelloSimClient +from .config import Config, Colors +from .vision import DroneVision +from .flight import FlightController +from .ui import HUD + +class FaceTrackingApp: + def __init__(self, use_real_tello: bool = False): + print(f"\n[System] Initializing Tello AI Pilot (Real Tello: {use_real_tello})") + + self.tello = TelloSimClient(use_real_tello=use_real_tello) + self.tello.connect() + self.tello.streamon() + + self.vision = DroneVision() + self.flight_controller = FlightController() + + # State Management + self.is_running = True + self.is_manual = True + self.is_locked = False + self.is_taking_off = False + self.is_flying = False + self.takeoff_error = False + self.locked_person = None + self.locked_person_features = None + self.lock_trigger = False + self.emergency_stop = False + + # Manual Override States + self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0 + self.is_rotating = False + + self.current_height = 0.0 + self.current_yaw = 0.0 + self.target_altitude = Config.TARGET_ALTITUDE + self.last_rc_time = 0.0 + self._last_heartbeat = 0.0 + self._prev_rc = [0, 0, 0, 0] + + # AI Threading + self.ai_lock = threading.Lock() + self.latest_frame = None + self.ai_results = { + "faces": [], + "persons": [], + "depth_map_vis": None, + "zones": {"LEFT": False, "CENTER": False, "RIGHT": False}, + "zone_scores": {"LEFT": 0.0, "CENTER": 0.0, "RIGHT": 0.0}, + "reid_target_box": None + } + + cv2.namedWindow(Config.WIN_NAME, cv2.WINDOW_NORMAL) + cv2.resizeWindow(Config.WIN_NAME, Config.WIDTH, Config.HEIGHT) + cv2.setMouseCallback(Config.WIN_NAME, self._on_mouse) + + def _on_mouse(self, event, x, y, flags, param): + if event == cv2.EVENT_LBUTTONDOWN: + with self.ai_lock: + faces = self.ai_results["faces"] + persons = self.ai_results["persons"] + frame = self.latest_frame.copy() if self.latest_frame is not None else None + + clicked_target = None + for pbox in persons: + px, py, pw, ph = pbox + if px < x < px+pw and py < y < py+ph: + clicked_target = pbox; break + + if not clicked_target: + for fbox in faces: + fx, fy, fw, fh = fbox + if fx < x < fx+fw and fy < y < fy+fh: + clicked_target = fbox; break + + if clicked_target and frame is not None: + feat = self.vision.extract_person_features(frame, clicked_target) + if feat is not None: + print(f"[System] Manual Lock-ON at ({x}, {y})") + self.locked_person = clicked_target + self.locked_person_features = feat + self.is_locked = True + self.lock_trigger = False + self.is_manual = False + else: + print("[System] Manual Unlock") + self.is_locked = False + self.locked_person = None + self.locked_person_features = None + + def update_telemetry(self): + try: + self.current_yaw = float(self.tello.get_yaw()) + raw_h = float(self.tello.get_height()) + self.current_height = raw_h if abs(raw_h) < 10 else raw_h / 100.0 + except: pass + + def ai_worker(self): + while self.is_running: + try: + frame_to_process = None + with self.ai_lock: + if self.latest_frame is not None: + frame_to_process = self.latest_frame.copy() + + if frame_to_process is not None: + faces = self.vision.detect_faces(frame_to_process) + persons = self.vision.detect_persons(frame_to_process) + self.vision.estimate_depth_and_radar(frame_to_process) + + reid_match = None + if self.is_locked and self.locked_person_features is not None: + max_sim = 0 + for pbox in persons: + feat = self.vision.extract_person_features(frame_to_process, pbox) + sim = self.vision.compare_features(self.locked_person_features, feat) + if sim > max_sim: + max_sim = sim + reid_match = pbox + + if max_sim > 0.75: + new_feat = self.vision.extract_person_features(frame_to_process, reid_match) + if new_feat is not None: + self.locked_person_features = self.locked_person_features * 0.95 + new_feat * 0.05 + else: + reid_match = None + + with self.ai_lock: + self.ai_results["faces"] = faces + self.ai_results["persons"] = persons + self.ai_results["depth_map_vis"] = self.vision.depth_map_vis + self.ai_results["zones"] = self.vision.zones.copy() + self.ai_results["zone_scores"] = self.vision.zone_scores.copy() + self.ai_results["reid_target_box"] = reid_match + + time.sleep(0.01) + except Exception: + time.sleep(1) + + def run(self): + ai_thread = threading.Thread(target=self.ai_worker, daemon=True) + ai_thread.start() + + try: + while self.is_running: + fr = self.tello.get_frame_read() + if fr is None or fr.frame is None: + time.sleep(0.01); continue + + frame = cv2.resize(fr.frame.copy(), (Config.WIDTH, Config.HEIGHT)) + with self.ai_lock: + self.latest_frame = frame.copy() + + self.update_telemetry() + with self.ai_lock: + faces = self.ai_results["faces"] + persons = self.ai_results["persons"] + depth_map_vis = self.ai_results["depth_map_vis"] + zones = self.ai_results["zones"] + zone_scores = self.ai_results["zone_scores"] + reid_box = self.ai_results["reid_target_box"] + + if self.is_locked and reid_box is not None: + self.locked_person = reid_box + + if self.lock_trigger and len(faces) > 0 and len(persons) > 0: + (fx, fy, fw, fh) = max(faces, key=lambda f: f[2]*f[3]) + fcx, fcy = fx + fw//2, fy + fh//2 + for pbox in persons: + px, py, pw, ph = pbox + if px < fcx < px+pw and py < fcy < py+ph: + feat = self.vision.extract_person_features(frame, pbox) + if feat is not None: + self.locked_person = pbox + self.locked_person_features = feat + self.is_locked = True + self.lock_trigger = False + break + + # UI + if self.is_locked and self.locked_person: + (x,y,w,h) = self.locked_person + cv2.rectangle(frame, (x,y), (x+w,y+h), Colors.BLUE, 3) + cv2.putText(frame, "TARGET LOCKED", (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.BLUE, 2) + elif self.flight_controller.smooth_face is not None: + (x,y,w,h) = self.flight_controller.smooth_face + cv2.rectangle(frame, (x,y), (x+w,y+h), Colors.GREEN, 2) + cv2.putText(frame, "FACE FOUND", (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.GREEN, 2) + + key = cv2.waitKey(1) & 0xFF + self._handle_input(key) + + if self.is_flying and not self.is_taking_off and not self.emergency_stop: + try: tof = int(self.tello.get_distance_tof()) + except: tof = 1000 + + active_manual_rc = (self.m_lr, self.m_fb, self.m_ud, 60 if self.is_rotating else self.m_yv) + + rc, status = self.flight_controller.calculate( + faces=faces, + is_manual=self.is_manual, + emergency_stop=self.emergency_stop, + is_locked=self.is_locked, + locked_person=self.locked_person, + current_height=self.current_height, + target_altitude=0.0, + tof=tof, + zones=zones, + zone_scores=zone_scores, + manual_rc=active_manual_rc + ) + + # Throttle and optimized sending + now = time.time() + if now - self.last_rc_time >= 0.1: + changed = any(abs(rc[i] - self._prev_rc[i]) > 1 for i in range(4)) + heartbeat = now - self._last_heartbeat >= 0.5 + if changed or heartbeat: + self.tello.send_rc_control(*rc) + self._prev_rc = list(rc) + self._last_heartbeat = now + self.last_rc_time = now + else: + if self.is_taking_off: status = "TAKING OFF..." + elif self.takeoff_error: status = "TAKEOFF ERROR (LOW BAT?)" + elif self.emergency_stop: status = "STOPPED" + else: status = "READY / ON GROUND" + + try: bat = int(self.tello.get_battery()) + except: bat = 0 + + HUD.draw(frame, status, self.emergency_stop, self.is_locked, self.lock_trigger, + self.current_height, 0.0, self.current_yaw, bat, depth_map_vis, zones) + + cv2.imshow(Config.WIN_NAME, frame) + + except Exception: + traceback.print_exc() + finally: + self.is_running = False + try: self.tello.land() + except: pass + self.tello.end() + cv2.destroyAllWindows() + + def _handle_takeoff(self): + if self.is_taking_off or self.is_flying: return + + self.is_taking_off = True + self.takeoff_error = False + def _task(): + try: + print("[Flight] Takeoff command sent to SDK...") + self.tello.takeoff() + time.sleep(3) + self.is_flying = True + print("[Flight] In Air. Logic active.") + except Exception as e: + print(f"[Flight Error] Takeoff failed: {e}") + self.takeoff_error = True + finally: + self.is_taking_off = False + + threading.Thread(target=_task, daemon=True).start() + + def _handle_input(self, key: int): + if key == 13: self.is_running = False + elif key == 32: self.emergency_stop = not self.emergency_stop + elif key == ord('m'): self.is_manual = not self.is_manual + elif key == ord('k'): + self.lock_trigger = not self.lock_trigger + self.is_locked = False + self.locked_person_features = None + elif key == ord('t'): + if not self.is_flying and not self.is_taking_off: + self.emergency_stop = False + self._handle_takeoff() + elif key == ord('l'): + self.is_flying = False + self.is_taking_off = False + try: self.tello.land() + except: pass + elif key == ord('1'): self.is_rotating = not self.is_rotating + + # Reset manual speed + self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0 + + if self.is_manual and not self.emergency_stop: + s = 100 # Maximum manual speed + if key == ord('w'): self.m_fb = s + elif key == ord('s'): self.m_fb = -s + elif key == ord('a'): self.m_lr = -s + elif key == ord('d'): self.m_lr = s + elif key == ord('r'): self.m_ud = s + elif key == ord('f'): self.m_ud = -s + elif key == ord('e'): self.m_yv = s + elif key == ord('z'): self.m_yv = -s diff --git a/drone_pilot/ui.py b/drone_pilot/ui.py new file mode 100644 index 0000000..f4d32c5 --- /dev/null +++ b/drone_pilot/ui.py @@ -0,0 +1,52 @@ +import cv2 +import numpy as np +from typing import Dict +from .config import Config, Colors + +class HUD: + @staticmethod + def draw(frame: np.ndarray, + status: str, + emergency_stop: bool, + is_locked: bool, + lock_trigger: bool, + current_height: float, + target_altitude: float, + current_yaw: float, + bat: int, + depth_map_vis: np.ndarray, + zones: Dict[str, bool]): + + cv2.rectangle(frame, (0, 0), (Config.WIDTH, 90), Colors.HUD_BG, -1) + cv2.rectangle(frame, (0, Config.HEIGHT-90), (Config.WIDTH, Config.HEIGHT), Colors.HUD_BG, -1) + + stat_color = Colors.RED if "AVOIDING" in status or emergency_stop else Colors.GREEN + if is_locked: stat_color = Colors.BLUE + + cv2.putText(frame, f"STATUS: {status}", (25, 40), cv2.FONT_HERSHEY_DUPLEX, 1.1, stat_color, 2) + cv2.putText(frame, f"ALT: {current_height:.1f}m | BAT: {bat}% | YAW: {int(current_yaw)}o", (25, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.WHITE, 2) + + if depth_map_vis is not None: + mini = cv2.resize(depth_map_vis, (200, 150)) + frame[100:250, 25:225] = mini + cv2.rectangle(frame, (25, 100), (225, 250), Colors.WHITE, 2) + cv2.putText(frame, "AI DEPTH", (25, 275), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.WHITE, 2) + + radar_x = Config.WIDTH - 220 + for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]): + short_name = name[0] + active = zones[name] + col = Colors.RED if active else Colors.GREEN + cv2.rectangle(frame, (radar_x + i*70, 100), (radar_x + (i+1)*60, 150), col, -1 if active else 2) + cv2.putText(frame, short_name, (radar_x + 20 + i*70, 135), cv2.FONT_HERSHEY_DUPLEX, 0.8, Colors.WHITE, 2) + + stop_col = Colors.RED if not emergency_stop else (0, 255, 255) + cv2.rectangle(frame, (Config.WIDTH - 250, Config.HEIGHT - 80), (Config.WIDTH - 25, Config.HEIGHT - 10), stop_col, -1) + cv2.putText(frame, "STOP", (Config.WIDTH - 200, Config.HEIGHT - 35), cv2.FONT_HERSHEY_DUPLEX, 0.8, Colors.WHITE, 2) + + keys = ["T:Takeoff", "L:Land", "M:Manual", "K:LockOn", "R:Up", "F:Down", "1:Rotate", "Space:Stop", "ENTER:Quit"] + for i, k in enumerate(keys): + cv2.putText(frame, k, (25 + (i%5)*200, Config.HEIGHT - 55 + (i//5)*35), cv2.FONT_HERSHEY_SIMPLEX, 0.6, Colors.WHITE, 2) + + if lock_trigger: + cv2.putText(frame, "LOCK TRIGGER ACTIVE: FINDING FACE...", (25, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, Colors.BLUE, 2) diff --git a/drone_pilot/vision.py b/drone_pilot/vision.py new file mode 100644 index 0000000..5de508e --- /dev/null +++ b/drone_pilot/vision.py @@ -0,0 +1,218 @@ +import cv2 +import torch +import numpy as np +import onnxruntime as ort +from PIL import Image +from typing import List, Tuple, Dict +from .config import Config + +class DroneVision: + def __init__(self): + # 1. Face Detection: YuNet (ONNX) + Haar Cascade Fallback + self.face_detector = None + try: + self.face_detector = cv2.FaceDetectorYN.create( + "models/face_detection_yunet.onnx", + "", + (320, 320), + 0.4, + 0.3, + 5000 + ) + print("[AI] YuNet Face Detector loaded.") + except Exception as e: + print(f"[AI Warning] YuNet failed to load: {e}") + + self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') + self.profile_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_profileface.xml') + + # 2. ONNX Runtime Sessions + try: + providers = ['CPUExecutionProvider'] + if 'CUDAExecutionProvider' in ort.get_available_providers(): + providers.insert(0, 'CUDAExecutionProvider') + + self.depth_session = ort.InferenceSession("models/midas_small.onnx", providers=providers) + self.reid_session = ort.InferenceSession("models/reid_mobilenet.onnx", providers=providers) + print(f"[AI] Depth & ReID (ONNX) initialized.") + except Exception as e: + print(f"[AI Error] ONNX initialization failed: {e}") + self.depth_session = None + self.reid_session = None + + # 3. Person Detection (SSD Lite Torch) + try: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + from torchvision.models.detection import ssdlite320_mobilenet_v3_large, SSDLite320_MobileNet_V3_Large_Weights + weights = SSDLite320_MobileNet_V3_Large_Weights.DEFAULT + self.person_model = ssdlite320_mobilenet_v3_large(weights=weights, box_score_thresh=Config.PERSON_CONF_THRESHOLD).to(self.device) + self.person_model.eval() + self.person_preprocess = weights.transforms() + except Exception as e: + print(f"[AI Error] Person detector failed: {e}") + self.person_model = None + + self.depth_map_vis = None + self.zones = {"LEFT": False, "CENTER": False, "RIGHT": False} + self.zone_scores = {"LEFT": 0.0, "CENTER": 0.0, "RIGHT": 0.0} + + def _detect_haar(self, img_gray: np.ndarray) -> List[Tuple]: + faces = list(self.face_cascade.detectMultiScale(img_gray, 1.1, 7, minSize=(30, 30))) + profiles = self.profile_cascade.detectMultiScale(img_gray, 1.1, 8, minSize=(35, 35)) + for p in profiles: + is_new = True + px, py, pw, ph = p + for (fx, fy, fw, fh) in faces: + if abs(px - fx) < fw/2 and abs(py - fy) < fh/2: + is_new = False; break + if is_new: faces.append(tuple(p)) + return faces + + def detect_faces(self, frame: np.ndarray) -> List[Tuple]: + h, w = frame.shape[:2] + faces = [] + + # YuNet Detection + if self.face_detector is not None: + try: + self.face_detector.setInputSize((w, h)) + _, detections = self.face_detector.detect(frame) + if detections is not None: + for det in detections: + faces.append(tuple(det[0:4].astype(int))) + except: pass + + # Haar Fallback + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + gray = cv2.GaussianBlur(gray, (3, 3), 0) + haar_faces = self._detect_haar(gray) + + for hf in haar_faces: + is_new = True + hx, hy, hw, hh = hf + for (fx, fy, fw, fh) in faces: + if abs(hx - fx) < fw/2 and abs(hy - fy) < fh/2: + is_new = False; break + if is_new: faces.append(hf) + + # Center Zoom for extreme distance + zh, zw = int(h * 0.4), int(w * 0.4) + cy, cx = h // 2, w // 2 + y1, y2 = cy - zh // 2, cy + zh // 2 + x1, x2 = cx - zw // 2, cx + zw // 2 + center_crop = frame[y1:y2, x1:x2] + center_upscaled = cv2.resize(center_crop, (zw * 2, zh * 2), interpolation=cv2.INTER_LANCZOS4) + + # In Zoom mode, use YuNet if available, otherwise Haar + if self.face_detector is not None: + try: + uh, uw = center_upscaled.shape[:2] + self.face_detector.setInputSize((uw, uh)) + _, zoom_detections = self.face_detector.detect(center_upscaled) + if zoom_detections is not None: + for det in zoom_detections: + zx, zy, zw_f, zh_f = det[0:4] + rx, ry = int(zx / 2) + x1, int(zy / 2) + y1 + rw, rh = int(zw_f / 2), int(zh_f / 2) + is_new = True + for (fx, fy, fw, fh) in faces: + if abs(rx - fx) < fw/2 and abs(ry - fy) < fh/2: + is_new = False; break + if is_new: faces.append((rx, ry, rw, rh)) + except: pass + + # Always run Haar on zoom for robustness + zoom_gray = cv2.cvtColor(center_upscaled, cv2.COLOR_BGR2GRAY) + zoom_haar = self._detect_haar(zoom_gray) + for (zx, zy, zw_f, zh_f) in zoom_haar: + rx, ry = int(zx / 2) + x1, int(zy / 2) + y1 + rw, rh = int(zw_f / 2), int(zh_f / 2) + is_new = True + for (fx, fy, fw, fh) in faces: + if abs(rx - fx) < fw/2 and abs(ry - fy) < fh/2: + is_new = False; break + if is_new: faces.append((rx, ry, rw, rh)) + + return faces + + def detect_persons(self, frame: np.ndarray) -> List[Tuple]: + if self.person_model is None: return [] + img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(img_rgb) + input_tensor = self.person_preprocess(pil_img).to(self.device).unsqueeze(0) + with torch.no_grad(): + output = self.person_model(input_tensor)[0] + + persons = [] + for i in range(len(output['labels'])): + if output['labels'][i] == 1 and output['scores'][i] > Config.PERSON_CONF_THRESHOLD: + box = output['boxes'][i].cpu().numpy().astype(int) + persons.append((box[0], box[1], box[2]-box[0], box[3]-box[1])) + return persons + + def extract_person_features(self, frame: np.ndarray, box: Tuple) -> np.ndarray: + if self.reid_session is None: return None + try: + x, y, w, h = box + x1, y1 = max(0, x), max(0, y) + x2, y2 = min(frame.shape[1], x+w), min(frame.shape[0], y+h) + if x2 <= x1 or y2 <= y1: return None + + crop = frame[y1:y2, x1:x2] + img_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + img_resized = cv2.resize(img_rgb, (224, 224)) + img_float = img_resized.astype(np.float32) / 255.0 + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + img_norm = (img_float - mean) / std + + img_input = np.transpose(img_norm, (2, 0, 1)).astype(np.float32) + img_input = np.expand_dims(img_input, axis=0) + + ort_inputs = {self.reid_session.get_inputs()[0].name: img_input} + features = self.reid_session.run(None, ort_inputs)[0] + return features.flatten() + except: return None + + def compare_features(self, feat1: np.ndarray, feat2: np.ndarray) -> float: + if feat1 is None or feat2 is None: return 0.0 + try: + norm1 = np.linalg.norm(feat1) + norm2 = np.linalg.norm(feat2) + if norm1 == 0 or norm2 == 0: return 0.0 + return np.dot(feat1, feat2) / (norm1 * norm2) + except: return 0.0 + + def estimate_depth_and_radar(self, frame: np.ndarray): + if self.depth_session is None: return + try: + h, w = frame.shape[:2] + slice_y1, slice_y2 = int(h*0.25), int(h*0.75) + analysis_area = frame[slice_y1:slice_y2, :] + img_rgb = cv2.cvtColor(analysis_area, cv2.COLOR_BGR2RGB) + + img_resized = cv2.resize(img_rgb, (256, 256)) + img_float = img_resized.astype(np.float32) / 255.0 + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + img_norm = (img_float - mean) / std + img_input = np.transpose(img_norm, (2, 0, 1)).astype(np.float32) + img_input = np.expand_dims(img_input, axis=0) + + ort_inputs = {self.depth_session.get_inputs()[0].name: img_input} + pred = self.depth_session.run(None, ort_inputs)[0][0] + + pred_upscaled = cv2.resize(pred, (analysis_area.shape[1], analysis_area.shape[0]), interpolation=cv2.INTER_CUBIC) + out_norm = cv2.normalize(pred_upscaled, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + self.depth_map_vis = cv2.applyColorMap(out_norm, cv2.COLORMAP_MAGMA) + depth_map = pred_upscaled / (np.max(pred_upscaled) + 1e-5) + + dh, dw = depth_map.shape + zone_w = dw // 3 + for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]): + zx1, zx2 = i * zone_w, (i + 1) * zone_w + score = np.mean(depth_map[dh//4:3*dh//4, zx1:zx2]) + self.zone_scores[name] = score + self.zones[name] = score > Config.DEPTH_THRESHOLD + except: pass diff --git a/face_tracking_orbit.py b/face_tracking_orbit.py deleted file mode 100644 index 006493e..0000000 --- a/face_tracking_orbit.py +++ /dev/null @@ -1,265 +0,0 @@ -import cv2 -import numpy as np -from tello_sim_client import TelloSimClient -import time -import torch # Deep Learning -import torchvision.transforms as T - -# --- Configuration --- -WINDOW_WIDTH = 960 -WINDOW_HEIGHT = 720 -TARGET_FACE_SIZE = 120 -FACE_CENTER_X = WINDOW_WIDTH // 2 -FACE_CENTER_Y = WINDOW_HEIGHT // 2 - -TARGET_ALTITUDE = 1.5 # meters -KP_ALTITUDE = 60 - -# Multi-Zone Radar & DL Config -OBSTACLE_THRESHOLD_TOF = 75 -DEPTH_THRESHOLD = 0.75 # Höherer Schwellenwert (weniger empfindlich) - -# --- Colors (BGR) --- -COLOR_GREEN = (0, 255, 0) -COLOR_RED = (0, 0, 255) -COLOR_BLUE = (255, 0, 0) -COLOR_WHITE = (255, 255, 255) - -class FaceTrackingApp: - def __init__(self): - print("[App] Initializing TelloSimClient...") - self.tello = TelloSimClient() - self.tello.connect() - self.tello.streamon() - - # Load Face Detector - cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' - self.face_cascade = cv2.CascadeClassifier(cascade_path) - - # Load MiDaS Model - print("[AI] Loading MiDaS DL Model...") - try: - self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - self.model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(self.device) - self.model.eval() - self.transform = torch.hub.load("intel-isl/MiDaS", "transforms").small_transform - except Exception as e: - print(f"[AI] Error: {e}") - self.model = None - - self.is_running = True - self.is_manual = False - self.is_orbiting = False - - self.zones = {"LEFT": False, "CENTER": False, "RIGHT": False} - self.zone_scores = {"LEFT": 0.0, "CENTER": 0.0, "RIGHT": 0.0} - self.status = "INIT" - self.depth_map_vis = None - - # Position Tracking - self.pos_x, self.pos_z = 0.0, 0.0 - self.path_history = [(0.0, 0.0)] - self.last_update_time = time.time() - self.current_yaw = 0.0 - self.current_height = 0.0 - - # Velocity Tracking - self.last_sent_rc = [0, 0, 0, 0] # lr, fb, ud, yv - self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0 - - self.avoidance_active = False - self.avoidance_dir = 0 # -1 for right, 1 for left - self.center_blocked = False - self.search_start = time.time() - - cv2.namedWindow("Tello AI Pilot (Stable Mode)", cv2.WINDOW_NORMAL) - cv2.resizeWindow("Tello AI Pilot (Stable Mode)", WINDOW_WIDTH, WINDOW_HEIGHT) - - def estimate_depth(self, frame): - if self.model is None: return None - # Wir analysieren nur den mittleren Streifen (kein Boden, kein Himmel) - analysis_area = frame[200:520, :] # Fokus auf Augenhöhe - img = cv2.cvtColor(analysis_area, cv2.COLOR_BGR2RGB) - input_batch = self.transform(img).to(self.device) - with torch.no_grad(): - pred = self.model(input_batch) - pred = torch.nn.functional.interpolate(pred.unsqueeze(1), size=analysis_area.shape[:2], mode="bicubic", align_corners=False).squeeze() - out = pred.cpu().numpy() - out_norm = cv2.normalize(out, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) - self.depth_map_vis = cv2.applyColorMap(out_norm, cv2.COLORMAP_MAGMA) - return out / (np.max(out) + 1e-5) - - def run_radar_scan(self, depth_map): - if depth_map is None: return - h, w = depth_map.shape - zone_w = w // 3 - for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]): - zx_start, zx_end = i * zone_w, (i + 1) * zone_w - # Durchschnittliche Tiefe im Zentrum der Zone - score = np.mean(depth_map[h//4 : 3*h//4, zx_start:zx_end]) - self.zone_scores[name] = score - self.zones[name] = score > DEPTH_THRESHOLD - - def update_telemetry(self, frame): - try: - self.current_yaw = float(self.tello.get_yaw()) - raw_height = float(self.tello.get_height()) - # Auto-Scale detection (Simulator = m, Real = cm) - self.current_height = raw_height if abs(raw_height) < 10 else raw_height / 100.0 - - dm = self.estimate_depth(frame) - if dm is not None: self.run_radar_scan(dm) - - # Dead Reckoning - now = time.time() - dt = now - self.last_update_time - self.last_update_time = now - - # Simulator Scale (0.6 ist ein guter Annäherungswert für cm/s) - sf = 0.6 - rad = np.radians(self.current_yaw) - lr, fb = self.last_sent_rc[0], self.last_sent_rc[1] - vx = fb * np.sin(rad) + lr * np.cos(rad) - vz = fb * np.cos(rad) - lr * np.sin(rad) - self.pos_x += vx * sf * dt - self.pos_z += vz * sf * dt - - if np.linalg.norm(np.array([self.pos_x, self.pos_z]) - np.array(self.path_history[-1])) > 0.5: - self.path_history.append((self.pos_x, self.pos_z)) - except: pass - - def process_flight(self, faces): - lr, fb, ud, yv = 0, 0, 0, 0 - - # Terrain Following (Altitude in m) - if not self.is_manual: - ud = int(np.clip((TARGET_ALTITUDE - self.current_height) * KP_ALTITUDE, -40, 40)) - - # --- AI PILOT --- - # TOF Sensor Check - try: tof = int(self.tello.get_distance_tof()) - except: tof = 1000 - - self.center_blocked = self.zones["CENTER"] or tof < OBSTACLE_THRESHOLD_TOF - - if self.center_blocked: - self.status = "!!! OBSTACLE !!!" - fb = 0 - - # Stick to a direction if already avoiding - if not self.avoidance_active: - self.avoidance_dir = 1 if self.zone_scores["LEFT"] < self.zone_scores["RIGHT"] else -1 - self.avoidance_active = True - - yv = 45 * self.avoidance_dir - elif self.avoidance_active: - self.status = "RECOVERING" - fb = 15; yv = 20 * self.avoidance_dir - if not self.zones["CENTER"]: - self.avoidance_active = False - self.avoidance_dir = 0 - - elif self.is_manual: - self.status = "MANUAL" - lr, fb, ud, yv = self.m_lr, self.m_fb, self.m_ud, self.m_yv - - elif len(faces) > 0: - self.search_start = time.time() - (x, y, w, h) = max(faces, key=lambda f: f[2] * f[3]) - yv_err = (x + w // 2) - FACE_CENTER_X - yv = int(np.clip(0.18 * yv_err, -60, 60)) - fb_raw = int(np.clip(0.5 * (TARGET_FACE_SIZE - w), -50, 50)) - - # Sanftere Reduzierung bei Drehung - fb = int(fb_raw * 0.5) if abs(yv_err) > 60 else fb_raw - - if self.is_orbiting: - self.status = "AI ORBIT"; lr = 35 - else: - self.status = "AI TRACKING" - else: - # SUCHE: "Walk and Look" Muster - elapsed = (time.time() - self.search_start) % 4.0 - if elapsed < 2.8: - self.status = "SEARCH: FORWARD" - fb = 35; yv = 8 - else: - self.status = "SEARCH: TURNING" - fb = 5; yv = 35 - - self.last_sent_rc = [lr, fb, ud, yv] - self.tello.send_rc_control(lr, fb, ud, yv) - - def draw_hud(self, frame, b): - try: - if hasattr(self.tello, 'is_flying'): - state = "FLYING" if self.tello.is_flying else "LANDED" - else: - state = self.tello.get_current_state() - except: state = "ERROR" - - if self.depth_map_vis is not None: - mini = cv2.resize(self.depth_map_vis, (160, 120)) - frame[60:180, 20:180] = mini - cv2.rectangle(frame, (20, 60), (180, 180), COLOR_WHITE, 1) - - cv2.rectangle(frame, (0, 0), (WINDOW_WIDTH, 60), (0,0,0), -1) - cv2.putText(frame, f"{self.status} | {state.upper()} | ALT: {self.current_height:.1f}m | BAT: {b}%", (20, 40), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLOR_GREEN if not self.center_blocked else COLOR_RED, 2) - - cv2.putText(frame, f"RC: {self.last_sent_rc}", (20, 70), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLOR_WHITE, 1) - - # Radar Score Debugging - for i, name in enumerate(["LEFT", "CENTER", "RIGHT"]): - col = COLOR_RED if self.zones[name] else COLOR_GREEN - val = self.zone_scores[name] - cv2.putText(frame, f"{name[:1]}:{val:.2f}", (WINDOW_WIDTH - 250 + i*80, 40), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, col, 1) - - def run(self): - try: - self.last_update_time = time.time() - while self.is_running: - fr = self.tello.get_frame_read() - if fr.frame is None: continue - frame = cv2.resize(fr.frame.copy(), (WINDOW_WIDTH, WINDOW_HEIGHT)) - self.update_telemetry(frame) - - # Face Detection (auf Originalbild) - faces = self.face_cascade.detectMultiScale(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), 1.2, 8) - - key = cv2.waitKey(1) & 0xFF - if key == ord('q'): self.is_running = False - elif key == ord('m'): self.is_manual = not self.is_manual - elif key == ord('o'): self.is_orbiting = not self.is_orbiting - elif key == ord('t'): - print("[App] Manual Takeoff Triggered") - self.tello.takeoff() - elif key == ord('l'): self.tello.land() - - if self.is_manual: - s = 60 - self.m_lr, self.m_fb, self.m_ud, self.m_yv = 0, 0, 0, 0 - if key == ord('w'): self.m_fb = s - elif key == ord('s'): self.m_fb = -s - elif key == ord('a'): self.m_lr = -s - elif key == ord('d'): self.m_lr = s - elif key == ord('r'): self.m_ud = s # R = UP - elif key == ord('f'): self.m_ud = -s # F = DOWN - elif key == ord('e'): self.m_yv = s # E = YAW RIGHT - elif key == ord('z'): self.m_yv = -s # Z = YAW LEFT - - self.process_flight(faces) - try: bat = int(self.tello.get_battery()) - except: bat = 0 - - self.draw_hud(frame, bat) - cv2.imshow("Tello AI Pilot (Stable Mode)", frame) - except Exception as e: print(f"Error: {e}") - finally: - self.tello.land(); self.tello.end(); cv2.destroyAllWindows() - -if __name__ == "__main__": - app = FaceTrackingApp() - app.run() diff --git a/run.py b/run.py new file mode 100644 index 0000000..e6cf1c1 --- /dev/null +++ b/run.py @@ -0,0 +1,6 @@ +from drone_pilot.main import FaceTrackingApp + +if __name__ == "__main__": + # Change use_real_tello=True when flying the physical drone + app = FaceTrackingApp(use_real_tello=True) + app.run() diff --git a/tello_sim/command_server.py b/tello_sim/command_server.py index cda05a7..30e4523 100644 --- a/tello_sim/command_server.py +++ b/tello_sim/command_server.py @@ -30,7 +30,7 @@ class CommandServer: try: test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - test_socket.bind(('localhost', port)) + test_socket.bind(('127.0.0.1', port)) test_socket.close() return True except OSError: @@ -88,7 +88,7 @@ class CommandServer: self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: - self.server_socket.bind(('localhost', 9999)) # Port number for communication + self.server_socket.bind(('127.0.0.1', 9999)) # Port number for communication self.server_socket.listen(5) print("[Command Listener] Listening on port 9999...") except OSError as e: diff --git a/tello_sim/tello_drone_sim.py b/tello_sim/tello_drone_sim.py index 4ac4ebf..267c5f1 100644 --- a/tello_sim/tello_drone_sim.py +++ b/tello_sim/tello_drone_sim.py @@ -60,4 +60,10 @@ class TelloDroneSim: raise def update(self) -> None: - self._ursina_adapter.tick() \ No newline at end of file + self._ursina_adapter.tick() + +if __name__ == "__main__": + sim = TelloDroneSim() + def update(): + sim.update() + sim.start() diff --git a/tello_sim/ursina_adapter.py b/tello_sim/ursina_adapter.py index 418cd88..5f07537 100644 --- a/tello_sim/ursina_adapter.py +++ b/tello_sim/ursina_adapter.py @@ -4,7 +4,6 @@ from OpenGL.GL import glReadPixels, GL_RGBA, GL_UNSIGNED_BYTE import numpy as np from typing import Literal import cv2 -import numpy as np from ursina import ( Ursina, window, @@ -175,6 +174,15 @@ class UrsinaAdapter(): cast_shadow=True ) + self.walking_person = Entity( + model='entities/business_man.glb', + scale=7.3, + position=(0, 13.0, 40), + rotation=(0, 90, 0), + collider='box', + cast_shadow=True + ) + self.light1 = Entity( model='entities/street_light.glb', scale=(4, 6.5, 5), @@ -303,6 +311,12 @@ class UrsinaAdapter(): self.max_roll = 20 self.tilt_smoothness = 0.05 + # RC control state + self.rc_lr = 0.0 + self.rc_fb = 0.0 + self.rc_ud = 0.0 + self.rc_yaw = 0.0 + self.create_meters() def run(self): @@ -478,6 +492,7 @@ class UrsinaAdapter(): return int(self.velocity.x * 3.6) def get_speed_y(self) -> int: + from ursina import time as ursina_time current_time = time() elapsed_time = current_time - self.last_time @@ -585,35 +600,90 @@ class UrsinaAdapter(): self.emergency() def update_movement(self) -> None: - self.velocity += self.acceleration + import ursina + dt = ursina.time.dt + t = time() + + # Update walking person movement + # Moves between -15 and 15 on X axis + walk_speed = 2.0 + old_x = self.walking_person.x + self.walking_person.x = sin(t * 0.5) * 15 - if self.velocity is None: - raise Exception("Velocity is None") + # Rotate person based on direction + if self.walking_person.x > old_x: + self.walking_person.rotation_y = 90 + else: + self.walking_person.rotation_y = -90 - if self.velocity.length() > self.max_speed: - self.velocity = self.velocity.normalized() * self.max_speed + # Apply RC control in world space based on local drone orientation + if self.is_flying: + # Tello RC values are -100 to 100. + speed_mult_h = 0.18 # Increased: 100 -> 18.0 m/s for faster forward + speed_mult_v = 0.05 # Vertical: 100 -> 5.0 m/s + + # Robust local-to-world conversion using yaw angle + rad = np.radians(self.drone.rotation_y) + sin_y = np.sin(rad) + cos_y = np.cos(rad) + + # vx = right-axis component, vz = forward-axis component + vx = (self.rc_fb * speed_mult_h * sin_y) + (self.rc_lr * speed_mult_h * cos_y) + vz = (self.rc_fb * speed_mult_h * cos_y) - (self.rc_lr * speed_mult_h * sin_y) + vy = self.rc_ud * speed_mult_v + + target_velocity = Vec3(vx, vy, vz) + + # Smoothly interpolate current velocity towards target + self.velocity = lerp(self.velocity, target_velocity, 10 * dt) + + # Apply yaw rotation (degrees per second) + # Reduced significantly: multiplier from 3.5 to 1.2 + self.drone.rotation_y += self.rc_yaw * 1.2 * dt + + # Set target tilt angles for visual feedback + self.pitch_angle = (self.rc_fb / 100.0) * self.max_pitch + self.roll_angle = (self.rc_lr / 100.0) * self.max_roll - self.velocity *= self.drag - new_position = self.drone.position + self.velocity - hit_info = raycast(self.drone.position, self.velocity.normalized(), distance=self.velocity.length(), ignore=(self.drone,)) # type: ignore + # Acceleration (m/s^2) from move() commands + self.velocity += self.acceleration * dt + + if self.velocity.length() > 20.0: # Max speed 20 m/s + self.velocity = self.velocity.normalized() * 20.0 - if not hit_info.hit: - self.drone.position = new_position + # Apply drag when no active RC input + if abs(self.rc_fb) < 1 and abs(self.rc_lr) < 1 and abs(self.rc_ud) < 1: + self.velocity = lerp(self.velocity, Vec3(0,0,0), 3 * dt) + + # Frame-independent position update + new_position = self.drone.position + self.velocity * dt + + # Collision check + if self.velocity.length() > 0.01: + hit_info = raycast(self.drone.position + Vec3(0,0.1,0), self.velocity.normalized(), + distance=self.velocity.length() * dt + 0.1, ignore=(self.drone,)) # type: ignore + if not hit_info.hit: + self.drone.position = new_position + else: + self.drone.position = new_position - if self.drone.y < 3: - self.drone.y = 3 + if self.drone.y < 2.6: # Ground level + self.drone.y = 2.6 + if self.velocity.y < 0: self.velocity.y = 0 self.acceleration = Vec3(0, 0, 0) # Apply pitch and roll to the drone self.drone.rotation_x = lerp(self.drone.rotation_x, self.pitch_angle, self.tilt_smoothness) self.drone.rotation_z = lerp(self.drone.rotation_z, self.roll_angle, self.tilt_smoothness) - current_time = time() - dt = current_time - self.last_time_accel + + from time import time as wall_time + current_time = wall_time() + dt_accel = current_time - self.last_time_accel - if dt > 0: + if dt_accel > 0: velocity_change = self.velocity - self.last_velocity_accel - self.calculated_acceleration = velocity_change / dt # type: ignore + self.calculated_acceleration = velocity_change / dt_accel # type: ignore self.last_velocity_accel = Vec3(self.velocity.x, self.velocity.y, self.velocity.z) self.last_time_accel = current_time @@ -697,15 +767,11 @@ class UrsinaAdapter(): self.drone.rotation_z = lerp(self.drone.rotation_z, self.roll_angle, self.tilt_smoothness) def send_rc_control(self, left_right_velocity_ms: float, forward_backward_velocity_ms: float, up_down_velocity_ms: float, yaw_velocity_ms: float): - - self.velocity = Vec3( - -left_right_velocity_ms / 100, # LEFT/RIGHT mapped to X - up_down_velocity_ms / 100, # UP/DOWN mapped to Y - forward_backward_velocity_ms / 100 # FORWARD/BACKWARD mapped to Z - ) - - self.drone.rotation_y += -yaw_velocity_ms * 0.05 # Smooth yaw rotation - print(f"[RC Control] Velocities set -> LR: {left_right_velocity_ms}, FB: {forward_backward_velocity_ms}, UD: {up_down_velocity_ms}, Yaw: {yaw_velocity_ms}") + # Only store the target RC values to be applied in the main thread (update_movement) + self.rc_lr = left_right_velocity_ms + self.rc_fb = forward_backward_velocity_ms + self.rc_ud = up_down_velocity_ms + self.rc_yaw = yaw_velocity_ms @staticmethod def map_coords(x: float, y: float, z: float) -> Vec3: @@ -892,7 +958,8 @@ class UrsinaAdapter(): rolling = False if self.bezier_mode: - t_now = time() + from time import time as wall_time + t_now = wall_time() elapsed = t_now - self.bezier_start_time # type: ignore t = min(1.0, elapsed / self.bezier_duration) @@ -929,4 +996,5 @@ class UrsinaAdapter(): self.update_movement() self.update_pitch_roll() - \ No newline at end of file + def update(self): + self.tick() diff --git a/tello_sim_client.py b/tello_sim_client.py index 04c3ff0..5d64fc0 100644 --- a/tello_sim_client.py +++ b/tello_sim_client.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -import logging import subprocess import platform import sys @@ -9,255 +8,141 @@ import cv2 import os import numpy as np +try: + from djitellopy import Tello + HAS_TELLO_LIB = True +except ImportError: + HAS_TELLO_LIB = False + @dataclass class BackgroundFrameRead(): frame: cv2.typing.MatLike class TelloSimClient: - def __init__(self, host='localhost', port=9999, auto_start_simulation=True): + def __init__(self, host='127.0.0.1', port=9999, auto_start_simulation=True, use_real_tello=False): self.host = host self.port = port - self.latest_frame = None - if auto_start_simulation and not self._check_simulation_running(): - self._start_simulation() - print("[Wrapper] Starting Tello Simulation...") - self._wait_for_simulation() + self.use_real_tello = use_real_tello + self.real_drone = None + + if self.use_real_tello: + if HAS_TELLO_LIB: + print("[Client] Connecting to REAL Tello Hardware...") + self.real_drone = Tello() + else: + print("[Error] djitellopy not found! Falling back to Simulator.") + self.use_real_tello = False - def _start_simulation(self): - sim_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tello_drone_sim.py')) - if platform.system() == "Windows": - command = f'start cmd /k python "{sim_path}"' - print("[DEBUG] Launching simulation command:", command) - subprocess.Popen(command, shell=True) - elif platform.system() == "Linux": - subprocess.Popen(['gnome-terminal', '--', 'python3', 'tello_drone_sim.py']) - elif platform.system() == "Darwin": - subprocess.Popen(['ls']) - subprocess.Popen(['pwd']) - python_path = os.path.join(os.path.dirname(sys.executable), 'python3') - print("Running python3 from path:", python_path) - subprocess.Popen([python_path, './tello_drone_sim.py'], cwd=os.getcwd(), - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - start_new_session=True) + if not self.use_real_tello: + if auto_start_simulation and not self._check_simulation_running(): + self._start_simulation() + self._wait_for_simulation() + + def connect(self): + if self.real_drone: self.real_drone.connect() + else: self._send_command('connect') + + def takeoff(self): + if self.real_drone: self.real_drone.takeoff() + else: self._send_command('takeoff') + + def land(self): + if self.real_drone: self.real_drone.land() + else: self._send_command('land') + + def streamon(self): + if self.real_drone: self.real_drone.streamon() + else: self._send_command('streamon') + + def streamoff(self): + if self.real_drone: self.real_drone.streamoff() + else: self._send_command('streamoff') + + def get_frame_read(self): + if self.real_drone: return self.real_drone.get_frame_read() + return self # Simulator acts as frame provider via TCP in get_frame_read-style + + @property + def frame(self): + """Helper for simulator to match BackgroundFrameRead structure""" + return self._receive_tcp_frame() + + def send_rc_control(self, lr, fb, ud, yaw): + if self.real_drone: self.real_drone.send_rc_control(int(lr), int(fb), int(ud), int(yaw)) + else: self._send_command(f"send_rc_control {lr} {fb} {ud} {yaw}") + + def get_yaw(self): + if self.real_drone: return self.real_drone.get_yaw() + return float(self._request_data('get_yaw') or 0) + + def get_height(self): + if self.real_drone: return self.real_drone.get_height() + return float(self._request_data('get_height') or 0) + + def get_distance_tof(self): + if self.real_drone: return self.real_drone.get_distance_tof() + return float(self._request_data('get_distance_tof') or 0) + + def get_battery(self): + if self.real_drone: return self.real_drone.get_battery() + return int(self._request_data('get_battery') or 0) + + def end(self): + if self.real_drone: + self.real_drone.streamoff() else: - raise OSError("Unsupported OS for launching terminal simulation.") + self._send_command('end') + + # --- Internal Simulator Logic --- + def _start_simulation(self): + sim_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tello_sim', 'tello_drone_sim.py')) + print(f"[Wrapper] Launching Simulator: {sim_path}") + clean_env = os.environ.copy() + # Set Qt platform to xcb for better compatibility on Linux/Wayland + clean_env["QT_QPA_PLATFORM"] = "xcb" + # Clean PyCharm debugger vars + for k in list(clean_env.keys()): + if k.startswith('PYDEVD'): del clean_env[k] + subprocess.Popen([sys.executable, sim_path], cwd=os.path.dirname(sim_path), + env=clean_env, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) def _check_simulation_running(self): try: - with socket.create_connection((self.host, self.port), timeout=1): - return True - except (ConnectionRefusedError, socket.timeout, OSError) as ex: - logging.error("[Wrapper] Simulation is not running.", ex) - return False + with socket.create_connection((self.host, self.port), timeout=0.5): return True + except: return False - def _wait_for_simulation(self, timeout=30): - print("[Wrapper] Waiting for simulation to become ready...") - start_time = time.time() - while time.time() - start_time < timeout: - if self._check_simulation_running(): - print("[Wrapper] Simulation is now ready!") - return + def _wait_for_simulation(self, timeout=15): + start = time.time() + while time.time() - start < timeout: + if self._check_simulation_running(): return time.sleep(1) - raise TimeoutError("[Error] Simulation did not become ready in time.") + print("[Warning] Simulator not responding.") - def _send_command(self, command: str): + def _send_command(self, cmd): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect((self.host, self.port)) - s.send(command.encode()) - except ConnectionRefusedError: - print(f"[Error] Unable to connect to the simulation at {self.host}:{self.port}") - - def get_frame_read(self) -> BackgroundFrameRead: - """Get the latest frame directly from the simulator over TCP.""" + s.connect((self.host, self.port)); s.send(cmd.encode()) + except: pass + + def _request_data(self, cmd): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((self.host, self.port)); s.send(cmd.encode()) + return s.recv(1024).decode() + except: return None + + def _receive_tcp_frame(self): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect((self.host, self.port)) s.send(b'get_latest_frame') - - # Receive frame size (4 bytes) size_data = s.recv(4) - if len(size_data) != 4: - print("[Error] Failed to receive frame size") - return BackgroundFrameRead(frame=np.zeros([360, 640, 3], dtype=np.uint8)) - - frame_size = int.from_bytes(size_data, byteorder='big') - - # If size is 0, no frame available - if frame_size == 0: - print("[Debug] No frame available from simulator") - return BackgroundFrameRead(frame=np.zeros([360, 640, 3], dtype=np.uint8)) - - # Receive the frame data - frame_data = b'' - bytes_received = 0 - while bytes_received < frame_size: - chunk = s.recv(min(4096, frame_size - bytes_received)) - if not chunk: - break - frame_data += chunk - bytes_received += len(chunk) - - # Decode the frame from PNG bytes - if len(frame_data) == frame_size: - nparr = np.frombuffer(frame_data, np.uint8) - image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - if image is not None: - # Return frame in BGR format (OpenCV's native format) - # Users should convert to RGB if needed for display - return BackgroundFrameRead(frame=image) - - print("[Error] Failed to decode frame data") - return BackgroundFrameRead(frame=np.zeros([360, 640, 3], dtype=np.uint8)) - - except ConnectionRefusedError: - print(f"[Error] Unable to connect to the simulation at {self.host}:{self.port}") - return BackgroundFrameRead(frame=np.zeros([360, 640, 3], dtype=np.uint8)) - except Exception as e: - print(f"[Error] Failed to get frame: {e}") - return BackgroundFrameRead(frame=np.zeros([360, 640, 3], dtype=np.uint8)) - - def _request_data(self, command): - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect((self.host, self.port)) - s.send(command.encode()) - return s.recv(4096).decode() - except ConnectionRefusedError: - print(f"[Error] Unable to retrieve '{command}' from {self.host}:{self.port}") - return "N/A" - - def wait_until_motion_complete(self): - while self._request_data("get_is_moving") == "True": - time.sleep(0.1) - - def get_battery(self): - return self._request_data('get_battery') - - def get_distance_tof(self): - return self._request_data('get_distance_tof') - - def get_height(self): - return self._request_data('get_height') - - def get_flight_time(self): - return self._request_data('get_flight_time') - - def get_speed_x(self): - return self._request_data('get_speed_x') - - def get_speed_y(self): - return self._request_data('get_speed_y') - - def get_speed_z(self): - return self._request_data('get_speed_z') - - def get_acceleration_x(self): - return self._request_data('get_acceleration_x') - - def get_acceleration_y(self): - return self._request_data('get_acceleration_y') - - def get_acceleration_z(self): - return self._request_data('get_acceleration_z') - - def get_pitch(self): - return self._request_data('get_pitch') - - def get_roll(self): - return self._request_data('get_roll') - - def get_yaw(self): - return self._request_data('get_yaw') - - def query_attitude(self): - return self._request_data('query_attitude') - - def get_current_state(self): - return self._request_data('get_current_state') - - def connect(self): - self._send_command('connect') - - def takeoff(self): - self._send_command('takeoff') - - def land(self): - self._send_command('land') - - def rotate_clockwise(self, degrees): - self._send_command(f'rotate_cw {degrees}') - - def rotate_counter_clockwise(self, degrees): - self._send_command(f'rotate_ccw {degrees}') - - def streamon(self): - self._send_command('streamon') - - def streamoff(self): - self._send_command('streamoff') - - def capture_frame(self): - self._send_command('capture_frame') - - def emergency(self): - self._send_command('emergency') - - def move_forward(self, distance): - self._send_command(f'forward {distance}') - - def move_back(self, distance): - self._send_command(f'backward {distance}') - - def move_left(self, distance): - self._send_command(f'left {distance}') - - def move_right(self, distance): - self._send_command(f'right {distance}') - - def move_up(self, distance): - self._send_command(f'up {distance}') - - def move_down(self, distance): - self._send_command(f'down {distance}') - - def flip_left(self): - self._send_command('flip_left') - - def flip_right(self): - self._send_command('flip_right') - - def flip_forward(self): - self._send_command('flip_forward') - - def flip_back(self): - self._send_command('flip_back') - - def go_xyz_speed(self, x, y, z, speed): - self._send_command(f"go {x} {y} {z} {speed}") - - def curve_xyz_speed(self, x1, y1, z1, x2, y2, z2, speed): - self._send_command(f"curve {x1} {y1} {z1} {x2} {y2} {z2} {speed}") - - def set_speed(self, speed): - self._send_command(f"set_speed {speed}") - - def send_rc_control(self, left_right_velocity, forward_backward_velocity, up_down_velocity, yaw_velocity): - self._send_command(f"send_rc_control {left_right_velocity} {forward_backward_velocity} {up_down_velocity} {yaw_velocity}") - - - def end(self): - self._send_command('end') - - def get_info(self): - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect((self.host, self.port)) - s.send(b'get_info') - return s.recv(4096).decode() - except ConnectionRefusedError: - print(f"[Error] Unable to retrieve info from {self.host}:{self.port}") - return "{}" - - def initiate_throw_takeoff(self): - self._send_command('throw_takeoff') \ No newline at end of file + if not size_data: return None + size = int.from_bytes(size_data, 'big') + data = b'' + while len(data) < size: + chunk = s.recv(min(size - len(data), 4096)) + if not chunk: break + data += chunk + return cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) + except: return None