团队里的深度学习工作逐渐开始落地,很多工作做得并不细致,做个阶段性的总结,后续腾出精力,再对其做更深入的挖掘。
Embedding向量的训练数据慢的问题
tf的Variable更新需要在cpu和gpu中copy数据,大部分的情况,这个都没有问题。然后我们场景的id Embedding的量级对比图像或者NLP大很多,因此需要copy的数据较大。看了下tracing的统计,平均要花100ms,导致我们的gpu利用率很低

Stack Overflow上有很多建议,我们尝试改成了constant,没有了这个copy,速度快了很多

但这样只能预训练Embedding向量了,TF上直接训练几百万的向量,gpu显存压力很大,而google以前开源的word2vec,做起来效率很高,遗憾的是只能基于word2vec的结构做了。
github上也有这个issue,还有一些其他方案https://github.com/tensorflow/tensorflow/issues/4495。
最终的方案可能要基于google着篇论文的思路来解决这个问题:https://arxiv.org/pdf/1602.02215.pdf。如果解决了由于embedding量级太大而导致oom的问题,则可以在更复杂的模型结构中让embedding向量和模型一起训练,目前我们还没有精力去尝试。该论文有tf的代码支持:https://github.com/tensorflow/models/tree/master/research/swivel
https://github.com/src-d/tensorflow-swivel
奇怪的Embedding_lookup性能问题
我们将训练好的模型部署在线上,由于线上都是java系统,采用tf的java版本。发现压测过程中,最耗时的是id的Embedding过程,lookup耗时很长。
我们分别对比了将id传入模型做embedding_lookup而后做inference,以及直接将Embedding向量在内存里存好,取出来去做inference,性能差距在几倍。
一度怀疑是压测有问题,不理解lookup为何会这么耗时,如果是hash表的查询,是很快的,对比在外部存储embedding向量,也是要查询的,不应该会慢。但几次测试的情况都如此,且用后一种方式,线上RT问题环节了很多。
由于大家的环境差距大,不好直接对比,但如果有明显不一致的结论,欢迎大家指正,共同讨论。
attention带来巨大的RT损失
Attention在模型结构上非常make sensor,但对线上inference的RT影响很大,在推荐排序领域,ROI较低。
TF graph不能大于2G
Protobuf的限制,参考:https://github.com/tensorflow/tensorflow/issues/6117。
对大的graph不友好,而电商里面最重要的商品id,量级较大,在embedding的后会产生较大的graph,要么减少id的个数,要么把embedding用其他模块预训练了。