julia_global - HazyResearch/dimmwitted GitHub Wiki

Julia Support: Can my gradient function accesses some global variables, e.g., stepsize?

In the example we just show in logistic regression here, we note that the gradient function uses the same stepsize (i.e., 0.00001) in all iterations. In real applications, we might want to use different stepsizes for different iterations; and more generally, we might also want to use some global variables inside the gradient function. In this tutorial, we will show you how to do this. The code can be found here.

Pre-requisites... To understand this tutorial, we assume that you have already familiar with the Julia walkthorugh, and knows how to write a logistic regression model where the loss function and gradient function does not have access to global variables.

What we cannot do?

As of now, you cannot use the following ways that seems natural.

stepsize = 0.00001
function grad(row::Array{Cdouble,1}, model::Array{Cdouble,1})
  global stepsize
end

The reason is mentioned in Julia's documentation. Therefore, in DimmWitted, we provide the following workaround.

The Current Workaround

As a workaround of accessing global variables, we extend the schema of the loss and gradient function to also take as input one more variable, which is an array of arbitrary immutable type. For example, to pass the stepsize, we can define

immutable SHARED_DATA
	stepsize::Cdouble
	decay::Cdouble
end
shared_data = Array(SHARED_DATA, 1)
shared_data[1] = SHARED_DATA(0.00001, 0.99)

Accordingly, the gradient function becomes

function grad(row::Array, model::Array, _shared_data::Array{SHARED_DATA,1})
	const stepsize = _shared_data[1].stepsize
	const label = row[length(row)]
	const nfeat = length(model)
	d = 0.0
	for i = 1:nfeat
		d = d + row[i]*model[i]
	end
	d = exp(-d)
	Z = stepsize * (-label + 1.0/(1.0+d))
  	for i = 1:nfeat
  		model[i] = model[i] - row[i] * Z
  	end
	return 1.0
end

Note that, this gradient function takes as input a variable called _shared_data.

When creating the DimmWitted object, you need to pass in the shared_data object that we just created as the last argument

dw = DimmWitted.open(examples, model, 
                DimmWitted.MR_PERMACHINE,    
                DimmWitted.DR_SHARDING,      
                DimmWitted.AC_ROW, shared_data)

The last twist you need to do is that you cannot use register_row to register the function now, instead, you need to use a function called register_row2:

handle_grad = DimmWitted.register_row2(dw, grad)