Source code for classification_logic

import time
from pipert import Routine
from pipert.core import Message
from pipert.core.routine import RoutineTypes
from queue import Empty
import torch
import torchvision


[docs]class ClassificationLogic(Routine):
[docs] routine_type = RoutineTypes.PROCESSING
def __init__(self, in_queue, out_queue, weights, *args, **kwargs): super().__init__(*args, **kwargs) self.in_queue = in_queue self.out_queue = out_queue self.weights = weights self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.net = torchvision.models.resnet50(pretrained=False) self.transform = torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.net = self.net.to(self.device) self.net.fc = torch.nn.Linear(self.net.fc.in_features, 2) chkpt = torch.load(self.weights, map_location=self.device) chkpt['state_dict'] = \ {k[4:]: v for k, v in chkpt['state_dict'].items() if self.net.state_dict()[k[4:]].numel() == v.numel()} self.net.load_state_dict(chkpt['state_dict'], strict=False)
[docs] def main_logic(self, *args, **kwargs): try: frame_msg = self.in_queue.get(block=False) frame = frame_msg.get_payload() frame = self.transform(frame) frame = frame.to(self.device) frame = frame.unsqueeze(0) pred = self.net(frame) pred = torch.nn.functional.softmax(pred, dim=1).detach()[0, 1].item() pred = str(round(pred, 2)) try: self.out_queue.get(block=False) self.state.dropped += 1 except Empty: pass pred_msg = Message(pred, frame_msg.source_address) self.out_queue.put(pred_msg, block=False) return True except Empty: time.sleep(0) return False
[docs] def setup(self, *args, **kwargs): self.state.dropped = 0
[docs] def cleanup(self, *args, **kwargs): # del self.model, self.device, self.classes, self.colors pass
@staticmethod
[docs] def get_constructor_parameters(): dicts = Routine.get_constructor_parameters() dicts.update({ "in_queue": "QueueIn", "out_queue": "QueueOut", }) return dicts
[docs] def does_routine_use_queue(self, queue): return (self.in_queue == queue) or (self.out_queue == queue)