ALDS 1_7_B のC++のコードを解読しRustで再実装する

60903 ワード

二分木の表現

各ノードに、

  • 親の情報、
  • 自分の子のうち右の子の情報、
  • 左の子の情報
    を持たせる。
struct Node { int parent, left, right; }; 

C++のコード

#include<cstdio>
#define MAX 10000
#define NIL -1

struct Node
{
    int parent, left, right;
};

Node T[MAX];
int n, D[MAX], H[MAX];

void setDepth(int u, int d) {
    if (u==NIL) return;
    D[u] = d;
    setDepth(T[u].left, d+1);
    setDepth(T[u].right, d+1);
}

int setHeight(int u){
    int h1 = 0, h2 = 0;
    if (T[u].left != NIL)
        h1 = setHeight(T[u].left) + 1;
    if (T[u].right != NIL)
        h2 = setHeight(T[u].right) + 1;
    return H[u] = (h1>h2 ? h1:h2); // 高い方を返す
}

int getSibling(int u) {
    // ロジックがわかりずらい。
    if (T[u].parent == NIL) return NIL; //structにisNil的なメソッドを追加すべき?
    if (T[T[u].parent].left != u && T[T[u].parent].left != NIL)
        return T[T[u].parent].left;
    if (T[T[u].parent].right != u && T[T[u].parent].right != NIL)
        return T[T[u].parent].right;
    return NIL;
}

void print(int u) {
    printf("node %d: ",u);
    printf("parent = %d, ",T[u].parent);
    printf("sibling = %d, ", getSibling(u));
    int deg = 0;
    if ( T[u].left != NIL) deg++ ;
    if ( T[u].right != NIL) deg++ ;
    printf("degree = %d, ",deg);
    printf("depth = %d, ",D[u]);
    printf("height = %d, ",H[u]);

    if (T[u].parent == NIL) {
        printf("root\n");
    } else if (T[u].left == NIL && T[u].right ==NIL) {
        printf("leaf\n");
    } else {
        printf("internal node\n");
    }
}

int main(){
    int v,l,r, root=0;
    scanf("%d", &n);

    for (int i = 0; i < n; i++) T[i].parent = NIL;
    for (int i = 0; i < n; i++) {
        scanf("%d %d %d", &v, &l, &r);
        T[v].left=l;
        T[v].right=r;
        if (l!=NIL) T[l].parent = v;
        if (r!=NIL) T[r].parent = v;
    }

    for (int i = 0; i < n; i++) if (T[i].parent==NIL) root = i;

    setDepth(root, 0);
    setHeight(root);

    for (int i=0; i<n; i++) print(i);
    return 0;
    
}

Rustでの再実装

nodeのparent, left, right にOption<usize>を使うことでNoneでの表現ができるようになったが、パターンマッチなどの操作が増えた。競プロではNil=-1のようにする方が早くコーディングできるので良いだろうが、本来はOption型を使うべきだろうと思っている。

use std::io;

fn read<T: std::str::FromStr>() -> Vec<T> {
    let mut buf = String::new();
    io::stdin().read_line(&mut buf).unwrap();
    buf.trim().split(' ').flat_map(str::parse).collect()
}

#[derive(Debug, Clone)]
struct Node {
    parent: Option<usize>,
    left: Option<usize>,
    right: Option<usize>
}

impl Node {
    fn set_parent(&mut self, parent: Option<usize>) {
        self.parent = parent;
    }
    fn set_left(&mut self, left: Option<usize>) {
        self.left = left;
    }
    fn set_right(&mut self, right: Option<usize>) {
        self.right = right
    }
    fn get_parent(&self) -> usize {
        self.parent.expect("parent")
    }
}


#[derive(Debug, Clone)]
struct BinaryTree {
    node: Vec<Node>,
    depth: Vec<i32>,
    height: Vec<i32>
}

impl BinaryTree {
    fn new(node_num: usize) -> Self {
        let node = vec![Node{
            parent:None as Option<usize>,
            left: None as Option<usize>,
            right: None as Option<usize>,
        } ;node_num];

        let depth: Vec<i32> = vec![0; node_num];
        let height: Vec<i32> = vec![0; node_num];

        BinaryTree {
            node: node, 
            depth: depth, 
            height: height
        }
    }

    fn setDepth(&mut self, index: usize, d:i32){
            self.depth[index] = d;
            if let Some(left) = self.node[index].left { 
                self.setDepth(left, d+1);
            }
            if let Some(right) = self.node[index].right { 
                self.setDepth(right, d+1);
            }
    }

    fn setHeight(&mut self, index: usize) -> usize{
        let mut h1 = 0;
        let mut h2 = 0;
        if let Some(left) = self.node[index].left {
            h1 = self.setHeight(left) + 1;
        }
        if let Some(right) = self.node[index].right {
            h2 = self.setHeight(right) + 1;
        }
        self.height[index] = h1.max(h2) as i32;
        return h1.max(h2) as usize;
    }

    fn getSibling(&self, index: usize) -> Option<usize> {
        if self.node[index].parent.is_none() { return None as Option<usize>; }

        if self.node[self.node[index].get_parent()].left != Some(index) &&
            self.node[self.node[index].get_parent()].left.is_some() {
                return self.node[self.node[index].get_parent()].left;
            }
        if self.node[self.node[index].get_parent()].right != Some(index) &&
            self.node[self.node[index].get_parent()].right.is_some() {
                return self.node[self.node[index].get_parent()].right;
            }
        return None as Option<usize>;
    }

}

fn none_cheker(x: i32) -> Option<usize> {
    match x {
        -1 => None as Option<usize>,
        _ => Some(x as usize)
    }
}

fn printer(btree: &BinaryTree, index: usize){ 
    print!("node {}: ", index);
    let parent = match btree.node[index].parent {
        Some(parent) => parent as i32,
        None => -1
    };
    print!("parent = {}, ", parent);

    let sibling = match btree.getSibling(index) {
        Some(sibling) => sibling as i32,
        None => -1
    };
    print!("sibling = {}, ",sibling);

    let mut deg = 0;
    if btree.node[index].left.is_some() { deg += 1; }
    if btree.node[index].right.is_some() { deg += 1; }
    print!("degree = {}, ", deg);
    print!("depth = {}, ", btree.depth[index]);
    print!("height = {}, ", btree.height[index]);

    match &btree.node[index] {
        node if node.parent.is_none() => println!("root"),
        node if node.left.is_none() && node.right.is_none() => println!("leaf"),
        _ => println!("internal node")
    }

}


fn main() {
    let node_num = read::<usize>()[0];
    let mut btree = BinaryTree::new(node_num);
    btree.setHeight(0);
    let mut root = 0;


    for _i in 0..node_num {
        let inputs = read::<i32>();
        let v = inputs[0] as usize;
        // -1以外だったらOption<usize>に変換して返す。-1だったらNoneを返す。
        let l = none_cheker(inputs[1]);
        let r = none_cheker(inputs[2]);

        btree.node[v].set_left(l);
        btree.node[v].set_right(r);

        if let Some(left) = l {btree.node[left].set_parent(Some(v))}
        if let Some(right) = r {btree.node[right].set_parent(Some(v))}
    }

    for i in 0..node_num {
        if btree.node[i].parent.is_none() {root = i};
    }
    btree.setDepth(root, 0);
    btree.setHeight(root);

    for i in 0..node_num {
        printer(&btree, i)
    }
}