Cam shift tracker.cpp v1

From wikidb
Jump to: navigation, search

//*************************************************************************
//  CamShift Tracker for ROS
// 
//  Copyright (C) 2012-15, Ed C. Epp
//  http://zdome.net/wiki/index.php/ROS
//
//  Based on camshiftdemo.cpp found in:
//      github.com/Itseez/opencv/tree/2.4/samples/cpp//
//  This program is free software: you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation, either version 3 of the License, or
//  (at your option) any later version.
//
//  This program is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//*************************************************************************

#include <ros/ros.h>
#include <opencv2/video/tracking.hpp>
#include <image_transport/image_transport.h>
#include <cv_bridge/cv_bridge.h>
#include <sensor_msgs/image_encodings.h>
#include <opencv2/highgui/highgui.hpp>
#include <geometry_msgs/Point.h>

//****************************************************************************
//          on_mouse function
//****************************************************************************

CvRect  our_selection;
int     our_object_selected = 0;
int     our_new_object_selected = 0;
int     our_dragging = 0;
CvPoint our_drag_location;

void on_mouse (int event, int x, int y, int flags, void* param)
{
    our_dragging = 0;
    switch (event)
    {
        case CV_EVENT_LBUTTONDOWN:
            our_selection = cvRect(x,y,0,0);
            break;
        case CV_EVENT_MOUSEMOVE:
            if (flags == CV_EVENT_FLAG_LBUTTON)
            {
                our_dragging = 1;
                our_drag_location = cvPoint(x,y);
            }
            break;
        case CV_EVENT_LBUTTONUP:
            our_selection.width  = abs (x - our_selection.x);
            our_selection.height = abs (y - our_selection.y);
            our_selection.x      = MIN (x, our_selection.x);
            our_selection.y      = MIN (y, our_selection.y);
            our_object_selected = our_new_object_selected = 1;
            //ROS_INFO("Mouse up: x: %d  y: %d  width: %d  height: %d", 
            //         our_selection.x, our_selection.y, 
            //         our_selection.width, our_selection.height);
            break;
    }
}

//****************************************************************************
//          help function
//****************************************************************************

static void help()
{
    std::cout << "\nThis is a demo that shows mean-shift based tracking\n"
            "You select a color objects such as your face and it tracks it.\n"
            "This reads from video camera (0 by default, or the camera number "
            "the user enters\n"
            "Usage: \n"
            "   ./camshiftdemo [camera number]\n";

    std::cout << "\n\nHot keys: \n"
            "\tESC - quit the program\n"
            "\tc - stop the tracking\n"
            "\tb - switch to/from backprojection view\n"
            "\th - show/hide object histogram\n"
            "To initialize tracking, select the object with mouse\n";
}

//****************************************************************************
//          CamShaftTracker class
//****************************************************************************

class CamShaftTracker
{

protected:

    // ROS related
    ros::NodeHandle                 my_handler;
    image_transport::ImageTransport my_image_transport;
    image_transport::Subscriber     my_image_subscriber;
    cv_bridge::CvImagePtr           my_webcam_img;
    ros::Publisher                  my_tracker_publisher;

    // Image related
    cv::Mat         my_image;
    cv::Mat         my_hsv;
    cv::Mat         my_hue;
    cv::Mat         my_mask;
    cv::Mat         my_histogram;
    cv::Mat         my_histogram_img;
    cv::Mat         my_backproject;

    // Image range
    int             my_vmin;
    int             my_vmax;
    int             my_smin;
  
    // Histogram related
    int             my_hist_dims;
    float*          my_hist_ranges_array;

    // CamShift related
    cv::Rect        my_track_window;  // must be global or program faults
    CvBox2D         my_track_box;
    CvConnectedComp my_track_comp;

    // Execution state
    bool            my_back_project_mode;
    bool            my_show_histogram;

public:

    // ---------- CamShaftTracker ------------------------------------------

