-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeep_learning_model.py
More file actions
154 lines (125 loc) · 4.55 KB
/
deep_learning_model.py
File metadata and controls
154 lines (125 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Deep Learning Model for Bearing Fault Classification
1D CNN approach - works directly on raw signals
"""
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
def build_1d_cnn(input_shape, num_classes):
"""
Build 1D Convolutional Neural Network for bearing fault detection.
Args:
input_shape: (signal_length, 1) - e.g., (2048, 1)
num_classes: Number of bearing conditions (e.g., 3)
Returns:
Compiled Keras model
"""
model = models.Sequential([
# First Conv Block
layers.Conv1D(64, kernel_size=10, activation='relu', input_shape=input_shape),
layers.BatchNormalization(),
layers.MaxPooling1D(pool_size=2),
layers.Dropout(0.2),
# Second Conv Block
layers.Conv1D(128, kernel_size=7, activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling1D(pool_size=2),
layers.Dropout(0.2),
# Third Conv Block
layers.Conv1D(256, kernel_size=5, activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling1D(pool_size=2),
layers.Dropout(0.3),
# Fourth Conv Block
layers.Conv1D(128, kernel_size=3, activation='relu'),
layers.BatchNormalization(),
layers.GlobalAveragePooling1D(),
# Dense Layers
layers.Dense(128, activation='relu'),
layers.Dropout(0.5),
layers.Dense(64, activation='relu'),
layers.Dropout(0.3),
# Output Layer
layers.Dense(num_classes, activation='softmax')
])
return model
def train_deep_learning_model(X_train, X_val, y_train, y_val, num_classes):
"""
Train 1D CNN model on bearing data.
Args:
X_train: Training signals (samples, signal_length)
X_val: Validation signals
y_train: Training labels (one-hot encoded)
y_val: Validation labels (one-hot encoded)
num_classes: Number of classes
Returns:
trained_model, history
"""
# Reshape for CNN: (samples, timesteps, channels)
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)
input_shape = (X_train.shape[1], 1)
print(f"\nBuilding 1D CNN...")
print(f" Input shape: {input_shape}")
print(f" Number of classes: {num_classes}")
# Build model
model = build_1d_cnn(input_shape, num_classes)
# Compile
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Print model summary
model.summary()
# Callbacks
early_stop = callbacks.EarlyStopping(
monitor='val_loss',
patience=15,
restore_best_weights=True,
verbose=1
)
reduce_lr = callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-7,
verbose=1
)
# Train
print(f"\nTraining CNN...")
history = model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=100,
batch_size=32,
callbacks=[early_stop, reduce_lr],
verbose=1
)
return model, history
def plot_training_history(history, output_dir):
"""Plot training history."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy
ax1.plot(history.history['accuracy'], label='Train', linewidth=2)
ax1.plot(history.history['val_accuracy'], label='Validation', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
ax1.set_title('CNN Training Accuracy', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Loss
ax2.plot(history.history['loss'], label='Train', linewidth=2)
ax2.plot(history.history['val_loss'], label='Validation', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax2.set_ylabel('Loss', fontsize=12, fontweight='bold')
ax2.set_title('CNN Training Loss', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{output_dir}/cnn_training_history.png', dpi=300, bbox_inches='tight')
print(f" ✓ Training history saved: {output_dir}/cnn_training_history.png")
plt.show()