这个题目还是很有意思,它没有给出解决问题所需的全部信息,而是只给了部分信息,来猜测正确答案是什么。在一定的概率下会猜测到正确的结果。

1 问题描述

从2到M中随机选择N个数,允许重复。
再从这N个数中随机选择它的一个子集,计算出这个子集中元素的乘积$p_1$。
重复上一步,直到得到了K个乘积$p_1,p_2,…,p_k$。
要求:给出这K个乘积,推测原来的N个数
具体的输入输出要求,不详细介绍了,到codejam网站上看去吧

2 解决方案

从2到M中选择N个数,若考虑顺序那么就有$(M-1)^N$种,若不考虑顺序就大大减少了。

对于题目的第二个输入N=12, M=8, 不考虑顺序就只有18564种,不算太大。
那么,我们的任务就是根据这K个乘积,选出概率最大的一种组合
现在我们计算在乘积为$p_1,p_2,…,p_k$条件下,原组合为C的概率: $$ P(C|p_1,p_2,...,p_k) = \frac{P(C)*P(p_1,p_2,...,p_k|C)}{P(p_1,p_2,...,p_k)} $$ 这个就是贝叶斯概率公式,一般形式写作:$ P(A|B)=\frac{P(A)P(B|A)}{P(B)} $ 机器学习中有一类重要的分类方法就是基于这个贝叶斯公式。

由于我们的目标是找出概率最大的C,所以不必把这个公式中的所有的项都计算出来,例如分母为P(p1,p2,…,pk)对于特定的K个乘积在所有的组合下都是不变的。
只需要计算$P(C)$和\(P(p_1,p_2,\ldots,p_k|C)\), 对于由于组合C,由于不考虑顺序那么,每个C被取得概率就不等了,比如234会出现6次,而333只会出现一次,计算这个P(C): $$ P(C)= \frac{N!/(c_2!c_3!\ldots c_m!)}{(M-1)^N}$$ 其中$c_2,c_3,\ldots,c_i$表示组合中数字i出现的次数

$$ P(p_1,p_2,\ldots,p_k|C) = P(p_1|C)P(p_2|C)\ldots P(p_k|C)$$

其中$P(p|C) = \frac{元素积为p的子集个数}{2^N}$
枚举C的所有子集就可以计算出上述$P(p|C)$。
现在就知道怎么实现了:对于所有的组合计算条件概率,从中选出概率最大的一个组合。 由于会有很多组乘积,需要进行推测,所以这里将会有大量的重复计算,所以,最好将P和P(p|C)全部提前计算出来。

3 实现和结果分析

 

3.1 预计算

预计算:计算出某个组合的概率,和这种组合下子集积的概率

# coding: UTF-8
import cPickle as pickle
from array import array
from sys import stdout
from sys import stderr
from random import randint

def fact(n):
    if n == 0:
        return 1
    return n * fact(n-1)

def probability(nums):
    p = fact(len(nums))
    i = 1
    c = 1
    while i < len(nums):
        if (nums[i] == nums[i-1]):
            c += 1
        else:
            p /= fact(c)
            c = 1
        i += 1
    p /= fact(c)
    return p

gcount = 0
def pre_compute(N, M):
    fo = open('dump.dat', 'w')
    nums = array('i', xrange(N))
    pset = {}

    def subset(d, p):
        if d >= N:
            if p in pset:
                pset[p] += 1
            else:
                pset[p] = 1
        else:
            subset(d+1, p)
            subset(d+1, p*nums[d])

    def search(d, lt):
        global gcount
        if d >= N:
            gcount += 1
            if (gcount % 100) == 0:
                print "generated %d.." % gcount
            pickle.dump(nums, fo)
            pickle.dump(probability(nums), fo)
            pset.clear()
            subset(0, 1)
            pickle.dump(pset, fo)
        else:
            for i in xrange(lt, M+1):
                nums[d] = i
                search(d+1, i)

    search(0, 2)
    fo.close()

预计算大概需要1m多,而载入只需要2s左右,还是节省了很多时间的

载入的代码:

def load():
    fo = open("dump.dat")
    count = 0
    table = []
    while True:
        try:
            nums=pickle.load(fo)
            prob=pickle.load(fo)
            pset=pickle.load(fo)
            table.append([nums, prob, pset])
            count += 1
            if (count % 1000) == 0:
                stderr.write("loaded %d..\n" % count)
        except EOFError:
            break
    return table

3.2 处理过程,遍历所有的组合找到概率最大的组合

def process_case(case):
    table = load()

    (R, N, M, K) = map(int, raw_input().split())
    print "Case #%d:" % case
    for r in xrange(R):
        products = map(int, raw_input().split())
        max_prob = -1
        nums = []
        for item in table:
            prob = item[1]
            pset = item[2]
            for p in products:
                if p not in pset:
                    prob = 0
                    break
                prob *= pset[p]
            if prob > max_prob:
                max_prob = prob
                nums = item[0]

        for n in nums:
            stdout.write(str(n))
        stdout.write("\n")

3.3 自己构造输入输出

其实这个题是可以自己构造输入的,然后自己写个judge的程序, 对于第二种输入,阈值要求是1120,自己构造输入,然后用上面的程序得到猜测结果, 大概猜对了1300左右。

构造输入和judge代码如下:

# 生成样例数据
def generate(R, N, M, K):
    fi = open('input.txt', 'w')
    fa = open('right-answer.txt', 'w')
    fi.write("1\n%d %d %d %d\n" % (R, N, M, K))
    nums = array('i', xrange(N))
    for i in xrange(R):
        for j in xrange(N):
            nums[j] = randint(2, M)
            fa.write("%d" % nums[j])
        fa.write('\n')
        for j in xrange(K):
            p = 1
            for n in nums:
                if randint(0, 1) == 1:
                    p *= n
            fi.write("%d " % p)
        fi.write('\n')
    fi.close()
    fa.close()

def judge(submit_file, answer_file, R):
    fs = open(submit_file)
    fa = open(answer_file)

    fs.readline()
    count = 0
    for i in xrange(R):
        ls = list(fs.readline().strip())
        la = list(fa.readline().strip())
        ls.sort()
        la.sort()
        if ls == la:
            count += 1
    fs.close()
    fa.close()
    return count

4 总结

谷歌的出题思路跟技术的发展趋势是一致的,基于统计的机器学习方法有很广泛的应用。 这个题就是使用贝叶斯公式(条件概率公式)进行推断。


ps1: 上面的python代码在我的机器上运行3m,而题目要求是四分钟内提交答案,时间有点紧,勉强够用。使用C++会更快的。不过,我后来找到了一个叫pypy的python解释器,它的速度比python官方的解释器快好多,一分钟就算出结果了。
ps2: 吐嘈下python

def f():
    x = 1
    def g():
        x += 1
    g()

调用f()就会出错,直到python3才给出个nonlocal关键字解决这个问题。
关于这点,我只想说,设计语言也太不专业了。。。

Date: 2013-05-10 Fri

Author: liyongmou

Org version 7.9.2 with Emacs version 24

Validate XHTML 1.0

 

posted on 2013-05-10 21:54  yongmou-  阅读(899)  评论(0编辑  收藏  举报