1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include " paddle/fluid/lite/kernels/ arm/math/packed_sgemm.h"
15+ #include " paddle/fluid/lite/arm/math/packed_sgemm.h"
1616#include < arm_neon.h>
1717
1818namespace paddle {
1919namespace lite {
20- namespace kernels {
2120namespace arm {
2221namespace math {
2322
@@ -68,7 +67,7 @@ void prepackA(float *out, const float *in, const int ldin, const int m0,
6867 prepackA_8x12 (out, in, ldin, m0, mmax, k0, kmax);
6968 }
7069#else
71- if (ctx->get_arch () == kA73 ) {
70+ if (ctx->arch () == kA73 ) {
7271 if (is_trans) {
7372 prepackA_trans_4x8 (out, in, ldin, m0, mmax, k0, kmax);
7473 } else {
@@ -86,7 +85,7 @@ void prepackA(float *out, const float *in, const int ldin, const int m0,
8685
8786void prepackA (TensorLite *tout, const TensorLite &tin, int m, int k, int group,
8887 bool is_trans, ARMContext *ctx) {
89- int hblock = get_hblock (ctx->get_arch ());
88+ int hblock = get_hblock (ctx->arch ());
9089 int m_roundup = hblock * ((m + hblock - 1 ) / hblock);
9190 int group_size_round_up = ((m_roundup * k + 15 ) / 16 ) * 16 ;
9291 if (tout->numel () < group_size_round_up * group) {
@@ -112,7 +111,7 @@ void sgemm_prepack(const float *A_packed, const float *B, const float *bias,
112111 sgemm_conv_8x12 (A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB,
113112 ctx);
114113#else // armv7
115- if (ctx->get_arch () == kA73 ) {
114+ if (ctx->arch () == kA73 ) {
116115 sgemm_conv_4x8 (A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB,
117116 ctx);
118117 } else {
@@ -1521,8 +1520,8 @@ void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias,
15211520 bool transB, ARMContext *ctx) {
15221521 size_t l2_cache =
15231522 ctx->l2_cache_size () > 0 ? ctx->l2_cache_size () : 512 * 1024 ;
1524- float *workspace = ctx->get_workspace_data <float >();
1525- int threads = ctx->get_threads ();
1523+ float *workspace = ctx->workspace_data <float >();
1524+ int threads = ctx->threads ();
15261525 // ! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
15271526 int x_block = (l2_cache - (MBLOCK * K)) / (sizeof (float ) * (K + MBLOCK));
15281527 x_block /= NBLOCK;
@@ -2359,8 +2358,8 @@ void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias,
23592358 bool transB, ARMContext* ctx) {
23602359 size_t l2_cache =
23612360 ctx->l2_cache_size () > 0 ? ctx->l2_cache_size () : 512 * 1024 ;
2362- auto * workspace = ctx->get_workspace_data <float >();
2363- int threads = ctx->get_threads ();
2361+ auto * workspace = ctx->workspace_data <float >();
2362+ int threads = ctx->threads ();
23642363 // ! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
23652364 int x_block =
23662365 (l2_cache - (MBLOCK_OTH * K)) / (sizeof (float ) * (K + MBLOCK_OTH));
@@ -2753,7 +2752,7 @@ void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias,
27532752 size_t l2_cache =
27542753 ctx->l2_cache_size () > 0 ? ctx->l2_cache_size () : 512 * 1024 ;
27552754 void * workspace = ctx->get_work_space ();
2756- int threads = ctx->get_threads ();
2755+ int threads = ctx->threads ();
27572756 // ! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
27582757 int x_block =
27592758 (l2_cache - (MBLOCK_A73 * K)) / (sizeof (float ) * (K + MBLOCK_A73));
@@ -3046,6 +3045,5 @@ void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias,
30463045
30473046} // namespace math
30483047} // namespace arm
3049- } // namespace kernels
30503048} // namespace lite
30513049} // namespace paddle
0 commit comments