hdu 4747 Mex

7320 ワード

http://acm.hdu.edu.cn/showproblem.php?pid=4747
入力した配列をa[]とし,1からnまで遍歴する必要があり,iまで遍歴すると仮定すると, 遍歴中にiからjまで現れなかった最小自然数をb[j]で表す
まずnから1スキャンして1から各点までのb[j]値を求める
そしてa[]を遍歴することは実際には現在のa[i]を絶えず除去することであり、例えばa[3]を除去する場合、残りのb[4]---b[n]は4から他の後続点まで形成された区間に現れなかった最小自然数を表す
iからnまで、b[]の値は常に単調に増加していることを知っておく必要があります.
現在のa[i]を削除するたびにb[]配列に影響を及ぼします.
次のa[i]と等しい数が現れる位置をrとすると、a[i]を取り除くことはrおよびr以降のb[]に影響を及ぼさない
iと r間の影響を受けるセグメントb[]がa[i]以上であるセグメントは(l,r)であり、このセグメント内のb[]はいずれもa[i]以上であると仮定する.
a[i]を除いた影響は,このセグメント内のb[]がa[i]に等しいことである.
rを見つけて事前にマークして、lを探して更新するセグメント(l,r)は2つの方法があります
1,二分でlを見つけ,更新セグメント(l,r)を巡回する.    このようにコードは比較的に短くて、同じく比較的に分かりやすくて、しかし比較的に時間がかかります
2、線分ツリーのメンテナンス                                このようにコードの量は比較的に大きくて、しかし時間が少なくて、線分の木の解法は比較的に標準的であるべきです
2つのコード:
#include<iostream>

#include<cstdio>

#include<algorithm>

#include<string>

#include<cstring>

#include<cmath>

#include<set>

#include<vector>

#include<list>

#include<stack>

#include<queue>

#include<map>



using namespace std;



typedef long long ll;

typedef pair<int,int> pp;



const int INF=0x3f3f3f3f;



const int N=200002;

bool exist[N];

int a[N],next[N],f[N];

int b[N];

int bsh(int l,int r,int k)

{

    while(l<=r)

    {

        int mid=(l+r)>>1;

        if(b[mid]<=k) l=mid+1;

        else r=mid-1;

    }

    return r;

}

int main()

{

    //freopen("data.in","r",stdin);

    int n;

    while(scanf("%d",&n)!=EOF)

    {

       if(n==0) break;

       for(int i=1;i<=n;++i)

       scanf("%d",&a[i]);

       for(int i=0;i<=n;++i)

       f[i]=n+1;

       for(int i=n;i>=1;--i)

       if(a[i]<n)

       {

           next[i]=f[a[i]];

           f[a[i]]=i;

       }

       ll ans=0;

       memset(exist,false,sizeof(exist));

       ll tmp=0;int l=0;

       for(int i=1;i<=n;++i)

       {

           if(a[i]<n)

           {

               exist[a[i]]=true;

               while(exist[l]) ++l;

           }

           b[i]=l;

           tmp+=b[i];

       }

       ans=tmp;

       for(int i=1;i<n;++i)

       {

           if(a[i]<n)

           {

               int r=next[i];

               int l=bsh(i,r-1,a[i]);

               for(int j=l+1;j<r;++j)

               {

                   tmp-=(b[j]-a[i]);

                   b[j]=a[i];

               }

           }

           tmp-=b[i];

           ans+=tmp;

       }

       cout<<ans<<endl;

    }

    return 0;

}



#include<iostream>

#include<cstdio>

#include<algorithm>

#include<string>

#include<cstring>

#include<cmath>

#include<set>

#include<vector>

#include<list>

#include<stack>

#include<queue>

#include<map>



using namespace std;



typedef long long ll;

typedef pair<int,int> pp;



const int INF=0x3f3f3f3f;



const int N=200002;

bool exist[N];

int a[N],next[N],f[N];

int b[N];

struct node

{

    int l,r,k,least;

    ll sum;

}tr[N*4];

void build(int x,int l,int r)

{

    tr[x].l=l;tr[x].r=r;tr[x].k=-1;

    if(l==r)

    {

        tr[x].least=b[l];

        tr[x].sum=b[l];

        return ;

    }

    int mid=(l+r)>>1;

    build((x<<1),l,mid);

    build((x<<1)|1,mid+1,r);

    tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);

    tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);

}

void update(int x,int l,int r,int k)

{

    if(l>r) return ;

    if(tr[x].l==l&&tr[x].r==r)

    {

        tr[x].least=k;

        tr[x].k=k;

        tr[x].sum=(ll)k*(tr[x].r-tr[x].l+1);

        return ;

    }

    if(tr[x].k!=-1)

    {

        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;

        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);

        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;

        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);

        tr[x].k=-1;

    }

    int mid=(tr[x].l+tr[x].r)>>1;

    if(r<=mid)

    update(x<<1,l,r,k);

    else if(l>mid)

    update((x<<1)|1,l,r,k);

    else

    {

        update(x<<1,l,mid,k);

        update((x<<1)|1,mid+1,r,k);

    }

    tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);

    tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);

    tr[x].k=-1;

}

int get(int x,int l,int r,int w)

{

    if(tr[x].l==tr[x].r)

    {

        if(tr[x].least>w)

        return (l-1);

        return l;

    }

    if(tr[x].k!=-1)

    {

        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;

        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);

        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;

        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);

        tr[x].k=-1;

    }

    int mid=(tr[x].l+tr[x].r)>>1;

    if(r<=mid)

    return get(x<<1,l,r,w);

    else if(l>mid)

    return get((x<<1)|1,l,r,w);

    else

    {

        if(tr[(x<<1)|1].least<=w)

        return get((x<<1)|1,mid+1,r,w);

        else

        return get(x<<1,l,mid,w);

    }

}

ll gsum(int x,int l,int r)

{

    if(l>r) return 0;



    if(tr[x].l==l&&tr[x].r==r)

    return tr[x].sum;

    if(tr[x].k!=-1)

    {

        tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;

        tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);

        tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;

        tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);

        tr[x].k=-1;

    }

    int mid=(tr[x].l+tr[x].r)>>1;

    if(r<=mid)

    return gsum(x<<1,l,r);

    else if(l>mid)

    return gsum((x<<1)|1,l,r);

    else

    return gsum(x<<1,l,mid)+gsum((x<<1)|1,mid+1,r);

}

int main()

{

    int n;

    while(scanf("%d",&n)!=EOF)

    {

       if(n==0) break;

       for(int i=1;i<=n;++i)

       scanf("%d",&a[i]);

       for(int i=0;i<=n;++i)

       f[i]=n+1;

       for(int i=n;i>=1;--i)

       if(a[i]<n)

       {

           next[i]=f[a[i]];

           f[a[i]]=i;

       }

       ll ans=0;

       memset(exist,false,sizeof(exist));

       int l=0;

       for(int i=1;i<=n;++i)

       {

           if(a[i]<n)

           {

               exist[a[i]]=true;

               while(exist[l]) ++l;

           }

           b[i]=l;

       }

       build(1,1,n);

       ans+=gsum(1,1,n);

       for(int i=1;i<n;++i)

       {

           if(a[i]<n)

           {

               int r=next[i];

               int l=get(1,i,r-1,a[i]);

               update(1,l+1,r-1,a[i]);

           }

           ans+=gsum(1,i+1,n);

       }

       cout<<ans<<endl;

    }

    return 0;

}