    CamShaftTracker (ros::NodeHandle & nh):my_handler (nh), 
                                           my_image_transport (my_handler)
    {
        // Image initialization
        my_image         = cv::Mat::zeros(200, 320, CV_8UC3);
        my_hsv           = cv::Mat::zeros(200, 320, CV_8UC3);
        my_hue           = cv::Mat::zeros(200, 320, CV_8UC3);
        my_mask          = cv::Mat::zeros(200, 320, CV_8UC3);
        my_histogram     = cv::Mat::zeros(200, 320, CV_8UC3);
        my_histogram_img = cv::Mat::zeros(200, 320, CV_8UC3);

        // Listen for raw image message topic messages and setup callback
        my_image_subscriber = my_image_transport.subscribe (
                                  "/usb_cam/image_raw", 1, 
                                  &CamShaftTracker::imageCallback, this);

        // Open HighGUI Windows
        cv::namedWindow ("CamShiftTracking", 1);
        cvSetMouseCallback( "CamShiftTracking", on_mouse, 0 );
        cvMoveWindow("CamShiftTracking", 100, 100);
        cv::namedWindow ("Histogram", 1);
        cvMoveWindow("Histogram", 900, 100);

        // Initialize ranges
        my_vmin = 10;
        my_vmax = 256;
        my_smin = 30;

        // Initialize histogram
        my_hist_dims            = 16;
        my_hist_ranges_array    = new float[2];
        my_hist_ranges_array[0] = 0;
        my_hist_ranges_array[1] = 180;

	// Initialize execution state
        my_back_project_mode = false;
	my_show_histogram = true;

        // Create publisher
        my_tracker_publisher = my_handler.advertise<geometry_msgs::Point>
                       ("tracker", 1000);
    }
 
    // ---------- imageCallback ----------------------------------------------

