ウェーブレット行列を実装した

元のデータに対して十分小さいサイズでありながら各種操作を高速に処理でき、文字列のみならず2次元データやグラフデータまで表現できるというウェーブレット行列を実装してみた。「高速文字列解析の世界」とかブログとか読んでやっとのことで実装した。
ウェーブレット行列の各操作のオーダーの表記では、文字集合のサイズをσ、文字列長をnとしている。

2014/8/25:プログラム修正

inline int popCount(unsigned int x){
	x = (x>>1 & 0x55555555)+(x & 0x55555555);
	x = (x>>2 & 0x33333333)+(x & 0x33333333);
	x = (x>>4 & 0x0f0f0f0f)+(x & 0x0f0f0f0f);
	x = (x>>8 & 0x00ff00ff)+(x & 0x00ff00ff);
	return (x>>16)+(x & 0x0000ffff);
}
inline int kthRightmostPop(unsigned int pop1,int k){
	unsigned int pop2,pop4,pop8,pop16;
	int pos=0;
	pop2 = (pop1>>1 & 0x55555555)+(pop1 & 0x55555555);
	pop4 = (pop2>>2 & 0x33333333)+(pop2 & 0x33333333);
	pop8 = (pop4>>4 & 0x0f0f0f0f)+(pop4 & 0x0f0f0f0f);
	pop16= (pop8>>8 & 0x00ff00ff)+(pop8 & 0x00ff00ff);
	if((pop16    &0x0000ffff) <= k){
		k -= (pop16    &0x0000ffff);
		pos |= 16;
	}
	if((pop8>>pos&0x000000ff) <= k){
		k -= (pop8>>pos&0x000000ff);
		pos |= 8;
	}
	if((pop4>>pos&0x0000000f) <= k){
		k -= (pop4>>pos&0x0000000f);
		pos |= 4;
	}
	if((pop2>>pos&0x00000003) <= k){
		k -= (pop2>>pos&0x00000003);
		pos |= 2;
	}
	if((pop1>>pos&0x00000001) <= k)pos |= 1;
	return pos;
}


//簡潔ビットベクトル
//メモリ使用量:2nビット
class BitVector{
	int n;
	int blocks;
	vector<unsigned int> B;
	vector<int> r;
public:
	BitVector(){}
	BitVector(int size){
		init(size);
	}
	void init(int size){
		n = size;
		blocks = (n>>5)+1;
		B.assign(blocks ,0);
		r.assign(blocks ,0);
	}
	void set(int k){
		B[k>>5] |= 1<<(k&31);
	}
	void build(){
		r[0]=0;
		for(int i=1;i<blocks;i++){
			r[i] = popCount(B[i-1]) + r[i-1];
		}
	}	
	bool access(int k)const{
		return B[k>>5] & 1<<(k&31);
	}
	//[0,k)の1の個数
	int rank(int k)const{
		return r[k>>5] + popCount(B[k>>5] & ((1<<(k&31))-1));
	}
	//k+1番目の1の場所
	//O(log n)
	int select1(int k)const{
		int lb=0,ub=blocks;
		if(k==-1)return -1;
		while(ub-lb>1){
			int m = (lb+ub)/2;
			if(k<r[m])ub=m;
			else lb=m;
		}
		k -= r[lb];
		return lb<<5 | kthRightmostPop(B[lb],k);
	}
	//O(log n)
	int select0(int k)const{
		int lb=0,ub=blocks;
		if(k==-1)return -1;
		while(ub-lb>1){
			int m = (lb+ub)/2;
			if(k<(m<<5)-r[m])ub=m;
			else lb=m;
		}
		k -= (lb<<5)-r[lb];
		return lb<<5 | kthRightmostPop(~B[lb],k);
	}
};


