東方算程譚

επιστημηがヨタをこく、弾幕とは無縁のCUDAなタワゴト

Thrust で 生ポを扱うには

Windows/Visual Studio版ではおなじみの「配列の足し算」: c[i] = a[i] + b[i] (i = 0, 1, ...) を Thrust使って書いてみました。

/*
  DO NOT FORGET: --expt-extended-lambda option to nvcc
 */

#include <cstdio>
#include <cassert>

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>

void addWithThrust(thrust::host_vector<int>& c, 
                   const thrust::host_vector<int>& a, 
                   const thrust::host_vector<int>& b);

int main() {
  using namespace std;

  const int arraySize = 5;
  const int a_data[arraySize] = {  1,  2,  3,  4,  5 };
  const int b_data[arraySize] = { 10, 20, 30, 40, 50 };

  thrust::host_vector<int> a(a_data, a_data + arraySize);
  thrust::host_vector<int> b(b_data, b_data + arraySize);
  thrust::host_vector<int> c(arraySize);

  // Add vectors in parallel.
  addWithThrust(c, a, b);

  printf("{1,2,3,4,5} + {10,20,30,40,50} = {%d,%d,%d,%d,%d}\n",
        c[0], c[1], c[2], c[3], c[4]);

  cudaDeviceReset();
}

void addWithThrust(      thrust::host_vector<int>& c, 
                   const thrust::host_vector<int>& a, 
                   const thrust::host_vector<int>& b) {
  using namespace std;

  assert( a.size() == b.size() );
  assert( a.size() == c.size() );

  thrust::device_vector<int> dev_a = a; // copy Host to Device
  thrust::device_vector<int> dev_b = b; // copy Host to Device
  thrust::device_vector<int> dev_c(c.size());

  thrust::transform(begin(dev_a), end(dev_a), // dev_a for input
                    begin(dev_b),             // dev_b for input
                    begin(dev_c),             // dev_c for output
                    [] __device__ (int x, int y) -> int { return x + y; }); // z = x + y

  c = dev_c; // copy Device to Host
}

楽ちんですねー、thrust::host_vector/device_vector にメモリ管理を任せてしまえば cudaMallocはコンストラクタ、cudaMemcpyoperator=()がやってくれるのでコードが実に涼しげです。

Thrustのコードを追いかけてみたところ、thrust:host_vectorのメモリ確保/解放は new/deletethrust::device_vectorcudaMalloc/cudaFreeが使われてるようです。

ここまで簡単になるんだからHost/Deviceで共用できるManaged Memory使えばもっと楽できんじゃないかと試してみたですよ。

#include <cstdio>
#include <cassert>

#include <algorithm>

#include <cuda_runtime.h>

#include <thrust/transform.h>

void addWithThrust(int* c,
                   const int* a,
                   const int* b,
                   int size);

int main() {
  using namespace std;

  const int arraySize = 5;
  const int a_data[arraySize] = {  1,  2,  3,  4,  5 };
  const int b_data[arraySize] = { 10, 20, 30, 40, 50 };

  int* a; cudaMallocManaged(&a, arraySize*sizeof(int)); std::copy( a_data, a_data + arraySize, a);
  int* b; cudaMallocManaged(&b, arraySize*sizeof(int)); std::copy( b_data, b_data + arraySize, b);
  int* c; cudaMallocManaged(&c, arraySize*sizeof(int)); std::fill( c, c+arraySize, 0);

  // Add vectors in parallel.
  addWithThrust(c, a, b, arraySize);

  printf("{1,2,3,4,5} + {10,20,30,40,50} = {%d,%d,%d,%d,%d}\n",
        c[0], c[1], c[2], c[3], c[4]);

  cudaFree(a);
  cudaFree(b);
  cudaFree(c);

  cudaDeviceReset();
}

void addWithThrust(int* c,
                   const int* a,
                   const int* b,
                   int size) {

  thrust::transform(a, a+size, // a for input
                    b,         // b for input
                    c,         // c for output
                    [] __device__ (int x, int y) -> int { return x + y; }); // z = x + y
  cudaDeviceSynchronize();

}

f:id:Episteme:20161121192437p:plain

あらら、ぜーんぜん動いてません。原因はココ:

  thrust::transform(a, a+size, // a for input
                    b,         // b for input
                    c,         // c for output
                    [] __device__ (int x, int y) -> int { return x + y; }); // z = x + y

Thrustが提供する多くの関数は、引き渡されるイテレータの型に応じてHostで行うかDeviceで行うかを静的に判断しています。このコードではイテレータとして生のポインタを渡してるんですけど、Thrustではイテレータが(生の)ポインタであったならHostで実行すべきものと判断してるっポいです。

解決策はふたつ。ひとつは第一引数に実行ポリシーを明示的に与えること。

#include <thrust/execution_policy.h>
...
void addWithThrust(int* c,
                   const int* a,
                   const int* b,
                   int size) {

  thrust::transform(thrust::device,
                    a, a+size, // a for input
                    b,         // b for input
                    c,         // c for output
                    [] __host__ __device__ (int x, int y) -> int { return x + y; }); // z = x + y
  cudaDeviceSynchronize();
}

もうひとつは、生のポインタを thrust::device_ptr に変換すること。

