1. 程式人生 > >題解【luogu2045 方格取數遊戲加強版】

題解【luogu2045 方格取數遊戲加強版】

Description

給出一個 \(n*n\) 的矩陣,每一格有一個非負整數 \(A_{i,j}\) ,(\(A_{i,j} <= 1000\))現在從 \((1,1)\) 出發,可以往右或者往下走,最後到達 \((n,n)\) ,每達到一格,把該格子的數取出來,該格子的數就變成 \(0\) ,這樣一共走 \(K\) 次,現在要求 \(K\) 次所達到的方格的數的和最大

Solution

一條邊 \((a,b)\) 表示容量為 \(a\) ,費用為 \(b\)
把每個點拆成兩個點,入點和出點。入點用來接受邊,出點用來發出邊
源點向 \((1,1)\) 連一條邊 \((k,0)\)\((n,n)\)

向匯點連一條 \((k,0)\) ,表示可以走 \(k\)
每個點往他的右和下分別連一條 \((\infty, 0)\) 表示聯通關係
每個點的入點與出點之間連兩條邊 \((1,x)\)\((\infty, 0)\)\(x\) 是該點的權值。
這是因為每個點只能取一次。
然後跑一遍最大費用最大流就完事啦
小技巧:把費用取負然後跑最小費用最大流

Code

#include <bits/stdc++.h>
using namespace std;
const int INF = 1000000000;
const int N = 550; 
int n, m, cnt, vis[N * N * 3], dis[N * N * 3]; 
int S, T, k, pre[N * N * 3], f[N * N * 3];  
struct node {
  int d, sid, tid; 
}a[N][N];
struct edge {
  int v, w, f; edge *next, *rev;
}pool[N * N * 2], *head[N * N * 3], *r[N * N * 3]; 
inline void addedge(int u, int v, int f, int w) {
  edge *p = &pool[++cnt], *q = &pool[++cnt]; 
  p->v = v, p->f = f, p->w = w,  p->next = head[u], head[u] = p; p->rev = q; 
  q->v = u, q->f = 0, q->w = -w, q->next = head[v], head[v] = q; q->rev = p;
}
inline bool spfa() {
  for(int i = S; i <= T; i++) r[i] = NULL, dis[i] = INF, vis[i] = 0, pre[i] = -1; 
  queue <int> Q; Q.push(S); vis[S] = 1; dis[S] = 0; f[S] = INF; 
  while(!Q.empty()) {
    int u = Q.front(); Q.pop(); vis[u] = 0; 
    for(edge *p = head[u]; p; p = p->next) {
      int v = p->v; 
      if(p->f > 0 && dis[v] > dis[u] + p->w) {
        dis[v] = dis[u] + p->w;
        pre[v] = u, r[v] = p; 
        f[v] = min(f[u], p->f); 
        if(!vis[v]) vis[v] = 1, Q.push(v); 
      }
    }
  } 
  return pre[T] != -1; 
}
inline int MCMF() {
  int ans = 0;  
  while(spfa()) {
    for(int i = T; i != S; i = pre[i]) {
      r[i]->f -= f[T]; r[i]->rev->f += f[T];  
    } ans += dis[T] * f[T];
  } return ans; 
}
int main() { 
  scanf("%d %d", &n, &k); S = 0, T = 2 * n * n + 1; 
  addedge(S, 1, k, 0); 
  for(int i = 1; i <= n; i++) {
    for(int j = 1; j <= n; j++) {
      int x; scanf("%d", &x); 
      int id = (i - 1) * n + j; 
      a[i][j].sid = 2 * id - 1;
      a[i][j].tid = 2 * id; 
      addedge(a[i][j].sid, a[i][j].tid, 1, -x); 
      addedge(a[i][j].sid, a[i][j].tid, INF, 0); 
    }
  } 
  for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++) {
      if(i < n) addedge(a[i][j].tid, a[i + 1][j].sid, INF, 0);
      if(j < n) addedge(a[i][j].tid, a[i][j + 1].sid, INF, 0); 
    }
  addedge(a[n][n].tid, T, k, 0); 
  printf("%d\n", -MCMF());
  return 0; 
}