您的位置:首页 > 其它

Kundu and Tree

2015-05-29 15:55 323 查看
Problem Statement

Russian

Chinese

Kundu是树的一个真正爱好者。树是一个包含N个点和N-1条边的连通图。今天,当他得到一棵树的时候,他给每条边涂上红色(‘r’)和黑色(’b’)之中的一种颜色。他有兴趣知道有多少个节点的三元组(a,b,c)满足在节点a到节点b、节点b到节点c和节点c到节点a的路径上,每条路径都至少有一条边是红色的。 请注意(a,b,c), (b,a,c)以及所有其他排列被认为是相同的三元组。 如果结果不小于 109 +
7, 输出结果对109 +
7取余的结果(%)。

输入格式

第一行包含一个整数N,测试数据的组数。 接下来的N-1行是边和颜色的表示,一组整数表示边,后面跟有边的颜色。颜色用一个小写英文字母来表示,是红色(‘r’)或者黑色(‘b’)。整数与整数、颜色字母之间由一个空格分隔。

输出格式

输出一个整数,表示三元组的数目。

*原始条件

1 ≤ N ≤
105

节点编号从1到 N。

输入样例
[code]5
1 2 b
2 3 r
3 4 r
4 5 b


*输出样例
[code]4


解释

给定的树如此



(2,3,4) 是一个满足条件的三元组,因为从2到3,3到4和2到4的每条路径都至少有一条红色的边。

(2,3,5), (1,3,4)和(1,3,5)是满足条件的其他三元组。

并查集+组合数学

#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <queue>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod = 1000000000 + 7;
const int maxn = 100010;
int parent[maxn];

void init_set() {
    memset(parent, -1, sizeof(parent));
}

int find_set(int x) {
    return parent[x] < 0 ?  x : parent[x] = find_set(parent[x]);
}

void union_set(int x, int y) {
    int r1 = find_set(x);
    int r2 = find_set(y);
    if(r1 == r2) return ;
    if(parent[r1] < parent[r2]) {
        parent[r1] += parent[r2];
        parent[r2] = r1;
    } else {
        parent[r2] += parent[r1];
        parent[r1] = r2;
    }
    return ;
}

ll C(ll a, ll b) {
    ll res = 1LL;
    for(ll i = 0; i < b; ++i) {
        res *= (a - i);
        res /= (i + 1);
    }
    return res % mod;
}

int main() {

    //freopen("aa.in", "r", stdin);

    int n;
    int u, v;
    ll ans = 0;
    char col;
    vector<int> vec;
    init_set();
    scanf("%d", &n);
    for(int i = 1; i < n; ++i) {
        scanf("%d %d %c", &u, &v, &col);
        if(col == 'b') {
            union_set(u, v);
        }
    }
    for(int i = 1; i <= n; ++i) {
        if(parent[i] < 0) {
            vec.push_back(-parent[i]);
        }
    }
    int len = vec.size();
    for(int i = 0; i < len; ++i) {
        if(vec[i] > 1) {
            ans = (ans + C(vec[i], 2) * (n - vec[i])) % mod;
        }
        if(vec[i] > 2) {
            ans = (ans + C(vec[i], 3)) % mod;
        }
    }
    ans = ((C(n, 3) - ans) % mod + mod) % mod;
    printf("%lld\n", ans);
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: