-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathanalysis_classify_intent.py
More file actions
50 lines (39 loc) · 2.32 KB
/
Copy pathanalysis_classify_intent.py
File metadata and controls
50 lines (39 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pandas as pd
from sklearn.metrics import accuracy_score
def calculate_misclassification_statistics(data):
# Filter to only misclassified samples
misclassified_data = data[data['misclassified']]
# Find the most common misclassification for each category
most_common_misclassification = misclassified_data.groupby('category')['predicted_category'].agg(
lambda x: x.mode()[0] if not x.mode().empty else 'No Common Misclassification')
# Calculate the fraction of misclassified prompts per category
misclassification_counts = misclassified_data.groupby('category').size()
total_counts = data['category'].value_counts()
misclassification_fraction = misclassification_counts / total_counts
# Calculate the fraction of most common misclassifications over total misclassifications
most_common_misclassification_fraction = misclassified_data.groupby('category').apply(
lambda x: len(x[x['predicted_category'] == x['predicted_category'].mode()[0]]) / len(x) if not x['predicted_category'].mode().empty else 0
)
return misclassification_fraction, most_common_misclassification, most_common_misclassification_fraction
def main():
# Load the classified data
classified_data_file = 'output/classified_data_gpt_small.csv'
data = pd.read_csv(classified_data_file)
# Check if 'split' column exists and filter for the test set if it does
if 'split' in data.columns:
data = data[data['split'] == 'test']
# Calculate the accuracy (optional, based on available data)
if 'predicted_category' in data.columns and 'category' in data.columns:
accuracy = accuracy_score(data['category'], data['predicted_category'])
print(f"Classification Accuracy: {accuracy}")
# Calculate misclassification statistics
misclassification_fraction, most_common_misclassification, most_common_misclassification_fraction = calculate_misclassification_statistics(data)
# Print the results
print("Fraction of Misclassified Prompts per Category:")
print(misclassification_fraction)
print("\nMost Common Misclassification for each Category:")
print(most_common_misclassification)
print("\nFraction of Most Common Misclassification over Total Misclassifications per Category:")
print(most_common_misclassification_fraction)
if __name__ == '__main__':
main()