julia_immutable - HazyResearch/dimmwitted GitHub Wiki

Julia Support: Can I use non-primative data type, e.g., structure, in my data?

Yes, you can. In this tutorial, we will walkthrough an example of using non-primative type, both for data and model. The takeaway is that as long as your data type is immutable, you can use it in DimmWitted in the same way as those primitive types. The code used in this tutorial can be found here.

Pre-requisites... To understand this tutorial, we assume that you have already familiar with the Julia walkthorugh, and knows how to write a logistic regression model where both data and model are of the type Cdouble.

Defining Data Type

The very first step of using non-primative type is to define the type that you want to use. DimmWitted is able to use any primative type that you defined. For example,

immutable DoublePair
	d1::Cdouble
	d2::Cdouble
end

This piece of code creates a type called DoublePair that consists of two Cdouble pairs. After we define this type, we can generate the data in a similar way as Julia walkthorugh. For example, one application we can try is to train two logistic regression with complementary label (Line 9 and Line 11) in a single pass over the data:

nexp = 100000
nfeat = 1024
examples = Array(DoublePair, nexp, nfeat+1)
for row = 1:nexp
	for col = 1:nfeat
		examples[row, col] = DoublePair(1.0,1.0)
	end
	if rand() > 0.8
		examples[row, nfeat+1] = DoublePair(1.0,0.0)
	else
		examples[row, nfeat+1] = DoublePair(0.0,1.0)
	end
end
model = DoublePair[DoublePair(0.0,0.0) for i = 1:nfeat]

Defining Functions

Given the new data type, to write the function (e.g., loss), we only need to change the signature accordingly, for example

function loss(row::Array{DoublePair,1}, model::Array{DoublePair,1})
	const label1 = row[length(row)].d1
	const label2 = row[length(row)].d2
	const nfeat = length(model)
	d1 = 0.0
	d2 = 0.0
	for i = 1:nfeat
		d1 = d1 + row[i].d1*model[i].d1
		d2 = d2 + row[i].d2*model[i].d2
	end
	v1 = (-label1 * d1 + log(exp(d1) + 1.0))
	v2 = (-label2 * d2 + log(exp(d2) + 1.0))
	return v1 + v2
end

Compared with Julia walkthorugh, this piece of code calculates two losses and return the sum of it. You can write the gradient function in a way similar to the loss function.

Getting the result!

After you define the function, you do not need to change anything else to run DimmWitted! To validate our result, we can check the result of the following piece of code:

sum1 = 0.0
sum2 = 0.0
for i = 1:length(model)
	sum1 = sum1 + model[i].d1
	sum2 = sum2 + model[i].d2
end
println("SUM OF MODEL1: ", sum1)
println("SUM OF MODEL2: ", sum2)

The result should be

SUM OF MODEL1: -1.2475164859764791
SUM OF MODEL2: 1.2474677445561093

Which is consistent with the synthetic data that we just generated.

Possible Pitfalls

There are couple things you need to keep in mind:

  • In DimmWitted v0.01, you must use immutable. For now, you cannot use type, tuple, or other ways of constructing composite types.
  • Because you are using immutable, you cannot write things like
model[i].d1 = 5.0

instead you need to write

model[i] = DoublePair(5, model[i].d2)