Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 133 additions & 31 deletions tools/speech_data_explorer/data_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
}
comparison_mode = False


# parse table filter queries
def split_filter_part(filter_part):
for op in filter_operators:
Expand All @@ -83,7 +84,8 @@ def split_filter_part(filter_part):
def parse_args():
parser = argparse.ArgumentParser(description='Speech Data Explorer')
parser.add_argument(
'manifest', help='path to JSON manifest file',
'manifest',
help='path to JSON manifest file',
)
parser.add_argument('--vocab', help='optional vocabulary to highlight OOV words')
parser.add_argument('--port', default='8050', help='serving port for establishing connection')
Expand Down Expand Up @@ -219,7 +221,9 @@ def load_data(
match_vocab_2 = defaultdict(lambda: 0)

def append_data(
data_filename, estimate_audio, field_name='pred_text',
data_filename,
estimate_audio,
field_name='pred_text',
):
data = []
wer_dist = 0.0
Expand Down Expand Up @@ -703,7 +707,9 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
class_name='border-end',
),
dbc.Col(
html.Div('Word Match Rate (WMR), %', className='text-secondary'), width=3, class_name='border-end',
html.Div('Word Match Rate (WMR), %', className='text-secondary'),
width=3,
class_name='border-end',
),
dbc.Col(html.Div('Mean Word Accuracy, %', className='text-secondary'), width=3),
],
Expand All @@ -713,7 +719,9 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
[
dbc.Col(
html.H5(
'{:.2f}'.format(wer), className='text-center p-1', style={'color': 'green', 'opacity': 0.7},
'{:.2f}'.format(wer),
className='text-center p-1',
style={'color': 'green', 'opacity': 0.7},
),
width=3,
class_name='border-end',
Expand All @@ -727,14 +735,18 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
),
dbc.Col(
html.H5(
'{:.2f}'.format(wmr), className='text-center p-1', style={'color': 'green', 'opacity': 0.7},
'{:.2f}'.format(wmr),
className='text-center p-1',
style={'color': 'green', 'opacity': 0.7},
),
width=3,
class_name='border-end',
),
dbc.Col(
html.H5(
'{:.2f}'.format(mwa), className='text-center p-1', style={'color': 'green', 'opacity': 0.7},
'{:.2f}'.format(mwa),
className='text-center p-1',
style={'color': 'green', 'opacity': 0.7},
),
width=3,
),
Expand All @@ -745,19 +757,30 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
stats_layout += [
dbc.Row(dbc.Col(html.H5(children='Alphabet'), class_name='text-secondary'), class_name='mt-3'),
dbc.Row(
dbc.Col(html.Div('{}'.format(sorted(alphabet))),), class_name='mt-2 bg-light font-monospace rounded border'
dbc.Col(
html.Div('{}'.format(sorted(alphabet))),
),
class_name='mt-2 bg-light font-monospace rounded border',
),
]
for k in figures_hist:
stats_layout += [
dbc.Row(dbc.Col(html.H5(figures_hist[k][0]), class_name='text-secondary'), class_name='mt-3'),
dbc.Row(dbc.Col(dcc.Graph(id='duration-graph', figure=figures_hist[k][1]),),),
dbc.Row(
dbc.Col(
dcc.Graph(id='duration-graph', figure=figures_hist[k][1]),
),
),
]

if metrics_available:
stats_layout += [
dbc.Row(dbc.Col(html.H5('Word accuracy distribution'), class_name='text-secondary'), class_name='mt-3'),
dbc.Row(dbc.Col(dcc.Graph(id='word-acc-graph', figure=figure_word_acc),),),
dbc.Row(
dbc.Col(
dcc.Graph(id='word-acc-graph', figure=figure_word_acc),
),
),
]

wordstable_columns = [{'name': 'Word', 'id': 'word'}, {'name': 'Count', 'id': 'count'}]
Expand Down Expand Up @@ -786,12 +809,21 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
sort_by=[{'column_id': 'word', 'direction': 'asc'}],
style_cell={'maxWidth': 0, 'textAlign': 'left'},
style_header={'color': 'text-primary'},
css=[{'selector': '.dash-filter--case', 'rule': 'display: none'},],
css=[
{'selector': '.dash-filter--case', 'rule': 'display: none'},
],
),
),
class_name='m-2',
),
dbc.Row(dbc.Col([html.Button('Download Vocabulary', id='btn_csv'), dcc.Download(id='download-vocab-csv'),]),),
dbc.Row(
dbc.Col(
[
html.Button('Download Vocabulary', id='btn_csv'),
dcc.Download(id='download-vocab-csv'),
]
),
),
]


