跳转至

拉格朗日插值

例题 Luogu P4781【模板】拉格朗日插值

给出 \(n\) 个点对 \((x_i,y_i)\)\(k\),且 \(\forall i,j\)\(i\neq j \iff x_i\neq x_j\)\(f(x_i)\equiv y_i\pmod{998244353}\)\(\deg(f(x))<n\)(定义 \(\deg(0)=-\infty\)),求 \(f(k)\bmod{998244353}\)

方法 1:差分法

差分法适用于 \(x_i=i\) 的情况。

如,用差分法求某三次多项式 \(f(x)=\sum_{i=0}^{3} a_ix^i\) 的多项式形式,已知 \(f(1)\)\(f(6)\) 的值分别为 \(1, 5, 14, 30, 55, 91\)

\[ \begin{array}{cccccccccccc} 1 & & 5 & & 14 & & 30 & & 55 & & 91 & \\ & 4 & & 9 & & 16 & & 25 & & 36 & \\ & & 5 & & 7 & & 9 & & 11 & \\ & & & 2 & & 2 & & 2 & \\ \end{array} \]

第一行为 \(f(x)\) 的连续的前 \(n\) 项;之后的每一行为之前一行中对应的相邻两项之差。观察到,如果这样操作的次数足够多(前提是 \(f(x)\) 为多项式),最终总会返回一个定值。

计算出第 \(i-1\) 阶差分的首项为 \(\sum_{j=1}^{i}(-1)^{i+j}\binom{i-1}{j-1}f(j)\),第 \(i-1\) 阶差分的首项对 \(f(k)\) 的贡献为 \(\binom{k-1}{i-1}\) 次。

\[ f(k)=\sum_{i=1}^n\binom{k-1}{i-1}\sum_{j=1}^{i}(-1)^{i+j}\binom{i-1}{j-1}f(j) \]

时间复杂度为 \(O(n^2)\)。这种方法对给出的点的限制性较强。

方法 2:待定系数法

\(f(x)=\sum_{i=0}^{n-1} a_ix^i\) 将每个 \(x_i\) 代入 \(f(x)\),有 \(f(x_i)=y_i\),这样就可以得到一个由 \(n\)\(n\) 元一次方程所组成的方程组,然后使用 高斯消元 解该方程组求出每一项 \(a_i\),即确定了 \(f(x)\) 的表达式。

如果您不知道什么是高斯消元,请看 高斯消元

时间复杂度 \(O(n^3)\),对给出点的坐标无要求。

方法 3:拉格朗日插值法

多项式部分简介 里我们已经定义了多项式除法。

那么我们会有:

\[ f(x)\equiv f(a)\pmod{(x-a)} \]

因为 \(f(x)-f(a)=(a_0-a_0)+a_1(x^1-a^1)+a_1(x^2-a^2)+\cdots +a_n(x^n-a^n)\),显然有 \((x-a)\) 这个因式。

这样我们就可以列一个关于 \(f(x)\) 的多项式线性同余方程组:

\[ \begin{cases} f(x)\equiv y_1\pmod{(x-x_1)}\\ f(x)\equiv y_2\pmod{(x-x_2)}\\ \vdots\\ f(x)\equiv y_n\pmod{(x-x_n)} \end{cases} \]

\[ \begin{aligned} M(x)&=\prod_{i=1}^n{(x-x_i)},\\ m_i(x)&=\dfrac M{x-x_i} \end{aligned} \]

\(m_i(x)\) 在模 \((x-x_i)\) 意义下的乘法逆元为

\[ m_i(x_i)^{-1}=\prod_{j\ne i}{(x_i-x_j)^{-1}} \]

\[ \begin{aligned} f(x)&\equiv\sum_{i=1}^n{y_i\left(m_i(x)\right)\left(m_i(x_i)^{-1}\right)}&\pmod{M(x)}\\ &\equiv\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}}&\pmod{M(x)} \end{aligned} \]

又因为 \(\deg\left(f(x)\right)<n\) 所以在模 \(M(x)\) 意义下 \(f(x)\) 就是唯一的,即:

\[ f(x)=\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}} \]

这就是拉格朗日插值的表达式。

通常意义下拉格朗日插值的一种推导

由于要求构造一个函数 \(f(x)\) 过点 \(P_1(x_1, y_1), P_2(x_2,y_2),\cdots,P_n(x_n,y_n)\)。首先设第 \(i\) 个点在 \(x\) 轴上的投影为 \(P_i^{\prime}(x_i,0)\)

