diff --git a/results/generate_csv_results.py b/results/generate_csv_results.py index 04cf710a..70fd4588 100644 --- a/results/generate_csv_results.py +++ b/results/generate_csv_results.py @@ -62,13 +62,13 @@ def diff(base_df, test_csv): test_df['rank_diff'] = rank_diff test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) - test_df.sort_values('top1', ascending=False, inplace=True) + test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) test_df.to_csv(test_csv, index=False, float_format='%.3f') for base_results, test_results in results.items(): base_df = pd.read_csv(base_results) - base_df.sort_values('top1', ascending=False, inplace=True) + base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) for test_csv in test_results: diff(base_df, test_csv) base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format)