//ウェーブレット行列
//Σ=[A-Za-z]
class WaveletMatrix{
	static const int BITLEN = 6;//文字のビット長
	int len;
	BitVector bv[BITLEN];
	int encode(char c)const{//6bit
		if('A'<=c&&c<='Z')return c-'A';
		return c-'a'+('Z'-'A'+1);
	}
	char decode(int n)const{
		if(0<=n&&n<26)return n+'A';
		return n-26+'a';
	}
	struct Node{
		int height,s,e,code;
		Node(){}
		Node(int a,int b,int c,int d):height(a),s(b),e(c),code(d){};
		bool operator <(Node a)const{return e-s<a.e-a.s;}
	};
public:
	int length()const{
		return len;
	}
	WaveletMatrix(const string &str){
		init(str);
	}
	//O(n logσ)
	void init(const string &str){
		len = str.size();
		for(int i=0;i<BITLEN;i++){
			bv[i].init(str.size());
		}
		int *bucket;
		bucket = new int[2*len];
		int *cur,*next;
		cur = bucket;
		next = bucket+len;
		int rank0[BITLEN]={0};
		for(int i=0;i<len;i++){
			int n = encode(str[i]);
			cur[i] = n;
			for(int j=BITLEN-1;j>=0;j--){
				if((n&1<<j)==0)rank0[j]++;
			}
		}
		for(int i=BITLEN-1;;i--){
			for(int j=0;j<len;j++){
				if(cur[j]&1<<i)bv[i].set(j);
			}
			bv[i].build();
			if(i==0)break;
			int idx0=0,idx1=rank0[i];
			for(int j=0;j<len;j++){
				if(cur[j]&1<<i)next[idx1++]=cur[j];
				else next[idx0++]=cur[j];
			}
			swap(cur,next);
		}
		delete[] bucket;
	}
	//O(logσ)
	char access(int k)const{
		int code=0;
		for(int i=BITLEN-1;i>0;i--){
			if(bv[i].access(k)){
				code |= 1<<i;
				k = len-bv[i].rank(len)+bv[i].rank(k);
			}else{
				k = k-bv[i].rank(k);
			}
		}
		if(bv[0].access(k))code |= 1;
		return decode(code);
	}
	//[s,e)中のcの個数
	//O(logσ)
	int rank(char c,int s,int e)const{
		int n = encode(c);
		for(int i=BITLEN-1;i>=0;i--){
			int ssum = bv[i].rank(s);
			int esum = bv[i].rank(e);
			if(n&1<<i){
				s = len-bv[i].rank(len) + ssum;
				e = s + esum-ssum;
			}else{
				s = s-ssum;
				e = e-esum;
			}
		}
		return e-s;
	}
	//k+1番目のcの位置
	//O(log n logσ)
	int select(char c,int k)const{
		int n = encode(c);
		int s=0,e=len;
		for(int i=BITLEN-1;i>=0;i--){
			int ssum = bv[i].rank(s);
			int esum = bv[i].rank(e);
			if(n&1<<i){
				s = len-bv[i].rank(len) + ssum;
				e = s + esum-ssum;
			}else{
				s = s-ssum;
				e = e-esum;
			}
		}
		int p = s+k;
		if(e<=p)return -1;
		for(int i=0;i<BITLEN;i++){
			if(n&1<<i)p = bv[i].select1(p-(len-bv[i].rank(len)));
			else p = bv[i].select0(p);
		}
		return p;
	}
	//[s,e)中で出現頻度が多い順にk個返す
	//O(min(e-s,σ)logσ) ,頻度が偏っていればO(klogσ)
	vector<pair<char,int> > topk(int s,int e,int k)const{
		vector<pair<char,int> > res;
		priority_queue<Node> pq;
		pq.push(Node(BITLEN-1,s,e,0));
		while(!pq.empty() && 0<=k){
			Node a = pq.top();
			pq.pop();
			if(a.height==-1){
				res.push_back(make_pair(decode(a.code),a.e-a.s));
				k--;
				continue;
			}
			int ssum = bv[a.height].rank(a.s);
			int esum = bv[a.height].rank(a.e);
			int num1 = esum-ssum;
			int num0 = a.e-a.s-num1;
			int s = a.s-ssum;
			pq.push(Node(a.height-1,s,s+num0,a.code));
			s = len-bv[a.height].rank(len) + ssum;
			pq.push(Node(a.height-1,s,s+num1,a.code|1<<a.height));
		}
		return res;
	}
	//[s,e)中のr+1番目に大きい文字
	//O(logσ)
	char quantile(int s,int e,int r)const{
		int code=0;
		for(int i=BITLEN-1;i>=0;i--){
			int ssum = bv[i].rank(s);
			int esum = bv[i].rank(e);
			int num1 = esum-ssum;
			if(num1<=r){
				r -= num1;
				s = s-ssum;
				e = e-esum;
			}else{
				code |= 1<<i;
				s = len-bv[i].rank(len) + ssum;
				e = s + num1;
			}
			if(s==e)return '\0';
		}
		return decode(code);
	}
	//[s,e)中のx未満の文字の個数
	int rank_lt(int s,int e,char x)const{
		int n = encode(x);
		int res=0;
		for(int i=BITLEN-1;i>=0;i--){
			int ssum = bv[i].rank(s);
			int esum = bv[i].rank(e);
			if(n&1<<i){
				res += e-s-(esum-ssum);
				s = len-bv[i].rank(len) + ssum;
				e = s + esum-ssum;
			}else{
				s = s-ssum;
				e = e-esum;
			}
			if(s==e)return res;
		}
		return res;
	}
	//[s,e)中の x<=c<y となる文字の個数
	//O(logσ)
	int rangefreq(int s,int e,char x,char y)const{
		return rank_lt(s,e,y)-rank_lt(s,e,x);
	}
	//[s,e)中に出現する文字を大きい順に頻度と共にk個返す
	//O(k logσ)
	vector<pair<char,int> > rangemaxk(int s,int e,int k)const{
		Node sta[BITLEN+1];
		int sp=0;
		vector<pair<char,int> > res;
		sta[sp++] = Node(BITLEN-1,s,e,0);
		while(sp && 0<=k){
			Node a = sta[--sp];
			if(a.height==-1){
				res.push_back(make_pair(decode(a.code),a.e-a.s));
				k--;
				continue;
			}
			int ssum = bv[a.height].rank(a.s);
			int esum = bv[a.height].rank(a.e);
			int num1 = esum-ssum;
			int num0 = a.e-a.s-num1;
			int s = a.s-ssum;
			if(num0)sta[sp++] = Node(a.height-1,s,s+num0,a.code);
			s = len-bv[a.height].rank(len) + ssum;
			if(num1)sta[sp++] = Node(a.height-1,s,s+num1,a.code|1<<a.height);
		}
		return res;
	}
	//[s,e)中に出現する文字を小さい順に頻度と共にk個返す
	//O(k logσ)
	vector<pair<char,int> > rangemink(int s,int e,int k)const{
		Node sta[BITLEN+1];
		int sp=0;
		vector<pair<char,int> > res;
		sta[sp++] = Node(BITLEN-1,s,e,0);
		while(sp && 0<=k){
			Node a = sta[--sp];
			if(a.height==-1){
				res.push_back(make_pair(decode(a.code),a.e-a.s));
				k--;
				continue;
			}
			int ssum = bv[a.height].rank(a.s);
			int esum = bv[a.height].rank(a.e);
			int num1 = esum-ssum;
			int num0 = a.e-a.s-num1;
			int s = len-bv[a.height].rank(len) + ssum;
			if(num1)sta[sp++] = Node(a.height-1,s,s+num1,a.code|1<<a.height);
			s = a.s-ssum;
			if(num0)sta[sp++] = Node(a.height-1,s,s+num0,a.code);
		}
		return res;
	}
	//[s,e)中のx<=c<yとなる文字cを頻度と共に列挙する
	//返す文字種類をkとすると O(k logσ)
	vector<pair<char,int> > rangelist(int s,int e,char x,char y)const{
		int ub = encode(y);
		int lb = encode(x);
		Node sta[BITLEN+1];
		int sp=0;
		vector<pair<char,int> > res;
		sta[sp++] = Node(BITLEN-1,0,len,0);
		while(sp){
			Node a = sta[--sp];
			if(a.height==-1){
				res.push_back(make_pair(decode(a.code),a.e-a.s));
				continue;
			}
			int ssum = bv[a.height].rank(a.s);
			int esum = bv[a.height].rank(a.e);
			int num1 = esum-ssum;
			int num0 = a.e-a.s-num1;
			int s = len-bv[a.height].rank(len) + ssum;
			if((a.code|1<<a.height)<ub && num1)sta[sp++] = Node(a.height-1,s,s+num1,a.code|1<<a.height);
			s = a.s-ssum;
			if(lb<=(a.code) && num0)sta[sp++] = Node(a.height-1,s,s+num0,a.code);
		}
		return res;
	}
};