timm util code - YingkunZhou/EdgeTransformerBench GitHub Wiki
import timm
params = [sum(p.numel() for p in timm.create_model(i).parameters())/1000000.0 for i in timm.list_models()]
open('timm-1.0.12.log', 'a').write('\n'.join([str(p) for p in params]))
n = [0 for i in range(10)]
for p in params:
if p <= 5: n[0] += 1
elif p <= 10: n[1] += 1
elif p <= 15: n[2] += 1
elif p <= 20: n[3] += 1
elif p <= 25: n[4] += 1
elif p <= 30: n[5] += 1
elif p <= 35: n[6] += 1
elif p <= 40: n[7] += 1
elif p <= 45: n[8] += 1
elif p <= 50: n[9] += 1
params = [i.strip() for i in open('timm.log', 'r').readlines()]
params = [float(p) for p in params]
vit_keword = [
'eca_',
'bat_resnext',
]
cnn_keyword = [
'convnext',
'darknet',
'densenet',
'dla',
'dpn',
'efficientnet',
'vovnet',
'fbnet',
'gernet',
'ghostnet',
'hardcorenas',
'hrnet',
'lcnet',
'resnet',
'mixnet',
'mnasnet',
'mobilenet_',
'mobilenetv1',
'mobilenetv2',
'mobilenetv3',
'mobileone',
'nf_regnet',
'regnet',
'repghostnet',
'repvgg',
'resnest',
'resnext',
'rexnet',
'spnasnet',
'tinynet',
]
def check_cnn(name):
for k in vit_keword:
if k in name:
return False
for k in cnn_keyword:
if k in name:
return True
return False
refer = open('results-imagenet.csv', 'r').readlines()
def check_size(name):
size = name.split('_')[-1]
for l in refer:
if name in l:
return l.split(',')[1]
if size == '224' or size == '256':
return size
print(name)
return '224'
import timm
params = [float(i.strip()) for i in open('timm.log', 'r').readlines()]
names = zip(timm.list_models(), params)
candidate = []
prev = ''
for (name,param) in list(names):
if 'test_' in name: continue
if param <= 15:
prefix = name.split('_')[0]
if 'densenet' in name:
prefix = 'densenet'
elif 'dla' in name:
prefix = 'dla'
elif 'dpn' in name:
prefix = 'dpn'
elif 'halonet' in name:
prefix = 'halonet'
elif 'resnest' in name:
prefix = 'resnest'
elif 'resnetv2' in name:
prefix = 'resnetv2'
elif 'resnet' in name:
prefix = 'resnet'
if 'tf_' in name:
prefix = prev
item = name+','+str(param)+','+check_size(name)
if check_cnn(name):
if prefix != prev:
candidate.append('-------cnn-------')
candidate.append('cnn,'+item)
else:
if prefix != prev:
candidate.append('-------vit-------')
candidate.append('vit,'+item)
prev = prefix
open('timm-15M.log', 'w').write('\n'.join(candidate))
import subprocess
lines = open("timm-1.0.12-15M.log").readlines()
for l in lines:
if "----" not in l:
w = l.strip().split(',')
model = w[1]
size = w[-1]
cmd = f"python python/convert.py --extern-model {model},{size} --non-pretrained --get-metrics 2>/dev/null"
subprocess.run(cmd, shell=True)
tmp.log
lines = open('tmp.log', 'r').readlines()
n = [0 for i in range(5)]
name = ''
for l in lines:
if 'C' == l[0]:
name = l.strip().split()[-1]
elif '[' == l[0]:
pass
else:
gmacs = float(l.strip())
# if gmacs > 2000: print(name)
if gmacs <= 1000: n[0] += 1
elif gmacs <= 2000: n[1] += 1
elif gmacs <= 3000: n[2] += 1
elif gmacs <= 4000: n[3] += 1
else: n[4] += 1