-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathN-planes-flow.py
More file actions
executable file
·340 lines (300 loc) · 11.3 KB
/
N-planes-flow.py
File metadata and controls
executable file
·340 lines (300 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
#!/usr/bin/env python
# takes a bunch of slices and aligns them
# a derivative of https://github.com/google-research/sofima/blob/main/notebooks/em_alignment.ipynb
# this first part just does the GPU intensive stuff
import sys
import os
import argparse
from concurrent import futures
import time
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from connectomics.common import bounding_box
from sofima import flow_field
from sofima import flow_utils
from sofima import map_utils
from sofima import mesh
from datetime import datetime
from skimage.transform import downscale_local_mean
import importlib
debug = False # save intermediate steps
# Parse command line arguments
parser = argparse.ArgumentParser(
description="computes the flow fields between a pair of slices - GPU intensive"
)
parser.add_argument(
"data_loader",
help="Data loader module name, e.g., data-test-2-planes"
)
parser.add_argument(
"basepath",
help="filepath to stitched planes"
)
parser.add_argument(
"min_z",
type=int,
help="lower bound on the planes to align"
)
parser.add_argument(
"max_z",
type=int,
help="upper bound on the planes to align"
)
parser.add_argument(
"scale",
type=int,
help="",
)
parser.add_argument(
"high_pass",
type=int,
help="",
)
parser.add_argument(
"uniform",
type=int,
help="",
)
parser.add_argument(
"threshold",
type=int,
help="",
)
parser.add_argument(
"close",
type=int,
help="",
)
parser.add_argument(
"patch_size",
type=str,
help="Side length of (square) patch for processing (in pixels, e.g., 32)",
)
parser.add_argument(
"stride",
type=str,
help="Distance of adjacent patches (in pixels, e.g., 8)"
)
parser.add_argument(
"scales",
help="the spatial resolutions to use when computing the flow field"
)
parser.add_argument(
"k0",
type=float,
help="spring constant for inter-section springs"
)
parser.add_argument(
"k",
type=float,
help="spring constant for intra-section springs"
)
parser.add_argument(
"batch_size",
type=int,
help="how many patches to process simultaneously",
)
parser.add_argument(
"write_metadata",
type=int,
help="whether to write the zarr metadata for not"
)
args = parser.parse_args()
data_loader = args.data_loader
basepath = args.basepath
min_z = args.min_z
max_z = args.max_z
scale = args.scale
high_pass = args.high_pass
uniform = args.uniform
threshold = args.threshold
close = args.close
patch_size = args.patch_size
stride = args.stride
scales = args.scales
k0 = args.k0
k = args.k
batch_size = args.batch_size
write_metadata = args.write_metadata
print("data_loader =", data_loader)
print("basepath =", basepath)
print("min_z =", min_z)
print("max_z =", max_z)
print("scale =", scale)
print("high_pass =", high_pass)
print("uniform =", uniform)
print("threshold =", threshold)
print("close =", close)
print("patch_size =", patch_size)
print("stride =", stride)
print("scales =", scales)
print("k0 =", k0)
print("k =", k)
print("batch_size =", batch_size)
print("write_metadata =", write_metadata)
patch_size_int = [int(x) for x in args.patch_size.split(',')]
stride_int = [int(x) for x in args.stride.split(',')]
scales_int = [int(x) for x in scales.split(',')]
if len(patch_size_int) != len(stride_int):
print("lengths of patch_size and stride must be equal")
exit()
data = importlib.import_module(os.path.basename(data_loader))
params = 'scale'+str(scale)+'.high_pass'+str(high_pass)+'.uniform'+str(uniform)+'.threshold'+str(threshold)+'.close'+str(close)
def _compute_flow(scales, patch_size, stride, prev_flows=None, pre_stride=None):
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
flows = {s:[] for s in scales}
_prev = data.load_data(basepath, min_z, 0)
prev = {s:_prev if s==0 else downscale_local_mean(_prev, (2**s,2**s)) for s in scales}
_prev_mask = data.load_mask(basepath, params, min_z)
prev_mask = {s:_prev_mask if s==0 or _prev_mask is None else _prev_mask[::2**s,::2**s] for s in scales}
if 0 not in scales:
del _prev
del _prev_mask
fs = []
fs_mask = []
with futures.ThreadPoolExecutor() as tpe:
# Prefetch the next sections to memory so that we don't have to wait for them
# to load when the GPU becomes available.
for z in range(min_z+1, max_z+2):
fs.append(tpe.submit(lambda z=z: data.load_data(basepath, z, 0)))
fs_mask.append(tpe.submit(lambda z=z: data.load_mask(basepath, params, z)))
fs = fs[::-1]
fs_mask = fs_mask[::-1]
for z in range(min_z+1, max_z+2):
print(datetime.now(), 'z =', z)
_curr = fs.pop().result()
curr = {s:_curr if s==0 else downscale_local_mean(_curr, (2**s,2**s)) for s in scales}
_curr_mask = fs_mask.pop().result()
curr_mask = {s:_curr_mask if s==0 or _curr_mask is None else _curr_mask[::2**s,::2**s] for s in scales}
if 0 not in scales:
del _curr
del _curr_mask
# The batch size is a parameter which impacts the efficiency of the computation (but
# not its result). It has to be large enough for the computation to fully utilize the
# available GPU capacity, but small enough so that the batch fits in GPU RAM.
for s in scales:
flows[s].append(mfc.flow_field(prev[s], curr[s],
(patch_size, patch_size),
(stride, stride),
batch_size = batch_size,
pre_mask = prev_mask[s],
post_mask = curr_mask[s],
mask_only_for_patch_selection=True,
pre_targeting_field = prev_flows[s][z-(min_z+1)][:2, ::] if prev_flows else None,
pre_targeting_step = (pre_stride, pre_stride) if pre_stride else None))
prev = curr
return flows
print(datetime.now(), 'computing flow')
fNx = _compute_flow(scales_int, patch_size_int[0], stride_int[0])
print("sum of flows = ", sum([np.nansum(np.abs(x)) for x in fNx.values()]))
for i in range(1, len(patch_size_int)):
fNx = _compute_flow(scales_int, patch_size_int[i], stride_int[i], fNx, stride_int[i-1])
print("sum of flows = ", sum([np.nansum(np.abs(x)) for x in fNx.values()]))
print(datetime.now(), 'cleaning flow')
fN = {}
for s in scales_int:
# Convert to [channels, z, y, x].
flows = np.transpose(np.array(fNx[s]), [1, 0, 2, 3])
# Pad to account for the edges of the images where there is insufficient context to estimate flow.
pad = patch_size_int[-1] // 2 // stride_int[-1]
flows = np.pad(flows, [[0, 0], [0, 0], [pad, pad], [pad, pad]], constant_values=np.nan)
fN[s] = flow_utils.clean_flow(flows, min_peak_ratio=1.6, min_peak_sharpness=1.6,
max_magnitude=80, max_deviation=20)
'''
f, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0,0].hist(flows1x[0, 0, ...][~np.isnan(flows1x[0, 0, ...])])
ax[0,1].hist(f1[0, 0, ...][~np.isnan(f1[0, 0, ...])])
ax[1,0].hist(flows2x[0, 0, ...][~np.isnan(flows2x[0, 0, ...])])
ax[1,1].hist(f2[0, 0, ...][~np.isnan(f2[0, 0, ...])])
ax[0,2].hist(flows1x[1, 0, ...][~np.isnan(flows1x[1, 0, ...])])
ax[0,3].hist(f1[1, 0, ...][~np.isnan(f1[1, 0, ...])])
ax[1,2].hist(flows2x[1, 0, ...][~np.isnan(flows2x[1, 0, ...])])
ax[1,3].hist(f2[1, 0, ...][~np.isnan(f2[1, 0, ...])])
plt.tight_layout()
plt.savefig("flows-f-hist.tif", dpi=100)
'''
'''
# Plot the horizontal component of the flow vector, before (left) and after (right) filtering
f, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0,0].imshow(flows1x[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,0].title.set_text('H flows1x')
ax[0,1].imshow(f1[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,1].title.set_text('H f1')
ax[1,0].imshow(flows2x[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,0].title.set_text('H flows2x')
ax[1,1].imshow(f2[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,1].title.set_text('H f2')
ax[0,2].imshow(flows1x[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,2].title.set_text('V flows1x')
ax[0,3].imshow(f1[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,3].title.set_text('V f1')
ax[1,2].imshow(flows2x[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,2].title.set_text('V flows2x')
ax[1,3].imshow(f2[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,3].title.set_text('V f2')
plt.tight_layout()
plt.savefig("flows-f.tif", dpi=600)
'''
print(datetime.now(), 'reconciling flow')
from scipy import interpolate
fN_hires = {}
s_min = min(scales_int)
scale_min = 1 / (2**s_min)
boxMx = bounding_box.BoundingBox(start=(0, 0, 0),
size=(fN[s_min].shape[-1], fN[s_min].shape[-2], 1))
for s in scales_int:
if s==0:
fN_hires[0] = fN[0]
continue
scale = 1 / (2**s)
boxNx = bounding_box.BoundingBox(start=(0, 0, 0),
size=(fN[s].shape[-1], fN[s].shape[-2], 1))
for z in range(fN[s].shape[1]):
print(datetime.now(), 's =', s, ', z =', z)
# Upsample and scale spatial components.
resampled = map_utils.resample_map(
fN[s][:, z:z + 1, ...],
boxNx, boxMx, 1 / scale, 1 / scale_min)
if s not in fN_hires:
fN_hires[s] = np.zeros((resampled.shape[0], fN[s].shape[1], *resampled.shape[2:]))
fN_hires[s][:, z:z + 1, ...] = resampled / scale
final_flow = flow_utils.reconcile_flows(tuple(fN_hires[k] for k in scales_int),
max_gradient=0, max_deviation=20, min_patch_size=400)
'''
# Plot (left to right): high res. flow, upsampled low res. flow, combined flow to use for alignment.
f, ax = plt.subplots(2, 3, figsize=(7.5, 5))
ax[0,0].imshow(f1[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,0].title.set_text('H f1')
ax[0,1].imshow(f2_hires[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,1].title.set_text('H f2hi')
ax[0,2].imshow(final_flow[0, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0,2].title.set_text('H final')
ax[1,0].imshow(f1[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,0].title.set_text('V f1')
ax[1,1].imshow(f2_hires[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,1].title.set_text('V f2hi')
ax[1,2].imshow(final_flow[1, 0, ...], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1,2].title.set_text('V final')
plt.tight_layout()
plt.savefig("flows-hi-final-b.tif", dpi=300)
'''
'''
f, ax = plt.subplots(1, 2, figsize=(7.5, 5))
ax[0].imshow(f2[0, 0, 300:400, 0:100], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[0].title.set_text('H f2')
ax[1].imshow(f2_hires[0, 0, 600:800, 0:200], cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[1].title.set_text('H f2hi')
plt.tight_layout()
plt.savefig("flows-f2-f2hi-d.tif", dpi=300)
'''
params = 'patch'+patch_size+'.stride'+stride+'.scales'+args.scales.replace(",",'')+'.k0'+str(k0)+'.k'+str(k)
data.save_flow(final_flow, min_z+1, max_z+1, basepath, params, write_metadata)
if debug:
for s in scales_int:
flows = np.transpose(np.array(fNx[s]), [1, 0, 2, 3])
flows = np.pad(flows, [[0, 0], [0, 0], [pad, pad], [pad, pad]], constant_values=np.nan)
np.save(os.path.join(basepath, 'fNx.s'+str(s)+'.'+params+'.npy'), flows)
np.save(os.path.join(basepath, 'fN.s'+str(s)+'.'+params+'.npy'), fN[s])
np.save(os.path.join(basepath, 'fN_hires.s'+str(s)+'.'+params+'.npy'), fN_hires[s])