読者です 読者をやめる 読者になる 読者になる

東方算程譚

επιστημηがヨタをこく、弾幕とは無縁のCUDAなタワゴト

zip_iteratorによるSoAのソート

ふたつの配列: int x[N], y[N] があって、このふたつを連動してソートしたい、ソート順は「xの小さい順だけど、xが同じときはyの小さい順」...あるある。 CUDA世界ではデータを SoA(Structure of Array)で構成することが多いのでなおさらよくあるシチュエーションです。

thrust::sort()はひとつの配列をソートするのは簡単だけど、複数の配列を連動してソートさせるにはどーすりゃいいんだと。

thrust::zip_iteratorっていう小賢しいiteratorが用意されています。 複数のiteratorを一本のzip_iteratorにまとめることができ、zip_iteratorの指す先が移動するとまとめられたiteratorそれぞれが同じだけ移動します。 加えてzip_iteratorに対しoperator*()で値を取り出すと各iteratorの指す先の値がtupleにまとめられて返ってくるですよ。

#include <thrust/iterator/zip_iterator.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>

#include <random>
#include <algorithm>
#include <numeric>

#include <iterator>
#include <iostream>

int main() {
  using namespace std;

  const int N = 90;
  thrust::host_vector<int> x(N);
  thrust::host_vector<int> y(N);

  { // テキトーな組を生成する
    mt19937 gen;
    iota(begin(x), begin(x)+N/2, 10);
    iota(begin(x)+N/2, end(x), 10);
    iota(begin(y), end(y), 10);
    shuffle(begin(x), end(x), gen);
    shuffle(begin(y), end(y), gen);
  }

  cout << "--- before:\n";
  for ( int i = 0; i < N; ++i ) { cout << x[i] << '-' << y[i] << ' '; }
  cout << endl;

  // deviceにコピー
  thrust::device_vector<int> dx = x;
  thrust::device_vector<int> dy = y;

  // dx, dy のbegin/end() をzip_iteratorでまとめ、
  auto first = thrust::make_zip_iterator(thrust::make_tuple(begin(dx), begin(dy)));
  auto last  = thrust::make_zip_iterator(thrust::make_tuple(end(dx)  , end(dy)  ));

  // dx,dy を連動してソートする
  thrust::sort(first, last); 

  // hostに書き戻し
  x = dx;
  y = dy;
  cout << "--- after(ascending):\n";
  for ( int i = 0; i < N; ++i ) { cout << x[i] << '-' << y[i] << ' '; }
  cout << endl;

  // もう一度、今度は(比較ファンクタを与えて)降順で。
  thrust::sort(first, last, 
               [] __device__ (const auto& a, const auto& b) -> bool { return b < a; });

  x = dx;
  y = dy;
  cout << "--- after(descending):\n";
  for ( int i = 0; i < N; ++i ) { cout << x[i] << '-' << y[i] << ' '; }
  cout << endl;
}

f:id:Episteme:20161128195150p:plain