-
Notifications
You must be signed in to change notification settings - Fork 1.3k
#使用Transformer进行文本分类#代码提交 #937
base: develop
Are you sure you want to change the base?
Conversation
| "source": [ | ||
| "import paddle\n", | ||
| "import paddle.nn as nn\n", | ||
| "import paddle.fluid.dygraph as dg\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paddle2.0不建议使用fluid,默认动态图开发模式。
| "pad_id = word_dict['<pad>']\r\n", | ||
| "embed_dim = 32 # Embedding size for each token\r\n", | ||
| "num_heads = 2 # Number of attention heads\r\n", | ||
| "ff_dim = 32 # Hidden layer size in feed forward network inside transformer\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ff_dim变量命名不是很清晰。
| " x = self.drop2(x)\r\n", | ||
| " x = self.soft(x)\r\n", | ||
| " return x\r\n", | ||
| "# class MyNet(paddle.nn.Layer):\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处注释可删除。
| }, | ||
| "source": [ | ||
| "可以看到经过两轮的迭代训练,可以达到85%左右的准确率,当然你也可以通过调整参数、更改优化方式等等来进一步提升性能。" | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可使用model.predict进行预测,打印出句子,预测标签和实际标签,这样比较直观。
根据要求进行了相应的修改,并已同步更新至AIStudio
|
根据要求进行了相应的修改,并已同步更新至AIStudio |
chenxiaozeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 suggestions.
| "class PointWiseFeedForwardNetwork(nn.Layer):\r\n", | ||
| " def __init__(self, embed_dim, feed_dim):\r\n", | ||
| " super(PointWiseFeedForwardNetwork, self).__init__()\r\n", | ||
| " self.linear1 = pd.fluid.dygraph.Linear(embed_dim, feed_dim, act='relu')\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多处fluid需要改成nn
| " loss=nn.CrossEntropyLoss())\r\n", | ||
| "\r\n", | ||
| "# 模型训练\r\n", | ||
| "model.fit(train_loader,\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练完成之后,可以调用model.predict()测试下模型在test数据集上的表现。
| @@ -0,0 +1 @@ | |||
|
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file, to delete?
|
根据要求进行了相应的修改,并已同步更新至AIStudio |
chenxiaozeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "class TransformerBlock(nn.Layer):\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paddle中已经提供了Transformer的相关API https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/nn/layer/transformer/TransformerEncoder_cn.html#transformerencoder ,如果只是为了使用而不是要说明这些具体实现的话,可否直接使用这些API呢
TCChenlong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除了上述问题外,还有两处需要注意下:
1、2.0已经发布了,麻烦更新到2.0版本;
2、看预测的效果不是特别好,可以再优化一下网络
感谢~
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import paddle as pd\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle
| "source": [ | ||
| "import paddle as pd\n", | ||
| "import paddle.nn as nn\n", | ||
| "import paddle.nn.functional as func\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不推荐这么写
| "train_dataset = IMDBDataset(train_sents, train_labels)\r\n", | ||
| "test_dataset = IMDBDataset(test_sents, test_labels)\r\n", | ||
| "\r\n", | ||
| "train_loader = pd.io.DataLoader(train_dataset, places=pd.CPUPlace(), return_list=True,\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
places=pd.CPUPlace() 可以删除
项目地址:https://aistudio.baidu.com/aistudio/projectdetail/1247954