Skip to content
Snippets Groups Projects
draw_networkx.py 4.16 KiB
Newer Older
# @Author:  Felix Kramer <kramer>
# @Date:   2021-05-08T20:34:30+02:00
# @Email:  kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by:    Felix Kramer
# @Last modified time: 2021-05-23T23:17:48+02:00
# @License: MIT

# standard types
import networkx as nx
import numpy as np
import plotly.graph_objects as go

#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#a845b5'],
        'color_edges':['#c762d4']
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()
    plot_nodes_edges(fig,options,input_graph)
    fig.update_layout(showlegend=False)

    return fig

def plot_networkx_dual(dual_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#6aa84f','#a845b5'],
        'color_edges':['#2BDF94','#c762d4']
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()
    for i,K in enumerate(dual_graph.layer):
        options['network_id']=i
        plot_nodes_edges(fig,options,K.G)
    fig.update_layout(showlegend=False)

    return fig

#auxillary functions generating traces for nodes and edges
def get_edge_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')

    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])
    edge_xyz = [[] for i in range(options['dim'])]

    for edge in input_graph.edges():

        xyz_0= pos[edge[0]]
        xyz_1 = pos[edge[1]]

        for i in range(options['dim']):

            edge_xyz[i].append(xyz_0[i])
            edge_xyz[i].append(xyz_1[i])
            edge_xyz[i].append(None)

    return edge_xyz

def get_edge_scatter(edge_xyz,options):

    if options['dim']==3:
        edge_trace = go.Scatter3d(
            x=edge_xyz[0], y=edge_xyz[1],z=edge_xyz[2],
            line=dict(width=5, color=options['color']),
            hoverinfo='none',
            mode='lines')
    else:
        edge_trace = go.Scatter(
            x=edge_xyz[0], y=edge_xyz[1],
            line=dict(width=5, color=options['color']),
            hoverinfo='none',
            mode='lines')

    return edge_trace

def get_edge_trace(input_graph, **kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    edge_xyz=get_edge_coords(input_graph,options)
    edge_trace=get_edge_scatter(edge_xyz,options)

    return edge_trace

def get_node_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0])!=options['dim']:
        options['dim']=len(list(pos.values())[0])

    node_xyz = [[] for i in range(options['dim'])]
    for node in input_graph.nodes():

        xyz_0= pos[node]

        for i in range(options['dim']):

            node_xyz[i].append(xyz_0[i])

    return node_xyz

def get_node_scatter(node_xyz,options):

    if options['dim']==3:
        node_trace = go.Scatter3d(
        x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
        mode='markers',
        hoverinfo='none',
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )
    else:
        node_trace = go.Scatter(
        x=node_xyz[0], y=node_xyz[1],
        mode='markers',
        hoverinfo='none',
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )

    return node_trace

def get_node_trace(input_graph,**kwargs):

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    node_xyz=get_node_coords(input_graph,options)

    node_trace = get_node_scatter(node_xyz,options)

    return node_trace

# integrate traces into the figure
def plot_nodes_edges(fig,options,input_graph):

    idx=options['network_id']
    edge_trace=(get_edge_trace(input_graph,color=options['color_edges'][idx]))
    node_trace=(get_node_trace(input_graph,color=options['color_nodes'][idx]))
    fig.add_trace( edge_trace)
    fig.add_trace( node_trace)