BZOJ 4231 回忆树(KMP+AC自动机+fail树+树状数组)

这是一篇关于BZOJ 4231题目的解题报告,涉及在回忆树中查找字符串出现次数的问题。利用KMP算法处理路径上的字符串比较,通过AC自动机处理单路径上的字符串,结合fail树和树状数组来优化解决方案,实现高效的时间复杂度。

题目

Description

回忆树是树。
具体来说,是n个点n-1条边的无向连通图,点标号为1~n,每条边上有一个字符(出于简化目的,我们认为只有小写字母)。
对一棵回忆树来说,回忆当然是少不了的。
一次回忆是这样的:你想起过往,触及心底…唔,不对,我们要说题目。
这题中我们认为回忆是这样的:给定2个点u,v(u可能等于v)和一个非空字符串s,问从u到v的简单路径上的所有边按照到u的距离从小到大的顺序排列后,边上的字符依次拼接形成的字符串中给定的串s出现了多少次。

Input

第一行2个整数,依次为树中点的个数n和回忆的次数m。
接下来n-1行,每行2个整数u、v和1个小写字母c,表示回忆树的点u、v之间有一条边,边上的字符为c
接下来2m行表示m次回忆,每次回忆2行:第1行2个整数u、v,第2行给出回忆的字符串s。

Output

对于每次回忆,输出串s出现的次数。

Sample Input

12 3
1 2 w
2 3 w
3 4 x
4 5 w
5 6 w
6 7 x
7 8 w
8 9 w
9 10 x
10 11 w
11 12 w
1 7
wwx
1 12
www
1 12

Sample Output

2
0
8

HINT

对于100%的数据,n<=100000,m<=100000,询问串的总长<=300000

Time Limit: 10 Sec Memory Limit: 256 MB

题解

将每条路径分为向上的路径和向下的路径,则字符串s出现的情况有两种,分情况讨论

1.s经过了LCA(u,v):此时字符串长度不超过2|s|,用KMP进行比较
2.s只存在一条路径上:将所有s的正串和反串分别建立AC自动机,DFS原树,经过一个节点时,在AC自动机上对应的节点处+1,遍历完子树后-1,答案即为fail树上s对应的的子树权值和

求出fail树的DFS序,用树状数组维护子树权值

时间复杂度:O((n+m)logn+|s|)O((n+m)logn+∑|s|)

