From c1e071d49e97ac46ddf0b0169d2c3876d9d65dbd Mon Sep 17 00:00:00 2001 From: KatrionaGoldmann Date: Wed, 27 Nov 2024 10:20:00 +0000 Subject: [PATCH 1/3] multithreading --- s3_download_with_inference.py | 11 ++++++ utils/aws_scripts.py | 65 ++++++++++++++++------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/s3_download_with_inference.py b/s3_download_with_inference.py index cc3bede..68dd416 100644 --- a/s3_download_with_inference.py +++ b/s3_download_with_inference.py @@ -159,6 +159,8 @@ def download_and_inference( ) print("\N{White Heavy Check Mark}") + num_workers = int(os.getenv("SLURM_CPUS_PER_TASK", os.cpu_count())) + parser = argparse.ArgumentParser( description="Script for downloading and processing images from S3." ) @@ -236,6 +238,12 @@ def download_and_inference( help="The path to the binary model weights", default="./models/v1_localizmodel_2021-08-17-12-06.pt", ) + parser.add_argument( + "--num_workers", + type=int, + default=num_workers, + help="Number of workers for multi-threaded downloads", + ) args = parser.parse_args() @@ -268,6 +276,8 @@ def download_and_inference( print("\033[93m - Not keeping crops\033[0m") crops_interval = None + print(f"\033[93m - Number of workers: {args.num_workers}\033[0m") + download_and_inference( args.country, args.deployment, @@ -276,4 +286,5 @@ def download_and_inference( data_storage_path, args.perform_inference, args.remove_image, + args.num_workers, ) diff --git a/utils/aws_scripts.py b/utils/aws_scripts.py index 4209176..83c2b34 100644 --- a/utils/aws_scripts.py +++ b/utils/aws_scripts.py @@ -9,6 +9,11 @@ from utils.inference_scripts import perform_inf import pandas as pd import sys +import tqdm +import math +from functools import partial +from concurrent.futures import ThreadPoolExecutor +import re def get_deployments(username, password): @@ -210,6 +215,7 @@ def get_objects( csv_file="results.csv", rerun_existing=False, crops_interval=None, + num_workers=1, ): """ Fetch objects from the S3 bucket and download them synchronously in batches. @@ -224,10 +230,6 @@ def get_objects( operation_parameters = {"Bucket": bucket_name, "Prefix": prefix} page_iterator = paginator.paginate(**operation_parameters) - progress_bar = tqdm.tqdm( - total=total_files, desc="Download files from server synchronously" - ) - if crops_interval is not None: t = first_dt intervals = [] @@ -242,39 +244,31 @@ def get_objects( if os.path.basename(page.get("Contents", [])[0]["Key"]).startswith("$"): print(f'{page.get("Contents", [])[0]["Key"]} is suspected corrupt, skipping') continue - + for obj in page.get("Contents", []): keys.append(obj["Key"]) - if len(keys) >= batch_size: - download_batch( - s3_client, - bucket_name, - keys, - local_path, - perform_inference, - remove_image, - localisation_model, - binary_model, - order_model, - order_labels, - species_model, - species_labels, - country, - region, - device, - order_data_thresholds, - csv_file, - rerun_existing, - intervals, - ) - keys = [] - progress_bar.update(batch_size) - if keys: + # don't rerun previously analysed images + results_df = pd.read_csv(csv_file, dtype=str) + run_images = [re.sub(r'^.*?dep', 'dep', x) for x in results_df['image_path']] + keys = [x for x in keys if x not in run_images] + + # Divide the keys among workers + chunks = [ + keys[i : i + math.ceil(len(keys) / num_workers)] + for i in range(0, len(keys), math.ceil(len(keys) / num_workers)) + ] + + # Shared progress bar + progress_bar = tqdm.tqdm(total=total_files, desc=f"Download files for {os.path.basename(csv_file).replace('_results.csv', '')}") + + def process_chunk(chunk): + for i in range(0, len(chunk), batch_size): + batch_keys = chunk[i : i + batch_size] download_batch( s3_client, bucket_name, - keys, + batch_keys, local_path, perform_inference, remove_image, @@ -282,16 +276,17 @@ def get_objects( binary_model, order_model, order_labels, - species_model, - species_labels, country, region, device, order_data_thresholds, csv_file, rerun_existing, - intervals, ) - progress_bar.update(len(keys)) + progress_bar.update(len(batch_keys)) + + # Use ThreadPoolExecutor instead of multiprocessing + with ThreadPoolExecutor(max_workers=num_workers) as executor: + executor.map(process_chunk, chunks) progress_bar.close() From 2d91db01c8b768ad9bfbde4662bafeee55003c5a Mon Sep 17 00:00:00 2001 From: KatrionaGoldmann Date: Wed, 27 Nov 2024 17:11:59 +0000 Subject: [PATCH 2/3] generate keys file for batch processing --- ..._deployments.py => 01_print_deployments.py | 2 +- 02_generate_keys.py | 82 +++++++++++++++++++ README.md | 9 +- 3 files changed, 91 insertions(+), 2 deletions(-) rename print_deployments.py => 01_print_deployments.py (95%) create mode 100644 02_generate_keys.py diff --git a/print_deployments.py b/01_print_deployments.py similarity index 95% rename from print_deployments.py rename to 01_print_deployments.py index 903e52d..c208f2d 100644 --- a/print_deployments.py +++ b/01_print_deployments.py @@ -92,7 +92,7 @@ def print_deployments(include_inactive=False, subset_countries=None, print_image for dep in sorted(all_deps): dep_info = [x for x in country_depl if x['deployment_id'] == dep][0] print(f"\033[1m - Deployment ID: {dep_info['deployment_id']}, Name: {dep_info['location_name']}, Deployment Key: '{dep_info['location_name']} - {dep_info['camera_id']}'\033[0m") - print(f" Location ID: {dep_info['location_id']}, Latitute: {dep_info['lat']}, Longitute: {dep_info['lon']}, Camera ID: {dep_info['camera_id']}, System ID: {dep_info['system_id']}, Status: {dep_info['status']}") + print(f" Location ID: {dep_info['location_id']}, Country code: {dep_info['country_code'].lower()}, Latitute: {dep_info['lat']}, Longitute: {dep_info['lon']}, Camera ID: {dep_info['camera_id']}, System ID: {dep_info['system_id']}, Status: {dep_info['status']}") # get the number of images for this deployment prefix = f"{dep_info['deployment_id']}/snapshot_images" diff --git a/02_generate_keys.py b/02_generate_keys.py new file mode 100644 index 0000000..0e00c5a --- /dev/null +++ b/02_generate_keys.py @@ -0,0 +1,82 @@ +import boto3 +import argparse +import json + +def list_s3_keys(bucket_name, prefix=""): + """ + List all keys in an S3 bucket under a specific prefix. + + Parameters: + bucket_name (str): The name of the S3 bucket. + prefix (str): The prefix to filter keys (default: ""). + + Returns: + list: A list of S3 object keys. + """ + with open("./credentials.json", encoding="utf-8") as config_file: + aws_credentials = json.load(config_file) + + + session = boto3.Session( + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + region_name=aws_credentials["AWS_REGION"], + ) + s3_client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"]) + + keys = [] + continuation_token = None + + while True: + list_kwargs = { + "Bucket": bucket_name, + "Prefix": prefix, + } + if continuation_token: + list_kwargs["ContinuationToken"] = continuation_token + + response = s3_client.list_objects_v2(**list_kwargs) + + # Add object keys to the list + for obj in response.get("Contents", []): + keys.append(obj["Key"]) + + # Check if there are more objects to list + if response.get("IsTruncated"): # If True, there are more results + continuation_token = response["NextContinuationToken"] + else: + break + + return keys + +def save_keys_to_file(keys, output_file): + """ + Save S3 keys to a file, one per line. + + Parameters: + keys (list): List of S3 keys. + output_file (str): Path to the output file. + """ + with open(output_file, "w") as f: + for key in keys: + f.write(key + "\n") + +def main(): + parser = argparse.ArgumentParser(description="Generate a file containing S3 keys from a bucket.") + parser.add_argument("--bucket", type=str, required=True, help="Name of the S3 bucket.") + parser.add_argument("--prefix", type=str, default="", help="Prefix to filter objects (default: '')") + parser.add_argument("--output_file", type=str, default="s3_keys.txt", help="Output file to save S3 keys.") + args = parser.parse_args() + + + + # List keys from the specified S3 bucket and prefix + print(f"Listing keys from bucket '{args.bucket}' with prefix '{args.prefix}'...") + keys = list_s3_keys(args.bucket, args.prefix) + + # Save keys to the output file + save_keys_to_file(keys, args.output_file) + print(f"Saved {len(keys)} keys to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index 1bf7e4b..65a2326 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ Load the conda env: ```bash source ~/miniforge3/bin/activate -conda activate "~/moth_detector_env/" +conda activate "~/conda_envs/moth_detector_env/" ``` Inferences are run by country and deployment site. To run the script, for Costa Rica say, use the following command: @@ -104,6 +104,13 @@ python s3_download_with_inference.py \ --deployment "Garden - 3F1C4908" ``` +### Generating the Keys + +```bash +python 02_generate_keys.py --bucket 'cri' --prefix '' --output_file s3_keys.txt +``` + + ## Running with slurm To run with slurm you need to be logged in on the [scientific nodes](https://help.jasmin.ac.uk/docs/interactive-computing/sci-servers/). From 9fc77a6b5a25ac81d8e6f1d29a7e9e026396c416 Mon Sep 17 00:00:00 2001 From: KatrionaGoldmann Date: Fri, 29 Nov 2024 17:31:35 +0000 Subject: [PATCH 3/3] chunk processing --- .gitignore | 1 + 02_generate_keys.py | 10 +- 03_pre_chop_files.py | 61 +++++++ 04_process_chunks.py | 217 ++++++++++++++++++++++ README.md | 36 +++- s3_download_with_inference.py | 290 ----------------------------- utils/aws_scripts.py | 333 ++++++---------------------------- utils/custom_models.py | 18 +- utils/inference_scripts.py | 30 ++- 9 files changed, 398 insertions(+), 598 deletions(-) create mode 100644 03_pre_chop_files.py create mode 100644 04_process_chunks.py delete mode 100644 s3_download_with_inference.py diff --git a/.gitignore b/.gitignore index a769b8a..7f7ce50 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ credentials.json *.out data/ models/ +keys/ diff --git a/02_generate_keys.py b/02_generate_keys.py index 0e00c5a..f19ed65 100644 --- a/02_generate_keys.py +++ b/02_generate_keys.py @@ -2,7 +2,7 @@ import argparse import json -def list_s3_keys(bucket_name, prefix=""): +def list_s3_keys(bucket_name, deployment_id=""): """ List all keys in an S3 bucket under a specific prefix. @@ -30,7 +30,7 @@ def list_s3_keys(bucket_name, prefix=""): while True: list_kwargs = { "Bucket": bucket_name, - "Prefix": prefix, + "Prefix": deployment_id, } if continuation_token: list_kwargs["ContinuationToken"] = continuation_token @@ -64,15 +64,15 @@ def save_keys_to_file(keys, output_file): def main(): parser = argparse.ArgumentParser(description="Generate a file containing S3 keys from a bucket.") parser.add_argument("--bucket", type=str, required=True, help="Name of the S3 bucket.") - parser.add_argument("--prefix", type=str, default="", help="Prefix to filter objects (default: '')") + parser.add_argument("--deployment_id", type=str, default="", help="The deployment id to filter objects. If set to '' then all deployments are used. (default: '')") parser.add_argument("--output_file", type=str, default="s3_keys.txt", help="Output file to save S3 keys.") args = parser.parse_args() # List keys from the specified S3 bucket and prefix - print(f"Listing keys from bucket '{args.bucket}' with prefix '{args.prefix}'...") - keys = list_s3_keys(args.bucket, args.prefix) + print(f"Listing keys from bucket '{args.bucket}' with deployment '{args.deployment_id}'...") + keys = list_s3_keys(args.bucket, args.deployment_id) # Save keys to the output file save_keys_to_file(keys, args.output_file) diff --git a/03_pre_chop_files.py b/03_pre_chop_files.py new file mode 100644 index 0000000..d640860 --- /dev/null +++ b/03_pre_chop_files.py @@ -0,0 +1,61 @@ +import json +import argparse +from math import ceil +import os + +def load_workload(input_file, file_extension): + """ + Load workload from a file. Assumes each line contains an S3 key. + """ + with open(input_file, 'r') as f: + all_keys = [line.strip() for line in f.readlines()] + subset_keys = [x for x in all_keys if x.endswith(file_extension)] + + # remove corrupt keys + subset_keys = [x for x in subset_keys if not os.path.basename(x).startswith('$')] + + # remove keys uploaded from the recycle bin (legacy code) + subset_keys = [x for x in subset_keys if not 'recycle' in x] + + return subset_keys + +def split_workload(keys, chunk_size): + """ + Split a list of keys into chunks of a specified size. + """ + num_chunks = ceil(len(keys) / chunk_size) + chunks = { + str(i + 1): {"keys": keys[i * chunk_size: (i + 1) * chunk_size]} + for i in range(num_chunks) + } + return chunks + +def save_chunks(chunks, output_file): + """ + Save chunks to a JSON file. + """ + with open(output_file, 'w') as f: + json.dump(chunks, f, indent=4) + +def main(): + parser = argparse.ArgumentParser(description="Pre-chop S3 workload into manageable chunks.") + parser.add_argument("--input_file", type=str, required=True, help="Path to file containing S3 keys, one per line.") + parser.add_argument("--file_extension", type=str, required=True, default="jpg|jpeg", help="File extensions to be chuncked. If empty, all extensions used.") + parser.add_argument("--chunk_size", type=int, default=100, help="Number of keys per chunk.") + parser.add_argument("--output_file", type=str, required=True, help="Path to save the output JSON file.") + args = parser.parse_args() + + # Load the workload from the input file + keys = load_workload(args.input_file, args.file_extension) + + # Split the workload into chunks + chunks = split_workload(keys, args.chunk_size) + + # Save the chunks to a JSON file + save_chunks(chunks, args.output_file) + + print(f"Successfully split {len(keys)} keys into {len(chunks)} chunks.") + print(f"Chunks saved to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/04_process_chunks.py b/04_process_chunks.py new file mode 100644 index 0000000..5020393 --- /dev/null +++ b/04_process_chunks.py @@ -0,0 +1,217 @@ +import argparse +import boto3 +import json +import os +from utils.inference_scripts import perform_inf +from boto3.s3.transfer import TransferConfig +import torch +from utils.custom_models import load_models + + +# Transfer configuration for optimised S3 download +transfer_config = TransferConfig( + max_concurrency=20, # Increase the number of concurrent transfers + multipart_threshold=8 * 1024 * 1024, # 8MB + max_io_queue=1000, + io_chunksize=262144, # 256KB +) + +def initialise_session(credentials_file="credentials.json"): + """ + Load AWS and API credentials from a configuration file and initialise an AWS session. + + Args: + credentials_file (str): Path to the credentials JSON file. + + Returns: + boto3.Client: Initialised S3 client. + """ + with open(credentials_file, encoding="utf-8") as config_file: + aws_credentials = json.load(config_file) + session = boto3.Session( + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + region_name=aws_credentials["AWS_REGION"], + ) + client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"]) + return client + + +def download_and_analyse( + keys, + output_dir, + bucket_name, + client, + remove_image=True, + perform_inference=True, + localisation_model=None, + binary_model=None, + order_model=None, + order_labels=None, + species_model=None, + species_labels=None, + device=None, + order_data_thresholds=None, + csv_file="results.csv", +): + """ + Download images from S3 and perform analysis. + + Args: + keys (list): List of S3 keys to process. + output_dir (str): Directory to save downloaded files and results. + bucket_name (str): S3 bucket name. + client (boto3.Client): Initialised S3 client. + Other args: Parameters for inference and analysis. + """ + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + for key in keys: + local_path = os.path.join(output_dir, os.path.basename(key)) + print(f"Downloading {key} to {local_path}") + client.download_file(bucket_name, key, local_path, Config=transfer_config) + + # Perform image analysis if enabled + print(f"Analysing {local_path}") + if perform_inference: + perform_inf( + local_path, + bucket_name=bucket_name, + loc_model=localisation_model, + binary_model=binary_model, + order_model=order_model, + order_labels=order_labels, + regional_model=species_model, + regional_category_map=species_labels, + device=device, + order_data_thresholds=order_data_thresholds, + csv_file=csv_file, + save_crops=True, + ) + # Remove the image if cleanup is enabled + if remove_image: + os.remove(local_path) + + +def main( + chunk_id, + json_file, + output_dir, + bucket_name, + credentials_file="credentials.json", + remove_image=True, + perform_inference=True, + localisation_model=None, + binary_model=None, + order_model=None, + order_labels=None, + species_model=None, + species_labels=None, + device=None, + order_data_thresholds=None, + csv_file="results.csv", +): + """ + Main function to process a specific chunk of S3 keys. + + Args: + chunk_id (str): ID of the chunk to process (e.g., chunk_0). + json_file (str): Path to the JSON file with key chunks. + output_dir (str): Directory to save results. + bucket_name (str): S3 bucket name. + Other args: Parameters for download and analysis. + """ + with open(json_file, "r") as f: + chunks = json.load(f) + + if chunk_id not in chunks: + raise ValueError(f"Chunk ID {chunk_id} not found in JSON file.") + + client = initialise_session(credentials_file) + + keys = chunks[chunk_id]['keys'] + download_and_analyse( + keys=keys, + output_dir=output_dir, + bucket_name=bucket_name, + client=client, + remove_image=remove_image, + perform_inference=perform_inference, + localisation_model=localisation_model, + binary_model=binary_model, + order_model=order_model, + order_labels=order_labels, + species_model=species_model, + species_labels=species_labels, + device=device, + order_data_thresholds=order_data_thresholds, + csv_file=csv_file, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process a specific chunk of S3 keys.") + parser.add_argument("--chunk_id", required=True, help="ID of the chunk to process (e.g., 0, 1, 2, 3).") + parser.add_argument("--json_file", required=True, help="Path to the JSON file with key chunks.") + parser.add_argument("--output_dir", required=True, help="Directory to save downloaded files and analysis results.", default="./data/") + parser.add_argument("--bucket_name", required=True, help="Name of the S3 bucket.") + parser.add_argument("--credentials_file", default="credentials.json", help="Path to AWS credentials file.") + parser.add_argument("--remove_image", action="store_true", help="Remove images after processing.") + parser.add_argument("--perform_inference", action="store_true", help="Enable inference.") + parser.add_argument("--localisation_model_path", type=str, help="Path to the localisation model weights.", default="./models/v1_localizmodel_2021-08-17-12-06.pt") + parser.add_argument("--binary_model_path", type=str, help="Path to the binary model weights.", default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth") + parser.add_argument("--order_model_path", type=str, help="Path to the order model weights.", default="./models/dhc_best_128.pth") + parser.add_argument("--order_labels", type=str, help="Path to the order labels file.") + parser.add_argument("--species_model_path", type=str, help="Path to the species model weights.", default="./models/turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt") + parser.add_argument("--species_labels", type=str, help="Path to the species labels file.", + default="./models/03_costarica_data_category_map.json") + parser.add_argument("--device", type=str, default="cpu", help="Device to run inference on (e.g., cpu or cuda).") + parser.add_argument("--order_thresholds_path", type=str, help="Path to the order data thresholds file.", default="./models/thresholdsTestTrain.csv") + parser.add_argument("--csv_file", default="results.csv", help="Path to save analysis results.") + + + args = parser.parse_args() + + if torch.cuda.is_available(): + device = torch.device("cuda") + print( + "\033[95m\033[1mCuda available, using GPU " + + "\N{White Heavy Check Mark}\033[0m\033[0m" + ) + else: + device = torch.device("cpu") + print( + "\033[95m\033[1mCuda not available, using CPU " + + "\N{Cross Mark}\033[0m\033[0m" + ) + + models = load_models( + device, + getattr(args, 'localisation_model_path'), + getattr(args, 'binary_model_path'), + getattr(args, 'order_model_path'), + getattr(args, 'order_thresholds_path'), + getattr(args, 'species_model_path'), + getattr(args, 'species_labels') + ) + + + main( + chunk_id=args.chunk_id, + json_file=args.json_file, + output_dir=args.output_dir, + bucket_name=args.bucket_name, + credentials_file=args.credentials_file, + remove_image=args.remove_image, + perform_inference=args.perform_inference, + localisation_model=models['localisation_model'], + binary_model=models['classification_model'], + order_model=models['order_model'], + order_labels=models['order_model_labels'], + order_data_thresholds=models['order_model_thresholds'], + species_model=models['species_model'], + species_labels=models['species_model_labels'], + device=device, + csv_file=args.csv_file, + ) diff --git a/README.md b/README.md index 65a2326..9d6c5f1 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ source ~/miniforge3/bin/activate conda activate "~/conda_envs/moth_detector_env/" ``` -Inferences are run by country and deployment site. To run the script, for Costa Rica say, use the following command: + -### Listing Available Deployments +The multi-core pipeline is run in several steps: + +1. Listing All Available Deployments +2. Generate Key Files +3. Chop the keys into chunks +4. Analyse the chunks + +### 01. Listing Available Deployments To find information about the available deployments you can use the print_deployments function. For all deployments: @@ -104,10 +111,29 @@ python s3_download_with_inference.py \ --deployment "Garden - 3F1C4908" ``` -### Generating the Keys +### 02. Generating the Keys + +```bash +python 02_generate_keys.py --bucket 'cri' --deployment_id 'dep000031' --output_file './keys/dep000031_keys.txt' +``` + +### 03. Pre-chop the Keys into Chunks + +```bash +python 03_pre_chop_files.py --input_file './keys/dep000031_keys.txt' --file_extension 'jpg|jpeg' --chunk_size 100 --output_file './keys/dep000031_workload_chunks.json' +``` + +### 04. Process the Chunked Files ```bash -python 02_generate_keys.py --bucket 'cri' --prefix '' --output_file s3_keys.txt +python 04_process_chunks.py \ + --chunk_id 1 \ + --json_file './keys/dep000031_workload_chunks.json' \ + --output_dir './data/dep000031' \ + --bucket_name 'cri' \ + --credentials_file './credentials.json' \ + --perform_inference \ + --remove_image ``` diff --git a/s3_download_with_inference.py b/s3_download_with_inference.py deleted file mode 100644 index 68dd416..0000000 --- a/s3_download_with_inference.py +++ /dev/null @@ -1,290 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -This script downloads files from an S3 bucket synchronously and performs -inference on the images. AWS credentials, S3 bucket name, and UKCEH API -credentials are loaded from a configuration file (credentials.json). -""" - -import json -import boto3 -import torch -import pandas as pd -import os -import argparse - -from utils.aws_scripts import get_objects, get_deployments -from utils.custom_models import load_models - - -def download_and_inference( - country, - deployment, - crops_interval, - rerun_existing, - local_directory_path, - perform_inference, - remove_image, -): - """ - Display the main menu and handle user interaction. - """ - - output_cols = [ - "image_path", - "bucket_name", - "analysis_datetime", - "box_score", - "box_label", - "x_min", - "y_min", - "x_max", - "y_max", - "class_name", - "class_confidence", - "order_name", - "order_confidence", - "species_name", - "species_confidence", - "cropped_image_path", - ] - - username = aws_credentials["UKCEH_username"] - password = aws_credentials["UKCEH_password"] - - print(f"\033[93m - Removing images after analysis: {remove_image}\033[0m") - print(f"\033[93m - Performing inference: {perform_inference}\033[0m") - print(f"\033[93m - Rerun existing inferences: {rerun_existing}\033[0m") - - all_deployments = get_deployments(username, password) - - print(f"\033[96m\033[1mAnalysing: {country}\033[0m\033[0m") - - country_deployments = [ - f"{d['location_name']} - {d['camera_id']}" - for d in all_deployments - if d["country"] == country and d["status"] == "active" - ] - - s3_bucket_name = [ - d["country_code"] - for d in all_deployments - if d["country"] == country and d["status"] == "active" - ][0].lower() - - if deployment == "All": - deps = country_deployments - else: - deps = [deployment] - - # loop through each deployment - for region in deps: - location_name, camera_id = region.split(" - ") - dep_id = [ - d["deployment_id"] - for d in all_deployments - if d["country"] == country - and d["location_name"] == location_name - and d["camera_id"] == camera_id - and d["status"] == "active" - ][0] - print(f"\033[96m - Deployment: {region}, id: {dep_id} \033[0m") - - # if the file doesnt exist, print headers - csv_file = os.path.abspath( - f"{local_directory_path}/{dep_id}/{dep_id}_results.csv" - ) - os.makedirs(os.path.dirname(csv_file), exist_ok=True) - print(f"\033[93m - Saving results to: {csv_file}\033[0m") - - if not os.path.isfile(csv_file): - all_boxes = pd.DataFrame(columns=output_cols) - all_boxes.to_csv(csv_file, index=False) - - prefix = f"{dep_id}/snapshot_images" - get_objects( - session, - aws_credentials, - s3_bucket_name, - prefix, - local_directory_path, - batch_size=100, - perform_inference=perform_inference, - remove_image=remove_image, - localisation_model=model_loc, - binary_model=classification_model, - order_model=order_model, - order_labels=order_labels, - species_model=regional_model, - species_labels=regional_category_map, - country=country, - region=region, - device=device, - order_data_thresholds=order_data_thresholds, - csv_file=csv_file, - rerun_existing=rerun_existing, - crops_interval=crops_interval, - ) - print("\N{White Heavy Check Mark}\033[0m\033[0m") - - -if __name__ == "__main__": - - # Use GPU if available - - if torch.cuda.is_available(): - device = torch.device("cuda") - print( - "\033[95m\033[1mCuda available, using GPU " - + "\N{White Heavy Check Mark}\033[0m\033[0m" - ) - else: - device = torch.device("cpu") - print( - "\033[95m\033[1mCuda not available, using CPU " - + "\N{Cross Mark}\033[0m\033[0m" - ) - - - - print("\033[96m\033[1mInitialising the JASMINE session...\033[0m\033[0m", end="") - # Load AWS credentials and S3 bucket name from config file - with open("./credentials.json", encoding="utf-8") as config_file: - aws_credentials = json.load(config_file) - session = boto3.Session( - aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], - region_name=aws_credentials["AWS_REGION"], - ) - print("\N{White Heavy Check Mark}") - - num_workers = int(os.getenv("SLURM_CPUS_PER_TASK", os.cpu_count())) - - parser = argparse.ArgumentParser( - description="Script for downloading and processing images from S3." - ) - parser.add_argument("--country", type=str, help="Specify the country name") - parser.add_argument("--deployment", type=str, help="Specify the deployment name") - parser.add_argument( - "--keep_crops", - action=argparse.BooleanOptionalAction, - default=False, - help="Whether to keep the crops", - ) - parser.add_argument( - "--perform_inference", - action=argparse.BooleanOptionalAction, - default=True, - help="Whether to perform the inference", - ) - parser.add_argument( - "--remove_image", - action=argparse.BooleanOptionalAction, - default=True, - help="Whether to remove the raw image after inference", - ) - parser.add_argument( - "--crops_interval", - type=str, - help="The interval for which to preserve the crops", - default=10, - ) - parser.add_argument( - "--rerun_existing", - action=argparse.BooleanOptionalAction, - default=False, - help="Whether to rerun images which have already been analysed", - ) - parser.add_argument( - "--data_storage_path", - type=str, - help="The path to scratch data storage", - default="./data/", - ) - parser.add_argument( - "--regional_model_path", - type=str, - help="The path to the regional models wights file", - default="./models/turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt", - ) - parser.add_argument( - "--regional_map_path", - type=str, - help="The path to the category map", - default="./models/03_costarica_data_category_map.json", - ) - parser.add_argument( - "--binary_model_path", - type=str, - help="The path to the binary model weights", - default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth", - ) - parser.add_argument( - "--order_model_path", - type=str, - help="The path to the binary model weights", - default="./models/dhc_best_128.pth", - ) - parser.add_argument( - "--order_threshold_path", - type=str, - help="The path to the binary model weights", - default="./models/thresholdsTestTrain.csv", - ) - parser.add_argument( - "--localisation_model_path", - type=str, - help="The path to the binary model weights", - default="./models/v1_localizmodel_2021-08-17-12-06.pt", - ) - parser.add_argument( - "--num_workers", - type=int, - default=num_workers, - help="Number of workers for multi-threaded downloads", - ) - - args = parser.parse_args() - - print("\033[96m\033[1mLoading in models...\033[0m\033[0m", end="") - ( - model_loc, - classification_model, - regional_model, - regional_category_map, - order_model, - order_data_thresholds, - order_labels, - ) = load_models(device, args.localisation_model_path, args.binary_model_path, args.order_model_path, args.order_threshold_path, args.regional_model_path, args.regional_map_path) - print("\N{White Heavy Check Mark}") - - - - # check that the data storage path exists - data_storage_path = os.path.abspath(args.data_storage_path) - if not os.path.isdir(data_storage_path): - os.makedirs(data_storage_path) - - print("\033[93m\033[1m" + "Pipeline parameters" + "\033[0m\033[0m") - print(f"\033[93m - Scratch and crops storage: {data_storage_path}\033[0m") - - if args.keep_crops: - crops_interval = args.crops_interval - print(f"\033[93m - Keeping crops every {crops_interval}mins\033[0m") - else: - print("\033[93m - Not keeping crops\033[0m") - crops_interval = None - - print(f"\033[93m - Number of workers: {args.num_workers}\033[0m") - - download_and_inference( - args.country, - args.deployment, - int(crops_interval), - args.rerun_existing, - data_storage_path, - args.perform_inference, - args.remove_image, - args.num_workers, - ) diff --git a/utils/aws_scripts.py b/utils/aws_scripts.py index 83c2b34..adfb292 100644 --- a/utils/aws_scripts.py +++ b/utils/aws_scripts.py @@ -1,292 +1,79 @@ -# utils/aws_scripts.py - -import requests -from requests.auth import HTTPBasicAuth -import tqdm -from boto3.s3.transfer import TransferConfig import os -from datetime import datetime, timedelta -from utils.inference_scripts import perform_inf -import pandas as pd -import sys -import tqdm -import math -from functools import partial -from concurrent.futures import ThreadPoolExecutor -import re - - -def get_deployments(username, password): - """Fetch deployments from the API with authentication.""" - - try: - url = "https://connect-apps.ceh.ac.uk/ami-data-upload/get-deployments/" - response = requests.get( - url, auth=HTTPBasicAuth(username, password), timeout=600 - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as err: - print(f"HTTP Error: {err}") - if response.status_code == 401: - print("Wrong username or password. Try again!") - sys.exit(1) - except Exception as err: - print(f"Error: {err}") - sys.exit(1) +import boto3 +from botocore.exceptions import ClientError +from tqdm import tqdm -def download_object( - s3_client, - bucket_name, - key, - download_path, - perform_inference=False, - remove_image=False, - localisation_model=None, - binary_model=None, - order_model=None, - order_labels=None, - species_model=None, - species_labels=None, - country="UK", - region="UKCEH", - device=None, - order_data_thresholds=None, - csv_file="results.csv", - intervals=None, -): - """ - Download a single object from S3 synchronously. +def list_objects(session, bucket_name, prefix): """ + List all objects in an S3 bucket with a specific prefix. - # Configure the transfer to optimize the download - transfer_config = TransferConfig( - max_concurrency=20, # Increase the number of concurrent transfers - multipart_threshold=8 * 1024 * 1024, # 8MB - max_io_queue=1000, - io_chunksize=262144, # 256KB - ) + Args: + session (boto3.Session): Authenticated AWS session. + bucket_name (str): Name of the S3 bucket. + prefix (str): Prefix for filtering objects. + Returns: + list: List of object keys in the bucket matching the prefix. + """ + s3_client = session.client("s3") + object_keys = [] try: - s3_client.download_file(bucket_name, key, download_path, Config=transfer_config) - - # If crops are saved, define the frequecy - if intervals: - path_df = get_datetime_from_string(os.path.basename(download_path)) - save_crops = path_df in intervals - else: - save_crops = False - if save_crops: - print(f" - Saving crops for: {os.path.basename(download_path)}") + paginator = s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix): + if "Contents" in page: + object_keys.extend(obj["Key"] for obj in page["Contents"]) + except ClientError as e: + print(f"\033[91mError listing objects: {e}\033[0m") + return object_keys - if perform_inference: - perform_inf( - download_path, - bucket_name=bucket_name, - loc_model=localisation_model, - binary_model=binary_model, - order_model=order_model, - order_labels=order_labels, - regional_model=species_model, - regional_category_map=species_labels, - country=country, - region=region, - device=device, - order_data_thresholds=order_data_thresholds, - csv_file=csv_file, - save_crops=save_crops, - ) - if remove_image: - os.remove(download_path) - except Exception as e: - print( - f"\033[91m\033[1m Error downloading {bucket_name}/{key}: {e}\033[0m\033[0m" - ) - -def get_datetime_from_string(input): - in_sting = input.replace("-snapshot.jpg", "") - dt = in_sting.split("-")[-1] - dt = datetime.strptime(dt, "%Y%m%d%H%M%S") - return dt - - -def download_batch( - s3_client, - bucket_name, - keys, - local_path, - perform_inference=False, - remove_image=False, - localisation_model=None, - binary_model=None, - order_model=None, - order_labels=None, - species_model=None, - species_labels=None, - country="UK", - region="UKCEH", - device=None, - order_data_thresholds=None, - csv_file="results.csv", - rerun_existing=False, - intervals=None, -): +def download_object(session, bucket_name, key, local_path, retries=3): """ - Download a batch of objects from S3. - """ - - existing_df = pd.read_csv(csv_file, dtype="unicode") - - for key in keys: - file_path, filename = os.path.split(key) + Download a single object from S3. - os.makedirs(os.path.join(local_path, file_path), exist_ok=True) - download_path = os.path.join(local_path, file_path, filename) + Args: + session (boto3.Session): Authenticated AWS session. + bucket_name (str): S3 bucket name. + key (str): Key of the object to download. + local_path (str): Local directory to save the object. + retries (int): Number of retries on failure. - # check if file is in csv_file 'path' column - if not rerun_existing: - if existing_df["image_path"].str.contains(download_path).any(): - print( - f"{os.path.basename(download_path)} has already been processed. Skipping..." - ) - continue - - download_object( - s3_client, - bucket_name, - key, - download_path, - perform_inference, - remove_image, - localisation_model, - binary_model, - order_model, - order_labels, - species_model, - species_labels, - country, - region, - device, - order_data_thresholds, - csv_file, - intervals, - ) - - -def count_files(s3_client, bucket_name, prefix): - """ - Count number of files for a given prefix. + Returns: + str: Local file path of the downloaded object. """ - paginator = s3_client.get_paginator("list_objects_v2") - operation_parameters = {"Bucket": bucket_name, "Prefix": prefix} - page_iterator = paginator.paginate(**operation_parameters) + s3_client = session.client("s3") + local_file_path = os.path.join(local_path, key) + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - count = 0 - all_keys = [] - for page in page_iterator: - if not os.path.basename(page.get("Contents", [])[0]["Key"]).startswith("$"): - count += page.get("KeyCount", 0) - file_i = page.get("Contents", [])[0]["Key"] - all_keys = all_keys + [file_i] - return count, all_keys + for attempt in range(retries): + try: + s3_client.download_file(bucket_name, key, local_file_path) + return local_file_path + except ClientError as e: + print(f"\033[93mRetry {attempt + 1}/{retries} - Error downloading {key}: {e}\033[0m") + raise RuntimeError(f"Failed to download {key} after {retries} attempts.") -def get_objects( - session, - aws_credentials, - bucket_name, - prefix, - local_path, - batch_size=100, - perform_inference=False, - remove_image=False, - localisation_model=None, - binary_model=None, - order_model=None, - order_labels=None, - species_model=None, - species_labels=None, - country="UK", - region="UKCEH", - device=None, - order_data_thresholds=None, - csv_file="results.csv", - rerun_existing=False, - crops_interval=None, - num_workers=1, -): +def download_batch(session, bucket_name, keys, local_path, retries=3): """ - Fetch objects from the S3 bucket and download them synchronously in batches. - """ - s3_client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"]) - - total_files, all_keys = count_files(s3_client, bucket_name, prefix) - first_dt = get_datetime_from_string(os.path.basename(all_keys[0])) - last_dt = get_datetime_from_string(os.path.basename(all_keys[-1] - - paginator = s3_client.get_paginator("list_objects_v2") - operation_parameters = {"Bucket": bucket_name, "Prefix": prefix} - page_iterator = paginator.paginate(**operation_parameters) - - if crops_interval is not None: - t = first_dt - intervals = [] - while t < last_dt: - intervals = intervals + [t] - t = t + timedelta(minutes=crops_interval) - else: - intervals = None - - keys = [] - for page in page_iterator: - if os.path.basename(page.get("Contents", [])[0]["Key"]).startswith("$"): - print(f'{page.get("Contents", [])[0]["Key"]} is suspected corrupt, skipping') - continue - - for obj in page.get("Contents", []): - keys.append(obj["Key"]) - - # don't rerun previously analysed images - results_df = pd.read_csv(csv_file, dtype=str) - run_images = [re.sub(r'^.*?dep', 'dep', x) for x in results_df['image_path']] - keys = [x for x in keys if x not in run_images] - - # Divide the keys among workers - chunks = [ - keys[i : i + math.ceil(len(keys) / num_workers)] - for i in range(0, len(keys), math.ceil(len(keys) / num_workers)) - ] - - # Shared progress bar - progress_bar = tqdm.tqdm(total=total_files, desc=f"Download files for {os.path.basename(csv_file).replace('_results.csv', '')}") - - def process_chunk(chunk): - for i in range(0, len(chunk), batch_size): - batch_keys = chunk[i : i + batch_size] - download_batch( - s3_client, - bucket_name, - batch_keys, - local_path, - perform_inference, - remove_image, - localisation_model, - binary_model, - order_model, - order_labels, - country, - region, - device, - order_data_thresholds, - csv_file, - rerun_existing, - ) - progress_bar.update(len(batch_keys)) + Download a batch of objects from S3. - # Use ThreadPoolExecutor instead of multiprocessing - with ThreadPoolExecutor(max_workers=num_workers) as executor: - executor.map(process_chunk, chunks) + Args: + session (boto3.Session): Authenticated AWS session. + bucket_name (str): S3 bucket name. + keys (list): List of object keys to download. + local_path (str): Local directory to save objects. + retries (int): Number of retries for each object. - progress_bar.close() + Returns: + list: List of local file paths of successfully downloaded objects. + """ + local_files = [] + for key in tqdm(keys, desc="Downloading batch"): + try: + local_file = download_object(session, bucket_name, key, local_path, retries) + local_files.append(local_file) + except RuntimeError as e: + print(f"\033[91mSkipping {key}: {e}\033[0m") + return local_files diff --git a/utils/custom_models.py b/utils/custom_models.py index c16b956..689e582 100644 --- a/utils/custom_models.py +++ b/utils/custom_models.py @@ -129,12 +129,12 @@ def load_models(device, localisation_model_path, binary_model_path, order_model_ species_model.load_state_dict(state_dict) species_model.eval() - return ( - model_loc, - classification_model, - species_model, - species_category_map, - model_order, - order_data_thresholds, - order_labels, - ) + return ({ + 'localisation_model': model_loc, + 'classification_model': classification_model, + 'species_model': species_model, + 'species_model_labels': species_category_map, + 'order_model': model_order, + 'order_model_thresholds': order_data_thresholds, + 'order_model_labels': order_labels + }) diff --git a/utils/inference_scripts.py b/utils/inference_scripts.py index a65fc70..e72e334 100644 --- a/utils/inference_scripts.py +++ b/utils/inference_scripts.py @@ -5,9 +5,6 @@ import numpy as np from datetime import datetime -# from utils.custom_models import Resnet50_species, ResNet50_order, load_models - - def classify_species(image_tensor, regional_model, regional_category_map): """ Classify the species of the moth using the regional model. @@ -76,12 +73,11 @@ def perform_inf( order_labels, regional_model, regional_category_map, - country, - region, device, order_data_thresholds, csv_file, save_crops, + box_threshold=0.99 ): """ Perform inferences on an image including: @@ -126,6 +122,7 @@ def perform_inf( "cropped_image_path", ] + image = Image.open(image_path).convert("RGB") original_image = image.copy() original_width, original_height = image.size @@ -137,12 +134,12 @@ def perform_inf( columns=all_cols ) - # Perform object localization + # Perform object localisation with torch.no_grad(): - localization_outputs = loc_model(input_tensor) - + localisation_outputs = loc_model(input_tensor) + # catch no crops - if len(localization_outputs[0]["boxes"]) == 0: + if len(localisation_outputs[0]["boxes"]) == 0 or all(localisation_outputs[0]["scores"] < box_threshold): df = pd.DataFrame( [ [ @@ -175,11 +172,11 @@ def perform_inf( ) # for each detection - for i in range(len(localization_outputs[0]["boxes"])): - x_min, y_min, x_max, y_max = localization_outputs[0]["boxes"][i] - box_score = localization_outputs[0]["scores"].tolist()[i] - box_label = localization_outputs[0]["labels"].tolist()[i] - + for i in range(len(localisation_outputs[0]["boxes"])): + x_min, y_min, x_max, y_max = localisation_outputs[0]["boxes"][i] + box_score = localisation_outputs[0]["scores"].tolist()[i] + box_label = localisation_outputs[0]["labels"].tolist()[i] + x_min = int(int(x_min) * original_width / 300) y_min = int(int(y_min) * original_height / 300) x_max = int(int(x_max) * original_width / 300) @@ -188,7 +185,7 @@ def perform_inf( box_width = x_max - x_min box_height = y_max - y_min - if box_score < 0.99: + if box_score < box_threshold: continue # if box height or width > half the image, skip @@ -234,7 +231,6 @@ def perform_inf( f"order: {order_name}, binary: {class_name}", fill="red", ) - draw.text((x_min, y_max), str(box_score), fill="black") # append to csv with pandas @@ -261,7 +257,9 @@ def perform_inf( ], columns=all_cols, ) + all_boxes = pd.concat([all_boxes, df]) + df.to_csv( f'{csv_file}', mode="a",