类欧几里得算法

我们要快速求以下函数的值:

f

先考虑 $a\ge c$ 或 $b\ge c$ 的情况。我们考虑将 $a$ 和 $b$ 对 $c$ 取模。

此时 $a$ 变为 $a\bmod c$,$b$ 变为 $b\bmod c$。我们接下来考虑 $a<c,b<c$ 的情况。设 $m=\lfloor \frac{an+b}{c} \rfloor$,则

我们推导 $[j<\lfloor \frac{ai+b}{c} \rfloor]$ 的值:

所以 $[j<\lfloor \frac{ai+b}{c} \rfloor]=[i> \lfloor\frac{cj+c-b-1}{a}\rfloor]$。我们继续推导:

最后,边界情况是 $a=0$。显然此时 $f(a,b,c,n)=(n+1)\lfloor \frac{b}{c} \rfloor$。综上:

我们观察式子中 $a,c$ 位置上的变化。在第一个式子中,$a$ 对 $c$ 取模。在第二个式子中,$a,c$ 交换了位置。这样交替进行,直到 $a=0$ 的过程,与欧几里得算法中辗转相除的过程相同。因此直接递归的时间复杂度是 $O(\log n)$。

g

接下来的推推导过程与 $f$ 类似,但是复杂了很多。

先考虑 $a\ge c$ 或 $b\ge c$ 的情况

然后考虑 $a<c,b<c$ 的情况。省略了一些在推导 $f$ 时推导过的东西。

我们设 $t=\lfloor\frac{cj+c-b-1}{a}\rfloor$。继续推导

最后考虑边界情况 $a=0$。显然此时 $g(a,b,c,n)=\frac{n(n+1)}{2}\lfloor \frac{b}{c} \rfloor$。综上:

h

先考虑 $a\ge c$ 或 $b\ge c$ 的情况

整理得:

然后考虑 $a<c,b<c$ 的情况。平方不太好处理,我们考虑平方分成两个部分。

所以:

我们接着推导式子的前半部分:

带入到原式,得

最后考虑边界情况 $a=0$。显然此时 $h(a,b,c,n)=(n+1)\lfloor \frac{b}{c} \rfloor^2$。综上:

实现

如果 $f,g,h$ 分开实现,时间复杂度会无法保证。例如,在求 $g(a,b,c,n)$ 时,如果满足 $a<c,b<c$,那么程序会同时调用 $f$ 和 $h$,导致递归形成很多分支,时间复杂度也就无法保证。

所以我们将 $f,g,h$ 放在同一个函数实现。代码没有任何难度,抄上面的公式即可。公式很复杂,注意不要抄错了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
const int mod = 998244353, inv2 = 499122177, inv6 = 166374059;
struct S
{
ll f, g, h;
S(ll f, ll g, ll h) : f(f), g(g), h(h) {}
};
S solve(ll a, ll b, ll c, ll n)
{
if (a == 0) return S((n + 1) * (b / c) % mod, n * (n + 1) / 2 % mod * (b / c) % mod, (n + 1) * (b / c) % mod * (b / c) % mod);
if (a >= c || b >= c)
{
S x = solve(a % c, b % c, c, n);
ll f = (x.f + n * (n + 1) / 2 % mod * (a / c) % mod + (n + 1) * (b / c) % mod) % mod;
ll g = (x.g + (n * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod) * (a / c) % mod + n * (n + 1) / 2 % mod * (b / c) % mod) % mod;
ll h = (x.h + 2 * x.f % mod * (b / c) % mod + 2 * x.g % mod * (a / c) % mod + (n * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod) * (a / c) % mod * (a / c) % mod + (n + 1) * (b / c) % mod * (b / c) % mod + n * (n + 1) % mod * (a / c) % mod * (b / c) % mod) % mod;
return S(f, g, h);
}
else
{
ll m = (a * n + b) / c;
S x = solve(c, c - b - 1, a, m - 1);
ll f = (n * m % mod - x.f + mod) % mod;
ll g = ((m * n % mod * (n + 1) % mod - x.h - x.f) % mod + mod) % mod * inv2 % mod;
ll h = ((n * m % mod * (m + 1) % mod - 2 * x.f - 2 * x.g - f) % mod + mod) % mod;
return S(f, g, h);
}
}