Source code for analysis.PreProcess

import numpy as np
import pandas as pd
from tqdm import tqdm
import os

from astropy.io import fits
from astropy.table import Table, vstack
from astroquery.gaia import Gaia
from astropy import units as u
from astropy.table import join

from scipy.stats import norm
import warnings
from astropy.units import UnitsWarning

import logging
logging.basicConfig(level=logging.INFO)

[docs]def convert_to_astropy_table(data): """ Convert input data to an Astropy Table. Parameters ---------- data : str, np.recarray, pd.DataFrame, or Table Input data, which can be a file path (CSV, FITS, TXT), a NumPy recarray, a Pandas DataFrame, or already an Astropy Table. Returns ------- Table Converted Astropy Table. Raises ------ ValueError If the file type is unsupported or cannot be read. TypeError If the input data type is unsupported. """ # If already an Astropy Table, return as-is if isinstance(data, Table): return data # Convert NumPy recarray to Astropy Table elif isinstance(data, np.recarray): return Table(data) # Convert Pandas DataFrame to Astropy Table elif isinstance(data, pd.DataFrame): return Table.from_pandas(data) # If input is a string, check if it is a file path and try to read it elif isinstance(data, str): # Extract file extension safely file_ext = os.path.splitext(data.lower())[1] try: if file_ext == ".csv": return Table.read(data, format="csv") elif file_ext in [".fits", ".fit"]: # Suppress unit warnings only when reading FITS with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UnitsWarning) return Table.read(data, format="fits") elif file_ext in [".txt", ".dat"]: return Table.read(data, format="ascii") else: raise ValueError(f"Unsupported file type: {file_ext}. Supported types: CSV, FITS, TXT.") except Exception as e: raise ValueError(f"Failed to read file {data}: {e}") # If input type is unsupported, raise an error else: raise TypeError("Unsupported data type. Must be an Astropy Table, NumPy recarray, Pandas DataFrame, or a valid file path.")
[docs]def galah_filter(star_data_in, dynamics_data_in, gaia_data_in, save_path=None): """ Applies quality cuts to GALAH, Gaia, and dynamics datasets to produce a refined sample of metal-poor, high-eccentricity stars. This function filters stars based on data quality, chemical abundances, orbital properties, and distance uncertainties. Parameters ---------- star_data_in : str, Table, np.recarray, or pd.DataFrame GALAH stellar data, provided as a file path (CSV, FITS, TXT) or an Astropy Table, NumPy recarray, or Pandas DataFrame. dynamics_data_in : str, Table, np.recarray, or pd.DataFrame Dynamics dataset containing orbital properties (e.g., energy, eccentricity, actions). gaia_data_in : str, Table, np.recarray, or pd.DataFrame Gaia dataset providing distances and photogeometric uncertainties. save_path : str, optional If provided, saves the filtered dataset as a FITS file at the specified path. Returns ------- Table An Astropy Table containing the filtered stellar sample. Filtering Criteria ------------------ **1. Data Quality Cuts** - `flag_sp == 0` → Only include stars with reliable stellar parameters. - `snr_c3_iraf > 30` → Ensure good signal-to-noise ratio (SNR). - `logg < 3.0` → Select only giant stars. **2. Element Abundance Filters** - `[Fe/H]`: Only stars with `flag_fe_h == 0` and `e_fe_h < 0.2`. - `[alpha/Fe]`: Only stars with `flag_alpha_fe == 0` and `e_alpha_fe < 0.2`. - `[Na/Fe]`: Remove unreliable Na measurements (`flag_Na_fe == 0` and `e_Na_fe < 0.2`). - `[Al/Fe]`: Remove unreliable Al measurements (`flag_Al_fe == 0` and `e_Al_fe < 0.2`). - `[Mn/Fe]`: Remove unreliable Mn measurements (`flag_Mn_fe == 0` and `e_Mn_fe < 0.2`). - `[Y/Fe]`: Remove unreliable Y measurements (`flag_Y_fe == 0` and `e_Y_fe < 0.2`). - `[Ba/Fe]`: Remove unreliable Ba measurements (`flag_Ba_fe == 0` and `e_Ba_fe < 0.2`). - `[Eu/Fe]`: Remove unreliable Eu measurements (`flag_Eu_fe == 0` and `e_Eu_fe < 0.2`). **3. Derived Element Ratio Filters** - `[Mg/Cu]`: Exclude stars with unreliable values (`flag_Mg_Cu == 0`, `e_Mg_Cu < 0.2`). - `[Mg/Mn]`: Exclude stars with unreliable values (`flag_Mg_Mn == 0`, `e_Mg_Mn < 0.2`). - `[Ba/Eu]`: Exclude stars with unreliable values (`flag_Ba_Eu == 0`, `e_Ba_Eu < 0.2`). **4. Orbital and Kinematic Cuts** - `Eccentricity > 0.85` → Select stars on highly radial orbits. - `Energy < 0` → Remove stars with unbound or positive energy. - `R_ap > 5` → Require an apocenter larger than 5 kpc to focus on outer halo structures. **5. Distance Uncertainty Cut (Gaia)** - `(r_hi_photogeo - r_med_photogeo) < 1500 pc` → Reject stars with large upper uncertainty. - `(r_med_photogeo - r_lo_photogeo) < 1500 pc` → Reject stars with large lower uncertainty. **6. Ensuring Data Consistency** - The GALAH, Gaia, and dynamics datasets are matched using `sobject_id`. - Datasets are ordered to maintain consistency. - Duplicate entries are removed. Output ------ - The filtered dataset is returned as an Astropy Table. - If `save_path` is specified, the dataset is saved as a FITS file. Notes ----- - These cuts aim to select metal-poor stars on extreme orbits, relevant for studies of the Galactic halo and accretion history. - Stars that pass the filters will have **high-quality chemical abundances, well-measured kinematics, and accurate distances**. """ # Ensure that the input data can either be converted to an Astropy Table or is already an Astropy Table star_data = convert_to_astropy_table(star_data_in) dynamics_data = convert_to_astropy_table(dynamics_data_in) gaia_data = convert_to_astropy_table(gaia_data_in) # Store initial number of stars initial_star_count = len(star_data) logging.info(f"Initial number of stars: {initial_star_count}") # ------------------ REQUIRED KEYS CHECK ------------------ # Define the required columns for each dataset required_keys = { "star_data": [ "sobject_id", "flag_sp", "snr_c3_iraf", "logg", "flag_fe_h", "e_fe_h", "flag_alpha_fe", "e_alpha_fe", "flag_Na_fe", "e_Na_fe", "flag_Al_fe", "e_Al_fe", "flag_Mn_fe", "e_Mn_fe", "flag_Y_fe", "e_Y_fe", "flag_Ba_fe", "e_Ba_fe", "flag_Eu_fe", "e_Eu_fe", "flag_Mg_fe", "e_Mg_fe", "flag_Cu_fe", "e_Cu_fe", "Mg_fe", "Cu_fe", "Mn_fe", "Ba_fe", "Eu_fe" ], "dynamics_data": ["sobject_id", "Energy", "Energy_5", "Energy_95", "ecc", "R_ap", "J_R", "L_Z", "J_Z"], "gaia_data": ["sobject_id", "r_med_photogeo", "r_lo_photogeo", "r_hi_photogeo"] } # Function to check missing keys in a dataset def check_missing_keys(dataset, dataset_name): missing_keys = [key for key in required_keys[dataset_name] if key not in dataset.colnames] if missing_keys: raise ValueError(f"Missing required columns in {dataset_name}: {missing_keys}") # Check all datasets for required keys check_missing_keys(star_data, "star_data") check_missing_keys(dynamics_data, "dynamics_data") check_missing_keys(gaia_data, "gaia_data") # ------------------ Begin Chemical Filtering (Gallah) ------------------ # 1. Recommended Stellar Parameters Filter sp_filter = star_data['flag_sp'] == 0 # 3. Recommended Signal to Noise Ratio Filter snr_filter = star_data['snr_c3_iraf'] > 30 # 2. Filtering out bad star data rg_filter = star_data['logg'] < 3.0 # 3. Element Abundance Filters # [Fe/H], [α/Fe], [Na/ Fe], [Al/Fe], [Mn/Fe], [Y/Fe], [Ba/Fe], [Eu/Fe], [Mg/Cu], [Mg/Mn], [Ba/Eu] # Fe/H filter fe_h_flag_filter = star_data['flag_fe_h'] == 0 fe_h_err_filter = star_data['e_fe_h'] < 0.2 fe_h_filter = fe_h_flag_filter & fe_h_err_filter # [α/Fe] filter alpha_fe_flag_filter = star_data['flag_alpha_fe'] == 0 alpha_fe_err_filter = star_data['e_alpha_fe'] < 0.2 alpha_fe_filter = alpha_fe_flag_filter & alpha_fe_err_filter # [Na/Fe] filter na_fe_flag_filter = star_data['flag_Na_fe'] == 0 na_fe_err_filter = star_data['e_Na_fe'] < 0.2 na_fe_filter = na_fe_flag_filter & na_fe_err_filter # [Al/Fe] filter al_fe_flag_filter = star_data['flag_Al_fe'] == 0 al_fe_err_filter = star_data['e_Al_fe'] < 0.2 al_fe_filter = al_fe_flag_filter & al_fe_err_filter # [Mn/Fe] filter mn_fe_flag_filter = star_data['flag_Mn_fe'] == 0 mn_fe_err_filter = star_data['e_Mn_fe'] < 0.2 mn_fe_filter = mn_fe_flag_filter & mn_fe_err_filter # [Y/Fe] filter y_fe_flag_filter = star_data['flag_Y_fe'] == 0 y_fe_err_filter = star_data['e_Y_fe'] < 0.2 y_fe_filter = y_fe_flag_filter & y_fe_err_filter # [Ba/Fe] filter ba_fe_flag_filter = star_data['flag_Ba_fe'] == 0 ba_fe_err_filter = star_data['e_Ba_fe'] < 0.2 ba_fe_filter = ba_fe_flag_filter & ba_fe_err_filter # [Eu/Fe] filter eu_fe_flag_filter = star_data['flag_Eu_fe'] == 0 eu_fe_err_filter = star_data['e_Eu_fe'] < 0.2 eu_fe_filter = eu_fe_flag_filter & eu_fe_err_filter # [Mg/Cu] filter if 'Mg_CU' not in star_data.colnames: star_data['Mg_Cu'] = star_data['Mg_fe'] - star_data['Cu_fe'] if 'e_Mg_Cu' not in star_data.colnames: star_data['e_Mg_Cu'] = np.sqrt(star_data['e_Mg_fe']**2 + star_data['e_Cu_fe']**2) if 'flag_Mg_Cu' not in star_data.colnames: mg_fe_flag_filter = star_data['flag_Mg_fe'] == 0 cu_fe_flag_filter = star_data['flag_Cu_fe'] == 0 mg_cu_flag_filter = mg_fe_flag_filter & cu_fe_flag_filter else: mg_cu_flag_filter = star_data['flag_Mg_Cu'] == 0 mg_cu_err_filter = star_data['e_Mg_Cu'] < 0.2 mg_cu_filter = mg_cu_flag_filter & mg_cu_err_filter # [Mg/Mn] filter if 'Mg_Mn' not in star_data.colnames: star_data['Mg_Mn'] = star_data['Mg_fe'] - star_data['Mn_fe'] if 'e_Mg_Mn' not in star_data.colnames: star_data['e_Mg_Mn'] = np.sqrt(star_data['e_Mg_fe']**2 + star_data['e_Mn_fe']**2) if 'flag_Mg_Mn' not in star_data.colnames: mg_fe_flag_filter = star_data['flag_Mg_fe'] == 0 mn_fe_flag_filter = star_data['flag_Mn_fe'] == 0 mg_mn_flag_filter = mg_fe_flag_filter & mn_fe_flag_filter else: mg_mn_flag_filter = star_data['flag_Mg_Mn'] == 0 mg_mn_err_filter = star_data['e_Mg_Mn'] < 0.2 mg_mn_filter = mg_mn_flag_filter & mg_mn_err_filter # [Ba/Eu] filter if 'Ba_Eu' not in star_data.colnames: star_data['Ba_Eu'] = star_data['Ba_fe'] - star_data['Eu_fe'] if 'e_Ba_Eu' not in star_data.colnames: star_data['e_Ba_Eu'] = np.sqrt(star_data['e_Ba_fe']**2 + star_data['e_Eu_fe']**2) if 'flag_Ba_Eu' not in star_data.colnames: ba_fe_flag_filter = star_data['flag_Ba_fe'] == 0 eu_fe_flag_filter = star_data['flag_Eu_fe'] == 0 ba_eu_flag_filter = ba_fe_flag_filter & eu_fe_flag_filter else: ba_eu_flag_filter = star_data['flag_Ba_Eu'] == 0 ba_eu_err_filter = star_data['e_Ba_Eu'] < 0.2 ba_eu_filter = ba_eu_flag_filter & ba_eu_err_filter # ------------------ Apply stage 1 filters ------------------ star_data = star_data[sp_filter & snr_filter & rg_filter & fe_h_filter & alpha_fe_filter & na_fe_filter & al_fe_filter & mn_fe_filter & y_fe_filter & ba_fe_filter & eu_fe_filter & mg_cu_filter & mg_mn_filter & ba_eu_filter] # ------------------ Process tables so they can be combined ------------------ # Order remaining stars by object ID star_data = star_data[np.argsort(star_data['sobject_id'])] # Filter Dynamics and Import data to enties match the star data dynamics_filter = np.isin(dynamics_data['sobject_id'], star_data['sobject_id']) gaia_filter = np.isin(gaia_data['sobject_id'], star_data['sobject_id']) dynamics_data = dynamics_data[dynamics_filter] gaia_data = gaia_data[gaia_filter] # Order them by object ID dynamics_data = dynamics_data[np.argsort(dynamics_data['sobject_id'])] gaia_data = gaia_data[np.argsort(gaia_data['sobject_id'])] # Assert that tables match and no duplicates if len(star_data) != len(dynamics_data) or len(star_data) != len(gaia_data): raise ValueError("Mismatch in number of rows between filtered star data, dynamics data, and Gaia data.") if not np.array_equal(star_data['sobject_id'], dynamics_data['sobject_id']) or not np.array_equal(star_data['sobject_id'], gaia_data['sobject_id']): raise ValueError("sobject_id mismatch between datasets. Ensure they have the same order and unique IDs.") # ------------------ Add dynamics data to central Gallah table ------------------ # Energy star_data['Energy'] = dynamics_data['Energy'] # Assume the energy error is a normal distribution # Tranform from 5, 95th percentile to standard deviation star_data['e_Energy'] = (dynamics_data['Energy_95'] - dynamics_data['Energy_5'])/ (norm.ppf(0.95) - norm.ppf(0.05)) # Eccentricity star_data['Eccen'] = dynamics_data['ecc'] # Apocenter star_data['R_ap'] = dynamics_data['R_ap'] # Action variables star_data['J_R'] = dynamics_data['J_R'] star_data['L_Z'] = dynamics_data['L_Z'] star_data['J_Z'] = dynamics_data['J_Z'] # ------------------ Add Gaia data to central Gallah table ------------------ # Use photo_geometric distance rather than just photometric distance as offers more accuracy star_data['r_med_photogeo'] = gaia_data['r_med_photogeo'] star_data['r_lo_photogeo'] = gaia_data['r_lo_photogeo'] star_data['r_hi_photogeo'] = gaia_data['r_hi_photogeo'] star_data['phot_g_mean_mag'] = gaia_data['phot_g_mean_mag'] # ------------------ Filter Eccentricity and Energy and Apocenter ------------------ ecc_filter = star_data['Eccen'] > 0.85 energy_filter = star_data['Energy'] < 0 apocenter_filter = star_data['R_ap'] > 5 # ------------------ Filter for distance uncert ------------------ # Uncertainty less than 1.5 kpc - both upper and lower bounds taken seperately to be rigorous dist_err_filter_hi = (star_data['r_hi_photogeo']-star_data['r_med_photogeo']) < 1500 dist_err_filter_lo = (star_data['r_med_photogeo']-star_data['r_lo_photogeo']) < 1500 # ------------------ Apply stage 2 filters ------------------ star_data = star_data[ecc_filter & energy_filter & apocenter_filter & dist_err_filter_hi & dist_err_filter_lo] # Store final number of stars final_star_count = len(star_data) logging.info(f"Final number of stars: {final_star_count}") logging.info(f"Fraction retained: {final_star_count / initial_star_count:.2%}") # ------------------ Save filtered data if path provided ------------------ # Save data if a path is provided if save_path: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UnitsWarning) star_data.write(save_path, format="fits", overwrite=True) logging.info(f"Filtered dataset saved to {save_path}") return star_data
[docs]def apogee_filter(star_data_in, SQL=False, save_path=None): """ Applies quality cuts to APOGEE stellar data to produce a refined sample of chemically selected stars with extreme kinematics. This function filters stars based on data quality, chemical abundances, and orbital properties to isolate **metal-poor stars with extreme orbits**. If `SQL=True`, Gaia DR3 distances are queried using the **astroquery** package, and an additional filtering step removes stars with large distance errors. Parameters ---------- star_data_in : str, Table, np.recarray, or pd.DataFrame APOGEE stellar data, provided as a file path (CSV, FITS, TXT) or an Astropy Table, NumPy recarray, or Pandas DataFrame. SQL : bool, optional If `True`, queries Gaia DR3 for distances using `astroquery.gaia` and applies additional filtering based on distance uncertainties. Defaults to `False`. save_path : str, optional If provided, saves the filtered dataset as a FITS file at the specified path. Returns ------- Table An Astropy Table containing the filtered stellar sample. Filtering Criteria ------------------ **1. Data Quality Cuts** - `extratarg == 0` → Select only Main Red Stars (MRS). - `logg < 3.0` → Restrict to giant stars. **2. Element Abundance Filters** - `[Fe/H]`: Require `fe_h_flag == 0` and `fe_h_err < 0.1` for reliable iron abundance. - `[Al/Fe]`: Require `al_fe_flag == 0` and `al_fe_err < 0.1` for accurate aluminum measurement. - `[Ce/Fe]`: Require `ce_fe_flag == 0` and `ce_fe_err < 0.15` for precise cerium abundance. **3. Derived Element Ratio Filters** - `[Mg/Mn]`: If missing, computed as `mg_fe - mn_fe`. Require: - `mg_mn_flag == 0` (or `mg_fe_flag == 0` & `mn_fe_flag == 0` if `mg_mn_flag` is missing). - `mg_mn_err < 0.1` for reliable measurement. - `[alpha/Fe]`: Constructed if missing from individual elements. Require: - `alpha_fe_flag == 0` and `alpha_fe_err < 0.1`. **4. Orbital and Kinematic Cuts** - `Eccentricity (ecc_50) > 0.85` → Select stars on highly radial orbits. - `Energy (E_50) < 0` → Remove unbound or high-energy stars. **5. Additional Filters (If SQL=True)** - Queries **Gaia DR3** for distances (`r_med_photogeo`, `r_lo_photogeo`, `r_hi_photogeo`). - **Distance Uncertainty Cut**: Rejects stars with: - `(r_hi_photogeo - r_med_photogeo) < 1500` pc (upper bound uncertainty) - `(r_med_photogeo - r_lo_photogeo) < 1500` pc (lower bound uncertainty) **6. Ensuring Data Consistency** - Checks for required keys before filtering. - Drops stars with missing values in `ecc_50` or `E_50`. - Orders dataset to maintain consistency. - Ensures Gaia ID (`GAIAEDR3_SOURCE_ID` or `dr3_source_id`) is present when `SQL=True`. Output ------ - The filtered dataset is returned as an Astropy Table. - If `save_path` is specified, the dataset is saved as a FITS file. Notes ----- - This selection aims to **isolate metal-poor stars with extreme orbits**, relevant for Galactic archaeology and halo studies. - Stars that pass the filters have **high-quality chemical abundances, well-measured kinematics, and a robust selection based on APOGEE data**. """ # Ensure that the input data can either be converted to an Astropy Table or is already an Astropy Table star_data = convert_to_astropy_table(star_data_in) # Store initial number of stars initial_star_count = len(star_data) logging.info(f"Initial number of stars: {initial_star_count}") # ------------------ REQUIRED KEYS CHECK ------------------ required_keys = [ "extratarg", "logg", "fe_h", "fe_h_err", "fe_h_flag", "al_fe", "al_fe_err", "al_fe_flag", "ce_fe", "ce_fe_err", "ce_fe_flag", "mg_fe", "mg_fe_err", "mg_fe_flag", "mn_fe", "mn_fe_err", "mn_fe_flag", "alpha_fe_err", "ecc_50", "E_50" ] # If SQL requirement is true, add the Gaia ID column - whether that be `GAIAEDR3_SOURCE_ID` or `dr3_source_id` if SQL: # Check if `GAIAEDR3_SOURCE_ID` or `dr3_source_id` exists if "GAIAEDR3_SOURCE_ID" in star_data.colnames: gaia_id_col = "GAIAEDR3_SOURCE_ID" elif "dr3_source_id" in star_data.colnames: gaia_id_col = "dr3_source_id" else: raise ValueError("SQL=True, but no Gaia source ID column (`GAIAEDR3_SOURCE_ID` or `dr3_source_id`) found in dataset.") # Add the correct Gaia ID column required_keys.append(gaia_id_col) # Function to check missing keys missing_keys = [key for key in required_keys if key not in star_data.colnames] if missing_keys: raise ValueError(f"Missing required columns in star_data: {missing_keys}") # ------------------ Filtering Data ------------------ # 1. Main Red Stars Filter mrs_filter = star_data['extratarg'] == 0 # 2. Filtering out bad star data # bs_filter = star_data['ASPCAPFLAG'] != 'STAR_BAD' # prog_filter = star_data['PROGRAMNAME'] != 'magclouds' rg_filter = star_data['logg'] < 3.0 # 3. Element Abundance Filters # Fe/H filter fe_h_flag_filter = star_data['fe_h_flag'] == 0 # ASSUME THIS WAS SUPPOSED Top BE DONE fe_h_err_filter = star_data['fe_h_err'] < 0.1 fe_h_filter = fe_h_flag_filter & fe_h_err_filter # Al/Fe filter al_fe_flag_filter = star_data['al_fe_flag'] == 0 al_fe_err_filter = star_data['al_fe_err'] < 0.1 al_fe_filter = al_fe_flag_filter & al_fe_err_filter # Ce/Fe filter ce_fe_flag_filter = star_data['ce_fe_flag'] == 0 ce_fe_err_filter = star_data['ce_fe_err'] < 0.15 ce_fe_filter = ce_fe_flag_filter & ce_fe_err_filter # Mg/Mn filter # Data Values if 'mg_mn' not in star_data.colnames: star_data['mg_mn'] = star_data['mg_fe'] - star_data['mn_fe'] # Flag filter if 'mg_mn_flag' not in star_data.colnames: mg_fe_flag_filter = star_data['mg_fe_flag'] == 0 mn_fe_flag_filter = star_data['mn_fe_flag'] == 0 mg_mn_flag_filter = mg_fe_flag_filter & mn_fe_flag_filter else: mg_mn_flag_filter = star_data['mg_mn_flag'] == 0 # Error values if 'mg_mn_err' not in star_data.colnames: star_data['mg_mn_err'] = np.sqrt(star_data['mg_fe_err']**2 + star_data['mn_fe_err']**2) mg_mn_err_filter = star_data['mg_mn_err'] < 0.1 mg_mn_filter = mg_mn_flag_filter & mg_mn_err_filter # Alpha/Fe filter if 'alpha_fe_flag' not in star_data.colnames: o_fe_flag_filter = star_data['o_fe_flag'] == 0 mg_fe_flag_filter = star_data['mg_fe_flag'] == 0 si_fe_flag_filter = star_data['si_fe_flag'] == 0 ca_fe_flag_filter = star_data['ca_fe_flag'] == 0 ti_fe_flag_filter = star_data['ti_fe_flag'] == 0 alpha_fe_flag_filter = o_fe_flag_filter & mg_fe_flag_filter & si_fe_flag_filter & ca_fe_flag_filter & ti_fe_flag_filter else: alpha_fe_flag_filter = star_data['alpha_fe_flag'] == 0 # Error values if 'alpha_m_err' not in star_data.colnames: o_fe_error_filter = star_data['o_fe_err'] < 0.1 mg_fe_error_filter = star_data['mg_fe_err'] < 0.1 si_fe_error_filter = star_data['si_fe_err'] < 0.1 ca_fe_error_filter = star_data['ca_fe_err'] < 0.1 ti_fe_error_filter = star_data['ti_fe_err'] < 0.1 alpha_fe_err_filter = o_fe_error_filter & mg_fe_error_filter & si_fe_error_filter & ca_fe_error_filter & ti_fe_error_filter else: alpha_fe_err_filter = star_data['alpha_m_err'] < 0.1 alpha_fe_filter = alpha_fe_flag_filter & alpha_fe_err_filter # ------------------ Applying Stage 1 Filters ------------------ # Extract only Main Red Stars apogee_data_red = star_data[mrs_filter] # Apply all filters to get the final cleaned dataset star_data = star_data[mrs_filter & rg_filter & fe_h_filter & al_fe_filter & ce_fe_filter & mg_mn_filter & alpha_fe_filter] # & bs_filter & prog_filter # ------------------ Filter Eccentricity and Energy and Apocenter ------------------ ecc_filter = star_data['ecc_50'] > 0.85 energy_filter = star_data['E_50'] < 0 # Missing distance uncertainty and apocenter filter # apocenter_filter = star_data['R_ap'] > 5 # dist_err_filter = apogee_data['DIST_ERR'] < 1.5 # ------------------ Apply stage 2 filters ------------------ star_data = star_data[ecc_filter & energy_filter] # & apocenter_filter & dist_err_filter # ------------------ SQL-Based Gaia Distance Query ------------------ if SQL: logging.info("Querying Gaia for distances...") # Extract Gaia IDs gaia_ids = np.array(star_data[gaia_id_col]) # Set size for SQL query and split up GAIA IDs query_size = 750 indiv_queries = np.array_split(gaia_ids, np.ceil(len(gaia_ids) / query_size)) # Empty list to store the results of each query # Track missing GAIA IDs list_query_results = [] missing_ids_set = set() # Query Gaia in chunks for i, query in enumerate(tqdm(indiv_queries, desc="Processing Queries")): gaia_id_list = ", ".join(query.astype(str)) # Define SQL query distance_query = f""" SELECT source_id, r_med_photogeo, r_lo_photogeo, r_hi_photogeo FROM external.gaiaedr3_distance WHERE source_id IN ({gaia_id_list}); """ # Run query job = Gaia.launch_job(distance_query) results = job.get_results() # Store missing IDs query_ids = set(query) returned_ids = set(results['source_id']) missing_ids_set.update(query_ids - returned_ids) # Append results list_query_results.append(results) # Combine results all_query_results = vstack(list_query_results) missing_gaia_ids = np.array(list(missing_ids_set)) # Remove stars with missing GAIA Data missing_ids_position = np.isin(gaia_ids, missing_gaia_ids) star_data = Table(star_data[~missing_ids_position]) # Sort tables by GAIA ID's all_query_results.sort('source_id') star_data.sort(gaia_id_col) # Check if the GAIA ID's match before merging if not np.array_equal(star_data[gaia_id_col], all_query_results['source_id']): raise ValueError("Mismatch in GAIA IDs - Ensure the IDs match before merging.") # Ensure all_query_results is an Astropy Table all_query_results = Table(all_query_results) # Merge Gaia distances into the main dataset star_data['r_med_photogeo'] = all_query_results['r_med_photogeo'] star_data['r_lo_photogeo'] = all_query_results['r_lo_photogeo'] star_data['r_hi_photogeo'] = all_query_results['r_hi_photogeo'] # ------------------ Save distance errors if data from sql provided ------------------ dist_err_filter_hi = (star_data['r_hi_photogeo'] - star_data['r_med_photogeo']) < 1500 dist_err_filter_lo = (star_data['r_med_photogeo'] - star_data['r_lo_photogeo']) < 1500 star_data = star_data[dist_err_filter_hi & dist_err_filter_lo] # ------------------ Save filtered data if path provided ------------------ # Store final number of stars final_star_count = len(star_data) logging.info(f"Final number of stars: {final_star_count}") logging.info(f"Fraction retained: {final_star_count / initial_star_count:.2%}") if save_path: star_data.write(save_path, format="fits", overwrite=True) logging.info(f"Filtered dataset saved to {save_path}") logging.info("\n=== Filter Diagnostics: Stars Rejected by Each Criterion ===") # Convert to full table if not already original_data = convert_to_astropy_table(star_data_in) N_initial = len(original_data) # Diagnostic counts for each mask filters = { "Main Red Stars (extratarg == 0)": mrs_filter, "logg < 3.0": rg_filter, "[Fe/H] quality": fe_h_filter, "[Al/Fe] quality": al_fe_filter, "[Ce/Fe] quality": ce_fe_filter, "[Mg/Mn] quality": mg_mn_filter, "[alpha/Fe] Overall Cut": alpha_fe_filter, "[alpha/Fe] - [O/Fe]": o_fe_flag_filter, "[alpha/Fe] - [Mg/Fe]": mg_fe_flag_filter, "[alpha/Fe] - [Si/Fe]": si_fe_flag_filter, "[alpha/Fe] - [Ca/Fe]": ca_fe_flag_filter, "[alpha/Fe] - [Ti/Fe]": ti_fe_flag_filter, "Eccentricity > 0.85": ecc_filter, "Energy < 0": energy_filter, } for name, mask in filters.items(): n_failed = len(mask) - np.sum(mask) logging.info(f"{name:30s}{n_failed:4d} stars removed") # SQL-based filtering if SQL: n_sql_failed = len(dist_err_filter_hi) - np.sum(dist_err_filter_hi & dist_err_filter_lo) logging.info(f"Gaia SQL distance cut → {n_sql_failed:4d} stars removed") return star_data