#include <thrust/device_ptr.h>
...
void addWithThrust(int* c,
                   const int* a,
                   const int* b,
                   int size) {

  thrust::device_ptr<int> pa(const_cast<int*>(a));
  thrust::device_ptr<int> pb(const_cast<int*>(b));
  thrust::device_ptr<int> pc(c);
  thrust::transform(pa, pa+size, // a for input
                    pb,          // b for input
                    pc,          // c for output
                    [] __host__ __device__ (int x, int y) -> int { return x + y; }); // z = x + y
  cudaDeviceSynchronize();

}

COO-format から CSR-format への変換

m行n列の行列、成分数は m*n個。m,nが大きくなるとメモリの使用量がシャレになりません。けどもほとんどの成分が0であるなら、0成分を省略することでメモリ消費をぐっと抑えることができます。

ほとんどの成分が0のスカスカな行列を疎行列(Sparse matrix)と称します。その疎行列を対象にした各種演算をCUDAでやらせるライブラリがcuSPARSEです。

疎行列を表現するには (行index, 列index, 成分) の組を非0成分数だけ並べればよいですな。SoA(Structure of Array)なレイアウトだと

  int   rowInd[nnz]; // 行index
  int   colInd[nnz]; // 列index
  float val[nnz];    // 成分

こんなカンジ、nnzは非0成分の数(number of nonzero)。この形式をCOO-format(Coodinate format)っていいます。

くSPARSEが提供する演算には疎行列を引数に与えるんですけど、上記COO-formatのままでは引数に与えることができず、CSR-format(Compressed Sparse Row format)に変換せにゃなりません。その方法。

まずCOO-formatの各組を行(row),列(col)の順に昇順でソートします。行(row)の小さい順、行が同じなら列(col)の小さい順。当然成分(val)も連動して入れ替えを行います。

こんな元データが

f:id:Episteme:20161116191559p:plain

このようにソートされます。

f:id:Episteme:20161116191605p:plain

これが row-major でソートされた COO-format。

行indexは同じ値が連続しますね、row-majorにソートしたんだから当然ですけど。 同じ値が連続するってことはすなわち冗長であり圧縮できるってことです。

rowInd[nnz] を rowPtr[m+1] に圧縮した結果がコチラ。

f:id:Episteme:20161116191611p:plain

これがCSR-format(Compressed Sparse matrix Row format)。

COO-format から CSR-format への変換関数群は cuSPARSE の中に用意されています。

#include <cuda_runtime.h>
#include <cusparse.h>
#include <algorithm>
#include <numeric>
#include <random>
#include <iostream>
#include <iomanip>

using namespace std;

int main() {
  const int m   = 10; // 行数
  const int n   = 20; // 列数
  const int nnz = 30; // 非0成分数(numner of non-zero)

  // 疎行列のCOO(Coodinate format)
  int*   rowInd;
  int*   colInd;
  float* cooVal;
  cudaMallocManaged(&rowInd, nnz*sizeof(int));
  cudaMallocManaged(&colInd, nnz*sizeof(int));
  cudaMallocManaged(&cooVal, nnz*sizeof(float));

  { // 乱数で疎行列を生成する
    int inx[m*n];
    iota(begin(inx), end(inx), 0);
    shuffle(begin(inx), end(inx), mt19937());
    for ( int i = 0; i < nnz; ++i ) {
      rowInd[i] = inx[i]/n;
      colInd[i] = inx[i]%n;
      cooVal[i] = rowInd[i] + colInd[i]*0.01f;
    }
  }

  cout << "COO-format\n"
          "# ( rowInd, colInd, val)\n";
  for ( int i = 0; i< nnz; ++i ) {
    cout << setw(3) << i << " ( "
         << setw(3) << rowInd[i] << ',' 
         << setw(3) << colInd[i] << ", " 
         << cooVal[i] 
         << " )" << endl;
  }
  cout << endl;

  // sort後の成分領域を確保
  float* csrVal;
  cudaMallocManaged(&csrVal, nnz*sizeof(float));

  // rowPtr(圧縮されたrowInd)領域を確保 大きさは m+1
  int* rowPtr;
  cudaMallocManaged(&rowPtr, (m+1)*sizeof(int));

  // cuSPARSE 初期化
  cusparseHandle_t handle;
  cusparseCreate(&handle);

  // coosortに必要なバッファの大きさを求め、
  size_t bufferSize;
  cusparseXcoosort_bufferSizeExt(handle, m, n, nnz, rowInd, colInd, &bufferSize);
  // バッファを確保する
  void* sortBuffer;
  cudaMalloc(&sortBuffer, bufferSize);

  // permutation領域を確保し、
  int* permutation;
  cudaMalloc(&permutation, nnz*sizeof(int));
  // 0, 1, 2,... で埋める
  cusparseCreateIdentityPermutation(handle, nnz, permutation);

  // 行、列の順で昇順にソートし、
  cusparseXcoosortByRow(handle, m, n, nnz, rowInd, colInd, permutation, sortBuffer);
  cudaFree(sortBuffer);

  // 成分がソート順に対応するようpermutationを基に配置する
  cusparseSgthr(handle, nnz, cooVal, csrVal, permutation, CUSPARSE_INDEX_BASE_ZERO);
  cudaFree(permutation);

  // rowIndを圧縮してrwoPtrに求める
  cusparseXcoo2csr(handle, rowInd, nnz, m, rowPtr, CUSPARSE_INDEX_BASE_ZERO);
  cudaDeviceSynchronize();

  cout << "COO-format row-major sorted\n"
          "# ( rowInd, colInd, val)\n";
  for ( int i = 0; i< nnz; ++i ) {
    cout << setw(3) << i << " ( "
         << setw(3) << rowInd[i] << ',' 
         << setw(3) << colInd[i] << ", " 
         << csrVal[i] 
         << " )" << endl;
  }
  cout << endl;

  cout << "CSR-format\n"
          "# ( rowPtr, colInd, val)\n";
  for ( int i = 0; i< nnz; ++i ) {
    cout << setw(3) << i; ;
    if ( i <= m )
      cout << " ( " << setw(3) << rowPtr[i] << ',';
    else
      cout << "     ( ";
    cout << setw(3) << colInd[i] << ", " 
         << csrVal[i] 
         << " )" << endl;
  }
  cout << endl;

  cudaFree(rowInd);
  cudaFree(rowPtr);
  cudaFree(colInd);
  cudaFree(cooVal);
  cudaFree(csrVal);
  cusparseDestroy(handle);
  cudaDeviceReset();

}

