Skip to content
Snippets Groups Projects
Commit b08bfe87 authored by EtienneAr's avatar EtienneAr
Browse files

Add paramter for detection threshold

parent 41ea1f41
No related branches found
No related tags found
No related merge requests found
<launch>
<arg name="dataset" default="ycbv"/>
<arg name="detection_threshold" default="0.7"/>
<arg name="debug" default="false"/>
<arg name="debug_models" default="bop_datasets/$(arg dataset)/models/"/> <!-- Relative path from cosypose local_data directory -->
......@@ -12,6 +13,7 @@
<!-- Start service and camera -->
<include file="$(find ros_cosypose)/launch/singleview_service.launch">
<arg name="dataset" value="$(arg dataset)"/>
<arg name="detection_threshold" value="$(arg detection_threshold)"/>
<arg name="debug" value="$(arg debug)"/>
<arg name="debug_models" value="$(arg debug_models)"/>
......@@ -21,7 +23,7 @@
</include>
<!-- call ros_cosypose in loop and publish on topics -->
<node name="$(anon pose_estimation_loop.py)" pkg="ros_cosypose" type="pose_estimation_loop.py">
<node name="pose_estimation_loop" pkg="ros_cosypose" type="pose_estimation_loop.py">
<param name="camera_name" value="$(arg camera_name)"/>
<param name="allow_tracking" value="$(arg allow_tracking)"/>
</node>
......
<launch>
<arg name="dataset" default="ycbv"/>
<arg name="detection_threshold" default="0.7"/>
<arg name="debug" default="false"/>
<arg name="debug_models" default="bop_datasets/$(arg dataset)/models/"/> <!-- Relative path from cosypose local_data directory -->
......@@ -18,8 +19,9 @@
</group>
<!-- ros_cosypose service -->
<node name="$(anon pose_estimation.py)" pkg="ros_cosypose" type="pose_estimation.py">
<node name="pose_estimation" pkg="ros_cosypose" type="pose_estimation.py">
<param name="dataset" value="$(arg dataset)"/>
<param name="detection_threshold" value="$(arg detection_threshold)"/>
<param name="debug" value="$(arg debug)"/>
<param name="debug_models" value="$(arg debug_models)"/>
</node>
......
......@@ -39,11 +39,6 @@ class PredictionService:
object_detector_run_id=object_detector_run_id,
)
self.detector_kwargs = dict(
one_instance_per_class=False,
detection_th=0.9
)
self.debug_converter = None
if(debug):
self.debug_converter = RVizTranslator(debug_models)
......@@ -79,8 +74,10 @@ class PredictionService:
pose_estimation_prior = PandasTensorCollection(infos=pd.DataFrame(dict(label=labels,batch_im_id=batch_im_ids)),
poses=torch.tensor(poses).float().cuda())
detector_kwargs = dict(one_instance_per_class=False, detection_th=rospy.get_param('~detection_threshold'))
# Predict poses using cosypose
self.predictor([image, ], cameras, pose_estimation_prior=pose_estimation_prior, detector_kwargs=self.detector_kwargs)
self.predictor([image, ], cameras, pose_estimation_prior=pose_estimation_prior, detector_kwargs=detector_kwargs)
pose_predictions = self.predictor.pose_predictions
# transform cosypose objects to ros msg
......@@ -107,6 +104,9 @@ if __name__ == '__main__':
debug = rospy.get_param('~debug', "false")
debug_models = rospy.get_param('~debug_models', "")
if not rospy.has_param('~detection_threshold'):
rospy.set_param('~detection_threshold', 0.7)
prediction_srv = PredictionService(dataset, debug, debug_models)
rospy.Service('pose_estimation', GetPrediction, prediction_srv.pose_estimation)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment