94 lines
3.5 KiB
Python
94 lines
3.5 KiB
Python
import json
|
|
import os.path
|
|
import re
|
|
import zipfile
|
|
from typing import NamedTuple, Optional, Iterable, List
|
|
|
|
from numpy import ndarray
|
|
|
|
from arkite_booster.cv2util import box_mask
|
|
from arkite_booster.util import FloatPoint
|
|
from arkite_booster.util import bounding_box_centre
|
|
|
|
ArkiteBox = NamedTuple(typename='ArkiteBox',
|
|
fields=[('name', str),
|
|
('points', List[FloatPoint]),
|
|
('depth', int),
|
|
('distance', int),
|
|
('centre', FloatPoint),
|
|
('mask', ndarray)])
|
|
|
|
HEF_BOX_LIST_FILE = 'data/Model.Entities.Box.json'
|
|
|
|
CALIBRATION_BOX_REGEX = re.compile('(Depth|IR) Cal ([0-9]+|- Center).*')
|
|
GRAB_BOX_PREFIX = 'grab_'
|
|
|
|
|
|
def load_boxes(model_entities_box_path: str,
|
|
project_ids: Optional[Iterable[int]] = None,
|
|
box_ids: Optional[Iterable[int]] = None,
|
|
keep_grab_boxes: bool = True,
|
|
keep_container_boxes: bool = True,
|
|
skip_calibration_boxes: bool = True) -> Iterable[ArkiteBox]:
|
|
model_entities_box_path = os.path.abspath(model_entities_box_path)
|
|
|
|
if model_entities_box_path.endswith('.json'):
|
|
with open(model_entities_box_path, 'r') as mebf:
|
|
all_boxes = json.load(mebf)
|
|
elif model_entities_box_path.endswith('.hef'):
|
|
with zipfile.ZipFile(model_entities_box_path, mode='r') as hef:
|
|
with hef.open(HEF_BOX_LIST_FILE) as mebf:
|
|
all_boxes = json.loads(mebf.read().decode('utf-8'))
|
|
else:
|
|
raise ValueError('Unknown file format: ' + model_entities_box_path)
|
|
|
|
boxes = []
|
|
for box_json in all_boxes:
|
|
if project_ids is not None and box_json['ProjectId'] not in project_ids:
|
|
continue
|
|
|
|
if box_ids is not None and box_json['Id'] not in box_ids:
|
|
continue
|
|
|
|
box_name = box_json['Name']
|
|
if skip_calibration_boxes and CALIBRATION_BOX_REGEX.match(box_name) is not None:
|
|
continue
|
|
|
|
if not keep_grab_boxes and box_name.startswith(GRAB_BOX_PREFIX):
|
|
continue
|
|
|
|
if not keep_container_boxes and not box_name.startswith(GRAB_BOX_PREFIX):
|
|
continue
|
|
|
|
points = [(coordinates["X"], coordinates["Y"]) for coordinates in box_json["DataPoints"]]
|
|
mask = box_mask(points)
|
|
centre = bounding_box_centre(points)
|
|
|
|
box = ArkiteBox(name=box_name,
|
|
depth=box_json['Depth'],
|
|
distance=box_json['Distance'],
|
|
points=points,
|
|
mask=mask,
|
|
centre=centre)
|
|
boxes.append(box)
|
|
|
|
return boxes
|
|
|
|
|
|
def boxes_for_use_case(box_metadata_root_dir: str,
|
|
use_case: str,
|
|
keep_grab_boxes: bool = True,
|
|
keep_container_boxes: bool = True,
|
|
skip_calibration_boxes: bool = True) -> Iterable[ArkiteBox]:
|
|
"""Ad-hoc method to match the detection libraty with additional `.hef` files provided by Stijn."""
|
|
box_metadata_root_dir = os.path.abspath(box_metadata_root_dir)
|
|
|
|
use_case_name = use_case.split('/')[-1]
|
|
use_case_hef = os.path.join(box_metadata_root_dir, use_case_name + '.hef')
|
|
use_case_boxes = load_boxes(use_case_hef,
|
|
keep_grab_boxes=keep_grab_boxes,
|
|
keep_container_boxes=keep_container_boxes,
|
|
skip_calibration_boxes=skip_calibration_boxes)
|
|
|
|
return use_case_boxes
|