    void imageCallback (const sensor_msgs::ImageConstPtr & msg_ptr)
    {
        const float* hist_ranges = my_hist_ranges_array;

        // Copy in the webcam image
        try
        {
            my_webcam_img = cv_bridge::toCvCopy(msg_ptr, 
                            sensor_msgs::image_encodings::BGR8);
        }
        catch (cv_bridge::Exception error)
        {
            ROS_ERROR ("cv_bridge Input Error: %s", error.what());
        }

	// Convert the image to cv::Mat
        my_image = cv::Mat (my_webcam_img->image).clone ();

        // Convert Input image from BGR to HSV 
        cv::cvtColor (my_image, my_hsv, cv::COLOR_BGR2HSV);

        my_backproject  = cv::Mat (my_webcam_img->image).clone ();

        // only do this is we have selected an tracking area
        if (our_object_selected)
        {
            int bin_w;
            int vmin = my_vmin;
            int vmax = my_vmax;
            geometry_msgs::Point msg;

 	    inRange( my_hsv, cv::Scalar(0,my_smin,MIN(vmin,vmax)),
		     cv::Scalar(180,256,MAX(vmin,vmax)), my_mask );
	    int ch[] = {0, 0};
	    my_hue.create(my_hsv.size(), my_hsv.depth());
	    mixChannels(&my_hsv, 1, &my_hue, 1, ch, 1);

            // only do this when the tracking area changes
            if (our_new_object_selected)
            {
 		// Set up the histrogram matrix
		cv::Mat roi(my_hue, our_selection);
		cv::Mat mask_roi(my_mask, our_selection);
		cv::calcHist (&roi, 1, 0, mask_roi, my_histogram, 1, 
                              &my_hist_dims, &hist_ranges);
		cv::normalize(my_histogram, my_histogram, 0, 255, CV_MINMAX);

                our_new_object_selected = 0;
                my_track_window = our_selection;

		// Set up the histrogram image
		my_histogram_img = cv::Scalar::all(0);
		bin_w = my_histogram_img.cols / my_hist_dims;
		cv::Mat buf(1, my_hist_dims, CV_8UC3);
		for (int i = 0; i < my_hist_dims; i++)
		    buf.at<cv::Vec3b>(i) = 
                        cv::Vec3b(cv::saturate_cast<uchar>
                                 (i*180./my_hist_dims), 255, 255);
		cvtColor(buf, buf, CV_HSV2BGR);

		// Draw the historgram bars in the histogram image
                for( int i = 0; i < my_hist_dims; i++ )
                {
		    int val = cv::saturate_cast<int>
                        (my_histogram.at<float>(i) *my_histogram_img.rows/255);
                    rectangle( my_histogram_img, 
                           cv::Point(i*bin_w,my_histogram_img.rows),
                           cv::Point((i+1)*bin_w,my_histogram_img.rows - val),
                           cv::Scalar(buf.at<cv::Vec3b>(i)), -1, 8 );
                }
            }

            // do the CAM Shift computations
	    cv::calcBackProject( &my_hue, 1, 0, my_histogram, 
                                   my_backproject, &hist_ranges);
            my_backproject &= my_mask;

	    cv::RotatedRect my_track_box = cv::CamShift(
                              my_backproject, my_track_window,
			      cv::TermCriteria( 
				   CV_TERMCRIT_EPS | CV_TERMCRIT_ITER, 
                                   10, 1 ));
            if( my_track_window.area() <= 1 )
            {
                int cols = my_backproject.cols, 
                    rows = my_backproject.rows, 
                    r    = (MIN(cols, rows) + 5)/6;
                my_track_window = cv::Rect(my_track_window.x - r, 
                                           my_track_window.y - r,
                                           my_track_window.x + r, 
                                           my_track_window.y + r) &
		              cv::Rect(0, 0, cols, rows);
            }

            if(my_back_project_mode )
	        cvtColor( my_backproject, my_image, cv::COLOR_GRAY2BGR );

            ellipse( my_image, my_track_box, cv::Scalar(0,0,255), 3, CV_AA );
 
            // publish ellipse location
            msg.x = my_track_box.center.x;
            msg.y = my_track_box.center.y;
            msg.z = my_track_box.size.width;
            //ROS_INFO ("--- x: %f  y: %f  width: %f ---\n", 
            //            msg.x, msg.y, msg.z);
            my_tracker_publisher.publish(msg);
        }
        // Display a negative region when mouse dragging
        if( our_dragging )
        {
            int x      = MIN (our_drag_location.x, our_selection.x);
            int y      = MIN (our_drag_location.y, our_selection.y);
            int width  = abs (our_drag_location.x - our_selection.x);
            int height = abs (our_drag_location.y - our_selection.y);

	    cv::Mat roi(my_image, cvRect(x, y, width, height));
            bitwise_not(roi, roi);
            //ROS_INFO("Our Dragging: x: %d  y: %d  width: %d  height: %d", 
            //         x, y, width, height);
        }

        // Display Input image and histogram
	cv::imshow ("CamShiftTracking", my_image);
	if (my_show_histogram)
	{
	    cv::imshow ("Histogram", my_histogram_img);
	}

	// Handle keyboard commands
        char c = (char)cv::waitKey(10);
        if( c == 27 )
	    exit(0);
        switch(c)
        {
        case 'b':
            my_back_project_mode =! my_back_project_mode;
            break;
        case 'c':
            our_object_selected = 0;
            my_histogram_img = cv::Scalar::all(0);
            break;
        case 'h':
            my_show_histogram = !my_show_histogram;
            if( !my_show_histogram )
	        cv::destroyWindow( "Histogram" );
            else
	    {
	        cv::namedWindow( "Histogram", 1 );
                cvMoveWindow("Histogram", 900, 100);
	    }
            break;
        default:
            ;
        }
     }

};

//****************************************************************************
//          main
//****************************************************************************

int main (int argc, char **argv)
{
    help();

    // Initialize ROS Node
    ros::init (argc, argv, "CAM_Shift_Tracking");

    // Start node and create a Node Handle
    ros::NodeHandle nHandle;

    // Instantiate CamShaftTracker Object
    CamShaftTracker camShaftTracker (nHandle);

    // Spin ...
    ros::spin ();
    // ... until done
    return 0;
}