import numpy as np
import pandas as pd
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
import dtale.global_state as global_state
from dtale.dash_application.charts import (
build_chart,
bar_builder,
build_axes,
chart_builder_passthru,
)
from dtale.dash_application.utils import get_data_id
from dtale.dash_application.layout.utils import (
build_cols,
build_option,
build_selections,
AGGS,
)
from dtale.utils import (
build_query,
classify_type,
dict_merge,
find_dtype,
is_app_root_defined,
json_date,
json_float,
make_list,
run_query,
get_dtypes,
)
from dtale.charts.utils import (
MAX_GROUPS,
ZAXIS_CHARTS,
build_group_inputs_filter,
convert_date_val_to_date,
)
[docs]def build_histogram(data_id, col, query, point_filter):
data = run_query(
global_state.get_data(data_id),
query,
global_state.get_context_variables(data_id),
)
data = data.query(build_group_inputs_filter(data, [point_filter]))
s = data[~pd.isnull(data[col])][col]
hist_data, hist_labels = np.histogram(s, bins=10)
hist_labels = list(map(lambda x: json_float(x, precision=3), hist_labels[1:]))
axes_builder = build_axes(
None,
"Bins",
dict(type="single", data={}),
dict(Frequency=0),
dict(Frequency=max(hist_data)),
data=pd.DataFrame(dict(Frequency=hist_data, Bins=hist_labels)),
)
hist_data = dict(data={"all": dict(x=hist_labels, Frequency=hist_data)})
bars = bar_builder(
hist_data,
"Bins",
["Frequency"],
axes_builder,
chart_builder_passthru,
modal=True,
)
bars.figure["layout"]["xaxis"]["type"] = "category"
bars.figure["layout"]["title"]["text"] = "Histogram of {} ({} data points)".format(
col, len(s)
)
return bars
[docs]def build_drilldown_title(data_id, all_inputs, click_point, props, val_prop):
data = global_state.get_data(data_id)
def _build_val(col, val):
if classify_type(find_dtype(data[col])) == "D":
return json_date(convert_date_val_to_date(val))
return val
if "text" in click_point: # Heatmaps
strs = []
for dim in click_point["text"].split("<br>"):
prop, val = dim.split(": ")
strs.append("{} ({})".format(prop, val))
return "Drilldown for: {}".format(", ".join(strs))
strs = []
frame_col = all_inputs.get("animate_by")
if frame_col:
strs.append("{} ({})".format(frame_col, click_point.get("customdata")))
for prop in props:
prop = make_list(prop)
val_key = prop[0]
if click_point.get(val_key) is not None:
col = make_list(all_inputs.get(prop[-1]))[0]
strs.append(
"{} ({})".format(col, _build_val(col, click_point.get(val_key)))
)
val_prop = make_list(val_prop)
val_key = val_prop[0]
val_col = make_list(all_inputs.get(val_prop[-1]))[0]
agg = AGGS[all_inputs.get("agg") or "raw"]
strs.append(
"{} {} ({})".format(agg, val_col, _build_val(val_col, click_point.get(val_key)))
)
return "Drilldown for: {}".format(", ".join(strs))
[docs]def init_callbacks(dash_app):
def toggle_modal(close1, close2, click_data, is_open, inputs, drilldowns_on):
if not drilldowns_on:
raise PreventUpdate
if (inputs.get("agg") or "raw") == "raw":
raise PreventUpdate
if close1 or close2 or click_data:
return not is_open
return is_open
def build_x_dropdown(is_open, pathname, inputs, chart_inputs, yaxis_data, map_data):
if not is_open:
raise PreventUpdate
df = global_state.get_data(get_data_id(pathname))
all_inputs = combine_inputs(
dash_app, inputs, chart_inputs, yaxis_data, map_data
)
chart_type, x, y, z, group, map_val, animate_by = (
all_inputs.get(p)
for p in ["chart_type", "x", "y", "z", "group", "map_val", "animate_by"]
)
if chart_type == "maps":
if all_inputs.get("map_type") == "choropleth":
x = all_inputs["loc"]
else:
x = "lat_lon"
x_options = build_selections(map_val, animate_by)
else:
x_options = build_selections(z or y, animate_by)
col_opts = list(build_cols(df.columns, get_dtypes(df)))
x_options = [build_option(c, l) for c, l in col_opts if c not in x_options]
return x_options, x
def update_click_data(
click_data,
pathname,
inputs,
chart_inputs,
yaxis_data,
map_data,
current_click_data,
drilldowns_on,
):
if not drilldowns_on:
raise PreventUpdate
all_inputs = combine_inputs(
dash_app, inputs, chart_inputs, yaxis_data, map_data
)
if (all_inputs.get("agg") or "raw") == "raw":
raise PreventUpdate
if not click_data:
raise PreventUpdate
data_id = get_data_id(pathname)
chart_type = all_inputs.get("chart_type")
click_data = click_data or current_click_data
click_point = next((p for p in click_data.get("points", [])), None)
if chart_type == "maps":
props = [["location", "loc"], "lat", "lon"]
val = ["z", "map_val"]
else:
props = ["x"]
val = "y"
if click_point.get("z"):
props = ["x", "y"]
val = "z"
header = build_drilldown_title(data_id, all_inputs, click_point, props, val)
return click_data, header
def load_drilldown_content(
_click_data_ts,
drilldown_type,
drilldown_x,
pathname,
inputs,
chart_inputs,
yaxis_data,
map_data,
click_data,
drilldowns_on,
):
if not drilldowns_on:
raise PreventUpdate
all_inputs = combine_inputs(
dash_app, inputs, chart_inputs, yaxis_data, map_data
)
agg = all_inputs.get("agg") or "raw"
chart_type = all_inputs.get("chart_type")
frame_col = all_inputs.get("animate_by")
all_inputs.pop("animate_by", None)
if agg == "raw":
raise PreventUpdate
if drilldown_x is None and chart_type != "maps":
raise PreventUpdate
if click_data:
click_point = next((p for p in click_data.get("points", [])), None)
if click_point:
data_id = get_data_id(pathname)
curr_settings = global_state.get_settings(data_id) or {}
query = build_query(
data_id, all_inputs.get("query") or curr_settings.get("query")
)
x_col = all_inputs.get("x")
y_col = next((y2 for y2 in make_list(all_inputs.get("y"))), None)
if chart_type in ZAXIS_CHARTS:
x, y, z, frame = (
click_point.get(p) for p in ["x", "y", "z", "customdata"]
)
if chart_type == "heatmap":
click_point_vals = {}
for dim in click_point["text"].split("<br>"):
prop, val = dim.split(": ")
click_point_vals[prop] = val
x, y, frame = (
click_point_vals.get(p) for p in [x_col, y_col, frame_col]
)
point_filter = {x_col: x, y_col: y}
if frame_col:
point_filter[frame_col] = frame
if drilldown_type == "histogram":
z_col = next(
(z2 for z2 in make_list(all_inputs.get("z"))), None
)
hist_chart = build_histogram(
data_id, z_col, query, point_filter
)
return hist_chart, dict(display="none")
else:
xy_query = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
if not query:
query = xy_query
else:
query = "({}) and ({})".format(query, xy_query)
all_inputs["query"] = query
all_inputs["chart_type"] = drilldown_type
all_inputs["agg"] = "raw"
all_inputs["modal"] = True
all_inputs["x"] = drilldown_x
all_inputs["y"] = [all_inputs["z"]]
chart, _, _ = build_chart(get_data_id(pathname), **all_inputs)
return chart, None
elif chart_type == "maps":
map_type = all_inputs.get("map_type")
point_filter = {}
if frame_col:
point_filter[frame_col] = click_point["customdata"]
if map_type == "choropleth":
point_filter[all_inputs["loc"]] = click_point["location"]
elif map_type == "scattergeo":
lat, lon = (click_point.get(p) for p in ["lat", "lon"])
point_filter[all_inputs["lat"]] = lat
point_filter[all_inputs["lon"]] = lon
map_val = all_inputs["map_val"]
if drilldown_type == "histogram":
hist_chart = build_histogram(
data_id, map_val, query, point_filter
)
return hist_chart, dict(display="none")
else:
map_query = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
if not query:
query = map_query
else:
query = "({}) and ({})".format(query, map_query)
all_inputs["query"] = query
all_inputs["chart_type"] = drilldown_type
all_inputs["agg"] = "raw"
all_inputs["modal"] = True
data_id = get_data_id(pathname)
data = global_state.get_data(data_id)
all_inputs["x"] = drilldown_x
x_style = None
if map_type != "choropleth":
all_inputs["x"] = "lat_lon"
lat, lon = (all_inputs.get(p) for p in ["lat", "lon"])
data.loc[:, "lat_lon"] = (
data[lat].astype(str) + "|" + data[lon].astype(str)
)
x_style = dict(display="none")
all_inputs["y"] = [map_val]
chart, _, _ = build_chart(data_id, data=data, **all_inputs)
return chart, x_style
else:
x_filter = click_point.get("x")
point_filter = {x_col: x_filter}
if frame_col:
point_filter[frame_col] = click_point.get("customdata")
if drilldown_type == "histogram":
hist_chart = build_histogram(
data_id, y_col, query, point_filter
)
return hist_chart, dict(display="none")
else:
x_query = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
if not query:
query = x_query
else:
query = "({}) and ({})".format(query, x_query)
all_inputs["query"] = query
all_inputs["chart_type"] = drilldown_type
all_inputs["agg"] = "raw"
all_inputs["modal"] = True
all_inputs["x"] = drilldown_x
chart, _, _ = build_chart(get_data_id(pathname), **all_inputs)
return chart, None
return None, dict(display="none")
for i in range(1, MAX_GROUPS + 1):
dash_app.callback(
Output("drilldown-modal-{}".format(i), "is_open"),
[
Input("close-drilldown-modal-{}".format(i), "n_clicks"),
Input("close-drilldown-modal-header-{}".format(i), "n_clicks"),
Input("chart-{}".format(i), "clickData"),
],
[
State("drilldown-modal-{}".format(i), "is_open"),
State("input-data", "data"),
State("drilldown-toggle", "on"),
],
)(toggle_modal)
dash_app.callback(
[
Output("chart-click-data-{}".format(i), "data"),
Output("drilldown-modal-header-{}".format(i), "children"),
],
[Input("chart-{}".format(i), "clickData")],
[
State("url", "pathname"),
State("input-data", "data"),
State("chart-input-data", "data"),
State("yaxis-data", "data"),
State("map-input-data", "data"),
State("chart-click-data-{}".format(i), "data"),
State("drilldown-toggle", "on"),
],
)(update_click_data)
dash_app.callback(
[
Output("drilldown-x-dropdown-{}".format(i), "options"),
Output("drilldown-x-dropdown-{}".format(i), "value"),
],
[Input("drilldown-modal-{}".format(i), "is_open")],
[
State("url", "pathname"),
State("input-data", "data"),
State("chart-input-data", "data"),
State("yaxis-data", "data"),
State("map-input-data", "data"),
],
)(build_x_dropdown)
dash_app.callback(
[
Output("drilldown-content-{}".format(i), "children"),
Output("drilldown-x-input-{}".format(i), "style"),
],
[
Input("chart-click-data-{}".format(i), "modified_timestamp"),
Input("drilldown-chart-type-{}".format(i), "value"),
Input("drilldown-x-dropdown-{}".format(i), "value"),
],
[
State("url", "pathname"),
State("input-data", "data"),
State("chart-input-data", "data"),
State("yaxis-data", "data"),
State("map-input-data", "data"),
State("chart-click-data-{}".format(i), "data"),
State("drilldown-toggle", "on"),
],
)(load_drilldown_content)