Expand Down Expand Up @@ -922,7 +954,12 @@ def update_wordstable(page_current, sort_by, filter_query):
)
]
samples_layout += [
dbc.Row(dbc.Col(html.Audio(id='player', controls=True),), class_name='mt-3 '),
dbc.Row(
dbc.Col(
html.Audio(id='player', controls=True),
),
class_name='mt-3 ',
),
dbc.Row(dbc.Col(dcc.Graph(id='signal-graph')), class_name='mt-3'),
]

Expand Down Expand Up @@ -1064,7 +1101,16 @@ def draw_vocab(Ox, Oy, color, size, data, dot_spacing='no', rad=0.01):
Oy == 'accuracy_model_' + model_name_1 and Ox == 'accuracy_model_' + model_name_2
):
fig.add_shape(
type="line", x0=0, y0=0, x1=100, y1=100, line=dict(color="MediumPurple", width=1, dash="dot",)
type="line",
x0=0,
y0=0,
x1=100,
y1=100,
line=dict(
color="MediumPurple",
width=1,
dash="dot",
),
)

return fig
Expand Down Expand Up @@ -1176,7 +1222,10 @@ def display_query(query):
),
dcc.Input(id='radius', placeholder='Enter radius of spacing (std is 0.01)'),
html.Hr(),
dcc.Input(id='filter-query-input', placeholder='Enter filter query',),
dcc.Input(
id='filter-query-input',
placeholder='Enter filter query',
),
],
style={'width': '200%', 'display': 'inline-block', 'float': 'middle'},
),
Expand All @@ -1194,7 +1243,11 @@ def display_query(query):
html.Hr(),
html.Div(id='datatable-query-structure', style={'whitespace': 'pre'}),
html.Hr(),
dbc.Row(dbc.Col(dcc.Graph(id='voc_graph'),),),
dbc.Row(
dbc.Col(
dcc.Graph(id='voc_graph'),
),
),
html.Hr(),
],
id='wrd_lvl',
Expand Down Expand Up @@ -1250,7 +1303,12 @@ def display_query(query):
{'selector': '.column-header--hide', 'rule': 'display: none'},
],
),
dbc.Row(dbc.Col(html.Audio(id='player-1', controls=True),), class_name='mt-3'),
dbc.Row(
dbc.Col(
html.Audio(id='player-1', controls=True),
),
class_name='mt-3',
),
]
)
),
Expand Down Expand Up @@ -1280,14 +1338,22 @@ def display_query(query):
[
html.Div(
[
dbc.Row(dbc.Col(dcc.Graph(id='utt_graph'),),),
dbc.Row(
dbc.Col(
dcc.Graph(id='utt_graph'),
),
),
html.Hr(),
dcc.Input(id='clicked_aidopath', style={'width': '100%'}),
html.Hr(),
dcc.Input(id='my-output-1', style={'display': 'none'}), # we do need this
]
),
html.Div([dbc.Row(dbc.Col(dcc.Graph(id='signal-graph-1')), class_name='mt-3'),]),
html.Div(
[
dbc.Row(dbc.Col(dcc.Graph(id='signal-graph-1')), class_name='mt-3'),
]
),
],
id='down_thing',
style={'display': 'block'},
Expand Down Expand Up @@ -1384,15 +1450,17 @@ def show_hide_element(visibility_state):

