Events and Dependencies - rshipley160/learn-cuda GitHub Wiki
Previous: Basic Synchronization Methods
As we alluded to in the previous tutorial, this tutorial will cover events and how they can be used to order execution of asynchronous streams to ensure that no data hazards are introduced.
Setup
To help us understand events and other concepts moving forward, we are going to work through a brand new problem: the quadratic formula. Specifically we are going to be implementing a vectorized version of the quadratic formula that can calculate the solution to several hundreds of quadratic equations at once. Our algorithm will also utilize multiple streams in order to perform applicable parts of the algorithm in parallel.
How do we parallelize the quadratic formula?
The first way we can parallelize the quadratic formula is by operating on many sets of operands at once, or vectorizing the operation as is common practice when parallelizing operations on a GPU. To do this, we simply store our operands in an array rather than keeping track of individual elements, and then each thread's unique ID is used to determine which operands it uses during its operation.
Next let's determine what parts of the quadratic formula can be performed in parallel with one another. To do this, we can create a computational graph that represents the quadratic formula. Computational graphs are simple diagrams that show the flow of a complex computation as a series of smaller, interconnected operations, with operations represented as nodes and the inputs and outputs of the operations represented as edges If you look at the formula and diagram to the right, you should fairly quickly be able to begin "reading" the graph and seeing that it does indeed represent the quadratic formula.
The graph is important because it allows us to see dependencies between operations and opportunities for concurrent execution very quickly. Dependencies are visible through the vertical links in the graph; operations in the lower part of the graph are dependent on operations farther up in the graph. Similarly, operations that sit at the same vertical level of the graph can be performed concurrently as soon as the operation they are dependent on is met. This allows for our initial two operations, calculating b squared and -4ac, to be done at the same time, as well as the addition and subtraction of b and tmp.
tmp is used as temporary storage after the square root operation so that we do not have to further synchronize the work being done. If we were to use one of our solution arrays sol1 or sol2, we would no longer be able to do the addition and subtraction that come after the square root in parallel because it would introduce a data hazard - one of the operations could potentially modify the data the other is using as an input!
Arithmetic kernels
As we said in the last section, each of the nodes in the graph represents a discrete arithmetic function that is being performed on all of the elements in the input arrays, i.e., a GPU kernel. Because of this, we need a separate kernel for each type of arithmetic function we want to perform.
While each of these kernels is simple, there are quite a few of them, so to keep our code organized we are going to separate the functional kernels from the program driver code by putting them in the file vector_arithmetic.cuh. Here is a look at some of the arithmetic kernels in our new file:
__global__ void elementwiseSum(float *a, float *b, float *out, int numElements) {
int id = globalIndex();
if (id < numElements)
out[id] = a[id] + b[id];
}
__global__ void elementwiseDifference(float *minuend, float *subtrahend, float *difference, int numElements) {
int id = globalIndex();
if (id < numElements)
difference[id] = minuend[id] - subtrahend[id];
}
__global__ void elementwiseSqrt(float *a, float *out, int numElements) {
int id = globalIndex();
if (id < numElements) {
if (a[id] <= 0)
out[id] = 0;
else
out[id] = sqrtf(a[id]);
}
}
As you can see, they are all very simple, thus the reason we are not discussing them in detail. Also included in this header file are the array fill and global index kernel we've seen before. You can check out the full header file in the repository folder for this tutorial.
Synchronization Using Events
Now that we've introduced the problem and the elements we are going to use to solve it, let's go about actually implementing our concurrent quadratic solver using events.
First, let's just lay out our kernels in sequential order:
// Concurrent
elementwiseProduct<<<gridSize, BLOCK_SIZE>>>(b, b, sol1, NUM_ELEMENTS);
elementScalarProduct<<<gridSize, BLOCK_SIZE>>>(a, c, -4, sol2, NUM_ELEMENTS);
// Wait on product calculations to complete
elementwiseSum<<<gridSize, BLOCK_SIZE>>>(sol1, sol2, sol1, NUM_ELEMENTS);
elementwiseSqrt<<<gridSize, BLOCK_SIZE>>>(sol1, tmp, NUM_ELEMENTS);
// Concurrent, wait on sqrt before starting
elementwiseDifference<<<gridSize, BLOCK_SIZE>>>(b, tmp, sol1, NUM_ELEMENTS);
elementwiseSum<<<gridSize, BLOCK_SIZE>>>(b, tmp, sol2, NUM_ELEMENTS);
//Concurrent
elementwiseQuotient<<<gridSize, BLOCK_SIZE>>>(sol1, a, 0.5, sol1, NUM_ELEMENTS); // wait on last difference to complete before running
elementwiseQuotient<<<gridSize, BLOCK_SIZE>>>(sol2, a, 0.5, sol2, NUM_ELEMENTS); // wait on last sum to complete before running
You should be able to compare this to the graph above and come to the conclusion that this is in fact the same algorithm laid out by the computational graph in the first section. In addition to laying out the computation kernels, I have also added comments that point out which kernels can run concurrently and what they must wait on before executing, which you can also verify against the computational graph.
Now let's set about adding in the concurrency (and thus the synchronization) to this quadratic solver. To start, we need to create two streams, because we only have two concurrent operations at a time, and two events to signify completion of some task for each stream. I've named the streams bPlus
and bMinus
in reference to the penultimate stage of the quadratic formula, in which one solution comes from adding b and the determinant, and the other comes from the difference of the two. Similarly I've named the events bPlusComplete
and bMinusComplete
because they will be associated with completion of some task on each stream.
cudaStream_t bMinus;
cudaStream_t bPlus;
cudaStreamCreate(&bMinus);
cudaStreamCreate(&bPlus);
cudaEvent_t bPlusComplete;
cudaEvent_t bMinusComplete;
cudaEventCreate(&bPlusComplete);
cudaEventCreate(&bMinusComplete);
With our streams created, we can now make all of the tasks we identified as concurrent actually concurrent by placing them on different streams:
// Concurrent
elementwiseProduct<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, b, sol1, NUM_ELEMENTS);
elementScalarProduct<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(a, c, -4, sol2, NUM_ELEMENTS);
// Wait on product calculations to complete
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, sol2, sol1, NUM_ELEMENTS);
elementwiseSqrt<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, tmp, NUM_ELEMENTS);
// Concurrent, wait on sqrt before starting
elementwiseDifference<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, tmp, sol1, NUM_ELEMENTS);
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(b, tmp, sol2, NUM_ELEMENTS);
//Concurrent
elementwiseQuotient<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(sol1, a, 0.5, sol1, NUM_ELEMENTS); // wait on last difference to complete before running
elementwiseQuotient<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol2, a, 0.5, sol2, NUM_ELEMENTS); // wait on last sum to complete before running
Now we have implemented some level of concurrency, but at this point race conditions can run rampant in our code - for example, the first elementwiseSum
may be reached by bPlus
before the elementwiseProduct
kernel on the bMinus
stream has been able to complete, which would cause the result to be inaccurate.
Enter events: just as we discussed when timing our performance experiments early on, events are markers in the timeline of a stream's execution. In addition to using them for timing completion of a task, we can also instruct streams to wait on events in other streams' timelines to be triggered before continuing in order to allow a more fine-grained peer-to-peer stream synchronization method than is allowed by stream or device synchronization. Unlike both of those methods it has the additional benefit of not blocking the default stream.
To synchronize the first elementSum
kernel after our concurrent product kernels, we need the sum kernel to wait on both the bPlus and bMinus streams to finish their multiplication tasks. This is done implicitly for the bPlus stream because the sum kernel is added to the bPlus stream after the product kernel, and all tasks in the same stream are completed sequentially. That means we only have to explicitly synchronize the sum kernel with the product kernel on the bMinus stream.
Recording Events
First, we have to record an event that occurs after the bMinus stream has completed its task and is ready to move on. We can do this by recording the event in the bMinus stream. Note that the second parameter of cudaEventRecord
is bMinus
instead of 0. This is because we want to record the event in the bMinus stream, rather than the host stream.
// Concurrent
elementwiseProduct<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, b, sol1, NUM_ELEMENTS);
elementScalarProduct<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(a, c, -4, sol2, NUM_ELEMENTS);
cudaEventRecord(bMinusComplete, bMinus); // Signal bMinus completion of product kernel
Waiting on Events
Now we have to make the next kernel launch on bPlus actually wait on the event we just recorded to actually occur - just recording the event by itself does not cause this to happen. We do this by calling cudaStreamWaitEvent
just before the item that requires synchronization. You'll notice that cudaStreamWaitEvent
requires two parameters: an event, and a stream. The stream specifies the stream that is to delay its execution, and the event is the event to be waited on. By passing it the parameters bPlus
and bMinusComplete
, we are effectively making bPlus wait on the bMinus stream to complete a task. The same method can be used to make streams wait on each other in a peer-to-peer fashion, as in this case, as well as a many-to-one or one-to-many configuration.
// Wait on product calculations to complete
cudaStreamWaitEvent(bPlus, bMinusComplete);
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, sol2, sol1, NUM_ELEMENTS);
We repeat this process, although reversed, in order to make the bMinus stream wait upon the completion of the square root kernel later on:
elementwiseSqrt<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, tmp, NUM_ELEMENTS);
cudaEventRecord(bPlusComplete, bPlus);
// Concurrent, wait on sqrt before starting
cudaStreamWaitEvent(bMinus, bPlusComplete);
elementwiseDifference<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, tmp, sol1, NUM_ELEMENTS);
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(b, tmp, sol2, NUM_ELEMENTS);
All in all, our program driver code now looks a bit like this:
cudaStream_t bMinus;
cudaStream_t bPlus;
cudaStreamCreate(&bMinus);
cudaStreamCreate(&bPlus);
cudaEvent_t bPlusComplete;
cudaEvent_t bMinusComplete;
cudaEventCreate(&bPlusComplete);
cudaEventCreate(&bMinusComplete);
// Concurrent
elementwiseProduct<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, b, sol1, NUM_ELEMENTS);
elementScalarProduct<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(a, c, -4, sol2, NUM_ELEMENTS);
cudaEventRecord(bMinusComplete, bMinus); // Signal bMinus completion of product kernel
// Wait on product calculations to complete
cudaStreamWaitEvent(bPlus, bMinusComplete);
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, sol2, sol1, NUM_ELEMENTS);
elementwiseSqrt<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol1, tmp, NUM_ELEMENTS);
cudaEventRecord(bPlusComplete, bPlus);
// Concurrent, wait on sqrt before starting
cudaStreamWaitEvent(bMinus, bPlusComplete);
elementwiseDifference<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(b, tmp, sol1, NUM_ELEMENTS);
elementwiseSum<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(b, tmp, sol2, NUM_ELEMENTS);
//Concurrent
elementwiseQuotient<<<gridSize, BLOCK_SIZE, 0, bMinus>>>(sol1, a, 0.5, sol1, NUM_ELEMENTS); // wait on last difference to complete before running
elementwiseQuotient<<<gridSize, BLOCK_SIZE, 0, bPlus>>>(sol2, a, 0.5, sol2, NUM_ELEMENTS); // wait on last sum to complete before running
It should be noted that the final kernel in each stream does not need to be explicitly synchronized because the kernels prior are already synchronized properly and all kernels enqueued in a stream are completed sequentially, as mentioned previously.
Feel free to check out the full example on your own to get a better feel for how the synchronization works, but otherwise this tutorial is complete. Check out the next article to see what kind of performance gains event-based synchronization can offer over the other methods we have learned.