draw_networkx.py 7.54 KB
Newer Older
felix's avatar
felix committed
1
2
3
4
5
# @Author:  Felix Kramer <kramer>
# @Date:   2021-05-08T20:34:30+02:00
# @Email:  kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by:    Felix Kramer
Felix's avatar
Felix committed
6
# @Last modified time: 2021-09-05T00:05:38+02:00
felix's avatar
felix committed
7
8
9
10
11
# @License: MIT

# standard types
import networkx as nx
import numpy as np
felix's avatar
felix committed
12
import pandas as pd
felix's avatar
felix committed
13
14
15
16
17
18
19
20
21
22
import plotly.graph_objects as go

#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#a845b5'],
        'color_edges':['#c762d4']
    }
Felix's avatar
Felix committed
23
24
    node_data=pd.DataFrame()
    edge_data=pd.DataFrame()
felix's avatar
felix committed
25

felix's avatar
felix committed
26
    for k,v in kwargs.items():
felix's avatar
felix committed
27

felix's avatar
felix committed
28
29
        if k in options:
            options[k]=v
felix's avatar
felix committed
30
31
32
33
34
35
    if 'node_data' in kwargs:
        for sk,sv in kwargs['node_data'].items():
            node_data[sk]=sv.to_numpy()
    if 'edge_data' in kwargs:
        for sk,sv in kwargs['edge_data'].items():
            edge_data[sk]=sv.to_numpy()
Felix's avatar
Felix committed
36

felix's avatar
felix committed
37
    fig = go.Figure()
Felix's avatar
Felix committed
38
39
40
    add_traces_nodes(fig,options,input_graph,node_data)
    add_traces_edges(fig,options,input_graph,edge_data)

felix's avatar
felix committed
41
42
43
44
45
46
47
48
49
50
51
    fig.update_layout(showlegend=False)

    return fig

def plot_networkx_dual(dual_graph,**kwargs):

    options={
        'network_id':0,
        'color_nodes':['#6aa84f','#a845b5'],
        'color_edges':['#2BDF94','#c762d4']
    }
felix's avatar
felix committed
52

felix's avatar
felix committed
53
54
55
56
57
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    fig = go.Figure()
Felix's avatar
Felix committed
58

felix's avatar
felix committed
59
    for i,K in enumerate(dual_graph.layer):
Felix's avatar
Felix committed
60

felix's avatar
felix committed
61
        options['network_id']=i
Felix's avatar
Felix committed
62
63
64
65
66
67
        node_data=K.get_nodes_data()
        edges_data=K.get_edges_data()

        add_traces_nodes(fig,options,K.G,node_data)
        add_traces_edges(fig,options,K.G,edges_data)

felix's avatar
felix committed
68
69
70
71
    fig.update_layout(showlegend=False)

    return fig

Felix's avatar
Felix committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# integrate traces into the figure
def add_traces_edges(fig, options, input_graph,extra_data):

    idx=options['network_id']

    edge_mid_trace=get_edge_mid_trace(input_graph,extra_data,color=options['color_edges'][idx])
    edge_invd_traces=get_edge_invd_traces(input_graph,extra_data,color=options['color_edges'][idx])

    for eit in edge_invd_traces:
        fig.add_trace(eit)

    fig.add_trace(edge_mid_trace)

def add_traces_nodes(fig, options, input_graph,extra_data):

    idx=options['network_id']
    node_trace=get_node_trace(input_graph,extra_data,color=options['color_nodes'][idx])
    fig.add_trace( node_trace)

felix's avatar
felix committed
91
#auxillary functions generating traces for nodes and edges
Felix's avatar
Felix committed
92
def get_edge_mid_trace(input_graph,extra_data, **kwargs):
felix's avatar
felix committed
93

Felix's avatar
Felix committed
94
95
96
97
98
99
100
    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v
felix's avatar
felix committed
101

Felix's avatar
Felix committed
102
    pos=nx.get_node_attributes(input_graph,'pos')
felix's avatar
felix committed
103
104
105
    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])

Felix's avatar
Felix committed
106
107
108
    E=input_graph.edges()
    if 'edge_list' in options:
        E=options['edge_list']
felix's avatar
felix committed
109

Felix's avatar
Felix committed
110
    middle_node_trace=get_hover_scatter_from_template(options)
felix's avatar
felix committed
111

Felix's avatar
Felix committed
112
113
    XYZ= [[] for i in range(options['dim'])]
    for j,edge in enumerate(E):
felix's avatar
felix committed
114

Felix's avatar
Felix committed
115
116
        XYZ_0 =pos[edge[0]]
        XYZ_1 =pos[edge[1]]
felix's avatar
felix committed
117

Felix's avatar
Felix committed
118
119
        for i,xi in enumerate(XYZ):
            xi.append((XYZ_0[i]+XYZ_1[i])/2.)
felix's avatar
felix committed
120

Felix's avatar
Felix committed
121
    set_hover_info(middle_node_trace,XYZ,extra_data)
felix's avatar
felix committed
122

Felix's avatar
Felix committed
123
124
125
126
127
128
129
130
131
    return middle_node_trace

def set_hover_info(trace,XYZ,extra_data):

    tags=['x','y','z']
    if len(XYZ)<3:
        tags=['x','y']
    for i,t in enumerate(tags):
        trace[t]=XYZ[i]
felix's avatar
felix committed
132
133
134

    if len(extra_data.keys())!=0:
        data=[ list(extra_data[c]) for c in extra_data.columns]
