
题目描述
给定两个整数数组 poly1
和 poly2
,其中每个数组中下标 i
的元素表示多项式中 xi
的系数。
设 A(x)
和 B(x)
分别是 poly1
和 poly2
表示的多项式。
返回一个长度为 (poly1.length + poly2.length - 1)
的整数数组 result
表示乘积多项式 R(x) = A(x) * B(x)
的系数,其中 result[i]
表示 R(x)
中 xi
的系数。
示例 1:
输入:poly1 = [3,2,5], poly2 = [1,4]
输出:[3,14,13,20]
解释:
A(x) = 3 + 2x + 5x2
且 B(x) = 1 + 4x
R(x) = (3 + 2x + 5x2) * (1 + 4x)
R(x) = 3 * 1 + (3 * 4 + 2 * 1)x + (2 * 4 + 5 * 1)x2 + (5 * 4)x3
R(x) = 3 + 14x + 13x2 + 20x3
- 因此,result =
[3, 14, 13, 20]
。
示例 2:
输入:poly1 = [1,0,-2], poly2 = [-1]
输出:[-1,0,2]
解释:
A(x) = 1 + 0x - 2x2
且 B(x) = -1
R(x) = (1 + 0x - 2x2) * (-1)
R(x) = -1 + 0x + 2x2
- 因此,result =
[-1, 0, 2]
。
示例 3:
输入:poly1 = [1,5,-3], poly2 = [-4,2,0]
输出:[-4,-18,22,-6,0]
解释:
A(x) = 1 + 5x - 3x2
且 B(x) = -4 + 2x + 0x2
R(x) = (1 + 5x - 3x2) * (-4 + 2x + 0x2)
R(x) = 1 * -4 + (1 * 2 + 5 * -4)x + (5 * 2 + -3 * -4)x2 + (-3 * 2)x3 + 0x4
R(x) = -4 -18x + 22x2 -6x3 + 0x4
- 因此,result =
[-4, -18, 22, -6, 0]
。
提示:
1 <= poly1.length, poly2.length <= 5 * 104
-103 <= poly1[i], poly2[i] <= 103
poly1
与 poly2
至少包含一个非零系数。
解法
方法一:FFT
我们可以使用快速傅里叶变换(FFT)来高效地计算两个多项式的乘积。FFT 是一种高效的算法,可以在 \(O(n \log n)\) 的时间复杂度内计算多项式的乘积。
具体步骤如下:
- 补足长度 将结果长度 \(m = |A|+|B|-1\) 向上取最近的 2 的幂 \(n\),便于分治 FFT。
- FFT 变换 分别对两条系数序列做正变换(
invert=False
)。
- 逐点相乘 对应频域元素相乘。
- 逆 FFT 对乘积序列做逆变换(
invert=True
),并把实部四舍五入取整得到最终系数。
时间复杂度 \(O(n \log n)\),空间复杂度 \(O(n)\)。其中 \(n\) 是多项式的长度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 | class Solution:
def multiply(self, poly1: List[int], poly2: List[int]) -> List[int]:
if not poly1 or not poly2:
return []
# 1. 计算目标长度
m = len(poly1) + len(poly2) - 1
n = 1
while n < m:
n <<= 1
# 2. 填充到长度 n
fa = list(map(complex, poly1)) + [0j] * (n - len(poly1))
fb = list(map(complex, poly2)) + [0j] * (n - len(poly2))
# 3. FFT 正变换
self._fft(fa, invert=False)
self._fft(fb, invert=False)
# 4. 逐点相乘
for i in range(n):
fa[i] *= fb[i]
# 5. 逆变换并取整
self._fft(fa, invert=True)
return [int(round(fa[i].real)) for i in range(m)]
def _fft(self, a: List[complex], invert: bool) -> None:
n = len(a)
# 位反转重排
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]
# 分治蝶形
len_ = 2
while len_ <= n:
ang = 2 * math.pi / len_ * (-1 if invert else 1)
wlen = complex(math.cos(ang), math.sin(ang))
for i in range(0, n, len_):
w = 1 + 0j
half = i + len_ // 2
for j in range(i, half):
u = a[j]
v = a[j + len_ // 2] * w
a[j] = u + v
a[j + len_ // 2] = u - v
w *= wlen
len_ <<= 1
# 逆变换需除以 n
if invert:
for i in range(n):
a[i] /= n
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 | class Solution {
public long[] multiply(int[] poly1, int[] poly2) {
if (poly1 == null || poly2 == null || poly1.length == 0 || poly2.length == 0) {
return new long[0];
}
int m = poly1.length + poly2.length - 1;
int n = 1;
while (n < m) n <<= 1;
Complex[] fa = new Complex[n];
Complex[] fb = new Complex[n];
for (int i = 0; i < n; i++) {
fa[i] = new Complex(i < poly1.length ? poly1[i] : 0, 0);
fb[i] = new Complex(i < poly2.length ? poly2[i] : 0, 0);
}
fft(fa, false);
fft(fb, false);
for (int i = 0; i < n; i++) {
fa[i] = fa[i].mul(fb[i]);
}
fft(fa, true);
long[] res = new long[m];
for (int i = 0; i < m; i++) {
res[i] = Math.round(fa[i].re);
}
return res;
}
private static void fft(Complex[] a, boolean invert) {
int n = a.length;
for (int i = 1, j = 0; i < n; i++) {
int bit = n >>> 1;
while ((j & bit) != 0) {
j ^= bit;
bit >>>= 1;
}
j ^= bit;
if (i < j) {
Complex tmp = a[i];
a[i] = a[j];
a[j] = tmp;
}
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * Math.PI / len * (invert ? -1 : 1);
Complex wlen = new Complex(Math.cos(ang), Math.sin(ang));
for (int i = 0; i < n; i += len) {
Complex w = new Complex(1, 0);
int half = len >>> 1;
for (int j = 0; j < half; j++) {
Complex u = a[i + j];
Complex v = a[i + j + half].mul(w);
a[i + j] = u.add(v);
a[i + j + half] = u.sub(v);
w = w.mul(wlen);
}
}
}
if (invert) {
for (int i = 0; i < n; i++) {
a[i].re /= n;
a[i].im /= n;
}
}
}
private static final class Complex {
double re, im;
Complex(double re, double im) {
this.re = re;
this.im = im;
}
Complex add(Complex o) {
return new Complex(re + o.re, im + o.im);
}
Complex sub(Complex o) {
return new Complex(re - o.re, im - o.im);
}
Complex mul(Complex o) {
return new Complex(re * o.re - im * o.im, re * o.im + im * o.re);
}
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53 | class Solution {
using cd = complex<double>;
void fft(vector<cd>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * M_PI / len * (invert ? -1 : 1);
cd wlen(cos(ang), sin(ang));
for (int i = 0; i < n; i += len) {
cd w(1, 0);
int half = len >> 1;
for (int j = 0; j < half; ++j) {
cd u = a[i + j];
cd v = a[i + j + half] * w;
a[i + j] = u + v;
a[i + j + half] = u - v;
w *= wlen;
}
}
}
if (invert)
for (cd& x : a) x /= n;
}
public:
vector<long long> multiply(vector<int>& poly1, vector<int>& poly2) {
if (poly1.empty() || poly2.empty()) return {};
int m = poly1.size() + poly2.size() - 1;
int n = 1;
while (n < m) n <<= 1;
vector<cd> fa(n), fb(n);
for (int i = 0; i < n; ++i) {
fa[i] = i < poly1.size() ? cd(poly1[i], 0) : cd(0, 0);
fb[i] = i < poly2.size() ? cd(poly2[i], 0) : cd(0, 0);
}
fft(fa, false);
fft(fb, false);
for (int i = 0; i < n; ++i) fa[i] *= fb[i];
fft(fa, true);
vector<long long> res(m);
for (int i = 0; i < m; ++i) res[i] = llround(fa[i].real());
return res;
}
};
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 | func multiply(poly1 []int, poly2 []int) []int64 {
if len(poly1) == 0 || len(poly2) == 0 {
return []int64{}
}
m := len(poly1) + len(poly2) - 1
n := 1
for n < m {
n <<= 1
}
fa := make([]complex128, n)
fb := make([]complex128, n)
for i := 0; i < len(poly1); i++ {
fa[i] = complex(float64(poly1[i]), 0)
}
for i := 0; i < len(poly2); i++ {
fb[i] = complex(float64(poly2[i]), 0)
}
fft(fa, false)
fft(fb, false)
for i := 0; i < n; i++ {
fa[i] *= fb[i]
}
fft(fa, true)
res := make([]int64, m)
for i := 0; i < m; i++ {
res[i] = int64(math.Round(real(fa[i])))
}
return res
}
func fft(a []complex128, invert bool) {
n := len(a)
for i, j := 1, 0; i < n; i++ {
bit := n >> 1
for ; j&bit != 0; bit >>= 1 {
j ^= bit
}
j ^= bit
if i < j {
a[i], a[j] = a[j], a[i]
}
}
for length := 2; length <= n; length <<= 1 {
angle := 2 * math.Pi / float64(length)
if invert {
angle = -angle
}
wlen := cmplx.Rect(1, angle)
for i := 0; i < n; i += length {
w := complex(1, 0)
half := length >> 1
for j := 0; j < half; j++ {
u := a[i+j]
v := a[i+j+half] * w
a[i+j] = u + v
a[i+j+half] = u - v
w *= wlen
}
}
}
if invert {
for i := range a {
a[i] /= complex(float64(n), 0)
}
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 | export function multiply(poly1: number[], poly2: number[]): number[] {
const n1 = poly1.length,
n2 = poly2.length;
if (!n1 || !n2) return [];
if (Math.min(n1, n2) <= 64) {
const m = n1 + n2 - 1,
res = new Array<number>(m).fill(0);
for (let i = 0; i < n1; ++i) for (let j = 0; j < n2; ++j) res[i + j] += poly1[i] * poly2[j];
return res.map(v => Math.round(v));
}
let n = 1,
m = n1 + n2 - 1;
while (n < m) n <<= 1;
const reA = new Float64Array(n);
const imA = new Float64Array(n);
for (let i = 0; i < n1; ++i) reA[i] = poly1[i];
const reB = new Float64Array(n);
const imB = new Float64Array(n);
for (let i = 0; i < n2; ++i) reB[i] = poly2[i];
fft(reA, imA, false);
fft(reB, imB, false);
for (let i = 0; i < n; ++i) {
const a = reA[i],
b = imA[i],
c = reB[i],
d = imB[i];
reA[i] = a * c - b * d;
imA[i] = a * d + b * c;
}
fft(reA, imA, true);
const out = new Array<number>(m);
for (let i = 0; i < m; ++i) out[i] = Math.round(reA[i]);
return out;
}
function fft(re: Float64Array, im: Float64Array, invert: boolean): void {
const n = re.length;
for (let i = 1, j = 0; i < n; ++i) {
let bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) {
[re[i], re[j]] = [re[j], re[i]];
[im[i], im[j]] = [im[j], im[i]];
}
}
for (let len = 2; len <= n; len <<= 1) {
const ang = ((2 * Math.PI) / len) * (invert ? -1 : 1);
const wlenCos = Math.cos(ang),
wlenSin = Math.sin(ang);
for (let i = 0; i < n; i += len) {
let wRe = 1,
wIm = 0;
const half = len >> 1;
for (let j = 0; j < half; ++j) {
const uRe = re[i + j],
uIm = im[i + j];
const vRe0 = re[i + j + half],
vIm0 = im[i + j + half];
const vRe = vRe0 * wRe - vIm0 * wIm;
const vIm = vRe0 * wIm + vIm0 * wRe;
re[i + j] = uRe + vRe;
im[i + j] = uIm + vIm;
re[i + j + half] = uRe - vRe;
im[i + j + half] = uIm - vIm;
const nextWRe = wRe * wlenCos - wIm * wlenSin;
wIm = wRe * wlenSin + wIm * wlenCos;
wRe = nextWRe;
}
}
}
if (invert) {
for (let i = 0; i < n; ++i) {
re[i] /= n;
im[i] /= n;
}
}
}
|