多校JLU 10題JLUCPC(ツリーDP)



問題テストの時は問題が気持ち悪くて10問もできなかったのですが、本番では6問も鋭くできたチームがいて本当にすごいです
これは一番水の問題ですが、私にはまだ難しいので、問題解を見て、長い間引っ張ってやっと分かりました.
コード最適化後JOJは0.57 s走った.(午前中に分かったかと思うと、午後の試合が出てきた)
先輩がくれた問題解:
ツリーDP
ツリーは連通で無方向エッジであるため、1をルートと仮定し、ノード1から各サブツリーDFSへ今回のDFSではノードごとに数量node[i]を記録iをルートとするサブツリー(iノードを含む)のT[k]とsum[i]をルートとするサブツリー(iノードを含む)に、各ノードからルートへの経路長*T[k]の和を記録する
したがって、1回目のDFSで得られたsum[1]は、1をthe most convenient locationの答えとし、異なるノードをthe most convenient locationとする場合を列挙し始める
    1   /\  2   3  /4
例えば2ノードに移行すると、2ノードをthe most convenient locationとする答えは
Sum[1] –  sum_node[2] * g[1][2] + ( sum_node[1] –  sum_node[2]) * g[1][2] = Sum[2]

(g[1][2]は1から2の辺の重みを表す)
次に4ノードをthe most convenient locationとする答えは
 Sum[2] –  sum_node[4] * g[2][4] + ( sum_node[1] –  sum_node[4]) * g[2][4] = Sum[4]

このようにDFSを一度でいい
 
#include <cstdio>
#include <string.h>

const int maxn=100005;

struct Edge{
    int v,next,w;
}edge[2*maxn];

int cnt,n,head[maxn];
long long node[maxn],sum[maxn],sum_node[maxn];
bool vis[maxn];

void addedge(int u,int v,int w){
    edge[cnt].v=v;
    edge[cnt].w=w;
    edge[cnt].next=head[u];
    head[u]=cnt++;
    edge[cnt].v=u;
    edge[cnt].w=w;
    edge[cnt].next=head[v];
    head[v]=cnt++;
}

void dfs (int u)
{
    int v,i,t=head[u],w;
    vis[u]=true ;
    sum_node[u]=node[u];
    for (;~t;t=edge[t].next)
    {
        v=edge[t].v;
        if(vis[v])continue;
        w=edge[t].w;
        dfs(v);
        sum_node[u]+=sum_node[v];
        sum[0]+=w*sum_node[v];
    }
}

void dfs2 (int u)
{
    int v,t,w;
    vis[u]=true ;
    for (t=head[u]; ~t ; t=edge[t].next)
    {
        v=edge[t].v;
        if(vis[v])continue;
        w=edge[t].w;
        sum[v]=sum[u]+(sum_node[0]-(sum_node[v]<<1))*w;
        dfs2(v);
    }
}

inline void init ()
{
    memset (head , -1 ,sizeof(head));
    memset (vis , 0 , sizeof(vis));
    sum[0]=0;
    cnt=0;
}

long long sovle ()
{
    int  i;
    dfs(0);
    memset (vis , false , sizeof(vis));
    dfs2(0);
    long long min_num=sum[0];
    for (i=0 ; i<n ; ++i)
    if(min_num>sum[i]) min_num=sum[i];
    return min_num;
}

int main ()
{
    int i,u,v,w;
    while (~scanf("%d",&n))
    {
        init ();
        for (i=0 ; i<n ; ++i)
         scanf("%d",node+i);
        for (i=1 ; i<n ; ++i)
        {
            scanf("%d%d%d",&u,&v,&w);
            u--;v--;
            addedge (u,v,w);
        }
        printf("%lld
",sovle()); } return 0; }