(树链剖分 + 主席树)
题意:
给出两棵树,大小分别为\(n1\),\(n2\), 树上的结点权值为\(weight_i\)
同一棵树上的结点权值各不相同,不同树上的结点权值可以出现重复 每次查询 \(u1\) \(v1\) \(u2\) \(v2\) 第一棵树上\(u1\) 到 \(v1\)的路径上所有结点权值组成的集合\(S1\) 第二棵树上\(u2\) 到 \(v2\)的路径上所有结点权值组成的集合\(S2\) 求\(S1\) 与 \(S2\) 的交集\(1 <= n1,n2 <= 10^{5}\)
\(q <= 50000\)\(0 <= weight_i <= 10^{9}\)\(1 <= u1,v1 <= n1, 1 <= u2,v2 <= n2\)思路:
先考虑序列上的问题
将第二个序列的权值映射为相同权值在第一个序列中出现的下标 那么对于\(L1,R1,L2,R2\)来说 就是查询第二个序列中区间\([L2,R2]\)权值在\([L1,R1]\)的数字个数 这是经典问题,可以用主席树解决现在考虑树上问题,
路径映射到区间是不连续的,用树链剖分处理,每段都是连续的,就可以用主席树来查询 主席树查询也是要求连续的,如果同样用树链剖分来处理复杂度有3个log了 由于不带修改,如果用dfs序的方式建主席树 就可以用root[u],root[v],root[lca]这些根的信息来处理表示u到v上的所有路径信息了 这样复杂度就变成两个log了ps: 码力太弱了,写不动,写了大概两个半小时才写完,中间顺带复习了一波树链剖分
#include#define LL long long#define P pair #define ls(i) seg[i].lc#define rs(i) seg[i].rcusing namespace std;namespace IO { const int MX = 4e7; //1e7 占用内存 11000kb char buf[MX]; int c, sz; void begin() { c = 0; sz = fread(buf, 1, MX, stdin);//一次性全部读入 } inline bool read(int &t) { while (c < sz && buf[c] != '-' && (buf[c] < '0' || buf[c] > '9')) c++; if (c >= sz) return false;//若读完整个缓冲块则退出 bool flag = 0; if(buf[c] == '-') flag = 1, c++; for(t = 0; c < sz && '0' <= buf[c] && buf[c] <= '9'; c++) t = t * 10 + buf[c] - '0'; if(flag) t = -t; return true; }}void read(int &x){ x = 0; char c = getchar(); while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();}const int N = 1e5 + 10;int n1,n2,x;int value2[N];map pos;struct Edge{int v,nxt;};///主席树部分struct T{ int lc,rc,cnt;}seg[N * 50];int root[N],tot;void seg_init(){ root[0] = tot = 0; seg[0].lc = seg[0].rc = seg[0].cnt = 0;}void update(int &rt,int l,int r,int v){ seg[++tot] = seg[rt]; rt = tot; seg[rt].cnt++; if(l >= r) return ; int m = l + r>>1; if(v <= m) update(ls(rt),l,m,v); else update(rs(rt),m+1,r,v);}int query(int rt1,int rt2,int rt3,int l,int r,int pos){ if(r <= pos) return seg[rt1].cnt + seg[rt2].cnt - 2 * seg[rt3].cnt; if(l > pos) return 0; int m = l + r >>1; return query(ls(rt1),ls(rt2),ls(rt3),l,m,pos) + query(rs(rt1),rs(rt2),rs(rt3),m+1,r,pos);}int calc(int rt1,int rt2,int rt3,int L,int R){ return query(rt1,rt2,rt3,1,n1,R) - query(rt1,rt2,rt3,1,n1,L-1);}///dfs插入部分Edge e[2 * N];int head[N],EN;int dep[N];int fa[N][25];void edge_init(){ EN = 0; memset(head,-1,sizeof(head));}void add(int u,int v){ e[EN].v = v,e[EN].nxt = head[u]; head[u] = EN++;}void dfs(int u,int f,int d){ dep[u] = d,fa[u][0] = f; root[u] = root[f]; if(value2[u]) update(root[u],1,n1,value2[u]); for(int i = 1;i <= 20;i++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int i = head[u];~i;i = e[i].nxt){ int v = e[i].v; if(v != f) dfs(v, u, d + 1); }}int lca(int u,int v){ if(dep[u] < dep[v]) swap(u,v); int d = dep[u] - dep[v]; for(int i = 20;i >= 0 && u != v;i--) if(d & (1< = 0;i--) if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i]; return fa[u][0];}struct HLD{///树链剖分处理 int siz[N],dep[N],fa[N]; int top[N];///u所在链的第一个结点 int son[N];///以u为根的重儿子 int r[N];///每个位置是什么结点 int w[N];///结点u在哪个位置 int z; int head[N],EN; Edge e[2 * N]; void add(int u,int v){ e[EN].v = v,e[EN].nxt = head[u]; head[u] = EN++; } void init(){ EN = z = dep[1] = 0; memset(head,-1,sizeof(head)); } void dfs(int u,int f,int d){ fa[u] = f,dep[u] = d; siz[u] = 1,son[u] = 0; for(int i = head[u];~i;i = e[i].nxt){ int v = e[i].v; if(v!=f){ dfs(v,u,d+1); if(siz[v] > siz[son[u]]) son[u] = v; siz[u] += siz[v]; } } } void build(int u,int tp,int f){///当前结点所在链的第一个结点 w[u] = ++z,r[z] = u,top[u] = tp; if(son[u]) build(son[u],tp,u); for(int i = head[u];~i;i = e[i].nxt){ if(e[i].v != f && e[i].v != son[u]) build(e[i].v,e[i].v,u); } } int solve(int u,int v,int c,int d){ int cd = lca(c,d); int ans = 0; int f1 = top[u],f2 = top[v]; while(f1 != f2){ if(dep[f1] < dep[f2]) swap(f1,f2),swap(u,v); ans += calc(root[c],root[d],root[cd],w[f1],w[u]); if(value2[cd] >= w[f1] && value2[cd] <= w[u]) ans++; u = fa[f1],f1 = top[u]; } if(dep[u] > dep[v]) swap(u,v); ans += calc(root[c],root[d],root[cd],w[u],w[v]); if(value2[cd] >= w[u] && value2[cd] <= w[v]) ans++; return ans; }}hld;void init(){ edge_init(); hld.init(); seg_init(); pos.clear();}int main(){ IO::begin(); while(IO::read(n1)){ init(); for(int i = 2;i <= n1;i++){ IO::read(x); hld.add(x,i); } hld.dfs(1,0,0); hld.build(1,1,0); for(int i = 1;i <= n1;i++) { IO::read(x); pos[x] = i; } IO::read(n2); for(int i = 2;i <= n2;i++){ IO::read(x); add(x,i); } for(int i = 1;i <= n2;i++) { IO::read(x); if(pos[x]) value2[i] = hld.w[pos[x]]; else value2[i] = 0; } dfs(1,0,0); int a,b,c,d,q; IO::read(q); while(q--){ IO::read(a),IO::read(b),IO::read(c),IO::read(d); printf("%d\n",hld.solve(a,b,c,d)); } } return 0;}