前端也能理解的高阶算法:容斥原理

语言: CN / TW / HK

theme: v-green highlight: a11y-light


小知识,大挑战!本文正在参与“程序员必备小知识”创作活动。

什么是容斥原理

容斥原理是用于求集合的并集。

看这样一个问题,某班暑假集训,每个同学都会参加,其中参加足球队的有25人,参加排球队的有22人,参加游泳队的有24人,足球、排球都参加的有12人,足球、游泳都参加的有9人,排球、游泳都参加的有8人,三项都参加的有3人,求该班共有多少学生?

我们把问题抽象为,将元素分为 A,B,C 三类,并且种类之间有重复,用韦恩图可以表示为:

image.png

当我们求 A、B、C 的并集,如果用 A+B+C 就会有部分元素被重复计算,我们可以计算 A+B+C - (A∩B+B∩C+A∩C) + A∩B∩C

扩展到多个集合,我们要先将所有单个集合的大小计算出来,然后减去所有两个集合相交的部分,再加回所有三个集合相交的部分,再减去所有四个集合相交的部分,依此类推,一直计算到所有集合相交的部分。这就是容斥原理的基本定义。

image.png

我们可以简单记为 “奇加偶减” 即包含奇数个集合的数据要加到答案,包含偶数个集合的数据要被减去。

回到最开始的问题,可以写出答案 (25+22+24) - (12+9+8) + 3 = 45

LeetCode 实战

1201. 丑数 III

给你四个整数:nabc ,请你设计一个算法来找出第 n 个丑数。

丑数是可以被 a  b  c 整除的 正整数

示例 1 输入:n = 3, a = 2, b = 3, c = 5 输出:4 解释:丑数序列为 2, 3, 4, 5, 6, 8, 9, 10... 其中第 3 个是 4。

此题虽然标的是中等,实际难度大于很多困难,涉及知识点较多。

首先,结合上面的介绍,我们可以得出任意 x 以内的丑数个数,即:

能整除 a 的个数 + 能整除 b 的个数 + 能整除 c 的个数 - 能同时整除 a,b 的个数 - 能同时整除 a,c 的个数 - 能同时整除 b,c 的个数 + 能同时整除 a,b,c 的个数。

要求能同时整除 a,b 的个数,要先求出 ab最小公倍数 lcm(a,b) 然后求 x/lcm(a,b) 就是 x 以内的能同时整除 a,b 的个数。

而要求最小公倍数,需要先求 最大公约数,我们可以通过辗转相除法求最大公约数,然后用两个数的乘积除最大公约数即为最小公倍数。具体代码如下:

js /** * 求a和b的最大公约数 * @param {number} a * @param {number} b * @return {number} */ function gcd(a, b) { return b == 0 ? a : gcd(b, a % b) } /** * 求a和b的最小公倍数 * @param {number} a * @param {number} b * @return {number} */ function lcm(a, b) { return a * b / gcd(a, b) }

然后就是,我们能求出 x 以内丑数个数,那么我们就可以通过二分法,枚举答案去判断区间内的丑数是否大于等于 n 来求出答案。完整代码:

```js function gcd(a, b) { return b == 0 ? a : gcd(b, a % b) } function lcm(a, b) { return a * b / gcd(a, b) } /* * @param {number} n * @param {number} a * @param {number} b * @param {number} c * @return {number} / var nthUglyNumber = function(n, a, b, c) { // 求x以内的丑数个数 function calc(x) { // 包含一个集合的数目 let v1 = Math.floor(x / a) + Math.floor(x / b) + Math.floor(x / c); // 包含两个集合的 let v2 = Math.floor(x / lcm(a, b)) + Math.floor(x / lcm(b, c)) + Math.floor(x / lcm(a, c)); // 包含三个集合的 let v3 = x / lcm(lcm(a, b), c);

    return v1 - v2 + v3;
}
// 二分求解
let l = n, r = 2000000000, ans = l;
while (l <= r) {
    let m = Math.floor((l + r) / 2);
    if (calc(m) >= n) {
        ans = m;
        r = m - 1;
    } else {
        l = m + 1;
    }
}
return ans;

}; ```

参考资料