summaryrefslogtreecommitdiff
path: root/subprojects
diff options
context:
space:
mode:
authorOlivier CrĂȘte <olivier.crete@collabora.com>2024-01-25 00:02:30 -0500
committerOlivier CrĂȘte <olivier.crete@collabora.com>2024-02-02 18:47:52 -0500
commit3325a10f573361faf90e02119f92c7a27fe04135 (patch)
tree7f85323b762f0c1993e0224c35b9aeb89b7e7dde /subprojects
parent5e1291fd86b636689ee012eef0cd79505495c6bb (diff)
onnx: Port SSD detector to C
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
Diffstat (limited to 'subprojects')
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp228
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h83
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.c (renamed from subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp)290
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h10
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp4
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h6
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/meson.build3
-rw-r--r--subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h20
8 files changed, 264 insertions, 380 deletions
diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp
deleted file mode 100644
index dbd4b30843..0000000000
--- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * GStreamer gstreamer-objectdetectorutils
- * Copyright (C) 2023 Collabora Ltd
- *
- * gstobjectdetectorutils.cpp
- *
- * This library is free software; you can redistribute it and/or
- * modify it under the terms of the GNU Library General Public
- * License as published by the Free Software Foundation; either
- * version 2 of the License, or (at your option) any later version.
- *
- * This library is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
- * Library General Public License for more details.
- *
- * You should have received a copy of the GNU Library General Public
- * License along with this library; if not, write to the
- * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
- * Boston, MA 02110-1301, USA.
- */
-
-#include "gstobjectdetectorutils.h"
-
-#include <gio/gio.h>
-
-
-char **
-read_labels (const char * labels_file)
-{
- GPtrArray *array;
- GFile *file = g_file_new_for_path (labels_file);
- GFileInputStream *file_stream;
- GDataInputStream *data_stream;
- GError *error = NULL;
- gchar *line;
-
- file_stream = g_file_read (file, NULL, &error);
- g_object_unref (file);
- if (!file_stream) {
- GST_WARNING ("Could not open file %s: %s\n", labels_file,
- error->message);
- g_clear_error (&error);
- return NULL;
- }
-
- data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
- g_object_unref (file_stream);
-
- array = g_ptr_array_new();
-
- while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
- &error)))
- g_ptr_array_add (array, line);
-
- g_object_unref (data_stream);
-
- if (error) {
- GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
- g_ptr_array_free (array, TRUE);
- g_clear_error (&error);
- return NULL;
- }
-
- if (array->len == 0) {
- g_ptr_array_free (array, TRUE);
- return NULL;
- }
-
- g_ptr_array_add (array, NULL);
-
- return (char **) g_ptr_array_free (array, FALSE);
-}
-
-
-GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
- float _y0, float _width, float _height):
-label (lbl),
-score (score),
-x0 (_x0),
-y0 (_y0),
-width (_width),
-height (_height)
-{
-}
-
-GstMlBoundingBox::GstMlBoundingBox ():
-GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
-{
-}
-
-namespace GstObjectDetectorUtils
-{
-
- GstObjectDetectorUtils::GstObjectDetectorUtils ()
- {
- }
-
-
- std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
- int32_t h, GstTensorMeta * tmeta, gchar **labels,
- float scoreThreshold)
- {
-
- auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
- g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
- if (classIndex == GST_TENSOR_MISSING_ID) {
- GST_ERROR ("Missing class tensor id");
- return std::vector < GstMlBoundingBox > ();
- }
- auto type = tmeta->tensor[classIndex].type;
- return (type == GST_TENSOR_TYPE_FLOAT32) ?
- doRun < float >(w, h, tmeta, labels, scoreThreshold)
- : doRun < int >(w, h, tmeta, labels, scoreThreshold);
- }
-
- template < typename T > std::vector < GstMlBoundingBox >
- GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
- GstTensorMeta * tmeta, char **labels, float scoreThreshold)
- {
- std::vector < GstMlBoundingBox > boundingBoxes;
- GstMapInfo map_info[GstObjectDetectorMaxNodes];
- GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
- gint index;
- T *numDetections = nullptr, *bboxes = nullptr, *scores =
- nullptr, *labelIndex = nullptr;
-
- // number of detections
- index = gst_tensor_meta_get_index_from_id (tmeta,
- g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
- if (index == GST_TENSOR_MISSING_ID) {
- GST_WARNING ("Missing tensor data for tensor index %d", index);
- goto cleanup;
- }
- memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
- if (!memory[index]) {
- GST_WARNING ("Missing tensor data for tensor index %d", index);
- goto cleanup;
- }
- if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
- GST_WARNING ("Failed to map tensor memory for index %d", index);
- goto cleanup;
- }
- numDetections = (T *) map_info[index].data;
-
- // bounding boxes
- index =
- gst_tensor_meta_get_index_from_id (tmeta,
- g_quark_from_static_string(GST_MODEL_OBJECT_DETECTOR_BOXES));
- if (index == GST_TENSOR_MISSING_ID) {
- GST_WARNING ("Missing tensor data for tensor index %d", index);
- goto cleanup;
- }
- memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
- if (!memory[index]) {
- GST_WARNING ("Failed to map tensor memory for index %d", index);
- goto cleanup;
- }
- if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
- GST_ERROR ("Failed to map GstMemory");
- goto cleanup;
- }
- bboxes = (T *) map_info[index].data;
-
- // scores
- index =
- gst_tensor_meta_get_index_from_id (tmeta,
- g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
- if (index == GST_TENSOR_MISSING_ID) {
- GST_ERROR ("Missing scores tensor id");
- goto cleanup;
- }
- memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
- if (!memory[index]) {
- GST_WARNING ("Missing tensor data for tensor index %d", index);
- goto cleanup;
- }
- if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
- GST_ERROR ("Failed to map GstMemory");
- goto cleanup;
- }
- scores = (T *) map_info[index].data;
-
- // optional label
- labelIndex = nullptr;
- index =
- gst_tensor_meta_get_index_from_id (tmeta,
- g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
- if (index != GST_TENSOR_MISSING_ID) {
- memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
- if (!memory[index]) {
- GST_WARNING ("Missing tensor data for tensor index %d", index);
- goto cleanup;
- }
- if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
- GST_ERROR ("Failed to map GstMemory");
- goto cleanup;
- }
- labelIndex = (T *) map_info[index].data;
- }
-
- for (int i = 0; i < numDetections[0]; ++i) {
- if (scores[i] > scoreThreshold) {
- std::string label = "";
-
- if (labels && labelIndex)
- label = labels[(int)labelIndex[i] - 1];
- auto score = scores[i];
- auto y0 = bboxes[i * 4] * h;
- auto x0 = bboxes[i * 4 + 1] * w;
- auto bheight = bboxes[i * 4 + 2] * h - y0;
- auto bwidth = bboxes[i * 4 + 3] * w - x0;
- boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
- bheight));
- }
- }
-
- cleanup:
- for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
- if (memory[i])
- gst_memory_unmap (memory[i], map_info + i);
-
- }
-
- return boundingBoxes;
- }
-
-}
diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h
deleted file mode 100644
index 2e7a83995a..0000000000
--- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * GStreamer gstreamer-objectdetectorutils
- * Copyright (C) 2023 Collabora Ltd
- *
- * gstobjectdetectorutils.h
- *
- * This library is free software; you can redistribute it and/or
- * modify it under the terms of the GNU Library General Public
- * License as published by the Free Software Foundation; either
- * version 2 of the License, or (at your option) any later version.
- *
- * This library is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
- * Library General Public License for more details.
- *
- * You should have received a copy of the GNU Library General Public
- * License along with this library; if not, write to the
- * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
- * Boston, MA 02110-1301, USA.
- */
-#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
-#define __GST_OBJECT_DETECTOR_UTILS_H__
-
-#include <gst/gst.h>
-#include <string>
-#include <vector>
-
-#include "gstml.h"
-#include "tensor/gsttensormeta.h"
-
-char ** read_labels (const char * labels_file);
-
-/* Object detection tensor id strings */
-#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
-#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
-#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
-#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
-
-
-/**
- * GstMlBoundingBox:
- *
- * @label label
- * @score detection confidence
- * @x0 top left hand x coordinate
- * @y0 top left hand y coordinate
- * @width width
- * @height height
- *
- * Since: 1.20
- */
-struct GstMlBoundingBox {
- GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
- float _width, float _height);
- GstMlBoundingBox();
- std::string label;
- float score;
- float x0;
- float y0;
- float width;
- float height;
-};
-
-namespace GstObjectDetectorUtils {
- const int GstObjectDetectorMaxNodes = 4;
- class GstObjectDetectorUtils {
- public:
- GstObjectDetectorUtils(void);
- ~GstObjectDetectorUtils(void) = default;
- std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
- GstTensorMeta *tmeta,
- char **labels,
- float scoreThreshold);
- private:
- template < typename T > std::vector < GstMlBoundingBox >
- doRun(int32_t w, int32_t h,
- GstTensorMeta *tmeta, char **labels,
- float scoreThreshold);
- };
-}
-
-#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */
diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.c
index 5c69579323..f89ab28f4f 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp
+++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.c
@@ -44,19 +44,23 @@
#include "config.h"
#endif
-
#include "gstssdobjectdetector.h"
-#include "gstobjectdetectorutils.h"
+
+#include <gio/gio.h>
#include <gst/gst.h>
#include <gst/video/video.h>
-
#include <gst/analytics/analytics.h>
#include "tensor/gsttensormeta.h"
+/* Object detection tensor id strings */
+#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
+#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
+#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
+#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
+
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
#define GST_CAT_DEFAULT ssd_object_detector_debug
-#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
@@ -68,7 +72,7 @@ enum
PROP_SCORE_THRESHOLD,
};
-#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
+#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
static GstStaticPadTemplate gst_ssd_object_detector_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
@@ -97,7 +101,8 @@ static gboolean
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
-G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
+G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector,
+ GST_TYPE_BASE_TRANSFORM);
static void
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
@@ -135,7 +140,8 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
g_param_spec_float ("score-threshold",
"Score threshold",
"Threshold for deciding when to remove boxes based on score",
- 0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
+ 0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD,
+ (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "objectdetector",
@@ -155,7 +161,6 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
static void
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
{
- self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
}
static void
@@ -164,12 +169,58 @@ gst_ssd_object_detector_finalize (GObject * object)
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
g_free (self->label_file);
- g_strfreev (self->labels);
- delete GST_ODUTILS_MEMBER (self);
+ g_clear_pointer (&self->labels, g_array_unref);
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
}
+static GArray *
+read_labels (const char *labels_file)
+{
+ GArray *array;
+ GFile *file = g_file_new_for_path (labels_file);
+ GFileInputStream *file_stream;
+ GDataInputStream *data_stream;
+ GError *error = NULL;
+ gchar *line;
+
+ file_stream = g_file_read (file, NULL, &error);
+ g_object_unref (file);
+ if (!file_stream) {
+ GST_WARNING ("Could not open file %s: %s\n", labels_file, error->message);
+ g_clear_error (&error);
+ return NULL;
+ }
+
+ data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
+ g_object_unref (file_stream);
+
+ array = g_array_new (FALSE, FALSE, sizeof (GQuark));
+
+ while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
+ &error))) {
+ GQuark label = g_quark_from_string (line);
+ g_array_append_val (array, label);
+ g_free (line);
+ }
+
+ g_object_unref (data_stream);
+
+ if (error) {
+ GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
+ g_array_free (array, TRUE);
+ g_clear_error (&error);
+ return NULL;
+ }
+
+ if (array->len == 0) {
+ g_array_free (array, TRUE);
+ return NULL;
+ }
+
+ return array;
+}
+
static void
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
@@ -179,21 +230,21 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
switch (prop_id) {
case PROP_LABEL_FILE:
- {
- gchar **labels;
-
- filename = g_value_get_string (value);
- labels = read_labels (filename);
-
- if (labels) {
- g_free (self->label_file);
- self->label_file = g_strdup (filename);
- g_strfreev (self->labels);
- self->labels = labels;
- } else {
- GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
- }
+ {
+ GArray *labels;
+
+ filename = g_value_get_string (value);
+ labels = read_labels (filename);
+
+ if (labels) {
+ g_free (self->label_file);
+ self->label_file = g_strdup (filename);
+ g_clear_pointer (&self->labels, g_array_unref);
+ self->labels = labels;
+ } else {
+ GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
+ }
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
@@ -259,7 +310,8 @@ gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
- if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
+ if (boxesIndex == GST_TENSOR_MISSING_ID
+ || scoresIndex == GST_TENSOR_MISSING_ID
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
continue;
@@ -300,13 +352,175 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
return GST_FLOW_OK;
}
+#define DEFINE_GET_FUNC(TYPE, MAX) \
+ static gboolean \
+ get_ ## TYPE ## _at_index (GstTensor *tensor, GstMapInfo *map, \
+ guint index, TYPE * out) \
+ { \
+ switch (tensor->type) { \
+ case GST_TENSOR_TYPE_FLOAT32: { \
+ float *f = (float *) map->data; \
+ if (sizeof(*f) * (index + 1) > map->size) \
+ return FALSE; \
+ *out = f[index]; \
+ break; \
+ } \
+ case GST_TENSOR_TYPE_UINT32: { \
+ guint32 *u = (guint32 *) map->data; \
+ if (sizeof(*u) * (index + 1) > map->size) \
+ return FALSE; \
+ *out = u[index]; \
+ break; \
+ } \
+ default: \
+ GST_ERROR ("Only float32 and int32 tensors are understood"); \
+ return FALSE; \
+ } \
+ return TRUE; \
+ }
+
+DEFINE_GET_FUNC (guint32, UINT32_MAX)
+ DEFINE_GET_FUNC (float, FLOAT_MAX)
+#undef DEFINE_GET_FUNC
+ static void
+ extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
+ GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
+{
+ gint classes_index;
+ gint boxes_index;
+ gint scores_index;
+ gint numdetect_index;
+
+ GstMapInfo boxes_map = GST_MAP_INFO_INIT;
+ GstMapInfo numdetect_map = GST_MAP_INFO_INIT;
+ GstMapInfo scores_map = GST_MAP_INFO_INIT;
+ GstMapInfo classes_map = GST_MAP_INFO_INIT;
+
+ guint num_detections = 0;
+
+ classes_index = gst_tensor_meta_get_index_from_id (tmeta,
+ g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
+ numdetect_index = gst_tensor_meta_get_index_from_id (tmeta,
+ g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
+ scores_index = gst_tensor_meta_get_index_from_id (tmeta,
+ g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
+ boxes_index = gst_tensor_meta_get_index_from_id (tmeta,
+ g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
+
+ if (numdetect_index == GST_TENSOR_MISSING_ID
+ || scores_index == GST_TENSOR_MISSING_ID
+ || numdetect_index == GST_TENSOR_MISSING_ID) {
+ GST_WARNING ("Missing tensor data expected for SSD model");
+ return;
+ }
+
+ if (!gst_buffer_map (tmeta->tensor[numdetect_index].data, &numdetect_map,
+ GST_MAP_READ)) {
+ GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
+ numdetect_index);
+ goto cleanup;
+ }
+
+ if (!gst_buffer_map (tmeta->tensor[boxes_index].data, &boxes_map,
+ GST_MAP_READ)) {
+ GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
+ boxes_index);
+ goto cleanup;
+ }
+
+ if (!gst_buffer_map (tmeta->tensor[scores_index].data, &scores_map,
+ GST_MAP_READ)) {
+ GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
+ scores_index);
+ goto cleanup;
+ }
+
+ if (classes_index != GST_TENSOR_MISSING_ID &&
+ !gst_buffer_map (tmeta->tensor[classes_index].data, &classes_map,
+ GST_MAP_READ)) {
+ GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d",
+ classes_index);
+ }
+
+
+ if (!get_guint32_at_index (&tmeta->tensor[numdetect_index], &numdetect_map,
+ 0, &num_detections)) {
+ GST_ERROR_OBJECT (self, "Failed to get the number of detections");
+ goto cleanup;
+ }
+
+
+ GST_LOG_OBJECT (self, "Model claims %d detections", num_detections);
+
+ for (int i = 0; i < num_detections; i++) {
+ float score;
+ float x, y, bwidth, bheight;
+ gint x_i, y_i, bwidth_i, bheight_i;
+ guint32 bclass;
+ GQuark label = 0;
+ GstAnalyticsODMtd odmtd;
+
+ if (!get_float_at_index (&tmeta->tensor[numdetect_index], &scores_map,
+ i, &score))
+ continue;
+
+ GST_LOG_OBJECT (self, "Detection %u score is %f", i, score);
+ if (score < self->score_threshold)
+ continue;
+
+ if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
+ i * 4, &y))
+ continue;
+ if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
+ i * 4 + 1, &x))
+ continue;
+ if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
+ i * 4 + 2, &bheight))
+ continue;
+ if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
+ i * 4 + 3, &bwidth))
+ continue;
+
+ if (self->labels && classes_map.memory &&
+ get_guint32_at_index (&tmeta->tensor[classes_index], &classes_map,
+ i, &bclass)) {
+ if (bclass < self->labels->len)
+ label = g_array_index (self->labels, GQuark, bclass);
+ }
+
+ x_i = x * w;
+ y_i = y * h;
+ bheight_i = (bheight * h) - y_i;
+ bwidth_i = (bwidth * w) - x_i;
+
+ if (gst_analytics_relation_meta_add_od_mtd (rmeta, label,
+ x_i, y_i, bwidth_i, bheight_i, score, &odmtd))
+ GST_DEBUG_OBJECT (self,
+ "Object detected with label : %s, score: %f, bound box: %dx%d at (%d,%d)",
+ g_quark_to_string (label), score, bwidth_i, bheight_i, x_i, y_i);
+ else
+ GST_WARNING_OBJECT (self, "Could not add detection to meta");
+ }
+
+cleanup:
+
+ if (numdetect_map.memory)
+ gst_buffer_unmap (tmeta->tensor[numdetect_index].data, &numdetect_map);
+ if (classes_map.memory)
+ gst_buffer_unmap (tmeta->tensor[classes_index].data, &classes_map);
+ if (scores_map.memory)
+ gst_buffer_unmap (tmeta->tensor[scores_index].data, &scores_map);
+ if (boxes_map.memory)
+ gst_buffer_unmap (tmeta->tensor[boxes_index].data, &boxes_map);
+}
+
+
static gboolean
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
{
- GstTensorMeta *tmeta = NULL;
- GstAnalyticsODMtd odmtd;
- GstAnalyticsRelationMeta *rmeta;
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
+ GstTensorMeta *tmeta;
+ GstAnalyticsRelationMeta *rmeta;
// get all tensor metas
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
@@ -315,25 +529,11 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
return TRUE;
} else {
rmeta = gst_buffer_add_analytics_relation_meta (buf);
- g_return_val_if_fail (rmeta != NULL, FALSE);
+ g_assert (rmeta);
}
- std::vector < GstMlBoundingBox > boxes =
- GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
- self->video_info.height, tmeta, self->labels,
- self->score_threshold);
-
- for (auto & b:boxes) {
- if (gst_analytics_relation_meta_add_od_mtd (rmeta,
- g_quark_from_string(b.label.c_str ()), b.x0, b.y0, b.width, b.height,
- b.score, &odmtd)) {
- GST_DEBUG_OBJECT (self,
- "Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f)",
- b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
- } else {
- GST_ERROR_OBJECT (self, "Failed to add object detection analytics-meta");
- }
- }
+ extract_bounding_boxes (self, self->video_info.width,
+ self->video_info.height, rmeta, tmeta);
return TRUE;
}
diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h
index b101650ba4..9bc89aa85a 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h
+++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h
@@ -37,14 +37,11 @@ G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OB
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
-/**
+/*
* GstSsdObjectDetector:
*
* @label_file label file
* @score_threshold score threshold
- * @confidence_threshold confidence threshold
- * @iou_threhsold iou threshold
- * @od_ptr opaque pointer to GstOd object detection implementation
*
* Since: 1.20
*/
@@ -52,11 +49,8 @@ struct _GstSsdObjectDetector
{
GstBaseTransform basetransform;
gchar *label_file;
- gchar **labels;
+ GArray *labels;
gfloat score_threshold;
- gfloat confidence_threshold;
- gfloat iou_threshold;
- gpointer odutils;
GstVideoInfo video_info;
};
diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
index 37060dd543..ba5ae83ca5 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
+++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp
@@ -96,7 +96,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return inputImageFormat;
}
- void GstOnnxClient::setInputImageDatatype(GstTensorType datatype)
+ void GstOnnxClient::setInputImageDatatype(GstTensorDataType datatype)
{
inputDatatype = datatype;
switch (inputDatatype) {
@@ -144,7 +144,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return inputTensorScale;
}
- GstTensorType GstOnnxClient::getInputImageDatatype(void)
+ GstTensorDataType GstOnnxClient::getInputImageDatatype(void)
{
return inputDatatype;
}
diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h
index 0e4f50d68e..bdec9f1a3b 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h
+++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h
@@ -56,7 +56,7 @@ namespace GstOnnxNamespace {
bool hasSession(void);
void setInputImageFormat(GstMlInputImageFormat format);
GstMlInputImageFormat getInputImageFormat(void);
- GstTensorType getInputImageDatatype(void);
+ GstTensorDataType getInputImageDatatype(void);
void setInputImageOffset (float offset);
float getInputImageOffset ();
void setInputImageScale (float offset);
@@ -73,7 +73,7 @@ namespace GstOnnxNamespace {
private:
GstElement *debug_parent;
- void setInputImageDatatype (GstTensorType datatype);
+ void setInputImageDatatype (GstTensorDataType datatype);
template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
@@ -91,7 +91,7 @@ namespace GstOnnxNamespace {
std::vector < Ort::AllocatedStringPtr > outputNames;
std::vector < GQuark > outputIds;
GstMlInputImageFormat inputImageFormat;
- GstTensorType inputDatatype;
+ GstTensorDataType inputDatatype;
size_t inputDatatypeSize;
bool fixedInputImageSize;
float inputTensorOffset;
diff --git a/subprojects/gst-plugins-bad/ext/onnx/meson.build b/subprojects/gst-plugins-bad/ext/onnx/meson.build
index 11f93eaa3f..ceb0b9a85c 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/meson.build
+++ b/subprojects/gst-plugins-bad/ext/onnx/meson.build
@@ -15,8 +15,7 @@ endif
if onnxrt_dep.found()
gstonnx = library('gstonnx',
'gstonnx.c',
- 'decoders/gstobjectdetectorutils.cpp',
- 'decoders/gstssdobjectdetector.cpp',
+ 'decoders/gstssdobjectdetector.c',
'gstonnxinference.cpp',
'gstonnxclient.cpp',
'tensor/gsttensormeta.c',
diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h
index 49983308f4..d7911312d2 100644
--- a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h
+++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h
@@ -25,7 +25,7 @@
#include <gst/gst.h>
/**
- * GstTensorType:
+ * GstTensorDataType:
*
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
@@ -42,9 +42,11 @@
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
*
+ * Describe the type of data contain in the tensor.
+ *
* Since: 1.24
*/
-typedef enum _GstTensorType
+typedef enum _GstTensorDataType
{
GST_TENSOR_TYPE_INT4,
GST_TENSOR_TYPE_INT8,
@@ -60,17 +62,17 @@ typedef enum _GstTensorType
GST_TENSOR_TYPE_FLOAT32,
GST_TENSOR_TYPE_FLOAT64,
GST_TENSOR_TYPE_BFLOAT16,
-} GstTensorType;
+} GstTensorDataType;
/**
* GstTensor:
*
- * @id unique tensor identifier
- * @num_dims number of tensor dimensions
- * @dims tensor dimensions
- * @type @ref GstTensorType of tensor data
- * @data @ref GstBuffer holding tensor data
+ * @id: semantically identify the contents of the tensor
+ * @num_dims: number of tensor dimensions
+ * @dims: tensor dimensions
+ * @type: #GstTensorDataType of tensor data
+ * @data: #GstBuffer holding tensor data
*
* Since: 1.24
*/
@@ -79,7 +81,7 @@ typedef struct _GstTensor
GQuark id;
gint num_dims;
int64_t *dims;
- GstTensorType type;
+ GstTensorDataType data_type;
GstBuffer *data;
} GstTensor;