Felix's avatar
Felix committed
135
136
137
138
139
140
        text=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
        trace['text']=text
    else:
        trace['hoverinfo']='none'

def get_hover_scatter_from_template(options):
felix's avatar
felix committed
141
142

    if options['dim']==3:
Felix's avatar
Felix committed
143
144
145
146
147
148
149
150
151
152
        middle_node_trace = go.Scatter3d(
            x=[],
            y=[],
            z=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            opacity=0,
            marker=dict(color=options['color'])
        )
felix's avatar
felix committed
153
    else:
Felix's avatar
Felix committed
154
155
156
157
158
159
160
161
162
163
164
        middle_node_trace = go.Scatter(
            x=[],
            y=[],
            text=[],
            mode='markers',
            hoverinfo='text',
            marker=go.scatter.Marker(
                opacity=0,
                color=options['color']
            )
        )
felix's avatar
felix committed
165

Felix's avatar
Felix committed
166
    return middle_node_trace
felix's avatar
felix committed
167

Felix's avatar
Felix committed
168
def get_edge_invd_traces(input_graph,extra_data, **kwargs):
felix's avatar
felix committed
169
170
171
172
173
174
175
176
177

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

Felix's avatar
Felix committed
178
179
180
181
182
183
184
185
186
187
188
    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0]) != options['dim']:
        options['dim']=len(list(pos.values())[0])

    E=input_graph.edges()
    if 'edge_list' in options:
        E=options['edge_list']

    trace_list = []

    for i,edge in enumerate(E):
Felix's avatar
Felix committed
189

Felix's avatar
Felix committed
190
        options['weight']=5.
Felix's avatar
   
Felix committed
191

Felix's avatar
Felix committed
192
        if 'weight' in extra_data:
Felix's avatar
   
Felix committed
193

Felix's avatar
Felix committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            options['weight']=extra_data['weight'][i]

        trace=get_line_from_template(options)
        XYZ_0 = input_graph.nodes[edge[0]]['pos']
        XYZ_1 = input_graph.nodes[edge[1]]['pos']

        set_edge_info(trace,XYZ_0,XYZ_1)
        trace_list.append(trace)

    return trace_list

def set_edge_info(trace,XYZ_0,XYZ_1):

    tags=['x','y','z']
    if len(XYZ_0)<3:
        tags=['x','y']
    for i,t in enumerate(tags):
        trace[t]=[XYZ_0[i], XYZ_1[i], None]

def get_line_from_template(options):

    if options['dim']==3:

        trace=go.Scatter3d(
            x=[],
            y=[],
            z=[],
            mode='lines',
            line=dict(color=options['color'],  width=options['weight']),
            hoverinfo='none'
        )
felix's avatar
felix committed
225

Felix's avatar
Felix committed
226
227
228
229
230
231
232
233
234
235
236
    else:

        trace=go.Scatter(
            x=[],
            y=[],
            mode='lines',
            line=dict(color=options['color'], width=options['weight']),
            hoverinfo='none'
        )

    return  trace
felix's avatar
felix committed
237
238
239
240
241
242
243
244

def get_node_coords(input_graph,options):

    pos=nx.get_node_attributes(input_graph,'pos')
    if len(list(pos.values())[0])!=options['dim']:
        options['dim']=len(list(pos.values())[0])

    node_xyz = [[] for i in range(options['dim'])]
Felix's avatar
Felix committed
245
246
247
248
249
250

    N=input_graph.nodes()
    if 'node_list' in options:
        N=options['edge_list']

    for node in N:
felix's avatar
felix committed
251
252
253
254
255
256
257
258
259

        xyz_0= pos[node]

        for i in range(options['dim']):

            node_xyz[i].append(xyz_0[i])

    return node_xyz

felix's avatar
felix committed
260
def get_node_scatter(node_xyz,extra_data,options):
felix's avatar
felix committed
261

Felix's avatar
Felix committed
262
263
264
265
266
267
268
269
    mode='none'
    hover=''

    if len(extra_data.keys())!=0:
        mode='text'
        data=[ list(extra_data[c]) for c in extra_data.columns]
        hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]

felix's avatar
felix committed
270
271
272
273
    if options['dim']==3:
        node_trace = go.Scatter3d(
        x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
        mode='markers',
Felix's avatar
Felix committed
274
275
        hoverinfo=mode,
        hovertext=hover,
felix's avatar
felix committed
276
277
278
279
280
281
282
283
284
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )
    else:
        node_trace = go.Scatter(
        x=node_xyz[0], y=node_xyz[1],
        mode='markers',
Felix's avatar
Felix committed
285
286
        hoverinfo=mode,
        hovertext=hover,
felix's avatar
felix committed
287
288
289
290
291
292
293
294
        marker=dict(
            size=2,
            line_width=2,
            color=options['color'])
            )

    return node_trace

felix's avatar
felix committed
295
def get_node_trace(input_graph,extra_data,**kwargs):
felix's avatar
felix committed
296
297
298
299
300
301
302
303
304
305
306

    options={
        'color':'#888',
        'dim':3
    }
    for k,v in kwargs.items():
        if k in options:
            options[k]=v

    node_xyz=get_node_coords(input_graph,options)

felix's avatar
felix committed
307
    node_trace = get_node_scatter(node_xyz,extra_data,options)
felix's avatar
felix committed
308
309
310

    return node_trace

felix's avatar
felix committed
311
312
313
314
315
316
317
def create_tag(vals,columns):

    tag=f''
    for i,c in enumerate(columns):
        tag+=str(c)+': '+str(vals[i])+'<br>'

    return tag