考虑构造 \(n\) 个函数 \(f_1(x), f_2(x), \cdots, f_n(x)\),使得对于第 \(i\) 个函数 \(f_i(x)\),其图像过 \(\begin{cases}P_j^{\prime}(x_j,0),(j\neq i)\\P_i(x_i,y_i)\end{cases}\),则可知题目所求的函数 \(f(x)=\sum\limits_{i=1}^nf_i(x)\)

那么可以设 \(f_i(x)=a\cdot\prod_{j\neq i}(x-x_j)\),将点 \(P_i(x_i,y_i)\) 代入可以知道 \(a=\dfrac{y_i}{\prod_{j\neq i} (x_i-x_j)}\),所以

\(f_i(x)=y_i\cdot\dfrac{\prod_{j\neq i} (x-x_j)}{\prod_{j\neq i} (x_i-x_j)}=y_i\cdot\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}\)

那么我们就可以从另一个角度推导出通常意义下(而非模意义下)拉格朗日插值的式子为:

\(f(x)=\sum_{i=1}^ny_i\cdot\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}\)

代码实现

因为在固定模 \(998244353\) 意义下运算,计算乘法逆元的时间复杂度我们在这里暂且认为是常数时间。

#include <exception>
#include <iostream>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>

template <unsigned int Mod>
class Fp {
  static_assert(static_cast<int>(Mod) > 1);

 public:
  Fp() : v_() {}

  Fp(int v) : v_(safe_mod(v)) {}

  static unsigned int safe_mod(int v) {
    v %= static_cast<int>(Mod);
    return v < 0 ? v + static_cast<int>(Mod) : v;
  }

  unsigned int value() const { return v_; }

  Fp operator-() const { return Fp(Mod - v_); }

  Fp pow(int e) const {
    if (e < 0) return inv().pow(-e);
    for (Fp x(*this), res(1);; x *= x) {
      if (e & 1) res *= x;
      if ((e >>= 1) == 0) return res;
    }
  }

  Fp inv() const {
    int x1 = 1, x3 = 0, a = v_, b = Mod;
    while (b != 0) {
      int q = a / b, x1_old = x1, a_old = a;
      x1 = x3, x3 = x1_old - x3 * q, a = b, b = a_old - b * q;
    }
    return Fp(x1);
  }

  Fp &operator+=(const Fp &rhs) {
    if ((v_ += rhs.v_) >= Mod) v_ -= Mod;
    return *this;
  }

  Fp &operator-=(const Fp &rhs) {
    if ((v_ += Mod - rhs.v_) >= Mod) v_ -= Mod;
    return *this;
  }

  Fp &operator*=(const Fp &rhs) {
    v_ = static_cast<unsigned long long>(v_) * rhs.v_ % Mod;
    return *this;
  }

  Fp &operator/=(const Fp &rhs) { return operator*=(rhs.inv()); }

  void swap(Fp &rhs) {
    unsigned int v = v_;
    v_ = rhs.v_, rhs.v_ = v;
  }

  friend Fp operator+(const Fp &lhs, const Fp &rhs) { return Fp(lhs) += rhs; }

  friend Fp operator-(const Fp &lhs, const Fp &rhs) { return Fp(lhs) -= rhs; }

  friend Fp operator*(const Fp &lhs, const Fp &rhs) { return Fp(lhs) *= rhs; }

  friend Fp operator/(const Fp &lhs, const Fp &rhs) { return Fp(lhs) /= rhs; }

  friend bool operator==(const Fp &lhs, const Fp &rhs) {
    return lhs.v_ == rhs.v_;
  }

  friend bool operator!=(const Fp &lhs, const Fp &rhs) {
    return lhs.v_ != rhs.v_;
  }

  friend std::istream &operator>>(std::istream &lhs, Fp &rhs) {
    int v;
    lhs >> v;
    rhs = Fp(v);
    return lhs;
  }

  friend std::ostream &operator<<(std::ostream &lhs, const Fp &rhs) {
    return lhs << rhs.v_;
  }

 private:
  unsigned int v_;
};

template <typename T>
class Poly : public std::vector<T> {
 public:
  using std::vector<T>::vector;  // 使用继承的构造函数

  bool is_zero() const { return deg() == -1; }

  void shrink() { this->resize(std::max(deg() + 1, 1)); }

  int deg()
      const {  // 多项式的次数,当多项式为零时度数为 -1 而不是一般定义的负无穷
    int d = static_cast<int>(this->size()) - 1;
    const T z;
    while (d >= 0 && this->operator[](d) == z) --d;
    return d;
  }

  T leading_coeff() const {
    int d = deg();
    return d == -1 ? T() : this->operator[](d);
  }

  Poly operator-() const {
    Poly res;
    res.reserve(this->size());
    for (auto &&i : *this) res.emplace_back(-i);
    res.shrink();
    return res;
  }

