-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpredict_bmi.py
More file actions
168 lines (135 loc) · 5.54 KB
/
predict_bmi.py
File metadata and controls
168 lines (135 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python3
"""
Simple BMI prediction script using trained DenseNet models.
This script demonstrates how to load a pre-trained model and run BMI predictions
on sample images from the Visual Body to BMI dataset.
"""
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import argparse
from model import SEDensenet121, SEDensenet201
from dataset import BMIDataset, load_sample_data
def load_model(model_path, model_type="densenet121", device="cuda"):
"""
Load a trained model from checkpoint.
Args:
model_path: Path to the model checkpoint file
model_type: Type of model ("densenet121" or "densenet201")
device: Device to load model on
Returns:
Loaded model in evaluation mode
"""
if not os.path.exists(model_path):
print(f"Model checkpoint not found at {model_path}")
print("Please download the model weights from the authors:")
print("Contact: rmanichand@ethz.ch or planger@ethz.ch")
return None
# Initialize model
if model_type.lower() == "densenet121":
model = SEDensenet121()
elif model_type.lower() == "densenet201":
model = SEDensenet201()
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"Loaded {model_type} model from {model_path}")
return model
def predict_bmi(model, dataloader, device="cuda"):
"""
Run BMI prediction on a dataset.
Args:
model: Trained model
dataloader: DataLoader with images
device: Device for computation
Returns:
List of results with predictions and metadata
"""
results = []
model.eval()
with torch.no_grad():
for batch_idx, (images, bmis, individual_ids) in enumerate(dataloader):
images = images.to(device, dtype=torch.float32)
# Get predictions
outputs = model(images)
predicted_bmis = outputs.cpu().numpy().flatten()
# Store results
for i in range(len(images)):
result = {
'individual_id': individual_ids[i],
'predicted_bmi': predicted_bmis[i],
'actual_bmi': bmis[i].item() if bmis[i] is not None else None,
}
results.append(result)
# Print progress
if result['actual_bmi'] is not None:
error = abs(result['predicted_bmi'] - result['actual_bmi'])
print(f"ID: {result['individual_id']:<10} | "
f"Predicted: {result['predicted_bmi']:.2f} | "
f"Actual: {result['actual_bmi']:.2f} | "
f"Error: {error:.2f}")
else:
print(f"ID: {result['individual_id']:<10} | "
f"Predicted: {result['predicted_bmi']:.2f}")
return results
def main():
parser = argparse.ArgumentParser(description='BMI Prediction using DenseNet')
parser.add_argument('--model_path', type=str, default='weights/best_model.ckpt',
help='Path to model checkpoint')
parser.add_argument('--model_type', type=str, default='densenet121',
choices=['densenet121', 'densenet201'],
help='Type of model to use')
parser.add_argument('--data_dir', type=str, default='data',
help='Directory containing the dataset')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for inference')
parser.add_argument('--device', type=str, default='auto',
choices=['auto', 'cuda', 'cpu'],
help='Device to use for computation')
args = parser.parse_args()
# Set device
if args.device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
print(f"Using device: {device}")
# Load dataset
print("Loading dataset...")
sample_data = load_sample_data(args.data_dir)
if sample_data is None:
print("Failed to load dataset. Exiting...")
sys.exit(1)
# Create dataset and dataloader
dataset = BMIDataset(sample_data)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
# Load model
print("Loading model...")
model = load_model(args.model_path, args.model_type, device)
if model is None:
print("Failed to load model. Exiting...")
sys.exit(1)
# Run predictions
print("\nRunning BMI predictions...")
print("-" * 80)
results = predict_bmi(model, dataloader, device)
# Calculate summary statistics
if any(r['actual_bmi'] is not None for r in results):
errors = [abs(r['predicted_bmi'] - r['actual_bmi'])
for r in results if r['actual_bmi'] is not None]
print("-" * 80)
print(f"Summary Statistics (n={len(errors)}):")
print(f"Average Error: {sum(errors)/len(errors):.2f}")
print(f"Max Error: {max(errors):.2f}")
print(f"Min Error: {min(errors):.2f}")
print(f"\nCompleted predictions for {len(results)} images.")
if __name__ == "__main__":
main()