Skip to content
Snippets Groups Projects
draw_networkx.py 7.55 KiB
Newer Older
# @Author:  Felix Kramer <kramer>
# @Date:   2021-05-08T20:34:30+02:00
# @Email:  kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by:    Felix Kramer
Felix's avatar
 
Felix committed
# @Last modified time: 2021-06-25T16:34:29+02:00
# @License: MIT

# standard types
import networkx as nx
import numpy as np
felix's avatar
felix committed
import pandas as pd
import plotly.graph_objects as go

#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#a845b5'],
        'color_edges':['#c762d4']
    }
Felix's avatar
Felix committed
    node_data=pd.DataFrame()
    edge_data=pd.DataFrame()
felix's avatar
felix committed

    for k,v in kwargs.items():
felix's avatar
felix committed

        if k in options:
            options[k]=v
felix's avatar
felix committed
    if 'node_data' in kwargs:
        for sk,sv in kwargs['node_data'].items():
            node_data[sk]=sv.to_numpy()
    if 'edge_data' in kwargs:
        for sk,sv in kwargs['edge_data'].items():
            edge_data[sk]=sv.to_numpy()
    fig = go.Figure()
Felix's avatar
Felix committed
    add_traces_nodes(fig,options,input_graph,node_data)
    add_traces_edges(fig,options,input_graph,edge_data)

    fig.update_layout(showlegend=False)

    return fig

def plot_networkx_dual(dual_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#6aa84f','#a845b5'],
        'color_edges':['#2BDF94','#c762d4']
    }
felix's avatar
felix committed

    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()
    for i,K in enumerate(dual_graph.layer):
        options['network_id']=i
Felix's avatar
Felix committed
        node_data=K.get_nodes_data()
        edges_data=K.get_edges_data()

        add_traces_nodes(fig,options,K.G,node_data)
        add_traces_edges(fig,options,K.G,edges_data)

    fig.update_layout(showlegend=False)

    return fig

Felix's avatar
Felix committed
# integrate traces into the figure
def add_traces_edges(fig, options, input_graph,extra_data):

    idx=options['network_id']

    edge_mid_trace=get_edge_mid_trace(input_graph,extra_data,color=options['color_edges'][idx])
    edge_invd_traces=get_edge_invd_traces(input_graph,extra_data,color=options['color_edges'][idx])

    for eit in edge_invd_traces:
        fig.add_trace(eit)

    fig.add_trace(edge_mid_trace)

def add_traces_nodes(fig, options, input_graph,extra_data):

    idx=options['network_id']
    node_trace=get_node_trace(input_graph,extra_data,color=options['color_nodes'][idx])
    fig.add_trace( node_trace)

#auxillary functions generating traces for nodes and edges
Felix's avatar
Felix committed
def get_edge_mid_trace(input_graph,extra_data, **kwargs):
Felix's avatar
Felix committed
    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v
Felix's avatar
Felix committed
    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])

Felix's avatar
Felix committed
    E=input_graph.edges()
    if 'edge_list' in options:
        E=options['edge_list']
Felix's avatar
Felix committed
    middle_node_trace=get_hover_scatter_from_template(options)
Felix's avatar
Felix committed
    XYZ= [[] for i in range(options['dim'])]
    for j,edge in enumerate(E):
Felix's avatar
Felix committed
        XYZ_0 =pos[edge[0]]
        XYZ_1 =pos[edge[1]]
Felix's avatar
Felix committed
        for i,xi in enumerate(XYZ):
            xi.append((XYZ_0[i]+XYZ_1[i])/2.)
Felix's avatar
Felix committed
    set_hover_info(middle_node_trace,XYZ,extra_data)
felix's avatar
felix committed

Felix's avatar
Felix committed
    return middle_node_trace

def set_hover_info(trace,XYZ,extra_data):

    tags=['x','y','z']
    if len(XYZ)<3:
        tags=['x','y']
    for i,t in enumerate(tags):
        trace[t]=XYZ[i]
