-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-bench.lua
92 lines (76 loc) · 1.6 KB
/
test-bench.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
local ffi = require 'ffi'
local SZ1 = 2
local N = 10000000
--local N = 10000
--local SZ1 = 500
--local SZ2 = 210
--local SZ3 = 130
require 'torch'
ffi.cdef[[
typedef struct THAllocator THAllocator;
typedef struct
{
double *data;
long size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
} THDoubleStorage;
typedef struct
{
long *size;
long *stride;
int nDimension;
THDoubleStorage *storage;
long storageOffset;
int refcount;
char flag;
} THDoubleTensor;
double THDoubleTensor_sumall(THDoubleTensor *tensor);
]]
local function creatv(SZ1, SZ2, SZ3, isvec)
local t
if SZ2 and SZ3 then
t = torch.Tensor(SZ1, SZ2, SZ3):fill(0.15)
elseif SZ2 then
t = torch.Tensor(SZ1, SZ2):fill(0.15)
else
t = torch.Tensor(SZ1):fill(0.15)
end
if isvec then
t = t:reshape(t:nElement())
end
return t
end
local x = creatv(SZ1, SZ2, SZ3)
local y = creatv(SZ1, SZ2, SZ3, true)
local z = creatv(SZ1, SZ2, SZ3, true)
--x = x:transpose(2,3)
local x_p = x:cdata()
local y_p = y:cdata()
local z_p = z:cdata()
function benchmark(txt, func, endfunc)
torch.manualSeed(1111)
x:copy(creatv(SZ1, SZ2, SZ3))
y:copy(creatv(SZ1, SZ2, SZ3))
z:copy(creatv(SZ1, SZ2, SZ3))
print('--------------------------------------------')
print(txt .. ' input size: ' .. SZ1)
local t = torch.Timer()
for i=1,N do
func()
end
print('time for ' .. N .. ' calls: ', t:time().real)
endfunc()
end
benchmark(
'torch7 sumall',
function()
sum = torch.sum(x)
end,
function()
print('sum', sum)
end
)
os.exit()