@app.callback(
[Output('datatable-advanced-filtering-2', 'page_current'), Output('my-output-1', 'value')],
[Input('utt_graph', 'clickData'),],
[
Input('utt_graph', 'clickData'),
],
)
def real_select_click(hoverData):
if hoverData is not None:
path = str(hoverData['points'][0]['customdata'][-1])
for t in range(len(data_with_metrics)):
if data_with_metrics[t]['audio_filepath'] == path:
ind = t
s = t #% 5
s = t # % 5
sel = s
pg = math.ceil(ind // 5)
return pg, sel
Expand All @@ -1401,7 +1469,8 @@ def real_select_click(hoverData):


@app.callback(
[Output('datatable-advanced-filtering-2', 'selected_rows')], [Input('my-output-1', 'value')],
[Output('datatable-advanced-filtering-2', 'selected_rows')],
[Input('my-output-1', 'value')],
)
def real_select_click(num):
s = num
Expand Down Expand Up @@ -1448,7 +1517,18 @@ def draw_table_with_metrics(met, hoverData, data_virt):
'audio_filepath': True,
},
) #'numwords': True,
fig.add_shape(type="line", x0=0, y0=0, x1=100, y1=100, line=dict(color="Red", width=1, dash="dot",))
fig.add_shape(
type="line",
x0=0,
y0=0,
x1=100,
y1=100,
line=dict(
color="Red",
width=1,
dash="dot",
),
)
fig.update_layout(clickmode='event+select')
fig.update_traces(marker_size=10)
path = None
Expand Down Expand Up @@ -1542,11 +1622,14 @@ def nav_click(url):
else:
return [stats_layout, True, False, False]


else:

@app.callback(
[Output('page-content', 'children'), Output('stats_link', 'active'), Output('samples_link', 'active'),],
[
Output('page-content', 'children'),
Output('stats_link', 'active'),
Output('samples_link', 'active'),
],
[Input('url', 'pathname')],
)
def nav_click(url):
Expand Down Expand Up @@ -1577,9 +1660,16 @@ def show_item(idx, data):
return [data[idx[0]][k] for k in data_with_metrics[0]]


@app.callback(Output('_diff', 'srcDoc'), [Input('datatable', 'selected_rows'), Input('datatable', 'data'),])
@app.callback(
Output('_diff', 'srcDoc'),
[
Input('datatable', 'selected_rows'),
Input('datatable', 'data'),
],
)
def show_diff(
idx, data,
idx,
data,
):
if len(idx) == 0:
raise PreventUpdate
Expand All @@ -1605,10 +1695,14 @@ def show_diff(

@app.callback(
Output('__diff', 'srcDoc'),
[Input('datatable-advanced-filtering-2', 'selected_rows'), Input('datatable-advanced-filtering-2', 'data'),],
[
Input('datatable-advanced-filtering-2', 'selected_rows'),
Input('datatable-advanced-filtering-2', 'data'),
],
)
def show_diff(
idx, data,
idx,
data,
):
if len(idx) == 0:
raise PreventUpdate
Expand Down Expand Up @@ -1664,7 +1758,11 @@ def plot_signal(idx, data):
figs.add_trace(
go.Heatmap(
z=s_db,
colorscale=[[0, 'rgb(30,62,62)'], [0.5, 'rgb(30,128,128)'], [1, 'rgb(30,255,30)'],],
colorscale=[
[0, 'rgb(30,62,62)'],
[0.5, 'rgb(30,128,128)'],
[1, 'rgb(30,255,30)'],
],
colorbar=dict(yanchor='middle', lenmode='fraction', y=0.2, len=0.5, ticksuffix=' dB'),
dx=time_stride,
dy=fs / n_fft / 1000,
Expand Down Expand Up @@ -1720,7 +1818,11 @@ def plot_signal(idx, data):
figs.add_trace(
go.Heatmap(
z=s_db,
colorscale=[[0, 'rgb(30,62,62)'], [0.5, 'rgb(30,128,128)'], [1, 'rgb(30,255,30)'],],
colorscale=[
[0, 'rgb(30,62,62)'],
[0.5, 'rgb(30,128,128)'],
[1, 'rgb(30,255,30)'],
],
colorbar=dict(yanchor='middle', lenmode='fraction', y=0.2, len=0.5, ticksuffix=' dB'),
dx=time_stride,
dy=fs / n_fft / 1000,
Expand Down Expand Up @@ -1789,4 +1891,4 @@ def update_player(idx, data):


if __name__ == '__main__':
app.run_server(host='0.0.0.0', port=args.port, debug=args.debug)
app.run(host='0.0.0.0', port=args.port, debug=args.debug)
Loading