  Poly &operator+=(const Poly &rhs) {
    if (this->size() < rhs.size()) this->resize(rhs.size());
    for (int i = 0, e = static_cast<int>(rhs.size()); i != e; ++i)
      this->operator[](i) += rhs[i];
    shrink();
    return *this;
  }

  Poly &operator-=(const Poly &rhs) {
    if (this->size() < rhs.size()) this->resize(rhs.size());
    for (int i = 0, e = static_cast<int>(rhs.size()); i != e; ++i)
      this->operator[](i) -= rhs[i];
    shrink();
    return *this;
  }

  Poly &operator*=(const Poly &rhs) {
    int n = deg(), m = rhs.deg();
    if (n == -1 || m == -1) return operator=(Poly{0});
    Poly res(n + m + 1);
    for (int i = 0; i <= n; ++i)
      for (int j = 0; j <= m; ++j) res[i + j] += this->operator[](i) * rhs[j];
    return operator=(res);
  }

  Poly &operator/=(const Poly &rhs) {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    if (q <= -1) return operator=(Poly{0});
    Poly res(q + 1);
    const T iv = 1 / rhs.leading_coeff();
    for (int i = q; i >= 0; --i)
      if ((res[i] = this->operator[](n--) * iv) != T())
        for (int j = 0; j != m; ++j) this->operator[](i + j) -= res[i] * rhs[j];
    return operator=(res);
  }

  Poly &operator%=(const Poly &rhs) {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    const T iv = 1 / rhs.leading_coeff();
    for (int i = q; i >= 0; --i)
      if (T res = this->operator[](n--) * iv; res != T())
        for (int j = 0; j <= m; ++j) this->operator[](i + j) -= res * rhs[j];
    shrink();
    return *this;
  }

  std::pair<Poly, Poly> div_mod(const Poly &rhs) const {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    if (q <= -1) return std::make_pair(Poly{0}, Poly(*this));
    const T iv = 1 / rhs.leading_coeff();
    Poly quo(q + 1), rem(*this);
    for (int i = q; i >= 0; --i)
      if ((quo[i] = rem[n--] * iv) != T())
        for (int j = 0; j <= m; ++j) rem[i + j] -= quo[i] * rhs[j];
    rem.shrink();
    return std::make_pair(quo, rem);  // (quotient, remainder)
  }

  T eval(const T &pt) const {
    T res;
    for (int i = deg(); i >= 0; --i) res = res * pt + this->operator[](i);
    return res;
  }

  friend Poly operator+(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) += rhs;
  }

  friend Poly operator-(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) -= rhs;
  }

  friend Poly operator*(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) *= rhs;
  }

  friend Poly operator/(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) /= rhs;
  }

  friend Poly operator%(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) %= rhs;
  }

  friend bool operator==(const Poly &lhs, const Poly &rhs) {
    int d = lhs.deg();
    if (d != rhs.deg()) return false;
    for (; d >= 0; --d)
      if (lhs[d] != rhs[d]) return false;
    return true;
  }

  friend bool operator!=(const Poly &lhs, const Poly &rhs) {
    return !(lhs == rhs);
  }

  friend std::ostream &operator<<(std::ostream &lhs, const Poly &rhs) {
    int s = 0, e = static_cast<int>(rhs.size());
    lhs << '[';
    for (auto &&i : rhs) {
      lhs << i;
      if (s >= 1) lhs << 'x';
      if (s > 1) lhs << '^' << s;
      if (++s != e) lhs << " + ";
    }
    return lhs << ']';
  }
};

template <typename T>
Poly<T> lagrange_interpolation(const std::vector<T> &x,
                               const std::vector<T> &y) {
  if (x.size() != y.size()) throw std::runtime_error("x.size() != y.size()");
  const int n = static_cast<int>(x.size());
  Poly<T> M = {T(1)}, f;
  for (int i = 0; i != n; ++i) M *= Poly<T>{-x[i], T(1)};
  for (int i = 0; i != n; ++i) {
    auto m = M / Poly<T>{-x[i], T(1)};
    f += Poly<T>{y[i] / m.eval(x[i])} * m;
  }
  return f;
}

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  using Z = Fp<998244353>;
  int n;
  Z k;
  std::cin >> n >> k;
  std::vector<Z> x(n), y(n);
  for (int i = 0; i != n; ++i) std::cin >> x[i] >> y[i];
  std::cout << lagrange_interpolation(x, y).eval(k) << std::endl;
  return 0;
}

本题中只用求出 \(f(k)\) 的值,所以在计算上式的过程中直接将 \(k\) 代入即可。

