Skip to content

Commit 3b843ee

Browse files
committed
small change in parallelism to be compatible with older julia verions, small type stability fix in cg
1 parent 5dc5a2d commit 3b843ee

10 files changed

+203
-91
lines changed

examples/ConstraintSetupExamples.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ else
2727
#error("download escalator video from http://cvxr.com/tfocs/demos/rpca/escalator_data.mat")
2828
end
2929

30-
mtrue = read(file, "X")
31-
n1 = convert(Integer,read(file, "m"))
32-
n2 = convert(Integer,read(file, "n"))
33-
m_mat = convert(Array{TF,2},mtrue)
34-
m_tensor = convert(Array{TF,3},reshape(mtrue,n1,n2,Integer(200)))
30+
mtrue = read(file, "X");
31+
n1 = convert(Integer,read(file, "m"));
32+
n2 = convert(Integer,read(file, "n"));
33+
m_mat = convert(Array{TF,2},mtrue);
34+
m_tensor = convert(Array{TF,3},reshape(mtrue,n1,n2,Integer(200)));
3535

3636
#computational grid for the video
37-
comp_grid = compgrid((1f0,1f0,1f0),(size(m_tensor,1),size(m_tensor,2), size(m_tensor,3)))
37+
comp_grid = compgrid((1f0,1f0,1f0),(size(m_tensor,1),size(m_tensor,2), size(m_tensor,3)));
3838

39-
comp_grid_time_slice = compgrid((1f0,1f0),(size(m_tensor,1),size(m_tensor,2)))
39+
comp_grid_time_slice = compgrid((1f0,1f0),(size(m_tensor,1),size(m_tensor,2)));
4040

4141

4242
######################################################################
@@ -52,7 +52,7 @@ options = PARSDMM_options()
5252
#l1 (total variation) constraints (in one direction)
5353

5454
#find a reasonable value for the l1-ball
55-
(TD_OP, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid,"D_z",options.FL)
55+
(TD_OP, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid,"D_z",options.FL);
5656
TV_z = norm(TD_OP*vec(m_tensor),1)
5757

5858
m_min = 0.0
@@ -61,15 +61,15 @@ set_type = "l1"
6161
TD_OP = "D_z"
6262
app_mode = ("tensor","")
6363
custom_TD_OP = ([],false)
64-
push!(constraint, set_definitions(set_type,TD_OP,m_min,m_max,app_mode,custom_TD_OP))
64+
push!(constraint, set_definitions(set_type,TD_OP,m_min,m_max,app_mode,custom_TD_OP));
6565

66-
(P_sub,TD_OP,set_Prop) = setup_constraints(constraint,comp_grid,options.FL)
67-
(TD_OP,AtA,l,y) = PARSDMM_precompute_distribute(TD_OP,set_Prop,comp_grid,options)
66+
(P_sub,TD_OP,set_Prop) = setup_constraints(constraint,comp_grid,options.FL);
67+
(TD_OP,AtA,l,y) = PARSDMM_precompute_distribute(TD_OP,set_Prop,comp_grid,options);
6868

