BOJ 1761頂点の距離


https://www.acmicpc.net/problem/1761
2秒、128 MBメモリ
input :
  • N(2 ≤ N ≤ 40,000)
  • N-1行:ツリーに接続された2点と距離
  • M(1 ≤ M ≤ 10,000)
  • M行:一対の入力
  • output :
  • 出力2ノード間の距離
  • 条件:

  • N個の頂点からなるツリー

  • M個の2節の点の対を入力する時、2つのノードの間の距離を出力します
  • 改善されたLCAを用いて解決しようとしたが,parentに距離を追加することで3次元配列を使用することは想像以上にメモリを消費した.
    40,000*log(40,000)*2を作ると全てのケースが発生すると思いますが、間違った計算のようです.
    とにかく、LCAを使ってそれぞれ自分の親を探す方法で問題を解決することができます.
    彼らの深さが異なると、両親が異なると、彼らは自分の両親に更新されます.
    距離を合わせてこそ、問題は欲しい答えを得ることができる.
    import sys
    sys.setrecursionlimit(100000)
    
    def dfs(node, deep):
        visit[node] = 1
        depth[node] = deep
    
        for next_node, cost in graph[node]:
            if visit[next_node] == 1:
                continue
            parent[next_node] = [node, cost]
            dfs(next_node, deep + 1)
    
    def lca(a, b):
        ans = 0
    
        # 언제나 b가 더 깊은 곳에 있도록 만듬
        if depth[a] > depth[b]:
            a, b = b, a
    
        while depth[a] != depth[b]:
            ans += parent[b][1]
            b = parent[b][0]
    
        if a == b:
            return ans
    
        while parent[a][0] != parent[b][0]:
            ans += parent[b][1]
            ans += parent[a][1]
            b = parent[b][0]
            a = parent[a][0]
    
        ans += parent[a][1] + parent[b][1]
        return ans
    
    n = int(sys.stdin.readline())
    graph = [[] for _ in range(n + 1)]
    visit, depth = [0] * (n + 1), [0] * (n + 1)
    parent = [[0, 0] for _ in range(n + 1)]
    root = [i for i in range(n + 1)]
    
    for _ in range(n - 1):
        a, b, cost = map(int, sys.stdin.readline().split())
        graph[a].append((b, cost))
        graph[b].append((a, cost))
    
        # root 노드를 찾아서 모든 노드의 깊이를 구하기 위한 과정
        if root[a] > root[b]:
            root[a] = root[b]
        else:
            root[b] = root[a]
    
    dfs(root[1], 0)
    
    m = int(sys.stdin.readline())
    for _ in range(m):
        u, v = map(int, sys.stdin.readline().split())
        print(lca(u, v))