# @Author:  Felix Kramer <kramer>
# @Date:   2021-05-08T20:34:30+02:00
# @Email:  kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by:    Felix Kramer
# @Last modified time: 2021-09-05T17:02:35+02:00
# @License: MIT

# standard types
import networkx as nx
import numpy as np
import pandas as pd
import plotly.graph_objects as go

#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#a845b5'],
        'color_edges':['#c762d4'],
        'markersize': [2]
    }

    node_data=pd.DataFrame()
    edge_data=pd.DataFrame()

    for k,v in kwargs.items():
        if k in options:
            options[k]=v
    if 'node_data' in kwargs:
        for sk,sv in kwargs['node_data'].items():
            node_data[sk]=sv.to_numpy()
    if 'edge_data' in kwargs:
        for sk,sv in kwargs['edge_data'].items():
            edge_data[sk]=sv.to_numpy()

    fig = go.Figure()
    add_traces_nodes(fig,options,input_graph,node_data)
    add_traces_edges(fig,options,input_graph,edge_data)

    fig.update_layout(showlegend=False)

    return fig

def plot_networkx_dual(dual_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#6aa84f','#a845b5'],
        'color_edges':['#2BDF94','#c762d4'],
        'markersize': [2,2]
    }

    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()

    for i,K in enumerate(dual_graph.layer):

        options['network_id']=i
        node_data=K.get_nodes_data(**kwargs)
        edges_data=K.get_edges_data(**kwargs)

        add_traces_nodes(fig,options,K.G,node_data)
        add_traces_edges(fig,options,K.G,edges_data)

    fig.update_layout(showlegend=False)

    return fig

# integrate traces into the figure
def add_traces_edges(fig, options, input_graph,extra_data):

    idx=options['network_id']

    edge_mid_trace=get_edge_mid_trace(input_graph,extra_data,color=options['color_edges'][idx])
    edge_invd_traces=get_edge_invd_traces(input_graph,extra_data,color=options['color_edges'][idx])

    for eit in edge_invd_traces:
        fig.add_trace(eit)

    fig.add_trace(edge_mid_trace)

def add_traces_nodes(fig, options, input_graph,extra_data):

    idx=options['network_id']

    node_trace=get_node_trace( input_graph, extra_data, color = options['color_nodes'][idx], markersize = options['markersize'][idx] )
    fig.add_trace( node_trace)

#auxillary functions generating traces for nodes and edges
def get_edge_mid_trace(input_graph,extra_data, **kwargs):

    options={
        'color':'#888',
        # 'dim':3
    }
    dim=3
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0]) != dim:
        dim=len(list(pos.values())[0])

    E=input_graph.edges()
    if 'edge_list' in options:
        E=options['edge_list']

    middle_node_trace=get_hover_scatter_from_template(dim,options)

    XYZ= [[] for i in range(dim)]
    for j,edge in enumerate(E):

        XYZ_0 =pos[edge[0]]
        XYZ_1 =pos[edge[1]]

        for i,xi in enumerate(XYZ):
            xi.append((XYZ_0[i]+XYZ_1[i])/2.)

    set_hover_info(middle_node_trace,XYZ,extra_data)

    return middle_node_trace

def set_hover_info(trace,XYZ,extra_data):

    tags=['x','y','z']
    if len(XYZ)<3:
        tags=['x','y']
    for i,t in enumerate(tags):
        trace[t]=XYZ[i]

    if len(extra_data.keys())!=0:
        data=[ list(extra_data[c]) for c in extra_data.columns]
        text=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
        trace['text']=text
    else:
        trace['hoverinfo']='none'

def get_hover_scatter_from_template(dim,options):

    if dim==3:
        middle_node_trace = go.Scatter3d(
            x=[],
            y=[],
            z=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            opacity=0,
            marker=dict(**options)
            # marker=dict(color=options['color'])
        )
    else:
        middle_node_trace = go.Scatter(
            x=[],
            y=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            # marker=go.scatter.Marker(
            #     opacity=0,
            #     color=options['color']
            # )
            marker=go.scatter.Marker(
                opacity=0,
                **options
            )

        )

    return middle_node_trace

