Source code for pipert.contrib.sort

from pipert.contrib.sort_tracker.sort import Sort
import torch
from pipert.utils.structures import Instances, Boxes
import numpy as np
from pipert.core.component import BaseComponent
from queue import Queue
import argparse
from urllib.parse import urlparse
from pipert.core.routine import Routine
from pipert.core.mini_logics import Message2Redis, MessageFromRedis
from pipert.core import QueueHandler
import os


[docs]class InstancesSort(Sort): def __init__(self, max_age: int = 1, min_hits: int = None, window_size: int = None, percent_seen: float = None, verbose: bool = False): super().__init__(max_age, min_hits, window_size, percent_seen, verbose)
[docs] def update_instances(self, instances: Instances): im_size = instances.image_size tracks = None if len(instances): boxes = instances.get("pred_boxes").tensor.cpu().numpy() scores = instances.get("scores").cpu().unsqueeze(1).numpy() pred_classes = instances.get("pred_classes").cpu().unsqueeze(1).numpy() dets = np.concatenate((boxes, scores, pred_classes), axis=1) tracks = self.update(dets) ret_tracks = Instances(im_size) if tracks is not None: tracks = torch.tensor(tracks) ret_tracks.set("pred_boxes", Boxes(tracks[:, :4])) ret_tracks.set("scores", tracks[:, 4]) if tracks.shape[0] != 0: ret_tracks.set("pred_classes", tracks[:, 5].round().int()) ret_tracks.set("track_ids", tracks[:, -1].int()) return ret_tracks
[docs]class SORTLogic(Routine): def __init__(self, in_queue, out_queue, component_name, *args, **kwargs): super().__init__(component_name=component_name) self.in_queue = QueueHandler(in_queue) self.out_queue = QueueHandler(out_queue) self.sort = InstancesSort(*args, **kwargs)
[docs] def main_logic(self, *args, **kwargs): pred_msg = self.in_queue.non_blocking_get() if pred_msg: instances = pred_msg.get_payload() new_instances = self.sort.update_instances(instances) pred_msg.update_payload(new_instances) success = self.out_queue.deque_non_blocking_put(pred_msg) return success else: return False
[docs] def setup(self, *args, **kwargs): self.state.dropped = 0
[docs] def cleanup(self, *args, **kwargs): pass
[docs]class SORTComponent(BaseComponent): def __init__(self, endpoint, in_key, out_key, redis_url, name, maxlen=100, *args, **kwargs): super().__init__(endpoint, name) # TODO: should queue maxsize be configurable? self.in_queue = Queue(maxsize=1) self.out_queue = Queue(maxsize=1) t_get_meta = MessageFromRedis(in_key, redis_url, self.in_queue, name="get_frames", component_name=self.name).as_thread() self.register_routine(t_get_meta) t_sort = SORTLogic(self.in_queue, self.out_queue, self.name, *args, **kwargs).as_thread() self.register_routine(t_sort) t_upload_meta = Message2Redis(out_key, redis_url, self.out_queue, maxlen, name="upload_redis", component_name=self.name).as_thread() self.register_routine(t_upload_meta)
if __name__ == '__main__':
[docs] parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', help='Input stream key name', type=str, default='camera:2') parser.add_argument('-o', '--output', help='Output stream key name', type=str, default='camera:3') parser.add_argument('-u', '--url', help='Redis URL', type=str, default='redis://127.0.0.1:6379') parser.add_argument('-z', '--zpc', help='zpc port', type=str, default='4247') parser.add_argument('--maxlen', help='Maximum length of output stream', type=int, default=100) # max_age: int = 1, min_hits: int = None, window_size: int = None, percent_seen parser.add_argument('--max-age', type=int, default=1, help='object confidence threshold') parser.add_argument('--min-hits', type=int, help='iou threshold for non-maximum suppression') parser.add_argument('--window-size', type=int, default='mp4v', help='output video codec (verify ffmpeg support)') parser.add_argument('--percent-seen', type=float, help='output video codec (verify ffmpeg support)') opt = parser.parse_args() # url = urlparse(opts.url) url = os.environ.get('REDIS_URL') url = urlparse(url) if url is not None else urlparse(opt.url) zpc = SORTComponent(f"tcp://0.0.0.0:{opt.zpc}", opt.input, opt.output, url, "SORTComponent", opt.maxlen, opt.max_age, opt.min_hits, opt.window_size, opt.percent_seen) print(f"run {zpc.name}") zpc.run() print(f"Killed {zpc.name}")