今回は簡単な例題として、以前作成した「二分探索木」をテンプレートを使って書き直してみましょう。クラス名は Tree とします。なお、C++の標準ライブラリには <map> や <set> などの平衡二分木が用意されているので、私たちが単純な二分木を作る必要はありませんが、テンプレートのお勉強ということで、あえてプログラムを作ってみましょう。
最初は、節 Node と 二分木 Tree のクラステンプレートを定義します。
リスト : 節の定義
template<class T> class Node {
T item;
Node* left;
Node* right;
public:
explicit Node(const T& x)
: item(x), left(0), right(0) { }
explicit Node(T&& x)
: item(move(x)), left(0), right(0) { }
T& get_item() { return item; }
Node* get_left() const { return left; }
Node* get_right() const { return right; }
void put_item(const T& x) { item = x; }
void put_item(T&& x) { item = move(x); }
void put_left(Node* l) { left = l; }
void put_right(Node* r) { right = r; }
};
コンストラクタは引数に const T& を受け取るものと右辺値参照 (T&&) を受け取るものの 2 つを用意します。二分探索木の場合、要素 item の値を書き換えると、二分探索木の条件を満たさなくなる危険性がありますが、move を使いたい場合があるので、返り値は T& とします。あとはデータ型を int から T に変更するだけです。
リスト : 二分探索木の定義
template<class T> class Tree {
Node<T>* root;
public:
Tree() : root(0) { }
~Tree() { destroy_node(root); }
// コピーコンストラクタ
Tree(const Tree& tree) {
root = copy_node(tree.root);
}
// 代入演算子
Tree& operator=(const Tree& tree) {
if (this != &tree) {
destroy_node(root);
root = copy_node(tree.root);
}
return *this;
}
// ムーブコンストラクタ
Tree(Tree&& tree) : root(tree.root) {
tree.root = 0;
}
// ムーブ代入演算子
Tree& operator=(Tree&& tree) {
if (this != &tree) {
destroy_node(root);
root = tree.root;
tree.root = 0;
}
return *this;
}
// メンバ関数
bool empty() const { return !root; }
bool search(const T& x) const { return search_node(x, root); }
const T& min() const {
if (!root) throw std::runtime_error("Tree::min empty tree");
return search_min(root);
}
const T& max() const {
if (!root) throw std::runtime_error("Tree::max empty tree");
return search_max(root);
}
void insert(const T& x) { root = insert_node(x, root); }
void insert(T&& x) { root = insert_node(forward<T>(x), root); }
void remove(const T& x) { root = delete_node(x, root); }
void remove_min() {
if (root) root = delete_min(root);
}
void remove_max() {
if (root) root = delete_max(root);
}
//
// イテレータ (省略)
//
};
Tree のテンプレート仮引数は T なので、節のデータ型は Node<T> になります。メンバ変数 root のデータ型は Node<T>* になります。あとは、ムーブコンストラクタとムーブ代入演算子を定義して、データ型を int から T に変更します。このとき、要素の値 (item) を返すメンバ関数 max と min は、返り値の型を const T& とすることに注意してください。
それから、引数に右辺値参照を受け取るメンバ関数 insert を追加します。これに対応するため、作業用関数 insert_node に右辺値参照を受け取る関数を追加します。insert_node を呼び出すとき、引数 x に forward を適用することをお忘れなく。
次は insert_node に右辺値参照を受け取る関数を追加します。
リスト : データの挿入 (右辺値参照)
template<class T>
Node<T>* insert_node(T&& x, Node<T>* node)
{
if (!node) return new Node<T>(forward<T>(x));
if (x < node->get_item())
node->put_left(insert_node(forward<T>(x), node->get_left()));
else if (x > node->get_item())
node->put_right(insert_node(forward<T>(x), node->get_right()));
return node;
}
引数 x は右辺値参照なので、Node のコンストラクタを呼び出すときと、insert_node を再帰呼び出しするときは、引数 x に forward を適用してください。これで、引数 x の所有権を新しい節に移動することができます。
もう一つ、データを削除する作業用関数 delete_node を修正します。
リスト : データの削除
template<class T>
Node<T>* delete_node(const T& x, Node<T>* node)
{
if (!node) return node;
if (x == node->get_item()) {
if (!node->get_left()) {
Node<T>* x = node->get_right();
delete node;
return x;
}
if (!node->get_right()) {
Node<T>* x = node->get_left();
delete node;
return x;
}
node->put_item(move(search_min(node->get_right())));
node->put_right(delete_min(node->get_right()));
} else if (x < node->get_item())
node->put_left(delete_node(x, node->get_left()));
else
node->put_right(delete_node(x, node->get_right()));
return node;
}
右部分木から最小値を探して、それを item にセットします。このとき、search_min の返り値に move を適用して、ムーブ操作が定義されていれば、それを行うようにします。作業用関数 search_min と search_max の返り値の型は const T& ではなく、T& とすることに注意してください。
次はイテレータを作ります。
リスト : イテレータ
class Iterator : public iterator<forward_iterator_tag, T> {
vector<Node<T>*> stack;
// 次の node へ進める
void next_node(Node<T>* node) {
while (node) {
stack.push_back(node);
node = node->get_left();
}
}
public:
Iterator(Tree* tree, bool end) {
if (!end) next_node(tree->root);
}
// 間接参照
const T& operator*() const {
return stack.back()->get_item();
}
const T* operator->() const {
return &(stack.back()->get_item());
}
// 前置の ++ 演算子
Iterator& operator++() {
Node<T>* node = stack.back();
stack.pop_back();
next_node(node->get_right());
return *this;
}
// 後置の ++ 演算子
Iterator operator++(int n) {
Iterator iter(*this);
Node<T>* node = stack.back();
stack.pop_back();
next_node(node->get_right());
return iter;
}
// 比較演算子
bool operator==(const Iterator& iter) {
return stack == iter.stack;
}
bool operator!=(const Iterator& iter) {
return stack != iter.stack;
}
};
Iterator begin() { return Iterator(this, false); }
Iterator end() { return Iterator(this, true); }
イテレータの基本的な操作は以前作成した IntVec と同じです。ただし、たどってきた節は vector<Node<T>*> stack に格納します。この場合、単純な配列よりも vector を使ったほうが簡単です。
ところで、ムーブコンストラクタとムーブ代入演算子は、宣言されていないとコンパイラが自動的に生成してくれます。ただし、デストラクタ、コピーコンストラクタ、代入演算子のどれか一つでも宣言されていると、自動生成されないことに注意してください。
デフォルトの動作は非静的メンバ変数をムーブすることなので、Iterator のメンバ変数 stack を vector で定義すると、ムーブ操作にも対応できるのでとても便利です。
それから、間接参照の演算子 * と -> の返り値を const で修飾します。これで節の値 (item) を書き換えることができなくなります。あとのプログラムは簡単なので説明は割愛します。詳細はプログラムリストをお読みください。
それでは簡単なテストを行ってみましょう。
リスト : 二分探索木のテスト
int main()
{
vector<int> a = {5, 7, 3, 4, 2, 1, 8, 6, 9};
vector<int> b = {15, 17, 13, 14, 12, 11, 18, 16, 19};
Tree<int> tree_a;
Tree<int> tree_b;
for (int x : a) tree_a.insert(x);
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : b) tree_b.insert(x);
for (int x : tree_b) cout << x << " ";
cout << endl;
{
Tree<int> tree_c = tree_a;
tree_a = tree_b;
tree_b = tree_c;
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : tree_b) cout << x << " ";
cout << endl;
}
{
Tree<int> tree_c = move(tree_a);
tree_a = move(tree_b);
tree_b = move(tree_c);
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : tree_b) cout << x << " ";
cout << endl;
}
for (int x = 0; x <= 10; x++)
cout << tree_a.search(x) << " ";
cout << endl;
for (auto iter = tree_a.begin(); iter != tree_a.end(); ++iter)
cout << *iter << " ";
cout << endl;
auto iter = tree_b.begin();
while (iter != tree_b.end())
cout << *iter++ << " ";
cout << endl;
for_each(tree_a.begin(), tree_a.end(), [](int x){ cout << x << " "; });
cout << endl;
for (int i = 0; i < 9; i++) {
tree_a.remove(a[i]);
for_each(tree_a.begin(), tree_a.end(), [](int x){ cout << x << " "; });
cout << endl;
}
while (!tree_b.empty()) {
cout << tree_b.min() << endl;
cout << tree_b.max() << endl;
tree_b.remove_min();
tree_b.remove_max();
for (int x : tree_b) cout << x << " ";
cout << endl;
}
Tree<string> tree_c;
tree_c.insert("foo");
tree_c.insert("bar");
tree_c.insert("baz");
tree_c.insert("oops");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("foo");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("bar");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("oops");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("baz");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
}
$ clang++ tree.cpp $ ./a.out 1 2 3 4 5 6 7 8 9 11 12 13 14 15 16 17 18 19 11 12 13 14 15 16 17 18 19 1 2 3 4 5 6 7 8 9 1 2 3 4 5 6 7 8 9 11 12 13 14 15 16 17 18 19 0 1 1 1 1 1 1 1 1 1 0 1 2 3 4 5 6 7 8 9 11 12 13 14 15 16 17 18 19 1 2 3 4 5 6 7 8 9 1 2 3 4 6 7 8 9 1 2 3 4 6 8 9 1 2 4 6 8 9 1 2 6 8 9 1 6 8 9 6 8 9 6 9 9 11 19 12 13 14 15 16 17 18 12 18 13 14 15 16 17 13 17 14 15 16 14 16 15 15 15 bar baz foo oops bar baz oops baz oops baz
コピーコンストラクタ、代入演算子、ムーブコンストラクタ、ムーブ代入演算子は正常に動作しています。イテレータを実装すると、範囲 for 文や、for_each() など algorithm の関数も利用することができます。また、int 以外にも string を指定すると、文字列を格納することができます。興味のある方はいろいろ試してみてください。
//
// tree.cpp : 二分探索木 (テンプレート版)
//
// Copyright (C) 2015-2023 Makoto Hiroi
//
#include <iostream>
#include <stdexcept>
#include <iterator>
#include <vector>
#include <algorithm>
using namespace std;
// 節
template<class T> class Node {
T item;
Node* left;
Node* right;
public:
explicit Node(const T& x)
: item(x), left(0), right(0) { }
explicit Node(T&& x)
: item(move(x)), left(0), right(0) { }
T& get_item() { return item; }
Node* get_left() const { return left; }
Node* get_right() const { return right; }
void put_item(const T& x) { item = x; }
void put_item(T&& x) { item = move(x); }
void put_left(Node* l) { left = l; }
void put_right(Node* r) { right = r; }
};
// 探索
template<class T>
bool search_node(const T& x, Node<T>* node)
{
while (node) {
if (x == node->get_item()) return true;
else if (x < node->get_item())
node = node->get_left();
else
node = node->get_right();
}
return false;
}
// 挿入
template<class T>
Node<T>* insert_node(const T& x, Node<T>* node)
{
if (!node) return new Node<T>(x);
if (x < node->get_item())
node->put_left(insert_node(x, node->get_left()));
else if (x > node->get_item())
node->put_right(insert_node(x, node->get_right()));
return node;
}
template<class T>
Node<T>* insert_node(T&& x, Node<T>* node)
{
if (!node) return new Node<T>(forward<T>(x));
if (x < node->get_item())
node->put_left(insert_node(forward<T>(x), node->get_left()));
else if (x > node->get_item())
node->put_right(insert_node(forward<T>(x), node->get_right()));
return node;
}
// 最小値を探す
template<class T>
T& search_min(Node<T>* node)
{
while (node->get_left()) node = node->get_left();
return node->get_item();
}
// 最小値の節を削除
template<class T>
Node<T>* delete_min(Node<T>* node)
{
if (!node->get_left()) {
Node<T>* x = node->get_right();
delete node;
return x;
}
node->put_left(delete_min(node->get_left()));
return node;
}
// 最大値を探す
template<class T>
T& search_max(Node<T>* node)
{
while (node->get_right()) node = node->get_right();
return node->get_item();
}
// 最大値の節を削除
template<class T>
Node<T>* delete_max(Node<T>* node)
{
if (!node->get_right()) {
Node<T>* x = node->get_left();
delete node;
return x;
}
node->put_right(delete_max(node->get_right()));
return node;
}
// 削除
template<class T>
Node<T>* delete_node(const T& x, Node<T>* node)
{
if (!node) return node;
if (x == node->get_item()) {
if (!node->get_left()) {
Node<T>* x = node->get_right();
delete node;
return x;
}
if (!node->get_right()) {
Node<T>* x = node->get_left();
delete node;
return x;
}
node->put_item(move(search_min(node->get_right())));
node->put_right(delete_min(node->get_right()));
} else if (x < node->get_item())
node->put_left(delete_node(x, node->get_left()));
else
node->put_right(delete_node(x, node->get_right()));
return node;
}
// コピー
template<class T>
Node<T>* copy_node(Node<T>* node)
{
if (!node) return 0;
Node<T>* new_node = new Node<T>(node->get_item());
new_node->put_left(copy_node(node->get_left()));
new_node->put_right(copy_node(node->get_right()));
return new_node;
}
// 廃棄
template<class T>
void destroy_node(Node<T>* node)
{
if (node) {
destroy_node(node->get_left());
destroy_node(node->get_right());
delete node;
}
}
// 二分探索木
template<class T> class Tree {
Node<T>* root;
public:
Tree() : root(0) { }
~Tree() { destroy_node(root); }
// コピーコンストラクタ
Tree(const Tree& tree) {
root = copy_node(tree.root);
}
// 代入演算子
Tree& operator=(const Tree& tree) {
if (this != &tree) {
destroy_node(root);
root = copy_node(tree.root);
}
return *this;
}
// ムーブコンストラクタ
Tree(Tree&& tree) : root(tree.root) {
tree.root = 0;
}
// ムーブ代入演算子
Tree& operator=(Tree&& tree) {
if (this != &tree) {
destroy_node(root);
root = tree.root;
tree.root = 0;
}
return *this;
}
// メンバ関数
bool empty() const { return !root; }
bool search(const T& x) const { return search_node(x, root); }
const T& min() const {
if (!root) throw std::runtime_error("Tree::min empty tree");
return search_min(root);
}
const T& max() const {
if (!root) throw std::runtime_error("Tree::max empty tree");
return search_max(root);
}
void insert(const T& x) { root = insert_node(x, root); }
void insert(T&& x) { root = insert_node(forward<T>(x), root); }
void remove(const T& x) { root = delete_node(x, root); }
void remove_min() {
if (root) root = delete_min(root);
}
void remove_max() {
if (root) root = delete_max(root);
}
// イテレータ
class Iterator : public iterator<forward_iterator_tag, T> {
vector<Node<T>*> stack;
// 次の node へ進める
void next_node(Node<T>* node) {
while (node) {
stack.push_back(node);
node = node->get_left();
}
}
public:
Iterator(Tree* tree, bool end) {
if (!end) next_node(tree->root);
}
// 間接参照
const T& operator*() const {
return stack.back()->get_item();
}
const T* operator->() const {
return &(stack.back()->get_item());
}
// 前置の ++ 演算子
Iterator& operator++() {
Node<T>* node = stack.back();
stack.pop_back();
next_node(node->get_right());
return *this;
}
// 後置の ++ 演算子
Iterator operator++(int n) {
Iterator iter(*this);
Node<T>* node = stack.back();
stack.pop_back();
next_node(node->get_right());
return iter;
}
// 比較演算子
bool operator==(const Iterator& iter) {
return stack == iter.stack;
}
bool operator!=(const Iterator& iter) {
return stack != iter.stack;
}
};
Iterator begin() { return Iterator(this, false); }
Iterator end() { return Iterator(this, true); }
};
int main()
{
vector<int> a = {5, 7, 3, 4, 2, 1, 8, 6, 9};
vector<int> b = {15, 17, 13, 14, 12, 11, 18, 16, 19};
Tree<int> tree_a;
Tree<int> tree_b;
for (int x : a) tree_a.insert(x);
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : b) tree_b.insert(x);
for (int x : tree_b) cout << x << " ";
cout << endl;
{
Tree<int> tree_c = tree_a;
tree_a = tree_b;
tree_b = tree_c;
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : tree_b) cout << x << " ";
cout << endl;
}
{
Tree<int> tree_c = move(tree_a);
tree_a = move(tree_b);
tree_b = move(tree_c);
for (int x : tree_a) cout << x << " ";
cout << endl;
for (int x : tree_b) cout << x << " ";
cout << endl;
}
for (int x = 0; x <= 10; x++)
cout << tree_a.search(x) << " ";
cout << endl;
for (auto iter = tree_a.begin(); iter != tree_a.end(); ++iter)
cout << *iter << " ";
cout << endl;
auto iter = tree_b.begin();
while (iter != tree_b.end())
cout << *iter++ << " ";
cout << endl;
for_each(tree_a.begin(), tree_a.end(), [](int x){ cout << x << " "; });
cout << endl;
for (int i = 0; i < 9; i++) {
tree_a.remove(a[i]);
for_each(tree_a.begin(), tree_a.end(), [](int x){ cout << x << " "; });
cout << endl;
}
while (!tree_b.empty()) {
cout << tree_b.min() << endl;
cout << tree_b.max() << endl;
tree_b.remove_min();
tree_b.remove_max();
for (int x : tree_b) cout << x << " ";
cout << endl;
}
Tree<string> tree_c;
tree_c.insert("foo");
tree_c.insert("bar");
tree_c.insert("baz");
tree_c.insert("oops");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("foo");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("bar");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("oops");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
tree_c.remove("baz");
for (auto& x : tree_c) cout << x << " ";
cout << endl;
}