def get_edge_invd_traces(input_graph,extra_data, **kwargs):

    options={
        'color':'#888',
        # 'dim':3
    }
    dim=3
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    # handle exceptions and new containers
    colorful=False
    if type(options['color'])!=str:
        colorful=True
        options['colorscale']='plasma'
        options['cmin']=np.min(options['color'])
        options['cmax']=np.max(options['color'])

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0]) != dim:
        dim=len(list(pos.values())[0])

    E=input_graph.edges()
    # if 'edge_list' in options:
    #     E=options['edge_list']

    # add new traces

    trace_list = []
    aux_option=dict(options)
    for i,edge in enumerate(E):

        aux_option['width']=5.

        if 'weight' in extra_data:

            aux_option['width']=extra_data['weight'][i]

        if colorful:
            aux_option['color']=[options['color'][i] for j in range(2)]

        trace=get_line_from_template(dim,aux_option)
        XYZ_0 = input_graph.nodes[edge[0]]['pos']
        XYZ_1 = input_graph.nodes[edge[1]]['pos']

        set_edge_info(trace,XYZ_0,XYZ_1)
        trace_list.append(trace)

    return trace_list

def set_edge_info(trace,XYZ_0,XYZ_1):

    tags=['x','y','z']
    if len(XYZ_0)<3:
        tags=['x','y']

    for i,t in enumerate(tags):
        trace[t]=[XYZ_0[i], XYZ_1[i], None]

def get_line_from_template(dim,options):

    if dim==3:

        trace=go.Scatter3d(
            x=[],
            y=[],
            z=[],
            mode='lines',
            # line=dict(color=options['color'],  width=options['weight'], cmin=options['cmin'], cmax=options['cmax']),
            line=dict(**options),
            hoverinfo='none'
        )

    else:

        trace=go.Scatter(
            x=[],
            y=[],
            mode='lines',
            # line=dict(color=options['color'], width=options['weight'], cmin=options['cmin'], cmax=options['cmax']),
            line=dict(**options),
            hoverinfo='none'
        )

    return  trace

def get_node_trace(input_graph,extra_data,**kwargs):

    options={
        'color':'#888',
        'dim':3,
        'markersize':2
    }

    for k,v in kwargs.items():
        if k in options:

            options[k]=v

    node_xyz=get_node_coords(input_graph,options)

    node_trace = get_node_scatter(node_xyz,extra_data,options)

    return node_trace

def get_node_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0])!=options['dim']:
        options['dim']=len(list(pos.values())[0])

    node_xyz = [[] for i in range(options['dim'])]

    N=input_graph.nodes()
    if 'node_list' in options:
        N=options['edge_list']

    for node in N:

        xyz_0= pos[node]

        for i in range(options['dim']):

            node_xyz[i].append(xyz_0[i])

    return node_xyz

def get_node_scatter(node_xyz,extra_data,options):

    mode='none'
    hover=''

    if len(extra_data.keys())!=0:
        mode='text'
        data=[ list(extra_data[c]) for c in extra_data.columns]
        hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]

    if options['dim']==3:
        node_trace = go.Scatter3d(
        x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
        mode='markers',
        hoverinfo=mode,
        hovertext=hover,
        marker=dict(
            size=options['markersize'],
            line_width=2,
            color=options['color'])
        )
    else:
        node_trace = go.Scatter(
        x=node_xyz[0], y=node_xyz[1],
        mode='markers',
        hoverinfo=mode,
        hovertext=hover,
        marker=dict(
            size=options['markersize'],
            line_width=2,
            color=options['color'])
            )

    return node_trace

def create_tag(vals,columns):

    tag=f''
    for i,c in enumerate(columns):
        tag+=str(c)+': '+str(vals[i])+'<br>'

    return tag