Visual Studio 2013 + CUDA 7.5 から 2015+8.0 へ

今年9月末、CUDA 8.0がリリースされました。以前の版では Visual Studio 2013までしか対応してなくて、CUDA 8.0でようやくVS2015で使えるようになりました。

CUDA 7.5で作ったVS2013 projectを VS2015 に食わすと、こんなエラーが現れたりします:

f:id:Episteme:20161114193655p:plain

VS2015でサポートされるCUDAは8.0だけなんで、そこにCUDA 7.5の設定を含んだproject食わすんだから"そんな設定知らへんよ!"ってことですわ。

VS2013がマシン上に残っていればそいつでprojectを開けて 7.5→8.0 に変更したのち、そのprojectをVS2015に食わすんですが...VS2013を un-install しちゃってるとアチャーなことになります。てか実際やらかしました(テヘ

こんなときどーするか、VS2015で読み込みに失敗したprojectを右-clickし、~.vcxproj を編集します。~.vcxprojはXMLよーするにただのテキスト・ファイル。中から文字列 "CUDA 7.5" を見つけ(たぶん2か所あるハズ)、"CUDA 8.0" に書き換えて、project右-clickして再読み込みしてみてくださいな。

f:id:Episteme:20161114193725p:plain

NPP : Canny変換(そのに)

Canny変換は明るさの変化点を見つけることで輪郭を検出します。 カラー画像で色は違うけど明るさの同じ領域が接しているとモノクロ化したときに明るさに変化がないため輪郭が検出できなくなるんですね。

カラー画像をRGB3枚の画像にバラし、それぞれにCanny変換をかけて再合成してみました。

/*
 * DO NOT FORGET nvcc option : --expt-extended-lambda
 */

// std
#include <iostream>

// OpenCV
#include <opencv2/opencv.hpp>

// CUDA
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <npp.h>

// カーネル関数 二次元のtransform 
//    dst[y][x] = fun(src[y][x]) 
//       where : 0 <= x < width, 0 <= y < height
template<typename T, typename U, typename Function>
__global__ void kernel_transform2D(unsigned int  width, unsigned int height, 
                                        const T* src,         size_t src_pitch,
                                              U* dst,         size_t dst_pitch,
                                       Function  fun) {
  unsigned int x = blockDim.x * blockIdx.x + threadIdx.x;
  unsigned int y = blockDim.y * blockIdx.y + threadIdx.y;
  if ( x < width && y < height ) {
    U* dst_ptr = ((U*)((char*)dst + dst_pitch*y)) + x;
    *dst_ptr = fun(((const T*)((const char*)src + src_pitch*y))[x], *dst_ptr);
  }
}

void color2gray(unsigned int  width, unsigned int height, 
                      uchar3* src,         size_t src_pitch,
                       uchar* dst,         size_t dst_pitch) {
  kernel_transform2D<<<dim3((width+31)/32, (height+7)/8), dim3(32,8)>>>(
    width, height, 
    src, src_pitch, 
    dst, dst_pitch,
    [] __device__ (const uchar3 v, uchar) -> uchar { 
       int t = (v.x + v.y*7 + v.z*2)/10; 
       if ( t <   0 ) t = 0; 
       if ( t > 255 ) t = 255; 
       return (uchar)t; 
    }
  );
}

void color2gray_channel(unsigned int  width, unsigned int height, 
                        uchar3* src, size_t src_pitch,
                        uchar*  dst, size_t dst_pitch,
                       int    channel) {
  kernel_transform2D<<<dim3((width+31)/32, (height+7)/8), dim3(32,8)>>>(
    width, height, 
    src, src_pitch, 
    dst, dst_pitch,
    [=] __device__ (const uchar3 v, uchar) -> uchar 
      { return ((const uchar*)&v)[channel]; }
  );
}

void gray2color_channel(unsigned int  width, unsigned int height, 
                        uchar*  src, size_t src_pitch,
                        uchar3* dst, size_t dst_pitch,
                        int channel) {
  kernel_transform2D<<<dim3((width+31)/32, (height+7)/8), dim3(32,8)>>>(
    width, height, 
    src, src_pitch, 
    dst, dst_pitch,
    [=] __device__ (const uchar v, uchar3 c) -> uchar3 
      { uchar3 t = c; ((uchar*)&t)[channel] = v; return t; }
  );
}

int main(int argc, char *argv[]) {
  cv::VideoCapture camera(0);

  cv::namedWindow("original", CV_WINDOW_AUTOSIZE);
  cv::namedWindow("canny", CV_WINDOW_AUTOSIZE);

  cv::Mat frame;
  cv::Mat canny;

  uchar3* d_frame;
  uchar*  d_gray_base;
  uchar*  d_canny_base;

  uchar*  d_gray[3];
  uchar*  d_canny[3];
  size_t  d_frame_pitch;
  size_t  d_gray_pitch;
  size_t  d_canny_pitch;
  Npp8u*  d_buffer;
  NppiSize size;

 
  // 一発目のキャプチャでフレームのサイズがわかるから
  // (そして多分その後ずっと変わらんだろから)
  // それを基にdevice-memoryを確保
  camera >> frame;

  size.width = (int)frame.size().width;
  size.height = (int)frame.size().height;

  cudaMallocPitch(&d_frame,      &d_frame_pitch, size.width*sizeof(uchar3), size.height);
  cudaMallocPitch(&d_gray_base,  &d_gray_pitch,  size.width,                size.height*3);
  cudaMallocPitch(&d_canny_base, &d_canny_pitch, size.width,                size.height*3);

  for ( size_t i = 0; i < 3; ++i ) {
    d_gray[i]  = d_gray_base  + d_gray_pitch *size.height*i;
    d_canny[i] = d_canny_base + d_canny_pitch*size.height*i;
  }

  // Cannyに引き渡すパラメータ
  NppiSize  nroi    = size;
  NppiPoint noffset = { 0, 0 };
  // 以下のパラメータはイイカンジになるよう適宜調整。
  Npp16s                 nlow_threshold  = 50;
  Npp16s                 nhigh_threshold = 150;
  NppiDifferentialKernel nkernel   = NPP_FILTER_SOBEL;
  NppiMaskSize           nmasksize = NPP_MASK_SIZE_3_X_3;


  // Canny変換に必要なバッファを確保
  {
  int buffer_size;
  nppiFilterCannyBorderGetBufferSize(size, &buffer_size);
  cudaMalloc(&d_buffer, buffer_size);
  }

  canny = frame.clone();
  std::cout 
    << "width,height   = " << size.width << ',' << size.height  
    << "\nstep           = " << frame.step 
    << "\ndepth, channel = " << frame.depth() << ',' << frame.channels()
    << "\n***** [ESC] to exit. *****\n";

  while ( cv::waitKey(10) != 0x1b ) {
    // [1] 画像を frame にキャプチャ
    camera >> frame;
    cv::imshow("original", frame);

    // [2] frame から d_frame へコピー
    cudaMemcpy2D(d_frame, d_frame_pitch, frame.data, frame.step, 
                 size.width*sizeof(uchar3), size.height, cudaMemcpyDefault);

    for ( int i = 0; i < 3; ++i ) {
      // [3] d_frame をモノクロ化して d_gray へ
      color2gray_channel(size.width, size.height, d_frame, d_frame_pitch, d_gray[i], d_gray_pitch,i);

      // [4] d_gray に Canny変換カマして d_canny へ
      nppiFilterCannyBorder_8u_C1R(d_gray[i],  (int)d_gray_pitch,  size, noffset,
                                   d_canny[i], (int)d_canny_pitch, nroi,
                                   nkernel, nmasksize,
                                   nlow_threshold, nhigh_threshold,
                                   nppiNormL2, NPP_BORDER_REPLICATE, 
                                   d_buffer);

      // [5] d_canny をカラー化(RGBを同じ値にするだけ)して d_frame へ
      gray2color_channel(size.width, size.height, d_canny[i], d_canny_pitch, d_frame, d_frame_pitch, i);
    }

    // [6] d_frame を canny へコピー
    cudaMemcpy2D(canny.data, canny.step, d_frame, d_frame_pitch, 
                 size.width*sizeof(uchar3), size.height, cudaMemcpyDefault);

    // [7] 描画!
    cv::imshow("canny", canny);
  }

  // あとしまつ
  cudaFree(d_frame);
  cudaFree(d_gray_base);
  cudaFree(d_canny_base);
  cudaFree(d_buffer);
}

こんなんができましたよ。

f:id:Episteme:20161112000224p:plain

NPP : Canny変換

NPP(NVIDIA Performance Primitive) の中に Canny変換 を見つけました。
どうやら CUDA 8.0 で新たに追加されたみたいです。

Canny変換は画像の輪郭を抽出するもので、Sobel/Scharr変換よりシャープな輪郭線を描いてくれます。 Sobel/Scharr変換で得られた輝度勾配の稜線を見つけてくれるってゆーか。

早速試してみました。 OpenCV 3.1 を使ってWeb-cameraからの画像のキャプチャと描画を行います。ダンドリはこんな。

f:id:Episteme:20161110181058p:plain

  1. Web-cameraからキャプチャした画像を
  2. Device-memoryにコピー
  3. モノクロ化し
  4. Canny変換を施します。
  5. 変換後のモノクロ画像をカラー化(RGBに同じ値を入れるだけ)し
  6. Hostに書き戻して
  7. 描画!
/*
 * DO NOT FORGET nvcc option : --expt-extended-lambda
 */

// std
#include <iostream>

// OpenCV
#include <opencv2/opencv.hpp>

// CUDA
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <npp.h>

// カーネル関数 二次元のtransform 
//    dst[y][x] = fun(src[y][x]) 
//       where : 0 <= x < width, 0 <= y < height
template<typename T, typename U, typename Function>
__global__ void kernel_transform2D(unsigned int  width, unsigned int height, 
                                        const T* src,         size_t src_pitch,
                                              U* dst,         size_t dst_pitch,
                                       Function  fun) {
  unsigned int x = blockDim.x * blockIdx.x + threadIdx.x;
  unsigned int y = blockDim.y * blockIdx.y + threadIdx.y;
  if ( x < width && y < height ) {
    ((U*)((char*)dst + dst_pitch*y))[x] = fun(((const T*)((const char*)src + src_pitch*y))[x]);
  }
}

void color2gray(unsigned int  width, unsigned int height, 
                      uchar3* src,         size_t src_pitch,
                       uchar* dst,         size_t dst_pitch) {
  kernel_transform2D<<<dim3((width+31)/32, (height+7)/8), dim3(32,8)>>>(
    width, height, 
    src, src_pitch, 
    dst, dst_pitch,
    [] __device__ (const uchar3 v) -> uchar { 
       int t = (v.x + v.y*7 + v.z*2)/10; 
       if ( t <   0 ) t = 0; 
       if ( t > 255 ) t = 255; 
       return (uchar)t; 
    }
  );
}

void gray2color(unsigned int  width, unsigned int height, 
                       uchar* src,         size_t src_pitch,
                      uchar3* dst,         size_t dst_pitch) {
  kernel_transform2D<<<dim3((width+31)/32, (height+7)/8), dim3(32,8)>>>(
    width, height, 
    src, src_pitch, 
    dst, dst_pitch,
    [] __device__ (const uchar v) -> uchar3 { return make_uchar3(v,v,v); }
  );
}

int main(int argc, char *argv[]) {
  cv::VideoCapture camera(0);

  cv::namedWindow("original", CV_WINDOW_AUTOSIZE);
  cv::namedWindow("canny", CV_WINDOW_AUTOSIZE);

  cv::Mat frame;
  cv::Mat canny;

  uchar3* d_frame;
  uchar*  d_gray;
  uchar*  d_canny;
  size_t  d_frame_pitch;
  size_t  d_gray_pitch;
  size_t  d_canny_pitch;
  Npp8u*  d_buffer;
  NppiSize size;

 
  // 一発目のキャプチャでフレームのサイズがわかるから
  // (そして多分その後ずっと変わらんだろから)
  // それを基にdevice-memoryを確保
  camera >> frame;
  size.width = (int)frame.size().width;
  size.height = (int)frame.size().height;

  cudaMallocPitch(&d_frame, &d_frame_pitch, size.width*sizeof(uchar3), size.height);
  cudaMallocPitch(&d_gray,  &d_gray_pitch,  size.width,                size.height);
  cudaMallocPitch(&d_canny, &d_canny_pitch, size.width,                size.height);

  // Cannyに引き渡すパラメータ
  NppiSize  nroi    = size;
  NppiPoint noffset = { 0, 0 };
  Npp16s    nlow_threshold  = 50;  // これと
  Npp16s    nhigh_threshold = 150; // これは適宜調整。

  // Canny変換に必要なバッファを確保
  {
  int buffer_size;
  nppiFilterCannyBorderGetBufferSize(size, &buffer_size);
  cudaMalloc(&d_buffer, buffer_size);
  }

  canny = frame.clone();
  std::cout 
    << "width,height   = " << size.width << ',' << size.height  
    << "\nstep           = " << frame.step 
    << "\ndepth, channel = " << frame.depth() << ',' << frame.channels()
    << "\n***** [ESC] to exit. *****\n";

  while ( cv::waitKey(10) != 0x1b ) {

    // [1] 画像を frame にキャプチャ
    camera >> frame;
    cv::imshow("original", frame);

    // [2] frame から d_frame へコピー
    cudaMemcpy2D(d_frame, d_frame_pitch, frame.data, frame.step, 
                 size.width*sizeof(uchar3), size.height, cudaMemcpyDefault);
    // [3] d_frame をモノクロ化して d_gray へ
    color2gray(size.width, size.height, d_frame, d_frame_pitch, d_gray, d_gray_pitch);

    // [4] d_gray に Canny変換カマして d_canny へ
    nppiFilterCannyBorder_8u_C1R(d_gray,  (int)d_gray_pitch,  size, noffset,
                                 d_canny, (int)d_canny_pitch, nroi,
                                 NPP_FILTER_SOBEL, NPP_MASK_SIZE_3_X_3,
                                 nlow_threshold, nhigh_threshold,
                                 nppiNormL2, NPP_BORDER_REPLICATE, 
                                 d_buffer);

    // [5] d_canny をカラー化(RGBを同じ値にするだけ)して d_frame へ
    gray2color(size.width, size.height, d_canny, d_canny_pitch, d_frame, d_frame_pitch);

    // [6] d_frame を canny へコピー
    cudaMemcpy2D(canny.data, canny.step, d_frame, d_frame_pitch, 
                 size.width*sizeof(uchar3), size.height, cudaMemcpyDefault);

    // [7] 描画!
    cv::imshow("canny", canny);
  }

  // あとしまつ
  cudaFree(d_frame);
  cudaFree(d_gray);
  cudaFree(d_canny);
  cudaFree(d_buffer);
}

こんな輪郭線を描いてくれます;

f:id:Episteme:20161110181141p:plain

カスタム・アロケータ

CUDAでは目的/用途に応じて様々なメモリの確保/解放APIが用意されています。

  1. ピン留めされたHost-memory : cudaMallocHost / cudaFreeHost
  2. Host/Device双方で共用できるManaged-memory : cudaMallocManaged / cudaFree
  3. Device-memory : cudaMalloc : cudaFree

Device側はさておき、Host側は上記1.2および通常のnew[] / delete[] の3種のメモリ確保/解放を使い分けることになります。

C++屋が(可変長)配列を扱う際。日常的にstd::vectorのお世話になるのですが、std::vector<T>は(デフォルトで)new[]/delete[]が内部的なメモリ管理に使われます。

このメモリ管理をcudaMallocHost/cudaFreeHostあるいはcudaMallocManaged/cudaFreeに差し替えることができれば ピン留めされたvectorHost/Deviceの双方でで共用できるvector が使えて便利。

ってわけで実装しました。ついでに unnique_device_ptr と cuda runtime 例外も。

/* cuda_except.h */
#ifndef CUDA_EXCEPT_H_
#define CUDA_EXCEPT_H_

#include <cuda_runtime.h>
#include <stdexcept>

namespace cu {

class cuda_error : public std::runtime_error {
  cudaError_t err_;
public:
  cuda_error(cudaError_t error) : std::runtime_error(cudaGetErrorString(error)), err_(error) {}
  cudaError_t code() const { return err_; }
  const char* name() const { return cudaGetErrorName(err_); }
};

}
#endif
/* cuda_allocator.h */
#ifndef CUDA_ALLOCATOR_H_
#define CUDA_ALLOCATOR_H_

#include "cuda_except.h"

namespace cu {

template <class T>
struct host_allocator {
  typedef T value_type;
  host_allocator() noexcept {} //default ctor not required by STL
  template<class U> host_allocator(const host_allocator<U>&) noexcept {}
  template<class U> bool operator==(const host_allocator<U>&) const noexcept { return true; }
  template<class U> bool operator!=(const host_allocator<U>&) const noexcept { return false; }
  T* allocate(const size_t n) const;
  void deallocate(T* const p) const noexcept;
  void deallocate(T* const p, size_t) const noexcept { deallocate(p); }
};

template <class T>
struct managed_allocator {
  typedef T value_type;
  managed_allocator() noexcept {} //default ctor not required by STL
  template<class U> managed_allocator(const managed_allocator<U>&) noexcept {}
  template<class U> bool operator==(const managed_allocator<U>&) const noexcept { return true; }
  template<class U> bool operator!=(const managed_allocator<U>&) const noexcept { return false; }
  T* allocate(const size_t n) const;
  void deallocate(T* const p) const noexcept;
  void deallocate(T* const p, size_t) const noexcept { deallocate(p); }
};

template <class T>
struct device_allocator {
  typedef T value_type;
  device_allocator() noexcept {} //default ctor not required by STL
  template<class U> device_allocator(const device_allocator<U>&) noexcept {}
  template<class U> bool operator==(const device_allocator<U>&) const noexcept { return true; }
  template<class U> bool operator!=(const device_allocator<U>&) const noexcept { return false; }
  T* allocate(const size_t n) const;
  void deallocate(T* const p) const noexcept;
  void deallocate(T* const p, size_t) const noexcept { deallocate(p); }
};

#include <cuda_runtime.h>

template <class T>
T* host_allocator<T>::allocate(const size_t n) const {
  if ( n == 0 ) return nullptr;
  if ( n > static_cast<size_t>(-1) / sizeof(T) ) throw std::bad_array_new_length();
  void* pv = nullptr;
  cudaError_t err = cudaMallocHost(&pv, n*sizeof(T));
  if ( err != cudaSuccess ) throw cuda_error(err);
  return static_cast<T*>(pv);
}

template<class T> 
void host_allocator<T>::deallocate(T * const p) const noexcept{
  cudaError_t err = cudaFreeHost(p);
//if ( err != cudaSuccess ) throw cuda_error(err);
}

template <class T>
T* managed_allocator<T>::allocate(const size_t n) const {
  if ( n == 0 ) return nullptr;
  if ( n > static_cast<size_t>(-1) / sizeof(T) ) throw std::bad_array_new_length();
  void* pv = nullptr;
  cudaError_t err = cudaMallocManaged(&pv, n*sizeof(T));
  if ( err != cudaSuccess ) throw cuda_error(err);
  return static_cast<T*>(pv);
}

template<class T> 
void managed_allocator<T>::deallocate(T * const p) const noexcept {
  cudaError_t err = cudaFree(p);
//if ( err != cudaSuccess ) throw cuda_error(err);
}

template <class T>
T* device_allocator<T>::allocate(const size_t n) const {
  if ( n == 0 ) return nullptr;
  if ( n > static_cast<size_t>(-1) / sizeof(T) ) throw std::bad_array_new_length();
  void* pv = nullptr;
  cudaError_t err = cudaMalloc(&pv, n*sizeof(T));
  if ( err != cudaSuccess ) throw cuda_error(err);
  return static_cast<T*>(pv);
}

template<class T> 
void device_allocator<T>::deallocate(T * const p) const noexcept {
  cudaError_t err = cudaFree(p);
//if ( err != cudaSuccess ) throw cuda_error(err);
}

}
#endif
/* unique_device_ptr.h */
#ifndef UNIQUE_DEVICE_PTR_H_
#define UNIQUE_DEVICE_PTR_H_

#include <cuda_runtime.h>
#include <memory>

namespace cu {

template<typename T> struct device_delete {
  device_delete() noexcept = default;
  void operator()(T* ptr) const { cudaFree(ptr); }
};

template<typename T> struct device_delete<T[]> {
  device_delete() noexcept = default;
  void operator()(T* ptr) const { cudaFree(ptr); }
};

template<typename T>
using device_unique_ptr = std::unique_ptr<T,device_delete<T>>;

}

#endif

Windows版CUDA Toolkitではおなじみの配列の足し算をre-writeしてみました。メモリ管理とエラー処理がぐっと楽になります。

#include "cuda_runtime.h"
#include "device_launch_parameters.h"

#include <stdio.h>

#include "cuda_allocator.h"
#include "unique_device_ptr.h"

#include <vector>

template<typename T> using host_vector = std::vector<T, cu::host_allocator<T>>;

void addWithCuda(int *c, const int *a, const int *b, unsigned int size);

__global__ void addKernel(int *c, const int *a, const int *b) {
    int i = threadIdx.x;
    c[i] = a[i] + b[i];
}

inline void cuda_check(cudaError_t status) {
  if (status != cudaSuccess) throw cu::cuda_error(status);
}

int main() {
  try {
    const int arraySize = 5;
    host_vector<int> a = {  1,  2,  3,  4,  5 };
    host_vector<int> b = { 10, 20, 30, 40, 50 };
    host_vector<int> c(arraySize, 0);

    // Add vectors in parallel.
    addWithCuda(c.data(), a.data(), b.data(), arraySize);

    printf("{1,2,3,4,5} + {10,20,30,40,50} = {%d,%d,%d,%d,%d}\n",
        c[0], c[1], c[2], c[3], c[4]);
  } catch ( const cu::cuda_error& er ) {
    fprintf(stderr, "%s : %s\n", er.name(), er.what());
  }

   // cudaDeviceReset must be called before exiting in order for profiling and
  // tracing tools such as Nsight and Visual Profiler to show complete traces.
  cudaError_t cudaStatus = cudaDeviceReset();
  if (cudaStatus != cudaSuccess) {
    fprintf(stderr, "cudaDeviceReset failed!");
    return 1;
  }

}

// Helper function for using CUDA to add vectors in parallel.
void addWithCuda(int *c, const int *a, const int *b, unsigned int size) {
  cudaError_t cudaStatus;

  cu::device_allocator<int> alloc;

  // Choose which GPU to run on, change this on a multi-GPU system.
  cudaStatus = cudaSetDevice(0);
  cuda_check(cudaStatus);

  cu::device_unique_ptr<int[]> dev_a(alloc.allocate(size));
  cu::device_unique_ptr<int[]> dev_b(alloc.allocate(size));
  cu::device_unique_ptr<int[]> dev_c(alloc.allocate(size));

  cudaStatus = cudaMemcpyAsync(dev_a.get(), a, size*sizeof(int), cudaMemcpyHostToDevice);
  cuda_check(cudaStatus);
  cudaStatus = cudaMemcpyAsync(dev_b.get(), b, size*sizeof(int), cudaMemcpyHostToDevice);
  cuda_check(cudaStatus);

  // Launch a kernel on the GPU with one thread for each element.
  addKernel<<<1, size>>>(dev_c.get(), dev_a.get(), dev_b.get());

  // Check for any errors launching the kernel
  cudaStatus = cudaGetLastError();
  cuda_check(cudaStatus);
    
  cudaStatus = cudaMemcpyAsync(c, dev_c.get(), size*sizeof(int), cudaMemcpyDeviceToHost);
  cuda_check(cudaStatus);

  // cudaDeviceSynchronize waits for the kernel to finish, and returns
  // any errors encountered during the launch.
  cudaStatus = cudaDeviceSynchronize();
  cuda_check(cudaStatus);

}

cuFFT: フーリエ変換で雑音を消す

フーリエ変換(Fourier Transform)は信号処理/解析のド定番。

周期を持ったあらゆる波は異なる周波数のサイン波の重ね合わせで作り出すことができ、与えられた波形から、その成分(サイン波)を割り出すのがフーリエ変換、CUDAにはフーリエ変換ライブラリ: cuFFT が入ってます。軽く使ってみましょうね。

まず入力波を作ります。350,400,450Hzのサイン波を重ね合わせ、さらに一様乱数で生成したノイズを乗せましょう。

  const size_t N = 44100U; // データ数(44.1kHzサンプリングでの1秒分)

  vector<float> signal(N);
  vector<float> h_in(N);

  // 振幅 ±2 のホワイト・ノイズ
  mt19937 mt;
  uniform_real_distribution<float> rnd(-2.0f, 2.0f);

  // 350,400,450Hzのサイン波にノイズを乗せる
  float omega = 2.0f * 3.1416f / N;
  for ( unsigned int i = 0; i < N; ++i ) {
    signal[i] = 
       sinf(omega * 350.0f * (float)i) * 1.0f +
       sinf(omega * 400.0f * (float)i) * 0.8f +
       sinf(omega * 450.0f * (float)i) * 0.6f ;
    h_in[i]   = signal[i] + rnd(mt);
  }

こんなのができました。'赤'はサイン波の合成、それにノイズを乗せたのが'青'です。

f:id:Episteme:20161028220303p:plain

この(ノイズまみれの)信号にフーリエ変換を施します。

  // device-memoryの確保(入/出力兼用)
  float* d_real = nullptr;
  cudaMalloc(&d_real, N*sizeof(float));
  float2* d_cplx = reinterpret_cast<float2*>(d_real);

  // フーリエ変換
  cudaMemcpy(d_real, h_in.data(), N*sizeof(float), cudaMemcpyHostToDevice);

  cufftHandle plan_f;
  cufftPlan1d(&plan_f, N, CUFFT_R2C, 1); // Real to Complex (forward)
  cufftExecR2C(plan_f, d_real, d_cplx);

  vector<float2> h_mid(N/2); // スペクトル(フーリエ変換の結果)
  cudaMemcpy(h_mid.data(), d_cplx, N*sizeof(float), cudaMemcpyDeviceToHost);

変換結果がコレ。

f:id:Episteme:20161028220331p:plain

350,400,450に大きなピークが見られますね。ノイズは様々な周波数の波がちょっとずつ重なったものなのでグラフの底に貼りつく'モジョモジョ'した部分に現れます。

で、このデータから300Hz以下と500Hz以上の部分をばっさり削ってしまいます。帯域フィルタ(band-pass filter)ってやつです。

  // band-pass filter
  // 300Hz以下/500Hz以上の信号をカットする
  cudaMemset(d_cplx     , 0,      300U  * sizeof(float2));
  cudaMemset(d_cplx+500U, 0, (N/2-500U) * sizeof(float2));

よーするに邪魔なノイズ成分の多くを削り取ったことになります。

しかるのち逆フーリエ変換をかけて、周波数軸から時間軸に戻します。

  // 逆フーリエ変換
  cufftHandle plan_i;
  cufftPlan1d(&plan_i, N, CUFFT_C2R, 1); // Complex to Real (inverse)
  cufftExecC2R(plan_i, d_cplx, d_real);

  // 結果の出力
  vector<float>  h_out(N);
  cudaMemcpy(h_out.data(), d_real, N*sizeof(float), cudaMemcpyDeviceToHost);

結果がコレ。ノイズが消えました♪

f:id:Episteme:20161028220401p:plain

コチラ↓が全コード:

/*
 * Noise Reduction with cuFFT
 */
#include <cuda_runtime.h>
#include <cufft.h>

#include <iostream>
#include <random>
#include <vector>
#include <cmath>

using namespace std;

int main() {

  const size_t N = 44100U; // データ数(44.1kHzサンプリングでの1秒分)

  vector<float> signal(N);
  vector<float> h_in(N);

  // 振幅 ±2 のホワイト・ノイズ
  mt19937 mt;
  uniform_real_distribution<float> rnd(-2.0f, 2.0f);

  // 350,400,450Hzのサイン波にノイズを乗せる
  float omega = 2.0f * 3.1416f / N;
  for ( unsigned int i = 0; i < N; ++i ) {
    signal[i] = 
       sinf(omega * 350.0f * (float)i) * 1.0f +
       sinf(omega * 400.0f * (float)i) * 0.8f +
       sinf(omega * 450.0f * (float)i) * 0.6f ;
    h_in[i]   = signal[i] + rnd(mt);
  }

  // device-memoryの確保(入/出力兼用)
  float* d_real = nullptr;
  cudaMalloc(&d_real, N*sizeof(float));
  float2* d_cplx = reinterpret_cast<float2*>(d_real);

  // フーリエ変換
  cudaMemcpy(d_real, h_in.data(), N*sizeof(float), cudaMemcpyHostToDevice);

  cufftHandle plan_f;
  cufftPlan1d(&plan_f, N, CUFFT_R2C, 1); // Real to Complex (forward)
  cufftExecR2C(plan_f, d_real, d_cplx);

  vector<float2> h_mid(N/2); // スペクトル(フーリエ変換の結果)
  cudaMemcpy(h_mid.data(), d_cplx, N*sizeof(float), cudaMemcpyDeviceToHost);

  // band-pass filter
  // 300Hz以下/500Hz以上の信号をカットする
  cudaMemset(d_cplx     , 0,      300U  * sizeof(float2));
  cudaMemset(d_cplx+500U, 0, (N/2-500U) * sizeof(float2));

  // 逆フーリエ変換
  cufftHandle plan_i;
  cufftPlan1d(&plan_i, N, CUFFT_C2R, 1); // Complex to Real (inverse)
  cufftExecC2R(plan_i, d_cplx, d_real);

  // 結果の出力
  vector<float>  h_out(N);
  cudaMemcpy(h_out.data(), d_real, N*sizeof(float), cudaMemcpyDeviceToHost);

  cout << "signal, noised, processed, spectrum" << endl;
  for ( unsigned int i = 0; i < 500; ++i ) {
    cout << signal[i] << ',' 
         << h_in[i] << ',' 
         << h_out[i]/N << ',' 
         << cuCabsf(h_mid[i]) << endl;
  } 

  cudaFree(d_real);
  cufftDestroy(plan_f);
  cufftDestroy(plan_i);

}

(original: 2015-05-21 #36 #37)