Chapter 05 Shared memory and constant memory - SaoYan/Learning_CUDA GitHub Wiki
- 1. Shared memory
- 2. Shared memory application
- 3. Constant memory
- 4. Warp shuffle instruction
- A fixed amount of shared memory is allocated to each thread block when it starts executing. This shared memory address space is shared by all threads in a thread block. Its contents have the same lifetime as the thread block in which it was created.
- Shared memory accesses are issued per warp. Ideally, each request to access shared memory by a warp is serviced in one transaction. In the worst case, each request to shared memory is executed sequentially in 32 unique transactions.
- If multiple threads access the same word in shared memory, one thread fetches the word, and sends it to the other threads via multicast.
- Shared memory is partitioned among all resident thread blocks on an SM. Therefore, shared memory is a critical resource that limits device parallelism.
You can allocate shared memory variables either statically or dynamically.
Shared memory can also be declared as either local to a CUDA kernel or globally in a CUDA source code file. If declared inside a kernel function, the scope of the variable is local to the kernel. If declared outside of any kernels in a file, the scope of this variable is global to all kernels.
- allocating statically
__shared__ float var[size_y][size_x];
- allocating dynamically
If the size of shared memory is unknown at compile time, you can declare an un-sized array with the extern keyword.
extern __shared__ float var[];
You need to dynamically allocate shared memory at each kernel invocation by specifying the desired size in bytes as a third argument.
kernel_name<<<grid, block, n * sizeof(float)>>>(...)
Note: you can only declare 1D arrays dynamically.
-
What are banks? To achieve high memory bandwidth, shared memory is divided into 32 equally-sized memory modules, called banks, which can be accessed simultaneously. There are 32 banks because there are 32 threads in a warp.
-
How does this influence the performance? Depending on the compute capability of a GPU, the addresses of shared memory are mapped to different banks in different patterns (more on this later). If a shared memory load or store operation issued by a warp does not access more than one memory location per bank, the operation can be serviced by one memory transaction. Otherwise, the operation is serviced by multiple memory transactions, thereby decreasing memory bandwidth utilization.
When multiple addresses in a shared memory request fall into the same memory bank, a bank conflict occurs, causing the request to be replayed.
Three typical situations occur when a request to shared memory is issued by a warp:
- Parallel access: multiple addresses accessed across multiple banks; optimally, a conflict-free shared memory access is performed when every address is in a separate bank.
- Serial access: (the worst pattern) multiple addresses accessed within the same bank
- Broadcast access: a single address read in a single bank; one memory transaction is executed, and the accessed word is broadcast to all requesting threads; the bandwidth utilization is poor because only a small number of bytes are read.
Shared memory bank width defines which shared memory addresses are in which shared memory banks. Memory bank width varies for devices, and two typical casees are 4-byte and 8-byte bank width. Successive 32-bit words map to successive banks. The calculation of bank index follows: bank index = (byte address / bank width) % 32. For example, for 4-byte bank width:
Note:
- A bank conflict does not occur when two threads from the same warp access the same address. For read access: the word is broadcast to the requesting threads. For write access: the word is written by only one of the threads (which thread performs the write is undefined).
- Similarly, for 8-byte (64-bit) bank width, if two threads access any sub-word within the same 64-bit word, there will be no bank conflict.
Example of conflict-free access (8-byte bank width; each grid in the figure is a 4-byte word):
two threads access words in the same bank and within the same 8-byte word
Example of bank conflict
three threads access the same bank and the addresses fall in three different 8-byte words
Memory padding is one way to avoid bank conflicts.
Suppose you have only five shared memory banks. If all threads access different locations in bank 0, a five-way bank conflict occurs. One way to resolve this type of bank conflict is to add a word of padding after every N elements (N = # banks). The words that used to all belong to bank 0 are now spread across different banks because of the padding.
Note
- The padded elements are never used, causing a waste of memory.
- You need to recalculate array indices to make sure you access the correct data elements.
- When the memory bank width varies, the memory padding pattern may need to change accordingly.
Some devices support both 4-byte and 8-byte bank width (e.g. Kepler, the default is 4-byte mode).
The access mode can be queried using the following CUDA runtime API function:
cudaError_t cudaDeviceGetSharedMemConfig(cudaSharedMemConfig *pConfig);
pConfig can be either cudaSharedMemBankSizeFourByte or cudaSharedMemBankSizeEightByte.
The access mode can be configured using:
cudaError_t cudaDeviceSetSharedMemConfig(cudaSharedMemConfig config);
pConfig can be any of the following three: cudaSharedMemBankSizeDefault, cudaSharedMemBankSizeFourByte, or cudaSharedMemBankSizeEightByte.
Note
- Changing the shared memory configuration between kernel launches might require an implicit device synchronization point.
- Changing the shared memory bank size will not increase shared memory usage or affect occupancy of kernels, but it might have a major effect on performance.
For Kepler, shared memory and the L1 cache shared the same on-chip 64 KB storage. CUDA provides two methods for configuring the size of L1 cache and shared memory: per-device configuration and per-kernel configuration
Note: For Kepler devices, L1 cache is used for register spills, so you should prefer more L1 cache when a kernel uses more registers.
- per-device configuration
cudaError_t cudaDeviceSetCacheConfig(cudaFuncCache cacheConfig);
cacheConfig can be one of:
cudaFuncCachePreferNone // no preference(default)
cudaFuncCachePreferShared // prefer 48KB shared memory and 16 KB L1 cache
cudaFuncCachePreferL1 // prefer 48KB L1 cache and 16 KB shared memory
cudaFuncCachePreferEqual // prefer 32KB L1 cache and 32 KB shared memory
The CUDA runtime makes a best effort to use the requested device on-chip memory configuration, but it is free to choose a different configuration if required to execute a kernel function. A per-kernel configuration can also override the device-wide setting (see below).
- per-kernel configuration
cudaError_t cudaFuncSetCacheConfig(const void* func, cudaFuncCacheca cheConfig);
Launching a kernel with a different preference than the most recent preference setting might result in implicit device synchronization.
Maxwell and Pascal, provide dedicated space to the shared memory of each SM, since the functionality of the L1 and texture caches have been merged. This increases the shared memory space available per SM as compared to Kepler: GP100 offers 64 KB shared memory per SM, and GP104 provides 96 KB per SM.
Shared memory can be simultaneously accessed by multiple threads within a thread block. Intra-block synchronization is necessary to avoid conflict.
In general, there are two basic approaches to synchronization: barriers, memory fences. At a barrier, all calling threads wait for all other calling threads to reach the barrier point. At a memory fence, all calling threads stall until all modifications to memory are visible to all other calling threads.
First of all, we need to know CUDA's weakly-ordered memory model:
- The order in which a GPU thread writes data to (or read data from) different memories (shared memory, global memory, page-locked host memory, or the memory of a peer device) is not necessarily the same order of those accesses in the source code.
- The order in which a thread’s writes become visible to other threads may not match the actual order in which those writes were performed.
You can specify a barrier point in a kernel by calling the function __syncthreads()
What does __syncthreads do?
- acting as a barrier point: threads in a block must wait until all threads have reached that point
- ensuring that all global and shared memory accesses made by these threads prior to the barrier point are visible to all threads in the same block
Note: be careful of using __syncthreads in conditional code
It is only valid to call __syncthreads if a conditional is guaranteed to evaluate identically across the entire thread block. Otherwise execution is likely to hang or produce unintended side effects. For example, the following code segment may cause threads in a block to wait indefinitely for each other because all threads in a block never hit the same barrier point.
if (threadID % 2 == 0) {
__syncthreads();
}
else {
__syncthreads();
}
There is no inter-block synchronization. If a CUDA kernel requires global synchronization across blocks, you can split the kernel apart at the synchronization point and performing multiple kernel launches. Because each successive kernel launch must wait for the preceding kernel launch to complete (more precisely, CUDA operations issued to the same stream always serialize.), this produces an implicit global barrier.
Memory fence functions ensure that any memory write before the fence is visible to other threads after the fence. Note that memory fences do not perform any thread synchronization, and so it is not necessary for all threads in a block to actually execute this instruction.
There are three variants of memory fences depending on the desired scope: block, grid, or system.
- __threadfence_block() creates a memory fence within a thread block
- __threadfence() creates a memory fence at the grid level; it stalls the calling thread until all of its writes to global memory are visible to all threads in the same grid.
- __threadfence_system() sets a memory fence across the system (including host and device); it stalls the calling thread to ensure all its writes to global memory, page-locked host memory, and the memory of other devices are visible to all threads in all devices and host threads.
Declaring a variable in global or shared memory using the volatile qualifier prevents compiler optimization which might temporally cache data in registers or local memory. With the volatile qualifier, the compiler assumes that the variable’s value can be changed or used at any time by any other thread. Therefore, any reference to this variable is compiled to a global memory read or global memory write instruction that skips the cache.
A simple demo: allocating shared memory of a 32x32 integer array;
#define BDIMX 32
#define BDIMY 32
using block(32, 32) and grid(1, 1); writing data into shared memory, then loading data from shared memory and writing to global memory
- Statically declared shared memory - row-based v.s. col-based
There are two different layout of write data to shared memory: tile[threadIdx.y][threadIdx.x] (row-based) or tile[threadIdx.x][threadIdx.y] (column-based). The former one is conflict-free, while the latter one causes 32-way conflict.
__global__ void writeRowReadRow (int *out) {
// static shared memory
__shared__ int tile[BDIMY][BDIMX];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store operation
tile[threadIdx.y][threadIdx.x] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[threadIdx.y][threadIdx.x] ;
}
__global__ void writeColReadCol (int *out) {
// static shared memory
__shared__ int tile[BDIMY][BDIMX];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store operation
tile[threadIdx.x][threadIdx.y] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[threadIdx.x][threadIdx.y];
}
nvprof --metrics shared_load_transactions_per_request,shared_store_transactions_per_request ./sharedMemSquare
The profiling results clearly shows bank conflicts:
- Dynamically declared shared memory
__global__ void writeRowReadRowDynamic(int *out) {
// dynamic shared memory
extern __shared__ int tile[];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store operation
tile[idx] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[idx];
}
__global__ void writeColReadColDynamic(int *out) {
// dynamic shared memory
extern __shared__ int tile[];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store
int colIdx = threadIdx.x * blockDim.y + threadIdx.y; // col-based index
tile[colIdx] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[colIdx];
}
- Memory padding
Now use memory padding to resolve the conflict caused by col-based access:
#define PADDING 1
__global__ void writeColReadColPad(int *out) {
// static shared memory with memory padding
__shared__ int tile[BDIMY][BDIMX + PADDING];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store operation
tile[threadIdx.x][threadIdx.y] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[threadIdx.x][threadIdx.y];
}
__global__ void writeColReadColDynamicPad(int *out) {
// dynamic shared memory
extern __shared__ int tile[];
// mapping from thread index to global memory index
// assuming only one block
int idx = threadIdx.y * blockDim.x + threadIdx.x;
// shared memory store operation
int colIdx = threadIdx.x * (blockDim.y + PADDING) + threadIdx.y; // col-based index
tile[colIdx] = idx;
// wait for all threads to complete
__syncthreads();
// shared memory load operation
// global memory store operation
out[idx] = tile[colIdx];
}
Similar to square shared memory:
- statically v.s. dynamically
- row-based v.s. col-based
- memory padding
Note that in this demo, the padding should be 2 in order to achieve conflict-free access. (The analysis is easy, do it yourself :D)
The benefits of using shared memory include:
- caching data on-chip, thereby reducing the number of global memory accesses (demo: parallel reduction with shared memory)
- avoiing non-coalesced global memory access (demo: matrix transpose using shared memory)
Using shared memory improves performance due to less global memory access (shared memory is on-chip and has much lower latency than global memory). This can be verified using gld_transactions and gst_transactions metrics.
But keep in mind that the shared memory is limited resource, and allocating too much shared memory will reduce the number of blocks scheduled on each SM.
- baseline without shared memory (complete unrolling; unrolling warps)
__global__ void reduceGlobalMem(int *g_idata, int * g_odata, const int n) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
const int tid = threadIdx.x;
// convert global data pointer to the local pointer of this block
int *idata = g_idata + blockIdx.x * blockDim.x;
// in-place reduction and complete unroll
if (blockDim.x >= 1024 && tid < 512) idata[tid] += idata[tid + 512];
__syncthreads();
if (blockDim.x >= 512 && tid < 256) idata[tid] += idata[tid + 256];
__syncthreads();
if (blockDim.x >= 256 && tid < 128) idata[tid] += idata[tid + 128];
__syncthreads();
if (blockDim.x >= 128 && tid < 64) idata[tid] += idata[tid + 64];
__syncthreads();
// unrolling wrap
if (tid < 32) {
volatile int *vmem = idata;
vmem[tid] += vmem[tid + 32];
vmem[tid] += vmem[tid + 16];
vmem[tid] += vmem[tid + 8];
vmem[tid] += vmem[tid + 4];
vmem[tid] += vmem[tid + 2];
vmem[tid] += vmem[tid + 1];
}
if (tid == 0) g_odata[blockIdx.x] = idata[0];
}
- using shared memory
#define BLOCKSIZE 128
__global__ void reduceSharedMem(int *g_idata, int * g_odata, const int n) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
const int tid = threadIdx.x;
// convert global data pointer to the local pointer of this block
int *idata = g_idata + blockIdx.x * blockDim.x;
// shared memory
__shared__ int smem[BLOCKSIZE];
smem[tid] = idata[tid];
__syncthreads();
// in-place reduction and complete unroll
if (blockDim.x >= 1024 && tid < 512) smem[tid] += smem[tid + 512];
__syncthreads();
if (blockDim.x >= 512 && tid < 256) smem[tid] += smem[tid + 256];
__syncthreads();
if (blockDim.x >= 256 && tid < 128) smem[tid] += smem[tid + 128];
__syncthreads();
if (blockDim.x >= 128 && tid < 64) smem[tid] += smem[tid + 64];
__syncthreads();
// unrolling warp
if (tid < 32) {
volatile int *vmem = smem;
vmem[tid] += vmem[tid + 32];
vmem[tid] += vmem[tid + 16];
vmem[tid] += vmem[tid + 8];
vmem[tid] += vmem[tid + 4];
vmem[tid] += vmem[tid + 2];
vmem[tid] += vmem[tid + 1];
}
if (tid == 0) g_odata[blockIdx.x] = smem[0];
}
The main benefits of loop unrolling lie in two aspects: increasing global memory throughput by exposing more parallel instructions per thread; reducing global memory store transactions
- baseline without shared memory
__global__ void reduceGlobalMemUnroll8(int *g_idata, int * g_odata, const int n) {
const int idx = blockIdx.x * blockDim.x * 8 + threadIdx.x;
if (idx >= n) return;
const int tid = threadIdx.x;
// convert global data pointer to the local pointer of this block
int *idata = g_idata + blockIdx.x * blockDim.x * 8;
// unrolling 8
if (idx + 7 * blockDim.x < n) {
int a1 = g_idata[idx];
int a2 = g_idata[idx + blockDim.x];
int a3 = g_idata[idx + 2 * blockDim.x];
int a4 = g_idata[idx + 3 * blockDim.x];
int a5 = g_idata[idx + 4 * blockDim.x];
int a6 = g_idata[idx + 5 * blockDim.x];
int a7 = g_idata[idx + 6 * blockDim.x];
int a8 = g_idata[idx + 7 * blockDim.x];
g_idata[idx] = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8;
}
__syncthreads();
// in-place reduction and complete unroll
if (blockDim.x >= 1024 && tid < 512) idata[tid] += idata[tid + 512];
__syncthreads();
if (blockDim.x >= 512 && tid < 256) idata[tid] += idata[tid + 256];
__syncthreads();
if (blockDim.x >= 256 && tid < 128) idata[tid] += idata[tid + 128];
__syncthreads();
if (blockDim.x >= 128 && tid < 64) idata[tid] += idata[tid + 64];
__syncthreads();
// unrolling warp
if (tid < 32) {
volatile int *vmem = idata;
vmem[tid] += vmem[tid + 32];
vmem[tid] += vmem[tid + 16];
vmem[tid] += vmem[tid + 8];
vmem[tid] += vmem[tid + 4];
vmem[tid] += vmem[tid + 2];
vmem[tid] += vmem[tid + 1];
}
if (tid == 0) g_odata[blockIdx.x] = idata[0];
}
- using shared memory
__global__ void reduceSharedMemUnroll8(int *g_idata, int * g_odata, const int n) {
const int idx = blockIdx.x * blockDim.x * 8 + threadIdx.x;
if (idx >= n) return;
const int tid = threadIdx.x;
// unrolling 8
__shared__ int smem[BLOCKSIZE];
int temp = 0;
if (idx + 7 * blockDim.x < n) {
int a1 = g_idata[idx];
int a2 = g_idata[idx + blockDim.x];
int a3 = g_idata[idx + 2 * blockDim.x];
int a4 = g_idata[idx + 3 * blockDim.x];
int a5 = g_idata[idx + 4 * blockDim.x];
int a6 = g_idata[idx + 5 * blockDim.x];
int a7 = g_idata[idx + 6 * blockDim.x];
int a8 = g_idata[idx + 7 * blockDim.x];
temp = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8;
}
smem[tid] = temp;
__syncthreads();
// in-place reduction and complete unroll
if (blockDim.x >= 1024 && tid < 512) smem[tid] += smem[tid + 512];
__syncthreads();
if (blockDim.x >= 512 && tid < 256) smem[tid] += smem[tid + 256];
__syncthreads();
if (blockDim.x >= 256 && tid < 128) smem[tid] += smem[tid + 128];
__syncthreads();
if (blockDim.x >= 128 && tid < 64) smem[tid] += smem[tid + 64];
__syncthreads();
// unrolling warp
if (tid < 32) {
volatile int *vmem = smem;
vmem[tid] += vmem[tid + 32];
vmem[tid] += vmem[tid + 16];
vmem[tid] += vmem[tid + 8];
vmem[tid] += vmem[tid + 4];
vmem[tid] += vmem[tid + 2];
vmem[tid] += vmem[tid + 1];
}
if (tid == 0) g_odata[blockIdx.x] = smem[0];
}
#define BLOCKX 32
#define BLOCKY 16
__global__ void transposeShared(float *in, float *out, int nx, int ny) {
// shared memory
__shared__ float smem[BLOCKY][BLOCKX];
// global memory index for original matrix
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
// transposed index in shared memory
int irow = (threadIdx.y * blockDim.x + threadIdx.x) % blockDim.y;
int icol = (threadIdx.y * blockDim.x + threadIdx.x) / blockDim.y;
// global memory index for transposed matrix
int ox = blockDim.y * blockIdx.y + irow;
int oy = blockDim.x * blockIdx.x + icol;
if (ix < nx && iy < ny) {
smem[threadIdx.y][threadIdx.x] = in[iy * nx + ix];
__syncthreads();
out[oy * ny + ox] = smem[irow][icol];
}
}
... ...
dim3 block(BLOCKX, BLOCKY);
dim3 grid((nx + block.x - 1) / block.x, (ny + block.y - 1) / block.y);
transposeShared<<<grid, block>>>(d_in, d_out, nx, ny);
All other aforementioned things (memory padding, dynamic shared memory) can also be applied here.
The full code is here
- Constant memory is read-only from kernel, but it is both readable and writable from the host.
- Constant memory resides in device DRAM (like global memory) and has a dedicated on-chip cache. Like the L1 cache and shared memory, reading from the per-SM constant cache has a much lower latency than reading directly from constant memory.
- It is best if all threads in a warp access the same location in constant memory. Accesses to different addresses by threads within a warp are serialized.
- Constant variables must be declared in global scope with the qualifier __constant__, and the values are set using the function:
cudaError_t cudaMemcpyToSymbol(const void *symbol, const void * src, size_t count, size_t offset, cudaMemcpyKind kind)
- The lifespan of constant variables is the whole application. They are accessible from all threads within a grid and from the host.
-
Background
-
Implementation - using constant memory
In the above fomular, the coefficients c0, c1, c2, and c3 are the same across all threads and are never modified. This makes them excellent candidates for constant memory because (1) they are read-only (2) every thread in a warp references the same constant memory location at the same time
- Comparing with the read-only cache
Global memory's read-only cache is a separate cache with separate memory bandwidth from normal global memory reads.
Generally, the read-only cache is better for scattered reads than the L1 cache, but are not well optimized for uniform reads (all threads in a warp reading the same address). For the 1D stencil, constant mempry yields better performance than read-only cache.
Official tutorial - Using CUDA Warp-Level Primitives
- The shuffle instruction
Starting with the Kepler family of GPUs (compute capability 3.0 or higher), the shuffle instruction was introduced to allow threads in the same warp to directly read another thread’s register.
Advantage:
- exchanging data with each other directly, rather than going through shared or global memory
- lower latency than shared memory
- not consuming extra memory to perform a data exchange
- An important concept: lane
A lane simply refers to a single thread within a warp. Each lane in a warp is uniquely identified by a lane index in the range [0,31].
Example: 1D block
warpID = threadIdx.x / 32
laneID = threadIdx.x % 32
Before going on:
- For the following functions, T can be int, unsigned int, long, unsigned long, long long, unsigned long long, float or double. With the cuda_fp16.h header included, T can also be __half or __half2.
- For each function, the set of threads that participates in invoking each primitive is specified using a 32-bit mask. But what you think of this parameter is probably wrong... Be sure to read this.
T __shfl_sync(unsigned int mask, T var, int srcLane, int width = warpSize);
This function send var from srcLane to other threads, and return the value var. When width=wrapSize (default), the sharing scope is the whole warp; width can also be set to any power-of-2 between 2 and 32, then the warp will be divided into several segments with a separate shuffle operation performed in each segment.
e.g int y = __shfl_sync(0xffffffff, x, 3, 16); Then threads 0 through 15 would receive the value of x from thread 3, and threads 16 through 31 would receive the value of x from thread 19.
T __shfl_up_sync(unsigned int mask, T var, unsigned int delta, int width = warpSize);
This function calculates the source lane index by subtracting delta from the caller’s lane index.
Note: There is no wrap around with __shfl_up_sync, so the lowest delta threads in a warp will be unchanged.
T __shfl_up_sync(unsigned int mask, T var, unsigned int delta, int width = warpSize);
This is exactly opposite to __shfl_up_sync
T __shfl_xor_sync(unsigned int mask, T var, int laneMask, int width = warpSize);
The intrinsic instruction calculates a source lane index by performing a bitwise XOR of the caller’s lane index with laneMask. The value held by the source thread is returned.
This instruction facilitates a butterfly addressing pattern: