跳转至

3549. 两个多项式相乘 🔒

题目描述

给定两个整数数组 poly1 和 poly2,其中每个数组中下标 i 的元素表示多项式中 xi 的系数。

设 A(x) 和 B(x) 分别是 poly1 和 poly2 表示的多项式。

返回一个长度为 (poly1.length + poly2.length - 1) 的整数数组 result 表示乘积多项式 R(x) = A(x) * B(x) 的系数,其中 result[i] 表示 R(x) 中 xi 的系数。

 

示例 1:

输入:poly1 = [3,2,5], poly2 = [1,4]

输出:[3,14,13,20]

解释:

  • A(x) = 3 + 2x + 5x2 且 B(x) = 1 + 4x
  • R(x) = (3 + 2x + 5x2) * (1 + 4x)
  • R(x) = 3 * 1 + (3 * 4 + 2 * 1)x + (2 * 4 + 5 * 1)x2 + (5 * 4)x3
  • R(x) = 3 + 14x + 13x2 + 20x3
  • 因此,result = [3, 14, 13, 20]

示例 2:

输入:poly1 = [1,0,-2], poly2 = [-1]

输出:[-1,0,2]

解释:

  • A(x) = 1 + 0x - 2x2 且 B(x) = -1
  • R(x) = (1 + 0x - 2x2) * (-1)
  • R(x) = -1 + 0x + 2x2
  • 因此,result = [-1, 0, 2]

示例 3:

输入:poly1 = [1,5,-3], poly2 = [-4,2,0]

输出:[-4,-18,22,-6,0]

解释:

  • A(x) = 1 + 5x - 3x2 且 B(x) = -4 + 2x + 0x2
  • R(x) = (1 + 5x - 3x2) * (-4 + 2x + 0x2)
  • R(x) = 1 * -4 + (1 * 2 + 5 * -4)x + (5 * 2 + -3 * -4)x2 + (-3 * 2)x3 + 0x4
  • R(x) = -4 -18x + 22x2 -6x3 + 0x4
  • 因此,result = [-4, -18, 22, -6, 0]

 

提示:

  • 1 <= poly1.length, poly2.length <= 5 * 104
  • -103 <= poly1[i], poly2[i] <= 103
  • poly1 与 poly2 至少包含一个非零系数。

解法

方法一:FFT

我们可以使用快速傅里叶变换(FFT)来高效地计算两个多项式的乘积。FFT 是一种高效的算法,可以在 \(O(n \log n)\) 的时间复杂度内计算多项式的乘积。

具体步骤如下:

  1. 补足长度 将结果长度 \(m = |A|+|B|-1\) 向上取最近的 2 的幂 \(n\),便于分治 FFT。
  2. FFT 变换 分别对两条系数序列做正变换(invert=False)。
  3. 逐点相乘 对应频域元素相乘。
  4. 逆 FFT 对乘积序列做逆变换(invert=True),并把实部四舍五入取整得到最终系数。

时间复杂度 \(O(n \log n)\),空间复杂度 \(O(n)\)。其中 \(n\) 是多项式的长度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class Solution:
    def multiply(self, poly1: List[int], poly2: List[int]) -> List[int]:
        if not poly1 or not poly2:
            return []

        # 1. 计算目标长度
        m = len(poly1) + len(poly2) - 1
        n = 1
        while n < m:
            n <<= 1

        # 2. 填充到长度 n
        fa = list(map(complex, poly1)) + [0j] * (n - len(poly1))
        fb = list(map(complex, poly2)) + [0j] * (n - len(poly2))

        # 3. FFT 正变换
        self._fft(fa, invert=False)
        self._fft(fb, invert=False)

        # 4. 逐点相乘
        for i in range(n):
            fa[i] *= fb[i]

        # 5. 逆变换并取整
        self._fft(fa, invert=True)
        return [int(round(fa[i].real)) for i in range(m)]

    def _fft(self, a: List[complex], invert: bool) -> None:
        n = len(a)

        # 位反转重排
        j = 0
        for i in range(1, n):
            bit = n >> 1
            while j & bit:
                j ^= bit
                bit >>= 1
            j ^= bit
            if i < j:
                a[i], a[j] = a[j], a[i]

        # 分治蝶形
        len_ = 2
        while len_ <= n:
            ang = 2 * math.pi / len_ * (-1 if invert else 1)
            wlen = complex(math.cos(ang), math.sin(ang))
            for i in range(0, n, len_):
                w = 1 + 0j
                half = i + len_ // 2
                for j in range(i, half):
                    u = a[j]
                    v = a[j + len_ // 2] * w
                    a[j] = u + v
                    a[j + len_ // 2] = u - v
                    w *= wlen
            len_ <<= 1

        # 逆变换需除以 n
        if invert:
            for i in range(n):
                a[i] /= n
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class Solution {
    public long[] multiply(int[] poly1, int[] poly2) {
        if (poly1 == null || poly2 == null || poly1.length == 0 || poly2.length == 0) {
            return new long[0];
        }

        int m = poly1.length + poly2.length - 1;
        int n = 1;
        while (n < m) n <<= 1;

        Complex[] fa = new Complex[n];
        Complex[] fb = new Complex[n];
        for (int i = 0; i < n; i++) {
            fa[i] = new Complex(i < poly1.length ? poly1[i] : 0, 0);
            fb[i] = new Complex(i < poly2.length ? poly2[i] : 0, 0);
        }

        fft(fa, false);
        fft(fb, false);

        for (int i = 0; i < n; i++) {
            fa[i] = fa[i].mul(fb[i]);
        }

        fft(fa, true);

        long[] res = new long[m];
        for (int i = 0; i < m; i++) {
            res[i] = Math.round(fa[i].re);
        }
        return res;
    }

    private static void fft(Complex[] a, boolean invert) {
        int n = a.length;

        for (int i = 1, j = 0; i < n; i++) {
            int bit = n >>> 1;
            while ((j & bit) != 0) {
                j ^= bit;
                bit >>>= 1;
            }
            j ^= bit;
            if (i < j) {
                Complex tmp = a[i];
                a[i] = a[j];
                a[j] = tmp;
            }
        }

        for (int len = 2; len <= n; len <<= 1) {
            double ang = 2 * Math.PI / len * (invert ? -1 : 1);
            Complex wlen = new Complex(Math.cos(ang), Math.sin(ang));

            for (int i = 0; i < n; i += len) {
                Complex w = new Complex(1, 0);
                int half = len >>> 1;
                for (int j = 0; j < half; j++) {
                    Complex u = a[i + j];
                    Complex v = a[i + j + half].mul(w);
                    a[i + j] = u.add(v);
                    a[i + j + half] = u.sub(v);
                    w = w.mul(wlen);
                }
            }
        }

        if (invert) {
            for (int i = 0; i < n; i++) {
                a[i].re /= n;
                a[i].im /= n;
            }
        }
    }

    private static final class Complex {
        double re, im;
        Complex(double re, double im) {
            this.re = re;
            this.im = im;
        }
        Complex add(Complex o) {
            return new Complex(re + o.re, im + o.im);
        }
        Complex sub(Complex o) {
            return new Complex(re - o.re, im - o.im);
        }
        Complex mul(Complex o) {
            return new Complex(re * o.re - im * o.im, re * o.im + im * o.re);
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solution {
    using cd = complex<double>;

    void fft(vector<cd>& a, bool invert) {
        int n = a.size();
        for (int i = 1, j = 0; i < n; ++i) {
            int bit = n >> 1;
            for (; j & bit; bit >>= 1) j ^= bit;
            j ^= bit;
            if (i < j) swap(a[i], a[j]);
        }
        for (int len = 2; len <= n; len <<= 1) {
            double ang = 2 * M_PI / len * (invert ? -1 : 1);
            cd wlen(cos(ang), sin(ang));
            for (int i = 0; i < n; i += len) {
                cd w(1, 0);
                int half = len >> 1;
                for (int j = 0; j < half; ++j) {
                    cd u = a[i + j];
                    cd v = a[i + j + half] * w;
                    a[i + j] = u + v;
                    a[i + j + half] = u - v;
                    w *= wlen;
                }
            }
        }
        if (invert)
            for (cd& x : a) x /= n;
    }

public:
    vector<long long> multiply(vector<int>& poly1, vector<int>& poly2) {
        if (poly1.empty() || poly2.empty()) return {};
        int m = poly1.size() + poly2.size() - 1;
        int n = 1;
        while (n < m) n <<= 1;

        vector<cd> fa(n), fb(n);
        for (int i = 0; i < n; ++i) {
            fa[i] = i < poly1.size() ? cd(poly1[i], 0) : cd(0, 0);
            fb[i] = i < poly2.size() ? cd(poly2[i], 0) : cd(0, 0);
        }

        fft(fa, false);
        fft(fb, false);
        for (int i = 0; i < n; ++i) fa[i] *= fb[i];
        fft(fa, true);

        vector<long long> res(m);
        for (int i = 0; i < m; ++i) res[i] = llround(fa[i].real());
        return res;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
func multiply(poly1 []int, poly2 []int) []int64 {
    if len(poly1) == 0 || len(poly2) == 0 {
        return []int64{}
    }

    m := len(poly1) + len(poly2) - 1
    n := 1
    for n < m {
        n <<= 1
    }

    fa := make([]complex128, n)
    fb := make([]complex128, n)
    for i := 0; i < len(poly1); i++ {
        fa[i] = complex(float64(poly1[i]), 0)
    }
    for i := 0; i < len(poly2); i++ {
        fb[i] = complex(float64(poly2[i]), 0)
    }

    fft(fa, false)
    fft(fb, false)
    for i := 0; i < n; i++ {
        fa[i] *= fb[i]
    }
    fft(fa, true)

    res := make([]int64, m)
    for i := 0; i < m; i++ {
        res[i] = int64(math.Round(real(fa[i])))
    }
    return res
}

func fft(a []complex128, invert bool) {
    n := len(a)
    for i, j := 1, 0; i < n; i++ {
        bit := n >> 1
        for ; j&bit != 0; bit >>= 1 {
            j ^= bit
        }
        j ^= bit
        if i < j {
            a[i], a[j] = a[j], a[i]
        }
    }

    for length := 2; length <= n; length <<= 1 {
        angle := 2 * math.Pi / float64(length)
        if invert {
            angle = -angle
        }
        wlen := cmplx.Rect(1, angle)
        for i := 0; i < n; i += length {
            w := complex(1, 0)
            half := length >> 1
            for j := 0; j < half; j++ {
                u := a[i+j]
                v := a[i+j+half] * w
                a[i+j] = u + v
                a[i+j+half] = u - v
                w *= wlen
            }
        }
    }

    if invert {
        for i := range a {
            a[i] /= complex(float64(n), 0)
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
export function multiply(poly1: number[], poly2: number[]): number[] {
    const n1 = poly1.length,
        n2 = poly2.length;
    if (!n1 || !n2) return [];

    if (Math.min(n1, n2) <= 64) {
        const m = n1 + n2 - 1,
            res = new Array<number>(m).fill(0);
        for (let i = 0; i < n1; ++i) for (let j = 0; j < n2; ++j) res[i + j] += poly1[i] * poly2[j];
        return res.map(v => Math.round(v));
    }

    let n = 1,
        m = n1 + n2 - 1;
    while (n < m) n <<= 1;

    const reA = new Float64Array(n);
    const imA = new Float64Array(n);
    for (let i = 0; i < n1; ++i) reA[i] = poly1[i];

    const reB = new Float64Array(n);
    const imB = new Float64Array(n);
    for (let i = 0; i < n2; ++i) reB[i] = poly2[i];

    fft(reA, imA, false);
    fft(reB, imB, false);

    for (let i = 0; i < n; ++i) {
        const a = reA[i],
            b = imA[i],
            c = reB[i],
            d = imB[i];
        reA[i] = a * c - b * d;
        imA[i] = a * d + b * c;
    }

    fft(reA, imA, true);

    const out = new Array<number>(m);
    for (let i = 0; i < m; ++i) out[i] = Math.round(reA[i]);
    return out;
}

function fft(re: Float64Array, im: Float64Array, invert: boolean): void {
    const n = re.length;

    for (let i = 1, j = 0; i < n; ++i) {
        let bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) {
            [re[i], re[j]] = [re[j], re[i]];
            [im[i], im[j]] = [im[j], im[i]];
        }
    }

    for (let len = 2; len <= n; len <<= 1) {
        const ang = ((2 * Math.PI) / len) * (invert ? -1 : 1);
        const wlenCos = Math.cos(ang),
            wlenSin = Math.sin(ang);

        for (let i = 0; i < n; i += len) {
            let wRe = 1,
                wIm = 0;
            const half = len >> 1;
            for (let j = 0; j < half; ++j) {
                const uRe = re[i + j],
                    uIm = im[i + j];
                const vRe0 = re[i + j + half],
                    vIm0 = im[i + j + half];
                const vRe = vRe0 * wRe - vIm0 * wIm;
                const vIm = vRe0 * wIm + vIm0 * wRe;
                re[i + j] = uRe + vRe;
                im[i + j] = uIm + vIm;
                re[i + j + half] = uRe - vRe;
                im[i + j + half] = uIm - vIm;
                const nextWRe = wRe * wlenCos - wIm * wlenSin;
                wIm = wRe * wlenSin + wIm * wlenCos;
                wRe = nextWRe;
            }
        }
    }

    if (invert) {
        for (let i = 0; i < n; ++i) {
            re[i] /= n;
            im[i] /= n;
        }
    }
}

评论