Newer
Older
# @Author: Felix Kramer <kramer>
# @Date: 2021-05-08T20:34:30+02:00
# @Email: kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by: Felix Kramer
# @Last modified time: 2021-05-23T23:17:48+02:00
# @License: MIT
# standard types
import networkx as nx
import numpy as np
import plotly.graph_objects as go
#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):
options={
'network_id':0,
'color_nodes':['#a845b5'],
'color_edges':['#c762d4']
}
fig.update_layout(showlegend=False)
return fig
def plot_networkx_dual(dual_graph,**kwargs):
options={
'network_id':0,
'color_nodes':['#6aa84f','#a845b5'],
'color_edges':['#2BDF94','#c762d4']
}
for k,v in kwargs.items():
if k in options:
options[k]=v
fig = go.Figure()
for i,K in enumerate(dual_graph.layer):
options['network_id']=i
fig.update_layout(showlegend=False)
return fig
#auxillary functions generating traces for nodes and edges
def get_edge_coords(input_graph,options):
pos=nx.get_node_attributes(input_graph,'pos')
if len(list(pos.values())[0]) != options['dim']:
options['dim']=len(list(pos.values())[0])
edge_xyz = [[] for i in range(options['dim'])]
for edge in input_graph.edges():
xyz_0= pos[edge[0]]
xyz_1 = pos[edge[1]]
for i in range(options['dim']):
edge_xyz[i].append(xyz_0[i])
edge_xyz[i].append(xyz_1[i])
edge_xyz[i].append(None)
return edge_xyz
def get_edge_scatter(edge_xyz,extra_data,options):
mode='none'
hover=''
if len(extra_data.keys())!=0:
mode='text'
data=[ list(extra_data[c]) for c in extra_data.columns]
hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
if options['dim']==3:
edge_trace = go.Scatter3d(
x=edge_xyz[0], y=edge_xyz[1],z=edge_xyz[2],
line=dict(width=5, color=options['color']),
mode='lines')
else:
edge_trace = go.Scatter(
x=edge_xyz[0], y=edge_xyz[1],
line=dict(width=5, color=options['color']),
options={
'color':'#888',
'dim':3
}
for k,v in kwargs.items():
if k in options:
options[k]=v
edge_xyz=get_edge_coords(input_graph,options)
return edge_trace
def get_node_coords(input_graph,options):
pos=nx.get_node_attributes(input_graph,'pos')
if len(list(pos.values())[0])!=options['dim']:
options['dim']=len(list(pos.values())[0])
node_xyz = [[] for i in range(options['dim'])]
for node in input_graph.nodes():
xyz_0= pos[node]
for i in range(options['dim']):
node_xyz[i].append(xyz_0[i])
return node_xyz
if options['dim']==3:
node_trace = go.Scatter3d(
x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
mode='markers',
hoverinfo='none',
marker=dict(
size=2,
line_width=2,
color=options['color'])
)
else:
node_trace = go.Scatter(
x=node_xyz[0], y=node_xyz[1],
mode='markers',
hoverinfo='none',
marker=dict(
size=2,
line_width=2,
color=options['color'])
)
return node_trace
options={
'color':'#888',
'dim':3
}
for k,v in kwargs.items():
if k in options:
options[k]=v
node_xyz=get_node_coords(input_graph,options)
node_trace = get_node_scatter(node_xyz,extra_data,options)
def create_tag(vals,columns):
tag=f''
for i,c in enumerate(columns):
tag+=str(c)+': '+str(vals[i])+'<br>'
return tag
def plot_graph_components(fig,options,input_graph,extra_data):
edge_trace=(get_edge_trace(input_graph,extra_data,color=options['color_edges'][idx]))
node_trace=(get_node_trace(input_graph,extra_data,color=options['color_nodes'][idx]))
fig.add_trace( edge_trace)
fig.add_trace( node_trace)