felix's avatar
felix committed

    if len(extra_data.keys())!=0:
        data=[ list(extra_data[c]) for c in extra_data.columns]
Felix's avatar
Felix committed
        text=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
        trace['text']=text
    else:
        trace['hoverinfo']='none'

def get_hover_scatter_from_template(options):

    if options['dim']==3:
Felix's avatar
Felix committed
        middle_node_trace = go.Scatter3d(
            x=[],
            y=[],
            z=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            opacity=0,
            marker=dict(color=options['color'])
        )
    else:
Felix's avatar
Felix committed
        middle_node_trace = go.Scatter(
            x=[],
            y=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            marker=go.scatter.Marker(
                opacity=0,
                color=options['color']
            )
        )
Felix's avatar
Felix committed
    return middle_node_trace
Felix's avatar
Felix committed
def get_edge_invd_traces(input_graph,extra_data, **kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

Felix's avatar
Felix committed
    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])

    E=input_graph.edges()
    if 'edge_list' in options:
        E=options['edge_list']

    trace_list = []

    for i,edge in enumerate(E):
Felix's avatar
 
Felix committed
        
Felix's avatar
Felix committed
        options['weight']=5.
Felix's avatar
 
Felix committed

Felix's avatar
Felix committed
        if 'weight' in extra_data:
Felix's avatar
 
Felix committed

Felix's avatar
Felix committed
            options['weight']=extra_data['weight'][i]

        trace=get_line_from_template(options)
        XYZ_0 = input_graph.nodes[edge[0]]['pos']
        XYZ_1 = input_graph.nodes[edge[1]]['pos']

        set_edge_info(trace,XYZ_0,XYZ_1)
        trace_list.append(trace)

    return trace_list

def set_edge_info(trace,XYZ_0,XYZ_1):

    tags=['x','y','z']
    if len(XYZ_0)<3:
        tags=['x','y']
    for i,t in enumerate(tags):
        trace[t]=[XYZ_0[i], XYZ_1[i], None]

def get_line_from_template(options):

    if options['dim']==3:

        trace=go.Scatter3d(
            x=[],
            y=[],
            z=[],
            mode='lines',
            line=dict(color=options['color'],  width=options['weight']),
            hoverinfo='none'
        )
Felix's avatar
Felix committed
    else:

        trace=go.Scatter(
            x=[],
            y=[],
            mode='lines',
            line=dict(color=options['color'], width=options['weight']),
            hoverinfo='none'
        )

    return  trace

def get_node_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0])!=options['dim']:
        options['dim']=len(list(pos.values())[0])

    node_xyz = [[] for i in range(options['dim'])]
Felix's avatar
Felix committed

    N=input_graph.nodes()
    if 'node_list' in options:
        N=options['edge_list']

    for node in N:

        xyz_0= pos[node]

        for i in range(options['dim']):

            node_xyz[i].append(xyz_0[i])

    return node_xyz

felix's avatar
felix committed
def get_node_scatter(node_xyz,extra_data,options):
Felix's avatar
Felix committed
    mode='none'
    hover=''

    if len(extra_data.keys())!=0:
        mode='text'
        data=[ list(extra_data[c]) for c in extra_data.columns]
        hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]

    if options['dim']==3:
        node_trace = go.Scatter3d(
        x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
        mode='markers',
Felix's avatar
Felix committed
        hoverinfo=mode,
        hovertext=hover,
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )
    else:
        node_trace = go.Scatter(
        x=node_xyz[0], y=node_xyz[1],
        mode='markers',
Felix's avatar
Felix committed
        hoverinfo=mode,
        hovertext=hover,
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )

    return node_trace

felix's avatar
felix committed
def get_node_trace(input_graph,extra_data,**kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    node_xyz=get_node_coords(input_graph,options)

felix's avatar
felix committed
    node_trace = get_node_scatter(node_xyz,extra_data,options)

    return node_trace

felix's avatar
felix committed
def create_tag(vals,columns):

    tag=f''
    for i,c in enumerate(columns):
        tag+=str(c)+': '+str(vals[i])+'<br>'

    return tag