\[ f(k)=\sum_{i=1}^{n}y_i\prod_{j\neq i }\frac{k-x_j}{x_i-x_j} \]

本题中,还需要求解逆元。如果先分别计算出分子和分母,再将分子乘进分母的逆元,累加进最后的答案,时间复杂度的瓶颈就不会在求逆元上,时间复杂度为 \(O(n^2)\)

横坐标是连续整数的拉格朗日插值

如果已知点的横坐标是连续整数,我们可以做到 \(O(n)\) 插值。

设要求 \(n\) 次多项式为 \(f(x)\),我们已知 \(f(1),\cdots,f(n+1)\)\(1\le i\le n+1\)),考虑代入上面的插值公式:

\[ \begin{aligned} f(x)&=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j\ne i}\frac{x-j}{i-j} \end{aligned} \]

后面的累乘可以分子分母分别考虑,不难得到分子为:

\[ \dfrac{\prod\limits_{j=1}^{n+1}(x-j)}{x-i} \]

分母的 \(i-j\) 累乘可以拆成两段阶乘来算:

\[ (-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)! \]

于是横坐标为 \(1,\cdots,n+1\) 的插值公式:

\[ f(x)=\sum\limits_{i=1}^{n+1}y_i\cdot\frac{\prod\limits_{j=1}^{n+1}(x-j)}{(x-i)\cdot(-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)!} \]

预处理 \((x-i)\) 前后缀积、阶乘阶乘逆,然后代入这个式子,复杂度为 \(O(n)\)

例题 CF622F The Sum of the k-th Powers

给出 \(n,k\),求 \(\sum\limits_{i=1}^ni^k\)\(10^9+7\) 取模的值。

本题中,答案是一个 \(k+1\) 次多项式,因此我们可以线性筛出 \(1^i,\cdots,(k+2)^i\) 的值然后进行 \(O(n)\) 插值。

也可以通过组合数学相关知识由差分法的公式推得下式:

\[ f(x)=\sum_{i=1}^{n+1}\binom{x-1}{i-1}\sum_{j=1}^{i}(-1)^{i+j}\binom{i-1}{j-1}y_{j}=\sum\limits_{i=1}^{n+1}y_i\cdot\frac{\prod\limits_{j=1}^{n+1}(x-j)}{(x-i)\cdot(-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)!} \]
代码实现
// By: Luogu@rui_er(122461)
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5, mod = 1e9 + 7;

int n, k, tab[N], p[N], pcnt, f[N], pre[N], suf[N], fac[N], inv[N], ans;

int qpow(int x, int y) {
  int ans = 1;
  for (; y; y >>= 1, x = 1LL * x * x % mod)
    if (y & 1) ans = 1LL * ans * x % mod;
  return ans;
}

void sieve(int lim) {
  f[1] = 1;
  for (int i = 2; i <= lim; i++) {
    if (!tab[i]) {
      p[++pcnt] = i;
      f[i] = qpow(i, k);
    }
    for (int j = 1; j <= pcnt && 1LL * i * p[j] <= lim; j++) {
      tab[i * p[j]] = 1;
      f[i * p[j]] = 1LL * f[i] * f[p[j]] % mod;
      if (!(i % p[j])) break;
    }
  }
  for (int i = 2; i <= lim; i++) f[i] = (f[i - 1] + f[i]) % mod;
}

int main() {
  scanf("%d%d", &n, &k);
  sieve(k + 2);
  if (n <= k + 2) return printf("%d\n", f[n]) & 0;
  pre[0] = suf[k + 3] = 1;
  for (int i = 1; i <= k + 2; i++) pre[i] = 1LL * pre[i - 1] * (n - i) % mod;
  for (int i = k + 2; i >= 1; i--) suf[i] = 1LL * suf[i + 1] * (n - i) % mod;
  fac[0] = inv[0] = fac[1] = inv[1] = 1;
  for (int i = 2; i <= k + 2; i++) {
    fac[i] = 1LL * fac[i - 1] * i % mod;
    inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
  }
  for (int i = 2; i <= k + 2; i++) inv[i] = 1LL * inv[i - 1] * inv[i] % mod;
  for (int i = 1; i <= k + 2; i++) {
    int P = 1LL * pre[i - 1] * suf[i + 1] % mod;
    int Q = 1LL * inv[i - 1] * inv[k + 2 - i] % mod;
    int mul = ((k + 2 - i) & 1) ? -1 : 1;
    ans = (ans + 1LL * (Q * mul + mod) % mod * P % mod * f[i] % mod) % mod;
  }
  printf("%d\n", ans);
  return 0;
}