6969
@time (x,log_PARSDMM) = PARSDMM(vec(m_tensor),AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
7070
@time (x,log_PARSDMM) = PARSDMM(vec(m_tensor),AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
7171

72-
m_proj = reshape(x,comp_grid.n)
72+
m_proj = reshape(x,comp_grid.n);
7373

7474
figure();
7575
for i=1:comp_grid.n[3]
@@ -83,23 +83,23 @@ end
8383
######## (anisotropic) total-variation on the time-derivative ########
8484
######## using JOLI operators ########
8585

86-
options = PARSDMM_options()
86+
options = PARSDMM_options();
8787

8888
#TV operator per time-slice
89-
(TV, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid_time_slice,"TV",options.FL)
89+
(TV, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid_time_slice,"TV",options.FL);
9090
#time derivative over the time-slices
91-
(D, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid,"D_z",options.FL)
91+
(D, AtA_diag, dense, TD_n) = get_TD_operator(comp_grid,"D_z",options.FL);
9292

93-
CustomOP_explicit_sparse = kron(TV, SparseMatrixCSC{TF}(LinearAlgebra.I, comp_grid.n[3]-1,comp_grid.n[3]-1))* D
93+
CustomOP_explicit_sparse = kron(TV, SparseMatrixCSC{TF}(LinearAlgebra.I, comp_grid.n[3]-1,comp_grid.n[3]-1))* D;
9494

95-
D = joMatrix(D)
96-
TV = joMatrix(TV)
97-
CustomOP_JOLI = joKron(TV, joEye(comp_grid.n[3]-1,DDT=Float32,RDT=Float32))* D
95+
D = joMatrix(D);
96+
TV = joMatrix(TV);
97+
CustomOP_JOLI = joKron(TV, joEye(comp_grid.n[3]-1,DDT=Float32,RDT=Float32))* D;
9898

9999
## Solve using JOLI ##
100100

101101
#initialize constraints
102-
constraint = Vector{SetIntersectionProjection.set_definitions}()
102+
constraint = Vector{SetIntersectionProjection.set_definitions}();
103103

104104
m_min = 0.0
105105
m_max = 0.1*norm(CustomOP_JOLI*vec(m_tensor),1)
@@ -122,25 +122,25 @@ CustomOP_JOLI = joKron(TV, joEye(comp_grid.n[3]-1,DDT=Float32,RDT=Float32))* D
122122
@time (x_joli,log_PARSDMM) = PARSDMM(vec(m_tensor),AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
123123

124124
## solve using explicit sparse array ##
125-
constraint = Vector{SetIntersectionProjection.set_definitions}()
125+
constraint = Vector{SetIntersectionProjection.set_definitions}();
126126

127127
m_min = 0.0
128128
m_max = 0.025*norm(CustomOP_explicit_sparse*vec(m_tensor),1)
129129
set_type = "l1"
130130
TD_OP = "identity"
131131
app_mode = ("matrix","")
132132
custom_TD_OP = (CustomOP_explicit_sparse,false)
133-
push!(constraint, set_definitions(set_type,TD_OP,m_min,m_max,app_mode,custom_TD_OP))
133+
push!(constraint, set_definitions(set_type,TD_OP,m_min,m_max,app_mode,custom_TD_OP));
134134

135135
options=PARSDMM_options()
136-
(P_sub,TD_OP,set_Prop) = setup_constraints(constraint,comp_grid,options.FL)
136+
(P_sub,TD_OP,set_Prop) = setup_constraints(constraint,comp_grid,options.FL);
137137

138138
#set properties of custom operator
139139
set_Prop.AtA_diag[1] = false
140140
set_Prop.dense[1] = false
141141
set_Prop.banded[1] = true
142142

143-
(TD_OP,AtA,l,y) = PARSDMM_precompute_distribute(TD_OP,set_Prop,comp_grid,options)
143+
(TD_OP,AtA,l,y) = PARSDMM_precompute_distribute(TD_OP,set_Prop,comp_grid,options);
144144

145145
@time (x_sp,log_PARSDMM) = PARSDMM(vec(m_tensor),AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
146146
@time (x_sp,log_PARSDMM) = PARSDMM(vec(m_tensor),AtA,TD_OP,set_Prop,P_sub,comp_grid,options);

examples/projection_intersection_2D.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ options.parallel = true
114114

115115
@time (x2,log_PARSDMM) = PARSDMM(m,AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
116116
@time (x2,log_PARSDMM) = PARSDMM(m,AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
117-
#@time (x,log_PARSDMM) = PARSDMM(m,AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
117+
@time (x,log_PARSDMM) = PARSDMM(m,AtA,TD_OP,set_Prop,P_sub,comp_grid,options);
118118

119119
#plot
120120
figure();
@@ -125,27 +125,30 @@ savefig("projected_model_ParallelPARSDMM.png",bbox_inches="tight")
125125
#print timings in terminal
126126
log_PARSDMM.timing
127127

128-
# #use multilevel-serial (2-levels)
129-
# options.parallel = false
128+
#use multilevel-serial (2-levels)
129+
options.parallel = false
130130

131-
# #2 levels, the gird point spacing at level 2 is 3X that of the original (level 1) grid
132-
# n_levels = 2
133-
# coarsening_factor = 3
131+
#2 levels, the gird point spacing at level 2 is 3X that of the original (level 1) grid
132+
n_levels = 2
133+
coarsening_factor = 3
134134

135-
# #set up all required quantities for each level
136-
# #(m_levels,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels)=setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
137-
# (TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels)=setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
135+
#set up all required quantities for each level
136+
#(m_levels,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels)=setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
137+
(TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels)=setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
138138

139-
# println("")
140-
# println("PARSDMM multilevel-serial (bounds and bounds on D_z):")
141-
# @time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
142-
# @time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
143-
# @time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
139+
println("")
140+
println("PARSDMM multilevel-serial (bounds and bounds on D_z):")
141+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
142+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
143+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
144144

145-
# figure();
146-
# subplot(2,1,1);imshow(permutedims(reshape(m,(comp_grid.n[1],comp_grid.n[2])),[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("model to project")
147-
# subplot(2,1,2);imshow(permutedims(reshape(x,(comp_grid.n[1],comp_grid.n[2])),[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("Projection (bounds and bounds on D_z)")
148-
# savefig("projected_model_MultilevelSerialPARSDMM.png",bbox_inches="tight")
145+
figure();
146+
subplot(2,1,1);imshow(permutedims(reshape(m,(comp_grid.n[1],comp_grid.n[2])),[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("model to project")
147+
subplot(2,1,2);imshow(permutedims(reshape(x,(comp_grid.n[1],comp_grid.n[2])),[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("Projection (bounds and bounds on D_z)")
148+
savefig("projected_model_MultilevelSerialPARSDMM.png",bbox_inches="tight")
149+
150+
#print timings in terminal, note that for the multi-level version, the timing is only for the final level on the original grid
151+
log_PARSDMM.timing
149152

150153
# #now use multi-level with parallel PARSDMM
151154
# options.parallel=true

examples/projection_intersection_3D.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ subplot(3, 3, 8);plot(log_PARSDMM.gamma) ;title("gamma")
102102
subplot(3, 3, 9);semilogy(log_PARSDMM.evol_x) ;title("x evolution")
103103

104104
#plot
105-
m_plot = reshape(m,comp_grid.n)
105+
m_plot = reshape(m,comp_grid.n);
106106
figure();
107107
subplot(3,1,1);imshow(m_plot[:,:,Int64(round(comp_grid.n[3]/2))],cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, ymax, 0]); title("model to project x-y slice")
108108
subplot(3,1,2);imshow(permutedims(m_plot[:,Int64(round(comp_grid.n[2]/2)),:],[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("model to project x-z slice")
109109
subplot(3,1,3);imshow(permutedims(m_plot[Int64(round(comp_grid.n[1]/2)),:,:],[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, ymax, zmax, 0]); title("model to project y-z slice")
110110

111-
x_plot = reshape(x,comp_grid.n)
111+
x_plot = reshape(x,comp_grid.n);
112112
figure();
113113
subplot(3,1,1);imshow(x_plot[:,:,Int64(round(comp_grid.n[3]/2))],cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, ymax, 0]); title("projected model x-y slice")
114114
subplot(3,1,2);imshow(permutedims(x_plot[:,Int64(round(comp_grid.n[2]/2)),:],[2,1]),cmap="jet",vmin=vmi,vmax=vma,extent=[0, xmax, zmax, 0]); title("projected model x-z slice")
@@ -121,13 +121,13 @@ n_levels = 2
121121
coarsening_factor = 3
122122

123123
#set up all required quantities for each level
124-
(TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels) = setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
124+
(TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels) = setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options);
125125

126126
println("")
127127
println("PARSDMM multilevel-serial (bounds and bounds on D_z):")
128-
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options)
129-
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options)
130-
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options)
128+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
129+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
130+
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);
131131

132132
#parallel single level
133133
println("")
@@ -147,7 +147,7 @@ n_levels = 2
147147
coarsening_factor = 3
148148

149149
#set up all required quantities for each level
150-
(TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels) = setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options)
150+
(TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels) = setup_multi_level_PARSDMM(m,n_levels,coarsening_factor,comp_grid,constraint,options);
151151

152152
println("PARSDMM multilevel-parallel (bounds and bounds on D_z):")
153153
@time (x,log_PARSDMM) = PARSDMM_multi_level(m,TD_OP_levels,AtA_levels,P_sub_levels,set_Prop_levels,comp_grid_levels,options);

src/CDS_MVp.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export CDS_MVp, CDS_MVp2, CDS_MVp3, CDS_MVp4
1+
export CDS_MVp
22

33
"""
44
compute single-thread matrix vector product with vector x, output is vector y: y=A*x
@@ -19,9 +19,9 @@ function CDS_MVp(
1919
r0 = max(1, 1-d)
2020
r1 = min(N, N-d)
2121
c0 = max(1, 1+d)
22-
for r = r0 : r1
23-
c = r - r0 + c0 #original
24-
@inbounds y[r] = y[r] + R[r,i] * x[c]#original
22+
for r = r0 : r1
23+
c = r - r0 + c0 #original
24+
@inbounds y[r] = y[r] + R[r,i] * x[c]#original
2525
end
2626
end
2727
return y
@@ -51,7 +51,7 @@ end
5151
# return y
5252
# end
5353

54-
# function CDS_MVp(
54+
# function CDS_MVp2(
5555
# N ::Integer,
5656
# ndiags ::Integer,
5757
# R ::Array{TF,2},

src/PARSDMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ for i=1:maxit #main loop
104104
# x-minimization
105105
@timeit to "argmin x" begin
106106
copy!(x_old,x);
107-
(x,iter,relres,x_solve_tol_ref) = argmin_x(Q,rhs,x,x_solve_tol_ref,i,log_PARSDMM,Q_offsets,Ax_out)
107+
(x,iter,relres,x_solve_tol_ref) = argmin_x(Q,rhs,x,x_solve_tol_ref,i,log_PARSDMM,Q_offsets,Ax_out,comp_grid)
108108
log_PARSDMM.cg_it[i] = iter
109109
log_PARSDMM.cg_relres[i] = relres
110110
end #end timer for argmin x

src/PARSDMM_initialize.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ function PARSDMM_initialize(
8686
m = [m ; zeros(TF,length(m)) ]
8787
end
8888
if parallel
89-
#f_compute_rel_feas = (feasibility_initial,TD_OP,P_sub) -> compute_relative_feasibility(m,feasibility_initial,TD_OP,P_sub)
90-
#[@spawnat pid f_compute_rel_feas for pid in 2:nworkers()]
91-
#feasibility_initial = pmap(f_compute_rel_feas, feasibility_initial,TD_OP,P_sub; distributed=true, batch_size=1, on_error=nothing, retry_delays=[], retry_check=nothing)
92-
feasibility_initial = pmap((feasibility_initial,TD_OP,P_sub) -> compute_relative_feasibility(m,feasibility_initial,TD_OP,P_sub) , feasibility_initial,TD_OP,P_sub; distributed=true, batch_size=1, on_error=nothing, retry_delays=[], retry_check=nothing)
89+
feasibility_initial = distribute(feasibility_initial)
90+
[@sync @spawnat pid m for pid in P_sub.pids]
91+
[@sync @spawnat pid compute_relative_feasibility(m,feasibility_initial[:L],TD_OP[:L],P_sub[:L]) for pid in P_sub.pids]
92+
feasibility_initial = convert(Vector{TF},feasibility_initial)
93+
94+
#using pmap
95+
#feasibility_initial = pmap((feasibility_initial,TD_OP,P_sub) -> compute_relative_feasibility(m,feasibility_initial,TD_OP,P_sub) , feasibility_initial,TD_OP,P_sub; distributed=true, batch_size=1, on_error=nothing, retry_delays=[], retry_check=nothing)
9396
else
9497
for ii = 1:length(P_sub)
9598
feasibility_initial[ii] = norm(P_sub[ii](TD_OP[ii]*m) .- TD_OP[ii]*m) ./ (norm(TD_OP[ii]*m)+(100*eps(TF)));

src/SetIntersectionProjection.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ using JOLI
1515
using JOLI.FFTW, JOLI.Wavelets
1616
using SortingAlgorithms
1717
using TimerOutputs
18+
# using Flux
19+
# using NNlib
20+
# using CUDA
1821

1922
export log_type_PARSDMM, set_properties, PARSDMM_options, set_definitions
2023

@@ -39,6 +42,7 @@ include("CDS_MVp.jl")
3942
include("CDS_MVp_MT.jl")
4043
include("CDS_MVp_MT_subfunc.jl")
4144
include("CDS_scaled_add!.jl")
45+
include("CDS2stencil.jl")
4246

4347
#scripts for parallelism
4448
include("update_y_l_parallel.jl")
@@ -53,7 +57,7 @@ include("setup_multi_level_PARSDMM.jl")
5357
include("constraint2coarse.jl")
5458
include("interpolate_y_l.jl")
5559

56-
#scripts for setting up constraints, projetors, linear operators
60+
#scripts for setting up constraints, projectors, linear operators
5761
include("default_PARSDMM_options.jl")
5862
include("convert_options!.jl")
5963
include("get_discrete_Grad.jl");
@@ -97,13 +101,6 @@ mutable struct log_type_PARSDMM
97101
cg_it :: Vector{Integer}
98102
cg_relres :: Vector{Real}
99103
timing :: TimerOutput
100-
# T_cg :: Real
101-
# T_stop :: Real
102-
# T_ini :: Real
103-
# T_rhs :: Real
104-
# T_adjust_rho_gamma:: Real
105-
# T_y_l_upd :: Real
106-
# T_Q_upd :: Real
107104
end
108105

109106
@with_kw mutable struct PARSDMM_options

src/argmin_x.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ function argmin_x(
1111
i ::Integer,
1212
log_PARSDMM,
1313
Q_offsets=[],
14-
Ax_out=zeros(TF,length(x)) ::Vector{TF}
14+
Ax_out=zeros(TF,length(x)) ::Vector{TF},
15+
comp_grid=[]
1516
) where {TF<:Real}
1617

1718
#Initialize
@@ -20,8 +21,14 @@ function argmin_x(
2021
iter = 0
2122

2223
if typeof(Q)==Array{TF,2} #set up multi-threaded matrix-vector product in compressed diagonal storage format
23-
Af1(in) = (fill!(Ax_out,TF(0)); CDS_MVp_MT(size(Q,1),size(Q,2),Q,Q_offsets,in,Ax_out); return Ax_out)
24+
25+
# w = CDS2stencil(Q, Q_offsets, comp_grid.n)
26+
# w = w |> gpu
27+
# cdims = DenseConvDims(reshape(x, comp_grid.n ...,1,1), w, padding=Int(size(w,1)-1)/2)
2428

29+
Af1(in) = Ax_CDS_MT(in, Ax_out, Q, Q_offsets)
30+
#Af1(in) = Ax_stencil_cuda(in, Ax_out, comp_grid, w, cdims)
31+
2532
#determine what relative residual CG needs to reach
2633
if i<3 #i is the PARSDMM iteration counter
2734
x_solve_tol_ref = TF(max(0.1*norm(Af1(x)-rhs)/norm(rhs),10*eps(TF))) #10% of current relative residual
@@ -30,6 +37,7 @@ function argmin_x(
3037
end
3138

3239
(x,flag,relres,iter) = cg(Af1,rhs,tol=x_solve_tol_ref,maxIter=1000,x=x,out=0);
40+
#(x,flag,relres,iter) = cg(Q,Q_offsets,rhs, x_solve_tol_ref, 1000, x, 0)
3341

3442
elseif typeof(Q)==SparseMatrixCSC{TF,Int64} #CG with native Julia sparse CSC format MPVs
3543

@@ -60,3 +68,29 @@ function argmin_x(
6068

6169
return x,iter,relres,x_solve_tol_ref
6270
end
71+
72+
function Ax_CDS_MT(in::Vector{TF}, Ax_out::Vector{TF}, Q::Array{TF,2}, Q_offsets::Vector{Int}) where TF <: Real
73+
74+
fill!(Ax_out,TF(0))
75+
CDS_MVp_MT(size(Q,1),size(Q,2),Q,Q_offsets,in,Ax_out);
76+
77+
return Ax_out
78+
end
79+
80+
# function Ax_stencil_cuda(in::Vector{TF}, Ax_out::Vector{TF}, comp_grid, w::CUDA.CuArray{TF, 5}, cdims) where TF <: Real
81+
# in = in |> gpu
82+
# Ax_out = Ax_out |>gpu
83+
# fill!(Ax_out,TF(0))
84+
# Ax_out = reshape(Ax_out,comp_grid.n ...,1,1)
85+
# in = reshape(in,comp_grid.n ...,1,1)
86+
# conv!(Ax_out, in, w, cdims)
87+
88+
# # #for now, boundary conditions are an issue, manually correct it
89+
# # kernel_size = size(w,1)
90+
91+
92+
# Ax_out = Ax_out |> cpu
93+
94+
# return vec(Ax_out)
95+
96+
# end

0 commit comments

Comments
 (0)