1111import requests
1212from astropy import units as u
1313from astropy .coordinates import SkyCoord
14+ from astropy .io import ascii
1415from astropy .table import Table , vstack
1516
1617try :
2223 Gaia .ROW_LIMIT = 0
2324
2425from ..utils import makedirs_if_needed
26+ from ..utils .overlap_checker import is_within
2527from .core import DownloadableBase , FitsTable
2628
2729_HAS_CASJOBS_ = True
@@ -180,6 +182,14 @@ class SdssQuery(DownloadableBase):
180182 WHERE n.objID = p.objID
181183 """
182184
185+ _spec_query_template = """
186+ SELECT s.specObjID as OBJID, s.bestObjID, s.ra as RA, s.dec as DEC, s.survey,
187+ ISNULL(s.z, -1) as SPEC_Z, ISNULL(s.zErr, -1) as SPEC_Z_ERR, ISNULL(s.zWarning, -1) as SPEC_Z_WARN
188+ FROM dbo.fGetNearbySpecObjEq({ra:.10g}, {dec:.10g}, {r_arcmin:.10g}) n, SpecObj s
189+ INTO mydb.{db_table_name}
190+ WHERE n.specObjID = s.specObjID
191+ """
192+
183193 def __init__ (
184194 self ,
185195 ra ,
@@ -190,6 +200,8 @@ def __init__(
190200 user = None ,
191201 password = None ,
192202 default_use_sciserver = True ,
203+ specs_only = False ,
204+ sciserver_via_csv = False ,
193205 ):
194206
195207 self .sciserver_user = user or os .getenv ("SCISERVER_USER" )
@@ -198,12 +210,7 @@ def __init__(
198210 self .casjobs_pass = password or os .getenv ("CASJOBS_PW" )
199211
200212 self .use_sciserver = False
201- if (
202- default_use_sciserver
203- and _HAS_SCISERVER_
204- and self .sciserver_user
205- and self .sciserver_pass
206- ):
213+ if default_use_sciserver and _HAS_SCISERVER_ and self .sciserver_user and self .sciserver_pass :
207214 self .use_sciserver = True
208215
209216 if self .use_sciserver :
@@ -212,8 +219,9 @@ def __init__(
212219 if not db_table_name :
213220 db_table_name = "SAGA" + get_random_string (4 )
214221 self .db_table_name = re .sub ("[^A-Za-z]" , "" , db_table_name )
215- self .query = self .construct_query (ra , dec , radius , self .db_table_name )
222+ self .query = self .construct_query (ra , dec , radius , self .db_table_name , specs_only )
216223 self .context = context
224+ self .sciserver_via_csv = sciserver_via_csv
217225
218226 def download_as_file (self , file_path , overwrite = False , compress = True ):
219227 if os .path .isfile (file_path ) and not overwrite :
@@ -227,6 +235,7 @@ def download_as_file(self, file_path, overwrite=False, compress=True):
227235 context = self .context ,
228236 username = self .sciserver_user ,
229237 password = self .sciserver_pass ,
238+ via_csv = self .sciserver_via_csv ,
230239 )
231240 else :
232241 self .run_casjobs_with_casjobs (
@@ -240,7 +249,7 @@ def download_as_file(self, file_path, overwrite=False, compress=True):
240249 )
241250
242251 @classmethod
243- def construct_query (cls , ra , dec , radius = 1.0 , db_table_name = None ):
252+ def construct_query (cls , ra , dec , radius = 1.0 , db_table_name = None , specs_only = False ):
244253 """
245254 Generates the query to send to the SDSS to get the full SDSS catalog around
246255 a target.
@@ -261,21 +270,16 @@ def construct_query(cls, ra, dec, radius=1.0, db_table_name=None):
261270 The SQL query to send to the SDSS skyserver
262271 """
263272
264- select_into_mydb = True
265- if db_table_name is None :
266- db_table_name = "TO_BE_REMOVED"
267- select_into_mydb = False
268-
269- # pylint: disable=possibly-unused-variable
270- ra = ensure_deg (ra )
271- dec = ensure_deg (dec )
272- r_arcmin = ensure_deg (radius ) * 60.0
273-
274- # ``**locals()`` means "use the local variable names to fill the template"
275- q = cls ._query_template .format (** locals ())
276- q = re .sub (r"[^\S\n]+" , " " , q ).strip ()
277- if not select_into_mydb :
278- q = q .replace ("INTO mydb.{}" .format (db_table_name ), "" )
273+ params = {
274+ "ra" : ensure_deg (ra ),
275+ "dec" : ensure_deg (dec ),
276+ "r_arcmin" : ensure_deg (radius ) * 60.0 ,
277+ "db_table_name" : (db_table_name or "__TO_BE_REMOVED__" ),
278+ }
279+ query_template = cls ._spec_query_template if specs_only else cls ._query_template
280+ q = query_template .format (** params )
281+ q = re .sub (r"\s+" , " " , q ).strip ()
282+ q = q .replace ("INTO mydb.__TO_BE_REMOVED__" , "" )
279283 return q
280284
281285 @staticmethod
@@ -324,9 +328,7 @@ def run_casjobs_with_casjobs(
324328 "POST" ,
325329 )
326330
327- job_id = cjob .submit (
328- query , context = context , task_name = "casjobs_" + db_table_name , estimate = 1
329- )
331+ job_id = cjob .submit (query , context = context , task_name = "casjobs_" + db_table_name , estimate = 1 )
330332 print (
331333 time .strftime ("[%m/%d %H:%M:%S]" ),
332334 "casjob ({}) submitted..." .format (db_table_name ),
@@ -356,7 +358,13 @@ def run_casjobs_with_casjobs(
356358
357359 @staticmethod
358360 def run_casjobs_with_sciserver (
359- query , output_path , compress = True , context = "DR14" , username = None , password = None
361+ query ,
362+ output_path ,
363+ compress = True ,
364+ context = "DR14" ,
365+ username = None ,
366+ password = None ,
367+ via_csv = False ,
360368 ):
361369 """
362370 Run a single casjobs and download casjobs output using SciServer
@@ -388,10 +396,14 @@ def run_casjobs_with_sciserver(
388396 if not (_HAS_SCISERVER_ and username and password ):
389397 raise ValueError ("You are not setup to run casjobs with SciServer" )
390398 SciServer .Authentication .login (username , password )
391- r = SciServer .CasJobs .executeQuery (query , context = context , format = "fits" )
399+ return_format = "csv" if via_csv else "fits"
400+ r = SciServer .CasJobs .executeQuery (query , context = context , format = return_format )
392401 file_open = gzip .open if compress else open
393402 with file_open (output_path , "wb" ) as f_out :
394- shutil .copyfileobj (r , f_out )
403+ if via_csv :
404+ ascii .read (r , format = "csv" ).write (f_out , format = "fits" )
405+ else :
406+ shutil .copyfileobj (r , f_out )
395407
396408
397409class DesQuery (DownloadableBase ):
@@ -553,7 +565,7 @@ def __init__(
553565 ra ,
554566 dec ,
555567 radius = 1.0 ,
556- decals_dr = "dr7 " ,
568+ decals_dr = "dr9 " ,
557569 decals_base_dir = "/global/project/projectdirs/cosmo/data/legacysurvey" ,
558570 ):
559571
@@ -607,61 +619,25 @@ def get_ra_dec_range(cls, filename):
607619 ra_max , dec_max = cls .brickname_to_ra_dec (bmax )
608620 return ra_min , ra_max , dec_min , dec_max
609621
610- @staticmethod
611- def is_within (ra , dec , ra_min , ra_max , dec_min , dec_max , margin_ra = 0 , margin_dec = 0 ):
612- return (
613- (ra_min - margin_ra <= ra )
614- & (ra_max + margin_ra >= ra )
615- & (dec_min - margin_dec <= dec )
616- & (dec_max + margin_dec >= dec )
617- )
618-
619- @staticmethod
620- def annotate_catalog (d ):
621- for band in "grz" :
622- BAND = band .upper ()
623- d [band + "_mag" ] = 22.5 - 2.5 * np .log10 (
624- d ["FLUX_" + BAND ] / d ["MW_TRANSMISSION_" + BAND ]
625- )
626- d [band + "_err" ] = (
627- 2.5
628- / np .log (10 )
629- / (d ["FLUX_" + BAND ] / d ["MW_TRANSMISSION_" + BAND ])
630- / np .sqrt (d ["FLUX_IVAR_" + BAND ])
631- )
632- return d
633-
634622 def get_decals_catalog (self ):
623+ center_coord = SkyCoord (self .ra , self .dec , unit = "deg" )
635624 output = []
636625 for sweep_dir in self .sweep_dirs :
637626 for filename in sorted (os .listdir (sweep_dir )):
638627 if not filename .startswith ("sweep-" ) or not filename .endswith (".fits" ):
639628 continue
640- if not self .is_within (
641- self .ra ,
642- self .dec ,
643- * self .get_ra_dec_range (filename ),
644- margin_ra = (self .radius * 1.01 / max (np .cos (np .deg2rad (self .dec )), 1.0e-8 )),
645- margin_dec = (self .radius * 1.01 ),
646- ):
629+ if not is_within (self .ra , self .dec , * self .get_ra_dec_range (filename ), margin = self .radius ):
647630 continue
648-
649631 d = FitsTable (os .path .join (sweep_dir , filename )).read ()
650- sep = (
651- SkyCoord (d ["RA" ], d ["DEC" ], unit = "deg" )
652- .separation (SkyCoord (self .ra , self .dec , unit = "deg" ))
653- .deg
654- )
655- d = d [sep <= self .radius ]
656- if len (d ):
657- output .append (d )
658- del d
632+ mask = SkyCoord (d ["RA" ], d ["DEC" ], unit = "deg" ).separation (center_coord ).deg <= self .radius
633+ if mask .any ():
634+ output .append (d [mask ])
635+ del d , mask
659636
660637 if not output :
661638 return Table ()
662639
663- output = vstack (output , "exact" )
664- return self .annotate_catalog (output )
640+ return vstack (output , "exact" )
665641
666642 def download_as_file (self , file_path , overwrite = False , compress = True ):
667643 if os .path .isfile (file_path ) and not overwrite :
0 commit comments