Skip to content

Commit 660760c

Browse files
committed
add colab notebook for xlnet eval on squad 2.0
1 parent eab6401 commit 660760c

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "tutorial-predict-pipeline.ipynb",
7+
"version": "0.3.2",
8+
"provenance": []
9+
},
10+
"language_info": {
11+
"codemirror_mode": {
12+
"name": "ipython",
13+
"version": 3
14+
},
15+
"file_extension": ".py",
16+
"mimetype": "text/x-python",
17+
"name": "python",
18+
"nbconvert_exporter": "python",
19+
"pygments_lexer": "ipython3",
20+
"version": "3.7.3"
21+
},
22+
"kernelspec": {
23+
"name": "python3",
24+
"display_name": "Python 3"
25+
},
26+
"accelerator": "GPU"
27+
},
28+
"cells": [
29+
{
30+
"cell_type": "code",
31+
"metadata": {
32+
"id": "zNtCqwveFjcK",
33+
"colab_type": "code",
34+
"outputId": "6a94d325-b50a-4874-a999-59702327dcbe",
35+
"colab": {
36+
"base_uri": "https://localhost:8080/",
37+
"height": 151
38+
}
39+
},
40+
"source": [
41+
"!git clone https://github.com/cdqa-suite/cdQA.git"
42+
],
43+
"execution_count": 1,
44+
"outputs": [
45+
{
46+
"output_type": "stream",
47+
"text": [
48+
"Cloning into 'cdQA'...\n",
49+
"remote: Enumerating objects: 61, done.\u001b[K\n",
50+
"remote: Counting objects: 100% (61/61), done.\u001b[K\n",
51+
"remote: Compressing objects: 100% (49/49), done.\u001b[K\n",
52+
"remote: Total 1138 (delta 28), reused 35 (delta 12), pack-reused 1077\u001b[K\n",
53+
"Receiving objects: 100% (1138/1138), 441.88 KiB | 1.10 MiB/s, done.\n",
54+
"Resolving deltas: 100% (686/686), done.\n"
55+
],
56+
"name": "stdout"
57+
}
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"metadata": {
63+
"id": "v2XvXm4bFp7h",
64+
"colab_type": "code",
65+
"colab": {}
66+
},
67+
"source": [
68+
"import os\n",
69+
"cwd = os.getcwd()\n",
70+
"os.chdir(\"cdQA\")"
71+
],
72+
"execution_count": 0,
73+
"outputs": []
74+
},
75+
{
76+
"cell_type": "code",
77+
"metadata": {
78+
"id": "5jBtSKczGF38",
79+
"colab_type": "code",
80+
"outputId": "d657fe20-985d-4fc8-b435-794e17f77748",
81+
"colab": {
82+
"base_uri": "https://localhost:8080/",
83+
"height": 55
84+
}
85+
},
86+
"source": [
87+
"!git checkout sync-huggingface"
88+
],
89+
"execution_count": 3,
90+
"outputs": [
91+
{
92+
"output_type": "stream",
93+
"text": [
94+
"Branch 'sync-huggingface' set up to track remote branch 'sync-huggingface' from 'origin'.\n",
95+
"Switched to a new branch 'sync-huggingface'\n"
96+
],
97+
"name": "stdout"
98+
}
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"metadata": {
104+
"id": "DHl2HUX1GRd6",
105+
"colab_type": "code",
106+
"outputId": "625ba318-f7e5-4f24-98a5-cb7cbacbf175",
107+
"colab": {
108+
"base_uri": "https://localhost:8080/",
109+
"height": 170
110+
}
111+
},
112+
"source": [
113+
"!pip install -q -e ."
114+
],
115+
"execution_count": 4,
116+
"outputs": [
117+
{
118+
"output_type": "stream",
119+
"text": [
120+
"\u001b[K |████████████████████████████████| 133kB 4.2MB/s \n",
121+
"\u001b[K |████████████████████████████████| 163kB 43.6MB/s \n",
122+
"\u001b[K |████████████████████████████████| 225kB 45.6MB/s \n",
123+
"\u001b[K |████████████████████████████████| 655kB 35.1MB/s \n",
124+
"\u001b[K |████████████████████████████████| 1.0MB 37.8MB/s \n",
125+
"\u001b[?25h Building wheel for tika (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
126+
" Building wheel for wget (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
127+
" Building wheel for regex (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
128+
],
129+
"name": "stdout"
130+
}
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"metadata": {
136+
"id": "_NWD3P6qH_8_",
137+
"colab_type": "code",
138+
"colab": {}
139+
},
140+
"source": [
141+
"import wget\n",
142+
"\n",
143+
"squad_urls = [\n",
144+
" 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',\n",
145+
" 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',\n",
146+
" 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',\n",
147+
" 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'\n",
148+
"]\n",
149+
"\n",
150+
"for squad_url in squad_urls:\n",
151+
" wget.download(url=squad_url, out='.')"
152+
],
153+
"execution_count": 0,
154+
"outputs": []
155+
},
156+
{
157+
"cell_type": "code",
158+
"metadata": {
159+
"id": "ylorIsqLz_J3",
160+
"colab_type": "code",
161+
"outputId": "e6efed6f-551a-41da-9f18-62ed0417830d",
162+
"colab": {
163+
"base_uri": "https://localhost:8080/",
164+
"height": 649
165+
}
166+
},
167+
"source": [
168+
"!wget https://github.com/cdqa-suite/cdQA/releases/download/XLNet_cased_vCPU/pytorch_model.bin\n",
169+
"!wget https://github.com/cdqa-suite/cdQA/releases/download/XLNet_cased_vCPU/config.json"
170+
],
171+
"execution_count": 6,
172+
"outputs": [
173+
{
174+
"output_type": "stream",
175+
"text": [
176+
"--2019-09-01 16:22:00-- https://github.com/cdqa-suite/cdQA/releases/download/XLNet_cased_vCPU/pytorch_model.bin\n",
177+
"Resolving github.com (github.com)... 192.30.253.113\n",
178+
"Connecting to github.com (github.com)|192.30.253.113|:443... connected.\n",
179+
"HTTP request sent, awaiting response... 302 Found\n",
180+
"Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/165645094/96b5db80-aa35-11e9-8147-fbf9e537f61c?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20190901%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20190901T162200Z&X-Amz-Expires=300&X-Amz-Signature=3137e708a0e6d08e1ae399eb69fcd41e2f44a7464aa35fdb2d0643b8f5e2b628&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3Dpytorch_model.bin&response-content-type=application%2Foctet-stream [following]\n",
181+
"--2019-09-01 16:22:00-- https://github-production-release-asset-2e65be.s3.amazonaws.com/165645094/96b5db80-aa35-11e9-8147-fbf9e537f61c?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20190901%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20190901T162200Z&X-Amz-Expires=300&X-Amz-Signature=3137e708a0e6d08e1ae399eb69fcd41e2f44a7464aa35fdb2d0643b8f5e2b628&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3Dpytorch_model.bin&response-content-type=application%2Foctet-stream\n",
182+
"Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.217.38.20\n",
183+
"Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.217.38.20|:443... connected.\n",
184+
"HTTP request sent, awaiting response... 200 OK\n",
185+
"Length: 476375014 (454M) [application/octet-stream]\n",
186+
"Saving to: ‘pytorch_model.bin’\n",
187+
"\n",
188+
"pytorch_model.bin 100%[===================>] 454.31M 16.5MB/s in 30s \n",
189+
"\n",
190+
"2019-09-01 16:22:31 (15.4 MB/s) - ‘pytorch_model.bin’ saved [476375014/476375014]\n",
191+
"\n",
192+
"--2019-09-01 16:22:33-- https://github.com/cdqa-suite/cdQA/releases/download/XLNet_cased_vCPU/config.json\n",
193+
"Resolving github.com (github.com)... 192.30.253.113\n",
194+
"Connecting to github.com (github.com)|192.30.253.113|:443... connected.\n",
195+
"HTTP request sent, awaiting response... 302 Found\n",
196+
"Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/165645094/96b5db80-aa35-11e9-84be-890f3b56af43?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20190901%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20190901T162234Z&X-Amz-Expires=300&X-Amz-Signature=3344f2dcc2a5f06990fbf79137b80686ca2847d86467219e419baa69d4ed33c7&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3Dconfig.json&response-content-type=application%2Foctet-stream [following]\n",
197+
"--2019-09-01 16:22:34-- https://github-production-release-asset-2e65be.s3.amazonaws.com/165645094/96b5db80-aa35-11e9-84be-890f3b56af43?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20190901%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20190901T162234Z&X-Amz-Expires=300&X-Amz-Signature=3344f2dcc2a5f06990fbf79137b80686ca2847d86467219e419baa69d4ed33c7&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3Dconfig.json&response-content-type=application%2Foctet-stream\n",
198+
"Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.130.115\n",
199+
"Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.130.115|:443... connected.\n",
200+
"HTTP request sent, awaiting response... 200 OK\n",
201+
"Length: 641 [application/octet-stream]\n",
202+
"Saving to: ‘config.json’\n",
203+
"\n",
204+
"config.json 100%[===================>] 641 --.-KB/s in 0s \n",
205+
"\n",
206+
"2019-09-01 16:22:35 (43.7 MB/s) - ‘config.json’ saved [641/641]\n",
207+
"\n"
208+
],
209+
"name": "stdout"
210+
}
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"metadata": {
216+
"ExecuteTime": {
217+
"end_time": "2019-06-25T14:21:08.091797Z",
218+
"start_time": "2019-06-25T14:21:03.027877Z"
219+
},
220+
"id": "umJkmO9HFf3L",
221+
"colab_type": "code",
222+
"outputId": "bd5330f0-3027-4316-dcad-6937235a3911",
223+
"colab": {
224+
"base_uri": "https://localhost:8080/",
225+
"height": 133
226+
}
227+
},
228+
"source": [
229+
"import os\n",
230+
"import torch\n",
231+
"from sklearn.externals import joblib\n",
232+
"from cdqa.reader.reader_sklearn import Reader\n",
233+
"\n",
234+
"reader = Reader(model_type='xlnet',\n",
235+
" model_name_or_path='xlnet-base-cased',\n",
236+
" fp16=False,\n",
237+
" output_dir='.',\n",
238+
" no_cuda=False,\n",
239+
" pretrained_model_path='.')"
240+
],
241+
"execution_count": 7,
242+
"outputs": [
243+
{
244+
"output_type": "stream",
245+
"text": [
246+
"/usr/local/lib/python3.6/dist-packages/sklearn/externals/joblib/__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
247+
" warnings.warn(msg, category=DeprecationWarning)\n",
248+
"100%|██████████| 641/641 [00:00<00:00, 319039.86B/s]\n",
249+
"100%|██████████| 798011/798011 [00:01<00:00, 720164.52B/s]\n",
250+
"100%|██████████| 467042463/467042463 [00:38<00:00, 12212631.71B/s]\n"
251+
],
252+
"name": "stderr"
253+
}
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"metadata": {
259+
"id": "AViocaq-gnQk",
260+
"colab_type": "code",
261+
"outputId": "1a26fc7b-e900-42fc-bfca-2118a2cb880f",
262+
"colab": {
263+
"base_uri": "https://localhost:8080/",
264+
"height": 36
265+
}
266+
},
267+
"source": [
268+
"# evaluate the model\n",
269+
"reader.evaluate(X='dev-v2.0.json')"
270+
],
271+
"execution_count": 0,
272+
"outputs": [
273+
{
274+
"output_type": "stream",
275+
"text": [
276+
"Evaluating: 7%|▋ | 103/1569 [01:25<20:25, 1.20it/s]"
277+
],
278+
"name": "stderr"
279+
}
280+
]
281+
}
282+
]
283+
}

0 commit comments

Comments
 (0)