enable mkldnn_concat layer#5705
Conversation
| outFmt = format::nc; | ||
| } else { | ||
| outFmt = format::nchw; | ||
| } |
There was a problem hiding this comment.
我看has16c, has8c和hasnc只用了一次,所以107-130行能简化成:
if (inputs[i]->getFormat() == format::nc && oc_ % 16 == 0) {
outFmt = format::nChw16c;
}
...
There was a problem hiding this comment.
不是的,在108行这块地方,是初始化这些bool,代表input的所有layer中有一个就设为true了。同时,要先等初始化完,然后在根据先后顺序判断需要的格式,所以不好放在一起做掉。
| channels_.resize(inputLayers_.size()); | ||
| channels_[0] = ic; | ||
| // need change the output channel, so use oc_ instead | ||
| // TODO(TJ): change API, use &oc |
There was a problem hiding this comment.
这段注释没懂,change API,什么API呢? change the output channel成什么样呢?
There was a problem hiding this comment.
这里指的是 reshape这个函数,因为只有oc没有用应用,但是定义的时候考虑到oc不希望被改的,但是这个layer是可以改oc的,但是改起来会牵扯到别的很多layer,所以后面马上可以统一改为引用。
There was a problem hiding this comment.
可以。本来是想这个时候改的,但是因为会牵扯到很多layer,不适合在concat这个PR里面提,所以决定分开提。
| bool has8c = false, has16c = false, hasnc = false; | ||
| for (size_t i = 0; i < inputs.size(); i++) { | ||
| // resetInValue will use ic_ so temporary change as current input's channel | ||
| // TODO(TJ): change ic_ as vector then can remove channels_ |
There was a problem hiding this comment.
so temporary change as? 缺宾语,change 什么 as current input's channel?
There was a problem hiding this comment.
resetInValue will use ic_ so temporary change as ic_ current input's channel
同上,这里也是有一个TODO的。这里的注释也是要去掉的,所以没注意太多。
准备马上就要去掉了,如果需要也可以补好。
| VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName() | ||
| << ": " << inGrads_[i]->getFormat() << "<<<"; | ||
| } | ||
| } |
There was a problem hiding this comment.
这几个print函数和基类的区别在什么地方呢?我看就多打印了<< i << ", " << inputLayers_[i]->getName()。是不是增强下基类函数的功能,就不用重写了呢
除了MKLDNNAddtoLayer有单独实现这部分功能,其他layer都没有。
There was a problem hiding this comment.
是的,只有Addto和concat需要多个input输入,所以重写了,但是同上面的TODO,这些都是想加到MKLDNNLayer里面的,所以先override了。
There was a problem hiding this comment.
TODO的话,能不能在issue或者project的card里记一下。已经有挺多TODO了。
| std::vector<std::shared_ptr<mkldnn::primitive>>& prims, | ||
| std::vector<MKLDNNMatrixPtr>& inputs, | ||
| MKLDNNMatrixPtr& out); | ||
| }; |
There was a problem hiding this comment.
.h文件这里的一系列reset*函数,是不是可以都去掉。如果是继承了父类,只要在.cpp里直接实现即可。其他layer同。
There was a problem hiding this comment.
这里是不可以的,子类还是在头文件里面声明要override掉父类的某个函数。
fix #5728