代码

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cctype>
#include<cstring>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<vector>
#define LL long long
using namespace std;
const int maxn=100005,maxs=600005,SIZE=26;
int n,m,sums=1,cc,ccc,x,y,z,zz;
char s[maxs],ss[maxs],sss[maxn],ssss[maxn],S[SIZE];
struct data
{
    int x,y,z,L,R,ans;
}b[maxn];
struct tree
{
    int to,next;char w;
}a[maxn<<1];
int first[maxn],np=0;
struct ques
{
    int to,next,w,z;
}c[maxn<<2];
int head[maxn],sz=0;
const int l=1<<15;
char buffer[l],*SS,*TT;
inline char Get_Char()
{
    if(SS==TT)
      {
        TT=(SS=buffer)+fread(buffer,1,l,stdin);
        if(SS==TT) return EOF;
      }
    return *SS++;
}
inline void read(int &x)
{
    char ch;x=0;int F=0;ch=Get_Char();
    while(!isdigit(ch)) {if(ch=='-') F=1;ch=Get_Char();}
    while(isdigit(ch)) {x=x*10+ch-'0';ch=Get_Char();}
    if(F) x=-x;return;
}
inline void read(LL &x)
{
    char ch;x=0;int F=0;ch=Get_Char();
    while(!isdigit(ch)) {if(ch=='-') F=1;ch=Get_Char();}
    while(isdigit(ch)) {x=x*10+ch-'0';ch=Get_Char();}
    if(F) x=-x;return;
}
void read(char *s)
{
    char ch;ch=Get_Char();
    while(!isalpha(ch)) ch=Get_Char();int t=0;
    while(isalpha(ch)) {s[t++]=ch;ch=Get_Char();}
    s[t]=0;return;
}
void write(int x)
{
    if(!x) putchar('0');
    else
      {
        char s[20];int cnt=0,F=0;if(x<0) x=-x,F=1;
        while(x) s[++cnt]=x%10+'0',x/=10;if(F) putchar('-');
        for(int i=cnt;i>=1;i--) putchar(s[i]);
      }
    putchar('\n');return;
}
void write(LL x)
{
    if(!x) putchar('0');
    else
      {
        char s[20];int cnt=0,F=0;if(x<0) x=-x,F=1;
        while(x) s[++cnt]=x%10+'0',x/=10;if(F) putchar('-');
        for(int i=cnt;i>=1;i--) putchar(s[i]);
      }
    putchar('\n');return;
}
void add(int x,int y,char w)
{
    a[++np]=(tree){y,first[x],w};
    first[x]=np;
    return;
}
void insert(int x,int y,int f,int z)
{
    c[++sz]=(ques){y,head[x],f,z};
    head[x]=sz;
    return;
}
struct KMP
{
    int fail[maxs];
    void getfail(char *s)
      {
        int n=strlen(s),i=0,j=-1;
        fail[0]=-1;
        while(i<n)
          {
            while(j>=0 && s[i]!=s[j]) j=fail[j];
            i++;j++;
            fail[i]=j;
          }
        return;
      }
    int kmp(char *s,char *ss)
      {
        int n=strlen(s),m=strlen(ss);
        int i=0,j=0,cnt=0;
        getfail(s);
        while(i<m)
          {
            while(j>=0 && s[j]!=ss[i]) j=fail[j];
            i++;j++;
            if(j==n) cnt++,j=fail[j];
          }
        return cnt;
      }
}h;
struct BIT
{
    int c[maxs];
    int lowbit(int x)
      {
        return x&(-x);
      }
    void update(int x,int y)
      {
        while(x<=sums)
          {
            c[x]+=y;
            x+=lowbit(x);
          }
        return;
      }
    int find(int x)
      {
        int sum=0;
        while(x)
          {
            sum+=c[x];
            x-=lowbit(x);
          }
        return sum;
      }
    int query(int x,int y)
      {
        int l,r;
        l=find(x-1);
        r=find(y);
        return r-l;
      }
}f,ff;
struct AC
{
    int chd[maxs][SIZE],fail[maxs],val[maxs],last[maxs],rt,np;
    int q[maxs],front,rear;
    struct tree
      {
        int to,next;
      }a[maxs<<1];
    int first[maxs],sz;
    int L[maxs],R[maxs],pos;
    AC()
      {
        rt=np=sz=front=rear=pos=0;
        memset(chd,0,sizeof(chd));
        memset(fail,0,sizeof(fail));
        memset(val,0,sizeof(val));
        memset(first,0,sizeof(first));
        memset(last,0,sizeof(last));
      }
    void insert(char *s)
      {
        int l=strlen(s),p=rt;
        for(int i=0;i<l;i++)
          {
            if(!chd[p][s[i]-'a']) chd[p][s[i]-'a']=++np;
            p=chd[p][s[i]-'a'];
          }
        val[p]++;
        return;
      }
    void getfail()
      {
        for(int i=0;i<SIZE;i++)
          if(chd[rt][i]) q[rear++]=chd[rt][i];
        while(front!=rear)
          {
            int u=q[front++];
            for(int i=0;i<SIZE;i++)
              {
                int v=chd[u][i];
                if(!v) {chd[u][i]=chd[fail[u]][i];continue;}
                q[rear++]=v;
                int t=fail[u];
                while(t && !chd[t][i]) t=fail[t];
                fail[v]=chd[t][i];
                last[v]=val[fail[v]]?fail[v]:last[fail[v]];
              }
          }
        return;
      }
    void add(int x,int y)
      {
        a[++sz]=(tree){y,first[x]};
        first[x]=sz;
        return;
      }
    void gettree()
      {
        for(int i=1;i<=np;i++) add(i,fail[i]),add(fail[i],i);
        return;
      }
    void DFS(int i,int f)
      {
        L[i]=++pos;
        for(int j=first[i];j;j=a[j].next)
          if(a[j].to!=f)
            DFS(a[j].to,i);
        R[i]=pos;
        return;
      }
}g,gg;
int Findg(int x)
{
    cc=0;
    for(int i=b[x].L;i<=b[x].R;i++) ss[cc++]=s[i];ss[cc]=0;
    int p=g.rt;
    for(int i=0;i<cc;i++) p=g.chd[p][ss[i]-'a'];
    return p;
}
int Findgg(int x)
{
    cc=0;
    for(int i=b[x].R;i>=b[x].L;i--) ss[cc++]=s[i];ss[cc]=0;
    int p=gg.rt;
    for(int i=0;i<cc;i++) p=gg.chd[p][ss[i]-'a'];
    return p;
}
void DFS(int i,int fa,int p,int q)
{
    if(g.val[p]) f.update(g.L[p],1);
    else if(g.last[p]) f.update(g.L[g.last[p]],1);
    if(gg.val[q]) ff.update(gg.L[q],1);
    else if(gg.last[q]) ff.update(gg.L[gg.last[q]],1);
    int r;
    for(int j=head[i];j;j=c[j].next)
      if(c[j].z>0)
        {
          r=Findg(c[j].to);
          b[c[j].to].ans+=c[j].w*f.query(g.L[r],g.R[r]);
        }
      else
        {
          r=Findgg(c[j].to);
          b[c[j].to].ans+=c[j].w*ff.query(gg.L[r],gg.R[r]);
        }
    for(int j=first[i];j;j=a[j].next)
      if(a[j].to!=fa)
        DFS(a[j].to,i,g.chd[p][a[j].w-'a'],gg.chd[q][a[j].w-'a']);
    if(g.val[p]) f.update(g.L[p],-1);
    else if(g.last[p]) f.update(g.L[g.last[p]],-1);
    if(gg.val[q]) ff.update(gg.L[q],-1);
    else if(gg.last[q]) ff.update(gg.L[gg.last[q]],-1);
    return;
}
int deep[maxn],fa[maxn][20];char fd[maxn];
void _DFS(int i,int f,int dp)
{
    fa[i][0]=f;deep[i]=dp;
    for(int j=1;j<=18;j++)
      fa[i][j]=fa[fa[i][j-1]][j-1];
    for(int j=first[i];j;j=a[j].next)
      if(a[j].to!=f)
        _DFS(a[j].to,i,dp+1),fd[a[j].to]=a[j].w;
    return;
}
int LCA(int x,int y)
{
    if(deep[x]<deep[y]) swap(x,y);
    int z=deep[x]-deep[y];
    for(int j=18;j>=0;j--)
      if(z&(1<<j)) x=fa[x][j];
    if(x==y) return x;
    for(int j=18;j>=0;j--)
      if(fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
    return fa[x][0];
}
void solve()
{
    read(n);read(m);
    for(int i=1;i<n;i++) read(x),read(y),read(S),add(x,y,S[0]),add(y,x,S[0]);
    _DFS(1,0,1);b[0].R=-1;
    for(int i=1;i<=m;i++)
      {
        read(b[i].x);read(b[i].y);read(s+b[i-1].R+1);
        b[i].z=LCA(b[i].x,b[i].y);b[i].L=b[i-1].R+1;
        b[i].R=b[i].L+strlen(s+b[i].L)-1;sums+=b[i].R-b[i].L+1;cc=ccc=0;
        z=b[i].x;
        if(deep[b[i].x]-deep[b[i].z]>=b[i].R-b[i].L+1)
          {
            for(int j=18;j>=0;j--)
              if(deep[fa[z][j]]-deep[b[i].z]>=b[i].R-b[i].L+1) z=fa[z][j];
            insert(b[i].x,i,1,-1);
            z=fa[z][0];insert(z,i,-1,-1);
          }
        while(deep[z]>deep[b[i].z]) sss[ccc++]=fd[z],z=fa[z][0];
        z=b[i].y;
        if(deep[b[i].y]-deep[b[i].z]>=b[i].R-b[i].L+1)
          {
            for(int j=18;j>=0;j--)
              if(deep[fa[z][j]]-deep[b[i].z]>=b[i].R-b[i].L+1) z=fa[z][j];
            insert(b[i].y,i,1,1);
            z=fa[z][0];insert(z,i,-1,1);
          }
        zz=z;while(deep[z]>deep[b[i].z]) ccc++,z=fa[z][0];
        z=zz;while(deep[z]>deep[b[i].z]) sss[ccc-(deep[zz]-deep[fa[z][0]])]=fd[z],z=fa[z][0];sss[ccc]=0;
        for(int j=b[i].L;j<=b[i].R;j++) ss[cc++]=s[j];ss[cc]=0;cc=0;
        for(int j=b[i].R;j>=b[i].L;j--) ssss[cc++]=s[j];ssss[cc]=0;
        b[i].ans+=h.kmp(ss,sss);
        g.insert(ss);gg.insert(ssss);
      }
    g.getfail();g.gettree(); g.DFS(0,0);
    gg.getfail();gg.gettree();gg.DFS(0,0);
    DFS(1,0,g.rt,gg.rt);
    for(int i=1;i<=m;i++) write(b[i].ans);
    return;
}
int main()
{
    solve();
    return 0;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值