您的位置:首页 > 其它

Regionals 2015 :: Asia - Daejeon A题

2016-04-10 12:23 507 查看
KM算法

点击打开链接 提交地址

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <queue>
#include <string>
#include <string.h>
using namespace std;
const int inf = 1000000000 ;

const int maxn = 508 ;
bool sx[maxn], sy[maxn] ;
int  match[maxn], w[maxn][maxn] ;
int  n , m , lx[maxn] , ly[maxn] ;
//n:左集元素个数; m:右集元素个数
void init (){
memset(w , 0 , sizeof(w)) ;    //不一定要,求最小值一般要初始化为负无穷!
}

int  dfs(int u){
sx[u] = 1 ;
for(int v = 0 ; v < m ; v++){
if(!sy[v] && lx[u] + ly[v] == w[u][v]){
sy[v] = 1 ;
if(match[v] == -1 || dfs(match[v])){
match[v] = u;
return 1 ;
}
}
}
return 0  ;
}

int KM(){
int i , j , k ;
memset(ly , 0 , sizeof(ly));
for(i = 0 ; i < n ; i++){
lx[i] = -inf ;
for(j = 0; j < m; j++) lx[i] = max(lx[i] , w[i][j]) ;
}
memset(match , -1 , sizeof(match)) ;
for(i = 0 ; i < n ; i++){
while(1){
memset(sx , 0 , sizeof(sx)) ;
memset(sy , 0 , sizeof(sy)) ;
if(dfs(i)) break ;
int d = inf;
for(j = 0 ; j < n ; j++){
if(sx[j]){
for(k = 0; k < m; k++)
if(!sy[k]) d = min(d , lx[j]+ly[k]-w[j][k]) ;
}
}
if(d == inf)    //找不到完美匹配
return -1 ;
for(j = 0 ; j < n ; j++){
if(sx[j])   lx[j] -= d ;
}
for(j = 0 ; j < m ; j++){
if(sy[j])  ly[j] += d ;
}
}
}
int sum = 0 , cnt = 0 ;
for(i = 0 ; i < m ; i++){
if(match[i] > -1 && w[match[i]][i] != -inf){
sum += w[match[i]][i];
cnt++ ;
}
}
if(cnt != n)  return -1 ;
return sum;
}

struct E{
int v ,  next ;
}e[500*500*2 + 10] ;
int g[508] ;
int id ;
void add(int u , int v){
e[id].v = v ;
e[id].next = g[u] ;
g[u] = id++ ;
}

int dist[508] ;
bool in[508] ;
void spfa(int start , int N){
queue<int> q ;
memset(in , 0 , sizeof(in)) ;
q.push(start) ;
in[start] = 1 ;
for(int i = 0 ; i <= N+1 ; i++) dist[i] =  10000000 ;
dist[start] = 0 ;
while(! q.empty()){
int u = q.front() ;
q.pop() ;
in[u] = 0 ;
for(int i = g[u] ; i != -1 ; i = e[i].next){
int v = e[i].v ;
if(dist[u] + 1 < dist[v]){
dist[v] = dist[u] + 1 ;
if(! in[v]){
in[v] = 1 ;
q.push(v) ;
}
}
}
}
}

int one[508] , two[508] ;

int  main(){
int u , v   ;
int t , N , M ;
cin>>t ;
while(t--){
cin>>N>>M  ;

id = 0 ;
memset(g , -1 , sizeof(g)) ;

for(int i = 0 ; i < M ; i++){
scanf("%d%d" , &u , &v) ;
add(u , v) ;
add(v , u) ;
}
n = 0 ;
for(int i = 1 ; i <= N ; i++){
scanf("%d" , &u) ;
if(u == 0){
one[n++] = i ;
}
}
n = 0 ;
for(int i = 1 ; i <= N ; i++){
scanf("%d" , &u) ;
if(u == 0){
two[n++] = i ;
}
}

m = n ;
for(int i = 0 ; i < n ; i++){
spfa(one[i] , N) ;
for(int j = 0 ; j < m ; j++)
w[i][j] = -dist[two[j]] ;
}
int ans = -KM() ;
printf("%d\n" , ans) ;

}
return  0 ;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: