Skip to content
Snippets Groups Projects
draw_networkx.py 4.81 KiB
Newer Older
# @Author:  Felix Kramer <kramer>
# @Date:   2021-05-08T20:34:30+02:00
# @Email:  kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by:    Felix Kramer
# @Last modified time: 2021-05-23T23:17:48+02:00
# @License: MIT

# standard types
import networkx as nx
import numpy as np
felix's avatar
felix committed
import pandas as pd
import plotly.graph_objects as go

#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#a845b5'],
        'color_edges':['#c762d4']
    }
felix's avatar
felix committed
    extra_data=pd.DataFrame()

    for k,v in kwargs.items():
felix's avatar
felix committed

        if k in options:
            options[k]=v
felix's avatar
felix committed
        else:
            extra_data[k]=v
    print(extra_data)
    fig = go.Figure()
felix's avatar
felix committed
    plot_graph_components(fig,options,input_graph,extra_data)
    fig.update_layout(showlegend=False)

    return fig

def plot_networkx_dual(dual_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#6aa84f','#a845b5'],
        'color_edges':['#2BDF94','#c762d4']
    }
felix's avatar
felix committed

    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()
    for i,K in enumerate(dual_graph.layer):
        options['network_id']=i
felix's avatar
felix committed
        plot_graph_components(fig,options,K.G)
    fig.update_layout(showlegend=False)

    return fig

#auxillary functions generating traces for nodes and edges
def get_edge_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')

    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])
    edge_xyz = [[] for i in range(options['dim'])]

    for edge in input_graph.edges():

        xyz_0= pos[edge[0]]
        xyz_1 = pos[edge[1]]

        for i in range(options['dim']):

            edge_xyz[i].append(xyz_0[i])
            edge_xyz[i].append(xyz_1[i])
            edge_xyz[i].append(None)

    return edge_xyz

felix's avatar
felix committed
def get_edge_scatter(edge_xyz,extra_data,options):

    mode='none'
    hover=''

    if len(extra_data.keys())!=0:
        mode='text'
        data=[ list(extra_data[c]) for c in extra_data.columns]
        hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]

    if options['dim']==3:
        edge_trace = go.Scatter3d(
            x=edge_xyz[0], y=edge_xyz[1],z=edge_xyz[2],
            line=dict(width=5, color=options['color']),
felix's avatar
felix committed
            hoverinfo=mode,
            hovertext=hover,
            mode='lines')
    else:
        edge_trace = go.Scatter(
            x=edge_xyz[0], y=edge_xyz[1],
            line=dict(width=5, color=options['color']),
felix's avatar
felix committed
            hoverinfo=mode,
            hovertext=hover,
            mode='lines')

    return edge_trace

felix's avatar
felix committed
def get_edge_trace(input_graph,extra_data, **kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    edge_xyz=get_edge_coords(input_graph,options)
felix's avatar
felix committed
    edge_trace=get_edge_scatter(edge_xyz,extra_data,options)

    return edge_trace

def get_node_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0])!=options['dim']:
        options['dim']=len(list(pos.values())[0])

    node_xyz = [[] for i in range(options['dim'])]
    for node in input_graph.nodes():

        xyz_0= pos[node]

        for i in range(options['dim']):

            node_xyz[i].append(xyz_0[i])

    return node_xyz

felix's avatar
felix committed
def get_node_scatter(node_xyz,extra_data,options):

    if options['dim']==3:
        node_trace = go.Scatter3d(
        x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
        mode='markers',
        hoverinfo='none',
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )
    else:
        node_trace = go.Scatter(
        x=node_xyz[0], y=node_xyz[1],
        mode='markers',
        hoverinfo='none',
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )

    return node_trace

felix's avatar
felix committed
def get_node_trace(input_graph,extra_data,**kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    node_xyz=get_node_coords(input_graph,options)

felix's avatar
felix committed
    node_trace = get_node_scatter(node_xyz,extra_data,options)

    return node_trace

felix's avatar
felix committed
def create_tag(vals,columns):

    tag=f''
    for i,c in enumerate(columns):
        tag+=str(c)+': '+str(vals[i])+'<br>'

    return tag
# integrate traces into the figure
felix's avatar
felix committed
def plot_graph_components(fig,options,input_graph,extra_data):

    idx=options['network_id']
felix's avatar
felix committed
    edge_trace=(get_edge_trace(input_graph,extra_data,color=options['color_edges'][idx]))
    node_trace=(get_node_trace(input_graph,extra_data,color=options['color_nodes'][idx]))
    fig.add_trace( edge_trace)
    fig.add_trace( node_trace)