diff --git a/.gitignore b/.gitignore index a769b8a..7f7ce50 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ credentials.json *.out data/ models/ +keys/ diff --git a/print_deployments.py b/01_print_deployments.py similarity index 95% rename from print_deployments.py rename to 01_print_deployments.py index 903e52d..c208f2d 100644 --- a/print_deployments.py +++ b/01_print_deployments.py @@ -92,7 +92,7 @@ def print_deployments(include_inactive=False, subset_countries=None, print_image for dep in sorted(all_deps): dep_info = [x for x in country_depl if x['deployment_id'] == dep][0] print(f"\033[1m - Deployment ID: {dep_info['deployment_id']}, Name: {dep_info['location_name']}, Deployment Key: '{dep_info['location_name']} - {dep_info['camera_id']}'\033[0m") - print(f" Location ID: {dep_info['location_id']}, Latitute: {dep_info['lat']}, Longitute: {dep_info['lon']}, Camera ID: {dep_info['camera_id']}, System ID: {dep_info['system_id']}, Status: {dep_info['status']}") + print(f" Location ID: {dep_info['location_id']}, Country code: {dep_info['country_code'].lower()}, Latitute: {dep_info['lat']}, Longitute: {dep_info['lon']}, Camera ID: {dep_info['camera_id']}, System ID: {dep_info['system_id']}, Status: {dep_info['status']}") # get the number of images for this deployment prefix = f"{dep_info['deployment_id']}/snapshot_images" diff --git a/02_generate_keys.py b/02_generate_keys.py new file mode 100644 index 0000000..f19ed65 --- /dev/null +++ b/02_generate_keys.py @@ -0,0 +1,82 @@ +import boto3 +import argparse +import json + +def list_s3_keys(bucket_name, deployment_id=""): + """ + List all keys in an S3 bucket under a specific prefix. + + Parameters: + bucket_name (str): The name of the S3 bucket. + prefix (str): The prefix to filter keys (default: ""). + + Returns: + list: A list of S3 object keys. + """ + with open("./credentials.json", encoding="utf-8") as config_file: + aws_credentials = json.load(config_file) + + + session = boto3.Session( + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + region_name=aws_credentials["AWS_REGION"], + ) + s3_client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"]) + + keys = [] + continuation_token = None + + while True: + list_kwargs = { + "Bucket": bucket_name, + "Prefix": deployment_id, + } + if continuation_token: + list_kwargs["ContinuationToken"] = continuation_token + + response = s3_client.list_objects_v2(**list_kwargs) + + # Add object keys to the list + for obj in response.get("Contents", []): + keys.append(obj["Key"]) + + # Check if there are more objects to list + if response.get("IsTruncated"): # If True, there are more results + continuation_token = response["NextContinuationToken"] + else: + break + + return keys + +def save_keys_to_file(keys, output_file): + """ + Save S3 keys to a file, one per line. + + Parameters: + keys (list): List of S3 keys. + output_file (str): Path to the output file. + """ + with open(output_file, "w") as f: + for key in keys: + f.write(key + "\n") + +def main(): + parser = argparse.ArgumentParser(description="Generate a file containing S3 keys from a bucket.") + parser.add_argument("--bucket", type=str, required=True, help="Name of the S3 bucket.") + parser.add_argument("--deployment_id", type=str, default="", help="The deployment id to filter objects. If set to '' then all deployments are used. (default: '')") + parser.add_argument("--output_file", type=str, default="s3_keys.txt", help="Output file to save S3 keys.") + args = parser.parse_args() + + + + # List keys from the specified S3 bucket and prefix + print(f"Listing keys from bucket '{args.bucket}' with deployment '{args.deployment_id}'...") + keys = list_s3_keys(args.bucket, args.deployment_id) + + # Save keys to the output file + save_keys_to_file(keys, args.output_file) + print(f"Saved {len(keys)} keys to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/03_pre_chop_files.py b/03_pre_chop_files.py new file mode 100644 index 0000000..d640860 --- /dev/null +++ b/03_pre_chop_files.py @@ -0,0 +1,61 @@ +import json +import argparse +from math import ceil +import os + +def load_workload(input_file, file_extension): + """ + Load workload from a file. Assumes each line contains an S3 key. + """ + with open(input_file, 'r') as f: + all_keys = [line.strip() for line in f.readlines()] + subset_keys = [x for x in all_keys if x.endswith(file_extension)] + + # remove corrupt keys + subset_keys = [x for x in subset_keys if not os.path.basename(x).startswith('$')] + + # remove keys uploaded from the recycle bin (legacy code) + subset_keys = [x for x in subset_keys if not 'recycle' in x] + + return subset_keys + +def split_workload(keys, chunk_size): + """ + Split a list of keys into chunks of a specified size. + """ + num_chunks = ceil(len(keys) / chunk_size) + chunks = { + str(i + 1): {"keys": keys[i * chunk_size: (i + 1) * chunk_size]} + for i in range(num_chunks) + } + return chunks + +def save_chunks(chunks, output_file): + """ + Save chunks to a JSON file. + """ + with open(output_file, 'w') as f: + json.dump(chunks, f, indent=4) + +def main(): + parser = argparse.ArgumentParser(description="Pre-chop S3 workload into manageable chunks.") + parser.add_argument("--input_file", type=str, required=True, help="Path to file containing S3 keys, one per line.") + parser.add_argument("--file_extension", type=str, required=True, default="jpg|jpeg", help="File extensions to be chuncked. If empty, all extensions used.") + parser.add_argument("--chunk_size", type=int, default=100, help="Number of keys per chunk.") + parser.add_argument("--output_file", type=str, required=True, help="Path to save the output JSON file.") + args = parser.parse_args() + + # Load the workload from the input file + keys = load_workload(args.input_file, args.file_extension) + + # Split the workload into chunks + chunks = split_workload(keys, args.chunk_size) + + # Save the chunks to a JSON file + save_chunks(chunks, args.output_file) + + print(f"Successfully split {len(keys)} keys into {len(chunks)} chunks.") + print(f"Chunks saved to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/04_process_chunks.py b/04_process_chunks.py new file mode 100644 index 0000000..5020393 --- /dev/null +++ b/04_process_chunks.py @@ -0,0 +1,217 @@ +import argparse +import boto3 +import json +import os +from utils.inference_scripts import perform_inf +from boto3.s3.transfer import TransferConfig +import torch +from utils.custom_models import load_models + + +# Transfer configuration for optimised S3 download +transfer_config = TransferConfig( + max_concurrency=20, # Increase the number of concurrent transfers + multipart_threshold=8 * 1024 * 1024, # 8MB + max_io_queue=1000, + io_chunksize=262144, # 256KB +) + +def initialise_session(credentials_file="credentials.json"): + """ + Load AWS and API credentials from a configuration file and initialise an AWS session. + + Args: + credentials_file (str): Path to the credentials JSON file. + + Returns: + boto3.Client: Initialised S3 client. + """ + with open(credentials_file, encoding="utf-8") as config_file: + aws_credentials = json.load(config_file) + session = boto3.Session( + aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], + region_name=aws_credentials["AWS_REGION"], + ) + client = session.client("s3", endpoint_url=aws_credentials["AWS_URL_ENDPOINT"]) + return client + + +def download_and_analyse( + keys, + output_dir, + bucket_name, + client, + remove_image=True, + perform_inference=True, + localisation_model=None, + binary_model=None, + order_model=None, + order_labels=None, + species_model=None, + species_labels=None, + device=None, + order_data_thresholds=None, + csv_file="results.csv", +): + """ + Download images from S3 and perform analysis. + + Args: + keys (list): List of S3 keys to process. + output_dir (str): Directory to save downloaded files and results. + bucket_name (str): S3 bucket name. + client (boto3.Client): Initialised S3 client. + Other args: Parameters for inference and analysis. + """ + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + for key in keys: + local_path = os.path.join(output_dir, os.path.basename(key)) + print(f"Downloading {key} to {local_path}") + client.download_file(bucket_name, key, local_path, Config=transfer_config) + + # Perform image analysis if enabled + print(f"Analysing {local_path}") + if perform_inference: + perform_inf( + local_path, + bucket_name=bucket_name, + loc_model=localisation_model, + binary_model=binary_model, + order_model=order_model, + order_labels=order_labels, + regional_model=species_model, + regional_category_map=species_labels, + device=device, + order_data_thresholds=order_data_thresholds, + csv_file=csv_file, + save_crops=True, + ) + # Remove the image if cleanup is enabled + if remove_image: + os.remove(local_path) + + +def main( + chunk_id, + json_file, + output_dir, + bucket_name, + credentials_file="credentials.json", + remove_image=True, + perform_inference=True, + localisation_model=None, + binary_model=None, + order_model=None, + order_labels=None, + species_model=None, + species_labels=None, + device=None, + order_data_thresholds=None, + csv_file="results.csv", +): + """ + Main function to process a specific chunk of S3 keys. + + Args: + chunk_id (str): ID of the chunk to process (e.g., chunk_0). + json_file (str): Path to the JSON file with key chunks. + output_dir (str): Directory to save results. + bucket_name (str): S3 bucket name. + Other args: Parameters for download and analysis. + """ + with open(json_file, "r") as f: + chunks = json.load(f) + + if chunk_id not in chunks: + raise ValueError(f"Chunk ID {chunk_id} not found in JSON file.") + + client = initialise_session(credentials_file) + + keys = chunks[chunk_id]['keys'] + download_and_analyse( + keys=keys, + output_dir=output_dir, + bucket_name=bucket_name, + client=client, + remove_image=remove_image, + perform_inference=perform_inference, + localisation_model=localisation_model, + binary_model=binary_model, + order_model=order_model, + order_labels=order_labels, + species_model=species_model, + species_labels=species_labels, + device=device, + order_data_thresholds=order_data_thresholds, + csv_file=csv_file, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process a specific chunk of S3 keys.") + parser.add_argument("--chunk_id", required=True, help="ID of the chunk to process (e.g., 0, 1, 2, 3).") + parser.add_argument("--json_file", required=True, help="Path to the JSON file with key chunks.") + parser.add_argument("--output_dir", required=True, help="Directory to save downloaded files and analysis results.", default="./data/") + parser.add_argument("--bucket_name", required=True, help="Name of the S3 bucket.") + parser.add_argument("--credentials_file", default="credentials.json", help="Path to AWS credentials file.") + parser.add_argument("--remove_image", action="store_true", help="Remove images after processing.") + parser.add_argument("--perform_inference", action="store_true", help="Enable inference.") + parser.add_argument("--localisation_model_path", type=str, help="Path to the localisation model weights.", default="./models/v1_localizmodel_2021-08-17-12-06.pt") + parser.add_argument("--binary_model_path", type=str, help="Path to the binary model weights.", default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth") + parser.add_argument("--order_model_path", type=str, help="Path to the order model weights.", default="./models/dhc_best_128.pth") + parser.add_argument("--order_labels", type=str, help="Path to the order labels file.") + parser.add_argument("--species_model_path", type=str, help="Path to the species model weights.", default="./models/turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt") + parser.add_argument("--species_labels", type=str, help="Path to the species labels file.", + default="./models/03_costarica_data_category_map.json") + parser.add_argument("--device", type=str, default="cpu", help="Device to run inference on (e.g., cpu or cuda).") + parser.add_argument("--order_thresholds_path", type=str, help="Path to the order data thresholds file.", default="./models/thresholdsTestTrain.csv") + parser.add_argument("--csv_file", default="results.csv", help="Path to save analysis results.") + + + args = parser.parse_args() + + if torch.cuda.is_available(): + device = torch.device("cuda") + print( + "\033[95m\033[1mCuda available, using GPU " + + "\N{White Heavy Check Mark}\033[0m\033[0m" + ) + else: + device = torch.device("cpu") + print( + "\033[95m\033[1mCuda not available, using CPU " + + "\N{Cross Mark}\033[0m\033[0m" + ) + + models = load_models( + device, + getattr(args, 'localisation_model_path'), + getattr(args, 'binary_model_path'), + getattr(args, 'order_model_path'), + getattr(args, 'order_thresholds_path'), + getattr(args, 'species_model_path'), + getattr(args, 'species_labels') + ) + + + main( + chunk_id=args.chunk_id, + json_file=args.json_file, + output_dir=args.output_dir, + bucket_name=args.bucket_name, + credentials_file=args.credentials_file, + remove_image=args.remove_image, + perform_inference=args.perform_inference, + localisation_model=models['localisation_model'], + binary_model=models['classification_model'], + order_model=models['order_model'], + order_labels=models['order_model_labels'], + order_data_thresholds=models['order_model_thresholds'], + species_model=models['species_model'], + species_labels=models['species_model_labels'], + device=device, + csv_file=args.csv_file, + ) diff --git a/README.md b/README.md index 2c0ae4d..1040270 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ source ~/miniforge3/bin/activate conda activate "~/conda_envs/moth_detector_env/" ``` -Inferences are run by country and deployment site. To run the script, for Costa Rica say, use the following command: + -### Listing Available Deployments +The multi-core pipeline is run in several steps: + +1. Listing All Available Deployments +2. Generate Key Files +3. Chop the keys into chunks +4. Analyse the chunks + +### 01. Listing Available Deployments To find information about the available deployments you can use the print_deployments function. For all deployments: @@ -105,6 +112,32 @@ python s3_download_with_inference.py \ --deployment "Garden - 3F1C4908" ``` +### 02. Generating the Keys + +```bash +python 02_generate_keys.py --bucket 'cri' --deployment_id 'dep000031' --output_file './keys/dep000031_keys.txt' +``` + +### 03. Pre-chop the Keys into Chunks + +```bash +python 03_pre_chop_files.py --input_file './keys/dep000031_keys.txt' --file_extension 'jpg|jpeg' --chunk_size 100 --output_file './keys/dep000031_workload_chunks.json' +``` + +### 04. Process the Chunked Files + +```bash +python 04_process_chunks.py \ + --chunk_id 1 \ + --json_file './keys/dep000031_workload_chunks.json' \ + --output_dir './data/dep000031' \ + --bucket_name 'cri' \ + --credentials_file './credentials.json' \ + --perform_inference \ + --remove_image +``` + + ## Running with slurm To run with slurm you need to be logged in on the [scientific nodes](https://help.jasmin.ac.uk/docs/interactive-computing/sci-servers/). diff --git a/s3_download_with_inference.py b/s3_download_with_inference.py deleted file mode 100644 index cd58663..0000000 --- a/s3_download_with_inference.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -This script downloads files from an S3 bucket synchronously and performs -inference on the images. AWS credentials, S3 bucket name, and UKCEH API -credentials are loaded from a configuration file (credentials.json). -""" - -import json -import boto3 -import torch -import pandas as pd -import os -import argparse - -from utils.aws_scripts import get_objects, get_deployments -from utils.custom_models import load_models - - -def download_and_inference( - country, - deployment, - rerun_existing, - local_directory_path, - perform_inference, - remove_image, - num_workers, -): - """ - Display the main menu and handle user interaction. - """ - - output_cols = [ - "image_path", - "bucket_name", - "analysis_datetime", - "box_score", - "box_label", - "x_min", - "y_min", - "x_max", - "y_max", - "class_name", - "class_confidence", - "order_name", - "order_confidence", - "cropped_image_path", - ] - - username = aws_credentials["UKCEH_username"] - password = aws_credentials["UKCEH_password"] - - print(f"\033[93m - Removing images after analysis: {remove_image}\033[0m") - print(f"\033[93m - Performing inference: {perform_inference}\033[0m") - print(f"\033[93m - Rerun existing inferences: {rerun_existing}\033[0m") - - all_deployments = get_deployments(username, password) - - print(f"\033[96m\033[1mAnalysing: {country}\033[0m\033[0m") - - country_deployments = [ - f"{d['location_name']} - {d['camera_id']}" - for d in all_deployments - if d["country"] == country and d["status"] == "active" - ] - - s3_bucket_name = [ - d["country_code"] - for d in all_deployments - if d["country"] == country and d["status"] == "active" - ][0].lower() - - if deployment == "All": - deps = country_deployments - else: - deps = [deployment] - - # loop through each deployment - for region in deps: - location_name, camera_id = region.split(" - ") - dep_id = [ - d["deployment_id"] - for d in all_deployments - if d["country"] == country - and d["location_name"] == location_name - and d["camera_id"] == camera_id - and d["status"] == "active" - ][0] - print(f"\033[96m - Deployment: {region}, id: {dep_id} \033[0m") - - # if the file doesnt exist, print headers - csv_file = os.path.abspath( - f"{local_directory_path}/{dep_id}/{dep_id}_results.csv" - ) - os.makedirs(os.path.dirname(csv_file), exist_ok=True) - print(f"\033[93m - Saving results to: {csv_file}\033[0m") - - if not os.path.isfile(csv_file): - all_boxes = pd.DataFrame(columns=output_cols) - all_boxes.to_csv(csv_file, index=False) - - prefix = f"{dep_id}/snapshot_images" - get_objects( - session, - aws_credentials, - s3_bucket_name, - prefix, - local_directory_path, - batch_size=100, - perform_inference=perform_inference, - remove_image=remove_image, - localisation_model=model_loc, - binary_model=classification_model, - order_model=order_model, - order_labels=order_labels, - country=country, - region=region, - device=device, - order_data_thresholds=order_data_thresholds, - csv_file=csv_file, - rerun_existing=rerun_existing, - num_workers=num_workers, - ) - print("\N{White Heavy Check Mark}\033[0m\033[0m") - - -if __name__ == "__main__": - - # Use GPU if available - - if torch.cuda.is_available(): - device = torch.device("cuda") - print( - "\033[95m\033[1mCuda available, using GPU " - + "\N{White Heavy Check Mark}\033[0m\033[0m" - ) - else: - device = torch.device("cpu") - print( - "\033[95m\033[1mCuda not available, using CPU " - + "\N{Cross Mark}\033[0m\033[0m" - ) - - print("\033[96m\033[1mInitialising the JASMINE session...\033[0m\033[0m", end="") - # Load AWS credentials and S3 bucket name from config file - with open("./credentials.json", encoding="utf-8") as config_file: - aws_credentials = json.load(config_file) - session = boto3.Session( - aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"], - region_name=aws_credentials["AWS_REGION"], - ) - - print("\N{White Heavy Check Mark}") - - num_workers = int(os.getenv("SLURM_CPUS_PER_TASK", os.cpu_count())) - - parser = argparse.ArgumentParser( - description="Script for downloading and processing images from S3." - ) - parser.add_argument("--country", type=str, help="Specify the country name") - parser.add_argument("--deployment", type=str, help="Specify the deployment name") - parser.add_argument( - "--keep_crops", - action=argparse.BooleanOptionalAction, - default=False, - help="Whether to keep the crops", - ) - parser.add_argument( - "--perform_inference", - action=argparse.BooleanOptionalAction, - default=True, - help="Whether to perform the inference", - ) - parser.add_argument( - "--remove_image", - action=argparse.BooleanOptionalAction, - default=True, - help="Whether to remove the raw image after inference", - ) - parser.add_argument( - "--rerun_existing", - action=argparse.BooleanOptionalAction, - default=False, - help="Whether to rerun images which have already been analysed", - ) - parser.add_argument( - "--data_storage_path", - type=str, - help="The path to scratch data storage", - default="./data/", - ) - parser.add_argument( - "--binary_model_path", - type=str, - help="The path to the binary model weights", - default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth", - ) - parser.add_argument( - "--order_model_path", - type=str, - help="The path to the binary model weights", - default="./models/dhc_best_128.pth", - ) - parser.add_argument( - "--order_threshold_path", - type=str, - help="The path to the binary model weights", - default="./models/thresholdsTestTrain.csv", - ) - parser.add_argument( - "--localisation_model_path", - type=str, - help="The path to the binary model weights", - default="./models/v1_localizmodel_2021-08-17-12-06.pt", - ) - parser.add_argument( - "--num_workers", - type=int, - default=num_workers, - help="Number of workers for multi-threaded downloads", - ) - - args = parser.parse_args() - - print("\033[96m\033[1mLoading in models...\033[0m\033[0m", end="") - ( - model_loc, - classification_model, - order_model, - order_data_thresholds, - order_labels, - ) = load_models(device, args.localisation_model_path, args.binary_model_path, args.order_model_path, args.order_threshold_path) - print("\N{White Heavy Check Mark}") - - # check that the data storage path exists - data_storage_path = os.path.abspath(args.data_storage_path) - if not os.path.isdir(data_storage_path): - os.makedirs(data_storage_path) - - print("\033[93m\033[1m" + "Pipeline parameters" + "\033[0m\033[0m") - print(f"\033[93m - Scratch and crops storage: {data_storage_path}\033[0m") - - print(f"\033[93m - Number of workers: {args.num_workers}\033[0m") - - download_and_inference( - args.country, - args.deployment, - args.rerun_existing, - data_storage_path, - args.perform_inference, - args.remove_image, - args.num_workers, - ) diff --git a/utils/aws_scripts.py b/utils/aws_scripts.py index 456c654..dc0c641 100644 --- a/utils/aws_scripts.py +++ b/utils/aws_scripts.py @@ -1,9 +1,3 @@ -# utils/aws_scripts.py - -import requests -from requests.auth import HTTPBasicAuth -import tqdm -from boto3.s3.transfer import TransferConfig import os from datetime import datetime, timedelta from utils.inference_scripts import perform_inf @@ -17,8 +11,9 @@ import re -def get_deployments(username, password): - """Fetch deployments from the API with authentication.""" +def list_objects(session, bucket_name, prefix): + """ + List all objects in an S3 bucket with a specific prefix. try: url = "https://connect-apps.ceh.ac.uk/ami-data-upload/get-deployments/" @@ -70,7 +65,7 @@ def download_object( s3_client.download_file(bucket_name, key, download_path, Config=transfer_config) # If crops are saved, define the frequecy - + save_crops = True print(f" - Saving crops for: {os.path.basename(download_path)}") @@ -224,7 +219,7 @@ def get_objects( results_df = pd.read_csv(csv_file, dtype=str) run_images = [re.sub(r'^.*?dep', 'dep', x) for x in results_df['image_path']] keys = [x for x in keys if x not in run_images] - + # Divide the keys among workers chunks = [ keys[i : i + math.ceil(len(keys) / num_workers)] diff --git a/utils/custom_models.py b/utils/custom_models.py index b916fdb..98fa95b 100644 --- a/utils/custom_models.py +++ b/utils/custom_models.py @@ -114,10 +114,12 @@ def load_models(device, localisation_model_path, binary_model_path, order_model_ model_order.eval() - return ( - model_loc, - classification_model, - model_order, - order_data_thresholds, - order_labels, - ) + return ({ + 'localisation_model': model_loc, + 'classification_model': classification_model, + 'species_model': species_model, + 'species_model_labels': species_category_map, + 'order_model': model_order, + 'order_model_thresholds': order_data_thresholds, + 'order_model_labels': order_labels + }) diff --git a/utils/inference_scripts.py b/utils/inference_scripts.py index dda711d..ffa9493 100644 --- a/utils/inference_scripts.py +++ b/utils/inference_scripts.py @@ -60,6 +60,7 @@ def perform_inf( order_data_thresholds, csv_file, save_crops, + box_threshold=0.99 ): """ Perform inferences on an image including: @@ -101,6 +102,7 @@ def perform_inf( "cropped_image_path", ] + image = Image.open(image_path).convert("RGB") original_image = image.copy() original_width, original_height = image.size @@ -112,12 +114,12 @@ def perform_inf( columns=all_cols ) - # Perform object localization + # Perform object localisation with torch.no_grad(): - localization_outputs = loc_model(input_tensor) + localisation_outputs = loc_model(input_tensor) # catch no crops - if len(localization_outputs[0]["boxes"]) == 0: + if len(localisation_outputs[0]["boxes"]) == 0 or all(localisation_outputs[0]["scores"] < box_threshold): df = pd.DataFrame( [ [ @@ -149,10 +151,10 @@ def perform_inf( ) # for each detection - for i in range(len(localization_outputs[0]["boxes"])): - x_min, y_min, x_max, y_max = localization_outputs[0]["boxes"][i] - box_score = localization_outputs[0]["scores"].tolist()[i] - box_label = localization_outputs[0]["labels"].tolist()[i] + for i in range(len(localisation_outputs[0]["boxes"])): + x_min, y_min, x_max, y_max = localisation_outputs[0]["boxes"][i] + box_score = localisation_outputs[0]["scores"].tolist()[i] + box_label = localisation_outputs[0]["labels"].tolist()[i] x_min = int(int(x_min) * original_width / 300) y_min = int(int(y_min) * original_height / 300) @@ -162,13 +164,13 @@ def perform_inf( box_width = x_max - x_min box_height = y_max - y_min - if box_score < 0.99: + if box_score < box_threshold: continue # if box height or width > half the image, skip if box_width > original_width / 2 or box_height > original_height / 2: continue - + # Crop the detected region and perform classification cropped_image = original_image.crop((x_min, y_min, x_max, y_max)) cropped_tensor = transform_species(cropped_image).unsqueeze(0).to(device) @@ -180,14 +182,14 @@ def perform_inf( # if save_crops then save the cropped image crop_path = "" - if order_name == "Coleoptera" or order_name == 'Heteroptera' or order_name == 'Hemiptera': - - if save_crops: + if order_name == "Coleoptera" or order_name == 'Heteroptera' or order_name == 'Hemiptera': + + if save_crops: crop_path = image_path.split(".")[0] + f"_crop{i}.jpg" cropped_image.save(crop_path) print(f"Potential beetle: {crop_path}") - + # append to csv with pandas df = pd.DataFrame( [