原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ33.html
题解
首先我们把问题转化成处理一个数组 ans ,其中 ans[i] 表示 d(u,a) 和 d(v,a) 同时为 i 的倍数的 (u,v) 个数。(最后求答案的时候只要莫比乌斯反演回来就好了。)
注意一下我的代码中对于 (u,v) 有祖先关系的是分开考虑的。
先点分治。
对于一个点分中心 x ,我们把答案分两部分考虑。
1. 在子树 x 中满足 LCA(u,v) = x 的 (u,v) 对于答案的贡献。
2. u,v 其中一个点在子树 x 中,另一个不在。
第一部分非常好求,不加赘述。
第二部分,我们考虑定义一个阀值 S ,我们预处理出 Smod[i][j] 表示 子树 x 中,到 x 的距离 mod i = j 的点的个数。这样,我们就可以 O(1) 得到 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 。这样,我们就可以在 $O(nS)$ 的复杂度内求出对于 $ans[i](i\leq S)$ 的贡献。 对于 i>S 的,我们可以直接暴力计算 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 ,复杂度为 $O(n^2/S)$ 。取 $S = O(\sqrt{n})$ 最优。故处理一个点分中心的复杂度为 $O(n\sqrt{n})$ (假设当前连通块大小为 n)。
所以总的时间复杂度为 $O(n\sqrt{n})$ 。
代码
#includeusing namespace std;typedef long long LL;LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x;}const int N=200005,M=500;int n;vector e[N];LL ans[N],ans2[N];int depth[N],fa[N];void dfs(int x,int pre,int d){ fa[x]=pre,depth[x]=d; for (auto y : e[x]) if (y!=pre) dfs(y,x,d+1);}int vis[N],size[N],Size;int Maxsize[N],rt;void get_root(int x,int pre){ size[x]=1,Maxsize[x]=0; for (auto y : e[x]) if (y!=pre&&!vis[y]){ get_root(y,x); size[x]+=size[y]; Maxsize[x]=max(Maxsize[x],size[y]); } Maxsize[x]=max(Maxsize[x],Size-size[x]); if (!rt||Maxsize[rt]>Maxsize[x]) rt=x;}vector d[N];void get_size(int x,int pre){ size[x]=1; for (auto y : e[x]) if (y!=pre&&!vis[y]) get_size(y,x),size[x]+=size[y];}void getd(int x,int pre,int d,vector &v){ while (d>=(int)v.size()) v.push_back(0); v[d]++; for (auto y : e[x]) if (y!=pre&&!vis[y]) getd(y,x,d+1,v);}LL S[N];LL Smod[M][M];void solve(int x){ rt=0; get_root(x,0); assert(rt!=0); vis[x=rt]=1; for (int i=0;i<=Size;i++) S[i]=0; int Mx=0; for (auto y : e[x]) if (!vis[y]){ get_size(y,0); if (depth[y] =1;i--) for (int j=i<<1;j<=Mx;j+=i) S[i]-=S[j]; S[0]++; int base=(int)(0.4*sqrt(Mx)+0.5); for (int i=1;i<=base;i++){ for (int j=0;j =1;i--) for (int j=i<<1;j<=n;j+=i) ans[i]-=ans[j]; for (int i=1;i<=n;i++) ans2[depth[i]]++; for (int i=n;i>=1;i--) ans2[i]+=ans2[i+1]; for (int i=1;i